Skip to content

Commit

Permalink
add multi npu op
Browse files Browse the repository at this point in the history
  • Loading branch information
huaweiZJX authored and momo609 committed Apr 23, 2024
1 parent 98393a3 commit 56cda46
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 29 deletions.
4 changes: 2 additions & 2 deletions docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ MMCV 提供了检测、分割等任务中常用的算子
| Deformable RoIPool | ||| ||
| DiffIoURotated | ||| | |
| DynamicScatter | ||| | |
| FurthestPointSample | || | | |
| FurthestPointSampleWithDist | || | | |
| FurthestPointSample | || | | |
| FurthestPointSampleWithDist | || | | |
| FusedBiasLeakyrelu | || | ||
| GatherPoints | || | ||
| GroupPoints | || | | |
Expand Down
19 changes: 19 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void furthest_point_sampling_forward_npu(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor,
int b, int n, int m) {
TORCH_CHECK(
(points_tensor.sizes()[1] >= m),
"the num of sampled points needs to be smaller than total num of points.");
at::Tensor points_xyz = points_tensor.transpose(1, 2).contiguous();
at::Tensor nearest_dist = temp_tensor.contiguous();
EXEC_NPU_CMD(aclnnFurthestPointSampling, points_xyz, nearest_dist, m, idx_tensor);
}

void furthest_point_sampling_forward_impl(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor,
int b, int n, int m);

REGISTER_NPU_IMPL(furthest_point_sampling_forward_impl, furthest_point_sampling_forward_npu);
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;

void furthest_point_sampling_with_dist_npu(Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor, int b, int n,
int m) {
auto output_size = {b, m};
at::Tensor result = at::empty(output_size, points_tensor.options().dtype(at::kInt));
EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor, m, result);
}

void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor, int b, int n,
int m);

REGISTER_NPU_IMPL(furthest_point_sampling_with_dist_forward_impl,
furthest_point_sampling_with_dist_npu);
21 changes: 21 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;

void iou3d_nms3d_normal_forward_npu(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh) {
int32_t box_num = boxes.size(0);
int32_t data_align = 16;
int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align;
at::Tensor mask = at::empty({ box_num, mask_num }, boxes.options().dtype(at::kShort));
EXEC_NPU_CMD(aclnnNms3dNormal, boxes, nms_overlap_thresh, mask);

keep = at::zeros({ box_num }, mask.options());
keep_num = at::zeros(1, mask.options());
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num);
}

void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh);

REGISTER_NPU_IMPL(iou3d_nms3d_normal_forward_impl, iou3d_nms3d_normal_forward_npu);
26 changes: 26 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

constexpr int32_t BOX_DIM = 7;

void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh)
{
TORCH_CHECK((boxes.sizes()[1] == BOX_DIM), "Input boxes shape should be (N, 7)");
int32_t box_num = boxes.size(0);
int32_t data_align = 16;
int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align;
at::Tensor mask = at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort));
EXEC_NPU_CMD(aclnnNms3d, boxes, nms_overlap_thresh, mask);

keep = at::zeros({box_num}, mask.options());
keep_num = at::zeros(1, mask.options());
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num);
}

void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh);

REGISTER_NPU_IMPL(iou3d_nms3d_forward_impl, iou3d_nms3d_forward_npu);
8 changes: 6 additions & 2 deletions mmcv/ops/furthest_point_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ def forward(ctx, points_xyz: torch.Tensor,
assert points_xyz.is_contiguous()

B, N = points_xyz.size()[:2]
output = torch.cuda.IntTensor(B, num_points)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
if points_xyz.device.type == 'npu':
output = torch.IntTensor(B, num_points).npu()
temp = torch.FloatTensor(B, N).fill_(1e10).npu()
else:
output = torch.cuda.IntTensor(B, num_points)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)

ext_module.furthest_point_sampling_forward(
points_xyz,
Expand Down
39 changes: 28 additions & 11 deletions tests/test_ops/test_furthest_point_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,55 @@
import torch

from mmcv.ops import furthest_point_sample, furthest_point_sample_with_dist
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_fps():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_fps(device):
xyz = torch.tensor([[[-0.2748, 1.0020, -1.1674], [0.1015, 1.3952, -1.2681],
[-0.8070, 2.4137,
-0.5845], [-1.0001, 2.1982, -0.5859],
[0.3841, 1.8983, -0.7431]],
[[-1.0696, 3.0758,
-0.1899], [-0.2559, 3.5521, -0.1402],
[0.8164, 4.0081, -0.1839], [-1.1000, 3.0213, -0.8205],
[-0.0518, 3.7251, -0.3950]]]).cuda()
[-0.0518, 3.7251, -0.3950]]]).to(device)

