Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add the support of BallQuery op for Ascend device #2963

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/ball_query_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void ball_query_forward_npu(int b, int n, int m, float min_radius,
float max_radius, int nsample, const Tensor new_xyz,
const Tensor xyz, Tensor idx) {
int64_t nsample_i64 = nsample;

// transpose new_xyz from [B, M, 3] to [M, B, 3]
at::Tensor new_xyz_transpose = new_xyz.transpose(0, 1);

// transpose xyz from [B, N, 3] to [B, 3, N]
at::Tensor xyz_transpose = xyz.transpose(1, 2);

// transpose idx from [B, M, nsample] to [M, B, nsample]
at::Tensor idx_transpose = NpuUtils::format_contiguous(idx.transpose(0, 1));

OpCommand cmd;
cmd.Name("BallQuery")
.Input(xyz_transpose)
.Input(new_xyz_transpose)
.Output(idx_transpose)
.Attr("min_radius", min_radius)
.Attr("max_radius", max_radius)
.Attr("sample_num", nsample_i64)
.Run();

idx_transpose = NpuUtils::format_contiguous(idx_transpose.transpose(0, 1));
idx.copy_(idx_transpose);
}

void ball_query_forward_impl(int b, int n, int m, float min_radius,
float max_radius, int nsample,
const Tensor new_xyz, const Tensor xyz,
Tensor idx);

REGISTER_NPU_IMPL(ball_query_forward_impl, ball_query_forward_npu);
8 changes: 6 additions & 2 deletions tests/test_ops/test_ball_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from mmcv.ops import ball_query
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.parametrize('device', [
Expand All @@ -14,7 +14,11 @@
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_ball_query(device):
new_xyz = torch.tensor(
Expand Down
Loading