diff --git a/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp new file mode 100644 index 0000000000..c4a1bcbd25 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp @@ -0,0 +1,21 @@ +#include "pytorch_npu_helper.hpp" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz, + const Tensor new_xyz, Tensor idx, Tensor dist2) { + // transpose known from [B, N, 3] to [B, 3, N] + at::Tensor source = xyz.transpose(2, 1).contiguous(); + at::Tensor target = new_xyz.contiguous(); + + bool is_from_knn = true; + EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2); +} + +void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz, + const Tensor new_xyz, Tensor idx, Tensor dist2); + +REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp new file mode 100644 index 0000000000..6740a731bc --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp @@ -0,0 +1,20 @@ +#include "pytorch_npu_helper.hpp" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void three_nn_forward_npu(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx) { + at::Tensor source = known.contiguous(); + at::Tensor target = unknown.contiguous(); + + bool is_from_knn = false; + EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2); +} + +void three_nn_forward_impl(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx); + +REGISTER_NPU_IMPL(three_nn_forward_impl, three_nn_forward_npu); diff --git a/mmcv/ops/knn.py b/mmcv/ops/knn.py index 48ce92f925..1e2a68d1d2 100644 --- a/mmcv/ops/knn.py +++ b/mmcv/ops/knn.py @@ -55,12 +55,31 @@ def forward(ctx, center_xyz_device = center_xyz.get_device() assert center_xyz_device == xyz.get_device(), \ 'center_xyz and xyz should be put on the same device' - if torch.cuda.current_device() != center_xyz_device: - torch.cuda.set_device(center_xyz_device) + if xyz.device.type != 'npu': + if torch.cuda.current_device() != center_xyz_device: + torch.cuda.set_device(center_xyz_device) B, npoint, _ = center_xyz.shape N = xyz.shape[1] + if xyz.device.type == 'npu': + dist = center_xyz.new_zeros((B, npoint, N)).float() + ext_module.knn_forward( + xyz, + center_xyz, + torch.Tensor([]).npu(), + dist, + b=B, + n=N, + m=npoint, + nsample=k) + dist2, idx = torch.topk(dist, k, dim=2, largest=False, sorted=True) + zeros_idx = torch.zeros( + xyz.shape[0], center_xyz.shape[1], k, dtype=torch.int32).npu() + idx.where(dist2 >= 1e10, zeros_idx) + idx = idx.transpose(2, 1).contiguous() # [B, k, npoint] + return idx.int() + idx = center_xyz.new_zeros((B, npoint, k)).int() dist2 = center_xyz.new_zeros((B, npoint, k)).float() diff --git a/mmcv/ops/three_nn.py b/mmcv/ops/three_nn.py index d41b9789cf..52d504609a 100644 --- a/mmcv/ops/three_nn.py +++ b/mmcv/ops/three_nn.py @@ -34,6 +34,21 @@ def forward(ctx: Any, target: torch.Tensor, B, N, _ = target.size() m = source.size(1) + if source.device.type == 'npu': + # strict to fp32 + source = source.transpose(2, 1).contiguous() + dtype_ = source.dtype + if dtype_ == torch.float16: + target = target.float() + source = source.float() + dist = target.new_empty(B, N, m) + ext_module.three_nn_forward( + target, source, dist, torch.Tensor([]).npu(), b=B, n=N, m=m) + dist2, idx = torch.topk(dist, 3, dim=2, largest=False, sorted=True) + dist2 = torch.sqrt(dist2) + if dtype_ == torch.float16: + dist2 = dist2.half() + return dist2, idx.int() dist2 = target.new_empty(B, N, 3) idx = target.new_empty(B, N, 3, dtype=torch.int32)