diff --git a/mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp b/mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp index f52789bbcc..0f56ca0230 100644 --- a/mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp @@ -38,8 +38,29 @@ void group_points_forward_npu(int b, int c, int n, int npoints, int nsample, out.copy_(res); } +void group_points_backward_npu(int b, int c, int n, int npoints, int nsample, + const Tensor grad_out, const Tensor idx, + Tensor grad_features) { + at::Tensor trans_idx = idx.view({b * npoints * nsample}); + at::Tensor trans_grad_out = grad_out.permute({0, 2, 3, 1}); + at::Tensor grad_out_tensor = trans_grad_out.contiguous(); + grad_out_tensor = grad_out_tensor.view({b * npoints * nsample, c}); + at::Tensor out = at::zeros({b, n, c}, grad_out.options()); + + EXEC_NPU_CMD(aclnnGroupPointsGrad, grad_out_tensor, trans_idx, b, c, n, + npoints, nsample, out); + + at::Tensor grad_points = out.transpose(1, 2); + + grad_features.copy_(grad_points); +} + void group_points_forward_impl(int b, int c, int n, int npoints, int nsample, const Tensor points, const Tensor idx, Tensor out); +void group_points_backward_impl(int b, int c, int n, int npoints, int nsample, + const Tensor points, const Tensor idx, + Tensor out); REGISTER_NPU_IMPL(group_points_forward_impl, group_points_forward_npu); +REGISTER_NPU_IMPL(group_points_backward_impl, group_points_backward_npu); diff --git a/tests/test_ops/test_group_points.py b/tests/test_ops/test_group_points.py index e8eaa66164..511e6418f1 100644 --- a/tests/test_ops/test_group_points.py +++ b/tests/test_ops/test_group_points.py @@ -47,8 +47,11 @@ def test_grouping_points(dtype, device): -1.4049, 0.4990, -0.7037, -0.9924, 0.0386 ]]], dtype=dtype).to(device) + features.requires_grad = True output = grouping_operation(features, idx) + output.backward(output) + grad_features = features.grad expected_output = torch.tensor( [[[[0.5798, 0.5798, 0.5798], [-1.3311, -1.3311, -1.3311], [0.9268, 0.9268, 0.9268], [0.5798, 0.5798, 0.5798], @@ -69,7 +72,34 @@ def test_grouping_points(dtype, device): [0.0386, 0.0386, 0.0386], [-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646]]]], dtype=dtype).to(device) + expected_grad_features = torch.tensor( + [[[ + 6.9576, 0.0000, 0.0000, -3.9933, 0.0000, 0.0000, 0.0000, 0.0000, + 2.7804, 0.0000 + ], + [ + 65.0964, 0.0000, 0.0000, 4.4220, 0.0000, 0.0000, 0.0000, 0.0000, + 6.4743, 0.0000 + ], + [ + -19.5192, 0.0000, 0.0000, -5.0793, 0.0000, 0.0000, 0.0000, + 0.0000, -5.0358, 0.0000 + ]], + [[ + -0.4560, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -1.1079, 0.0000, + 0.0000, -5.5581 + ], + [ + 14.1276, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 18.2595, 0.0000, + 0.0000, 8.4687 + ], + [ + -7.9752, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4970, 0.0000, + 0.0000, 0.1158 + ]]], + dtype=dtype).to(device) assert torch.allclose(output, expected_output) + assert torch.allclose(grad_features, expected_grad_features) @pytest.mark.parametrize('device', [