diff --git a/mmcv/ops/box_iou_rotated.py b/mmcv/ops/box_iou_rotated.py index a811531d42..8e199d9ac8 100644 --- a/mmcv/ops/box_iou_rotated.py +++ b/mmcv/ops/box_iou_rotated.py @@ -142,11 +142,6 @@ def box_iou_rotated(bboxes1: torch.Tensor, flip_mat[-1] = -1 bboxes1 = bboxes1 * flip_mat bboxes2 = bboxes2 * flip_mat - if bboxes1.device.type == 'npu': - scale_mat = bboxes1.new_ones(bboxes1.shape[-1]) - scale_mat[-1] = 1.0 / 0.01745329252 - bboxes1 = bboxes1 * scale_mat - bboxes2 = bboxes2 * scale_mat bboxes1 = bboxes1.contiguous() bboxes2 = bboxes2.contiguous() ext_module.box_iou_rotated( diff --git a/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp b/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp index c6e6b66478..d8b0bbaa67 100644 --- a/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp @@ -8,40 +8,15 @@ void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious, void box_iou_rotated_npu(const Tensor boxes1, const Tensor boxes2, Tensor ious, const int mode_flag, const bool aligned) { - at::Tensor boxes = at::ones_like(boxes1); - at::Tensor query_boxes = at::ones_like(boxes2); - boxes = boxes1.transpose(0, 1).unsqueeze(0); - query_boxes = boxes2.transpose(0, 1).unsqueeze(0); - bool is_trans = false; - string modeStr = "iou"; - if (mode_flag == 1) { - modeStr = "iof"; - } - bool is_cross = true; - if (aligned) { - is_cross = false; - } - float v_threshold = 0; - float e_threshold = 0; + TORCH_CHECK(boxes1.size(1) == 5, "boxes1 must be 2D tensor (N, 5)"); + TORCH_CHECK(boxes1.size(1) == 5, "boxes1 must be 2D tensor (N, 5)"); - OpCommand cmd; - cmd.Name("RotatedIou") - .Input(boxes) - .Input(query_boxes) - .Output(ious) - .Attr("trans", is_trans) - .Attr("mode", modeStr) - .Attr("is_cross", is_cross) - .Attr("v_threshold", v_threshold) - .Attr("e_threshold", e_threshold) - .Run(); - - if (is_cross) { - ious = ious.view({boxes1.size(0), boxes2.size(0)}); - } else { - ious = ious.view({boxes1.size(0), 1}); - } + auto trans = false; + auto is_clockwise = false; + EXEC_NPU_CMD(aclnnBoxesOverlapBev, boxes1, boxes2, trans, is_clockwise, + aligned, mode_flag, ious); + return; } REGISTER_NPU_IMPL(box_iou_rotated_impl, box_iou_rotated_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/boxes_overlap_bev_npu.cpp b/mmcv/ops/csrc/pytorch/npu/boxes_overlap_bev_npu.cpp new file mode 100644 index 0000000000..6bc6273083 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/boxes_overlap_bev_npu.cpp @@ -0,0 +1,25 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a, + const int num_b, const Tensor boxes_b, + Tensor ans_overlap); + +void iou3d_boxes_overlap_bev_forward_npu(const int num_a, const Tensor boxes_a, + const int num_b, const Tensor boxes_b, + Tensor ans_overlap) { + + TORCH_CHECK(boxes_a.size(1) == 7, "boxes_a must be 2D tensor (N, 7)"); + TORCH_CHECK(boxes_b.size(1) == 7, "boxes_b must be 2D tensor (N, 7)"); + + auto trans = false; + auto is_clockwise = false; + auto aligned = false; + auto mode_flag = 2; + EXEC_NPU_CMD(aclnnBoxesOverlapBev, boxes_a, boxes_b, trans, is_clockwise, aligned, mode_flag, ans_overlap); + return; +} + +REGISTER_NPU_IMPL(iou3d_boxes_overlap_bev_forward_impl, iou3d_boxes_overlap_bev_forward_npu); diff --git a/tests/test_ops/test_iou3d.py b/tests/test_ops/test_iou3d.py index 27a09eb361..6b2456e8b9 100644 --- a/tests/test_ops/test_iou3d.py +++ b/tests/test_ops/test_iou3d.py @@ -11,7 +11,11 @@ pytest.param( 'cuda', marks=pytest.mark.skipif( - not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_boxes_overlap_bev(device): np_boxes1 = np.asarray([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 0.0],