Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE][Sparse] Support SpSpMul #5464

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions dgl_sparse/include/sparse/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
namespace dgl {
namespace sparse {

// TODO(zhenkun): support addition of matrices with different sparsity.
/**
* @brief Adds two sparse matrices. Currently does not support two matrices with
* different sparsity.
* @brief Adds two sparse matrices possibly with different sparsities.
*
* @param A SparseMatrix
* @param B SparseMatrix
Expand All @@ -25,6 +23,18 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B);

/**
* @brief Multiplies two sparse matrices possibly with different sparsities.
*
* @param A SparseMatrix
* @param B SparseMatrix
*
* @return SparseMatrix
*/
c10::intrusive_ptr<SparseMatrix> SpSpMul(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it consistent --> use A & B?

const c10::intrusive_ptr<SparseMatrix>& rhs_mat);

} // namespace sparse
} // namespace dgl

Expand Down
114 changes: 114 additions & 0 deletions dgl_sparse/src/elemenwise_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
namespace dgl {
namespace sparse {

using namespace torch::autograd;

c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) {
Expand All @@ -32,5 +34,117 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
return SparseMatrix::FromCOO(sum.indices(), sum.values(), A->shape());
}

class SpSpMulAutoGrad : public Function<SpSpMulAutoGrad> {
public:
static variable_list forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
torch::Tensor rhs_val);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

A & B, A_val, B_val


static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};

/**
* @brief Compute the intersection of the non-zero coordinates between two
sparse matrices.
* @return Sparse matrix and indices tensor. The matrix contains the coordinates
shared by both matrices and the non-zero value from the first matrix at each
coordinate. The indices tensor shows the indices of the common coordinates
based on the first matrix.
*/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add * at the beginning of each row

std::pair<c10::intrusive_ptr<SparseMatrix>, torch::Tensor>
SparseMatrixIntersection(
c10::intrusive_ptr<SparseMatrix> lhs_mat, torch::Tensor lhs_val,
c10::intrusive_ptr<SparseMatrix> rhs_mat) {
auto lhs_dgl_coo = COOToOldDGLCOO(lhs_mat->COOPtr());
torch::Tensor rhs_row, rhs_col;
std::tie(rhs_row, rhs_col) = rhs_mat->COOTensors();
auto rhs_dgl_row = TorchTensorToDGLArray(rhs_row);
auto rhs_dgl_col = TorchTensorToDGLArray(rhs_col);
auto dgl_results =
aten::COOGetDataAndIndices(lhs_dgl_coo, rhs_dgl_row, rhs_dgl_col);
auto ret_row = DGLArrayToTorchTensor(dgl_results[0]);
auto ret_col = DGLArrayToTorchTensor(dgl_results[1]);
auto ret_indices = DGLArrayToTorchTensor(dgl_results[2]);
auto ret_val = lhs_mat->value().index_select(0, ret_indices);
auto ret_mat = SparseMatrix::FromCOO(
torch::stack({ret_row, ret_col}), ret_val, lhs_mat->shape());
return {ret_mat, ret_indices};
}

