From f0497a6495fd9d08084ff98690d8416f13c7beb2 Mon Sep 17 00:00:00 2001 From: zhouzhengkai Date: Thu, 14 Mar 2024 18:05:13 +0800 Subject: [PATCH] fix precision problem of msdag check fix fix fix --- .../csrc/pytorch/npu/ms_deform_attn_npu.cpp | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp index da6f291048..26aa847af5 100644 --- a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp @@ -91,6 +91,12 @@ void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shap Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step) { check_support(value, attn_weight); + int64_t num_heads = value.size(2); + int64_t embed_dims = value.size(3); + int64_t num_points = attn_weight.size(4); + TORCH_CHECK(embed_dims % 32 == 0, "embed_dims must be a multiple of 32, but embed_dims is", embed_dims, "."); + TORCH_CHECK(num_points % 4 == 0, "num_points must be a multiple of four, but num_points is", num_points, "."); + TORCH_CHECK(num_heads % 4 == 0, "num_heads must be a multiple of four, but num_heads is", num_heads, "."); at::Tensor value_fp32 = value; at::Tensor spatial_shapes_int32 = spatial_shapes; at::Tensor level_start_index_int32 = level_start_index; @@ -115,6 +121,10 @@ void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shap if (grad_output.scalar_type() != at::kFloat) { grad_output_fp32 = grad_output.to(at::kFloat); } + ori_type = value.scalar_type(); + at::Tensor grad_value_temp = at::zeros(value_fp32.sizes(), value_fp32.options()); + at::Tensor grad_sampling_loc_temp = at::zeros(sampling_loc_fp32.sizes(), sampling_loc_fp32.options()); + at::Tensor grad_attn_weight_temp = at::zeros(attn_weight_fp32.sizes(), attn_weight_fp32.options()); OpCommand cmd; cmd.Name("MultiScaleDeformableAttentionGrad") @@ -124,11 +134,16 @@ void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shap .Input(sampling_loc_fp32) .Input(attn_weight_fp32) .Input(grad_output_fp32) - .Output(grad_value) - .Output(grad_sampling_loc) - .Output(grad_attn_weight) + .Output(grad_value_temp) + .Output(grad_sampling_loc_temp) + .Output(grad_attn_weight_temp) .Run(); - grad_sampling_loc = grad_sampling_loc.transpose(4, 5).contiguous(); + grad_value_temp = grad_value_temp.to(ori_type); + grad_sampling_loc_temp = grad_sampling_loc_temp.transpose(4, 5).contiguous().to(ori_type); + grad_attn_weight_temp = grad_attn_weight_temp.to(ori_type); + grad_value.copy_(grad_value_temp); + grad_sampling_loc.copy_(grad_sampling_loc_temp); + grad_attn_weight.copy_(grad_attn_weight_temp); } REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, ms_deform_attn_backward_npu);