Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
  • Loading branch information
Skeleton003 committed Jun 13, 2024
1 parent 94b691b commit 57b1572
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions src/array/cpu/spmat_op_impl_coo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <dgl/runtime/parallel_for.h>
#include <dmlc/omp.h>

#include <limits>
#include <numeric>
#include <tuple>
#include <unordered_map>
Expand Down Expand Up @@ -323,6 +324,17 @@ template <class IdType>
CSRMatrix SortedCOOToCSR(const COOMatrix &coo) {
const int64_t N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0];

// TODO[Mingbang]: This is just a temporary check to ensure that NNZ does not
// exceed INT32_MAX, preventing overflow issues that could lead to undefined
// behavior or incorrect results. Later we need to suppoort larger values of
// NNZ.
if (std::is_same<IdType, int32_t>::value &&
NNZ > std::numeric_limits<int32_t>::max()) {
LOG(FATAL) << "Number of non zero elements exceeds the maximum value that "
"can be represented by int32_t for IdType int32_t.";
}

const IdType *const row_data = static_cast<IdType *>(coo.row->data);
const IdType *const data =
COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
Expand Down Expand Up @@ -418,6 +430,17 @@ CSRMatrix UnSortedSparseCOOToCSR(const COOMatrix &coo) {

const UIdType N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0];

// TODO[Mingbang]: This is just a temporary check to ensure that NNZ does not
// exceed INT32_MAX, preventing overflow issues that could lead to undefined
// behavior or incorrect results. Later we need to suppoort larger values of
// NNZ.
if (std::is_same<IdType, int32_t>::value &&
NNZ > std::numeric_limits<int32_t>::max()) {
LOG(FATAL) << "Number of non zero elements exceeds the maximum value that "
"can be represented by int32_t for IdType int32_t.";
}

const IdType *const row_data = static_cast<IdType *>(coo.row->data);
const IdType *const col_data = static_cast<IdType *>(coo.col->data);
const IdType *const data =
Expand Down Expand Up @@ -542,6 +565,17 @@ CSRMatrix UnSortedDenseCOOToCSR(const COOMatrix &coo) {

const UIdType N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0];

// TODO[Mingbang]: This is just a temporary check to ensure that NNZ does not
// exceed INT32_MAX, preventing overflow issues that could lead to undefined
// behavior or incorrect results. Later we need to suppoort larger values of
// NNZ.
if (std::is_same<IdType, int32_t>::value &&
NNZ > std::numeric_limits<int32_t>::max()) {
LOG(FATAL) << "Number of non zero elements exceeds the maximum value that "
"can be represented by int32_t for IdType int32_t.";
}

const IdType *const row_data = static_cast<IdType *>(coo.row->data);
const IdType *const col_data = static_cast<IdType *>(coo.col->data);
const IdType *const data =
Expand Down Expand Up @@ -637,6 +671,17 @@ template <typename IdType>
CSRMatrix UnSortedSmallCOOToCSR(COOMatrix coo) {
const int64_t N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0];

// TODO[Mingbang]: This is just a temporary check to ensure that NNZ does not
// exceed INT32_MAX, preventing overflow issues that could lead to undefined
// behavior or incorrect results. Later we need to suppoort larger values of
// NNZ.
if (std::is_same<IdType, int32_t>::value &&
NNZ > std::numeric_limits<int32_t>::max()) {
LOG(FATAL) << "Number of non zero elements exceeds the maximum value that "
"can be represented by int32_t for IdType int32_t.";
}

const IdType *row_data = static_cast<IdType *>(coo.row->data);
const IdType *col_data = static_cast<IdType *>(coo.col->data);
const IdType *data =
Expand Down

0 comments on commit 57b1572

Please sign in to comment.