variable_list SpSpMulAutoGrad::forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
torch::Tensor rhs_val) {
c10::intrusive_ptr<SparseMatrix> lhs_intersect_rhs, rhs_intersect_lhs;
torch::Tensor lhs_indices, rhs_indices;
std::tie(lhs_intersect_rhs, lhs_indices) =
SparseMatrixIntersection(lhs_mat, lhs_val, rhs_mat);
std::tie(rhs_intersect_lhs, rhs_indices) =
SparseMatrixIntersection(rhs_mat, rhs_val, lhs_intersect_rhs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need call intersection twice?
Can we simplify it to just one intersection calculation?

auto ret_mat = SparseMatrix::ValLike(
lhs_intersect_rhs,
lhs_intersect_rhs->value() * rhs_intersect_lhs->value());

ctx->saved_data["lhs_require_grad"] = lhs_val.requires_grad();
ctx->saved_data["rhs_require_grad"] = rhs_val.requires_grad();
if (lhs_val.requires_grad()) {
ctx->saved_data["lhs_val_shape"] = lhs_val.sizes().vec();
ctx->saved_data["rhs_intersect_lhs"] = rhs_intersect_lhs;
ctx->saved_data["lhs_indices"] = lhs_indices;
}
if (rhs_val.requires_grad()) {
ctx->saved_data["rhs_val_shape"] = rhs_val.sizes().vec();
ctx->saved_data["lhs_intersect_rhs"] = lhs_intersect_rhs;
ctx->saved_data["rhs_indices"] = rhs_indices;
}
return {ret_mat->Indices(), ret_mat->value()};
}

tensor_list SpSpMulAutoGrad::backward(
AutogradContext* ctx, tensor_list grad_outputs) {
torch::Tensor lhs_val_grad, rhs_val_grad;
auto output_grad = grad_outputs[1];
if (ctx->saved_data["lhs_require_grad"].toBool()) {
auto rhs_intersect_lhs =
ctx->saved_data["rhs_intersect_lhs"].toCustomClass<SparseMatrix>();
const auto& lhs_val_shape = ctx->saved_data["lhs_val_shape"].toIntVector();
auto lhs_indices = ctx->saved_data["lhs_indices"].toTensor();
lhs_val_grad = torch::zeros(lhs_val_shape, output_grad.options());
auto intersect_grad = rhs_intersect_lhs->value() * output_grad;
lhs_val_grad.index_put_({lhs_indices}, intersect_grad);
}
if (ctx->saved_data["rhs_require_grad"].toBool()) {
auto lhs_intersect_rhs =
ctx->saved_data["lhs_intersect_rhs"].toCustomClass<SparseMatrix>();
const auto& rhs_val_shape = ctx->saved_data["rhs_val_shape"].toIntVector();
auto rhs_indices = ctx->saved_data["rhs_indices"].toTensor();
rhs_val_grad = torch::zeros(rhs_val_shape, output_grad.options());
auto intersect_grad = lhs_intersect_rhs->value() * output_grad;
rhs_val_grad.index_put_({rhs_indices}, intersect_grad);
}
return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad};
}

c10::intrusive_ptr<SparseMatrix> SpSpMul(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
ElementwiseOpSanityCheck(lhs_mat, rhs_mat);
if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
return SparseMatrix::FromDiagPointer(
lhs_mat->DiagPtr(), lhs_mat->value() * rhs_mat->value(),
lhs_mat->shape());
}
TORCH_CHECK(
!lhs_mat->HasDuplicate() && !rhs_mat->HasDuplicate(),
"Only support SpSpMul on sparse matrices without duplicate values")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The HasDuplicate() is costly as it requires sorting and a linear scan (on GPU, it will incur CPU-GPU synchronization). I understand that SpSpMul shall not support matrices with duplicate entries. My question is what is the general best practice to handle those cases? I see three options:

  1. Use a heavy check like the code here.
  2. Try to design a light check.
  3. Say this is an undefined behavior but make sure the operation will not crash.

cc @frozenbugs

