From ef10dc7c0f8fb58fff22965fd23ecac3cadf6bda Mon Sep 17 00:00:00 2001 From: ZYF-Annarine Date: Wed, 19 Jun 2024 14:47:57 +0800 Subject: [PATCH] modify chamfer --- .../csrc/pytorch/npu/chamfer_distance_npu.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp index 4f5c32dbec..170a5fa72a 100644 --- a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp @@ -6,17 +6,17 @@ using namespace std; void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, Tensor dist2, Tensor idx1, Tensor idx2) { bool is_half = XYZ1.scalar_type() == at::kHalf; - at::Tensor xyz1 = at::ones_like(XYZ1); - at::Tensor xyz2 = at::ones_like(XYZ2); - at::Tensor distf1 = at::ones_like(dist1); - at::Tensor distf2 = at::ones_like(dist2); + at::Tensor xyz1 = XYZ1; + at::Tensor xyz2 = XYZ2; + at::Tensor distf1 = dist1; + at::Tensor distf2 = dist2; xyz1 = XYZ1.transpose(1, 2).transpose(0, 1); xyz2 = XYZ2.transpose(1, 2).transpose(0, 1); if (is_half) { xyz1 = xyz1.to(at::kFloat); xyz2 = xyz2.to(at::kFloat); - distf1 = dist1.to(at::kFloat); - distf2 = dist2.to(at::kFloat); + distf1 = distf1.to(at::kFloat); + distf2 = distf2.to(at::kFloat); } OpCommand cmd; cmd.Name("ChamferDistance") @@ -31,8 +31,8 @@ void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, distf1 = distf1.to(at::kHalf); distf2 = distf2.to(at::kHalf); } - dist1.copy_(distf1); - dist2.copy_(distf2); + dist1 = distf1; + dist2 = distf2; } void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1,