Skip to content

Commit

Permalink
fix precision problem of msdag
Browse files Browse the repository at this point in the history
check

fix
  • Loading branch information
RRaoyzee committed Mar 14, 2024
1 parent 265531f commit 63f331e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shap
if (grad_output.scalar_type() != at::kFloat) {
grad_output_fp32 = grad_output.to(at::kFloat);
}
at::Tensor grad_sampling_loc_temp = at::zeros(sampling_loc_fp32.sizes(), sampling_loc_fp32.options());

OpCommand cmd;
cmd.Name("MultiScaleDeformableAttentionGrad")
Expand All @@ -125,10 +126,11 @@ void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shap
.Input(attn_weight_fp32)
.Input(grad_output_fp32)
.Output(grad_value)
.Output(grad_sampling_loc)
.Output(grad_sampling_loc_temp)
.Output(grad_attn_weight)
.Run();
grad_sampling_loc = grad_sampling_loc.transpose(4, 5).contiguous();
grad_sampling_loc_temp = grad_sampling_loc_temp.transpose(4, 5).contiguous();
grad_sampling_loc.copy_(grad_sampling_loc_temp);
}

REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, ms_deform_attn_backward_npu);

0 comments on commit 63f331e

Please sign in to comment.