auto results = SpSpMulAutoGrad::apply(
lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value());
const auto& indices = results[0];
const auto& val = results[1];
return SparseMatrix::FromCOO(indices, val, lhs_mat->shape());
}

} // namespace sparse
} // namespace dgl
1 change: 1 addition & 0 deletions dgl_sparse/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("from_csc", &SparseMatrix::FromCSC)
.def("from_diag", &SparseMatrix::FromDiag)
.def("spsp_add", &SpSpAdd)
.def("spsp_mul", &SpSpMul)
.def("reduce", &Reduce)
.def("sum", &ReduceSum)
.def("smean", &ReduceMean)
Expand Down
3 changes: 3 additions & 0 deletions dgl_sparse/src/sparse_matrix_coalesce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::Coalesce() {

bool SparseMatrix::HasDuplicate() {
aten::CSRMatrix dgl_csr;
if (HasDiag()) {
return false;
}
// The format for calculation will be chosen in the following order: CSR,
// CSC. CSR is created if the sparse matrix only has CSC format.
if (HasCSR() || !HasCSC()) {
Expand Down
19 changes: 8 additions & 11 deletions python/dgl/sparse/elementwise_op_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ def spsp_add(A, B):
)


def spsp_mul(A, B):
"""Invoke C++ sparse library for multiplication"""
return SparseMatrix(
torch.ops.dgl_sparse.spsp_mul(A.c_sparse_matrix, B.c_sparse_matrix)
)


def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
"""Elementwise addition
Expand Down Expand Up @@ -119,17 +126,7 @@ def sp_mul(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
"""
if is_scalar(B):
return val_like(A, A.val * B)
if A.is_diag() and B.is_diag():
assert A.shape == B.shape, (
f"The shape of diagonal matrix A {A.shape} and B {B.shape} must"
f"match for elementwise multiplication."
)
return diag(A.val * B.val, A.shape)
# Python falls back to B.__rmul__(A) then TypeError when NotImplemented is
# returned.
# So this also handles the case of scalar * SparseMatrix since we set
# SparseMatrix.__rmul__ to be the same as SparseMatrix.__mul__.
return NotImplemented
return spsp_mul(A, B)


def sp_div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
Expand Down
58 changes: 42 additions & 16 deletions tests/python/pytorch/sparse/test_elementwise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
import pytest
import torch

from dgl.sparse import diag, power
from dgl.sparse import diag, power, val_like

from .utils import (
rand_coo,
rand_csc,
rand_csr,
rand_diag,
sparse_matrix_to_dense,
)


@pytest.mark.parametrize("opname", ["add", "sub", "mul", "truediv"])
Expand Down Expand Up @@ -225,18 +233,36 @@ def test_sub_sparse_diag(val_shape):
assert torch.allclose(dense_diff, -diff4)


@pytest.mark.parametrize("op", ["mul", "truediv", "pow"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look like the new test omits the case for "truediv" and "pow". Have they been covered in other test cases?

def test_error_op_sparse_diag(op):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(row.shape).to(ctx)
A = dglsp.from_coo(row, col, val)

shape = (3, 4)
D = dglsp.diag(torch.randn(row.shape[0]).to(ctx), shape=shape)

with pytest.raises(TypeError):
getattr(operator, op)(A, D)
with pytest.raises(TypeError):
getattr(operator, op)(D, A)
@pytest.mark.parametrize(
"create_func1", [rand_coo, rand_csr, rand_csc, rand_diag]
)
@pytest.mark.parametrize(
"create_func2", [rand_coo, rand_csr, rand_csc, rand_diag]
)
@pytest.mark.parametrize("shape", [(5, 5), (5, 3)])
@pytest.mark.parametrize("nnz1", [5, 15])
@pytest.mark.parametrize("nnz2", [1, 14])
@pytest.mark.parametrize("nz_dim", [None, 3])
def test_spspmul(create_func1, create_func2, shape, nnz1, nnz2, nz_dim):
dev = F.ctx()
A = create_func1(shape, nnz1, dev, nz_dim)
B = create_func2(shape, nnz2, dev, nz_dim)
C = dglsp.mul(A, B)
assert not C.has_duplicate()

DA = sparse_matrix_to_dense(A)
DB = sparse_matrix_to_dense(B)
DC = DA * DB

grad = torch.rand_like(C.val)
C.val.backward(grad)
DC_grad = sparse_matrix_to_dense(val_like(C, grad))
DC.backward(DC_grad)

assert torch.allclose(sparse_matrix_to_dense(C), DC, atol=1e-05)
assert torch.allclose(
val_like(A, A.val.grad).to_dense(), DA.grad, atol=1e-05
)
assert torch.allclose(
val_like(B, B.val.grad).to_dense(), DB.grad, atol=1e-05
)
11 changes: 10 additions & 1 deletion tests/python/pytorch/sparse/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch

from dgl.sparse import from_coo, from_csc, from_csr, SparseMatrix
from dgl.sparse import diag, from_coo, from_csc, from_csr, SparseMatrix

np.random.seed(42)
torch.random.manual_seed(42)
Expand Down Expand Up @@ -64,6 +64,15 @@ def rand_csc(shape, nnz, dev, nz_dim=None):
return from_csc(indptr, indices, val, shape=shape)


def rand_diag(shape, nnz, dev, nz_dim=None):
nnz = min(shape)
if nz_dim is None:
val = torch.randn(nnz, device=dev, requires_grad=True)
else:
val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)
return diag(val, shape)


def rand_coo_uncoalesced(shape, nnz, dev):
# Create a sparse matrix with possible duplicate entries.
row = torch.randint(shape[0], (nnz,), device=dev)
Expand Down