Skip to content

Commit

Permalink
add npu op
Browse files Browse the repository at this point in the history
  • Loading branch information
momo609 committed Apr 15, 2024
1 parent a5c38d2 commit 8a8c8ad
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 2 deletions.
19 changes: 19 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void furthest_point_sampling_forward_npu(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor,
int b, int n, int m) {
TORCH_CHECK(
(points_tensor.sizes()[1] >= m),
"the num of sampled points needs to be smaller than total num of points.");
at::Tensor points_xyz = points_tensor.transpose(1, 2).contiguous();
at::Tensor nearest_dist = temp_tensor.contiguous();
EXEC_NPU_CMD(aclnnFurthestPointSampling, points_xyz, nearest_dist, m, idx_tensor);
}

void furthest_point_sampling_forward_impl(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor,
int b, int n, int m);

REGISTER_NPU_IMPL(furthest_point_sampling_forward_impl, furthest_point_sampling_forward_npu);
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;

void furthest_point_sampling_with_dist_npu(Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor, int b, int n,
int m) {
auto output_size = {b, m};
at::Tensor result = at::empty(output_size, points_tensor.options().dtype(at::kInt));
EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor, m, result);
}

void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor, int b, int n,
int m);

REGISTER_NPU_IMPL(furthest_point_sampling_with_dist_forward_impl,
furthest_point_sampling_with_dist_npu);
21 changes: 21 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;

void iou3d_nms3d_normal_forward_npu(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh) {
int32_t box_num = boxes.size(0);
int32_t data_align = 16;
int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align;
at::Tensor mask = at::empty({ box_num, mask_num }, boxes.options().dtype(at::kShort));
EXEC_NPU_CMD(aclnnNms3dNormal, boxes, nms_overlap_thresh, mask);

keep = at::zeros({ box_num }, mask.options());
keep_num = at::zeros(1, mask.options());
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num);
}

void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh);

REGISTER_NPU_IMPL(iou3d_nms3d_normal_forward_impl, iou3d_nms3d_normal_forward_npu);
26 changes: 26 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

constexpr int32_t BOX_DIM = 7;

void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh)
{
TORCH_CHECK((boxes.sizes()[1] == BOX_DIM), "Input boxes shape should be (N, 7)");
int32_t box_num = boxes.size(0);
int32_t data_align = 16;
int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align;
at::Tensor mask = at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort));
EXEC_NPU_CMD(aclnnNms3d, boxes, nms_overlap_thresh, mask);

keep = at::zeros({box_num}, mask.options());
keep_num = at::zeros(1, mask.options());
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num);
}

void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh);

REGISTER_NPU_IMPL(iou3d_nms3d_forward_impl, iou3d_nms3d_forward_npu);
18 changes: 18 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/points_in_box_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void points_in_boxes_part_forward_impl_npu(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points) {
c10::SmallVector<int64_t, 8> output_size = {pts.size(0), pts.size(1)};
auto boxes_trans = boxes.transpose(1, 2).contiguous();
EXEC_NPU_CMD(aclnnPointsInBox, boxes_trans, pts, box_idx_of_points);
}
void points_in_boxes_part_forward_impl(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points);
REGISTER_NPU_IMPL(points_in_boxes_part_forward_impl, points_in_boxes_part_forward_impl_npu);
8 changes: 6 additions & 2 deletions mmcv/ops/furthest_point_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ def forward(ctx, points_xyz: torch.Tensor,
assert points_xyz.is_contiguous()

B, N = points_xyz.size()[:2]
output = torch.cuda.IntTensor(B, num_points)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
if points_xyz.device.type == 'npu':
output = torch.IntTensor(B, num_points).npu()
temp = torch.FloatTensor(B, N).fill_(1e10).npu()
else:
output = torch.cuda.IntTensor(B, num_points)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)

ext_module.furthest_point_sampling_forward(
points_xyz,
Expand Down

0 comments on commit 8a8c8ad

Please sign in to comment.