Skip to content

Commit

Permalink
fix three_interplote bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
momo609 committed Jun 17, 2024
1 parent e4e8f50 commit 38da280
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@ void three_interpolate_forward_npu(int b, int c, int m, int n,
TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
"three_interpolate_forward ascend only support fp32 and fp16.");

auto point_c_trans = points.transpose(1, 2);

auto point_c_trans = points.transpose(1, 2).to(at::kFloat);
auto weight_cast = weight.to(at::kFloat);
auto out_cast = out.to(at::kFloat);
OpCommand cmd;
cmd.Name("ThreeInterpolate")
.Input(point_c_trans)
.Input(idx)
.Input(weight)
.Output(out)
.Input(weight_cast)
.Output(out_cast)
.Run();

auto output = out.view({b, n, c}).transpose(1, 2);
if (originDtype == at::kHalf) {
out_cast = out_cast.to(at::kHalf);
}
auto output = out_cast.view({b, n, c}).transpose(1, 2);
auto res = output.contiguous();
out.copy_(res);
}
Expand All @@ -34,12 +38,17 @@ void three_interpolate_backward_npu(int b, int c, int n, int m,
TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
"three_interpolate_backward ascend only support fp32 and fp16.");

auto grad_x = at::unsqueeze(grad_out, 3);
auto grad_y = at::unsqueeze(grad_points, 3);

EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight, m, grad_y);
auto grad_x = at::unsqueeze(grad_out, 3).to(at::kFloat);
auto grad_y = at::unsqueeze(grad_points, 3).to(at::kFloat);
auto weight_cast = weight.to(at::kFloat);
EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight_cast, m,
grad_y);

auto output = at::squeeze(grad_y, 3);
auto grad_y_cast = grad_y;
if (originDtype == at::kHalf) {
grad_y_cast = grad_y.to(at::kHalf);
}
auto output = at::squeeze(grad_y_cast, 3);
auto res = output.contiguous();
grad_points.copy_(res);
}
Expand Down

0 comments on commit 38da280

Please sign in to comment.