Skip to content

Commit

Permalink
fix precision problem of msdag
Browse files Browse the repository at this point in the history
check

fix

fix

fix
  • Loading branch information
RRaoyzee committed Mar 20, 2024
1 parent 265531f commit f0497a6
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shap
Tensor &grad_sampling_loc,
Tensor &grad_attn_weight, const int im2col_step) {
check_support(value, attn_weight);
int64_t num_heads = value.size(2);
int64_t embed_dims = value.size(3);
int64_t num_points = attn_weight.size(4);
TORCH_CHECK(embed_dims % 32 == 0, "embed_dims must be a multiple of 32, but embed_dims is", embed_dims, ".");
TORCH_CHECK(num_points % 4 == 0, "num_points must be a multiple of four, but num_points is", num_points, ".");
TORCH_CHECK(num_heads % 4 == 0, "num_heads must be a multiple of four, but num_heads is", num_heads, ".");
at::Tensor value_fp32 = value;
at::Tensor spatial_shapes_int32 = spatial_shapes;
at::Tensor level_start_index_int32 = level_start_index;
Expand All @@ -115,6 +121,10 @@ void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shap
if (grad_output.scalar_type() != at::kFloat) {
grad_output_fp32 = grad_output.to(at::kFloat);
}
ori_type = value.scalar_type();
at::Tensor grad_value_temp = at::zeros(value_fp32.sizes(), value_fp32.options());
at::Tensor grad_sampling_loc_temp = at::zeros(sampling_loc_fp32.sizes(), sampling_loc_fp32.options());
at::Tensor grad_attn_weight_temp = at::zeros(attn_weight_fp32.sizes(), attn_weight_fp32.options());

OpCommand cmd;
cmd.Name("MultiScaleDeformableAttentionGrad")
Expand All @@ -124,11 +134,16 @@ void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shap
.Input(sampling_loc_fp32)
.Input(attn_weight_fp32)
.Input(grad_output_fp32)
.Output(grad_value)
.Output(grad_sampling_loc)
.Output(grad_attn_weight)
.Output(grad_value_temp)
.Output(grad_sampling_loc_temp)
.Output(grad_attn_weight_temp)
.Run();
grad_sampling_loc = grad_sampling_loc.transpose(4, 5).contiguous();
grad_value_temp = grad_value_temp.to(ori_type);
grad_sampling_loc_temp = grad_sampling_loc_temp.transpose(4, 5).contiguous().to(ori_type);
grad_attn_weight_temp = grad_attn_weight_temp.to(ori_type);
grad_value.copy_(grad_value_temp);
grad_sampling_loc.copy_(grad_sampling_loc_temp);
grad_attn_weight.copy_(grad_attn_weight_temp);
}

REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, ms_deform_attn_backward_npu);

0 comments on commit f0497a6

Please sign in to comment.