-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DO NOT MERGE][Sparse] Support SpSpMul #5464
base: master
Are you sure you want to change the base?
Conversation
To trigger regression tests:
|
shared by both matrices and the non-zero value from the first matrix at each | ||
coordinate. The indices tensor shows the indices of the common coordinates | ||
based on the first matrix. | ||
*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add * at the beginning of each row
* @return SparseMatrix | ||
*/ | ||
c10::intrusive_ptr<SparseMatrix> SpSpMul( | ||
const c10::intrusive_ptr<SparseMatrix>& lhs_mat, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make it consistent --> use A & B?
static variable_list forward( | ||
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat, | ||
torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat, | ||
torch::Tensor rhs_val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto.
A & B, A_val, B_val
std::tie(lhs_intersect_rhs, lhs_indices) = | ||
SparseMatrixIntersection(lhs_mat, lhs_val, rhs_mat); | ||
std::tie(rhs_intersect_lhs, rhs_indices) = | ||
SparseMatrixIntersection(rhs_mat, rhs_val, lhs_intersect_rhs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need call intersection twice?
Can we simplify it to just one intersection calculation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally LGTM. Two comments.
} | ||
TORCH_CHECK( | ||
!lhs_mat->HasDuplicate() && !rhs_mat->HasDuplicate(), | ||
"Only support SpSpMul on sparse matrices without duplicate values") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The HasDuplicate()
is costly as it requires sorting and a linear scan (on GPU, it will incur CPU-GPU synchronization). I understand that SpSpMul shall not support matrices with duplicate entries. My question is what is the general best practice to handle those cases? I see three options:
- Use a heavy check like the code here.
- Try to design a light check.
- Say this is an undefined behavior but make sure the operation will not crash.
cc @frozenbugs
@@ -225,18 +233,36 @@ def test_sub_sparse_diag(val_shape): | |||
assert torch.allclose(dense_diff, -diff4) | |||
|
|||
|
|||
@pytest.mark.parametrize("op", ["mul", "truediv", "pow"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look like the new test omits the case for "truediv" and "pow". Have they been covered in other test cases?
Description
Resolve #5368.
Checklist
Please feel free to remove inapplicable items for your PR.
Changes