Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Determinism] Enable environment var to use cusparse spmm deterministic algorithm #7310

Merged
merged 8 commits into from
Apr 19, 2024
28 changes: 15 additions & 13 deletions src/array/cuda/spmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
*/
#include <dgl/array.h>

#include <cstdlib>

#include "../../runtime/cuda/cuda_common.h"
#include "./functor.cuh"
#include "./ge_spmm.cuh"
Expand All @@ -21,13 +23,15 @@
* @note use cusparse if the reduce operator is `sum` and there is
* no broadcast, use dgl's kernel in other cases.
*/
template <int XPU, typename IdType, typename DType>

Check warning on line 26 in src/array/cuda/spmm.cu

View workflow job for this annotation

GitHub Actions / lintrunner

CLANGFORMAT format

See https://clang.llvm.org/docs/ClangFormat.html. Run `lintrunner -a` to apply this patch.
void SpMMCsr(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux) {
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0];
bool use_efeat = op != "copy_lhs";
bool use_deterministic_alg_only = false;
if (NULL != std::getenv("USE_DETERMINISTIC_ALG"))
use_deterministic_alg_only = true;

if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
Expand All @@ -37,10 +41,9 @@
for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr, static_cast<DType*>(ufeat->data), nullptr,
static_cast<DType*>(out->data), x_length);
} else if (
op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(more_nnz)) {
static_cast<DType*>(out->data), x_length, use_deterministic_alg_only);
} else if (op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(more_nnz)) {
// cusparse
int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
Expand All @@ -50,7 +53,7 @@
CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr, static_cast<DType*>(ufeat->data),
static_cast<DType*>(efeat->data), static_cast<DType*>(out->data),
x_length);
x_length, use_deterministic_alg_only);
} else { // general kernel
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
Expand All @@ -76,10 +79,9 @@
* @brief CUDA implementation of g-SpMM on Coo format.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCoo(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux) {
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat,
NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
if (reduce == "sum") {
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> >(
Expand Down
Loading
Loading