diff --git a/mmcv/ops/bbox.py b/mmcv/ops/bbox.py index bf6bd43bbb..4ba93d6b22 100644 --- a/mmcv/ops/bbox.py +++ b/mmcv/ops/bbox.py @@ -106,25 +106,17 @@ def bbox_overlaps(bboxes1: torch.Tensor, rows = bboxes1.size(0) cols = bboxes2.size(0) + if aligned: assert rows == cols + ious = bboxes1.new_zeros(rows) + else: + ious = bboxes1.new_zeros((rows, cols)) if rows * cols == 0: - return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols) - - if bboxes1.device.type == 'cpu': - return _bbox_overlaps_cpu( - bboxes1, bboxes2, mode=mode, aligned=aligned, offset=offset) - else: - if aligned: - ious = bboxes1.new_zeros(rows) - else: - ious = bboxes1.new_zeros((rows, cols)) - ext_module.bbox_overlaps( - bboxes1, - bboxes2, - ious, - mode=mode_flag, - aligned=aligned, - offset=offset) return ious + + ext_module.bbox_overlaps( + bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset) + + return ious diff --git a/mmcv/ops/csrc/pytorch/cpu/bbox_overlaps_cpu.cpp b/mmcv/ops/csrc/pytorch/cpu/bbox_overlaps_cpu.cpp new file mode 100644 index 0000000000..4498895489 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cpu/bbox_overlaps_cpu.cpp @@ -0,0 +1,65 @@ +// Copyright(c) OpenMMLab.All rights reserved. +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +using torch::indexing::None; +using torch::indexing::Slice; + +void bbox_overlaps_cpu_kernel(const Tensor boxes1, const Tensor boxes2, + Tensor ious, const int mode_flag, + const bool aligned, const int offset) { + Tensor temp_ious; + if (aligned) { + Tensor lt = torch::max(boxes1.index({Slice(None), Slice({None, 2})}), + boxes2.index({Slice(None), Slice({None, 2})})); + Tensor rb = torch::min(boxes1.index({Slice(None), Slice(2)}), + boxes2.index({Slice(None), Slice(2)})); + Tensor wh = (rb - lt + offset).clamp(0.f, INT_MAX * 1.f); + Tensor overlap = wh.index({Slice(None), 0}) * wh.index({Slice(None), 1}); + Tensor area1 = (boxes1.index({Slice(None), 2}) - + boxes1.index({Slice(None), 0}) + offset) * + (boxes1.index({Slice(None), 3}) - + boxes1.index({Slice(None), 1}) + offset); + if (mode_flag == 0) { + Tensor area2 = (boxes2.index({Slice(None), 2}) - + boxes2.index({Slice(None), 0}) + offset) * + (boxes2.index({Slice(None), 3}) - + boxes2.index({Slice(None), 1}) + offset); + temp_ious = overlap / (area1 + area2 - overlap); + } else { + temp_ious = overlap / area1; + } + } else { + Tensor lt = torch::max(boxes1.index({Slice(None), None, Slice({None, 2})}), + boxes2.index({Slice(None), Slice({None, 2})})); + Tensor rb = torch::min(boxes1.index({Slice(None), None, Slice(2)}), + boxes2.index({Slice(None), Slice(2)})); + Tensor wh = (rb - lt + offset).clamp(0.f, INT_MAX * 1.f); + Tensor overlap = wh.index({"...", 0}) * wh.index({"...", 1}); + Tensor area1 = (boxes1.index({Slice(None), 2}) - + boxes1.index({Slice(None), 0}) + offset) * + (boxes1.index({Slice(None), 3}) - + boxes1.index({Slice(None), 1}) + offset); + if (mode_flag == 0) { + Tensor area2 = (boxes2.index({Slice(None), 2}) - + boxes2.index({Slice(None), 0}) + offset) * + (boxes2.index({Slice(None), 3}) - + boxes2.index({Slice(None), 1}) + offset); + temp_ious = + overlap / (area1.index({Slice(None), None}) + area2 - overlap); + } else { + temp_ious = overlap / area1.index({Slice(None), None}); + } + } + ious.copy_(temp_ious); +} + +void bbox_overlaps_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode, const bool aligned, const int offset) { + bbox_overlaps_cpu_kernel(boxes1, boxes2, ious, mode, aligned, offset); +} + +void bbox_overlaps_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode, const bool aligned, const int offset); + +REGISTER_DEVICE_IMPL(bbox_overlaps_impl, CPU, bbox_overlaps_cpu); diff --git a/tests/test_ops/test_bbox.py b/tests/test_ops/test_bbox.py index 06f2d8f83b..3d1486eb01 100644 --- a/tests/test_ops/test_bbox.py +++ b/tests/test_ops/test_bbox.py @@ -35,6 +35,14 @@ def _test_bbox_overlaps(self, device='cpu', dtype=torch.float): out = bbox_overlaps(b1, b2, offset=1) assert np.allclose(out.cpu().numpy(), should_output, 1e-2) + b1 = torch.tensor([[10.0 + i, 10.0 + i, 30.0 + i, 30.0 + i] + for i in range(1000)]).to(device).type(dtype) + b2 = torch.tensor([[20.0 + i, 20.0 + i, 40.0 + i, 40.0 + i] + for i in range(1000)]).to(device).type(dtype) + should_output = np.array([1 / 7] * 1000) + out = bbox_overlaps(b1, b2, aligned=True) + assert np.allclose(out.cpu().numpy(), should_output, 1e-2) + @pytest.mark.parametrize('device', [ 'cpu', pytest.param(