Skip to content

Commit

Permalink
Add gather_point backward npu adpater
Browse files Browse the repository at this point in the history
Add gather_point backward npu adpater.
  • Loading branch information
momo609 committed Sep 14, 2023
1 parent f92e03a commit 44f367c
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,48 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
.Run();
}

void gather_points_backward_npu(int b, int c, int n, int npoints,
const Tensor grad_out, const Tensor idx,
Tensor grad_points) {
at::Tensor indices = idx;
if (idx.scalar_type() != at::ScalarType::Int) {
indices = idx.to(at::kInt);
}
if (idx.dim() == 0) {
indices.unsqueeze_(0);
}
int64_t dim = 0;
at::SmallVector<int64_t, N> pad_size = array_to_small_vector(idx.sizes());
at::Tensor trans_grad_points = grad_points.transpose(1,2).contiguous();
at::Tensor grad_points_view = trans_grad_points.view({trans_grad_points.sizes()[0]*trans_grad_points.sizes()[1], trans_grad_points.sizes()[2]});
at::Tensor trans_grad_out = grad_out.transpose(1,2).contiguous();
trans_grad_out = trans_grad_out.view({trans_grad_out.sizes()[0]*trans_grad_out.sizes()[1], trans_grad_out.sizes()[2]});
auto index = at::arange(0,b);
index = index.to(grad_out.device());
index = at::mul(index,n);
index = index.view({b,1});
index = at::broadcast_to(index, pad_size);
indices = at::add(index, indices);
indices = indices.view({-1});
OpCommand cmd;
cmd.Name("InplaceIndexAdd")
.Input(grad_points_view)
.Input(indices)
.Input(trans_grad_out)
.Output(grad_points_view)
.Attr("axis", dim)
.Run();
at::Tensor grad_points_result = grad_points_view.view(trans_grad_points.sizes());
grad_points_result = grad_points_result.transpose(1,2);
grad_points.copy_(grad_points_result);
}

void gather_points_forward_impl(int b, int c, int n, int npoints,
const Tensor points, const Tensor idx,
Tensor out);
void gather_points_backward_impl(int b, int c, int n, int npoints,
const Tensor grad_out, const Tensor idx,
Tensor grad_points);

REGISTER_NPU_IMPL(gather_points_forward_impl, gather_points_forward_npu);
REGISTER_NPU_IMPL(gather_points_backward_impl, gather_points_backward_npu);

0 comments on commit 44f367c

Please sign in to comment.