Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi npu op. #3031

Merged
merged 5 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ MMCV 提供了检测、分割等任务中常用的算子
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | √ | | |
| DynamicScatter | | √ | √ | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |
| FusedBiasLeakyrelu | | √ | | | √ |
| GatherPoints | | √ | | | √ |
| GroupPoints | | √ | | | |
Expand Down
23 changes: 23 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void furthest_point_sampling_forward_npu(Tensor points_tensor,
Tensor temp_tensor, Tensor idx_tensor,
int b, int n, int m) {
TORCH_CHECK(
(points_tensor.sizes()[1] >= m),
"the num of sampled points should smaller than total num of points.");
at::Tensor points_xyz = points_tensor.transpose(1, 2).contiguous();
at::Tensor nearest_dist = temp_tensor.contiguous();
EXEC_NPU_CMD(aclnnFurthestPointSampling, points_xyz, nearest_dist, m,
idx_tensor);
}

void furthest_point_sampling_forward_impl(Tensor points_tensor,
Tensor temp_tensor, Tensor idx_tensor,
int b, int n, int m);

REGISTER_NPU_IMPL(furthest_point_sampling_forward_impl,
furthest_point_sampling_forward_npu);
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;

void furthest_point_sampling_with_dist_npu(Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor, int b, int n,
int m) {
auto output_size = {b, m};
at::Tensor result =
at::empty(output_size, points_tensor.options().dtype(at::kInt));
EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor,
m, result);
}

void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor, int b,
int n, int m);

REGISTER_NPU_IMPL(furthest_point_sampling_with_dist_forward_impl,
furthest_point_sampling_with_dist_npu);
16 changes: 7 additions & 9 deletions mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,14 @@ void ms_deform_attn_impl_backward(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc, Tensor &grad_attn_weight,
const int im2col_step);
Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step);

void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc,
Tensor &grad_attn_weight, const int im2col_step) {
void ms_deform_attn_backward_npu(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc, Tensor &grad_attn_weight,
const int im2col_step) {
check_support(value, attn_weight);
at::Tensor value_fp32 = value;
at::Tensor spatial_shapes_int32 = spatial_shapes;
Expand Down
25 changes: 25 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;

void iou3d_nms3d_normal_forward_npu(const Tensor boxes, Tensor &keep,
Tensor &keep_num,
float nms_overlap_thresh) {
int32_t box_num = boxes.size(0);
int32_t data_align = 16;
int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align;
at::Tensor mask =
at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort));
EXEC_NPU_CMD(aclnnNms3dNormal, boxes, nms_overlap_thresh, mask);

keep = at::zeros({box_num}, mask.options());
keep_num = at::zeros(1, mask.options());
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num);
}

void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num,
float nms_overlap_thresh);

REGISTER_NPU_IMPL(iou3d_nms3d_normal_forward_impl,
iou3d_nms3d_normal_forward_npu);
26 changes: 26 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

constexpr int32_t BOX_DIM = 7;

void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, Tensor &keep_num,
float nms_overlap_thresh) {
TORCH_CHECK((boxes.sizes()[1] == BOX_DIM),
"Input boxes shape should be (N, 7)");
int32_t box_num = boxes.size(0);
int32_t data_align = 16;
int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align;
at::Tensor mask =
at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort));
EXEC_NPU_CMD(aclnnNms3d, boxes, nms_overlap_thresh, mask);
keep = at::zeros({box_num}, mask.options());
keep_num = at::zeros(1, mask.options());
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num);
}

void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh);

REGISTER_NPU_IMPL(iou3d_nms3d_forward_impl, iou3d_nms3d_forward_npu);
19 changes: 19 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/points_in_box_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void points_in_boxes_part_forward_impl_npu(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points) {
c10::SmallVector<int64_t, 8> output_size = {pts.size(0), pts.size(1)};
auto boxes_trans = boxes.transpose(1, 2).contiguous();
EXEC_NPU_CMD(aclnnPointsInBox, boxes_trans, pts, box_idx_of_points);
}
void points_in_boxes_part_forward_impl(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points);
REGISTER_NPU_IMPL(points_in_boxes_part_forward_impl,
points_in_boxes_part_forward_impl_npu);
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ void points_in_polygons_npu(const Tensor points, Tensor polygons, Tensor output,
"The batch of polygons tensor must be less than MAX_POLYGONS_BATCH");
at::Tensor trans_polygons = polygons.transpose(0, 1);
OpCommand cmd;
at::Tensor new_trans_polygons = NpuUtils::format_contiguous(trans_polygons);
at::Tensor new_trans_polygons = trans_polygons.contiguous();
cmd.Name("PointsInPolygons")
.Input(points, (string) "points")
.Input(new_trans_polygons, (string) "polygons")
Expand Down
8 changes: 6 additions & 2 deletions mmcv/ops/furthest_point_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ def forward(ctx, points_xyz: torch.Tensor,
assert points_xyz.is_contiguous()

B, N = points_xyz.size()[:2]
output = torch.cuda.IntTensor(B, num_points)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
if points_xyz.device.type == 'npu':
output = torch.IntTensor(B, num_points).npu()
temp = torch.FloatTensor(B, N).fill_(1e10).npu()
else:
output = torch.cuda.IntTensor(B, num_points)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)

