Skip to content

Commit

Permalink
[Feature] Add support for Ascend devices with gather_points (#2555)
Browse files Browse the repository at this point in the history
  • Loading branch information
zcc-zjut committed Jan 18, 2023
1 parent 4032b83 commit 84f60c1
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 43 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ We implement common ops used in detection, segmentation, etc.
| FurthestPointSample | || | | |
| FurthestPointSampleWithDist | || | | |
| FusedBiasLeakyrelu | || | ||
| GatherPoints | || | | |
| GatherPoints | || | | |
| GroupPoints | || | | |
| Iou3d | ||| | |
| KNN | || | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| FurthestPointSample | || | | |
| FurthestPointSampleWithDist | || | | |
| FusedBiasLeakyrelu | || | ||
| GatherPoints | || | | |
| GatherPoints | || | | |
| GroupPoints | || | | |
| Iou3d | ||| | |
| KNN | || | | |
Expand Down
29 changes: 29 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void gather_points_forward_npu(int b, int c, int n, int npoints,
const Tensor points, const Tensor idx,
Tensor out) {
// b, c, n, and npoints do not need to be passed into gatherv2,
// b, c, n, and npoints are calculated inside the operator
// gatherv2 operator in ascend needs to set axis to 2, batch_dims is 1
c10::SmallVector<int64_t, N> axis = {2};
int64_t batch_dims = 1;

OpCommand cmd;
cmd.Name("GatherV2")
.Input(points)
.Input(idx)
.Input(axis)
.Output(out)
.Attr("batch_dims", batch_dims)
.Run();
}

void gather_points_forward_impl(int b, int c, int n, int npoints,
const Tensor points, const Tensor idx,
Tensor out);

REGISTER_NPU_IMPL(gather_points_forward_impl, gather_points_forward_npu);
98 changes: 57 additions & 41 deletions tests/test_ops/test_gather_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,65 @@
import torch

from mmcv.ops import gather_points
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_gather_points():
features = torch.tensor([[[
-1.6095, -0.1029, -0.8876, -1.2447, -2.4031, 0.3708, -1.1586, -1.4967,
-0.4800, 0.2252
],
[
1.9138, 3.4979, 1.6854, 1.5631, 3.6776,
3.1154, 2.1705, 2.5221, 2.0411, 3.1446
],
[
-1.4173, 0.3073, -1.4339, -1.4340, -1.2770,
-0.2867, -1.4162, -1.4044, -1.4245, -1.4074
]],
[[
0.2160, 0.0842, 0.3661, -0.2749, -0.4909,
-0.6066, -0.8773, -0.0745, -0.9496, 0.1434
],
[
1.3644, 1.8087, 1.6855, 1.9563, 1.2746,
1.9662, 0.9566, 1.8778, 1.1437, 1.3639
],
[
-0.7172, 0.1692, 0.2241, 0.0721, -0.7540,
0.0462, -0.6227, 0.3223, -0.6944, -0.5294
]]]).cuda()
class TestGatherPoints:

idx = torch.tensor([[0, 1, 4, 0, 0, 0], [0, 5, 6, 0, 0, 0]]).int().cuda()
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_gather_points_all_close(self, device):
features = torch.tensor(
[[[
-1.6095, -0.1029, -0.8876, -1.2447, -2.4031, 0.3708, -1.1586,
-1.4967, -0.4800, 0.2252
],
[
1.9138, 3.4979, 1.6854, 1.5631, 3.6776, 3.1154, 2.1705,
2.5221, 2.0411, 3.1446
],
[
-1.4173, 0.3073, -1.4339, -1.4340, -1.2770, -0.2867, -1.4162,
-1.4044, -1.4245, -1.4074
]],
[[
0.2160, 0.0842, 0.3661, -0.2749, -0.4909, -0.6066, -0.8773,
-0.0745, -0.9496, 0.1434
],
[
1.3644, 1.8087, 1.6855, 1.9563, 1.2746, 1.9662, 0.9566,
1.8778, 1.1437, 1.3639
],
[
-0.7172, 0.1692, 0.2241, 0.0721, -0.7540, 0.0462, -0.6227,
0.3223, -0.6944, -0.5294
]]],
dtype=torch.float,
device=device)
idx = torch.tensor([[0, 1, 4, 0, 0, 0], [0, 5, 6, 0, 0, 0]],
dtype=torch.int32,
device=device)
output = gather_points(features, idx)
expected_output = torch.tensor(
[[[-1.6095, -0.1029, -2.4031, -1.6095, -1.6095, -1.6095],
[1.9138, 3.4979, 3.6776, 1.9138, 1.9138, 1.9138],
[-1.4173, 0.3073, -1.2770, -1.4173, -1.4173, -1.4173]],
[[0.2160, -0.6066, -0.8773, 0.2160, 0.2160, 0.2160],
[1.3644, 1.9662, 0.9566, 1.3644, 1.3644, 1.3644],
[-0.7172, 0.0462, -0.6227, -0.7172, -0.7172, -0.7172]]],
dtype=torch.float,
device=device)

output = gather_points(features, idx)
expected_output = torch.tensor(
[[[-1.6095, -0.1029, -2.4031, -1.6095, -1.6095, -1.6095],
[1.9138, 3.4979, 3.6776, 1.9138, 1.9138, 1.9138],
[-1.4173, 0.3073, -1.2770, -1.4173, -1.4173, -1.4173]],
[[0.2160, -0.6066, -0.8773, 0.2160, 0.2160, 0.2160],
[1.3644, 1.9662, 0.9566, 1.3644, 1.3644, 1.3644],
[-0.7172, 0.0462, -0.6227, -0.7172, -0.7172, -0.7172]]]).cuda()
assert torch.allclose(output, expected_output)

assert torch.allclose(output, expected_output)

# test fp16
output_half = gather_points(features.half(), idx)
assert torch.allclose(output_half, expected_output.half())
# test fp16
output_half = gather_points(features.half(), idx)
assert torch.allclose(output_half, expected_output.half())

0 comments on commit 84f60c1

Please sign in to comment.