diff --git a/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp index 5d812fe047..6d2588a01d 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp @@ -3,23 +3,24 @@ using namespace NPU_NAME_SPACE; void iou3d_nms3d_normal_forward_npu(const Tensor boxes, Tensor &keep, - Tensor &keep_num, - float nms_overlap_thresh) { + Tensor &num_out, float nms_overlap_thresh) { int32_t box_num = boxes.size(0); int32_t data_align = 16; int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; + const double iou_threshold = nms_overlap_thresh; at::Tensor mask = at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort)); - EXEC_NPU_CMD(aclnnNms3dNormal, boxes, nms_overlap_thresh, mask); + EXEC_NPU_CMD(aclnnNms3dNormal, boxes, iou_threshold, mask); - keep = at::zeros({box_num}, mask.options()); - keep_num = at::zeros(1, mask.options()); - EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num); + Tensor keep_t = at::zeros({box_num}, mask.options()); + Tensor num_out_t = at::zeros(1, mask.options()); + EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep_t, num_out_t); + num_out.fill_(num_out_t.item().toLong()); + keep.copy_(keep_t); } void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep, - Tensor &keep_num, - float nms_overlap_thresh); + Tensor &num_out, float nms_overlap_thresh); REGISTER_NPU_IMPL(iou3d_nms3d_normal_forward_impl, iou3d_nms3d_normal_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp index 13fe6db860..a143ed07b5 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp @@ -5,22 +5,26 @@ using namespace std; constexpr int32_t BOX_DIM = 7; -void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, Tensor &keep_num, +void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, Tensor &num_out, float nms_overlap_thresh) { TORCH_CHECK((boxes.sizes()[1] == BOX_DIM), "Input boxes shape should be (N, 7)"); int32_t box_num = boxes.size(0); int32_t data_align = 16; int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; + const double iou_threshold = nms_overlap_thresh; at::Tensor mask = at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort)); - EXEC_NPU_CMD(aclnnNms3d, boxes, nms_overlap_thresh, mask); - keep = at::zeros({box_num}, mask.options()); - keep_num = at::zeros(1, mask.options()); - EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num); + EXEC_NPU_CMD(aclnnNms3d, boxes, iou_threshold, mask); + + Tensor keep_t = at::zeros({box_num}, mask.options()); + Tensor num_out_t = at::zeros(1, mask.options()); + EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep_t, num_out_t); + num_out.fill_(num_out_t.item().toLong()); + keep.copy_(keep_t); } -void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep, - Tensor &keep_num, float nms_overlap_thresh); +void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep, Tensor &num_out, + float nms_overlap_thresh); REGISTER_NPU_IMPL(iou3d_nms3d_forward_impl, iou3d_nms3d_forward_npu);