Skip to content

Commit

Permalink
adapt npu_dynamic_voxelization & fix rotated_iou precision
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuweichen committed Jun 24, 2024
1 parent c94c723 commit 98c794d
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 4 deletions.
9 changes: 5 additions & 4 deletions mmcv/ops/box_iou_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,11 @@ def box_iou_rotated(bboxes1: torch.Tensor,
bboxes1 = bboxes1 * flip_mat
bboxes2 = bboxes2 * flip_mat
if bboxes1.device.type == 'npu':
scale_mat = bboxes1.new_ones(bboxes1.shape[-1])
scale_mat[-1] = 1.0 / 0.01745329252
bboxes1 = bboxes1 * scale_mat
bboxes2 = bboxes2 * scale_mat
if (mode_flag == 1 or aligned or not clockwise):
scale_mat = bboxes1.new_ones(bboxes1.shape[-1])
scale_mat[-1] = 1.0 / 0.01745329252
bboxes1 = bboxes1 * scale_mat
bboxes2 = bboxes2 * scale_mat
bboxes1 = bboxes1.contiguous()
bboxes2 = bboxes2.contiguous()
ext_module.box_iou_rotated(
Expand Down
8 changes: 8 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,

void box_iou_rotated_npu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
if (mode_flag == 0 && aligned == false) {
auto trans = false;
auto is_clockwise = false;
auto need_iou = true;
EXEC_NPU_CMD(aclnnBoxesOverlapBev, boxes1, boxes2, trans, is_clockwise,
need_iou, ious);
return;
}
at::Tensor boxes = at::ones_like(boxes1);
at::Tensor query_boxes = at::ones_like(boxes2);
boxes = boxes1.transpose(0, 1).unsqueeze(0);
Expand Down
35 changes: 35 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels,
const int max_points, const int max_voxels,
const int NDim = 3);

void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim = 3);

int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
Expand Down Expand Up @@ -53,4 +58,34 @@ int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels,
return voxel_num_int;
}

void dynamic_voxelize_forward_npu(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim = 3) {
uint32_t ptsNum = points.size(0);
uint32_t ptsFeature = points.size(1);
at::Tensor ptsTrans = at::transpose(points, 0, 1);
double coors_min_x = coors_range[0];
double coors_min_y = coors_range[1];
double coors_min_z = coors_range[2];
double coors_max_x = coors_range[3];
double coors_max_y = coors_range[4];
double coors_max_z = coors_range[5];
double voxel_x = voxel_size[0];
double voxel_y = voxel_size[1];
double voxel_z = voxel_size[2];
int grid_x = std::round((coors_max_x - coors_min_x) / voxel_x);
int grid_y = std::round((coors_max_y - coors_min_y) / voxel_y);
int grid_z = std::round((coors_max_z - coors_min_z) / voxel_z);

at::Tensor tmp_coors =
at::zeros({3, ptsNum}, points.options().dtype(at::kInt));
EXEC_NPU_CMD(aclnnDynamicVoxelization, ptsTrans, coors_min_x, coors_min_y,
coors_min_z, voxel_x, voxel_y, voxel_z, grid_x, grid_y, grid_z,
tmp_coors);
tmp_coors.transpose_(0, 1);
coors.copy_(tmp_coors);
}

REGISTER_NPU_IMPL(hard_voxelize_forward_impl, hard_voxelize_forward_npu);
REGISTER_NPU_IMPL(dynamic_voxelize_forward_impl, dynamic_voxelize_forward_npu);
15 changes: 15 additions & 0 deletions tests/test_ops/test_voxelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def test_voxelization_npu(device_type):
points = voxel_dict['points']

points = torch.tensor(points)
max_num_points = -1
dynamic_voxelization = Voxelization(voxel_size, point_cloud_range,
max_num_points)
max_num_points = 1000
hard_voxelization = Voxelization(voxel_size, point_cloud_range,
max_num_points)
Expand All @@ -207,3 +210,15 @@ def test_voxelization_npu(device_type):
assert np.all(coors == expected_coors)
assert np.all(voxels == expected_voxels)
assert np.all(num_points_per_voxel == expected_num_points_per_voxel)

# test dynamic_voxelization on npu
coors = dynamic_voxelization.forward(points)
coors = coors.cpu().detach().numpy()
points = points.cpu().detach().numpy()
for i in range(expected_voxels.shape[0]):
indices = _get_voxel_points_indices(points, coors, expected_voxels[i])
num_points_current_voxel = points[indices].shape[0]
assert num_points_current_voxel > 0
assert np.all(
points[indices] == expected_coors[i][:num_points_current_voxel])
assert num_points_current_voxel == expected_num_points_per_voxel[i]

0 comments on commit 98c794d

Please sign in to comment.