From 98c794d1e9411bf1dc588431a2de592106fd4644 Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Mon, 24 Jun 2024 10:01:00 +0800 Subject: [PATCH] adapt npu_dynamic_voxelization & fix rotated_iou precision --- mmcv/ops/box_iou_rotated.py | 9 ++--- .../csrc/pytorch/npu/box_iou_rotated_npu.cpp | 8 +++++ .../ops/csrc/pytorch/npu/voxelization_npu.cpp | 35 +++++++++++++++++++ tests/test_ops/test_voxelization.py | 15 ++++++++ 4 files changed, 63 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/box_iou_rotated.py b/mmcv/ops/box_iou_rotated.py index a811531d42..0f8a8e298b 100644 --- a/mmcv/ops/box_iou_rotated.py +++ b/mmcv/ops/box_iou_rotated.py @@ -143,10 +143,11 @@ def box_iou_rotated(bboxes1: torch.Tensor, bboxes1 = bboxes1 * flip_mat bboxes2 = bboxes2 * flip_mat if bboxes1.device.type == 'npu': - scale_mat = bboxes1.new_ones(bboxes1.shape[-1]) - scale_mat[-1] = 1.0 / 0.01745329252 - bboxes1 = bboxes1 * scale_mat - bboxes2 = bboxes2 * scale_mat + if (mode_flag == 1 or aligned or not clockwise): + scale_mat = bboxes1.new_ones(bboxes1.shape[-1]) + scale_mat[-1] = 1.0 / 0.01745329252 + bboxes1 = bboxes1 * scale_mat + bboxes2 = bboxes2 * scale_mat bboxes1 = bboxes1.contiguous() bboxes2 = bboxes2.contiguous() ext_module.box_iou_rotated( diff --git a/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp b/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp index c6e6b66478..14df358080 100644 --- a/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp @@ -8,6 +8,14 @@ void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious, void box_iou_rotated_npu(const Tensor boxes1, const Tensor boxes2, Tensor ious, const int mode_flag, const bool aligned) { + if (mode_flag == 0 && aligned == false) { + auto trans = false; + auto is_clockwise = false; + auto need_iou = true; + EXEC_NPU_CMD(aclnnBoxesOverlapBev, boxes1, boxes2, trans, is_clockwise, + need_iou, ious); + return; + } at::Tensor boxes = at::ones_like(boxes1); at::Tensor query_boxes = at::ones_like(boxes2); boxes = boxes1.transpose(0, 1).unsqueeze(0); diff --git a/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp b/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp index ffd9b4c43b..2abe7c8f95 100644 --- a/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp @@ -11,6 +11,11 @@ int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels, const int max_points, const int max_voxels, const int NDim = 3); +void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim = 3); + int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, at::Tensor &num_points_per_voxel, @@ -53,4 +58,34 @@ int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels, return voxel_num_int; } +void dynamic_voxelize_forward_npu(const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim = 3) { + uint32_t ptsNum = points.size(0); + uint32_t ptsFeature = points.size(1); + at::Tensor ptsTrans = at::transpose(points, 0, 1); + double coors_min_x = coors_range[0]; + double coors_min_y = coors_range[1]; + double coors_min_z = coors_range[2]; + double coors_max_x = coors_range[3]; + double coors_max_y = coors_range[4]; + double coors_max_z = coors_range[5]; + double voxel_x = voxel_size[0]; + double voxel_y = voxel_size[1]; + double voxel_z = voxel_size[2]; + int grid_x = std::round((coors_max_x - coors_min_x) / voxel_x); + int grid_y = std::round((coors_max_y - coors_min_y) / voxel_y); + int grid_z = std::round((coors_max_z - coors_min_z) / voxel_z); + + at::Tensor tmp_coors = + at::zeros({3, ptsNum}, points.options().dtype(at::kInt)); + EXEC_NPU_CMD(aclnnDynamicVoxelization, ptsTrans, coors_min_x, coors_min_y, + coors_min_z, voxel_x, voxel_y, voxel_z, grid_x, grid_y, grid_z, + tmp_coors); + tmp_coors.transpose_(0, 1); + coors.copy_(tmp_coors); +} + REGISTER_NPU_IMPL(hard_voxelize_forward_impl, hard_voxelize_forward_npu); +REGISTER_NPU_IMPL(dynamic_voxelize_forward_impl, dynamic_voxelize_forward_npu); diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py index cd01eb46e6..f62fcfb9e3 100644 --- a/tests/test_ops/test_voxelization.py +++ b/tests/test_ops/test_voxelization.py @@ -192,6 +192,9 @@ def test_voxelization_npu(device_type): points = voxel_dict['points'] points = torch.tensor(points) + max_num_points = -1 + dynamic_voxelization = Voxelization(voxel_size, point_cloud_range, + max_num_points) max_num_points = 1000 hard_voxelization = Voxelization(voxel_size, point_cloud_range, max_num_points) @@ -207,3 +210,15 @@ def test_voxelization_npu(device_type): assert np.all(coors == expected_coors) assert np.all(voxels == expected_voxels) assert np.all(num_points_per_voxel == expected_num_points_per_voxel) + + # test dynamic_voxelization on npu + coors = dynamic_voxelization.forward(points) + coors = coors.cpu().detach().numpy() + points = points.cpu().detach().numpy() + for i in range(expected_voxels.shape[0]): + indices = _get_voxel_points_indices(points, coors, expected_voxels[i]) + num_points_current_voxel = points[indices].shape[0] + assert num_points_current_voxel > 0 + assert np.all( + points[indices] == expected_coors[i][:num_points_current_voxel]) + assert num_points_current_voxel == expected_num_points_per_voxel[i]