From 4de31250594524484f61fdfaa9f56702a5054ba0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=8C=AF=E8=B1=AA?= Date: Thu, 22 Feb 2024 19:18:43 +0800 Subject: [PATCH] Add multi_scale_deform_attn op adapter for NPU. --- docs/en/understand_mmcv/ops.md | 2 +- docs/zh_cn/understand_mmcv/ops.md | 2 +- .../csrc/pytorch/npu/ms_deform_attn_npu.cpp | 77 +++++++++++++++++++ mmcv/ops/multi_scale_deform_attn.py | 7 +- tests/test_ops/test_ms_deformable_attn.py | 36 ++++++++- 5 files changed, 118 insertions(+), 6 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 80738684e6..2e127932ed 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -33,7 +33,7 @@ We implement common ops used in detection, segmentation, etc. | MergeCells | | √ | | | | | MinAreaPolygon | | √ | | | | | ModulatedDeformConv2d | √ | √ | √ | | √ | -| MultiScaleDeformableAttn | | √ | √ | | | +| MultiScaleDeformableAttn | | √ | √ | | √ | | NMS | √ | √ | √ | | √ | | NMSRotated | √ | √ | √ | | √ | | NMSQuadri | √ | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index efa79d2f4c..7f4d7ea63b 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -33,7 +33,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | MergeCells | | √ | | | | | MinAreaPolygon | | √ | | | | | ModulatedDeformConv2d | √ | √ | √ | | √ | -| MultiScaleDeformableAttn | | √ | √ | | | +| MultiScaleDeformableAttn | | √ | √ | | √ | | NMS | √ | √ | √ | | √ | | NMSRotated | √ | √ | √ | | √ | | NMSQuadri | √ | √ | | | | diff --git a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp new file mode 100644 index 0000000000..fa83f17547 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp @@ -0,0 +1,77 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +Tensor ms_deform_attn_impl_forward(const Tensor &value, + const Tensor &value_spatial_shapes, + const Tensor &value_level_start_index, + const Tensor &sampling_locations, + const Tensor &attention_weights, + const int im2col_step); + +void check_support(const Tensor &value, const Tensor &attention_weights) { + TORCH_CHECK( + (value.scalar_type() == at::kFloat || value.scalar_type() == at::kHalf), + "Dtype of value should be float32 or float16."); + int64_t num_heads = value.size(2); + int64_t embed_dims = value.size(3); + int64_t num_points = attention_weights.size(4); + TORCH_CHECK((num_heads >= 4 && num_heads <= 8), + "num_heads should be in the range of [4, 8]"); + TORCH_CHECK((embed_dims >= 32 && embed_dims <= 256), + "embed_dims should be in the range of [32, 256]"); + TORCH_CHECK((num_points >= 4 && num_points <= 8), + "num_points should be in the range of [4, 8]"); +} + +Tensor ms_deform_attn_forward_npu(const Tensor &value, + const Tensor &value_spatial_shapes, + const Tensor &value_level_start_index, + const Tensor &sampling_locations, + const Tensor &attention_weights, + const int im2col_step) { + check_support(value, attention_weights); + at::Tensor value_fp32 = value; + at::Tensor value_spatial_shapes_int32 = value_spatial_shapes; + at::Tensor value_level_start_index_int32 = value_level_start_index; + at::Tensor sampling_locations_fp32 = sampling_locations; + at::Tensor attention_weights_fp32 = attention_weights; + if (value.scalar_type() != at::kFloat) { + value_fp32 = value.to(at::kFloat); + } + if (value_spatial_shapes.scalar_type() != at::kInt) { + value_spatial_shapes_int32 = value_spatial_shapes.to(at::kInt); + } + if (value_level_start_index.scalar_type() != at::kInt) { + value_level_start_index_int32 = value_level_start_index.to(at::kInt); + } + if (sampling_locations.scalar_type() != at::kFloat) { + sampling_locations_fp32 = sampling_locations.to(at::kFloat); + } + if (attention_weights.scalar_type() != at::kFloat) { + attention_weights_fp32 = attention_weights.to(at::kFloat); + } + + c10::SmallVector output_size = { + value.size(0), sampling_locations.size(1), value.size(2) * value.size(3)}; + at::Tensor output = at::empty(output_size, value_fp32.options()); + + OpCommand cmd; + cmd.Name("MultiScaleDeformableAttnFunction") + .Input(value_fp32) + .Input(value_spatial_shapes_int32) + .Input(value_level_start_index_int32) + .Input(sampling_locations_fp32) + .Input(attention_weights_fp32) + .Output(output) + .Run(); + + at::Tensor real_output = output; + if (value.scalar_type() != at::kFloat) { + real_output = output.to(value.scalar_type()); + } + return real_output; +} + +REGISTER_NPU_IMPL(ms_deform_attn_impl_forward, ms_deform_attn_forward_npu); diff --git a/mmcv/ops/multi_scale_deform_attn.py b/mmcv/ops/multi_scale_deform_attn.py index 8073ebb198..8c09cd2aa0 100644 --- a/mmcv/ops/multi_scale_deform_attn.py +++ b/mmcv/ops/multi_scale_deform_attn.py @@ -13,7 +13,7 @@ from mmcv.cnn import constant_init, xavier_init from mmcv.cnn.bricks.registry import ATTENTION from mmcv.runner import BaseModule -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE from ..utils import ext_loader ext_module = ext_loader.load_ext( @@ -85,7 +85,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple: Returns: tuple[Tensor]: Gradient of input tensors in forward. """ - value, value_spatial_shapes, value_level_start_index,\ + value, value_spatial_shapes, value_level_start_index, \ sampling_locations, attention_weights = ctx.saved_tensors grad_value = torch.zeros_like(value) grad_sampling_loc = torch.zeros_like(sampling_locations) @@ -361,7 +361,8 @@ def forward(self, f'Last dim of reference_points must be' f' 2 or 4, but get {reference_points.shape[-1]} instead.') if ((IS_CUDA_AVAILABLE and value.is_cuda) - or (IS_MLU_AVAILABLE and value.is_mlu)): + or (IS_MLU_AVAILABLE and value.is_mlu) + or (IS_NPU_AVAILABLE and value.device.type == 'npu')): output = MultiScaleDeformableAttnFunction.apply( value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step) diff --git a/tests/test_ops/test_ms_deformable_attn.py b/tests/test_ops/test_ms_deformable_attn.py index 6bf91a3fdb..a2711a3d28 100644 --- a/tests/test_ops/test_ms_deformable_attn.py +++ b/tests/test_ops/test_ms_deformable_attn.py @@ -5,7 +5,7 @@ from mmcv.ops.multi_scale_deform_attn import ( MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch) -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE _USING_PARROTS = True _IS_AUTOCAST_AVAILABLE = True @@ -116,6 +116,40 @@ def test_forward_equal_with_pytorch_double(): assert max_rel_err < 1e-15 +@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support') +def test_forward_equal_with_pytorch_npu(): + N, M, D = 6, 4, 8 + Lq, L, P = 10000, 4, 8 + shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)], + dtype=torch.int32) + level_start_index = torch.cat((shapes.new_zeros( + (1, )), shapes.prod(1).cumsum(0)[:-1])) + S = sum((H * W).item() for H, W in shapes) + + torch.manual_seed(3) + value = torch.rand(N, S, M, D) * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2) + attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5 + attention_weights /= attention_weights.sum( + -1, keepdim=True).sum( + -2, keepdim=True) + im2col_step = 2 + output_pytorch = multi_scale_deformable_attn_pytorch( + value.float(), shapes, sampling_locations.float(), + attention_weights.float()).detach().cpu() + + output_npu = MultiScaleDeformableAttnFunction.apply( + value.npu().float(), shapes.npu(), level_start_index.npu(), + sampling_locations.npu().float(), + attention_weights.npu().float(), im2col_step).detach().cpu() + assert torch.allclose(output_npu, output_pytorch) + max_abs_err = (output_npu - output_pytorch).abs().max() + max_rel_err = ((output_npu - output_pytorch).abs() / + output_pytorch.abs()).max() + assert max_abs_err < 1e-18 + assert max_rel_err < 1e-15 + + @pytest.mark.parametrize('device', [ pytest.param( 'cuda',