From 27df2b3b9c874129b1ac3d8a51d2ef7cde96f06c Mon Sep 17 00:00:00 2001 From: RRaoyzee <162255573+RRaoyzee@users.noreply.github.com> Date: Thu, 16 May 2024 17:21:13 +0800 Subject: [PATCH] Refactor the implementation of ms_deform_attn for Ascend (#3057) --- .../csrc/pytorch/npu/ms_deform_attn_npu.cpp | 41 ++----------------- 1 file changed, 3 insertions(+), 38 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp index 7e943ca12f..453be7034c 100644 --- a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp @@ -89,44 +89,9 @@ void ms_deform_attn_backward_npu( Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step) { check_support(value, attn_weight); - at::Tensor value_fp32 = value; - at::Tensor spatial_shapes_int32 = spatial_shapes; - at::Tensor level_start_index_int32 = level_start_index; - at::Tensor sampling_loc_fp32 = sampling_loc.transpose(4, 5).contiguous(); - at::Tensor attn_weight_fp32 = attn_weight; - at::Tensor grad_output_fp32 = grad_output; - if (value.scalar_type() != at::kFloat) { - value_fp32 = value.to(at::kFloat); - } - if (spatial_shapes.scalar_type() != at::kInt) { - spatial_shapes_int32 = spatial_shapes.to(at::kInt); - } - if (level_start_index.scalar_type() != at::kInt) { - level_start_index_int32 = level_start_index.to(at::kInt); - } - if (sampling_loc.scalar_type() != at::kFloat) { - sampling_loc_fp32 = sampling_loc_fp32.to(at::kFloat); - } - if (attn_weight.scalar_type() != at::kFloat) { - attn_weight_fp32 = attn_weight.to(at::kFloat); - } - if (grad_output.scalar_type() != at::kFloat) { - grad_output_fp32 = grad_output.to(at::kFloat); - } - - OpCommand cmd; - cmd.Name("MultiScaleDeformableAttentionGrad") - .Input(value_fp32) - .Input(spatial_shapes_int32) - .Input(level_start_index_int32) - .Input(sampling_loc_fp32) - .Input(attn_weight_fp32) - .Input(grad_output_fp32) - .Output(grad_value) - .Output(grad_sampling_loc) - .Output(grad_attn_weight) - .Run(); - grad_sampling_loc = grad_sampling_loc.transpose(4, 5).contiguous(); + EXEC_NPU_CMD(aclnnMultiScaleDeformableAttentionGrad, value, spatial_shapes, + level_start_index, sampling_loc, attn_weight, grad_output, + grad_value, grad_sampling_loc, grad_attn_weight); } REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, ms_deform_attn_backward_npu);