diff --git a/mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu index fd191ee9c9..8e1e62df51 100644 --- a/mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu @@ -255,7 +255,7 @@ at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value, auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; for (int n = 0; n < batch / im2col_step_; ++n) { auto columns = output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda( at::cuda::getCurrentCUDAStream(), @@ -326,7 +326,7 @@ void ms_deform_attn_cuda_backward( for (int n = 0; n < batch / im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda( at::cuda::getCurrentCUDAStream(), diff --git a/mmcv/ops/multi_scale_deform_attn.py b/mmcv/ops/multi_scale_deform_attn.py index c1d415621a..546fb8b805 100644 --- a/mmcv/ops/multi_scale_deform_attn.py +++ b/mmcv/ops/multi_scale_deform_attn.py @@ -49,6 +49,18 @@ def forward(ctx, value: torch.Tensor, value_spatial_shapes: torch.Tensor, """ ctx.im2col_step = im2col_step + + # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; + # amp won't cast the type of sampling_locations, attention_weights + # (float32), but "value" is cast to float16, leading to the type + # mismatch with input (when it is float32) or weight. + # The flag for whether to use fp16 or amp is the type of "value", + # we cast sampling_locations and attention_weights to + # temporarily support fp16 and amp whatever the + # pytorch version is. + sampling_locations = sampling_locations.type_as(value) + attention_weights = attention_weights.type_as(value) + output = ext_module.ms_deform_attn_forward( value, value_spatial_shapes, diff --git a/tests/test_ops/test_ms_deformable_attn.py b/tests/test_ops/test_ms_deformable_attn.py index a29380552d..8e9f1af8c0 100644 --- a/tests/test_ops/test_ms_deformable_attn.py +++ b/tests/test_ops/test_ms_deformable_attn.py @@ -8,12 +8,21 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE _USING_PARROTS = True +_IS_AUTOCAST_AVAILABLE = True try: from parrots.autograd import gradcheck except ImportError: from torch.autograd import gradcheck _USING_PARROTS = False +try: + # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast + # would be imported and used; we should test if our modules support it. + from torch.cuda.amp import autocast +except ImportError: + _IS_AUTOCAST_AVAILABLE = False + pass + @pytest.mark.parametrize('device', [ 'cpu', @@ -168,6 +177,58 @@ def test_forward_equal_with_pytorch_float(device): assert max_rel_err < 1e-6 +@pytest.mark.skipif( + not _IS_AUTOCAST_AVAILABLE, reason='requires autocast support') +@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support') +def test_forward_equal_with_autocast(): + N, M, D = 1, 2, 2 + Lq, L, P = 2, 2, 2 + shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long) + level_start_index = torch.cat((shapes.new_zeros( + (1, )), shapes.prod(1).cumsum(0)[:-1])) + S = sum((H * W).item() for H, W in shapes) + + torch.manual_seed(3) + value = torch.rand(N, S, M, D) * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2) + attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5 + attention_weights /= attention_weights.sum( + -1, keepdim=True).sum( + -2, keepdim=True) + im2col_step = 2 + output_pytorch = multi_scale_deformable_attn_pytorch( + value, shapes, sampling_locations, attention_weights).detach().cpu() + + # float test + dtype = torch.float + with autocast(enabled=True): + output_device = MultiScaleDeformableAttnFunction.apply( + value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(), + sampling_locations.cuda(), attention_weights.cuda(), + im2col_step).detach().cpu() + assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_device - output_pytorch).abs().max() + max_rel_err = ((output_device - output_pytorch).abs() / + output_pytorch.abs()).max() + assert max_abs_err < 1e-9 + assert max_rel_err < 1e-6 + + # half test + dtype = torch.half + with autocast(enabled=True): + output_device = MultiScaleDeformableAttnFunction.apply( + value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(), + sampling_locations.cuda(), attention_weights.cuda(), + im2col_step).detach().cpu() + assert torch.allclose( + output_device, output_pytorch.half(), rtol=1e-2, atol=1e-3) + max_abs_err = (output_device - output_pytorch).abs().max() + max_rel_err = ((output_device - output_pytorch).abs() / + output_pytorch.abs()).max() + assert max_abs_err < 1e-5 + assert max_rel_err < 1e-2 + + @pytest.mark.parametrize('device', [ pytest.param( 'cuda',