diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 34604c05f43..f4d5251ebf7 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -39,7 +39,7 @@ We implement common ops used in detection, segmentation, etc. | NMSQuadri | √ | √ | | | | | PixelGroup | √ | | | | | | PointsInBoxes | √ | √ | | | | -| PointsInPolygons | | √ | | | | +| PointsInPolygons | | √ | | | √ | | PSAMask | √ | √ | √ | | √ | | RotatedFeatureAlign | √ | √ | √ | | | | RoIPointPool3d | | √ | √ | | | diff --git a/mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp b/mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp new file mode 100644 index 00000000000..43e3213cfea --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp @@ -0,0 +1,28 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +constexpr int32_t MAX_POLYGONS_BATCH = 2800; + + +void points_in_polygons_npu(const Tensor points, Tensor polygons, + Tensor output, const int rows, + const int cols) { + TORCH_CHECK((polygons.sizes()[0] <= MAX_POLYGONS_BATCH), + "The batch of polygons tensor must be less than MAX_POLYGONS_BATCH"); + at::Tensor trans_polygons = polygons.transpose(0, 1); + OpCommand cmd; + at::Tensor new_trans_polygons = NpuUtils::format_contiguous(trans_polygons); + cmd.Name("PointsInPolygons") + .Input(points, (string)"points") + .Input(new_trans_polygons, (string)"polygons") + .Output(output) + .Run(); +} + +void points_in_polygons_forward_impl(const Tensor points, Tensor polygons, + Tensor output, const int rows, + const int cols); + +REGISTER_NPU_IMPL(points_in_polygons_forward_impl, points_in_polygons_npu); diff --git a/mmcv/ops/points_in_polygons.py b/mmcv/ops/points_in_polygons.py index 62d0dbdc908..e54b5a896df 100644 --- a/mmcv/ops/points_in_polygons.py +++ b/mmcv/ops/points_in_polygons.py @@ -31,8 +31,11 @@ def points_in_polygons(points: Tensor, polygons: Tensor) -> Tensor: assert polygons.shape[1] == 8, \ 'polygons dimension should be 8, ' \ f'but got unexpected shape {polygons.shape[1]}' - output = torch.full([points.shape[0], polygons.shape[0]], - 0.).cuda().float() + output = torch.zeros( + points.shape[0], + polygons.shape[0], + dtype=torch.float32, + device=points.device) ext_module.points_in_polygons_forward(points.contiguous(), polygons.contiguous(), output) return output diff --git a/tests/test_ops/test_points_in_polygons.py b/tests/test_ops/test_points_in_polygons.py index dde8ab02391..d224d1593ad 100644 --- a/tests/test_ops/test_points_in_polygons.py +++ b/tests/test_ops/test_points_in_polygons.py @@ -4,20 +4,29 @@ import torch from mmcv.ops import points_in_polygons +from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_points_in_polygons(): +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) +def test_points_in_polygons(device): points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250], [100, 0]]) polygons = np.array([[200., 200., 400., 400., 500., 200., 400., 100.], [400., 400., 500., 500., 600., 300., 500., 200.], [300., 300., 600., 700., 700., 700., 700., 100.]]) expected_output = np.array([[0., 0., 0.], [0., 0., 1.], [0., 0., 0.], - [1., 0., 0.], [0., 0., 0.]]) - points = torch.from_numpy(points).cuda().float() - polygons = torch.from_numpy(polygons).cuda().float() - expected_output = torch.from_numpy(expected_output).cuda().float() - assert torch.allclose( - points_in_polygons(points, polygons), expected_output, 1e-3) + [1., 0., 0.], [0., 0., 0.]]).astype(np.float32) + points = torch.tensor(points, dtype=torch.float32, device=device) + polygons = torch.tensor(polygons, dtype=torch.float32, device=device) + assert np.allclose( + points_in_polygons(points, polygons).cpu().numpy(), expected_output, + 1e-3)