Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi_scale_deform_attn op adapter for NPU. #3034

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ We implement common ops used in detection, segmentation, etc.
| MergeCells | | √ | | | |
| MinAreaPolygon | | √ | | | |
| ModulatedDeformConv2d | √ | √ | √ | | √ |
| MultiScaleDeformableAttn | | √ | √ | | |
| MultiScaleDeformableAttn | | √ | √ | | |
| NMS | √ | √ | √ | | √ |
| NMSRotated | √ | √ | √ | | √ |
| NMSQuadri | √ | √ | | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MergeCells | | √ | | | |
| MinAreaPolygon | | √ | | | |
| ModulatedDeformConv2d | √ | √ | √ | | √ |
| MultiScaleDeformableAttn | | √ | √ | | |
| MultiScaleDeformableAttn | | √ | √ | | |
| NMS | √ | √ | √ | | √ |
| NMSRotated | √ | √ | √ | | √ |
| NMSQuadri | √ | √ | | | |
Expand Down
77 changes: 77 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

Tensor ms_deform_attn_impl_forward(const Tensor &value,
const Tensor &value_spatial_shapes,
const Tensor &value_level_start_index,
const Tensor &sampling_locations,
const Tensor &attention_weights,
const int im2col_step);

void check_support(const Tensor &value, const Tensor &attention_weights) {
TORCH_CHECK(
(value.scalar_type() == at::kFloat || value.scalar_type() == at::kHalf),
"Dtype of value should be float32 or float16.");
int64_t num_heads = value.size(2);
int64_t embed_dims = value.size(3);
int64_t num_points = attention_weights.size(4);
TORCH_CHECK((num_heads >= 4 && num_heads <= 8),
"num_heads should be in the range of [4, 8]");
TORCH_CHECK((embed_dims >= 32 && embed_dims <= 256),
"embed_dims should be in the range of [32, 256]");
TORCH_CHECK((num_points >= 4 && num_points <= 8),
"num_points should be in the range of [4, 8]");
}

Tensor ms_deform_attn_forward_npu(const Tensor &value,
const Tensor &value_spatial_shapes,
const Tensor &value_level_start_index,
const Tensor &sampling_locations,
const Tensor &attention_weights,
const int im2col_step) {
check_support(value, attention_weights);
at::Tensor value_fp32 = value;
at::Tensor value_spatial_shapes_int32 = value_spatial_shapes;
at::Tensor value_level_start_index_int32 = value_level_start_index;
at::Tensor sampling_locations_fp32 = sampling_locations;
at::Tensor attention_weights_fp32 = attention_weights;
if (value.scalar_type() != at::kFloat) {
value_fp32 = value.to(at::kFloat);
}
if (value_spatial_shapes.scalar_type() != at::kInt) {
value_spatial_shapes_int32 = value_spatial_shapes.to(at::kInt);
}
if (value_level_start_index.scalar_type() != at::kInt) {
value_level_start_index_int32 = value_level_start_index.to(at::kInt);
}
if (sampling_locations.scalar_type() != at::kFloat) {
sampling_locations_fp32 = sampling_locations.to(at::kFloat);
}
if (attention_weights.scalar_type() != at::kFloat) {
attention_weights_fp32 = attention_weights.to(at::kFloat);
}

c10::SmallVector<int64_t, 3> output_size = {
value.size(0), sampling_locations.size(1), value.size(2) * value.size(3)};
at::Tensor output = at::empty(output_size, value_fp32.options());

OpCommand cmd;
cmd.Name("MultiScaleDeformableAttnFunction")
.Input(value_fp32)
.Input(value_spatial_shapes_int32)
.Input(value_level_start_index_int32)
.Input(sampling_locations_fp32)
.Input(attention_weights_fp32)
.Output(output)
.Run();

at::Tensor real_output = output;
if (value.scalar_type() != at::kFloat) {
real_output = output.to(value.scalar_type());
}
return real_output;
}

REGISTER_NPU_IMPL(ms_deform_attn_impl_forward, ms_deform_attn_forward_npu);
7 changes: 4 additions & 3 deletions mmcv/ops/multi_scale_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from ..utils import ext_loader

ext_module = ext_loader.load_ext(
Expand Down Expand Up @@ -85,7 +85,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple:
Returns:
tuple[Tensor]: Gradient of input tensors in forward.
"""
value, value_spatial_shapes, value_level_start_index,\
value, value_spatial_shapes, value_level_start_index, \
sampling_locations, attention_weights = ctx.saved_tensors
grad_value = torch.zeros_like(value)
grad_sampling_loc = torch.zeros_like(sampling_locations)
Expand Down Expand Up @@ -361,7 +361,8 @@ def forward(self,
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if ((IS_CUDA_AVAILABLE and value.is_cuda)
or (IS_MLU_AVAILABLE and value.is_mlu)):
or (IS_MLU_AVAILABLE and value.is_mlu)
or (IS_NPU_AVAILABLE and value.device.type == 'npu')):
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
Expand Down
36 changes: 35 additions & 1 deletion tests/test_ops/test_ms_deformable_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
multi_scale_deformable_attn_pytorch)
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE

_USING_PARROTS = True
_IS_AUTOCAST_AVAILABLE = True
Expand Down Expand Up @@ -116,6 +116,40 @@ def test_forward_equal_with_pytorch_double():
assert max_rel_err < 1e-15


@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
def test_forward_equal_with_pytorch_npu():
N, M, D = 6, 4, 8
Lq, L, P = 10000, 4, 8
shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)],
dtype=torch.int32)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)

torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value.float(), shapes, sampling_locations.float(),
attention_weights.float()).detach().cpu()

output_npu = MultiScaleDeformableAttnFunction.apply(
value.npu().float(), shapes.npu(), level_start_index.npu(),
sampling_locations.npu().float(),
attention_weights.npu().float(), im2col_step).detach().cpu()
assert torch.allclose(output_npu, output_pytorch)
max_abs_err = (output_npu - output_pytorch).abs().max()
max_rel_err = ((output_npu - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-18
assert max_rel_err < 1e-15


@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
Expand Down
Loading