Skip to content

Commit

Permalink
npu knn/tnn bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
lizekai committed Jun 17, 2024
1 parent e4e8f50 commit 95af193
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 2 deletions.
21 changes: 21 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/knn_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "pytorch_npu_helper.hpp"
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "torch_npu/csrc/framework/utils/OpAdapter.h"

using namespace NPU_NAME_SPACE;
using namespace std;

void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2) {
// transpose known from [B, N, 3] to [B, 3, N]
at::Tensor source = xyz.transpose(2, 1).contiguous();
at::Tensor target = new_xyz.contiguous();

bool is_from_knn = true;
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2);
}

void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2);

REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu);
20 changes: 20 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "pytorch_npu_helper.hpp"
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "torch_npu/csrc/framework/utils/OpAdapter.h"

using namespace NPU_NAME_SPACE;
using namespace std;

void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx) {
at::Tensor source = known.contiguous();
at::Tensor target = unknown.contiguous();

bool is_from_knn = false;
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2);
}

void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx);

REGISTER_NPU_IMPL(three_nn_forward_impl, three_nn_forward_npu);
23 changes: 21 additions & 2 deletions mmcv/ops/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,31 @@ def forward(ctx,
center_xyz_device = center_xyz.get_device()
assert center_xyz_device == xyz.get_device(), \
'center_xyz and xyz should be put on the same device'
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)
if xyz.device.type != 'npu':
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)

B, npoint, _ = center_xyz.shape
N = xyz.shape[1]

if xyz.device.type == 'npu':
dist = center_xyz.new_zeros((B, npoint, N)).float()
ext_module.knn_forward(
xyz,
center_xyz,
torch.Tensor([]).npu(),
dist,
b=B,
n=N,
m=npoint,
nsample=k)
dist2, idx = torch.topk(dist, k, dim=2, largest=False, sorted=True)
zeros_idx = torch.zeros(
xyz.shape[0], center_xyz.shape[1], k, dtype=torch.int32).npu()
idx.where(dist2 >= 1e10, zeros_idx)
idx = idx.transpose(2, 1).contiguous() # [B, k, npoint]
return idx.int()

idx = center_xyz.new_zeros((B, npoint, k)).int()
dist2 = center_xyz.new_zeros((B, npoint, k)).float()

Expand Down
15 changes: 15 additions & 0 deletions mmcv/ops/three_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ def forward(ctx: Any, target: torch.Tensor,

B, N, _ = target.size()
m = source.size(1)
if source.device.type == 'npu':
# strict to fp32
source = source.transpose(2, 1).contiguous()
dtype_ = source.dtype
if dtype_ == torch.float16:
target = target.float()
source = source.float()
dist = target.new_empty(B, N, m)
ext_module.three_nn_forward(
target, source, dist, torch.Tensor([]).npu(), b=B, n=N, m=m)
dist2, idx = torch.topk(dist, 3, dim=2, largest=False, sorted=True)
dist2 = torch.sqrt(dist2)
if dtype_ == torch.float16:
dist2 = dist2.half()
return dist2, idx.int()
dist2 = target.new_empty(B, N, 3)
idx = target.new_empty(B, N, 3, dtype=torch.int32)

Expand Down

0 comments on commit 95af193

Please sign in to comment.