Skip to content

Commit

Permalink
[Sparse] Support SpSpMul
Browse files Browse the repository at this point in the history
  • Loading branch information
czkkkkkk committed Mar 17, 2023
1 parent 5cab423 commit 81705d5
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 16 deletions.
16 changes: 13 additions & 3 deletions dgl_sparse/include/sparse/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
namespace dgl {
namespace sparse {

// TODO(zhenkun): support addition of matrices with different sparsity.
/**
* @brief Adds two sparse matrices. Currently does not support two matrices with
* different sparsity.
* @brief Adds two sparse matrices possibly with different sparsities.
*
* @param A SparseMatrix
* @param B SparseMatrix
Expand All @@ -25,6 +23,18 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B);

/**
* @brief Multiplies two sparse matrices possibly with different sparsities.
*
* @param A SparseMatrix
* @param B SparseMatrix
*
* @return SparseMatrix
*/
c10::intrusive_ptr<SparseMatrix> SpSpMul(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat);

} // namespace sparse
} // namespace dgl

Expand Down
114 changes: 114 additions & 0 deletions dgl_sparse/src/elemenwise_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
namespace dgl {
namespace sparse {

using namespace torch::autograd;

c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) {
Expand All @@ -32,5 +34,117 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
return SparseMatrix::FromCOO(sum.indices(), sum.values(), A->shape());
}

class SpSpMulAutoGrad : public Function<SpSpMulAutoGrad> {
public:
static variable_list forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
torch::Tensor rhs_val);

static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};

/**
* @brief Compute the intersection of the non-zero coordinates between two
sparse matrices.
* @return Sparse matrix and indices tensor. The matrix contains the coordinates
shared by both matrices and the non-zero value from the first matrix at each
coordinate. The indices tensor shows the indices of the common coordinates
based on the first matrix.
*/
std::pair<c10::intrusive_ptr<SparseMatrix>, torch::Tensor>
SparseMatrixIntersection(
c10::intrusive_ptr<SparseMatrix> lhs_mat, torch::Tensor lhs_val,
c10::intrusive_ptr<SparseMatrix> rhs_mat) {
auto lhs_dgl_coo = COOToOldDGLCOO(lhs_mat->COOPtr());
torch::Tensor rhs_row, rhs_col;
std::tie(rhs_row, rhs_col) = rhs_mat->COOTensors();
auto rhs_dgl_row = TorchTensorToDGLArray(rhs_row);
auto rhs_dgl_col = TorchTensorToDGLArray(rhs_col);
auto dgl_results =
aten::COOGetDataAndIndices(lhs_dgl_coo, rhs_dgl_row, rhs_dgl_col);
auto ret_row = DGLArrayToTorchTensor(dgl_results[0]);
auto ret_col = DGLArrayToTorchTensor(dgl_results[1]);
auto ret_indices = DGLArrayToTorchTensor(dgl_results[2]);
auto ret_val = lhs_mat->value().index_select(0, ret_indices);
auto ret_mat = SparseMatrix::FromCOO(
torch::stack({ret_row, ret_col}), ret_val, lhs_mat->shape());
return {ret_mat, ret_indices};
}

variable_list SpSpMulAutoGrad::forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
torch::Tensor rhs_val) {
c10::intrusive_ptr<SparseMatrix> lhs_intersect_rhs, rhs_intersect_lhs;
torch::Tensor lhs_indices, rhs_indices;
std::tie(lhs_intersect_rhs, lhs_indices) =
SparseMatrixIntersection(lhs_mat, lhs_val, rhs_mat);
std::tie(rhs_intersect_lhs, rhs_indices) =
SparseMatrixIntersection(rhs_mat, rhs_val, lhs_intersect_rhs);
auto ret_mat = SparseMatrix::ValLike(
lhs_intersect_rhs,
lhs_intersect_rhs->value() * rhs_intersect_lhs->value());

ctx->saved_data["lhs_require_grad"] = lhs_val.requires_grad();
ctx->saved_data["rhs_require_grad"] = rhs_val.requires_grad();
if (lhs_val.requires_grad()) {
ctx->saved_data["lhs_val_shape"] = lhs_val.sizes().vec();
ctx->saved_data["rhs_intersect_lhs"] = rhs_intersect_lhs;
ctx->saved_data["lhs_indices"] = lhs_indices;
}
if (rhs_val.requires_grad()) {
ctx->saved_data["rhs_val_shape"] = rhs_val.sizes().vec();
ctx->saved_data["lhs_intersect_rhs"] = lhs_intersect_rhs;
ctx->saved_data["rhs_indices"] = rhs_indices;
}
return {ret_mat->Indices(), ret_mat->value()};
}

