From c26a8d5b4fa303ea563293b6611e523675253ca5 Mon Sep 17 00:00:00 2001 From: ckirchhoff <515629648@qq.com> Date: Sat, 28 Jan 2023 21:08:22 +0800 Subject: [PATCH] [Feature] Add support for Ascend devices with nms_rotated (#2550) * [Feature]: add nms_rotated npu adaptater code * [BugFix]: modify param in nms_rotated_npu.cpp * [clean code]: nms_rotated_npu.cpp * [clean code]: nms_rotated_npu.cpp * [clean code]: nms_rotated_npu.cpp * [clean code]: nms_rotated.cpp * [Doc]: add nms_rotated op in supported op list at ops.md * [Test]: add nms_rotated unit_test * [Bug]: remove device parameter in test_batched_nms function --- docs/en/understand_mmcv/ops.md | 2 +- docs/zh_cn/understand_mmcv/ops.md | 2 +- mmcv/ops/csrc/pytorch/nms_rotated.cpp | 15 ++++++-- mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp | 32 +++++++++++++++++ mmcv/ops/csrc/pytorch/pybind.cpp | 6 ++-- mmcv/ops/nms.py | 17 ++++++++- tests/test_ops/test_nms_rotated.py | 35 ++++++++++++++----- 7 files changed, 93 insertions(+), 16 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 93a77d122a..dedcdb4ed4 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -35,7 +35,7 @@ We implement common ops used in detection, segmentation, etc. | ModulatedDeformConv2d | √ | √ | √ | | √ | | MultiScaleDeformableAttn | | √ | √ | | | | NMS | √ | √ | √ | | √ | -| NMSRotated | √ | √ | | | | +| NMSRotated | √ | √ | | | √ | | NMSQuadri | √ | √ | | | | | PixelGroup | √ | | | | | | PointsInBoxes | √ | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 4e19a7f87e..ddf68d79f5 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -35,7 +35,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | ModulatedDeformConv2d | √ | √ | √ | | √ | | MultiScaleDeformableAttn | | √ | √ | | | | NMS | √ | √ | √ | | √ | -| NMSRotated | √ | √ | | | | +| NMSRotated | √ | √ | | | √ | | NMSQuadri | √ | √ | | | | | PixelGroup | √ | | | | | | PointsInBoxes | √ | √ | | | | diff --git a/mmcv/ops/csrc/pytorch/nms_rotated.cpp b/mmcv/ops/csrc/pytorch/nms_rotated.cpp index e4ef676a9d..ed669169a3 100644 --- a/mmcv/ops/csrc/pytorch/nms_rotated.cpp +++ b/mmcv/ops/csrc/pytorch/nms_rotated.cpp @@ -12,12 +12,17 @@ Tensor nms_rotated_cuda(const Tensor dets, const Tensor scores, const float iou_threshold, const int multi_label); #endif +#ifdef MMCV_WITH_NPU +Tensor nms_rotated_npu(const Tensor dets, const Tensor scores, + const Tensor labels, const float iou_threshold); +#endif + // Interface for Python // inline is needed to prevent multiple function definitions when this header is // included by different cpps Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, - const Tensor dets_sorted, const float iou_threshold, - const int multi_label) { + const Tensor dets_sorted, const Tensor labels, + const float iou_threshold, const int multi_label) { assert(dets.device().is_cuda() == scores.device().is_cuda()); if (dets.device().is_cuda()) { #ifdef MMCV_WITH_CUDA @@ -25,6 +30,12 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, multi_label); #else AT_ERROR("Not compiled with GPU support"); +#endif + } else if (dets.device().type() == at::kXLA) { +#ifdef MMCV_WITH_NPU + return nms_rotated_npu(dets, scores, labels, iou_threshold); +#else + AT_ERROR("Not compiled with NPU support"); #endif } diff --git a/mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp new file mode 100644 index 0000000000..b82ae585cd --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp @@ -0,0 +1,32 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; + +Tensor nms_rotated_npu(const Tensor dets, const Tensor scores, + const Tensor labels, const float iou_threshold) { + auto originDtype = dets.scalar_type(); + at::Tensor detsCast = dets; + at::Tensor scoresCast = scores; + if (originDtype != at::ScalarType::Float) { + detsCast = NPUNativeFunctions::npu_dtype_cast(dets, at::kFloat); + scoresCast = NPUNativeFunctions::npu_dtype_cast(scores, at::kFloat); + } + c10::SmallVector selectedIndexSize = {dets.size(0)}; + at::Tensor selectedBox = OpPreparation::ApplyTensor(dets); + at::Tensor selectedIndex = OpPreparation::ApplyTensor( + selectedIndexSize, dets.options().dtype(at::kInt), dets); + + c10::SmallVector output_sync_idx = {0, 1}; + OpCommand cmd; + cmd.Sync(output_sync_idx) + .Name("RotatedNMS") + .Input(detsCast) + .Input(scoresCast) + .Input(labels) + .Output(selectedBox) + .Output(selectedIndex) + .Attr("iou_threshold", (float)iou_threshold) + .Run(); + selectedIndex = NPUNativeFunctions::npu_dtype_cast(selectedIndex, at::kLong); + return selectedIndex; +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 4947b72152..74666d00e5 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -309,8 +309,8 @@ void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious, const int mode_flag, const bool aligned); Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, - const Tensor dets_sorted, const float iou_threshold, - const int multi_label); + const Tensor dets_sorted, const Tensor labels, + const float iou_threshold, const int multi_label); Tensor upfirdn2d(const Tensor &input, const Tensor &kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, @@ -748,7 +748,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("mode_flag"), py::arg("aligned")); m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"), py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), - py::arg("iou_threshold"), py::arg("multi_label")); + py::arg("labels"), py::arg("iou_threshold"), py::arg("multi_label")); m.def("ball_query_forward", &ball_query_forward, "ball_query_forward", py::arg("new_xyz_tensor"), py::arg("xyz_tensor"), py::arg("idx_tensor"), py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"), diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index 5d3e70b672..2e2d9fa005 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -454,6 +454,19 @@ def nms_rotated(dets: Tensor, else: dets_cw = dets multi_label = labels is not None + if labels is None: + input_labels = scores.new_empty(0, dtype=torch.int) + else: + input_labels = labels + if dets.device.type == 'npu': + order = scores.new_empty(0, dtype=torch.long) + keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw, + input_labels, iou_threshold, + multi_label) + dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)), + dim=1) + return dets, keep_inds + if multi_label: dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1) # type: ignore else: @@ -467,11 +480,13 @@ def nms_rotated(dets: Tensor, scores, order, dets_sorted, + input_labels, iou_threshold=iou_threshold, multi_label=multi_label) else: keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted, - iou_threshold, multi_label) + input_labels, iou_threshold, + multi_label) dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)), dim=1) return dets, keep_inds diff --git a/tests/test_ops/test_nms_rotated.py b/tests/test_ops/test_nms_rotated.py index 1b7f3607b0..bee562a6f1 100644 --- a/tests/test_ops/test_nms_rotated.py +++ b/tests/test_ops/test_nms_rotated.py @@ -3,13 +3,22 @@ import pytest import torch +from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE + -@pytest.mark.skipif( - not torch.cuda.is_available(), - reason='GPU is required to test NMSRotated op') class TestNmsRotated: - def test_ml_nms_rotated(self): + @pytest.mark.parametrize('device', [ + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + ]) + def test_ml_nms_rotated(self, device): from mmcv.ops import nms_rotated np_boxes = np.array( [[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8], @@ -24,8 +33,8 @@ def test_ml_nms_rotated(self): dtype=np.float32) np_expect_keep_inds = np.array([3, 1, 0], dtype=np.int64) - boxes = torch.from_numpy(np_boxes).cuda() - labels = torch.from_numpy(np_labels).cuda() + boxes = torch.from_numpy(np_boxes).to(device) + labels = torch.from_numpy(np_labels).to(device) # test cw angle definition dets, keep_inds = nms_rotated(boxes[:, :5], boxes[:, -1], 0.5, labels) @@ -41,7 +50,17 @@ def test_ml_nms_rotated(self): assert np.allclose(dets.cpu().numpy()[:, :5], np_expect_dets) assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds) - def test_nms_rotated(self): + @pytest.mark.parametrize('device', [ + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + ]) + def test_nms_rotated(self, device): from mmcv.ops import nms_rotated np_boxes = np.array( [[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8], @@ -55,7 +74,7 @@ def test_nms_rotated(self): dtype=np.float32) np_expect_keep_inds = np.array([3, 1, 0], dtype=np.int64) - boxes = torch.from_numpy(np_boxes).cuda() + boxes = torch.from_numpy(np_boxes).to(device) # test cw angle definition dets, keep_inds = nms_rotated(boxes[:, :5], boxes[:, -1], 0.5)