From 987d34b0cf8d6cd8725258332fcfc8c54529b1ab Mon Sep 17 00:00:00 2001 From: MrShadowY <116326039+MrShadowY@users.noreply.github.com> Date: Wed, 19 Jul 2023 10:36:26 +0800 Subject: [PATCH] [Feature] Add the support of BoxIouRotated op for ascend device (#2842) Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- docs/en/understand_mmcv/ops.md | 2 +- docs/zh_cn/understand_mmcv/ops.md | 2 +- mmcv/ops/box_iou_rotated.py | 5 ++ .../csrc/pytorch/npu/box_iou_rotated_npu.cpp | 47 +++++++++++++++++++ tests/test_ops/test_box_iou_rotated.py | 14 ++++-- 5 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 88ac943124..552896fc30 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -9,7 +9,7 @@ We implement common ops used in detection, segmentation, etc. | BallQuery | | √ | √ | | | | BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | -| BoxIouRotated | √ | √ | √ | | | +| BoxIouRotated | √ | √ | √ | | √ | | BoxIouQuadri | √ | √ | | | | | CARAFE | | √ | √ | | | | ChamferDistance | | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 1e5c1e8b4b..8e8651aeb0 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -9,7 +9,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | BallQuery | | √ | √ | | | | BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | -| BoxIouRotated | √ | √ | √ | | | +| BoxIouRotated | √ | √ | √ | | √ | | BoxIouQuadri | √ | √ | | | | | CARAFE | | √ | √ | | | | ChamferDistance | | √ | | | | diff --git a/mmcv/ops/box_iou_rotated.py b/mmcv/ops/box_iou_rotated.py index 8e199d9ac8..a811531d42 100644 --- a/mmcv/ops/box_iou_rotated.py +++ b/mmcv/ops/box_iou_rotated.py @@ -142,6 +142,11 @@ def box_iou_rotated(bboxes1: torch.Tensor, flip_mat[-1] = -1 bboxes1 = bboxes1 * flip_mat bboxes2 = bboxes2 * flip_mat + if bboxes1.device.type == 'npu': + scale_mat = bboxes1.new_ones(bboxes1.shape[-1]) + scale_mat[-1] = 1.0 / 0.01745329252 + bboxes1 = bboxes1 * scale_mat + bboxes2 = bboxes2 * scale_mat bboxes1 = bboxes1.contiguous() bboxes2 = bboxes2.contiguous() ext_module.box_iou_rotated( diff --git a/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp b/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp new file mode 100644 index 0000000000..c6e6b66478 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp @@ -0,0 +1,47 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode_flag, const bool aligned); + +void box_iou_rotated_npu(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode_flag, const bool aligned) { + at::Tensor boxes = at::ones_like(boxes1); + at::Tensor query_boxes = at::ones_like(boxes2); + boxes = boxes1.transpose(0, 1).unsqueeze(0); + query_boxes = boxes2.transpose(0, 1).unsqueeze(0); + + bool is_trans = false; + string modeStr = "iou"; + if (mode_flag == 1) { + modeStr = "iof"; + } + bool is_cross = true; + if (aligned) { + is_cross = false; + } + float v_threshold = 0; + float e_threshold = 0; + + OpCommand cmd; + cmd.Name("RotatedIou") + .Input(boxes) + .Input(query_boxes) + .Output(ious) + .Attr("trans", is_trans) + .Attr("mode", modeStr) + .Attr("is_cross", is_cross) + .Attr("v_threshold", v_threshold) + .Attr("e_threshold", e_threshold) + .Run(); + + if (is_cross) { + ious = ious.view({boxes1.size(0), boxes2.size(0)}); + } else { + ious = ious.view({boxes1.size(0), 1}); + } +} + +REGISTER_NPU_IMPL(box_iou_rotated_impl, box_iou_rotated_npu); diff --git a/tests/test_ops/test_box_iou_rotated.py b/tests/test_ops/test_box_iou_rotated.py index f57e54c1e6..3af811d0fe 100644 --- a/tests/test_ops/test_box_iou_rotated.py +++ b/tests/test_ops/test_box_iou_rotated.py @@ -4,7 +4,7 @@ import torch from mmcv.ops import box_iou_rotated -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE class TestBoxIoURotated: @@ -54,7 +54,11 @@ def test_box_iou_rotated_cpu(self): pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_box_iou_rotated(self, device): np_boxes1 = np.asarray( @@ -137,7 +141,11 @@ def test_box_iou_rotated_iof_cpu(self): pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_box_iou_rotated_iof(self, device): np_boxes1 = np.asarray(