Skip to content

Commit

Permalink
fix gather_point bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
momo609 committed May 30, 2024
1 parent abf8ca7 commit 13e0881
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
void gather_points_backward_npu(int b, int c, int n, int npoints,
const Tensor grad_out, const Tensor idx,
Tensor grad_points) {
at::Tensor grad_out_cast = grad_out;
at::Tensor grad_points_cast = grad_points;
if (grad_out.scalar_type() == at::ScalarType::Half) {
grad_out_cast = at_npu::native::custom_ops::npu_dtype_cast(grad_out, at::ScalarType::Float);
grad_points_cast = at_npu::native::custom_ops::npu_dtype_cast(grad_points, at::ScalarType::Float);
}
at::Tensor indices = idx;
if (idx.scalar_type() != at::ScalarType::Int) {
indices = idx.to(at::kInt);
Expand All @@ -37,11 +43,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints,
for (uint64_t i = 0; i < shape.size(); i++) {
pad_size.emplace_back(shape[i]);
}
at::Tensor trans_grad_points = grad_points.transpose(1, 2).contiguous();
at::Tensor trans_grad_points = grad_points_cast.transpose(1, 2).contiguous();
at::Tensor grad_points_view = trans_grad_points.view(
{trans_grad_points.sizes()[0] * trans_grad_points.sizes()[1],
trans_grad_points.sizes()[2]});
at::Tensor trans_grad_out = grad_out.transpose(1, 2).contiguous();
at::Tensor trans_grad_out = grad_out_cast.transpose(1, 2).contiguous();
trans_grad_out = trans_grad_out.view(
{trans_grad_out.sizes()[0] * trans_grad_out.sizes()[1],
trans_grad_out.sizes()[2]});
Expand All @@ -63,7 +69,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints,
at::Tensor grad_points_result =
grad_points_view.view(trans_grad_points.sizes());
grad_points_result = grad_points_result.transpose(1, 2);
grad_points.copy_(grad_points_result);
at::Tensor grad_points_result_cast = grad_points_result;
if (grad_out.scalar_type() == at::ScalarType::Half) {
grad_points_result_cast = at_npu::native::custom_ops::npu_dtype_cast(grad_points_result, at::ScalarType::Float);
}
grad_points.copy_(grad_points_result_cast);
}

void gather_points_forward_impl(int b, int c, int n, int npoints,
Expand Down

0 comments on commit 13e0881

Please sign in to comment.