Skip to content

Commit

Permalink
fix npu bug
Browse files Browse the repository at this point in the history
  • Loading branch information
hust17yixuan committed Aug 28, 2024
1 parent 5916fbd commit c40d772
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 24 deletions.
4 changes: 2 additions & 2 deletions mmcv/ops/csrc/pytorch/npu/knn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ using namespace std;
void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2) {
// transpose known from [B, N, 3] to [B, 3, N]
at::Tensor source = xyz.transpose(1, 2).contiguous();
at::Tensor source = xyz.transpose(2, 1).contiguous();
at::Tensor target = new_xyz.contiguous();

bool is_from_knn = true;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, idx, dist2);
}

void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
Expand Down
29 changes: 19 additions & 10 deletions mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@ void three_interpolate_forward_npu(int b, int c, int m, int n,
TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
"three_interpolate_forward ascend only support fp32 and fp16.");

auto point_c_trans = points.transpose(1, 2);

auto point_c_trans = points.transpose(1, 2).to(at::kFloat);
auto weight_cast = weight.to(at::kFloat);
auto out_cast = out.to(at::kFloat);
OpCommand cmd;
cmd.Name("ThreeInterpolate")
.Input(point_c_trans)
.Input(idx)
.Input(weight)
.Output(out)
.Input(weight_cast)
.Output(out_cast)
.Run();

auto output = out.view({b, n, c}).transpose(1, 2);
if (originDtype == at::kHalf) {
out_cast = out_cast.to(at::kHalf);
}
auto output = out_cast.view({b, n, c}).transpose(1, 2);
auto res = output.contiguous();
out.copy_(res);
}
Expand All @@ -34,12 +38,17 @@ void three_interpolate_backward_npu(int b, int c, int n, int m,
TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
"three_interpolate_backward ascend only support fp32 and fp16.");

auto grad_x = at::unsqueeze(grad_out, 3);
auto grad_y = at::unsqueeze(grad_points, 3);

EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight, m, grad_y);
auto grad_x = at::unsqueeze(grad_out, 3).to(at::kFloat);
auto grad_y = at::unsqueeze(grad_points, 3).to(at::kFloat);
auto weight_cast = weight.to(at::kFloat);
EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight_cast, m,
grad_y);

auto output = at::squeeze(grad_y, 3);
auto grad_y_cast = grad_y;
if (originDtype == at::kHalf) {
grad_y_cast = grad_y.to(at::kHalf);
}
auto output = at::squeeze(grad_y_cast, 3);
auto res = output.contiguous();
grad_points.copy_(res);
}
Expand Down
15 changes: 3 additions & 12 deletions mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,12 @@ using namespace std;

void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx) {
// transpose known [B, N, 3] -> [B, 3, N]
at::Tensor source = known.transpose(1, 2).contiguous();
at::Tensor source = known.contiguous();
at::Tensor target = unknown.contiguous();
auto originDtype = source.scalar_type();
if (originDtype == at::kHalf) {
source = source.to(at::kFloat);
target = target.to(at::kFloat);
}


bool is_from_knn = false;
uint32_t nsample = 3;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
if (originDtype == at::kHalf) {
dist2 = dist2.to(at::kHalf);
}
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2);
}

void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
Expand Down
17 changes: 17 additions & 0 deletions mmcv/ops/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ def forward(ctx,

B, npoint, _ = center_xyz.shape
N = xyz.shape[1]
if xyz.device.type == 'npu':
dist = center_xyz.new_zeros((B, npoint, N)).float()
ext_module.knn_forward(
xyz,
center_xyz,
torch.Tensor([]).npu(),
dist,
b=B,
n=N,
m=npoint,
nsample=k)
dist2, idx = torch.topk(dist, k, dim=2, largest=False, sorted=True)
zeros_idx = torch.zeros(
xyz.shape[0], center_xyz.shape[1], k, dtype=torch.int32).npu()
idx.where(dist2 >= 1e10, zeros_idx)
idx = idx.transpose(2, 1).contiguous() # [B, k, npoint]
return idx.int()

idx = center_xyz.new_zeros((B, npoint, k)).int()
dist2 = center_xyz.new_zeros((B, npoint, k)).float()
Expand Down
15 changes: 15 additions & 0 deletions mmcv/ops/three_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ def forward(ctx: Any, target: torch.Tensor,

B, N, _ = target.size()
m = source.size(1)
if source.device.type == 'npu':
# strict to fp32
source = source.transpose(2, 1).contiguous()
dtype_ = source.dtype
if dtype_ == torch.float16:
target = target.float()
source = source.float()
dist = target.new_empty(B, N, m)
ext_module.three_nn_forward(
target, source, dist, torch.Tensor([]).npu(), b=B, n=N, m=m)
dist2, idx = torch.topk(dist, 3, dim=2, largest=False, sorted=True)
dist2 = torch.sqrt(dist2)
if dtype_ == torch.float16:
dist2 = dist2.half()
return dist2, idx.int()
dist2 = target.new_empty(B, N, 3)
idx = target.new_empty(B, N, 3, dtype=torch.int32)

Expand Down

0 comments on commit c40d772

Please sign in to comment.