From 0d3cdc6ee016b6cca6a1220fcbc8f2a6ee3426cc Mon Sep 17 00:00:00 2001 From: Annarine <57247683+Annarine@users.noreply.github.com> Date: Thu, 25 Apr 2024 15:54:29 +0800 Subject: [PATCH] fix fps csrc bug (#3094) --- .../pytorch/npu/furthest_point_sampling_with_dist_npu.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp b/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp index 364d3bfa9a..24317a06bb 100644 --- a/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp @@ -6,11 +6,11 @@ void furthest_point_sampling_with_dist_npu(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor, int b, int n, int m) { - auto output_size = {b, m}; - at::Tensor result = - at::empty(output_size, points_tensor.options().dtype(at::kInt)); + TORCH_CHECK( + (points_tensor.sizes()[1] >= m), + "the num of sampled points should smaller than total num of points."); EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor, - m, result); + m, idx_tensor); } void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor,