Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add check for NNZ in COOToCSR #7459

Merged
merged 13 commits into from
Jun 17, 2024
45 changes: 45 additions & 0 deletions src/array/cpu/spmat_op_impl_coo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <dgl/runtime/parallel_for.h>
#include <dmlc/omp.h>

#include <limits>
#include <numeric>
#include <tuple>
#include <unordered_map>
Expand Down Expand Up @@ -323,6 +324,17 @@ template <class IdType>
CSRMatrix SortedCOOToCSR(const COOMatrix &coo) {
const int64_t N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0];

// TODO(Mingbang): This is just a temporary check to ensure that NNZ does not
Rhett-Ying marked this conversation as resolved.
Show resolved Hide resolved
// exceed INT32_MAX, preventing overflow issues that could lead to undefined
// behavior or incorrect results. Later we need to suppoort larger values of
// NNZ.
if (std::is_same<IdType, int32_t>::value &&
NNZ > std::numeric_limits<int32_t>::max()) {
LOG(FATAL) << "Number of non zero elements exceeds the maximum value that "
"can be represented by int32_t for IdType int32_t.";
}

const IdType *const row_data = static_cast<IdType *>(coo.row->data);
const IdType *const data =
COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
Expand Down Expand Up @@ -418,6 +430,17 @@ CSRMatrix UnSortedSparseCOOToCSR(const COOMatrix &coo) {

const UIdType N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0];

// TODO(Mingbang): This is just a temporary check to ensure that NNZ does not
// exceed INT32_MAX, preventing overflow issues that could lead to undefined
// behavior or incorrect results. Later we need to suppoort larger values of
// NNZ.
if (std::is_same<IdType, int32_t>::value &&
NNZ > std::numeric_limits<int32_t>::max()) {
LOG(FATAL) << "Number of non zero elements exceeds the maximum value that "
"can be represented by int32_t for IdType int32_t.";
}

const IdType *const row_data = static_cast<IdType *>(coo.row->data);
const IdType *const col_data = static_cast<IdType *>(coo.col->data);
const IdType *const data =
Expand Down Expand Up @@ -542,6 +565,17 @@ CSRMatrix UnSortedDenseCOOToCSR(const COOMatrix &coo) {

const UIdType N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0];

// TODO(Mingbang): This is just a temporary check to ensure that NNZ does not
// exceed INT32_MAX, preventing overflow issues that could lead to undefined
// behavior or incorrect results. Later we need to suppoort larger values of
// NNZ.
if (std::is_same<IdType, int32_t>::value &&
NNZ > std::numeric_limits<int32_t>::max()) {
LOG(FATAL) << "Number of non zero elements exceeds the maximum value that "
"can be represented by int32_t for IdType int32_t.";
}

const IdType *const row_data = static_cast<IdType *>(coo.row->data);
const IdType *const col_data = static_cast<IdType *>(coo.col->data);
const IdType *const data =
Expand Down Expand Up @@ -637,6 +671,17 @@ template <typename IdType>
CSRMatrix UnSortedSmallCOOToCSR(COOMatrix coo) {
const int64_t N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0];

// TODO(Mingbang): This is just a temporary check to ensure that NNZ does not
// exceed INT32_MAX, preventing overflow issues that could lead to undefined
// behavior or incorrect results. Later we need to suppoort larger values of
// NNZ.
if (std::is_same<IdType, int32_t>::value &&
NNZ > std::numeric_limits<int32_t>::max()) {
LOG(FATAL) << "Number of non zero elements exceeds the maximum value that "
"can be represented by int32_t for IdType int32_t.";
}

const IdType *row_data = static_cast<IdType *>(coo.row->data);
const IdType *col_data = static_cast<IdType *>(coo.col->data);
const IdType *data =
Expand Down
Loading