Skip to content

Commit

Permalink
modify chamfer
Browse files Browse the repository at this point in the history
  • Loading branch information
Annarine committed Jun 19, 2024
1 parent c94c723 commit 0852bb2
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ using namespace std;
void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
Tensor dist2, Tensor idx1, Tensor idx2) {
bool is_half = XYZ1.scalar_type() == at::kHalf;
at::Tensor xyz1 = at::ones_like(XYZ1);
at::Tensor xyz2 = at::ones_like(XYZ2);
at::Tensor distf1 = at::ones_like(dist1);
at::Tensor distf2 = at::ones_like(dist2);
at::Tensor xyz1 = XYZ1;
at::Tensor xyz2 = XYZ2;
at::Tensor distf1 = dist1;
at::Tensor distf2 = dist2;
xyz1 = XYZ1.transpose(1, 2).transpose(0, 1);
xyz2 = XYZ2.transpose(1, 2).transpose(0, 1);
if (is_half) {
xyz1 = xyz1.to(at::kFloat);
xyz2 = xyz2.to(at::kFloat);
distf1 = dist1.to(at::kFloat);
distf2 = dist2.to(at::kFloat);
distf1 = distf1.to(at::kFloat);
distf2 = distf2.to(at::kFloat);
}
OpCommand cmd;
cmd.Name("ChamferDistance")
Expand All @@ -31,8 +31,8 @@ void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
distf1 = distf1.to(at::kHalf);
distf2 = distf2.to(at::kHalf);
}
dist1.copy_(distf1);
dist2.copy_(distf2);
dist1 = distf1;
dist2 = distf2;
}

void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1,
Expand Down

0 comments on commit 0852bb2

Please sign in to comment.