ext_module.furthest_point_sampling_forward(
points_xyz,
Expand Down
39 changes: 28 additions & 11 deletions tests/test_ops/test_furthest_point_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,55 @@
import torch

from mmcv.ops import furthest_point_sample, furthest_point_sample_with_dist
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_fps():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_fps(device):
xyz = torch.tensor([[[-0.2748, 1.0020, -1.1674], [0.1015, 1.3952, -1.2681],
[-0.8070, 2.4137,
-0.5845], [-1.0001, 2.1982, -0.5859],
[0.3841, 1.8983, -0.7431]],
[[-1.0696, 3.0758,
-0.1899], [-0.2559, 3.5521, -0.1402],
[0.8164, 4.0081, -0.1839], [-1.1000, 3.0213, -0.8205],
[-0.0518, 3.7251, -0.3950]]]).cuda()
[-0.0518, 3.7251, -0.3950]]]).to(device)

idx = furthest_point_sample(xyz, 3)
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).cuda()
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).to(device)
assert torch.all(idx == expected_idx)


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_fps_with_dist():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_fps_with_dist(device):
xyz = torch.tensor([[[-0.2748, 1.0020, -1.1674], [0.1015, 1.3952, -1.2681],
[-0.8070, 2.4137,
-0.5845], [-1.0001, 2.1982, -0.5859],
[0.3841, 1.8983, -0.7431]],
[[-1.0696, 3.0758,
-0.1899], [-0.2559, 3.5521, -0.1402],
[0.8164, 4.0081, -0.1839], [-1.1000, 3.0213, -0.8205],
[-0.0518, 3.7251, -0.3950]]]).cuda()
[-0.0518, 3.7251, -0.3950]]]).to(device)

expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).cuda()
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).to(device)
xyz_square_dist = ((xyz.unsqueeze(dim=1) -
xyz.unsqueeze(dim=2))**2).sum(-1)
idx = furthest_point_sample_with_dist(xyz_square_dist, 3)
Expand All @@ -44,7 +61,7 @@ def test_fps_with_dist():
fps_idx = np.load('tests/data/for_3d_ops/fps_idx.npy')
features_for_fps_distance = np.load(
'tests/data/for_3d_ops/features_for_fps_distance.npy')
expected_idx = torch.from_numpy(fps_idx).cuda()
expected_idx = torch.from_numpy(fps_idx).to(device)
features_for_fps_distance = torch.from_numpy(
features_for_fps_distance).cuda()

Expand Down
14 changes: 11 additions & 3 deletions tests/test_ops/test_iou3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from mmcv.ops import boxes_iou3d, boxes_overlap_bev, nms3d, nms3d_normal
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.parametrize('device', [
Expand Down Expand Up @@ -77,7 +77,11 @@ def test_boxes_iou3d(device):
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_nms3d(device):
# test for 5 boxes
Expand Down Expand Up @@ -116,7 +120,11 @@ def test_nms3d(device):
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_nms3d_normal(device):
# test for 5 boxes
Expand Down
30 changes: 19 additions & 11 deletions tests/test_ops/test_roiaware_pool3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.parametrize('device', [
Expand Down Expand Up @@ -56,38 +56,46 @@ def test_RoIAwarePool3d(device, dtype):
torch.tensor(49.750, dtype=dtype).to(device), 1e-3)


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_points_in_boxes_part():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_points_in_boxes_part(device):
boxes = torch.tensor(
[[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3]],
[[-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]],
dtype=torch.float32).cuda(
) # boxes (b, t, 7) with bottom center in lidar coordinate
dtype=torch.float32).to(
device) # boxes (b, t, 7) with bottom center in lidar coordinate
pts = torch.tensor(
[[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6],
[0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3],
[4.7, 3.5, -12.2]],
[[3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], [-21.3, -52, -5],
[0, 0, 0], [6, 7, 8], [-2, -3, -4], [6, 4, 9]]],
dtype=torch.float32).cuda() # points (b, m, 3) in lidar coordinate
dtype=torch.float32).to(device) # points (b, m, 3) in lidar coordinate

point_indices = points_in_boxes_part(points=pts, boxes=boxes)
expected_point_indices = torch.tensor(
[[0, 0, 0, 0, 0, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1]],
dtype=torch.int32).cuda()
dtype=torch.int32).to(device)
assert point_indices.shape == torch.Size([2, 8])
assert (point_indices == expected_point_indices).all()

boxes = torch.tensor([[[0.0, 0.0, 0.0, 1.0, 20.0, 1.0, 0.523598]]],
dtype=torch.float32).cuda() # 30 degrees
dtype=torch.float32).to(device) # 30 degrees
pts = torch.tensor(
[[[4, 6.928, 0], [6.928, 4, 0], [4, -6.928, 0], [6.928, -4, 0],
[-4, 6.928, 0], [-6.928, 4, 0], [-4, -6.928, 0], [-6.928, -4, 0]]],
dtype=torch.float32).cuda()
dtype=torch.float32).to(device)
point_indices = points_in_boxes_part(points=pts, boxes=boxes)
expected_point_indices = torch.tensor([[-1, -1, 0, -1, 0, -1, -1, -1]],
dtype=torch.int32).cuda()
dtype=torch.int32).to(device)
assert (point_indices == expected_point_indices).all()


Expand Down
Loading