tensor_list SpSpMulAutoGrad::backward(
AutogradContext* ctx, tensor_list grad_outputs) {
torch::Tensor lhs_val_grad, rhs_val_grad;
auto output_grad = grad_outputs[1];
if (ctx->saved_data["lhs_require_grad"].toBool()) {
auto rhs_intersect_lhs =
ctx->saved_data["rhs_intersect_lhs"].toCustomClass<SparseMatrix>();
const auto& lhs_val_shape = ctx->saved_data["lhs_val_shape"].toIntVector();
auto lhs_indices = ctx->saved_data["lhs_indices"].toTensor();
lhs_val_grad = torch::zeros(lhs_val_shape, output_grad.options());
auto intersect_grad = rhs_intersect_lhs->value() * output_grad;
lhs_val_grad.index_put_({lhs_indices}, intersect_grad);
}
if (ctx->saved_data["rhs_require_grad"].toBool()) {
auto lhs_intersect_rhs =
ctx->saved_data["lhs_intersect_rhs"].toCustomClass<SparseMatrix>();
const auto& rhs_val_shape = ctx->saved_data["rhs_val_shape"].toIntVector();
auto rhs_indices = ctx->saved_data["rhs_indices"].toTensor();
rhs_val_grad = torch::zeros(rhs_val_shape, output_grad.options());
auto intersect_grad = lhs_intersect_rhs->value() * output_grad;
rhs_val_grad.index_put_({rhs_indices}, intersect_grad);
}
return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad};
}

