diff --git a/dgl_sparse/src/sddmm.cc b/dgl_sparse/src/sddmm.cc index b32f1640de45..810583e01a99 100644 --- a/dgl_sparse/src/sddmm.cc +++ b/dgl_sparse/src/sddmm.cc @@ -61,8 +61,9 @@ void _SDDMMSanityCheck( mat1.dtype() == mat2.dtype(), "SDDMM: the two dense matrices should have the same dtype."); TORCH_CHECK( - mat1.device() == mat2.device(), - "SDDMM: the two dense matrices should on the same device."); + mat1.device() == mat2.device() && sparse_mat->device() == mat2.device(), + "SDDMM: the two dense matrices and sparse matrix should on the same " + "device."); } torch::Tensor SDDMMAutoGrad::forward(