Skip to content

Commit

Permalink
[Sparse] Compact C++ API (dmlc#6334)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangyuzhi authored and DominikaJedynak committed Mar 12, 2024
1 parent f00afc5 commit 52356a2
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 1 deletion.
23 changes: 23 additions & 0 deletions dgl_sparse/include/sparse/matrix_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define SPARSE_MATRIX_OPS_H_

#include <sparse/sparse_format.h>
#include <sparse/sparse_matrix.h>

#include <tuple>

Expand All @@ -26,6 +27,28 @@ namespace sparse {
std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(
const std::shared_ptr<COO>& lhs, const std::shared_ptr<COO>& rhs);

/**
* @brief Compact sparse matrix by removing rows or columns without non-zero
* elements in the sparse matrix and relabeling indices of the dimension.
*
* This function serves a dual purpose: it allows you to reorganize the
* indices within a specific dimension (rows or columns) of the sparse matrix
* and, if needed, place certain 'leading_indices' at the beginning of the
* compact dimension.
*
* @param mat The sparse matrix to be compacted.
* @param dim The dimension to compact. Should be 0 or 1. Use 0 for row-wise
* compaction and 1 for column-wise compaction.
* @param leading_indices An optional tensor containing row or column ids that
* should be placed at the beginning of the compact dimension.
*
* @return A tuple containing the compacted sparse matrix and the index mapping
* of the compact dimension from the new index to the original index.
*/
std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> Compact(
const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,
torch::Tensor leading_indices);

} // namespace sparse
} // namespace dgl

Expand Down
65 changes: 65 additions & 0 deletions dgl_sparse/src/macro.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/**
* Copyright (c) 2023 by Contributors
* @file macro.h
* @brief DGL C++ sparse API macros.
*/
#ifndef DGL_SPARSE_MACRO_H_
#define DGL_SPARSE_MACRO_H_

namespace dgl {
namespace sparse {

/**
* Dispatch an operator to a templated implementation function
* according to its device:
*
* DGL_SPARSE_XPU_SWITCH(tensor.device().type(), XPU, {
* // Now XPU is a placeholder for tensor.device().type()
* DeviceSpecificImplementation<XPU>(...);
* });
*/
#define DGL_SPARSE_XPU_SWITCH(device, XPU, op, ...) \
do { \
if ((device) == c10::DeviceType::CPU) { \
constexpr auto XPU = c10::DeviceType::CPU; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< c10::DeviceTypeName(device) << " device."; \
} \
} while (0)

/**
* Dispatch according to ID type (either int32 or int64):
*
* DGL_SPARSE_ID_TYPE_SWITCH(tensor.dtype(), IdType, {
* // Now IdType is the type corresponding to data type of the tensor.
* // For instance, one can do this for a CPU array:
* IdType *data = static_cast<IdType *>(array.data_ptr());
* });
*/
#define DGL_SPARSE_ID_TYPE_SWITCH(dtype, IdType, op, ...) \
do { \
if ((dtype) == torch::kInt32) { \
typedef int32_t IdType; \
{ __VA_ARGS__ } \
} else if ((dtype) == torch::kInt64) { \
typedef int64_t IdType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< (dtype).name() << " as ID dtype."; \
} \
} while (0)

// Macro to dispatch according to device and index type.
#define DGL_SPARSE_COO_SWITCH(coo, XPU, IdType, op, ...) \
DGL_SPARSE_XPU_SWITCH(coo->indices.device().type(), XPU, op, { \
DGL_SPARSE_ID_TYPE_SWITCH( \
(coo)->indices.dtype(), IdType, op, {{__VA_ARGS__}}); \
});

} // namespace sparse
} // namespace dgl

#endif // DGL_SPARSE_MACRO_H_
11 changes: 11 additions & 0 deletions dgl_sparse/src/matrix_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include <sparse/matrix_ops.h>
#include <torch/script.h>

#include "./macro.h"
#include "./matrix_ops_impl.h"

namespace dgl {
namespace sparse {

Expand Down Expand Up @@ -55,5 +58,13 @@ std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(
return {ret_coo, lhs_indices, rhs_indices};
}

std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> Compact(
const c10::intrusive_ptr<SparseMatrix>& mat, uint64_t dim,
torch::Tensor leading_indices) {
DGL_SPARSE_COO_SWITCH(mat->COOPtr(), XPU, IdType, "Compact", {
return CompactImpl<XPU, IdType>(mat, dim, leading_indices);
});
}

} // namespace sparse
} // namespace dgl
16 changes: 15 additions & 1 deletion dgl_sparse/src/matrix_ops_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,22 @@
#ifndef DGL_SPARSE_MATRIX_OPS_IMPL_H_
#define DGL_SPARSE_MATRIX_OPS_IMPL_H_

#include <sparse/sparse_format.h>

#include <tuple>

namespace dgl {
namespace sparse {}
namespace sparse {

template <c10::DeviceType XPU, typename IdType>
std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> CompactImpl(
const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,
torch::Tensor leading_indices) {
// Place holder only.
return {mat, leading_indices};
}

} // namespace sparse
} // namespace dgl

#endif // DGL_SPARSE_MATRIX_OPS_IMPL_H_

0 comments on commit 52356a2

Please sign in to comment.