Skip to content

Commit

Permalink
[Enhancement] Support MultiScaleDeformableAttention with AMP (open-mm…
Browse files Browse the repository at this point in the history
…lab#2541)

* [Enhance] Support FP16 for MSDeformAttn

* [Fix] Data type mismatch

* Update mmcv/ops/multi_scale_deform_attn.py

* Add UT

Author:    nijkah <nijkah@gmail.com>

* Add cuda available condition

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
nijkah and zhouzaida committed Mar 20, 2023
1 parent 4ac6f50 commit 964afd2
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch / im2col_step_; ++n) {
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(
at::cuda::getCurrentCUDAStream(),
Expand Down Expand Up @@ -326,7 +326,7 @@ void ms_deform_attn_cuda_backward(

for (int n = 0; n < batch / im2col_step_; ++n) {
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(
at::cuda::getCurrentCUDAStream(),
Expand Down
12 changes: 12 additions & 0 deletions mmcv/ops/multi_scale_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ def forward(ctx, value: torch.Tensor, value_spatial_shapes: torch.Tensor,
"""

ctx.im2col_step = im2col_step

# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
# amp won't cast the type of sampling_locations, attention_weights
# (float32), but "value" is cast to float16, leading to the type
# mismatch with input (when it is float32) or weight.
# The flag for whether to use fp16 or amp is the type of "value",
# we cast sampling_locations and attention_weights to
# temporarily support fp16 and amp whatever the
# pytorch version is.
sampling_locations = sampling_locations.type_as(value)
attention_weights = attention_weights.type_as(value)

output = ext_module.ms_deform_attn_forward(
value,
value_spatial_shapes,
Expand Down
61 changes: 61 additions & 0 deletions tests/test_ops/test_ms_deformable_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,21 @@
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE

_USING_PARROTS = True
_IS_AUTOCAST_AVAILABLE = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck
_USING_PARROTS = False

try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# would be imported and used; we should test if our modules support it.
from torch.cuda.amp import autocast
except ImportError:
_IS_AUTOCAST_AVAILABLE = False
pass


@pytest.mark.parametrize('device', [
'cpu',
Expand Down Expand Up @@ -168,6 +177,58 @@ def test_forward_equal_with_pytorch_float(device):
assert max_rel_err < 1e-6


@pytest.mark.skipif(
not _IS_AUTOCAST_AVAILABLE, reason='requires autocast support')
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
def test_forward_equal_with_autocast():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)

torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value, shapes, sampling_locations, attention_weights).detach().cpu()

# float test
dtype = torch.float
with autocast(enabled=True):
output_device = MultiScaleDeformableAttnFunction.apply(
value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
sampling_locations.cuda(), attention_weights.cuda(),
im2col_step).detach().cpu()
assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_device - output_pytorch).abs().max()
max_rel_err = ((output_device - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-9
assert max_rel_err < 1e-6

# half test
dtype = torch.half
with autocast(enabled=True):
output_device = MultiScaleDeformableAttnFunction.apply(
value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
sampling_locations.cuda(), attention_weights.cuda(),
im2col_step).detach().cpu()
assert torch.allclose(
output_device, output_pytorch.half(), rtol=1e-2, atol=1e-3)
max_abs_err = (output_device - output_pytorch).abs().max()
max_rel_err = ((output_device - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-5
assert max_rel_err < 1e-2


@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
Expand Down

0 comments on commit 964afd2

Please sign in to comment.