Skip to content

Commit

Permalink
Add multi_scale_deform_attn op adapter for NPU (#3032)
Browse files Browse the repository at this point in the history
  • Loading branch information
DaGaiBa committed Mar 1, 2024
1 parent d9e10e1 commit cd05d71
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ We implement common ops used in detection, segmentation, etc.
| MergeCells | || | | |
| MinAreaPolygon | || | | |
| ModulatedDeformConv2d |||| ||
| MultiScaleDeformableAttn | ||| | |
| MultiScaleDeformableAttn | ||| | |
| NMS |||| ||
| NMSRotated |||| ||
| NMSQuadri ||| | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MergeCells | || | | |
| MinAreaPolygon | || | | |
| ModulatedDeformConv2d |||| ||
| MultiScaleDeformableAttn | ||| | |
| MultiScaleDeformableAttn | ||| | |
| NMS |||| ||
| NMSRotated |||| ||
| NMSQuadri ||| | | |
Expand Down
77 changes: 77 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

Tensor ms_deform_attn_impl_forward(const Tensor &value,
const Tensor &value_spatial_shapes,
const Tensor &value_level_start_index,
const Tensor &sampling_locations,
const Tensor &attention_weights,
const int im2col_step);

void check_support(const Tensor &value, const Tensor &attention_weights) {
TORCH_CHECK(
(value.scalar_type() == at::kFloat || value.scalar_type() == at::kHalf),
"Dtype of value should be float32 or float16.");
int64_t num_heads = value.size(2);
int64_t embed_dims = value.size(3);
int64_t num_points = attention_weights.size(4);
TORCH_CHECK((num_heads >= 4 && num_heads <= 8),
"num_heads should be in the range of [4, 8]");
TORCH_CHECK((embed_dims >= 32 && embed_dims <= 256),
"embed_dims should be in the range of [32, 256]");
TORCH_CHECK((num_points >= 4 && num_points <= 8),
"num_points should be in the range of [4, 8]");
}

Tensor ms_deform_attn_forward_npu(const Tensor &value,
const Tensor &value_spatial_shapes,
const Tensor &value_level_start_index,
const Tensor &sampling_locations,
const Tensor &attention_weights,
const int im2col_step) {
check_support(value, attention_weights);
at::Tensor value_fp32 = value;
at::Tensor value_spatial_shapes_int32 = value_spatial_shapes;
at::Tensor value_level_start_index_int32 = value_level_start_index;
at::Tensor sampling_locations_fp32 = sampling_locations;
at::Tensor attention_weights_fp32 = attention_weights;
if (value.scalar_type() != at::kFloat) {
value_fp32 = value.to(at::kFloat);
}
if (value_spatial_shapes.scalar_type() != at::kInt) {
value_spatial_shapes_int32 = value_spatial_shapes.to(at::kInt);
}
if (value_level_start_index.scalar_type() != at::kInt) {
value_level_start_index_int32 = value_level_start_index.to(at::kInt);
}
if (sampling_locations.scalar_type() != at::kFloat) {
sampling_locations_fp32 = sampling_locations.to(at::kFloat);
}
if (attention_weights.scalar_type() != at::kFloat) {
attention_weights_fp32 = attention_weights.to(at::kFloat);
}

c10::SmallVector<int64_t, 3> output_size = {
value.size(0), sampling_locations.size(1), value.size(2) * value.size(3)};
at::Tensor output = at::empty(output_size, value_fp32.options());

OpCommand cmd;
cmd.Name("MultiScaleDeformableAttnFunction")
.Input(value_fp32)
.Input(value_spatial_shapes_int32)
.Input(value_level_start_index_int32)
.Input(sampling_locations_fp32)
.Input(attention_weights_fp32)
.Output(output)
.Run();

at::Tensor real_output = output;
if (value.scalar_type() != at::kFloat) {
real_output = output.to(value.scalar_type());
}
return real_output;
}

REGISTER_NPU_IMPL(ms_deform_attn_impl_forward, ms_deform_attn_forward_npu);
7 changes: 4 additions & 3 deletions mmcv/ops/multi_scale_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mmengine.utils import deprecated_api_warning
from torch.autograd.function import Function, once_differentiable

from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from ..utils import ext_loader

ext_module = ext_loader.load_ext(
Expand Down Expand Up @@ -84,7 +84,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple:
Returns:
tuple[Tensor]: Gradient of input tensors in forward.
"""
value, value_spatial_shapes, value_level_start_index,\
value, value_spatial_shapes, value_level_start_index, \
sampling_locations, attention_weights = ctx.saved_tensors
grad_value = torch.zeros_like(value)
grad_sampling_loc = torch.zeros_like(sampling_locations)
Expand Down Expand Up @@ -364,7 +364,8 @@ def forward(self,
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if ((IS_CUDA_AVAILABLE and value.is_cuda)
or (IS_MLU_AVAILABLE and value.is_mlu)):
or (IS_MLU_AVAILABLE and value.is_mlu)
or (IS_NPU_AVAILABLE and value.device.type == 'npu')):
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
Expand Down
36 changes: 35 additions & 1 deletion tests/test_ops/test_ms_deformable_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
multi_scale_deformable_attn_pytorch)
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE

_USING_PARROTS = True
_IS_AUTOCAST_AVAILABLE = True
Expand Down Expand Up @@ -136,6 +136,40 @@ def test_forward_equal_with_pytorch_double():
assert max_rel_err < 1e-15


@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
def test_forward_equal_with_pytorch_npu():
N, M, D = 6, 4, 8
Lq, L, P = 10000, 4, 8
shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)],
dtype=torch.int32)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)

torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value.float(), shapes, sampling_locations.float(),
attention_weights.float()).detach().cpu()

output_npu = MultiScaleDeformableAttnFunction.apply(
value.npu().float(), shapes.npu(), level_start_index.npu(),
sampling_locations.npu().float(),
attention_weights.npu().float(), im2col_step).detach().cpu()
assert torch.allclose(output_npu, output_pytorch)
max_abs_err = (output_npu - output_pytorch).abs().max()
max_rel_err = ((output_npu - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-18
assert max_rel_err < 1e-15


@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
Expand Down

0 comments on commit cd05d71

Please sign in to comment.