Skip to content

Commit

Permalink
Apply lint with clang-format
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonC committed Apr 17, 2024
1 parent fdc5108 commit a8ae77c
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 148 deletions.
24 changes: 11 additions & 13 deletions src/array/cuda/spmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
* @file array/cuda/spmm.cu
* @brief SPMM C APIs and definitions.
*/
#include <cstdlib>
#include <dgl/array.h>

#include <cstdlib>

#include "../../runtime/cuda/cuda_common.h"
#include "./functor.cuh"
#include "./ge_spmm.cuh"
Expand All @@ -23,15 +24,14 @@ namespace aten {
* no broadcast, use dgl's kernel in other cases.
*/
template <int XPU, typename IdType, typename DType>

Check warning on line 26 in src/array/cuda/spmm.cu

View workflow job for this annotation

GitHub Actions / lintrunner

CLANGFORMAT format

See https://clang.llvm.org/docs/ClangFormat.html. Run `lintrunner -a` to apply this patch.
void SpMMCsr(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux) {
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0];
bool use_efeat = op != "copy_lhs";
bool use_deterministic_alg_only = false;
if (NULL != std::getenv("USE_DETERMINISTIC_ALG"))
use_deterministic_alg_only = true;
use_deterministic_alg_only = true;

if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
Expand All @@ -42,9 +42,8 @@ void SpMMCsr(
CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr, static_cast<DType*>(ufeat->data), nullptr,
static_cast<DType*>(out->data), x_length, use_deterministic_alg_only);
} else if (
op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(more_nnz)) {
} else if (op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(more_nnz)) {
// cusparse
int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
Expand Down Expand Up @@ -80,10 +79,9 @@ void SpMMCsr(
* @brief CUDA implementation of g-SpMM on Coo format.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCoo(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux) {
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat,
NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
if (reduce == "sum") {
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> >(
Expand Down
Loading

0 comments on commit a8ae77c

Please sign in to comment.