diff --git a/mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp b/mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp new file mode 100644 index 0000000000..0f56ca0230 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp @@ -0,0 +1,66 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void group_points_forward_npu(int b, int c, int n, int npoints, int nsample, + const Tensor points, const Tensor idx, + Tensor out) { + // b, c, n, and npoints do not need to be passed into gatherv2, + // b, c, n, and npoints are calculated inside the operator + // gatherv2 operator in ascend needs to set axis to 0, batch_dims is 0 + c10::SmallVector axis = {0}; + int64_t batch_dims = 0; + + auto index = at::arange(0, b); + index = index.to(points.device()); + index = index.view({-1, 1, 1}); + index = at::mul(index, n); + at::Tensor indices = at::add(index, idx); + indices = indices.view({-1}); + + at::Tensor trans_features = points.transpose(1, 2); + at::Tensor features = trans_features.contiguous(); + features = features.view({b * n, c}); + + OpCommand cmd; + cmd.Name("GatherV2") + .Input(features) + .Input(indices) + .Input(axis) + .Output(out) + .Attr("batch_dims", batch_dims) + .Run(); + + at::Tensor output = + out.view({b, npoints, nsample, c}).transpose(1, 3).transpose(2, 3); + at::Tensor res = output.contiguous(); + out.copy_(res); +} + +void group_points_backward_npu(int b, int c, int n, int npoints, int nsample, + const Tensor grad_out, const Tensor idx, + Tensor grad_features) { + at::Tensor trans_idx = idx.view({b * npoints * nsample}); + at::Tensor trans_grad_out = grad_out.permute({0, 2, 3, 1}); + at::Tensor grad_out_tensor = trans_grad_out.contiguous(); + grad_out_tensor = grad_out_tensor.view({b * npoints * nsample, c}); + at::Tensor out = at::zeros({b, n, c}, grad_out.options()); + + EXEC_NPU_CMD(aclnnGroupPointsGrad, grad_out_tensor, trans_idx, b, c, n, + npoints, nsample, out); + + at::Tensor grad_points = out.transpose(1, 2); + + grad_features.copy_(grad_points); +} + +void group_points_forward_impl(int b, int c, int n, int npoints, int nsample, + const Tensor points, const Tensor idx, + Tensor out); +void group_points_backward_impl(int b, int c, int n, int npoints, int nsample, + const Tensor points, const Tensor idx, + Tensor out); + +REGISTER_NPU_IMPL(group_points_forward_impl, group_points_forward_npu); +REGISTER_NPU_IMPL(group_points_backward_impl, group_points_backward_npu); diff --git a/tests/test_ops/test_group_points.py b/tests/test_ops/test_group_points.py index 8109540cea..511e6418f1 100644 --- a/tests/test_ops/test_group_points.py +++ b/tests/test_ops/test_group_points.py @@ -3,16 +3,25 @@ import torch from mmcv.ops import grouping_operation +from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) @pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double]) -def test_grouping_points(dtype): +def test_grouping_points(dtype, device): idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0], - [0, 0, 0]]]).int().cuda() + [0, 0, 0]]]).int().to(device) features = torch.tensor([[[ 0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274, 0.9268, 0.8414 @@ -37,9 +46,12 @@ def test_grouping_points(dtype): -0.6646, -0.6870, -0.1125, -0.2224, -0.3445, -1.4049, 0.4990, -0.7037, -0.9924, 0.0386 ]]], - dtype=dtype).cuda() + dtype=dtype).to(device) + features.requires_grad = True output = grouping_operation(features, idx) + output.backward(output) + grad_features = features.grad expected_output = torch.tensor( [[[[0.5798, 0.5798, 0.5798], [-1.3311, -1.3311, -1.3311], [0.9268, 0.9268, 0.9268], [0.5798, 0.5798, 0.5798], @@ -59,17 +71,54 @@ def test_grouping_points(dtype): [[-0.6646, -0.6646, -0.6646], [0.4990, 0.4990, 0.4990], [0.0386, 0.0386, 0.0386], [-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646]]]], - dtype=dtype).cuda() + dtype=dtype).to(device) + expected_grad_features = torch.tensor( + [[[ + 6.9576, 0.0000, 0.0000, -3.9933, 0.0000, 0.0000, 0.0000, 0.0000, + 2.7804, 0.0000 + ], + [ + 65.0964, 0.0000, 0.0000, 4.4220, 0.0000, 0.0000, 0.0000, 0.0000, + 6.4743, 0.0000 + ], + [ + -19.5192, 0.0000, 0.0000, -5.0793, 0.0000, 0.0000, 0.0000, + 0.0000, -5.0358, 0.0000 + ]], + [[ + -0.4560, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -1.1079, 0.0000, + 0.0000, -5.5581 + ], + [ + 14.1276, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 18.2595, 0.0000, + 0.0000, 8.4687 + ], + [ + -7.9752, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4970, 0.0000, + 0.0000, 0.1158 + ]]], + dtype=dtype).to(device) assert torch.allclose(output, expected_output) + assert torch.allclose(grad_features, expected_grad_features) -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) @pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double]) -def test_stack_grouping_points(dtype): +def test_stack_grouping_points(dtype, device): + if device == 'npu' and dtype == torch.double: + return idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0], [2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], - [1, 1, 1], [0, 0, 0]]).int().cuda() + [1, 1, 1], [0, 0, 0]]).int().to(device) features = torch.tensor([[ 0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274, 0.9268, 0.8414 @@ -94,9 +143,9 @@ def test_stack_grouping_points(dtype): -0.6646, -0.6870, -0.1125, -0.2224, -0.3445, -1.4049, 0.4990, -0.7037, -0.9924, 0.0386 ]], - dtype=dtype).cuda() - features_batch_cnt = torch.tensor([3, 3]).int().cuda() - indices_batch_cnt = torch.tensor([6, 6]).int().cuda() + dtype=dtype).to(device) + features_batch_cnt = torch.tensor([3, 3]).int().to(device) + indices_batch_cnt = torch.tensor([6, 6]).int().to(device) output = grouping_operation(features, idx, features_batch_cnt, indices_batch_cnt) expected_output = torch.tensor( @@ -160,5 +209,5 @@ def test_stack_grouping_points(dtype): [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798], [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457], [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]]], - dtype=dtype).cuda() + dtype=dtype).to(device) assert torch.allclose(output, expected_output)