c10::intrusive_ptr<SparseMatrix> SpSpMul(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
ElementwiseOpSanityCheck(lhs_mat, rhs_mat);
if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
return SparseMatrix::FromDiagPointer(
lhs_mat->DiagPtr(), lhs_mat->value() * rhs_mat->value(),
lhs_mat->shape());
}
TORCH_CHECK(
!lhs_mat->HasDuplicate() && !rhs_mat->HasDuplicate(),
"Only support SpSpMul on sparse matrices without duplicate values")
auto results = SpSpMulAutoGrad::apply(
lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value());
const auto& indices = results[0];
const auto& val = results[1];
return SparseMatrix::FromCOO(indices, val, lhs_mat->shape());
}

} // namespace sparse
} // namespace dgl
1 change: 1 addition & 0 deletions dgl_sparse/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("from_csc", &SparseMatrix::FromCSC)
.def("from_diag", &SparseMatrix::FromDiag)
.def("spsp_add", &SpSpAdd)
.def("spsp_mul", &SpSpMul)
.def("reduce", &Reduce)
.def("sum", &ReduceSum)
.def("smean", &ReduceMean)
Expand Down
3 changes: 3 additions & 0 deletions dgl_sparse/src/sparse_matrix_coalesce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::Coalesce() {

bool SparseMatrix::HasDuplicate() {
aten::CSRMatrix dgl_csr;
if (HasDiag()) {
return false;
}
// The format for calculation will be chosen in the following order: CSR,
// CSC. CSR is created if the sparse matrix only has CSC format.
if (HasCSR() || !HasCSC()) {
Expand Down
19 changes: 8 additions & 11 deletions python/dgl/sparse/elementwise_op_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ def spsp_add(A, B):
)


def spsp_mul(A, B):
"""Invoke C++ sparse library for multiplication"""
return SparseMatrix(
torch.ops.dgl_sparse.spsp_mul(A.c_sparse_matrix, B.c_sparse_matrix)
)


def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
"""Elementwise addition
Expand Down Expand Up @@ -119,17 +126,7 @@ def sp_mul(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
"""
if is_scalar(B):
return val_like(A, A.val * B)
if A.is_diag() and B.is_diag():
assert A.shape == B.shape, (
f"The shape of diagonal matrix A {A.shape} and B {B.shape} must"
f"match for elementwise multiplication."
)
return diag(A.val * B.val, A.shape)
# Python falls back to B.__rmul__(A) then TypeError when NotImplemented is
# returned.
# So this also handles the case of scalar * SparseMatrix since we set
# SparseMatrix.__rmul__ to be the same as SparseMatrix.__mul__.
return NotImplemented
return spsp_mul(A, B)


def sp_div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
Expand Down
44 changes: 43 additions & 1 deletion tests/python/pytorch/sparse/test_elementwise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
import pytest
import torch

from dgl.sparse import diag, power
from dgl.sparse import diag, power, val_like
from .utils import (
rand_coo,
rand_csc,
rand_csr,
rand_diag,
sparse_matrix_to_dense,
)


@pytest.mark.parametrize("opname", ["add", "sub", "mul", "truediv"])
Expand Down Expand Up @@ -240,3 +247,38 @@ def test_error_op_sparse_diag(op):
getattr(operator, op)(A, D)
with pytest.raises(TypeError):
getattr(operator, op)(D, A)


@pytest.mark.parametrize(
"create_func1", [rand_coo, rand_csr, rand_csc, rand_diag]
)
@pytest.mark.parametrize(
"create_func2", [rand_coo, rand_csr, rand_csc, rand_diag]
)
@pytest.mark.parametrize("shape", [(5, 5), (5, 3)])
@pytest.mark.parametrize("nnz1", [5, 15])
@pytest.mark.parametrize("nnz2", [1, 14])
@pytest.mark.parametrize("nz_dim", [None, 3])
def test_spspmul(create_func1, create_func2, shape, nnz1, nnz2, nz_dim):
dev = F.ctx()
A = create_func1(shape, nnz1, dev, nz_dim)
B = create_func2(shape, nnz2, dev, nz_dim)
C = dglsp.mul(A, B)
assert not C.has_duplicate()

DA = sparse_matrix_to_dense(A)
DB = sparse_matrix_to_dense(B)
DC = DA * DB

grad = torch.rand_like(C.val)
C.val.backward(grad)
DC_grad = sparse_matrix_to_dense(val_like(C, grad))
DC.backward(DC_grad)

assert torch.allclose(sparse_matrix_to_dense(C), DC, atol=1e-05)
assert torch.allclose(
val_like(A, A.val.grad).to_dense(), DA.grad, atol=1e-05
)
assert torch.allclose(
val_like(B, B.val.grad).to_dense(), DB.grad, atol=1e-05
)
11 changes: 10 additions & 1 deletion tests/python/pytorch/sparse/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch

from dgl.sparse import from_coo, from_csc, from_csr, SparseMatrix
from dgl.sparse import diag, from_coo, from_csc, from_csr, SparseMatrix

np.random.seed(42)
torch.random.manual_seed(42)
Expand Down Expand Up @@ -64,6 +64,15 @@ def rand_csc(shape, nnz, dev, nz_dim=None):
return from_csc(indptr, indices, val, shape=shape)


def rand_diag(shape, nnz, dev, nz_dim=None):
nnz = min(shape)
if nz_dim is None:
val = torch.randn(nnz, device=dev, requires_grad=True)
else:
val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)
return diag(val, shape)


def rand_coo_uncoalesced(shape, nnz, dev):
# Create a sparse matrix with possible duplicate entries.
row = torch.randint(shape[0], (nnz,), device=dev)
Expand Down

0 comments on commit 81705d5

Please sign in to comment.