idx = furthest_point_sample(xyz, 3)
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).cuda()
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).to(device)
assert torch.all(idx == expected_idx)


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_fps_with_dist():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_fps_with_dist(device):
xyz = torch.tensor([[[-0.2748, 1.0020, -1.1674], [0.1015, 1.3952, -1.2681],
[-0.8070, 2.4137,
-0.5845], [-1.0001, 2.1982, -0.5859],
[0.3841, 1.8983, -0.7431]],
[[-1.0696, 3.0758,
-0.1899], [-0.2559, 3.5521, -0.1402],
[0.8164, 4.0081, -0.1839], [-1.1000, 3.0213, -0.8205],
[-0.0518, 3.7251, -0.3950]]]).cuda()
[-0.0518, 3.7251, -0.3950]]]).to(device)

expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).cuda()
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).to(device)
xyz_square_dist = ((xyz.unsqueeze(dim=1) -
xyz.unsqueeze(dim=2))**2).sum(-1)
idx = furthest_point_sample_with_dist(xyz_square_dist, 3)
Expand All @@ -44,7 +61,7 @@ def test_fps_with_dist():
fps_idx = np.load('tests/data/for_3d_ops/fps_idx.npy')
features_for_fps_distance = np.load(
'tests/data/for_3d_ops/features_for_fps_distance.npy')
expected_idx = torch.from_numpy(fps_idx).cuda()
expected_idx = torch.from_numpy(fps_idx).to(device)
features_for_fps_distance = torch.from_numpy(
features_for_fps_distance).cuda()

Expand Down
14 changes: 11 additions & 3 deletions tests/test_ops/test_iou3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from mmcv.ops import boxes_iou3d, boxes_overlap_bev, nms3d, nms3d_normal
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.parametrize('device', [
Expand Down Expand Up @@ -77,7 +77,11 @@ def test_boxes_iou3d(device):
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_nms3d(device):
# test for 5 boxes
Expand Down Expand Up @@ -116,7 +120,11 @@ def test_nms3d(device):
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_nms3d_normal(device):
# test for 5 boxes
Expand Down
30 changes: 19 additions & 11 deletions tests/test_ops/test_roiaware_pool3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.parametrize('device', [
Expand Down Expand Up @@ -56,38 +56,46 @@ def test_RoIAwarePool3d(device, dtype):
torch.tensor(49.750, dtype=dtype).to(device), 1e-3)


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_points_in_boxes_part():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_points_in_boxes_part(device):
boxes = torch.tensor(
[[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3]],
[[-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]],
dtype=torch.float32).cuda(
) # boxes (b, t, 7) with bottom center in lidar coordinate
dtype=torch.float32).to(
device) # boxes (b, t, 7) with bottom center in lidar coordinate
pts = torch.tensor(
[[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6],
[0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3],
[4.7, 3.5, -12.2]],
[[3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], [-21.3, -52, -5],
[0, 0, 0], [6, 7, 8], [-2, -3, -4], [6, 4, 9]]],
dtype=torch.float32).cuda() # points (b, m, 3) in lidar coordinate
dtype=torch.float32).to(device) # points (b, m, 3) in lidar coordinate

point_indices = points_in_boxes_part(points=pts, boxes=boxes)
expected_point_indices = torch.tensor(
[[0, 0, 0, 0, 0, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1]],
dtype=torch.int32).cuda()
dtype=torch.int32).to(device)
assert point_indices.shape == torch.Size([2, 8])
assert (point_indices == expected_point_indices).all()

boxes = torch.tensor([[[0.0, 0.0, 0.0, 1.0, 20.0, 1.0, 0.523598]]],
dtype=torch.float32).cuda() # 30 degrees
dtype=torch.float32).to(device) # 30 degrees
pts = torch.tensor(
[[[4, 6.928, 0], [6.928, 4, 0], [4, -6.928, 0], [6.928, -4, 0],
[-4, 6.928, 0], [-6.928, 4, 0], [-4, -6.928, 0], [-6.928, -4, 0]]],
dtype=torch.float32).cuda()
dtype=torch.float32).to(device)
point_indices = points_in_boxes_part(points=pts, boxes=boxes)
expected_point_indices = torch.tensor([[-1, -1, 0, -1, 0, -1, -1, -1]],
dtype=torch.int32).cuda()
dtype=torch.int32).to(device)
assert (point_indices == expected_point_indices).all()


Expand Down

0 comments on commit 56cda46

Please sign in to comment.