Skip to content

Commit

Permalink
[Fix] Fix deform_conv.py for torch_npu v2.1 (#2967)
Browse files Browse the repository at this point in the history
  • Loading branch information
shun001 committed Oct 24, 2023
1 parent 51a9943 commit 00e92ab
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mmcv/ops/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ def symbolic(g,

@staticmethod
def _npu_backward(ctx, grad_output):
import torch_npu
input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \
ctx.saved_tensors
grad_input, grad_weight, grad_offset_all, grad_bias = \
torch.npu_deformable_conv2dbk(
torch_npu.npu_deformable_conv2dbk(
input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
Expand Down

0 comments on commit 00e92ab

Please sign in to comment.