Skip to content

Commit

Permalink
chamfer_distance
Browse files Browse the repository at this point in the history
  • Loading branch information
Annarine committed Dec 25, 2023
1 parent 4d1e6fb commit a42224c
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 52 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ We implement common ops used in detection, segmentation, etc.
| BorderAlign | || | | |
| BoxIouRotated |||| ||
| BoxIouQuadri ||| | | |
| CARAFE | ||| | |
| CARAFE | ||| | |
| ChamferDistance | || | | |
| CrissCrossAttention | || | | |
| ContourExpand || | | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| BoxIouRotated |||| ||
| BoxIouQuadri ||| | | |
| CARAFE | ||| | |
| ChamferDistance | || | | |
| ChamferDistance | || | | |
| CrissCrossAttention | || | | |
| ContourExpand || | | | |
| ConvexIoU | || | | |
Expand Down
4 changes: 2 additions & 2 deletions mmcv/ops/chamfer_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def backward(ctx,
device = grad_dist1.device
grad_dist1 = grad_dist1.contiguous()
grad_dist2 = grad_dist2.contiguous()
grad_xyz1 = torch.zeros(xyz1.size()).to(device)
grad_xyz2 = torch.zeros(xyz2.size()).to(device)
grad_xyz1 = torch.zeros(xyz1.size()).type(xyz1.dtype).to(device)
grad_xyz2 = torch.zeros(xyz2.size()).type(xyz2.dtype).to(device)

ext_module.chamfer_distance_backward(xyz1, xyz2, idx1, idx2,
grad_dist1, grad_dist2, grad_xyz1,
Expand Down
17 changes: 15 additions & 2 deletions mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
Tensor dist2, Tensor idx1, Tensor idx2) {
at::Tensor xyz1 = at::ones_like(XYZ1);
at::Tensor xyz2 = at::ones_like(XYZ2);
xyz1 = XYZ1.transpose(1,2);
xyz2 = XYZ2.transpose(1,2);
xyz1 = XYZ1.transpose(1,2).transpose(0,1);
xyz2 = XYZ2.transpose(1,2).transpose(0,1);
OpCommand cmd;
cmd.Name("ChamferDistance")
.Input(xyz1)
Expand All @@ -21,7 +21,20 @@ void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
.Run();
}

void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor grad_dist1, Tensor grad_dist2,
Tensor grad_xyz1, Tensor grad_xyz2) {
EXEC_NPU_CMD(aclnnChamferDistanceBackward, xyz1, xyz2, idx1, idx2,
grad_dist1, grad_dist2, grad_xyz1, grad_xyz2);
}

void chamfer_distance_forward_impl(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
Tensor dist2, Tensor idx1, Tensor idx2);
REGISTER_NPU_IMPL(chamfer_distance_forward_impl,
chamfer_distance_forward_npu);

void chamfer_distance_backward_impl(Tensor xyz1, Tensor xyz2, Tensor idx1, Tensor idx2,
Tensor grad_dist1, Tensor grad_dist2,
Tensor grad_xyz1, Tensor grad_xyz2);
REGISTER_NPU_IMPL(chamfer_distance_backward_impl,
chamfer_distance_backward_npu);
107 changes: 61 additions & 46 deletions tests/test_ops/test_chamfer_distance.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,72 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch

from mmcv.ops import chamfer_distance
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_chamfer_distance():
pointset1 = torch.tensor(
[[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]],
[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]],
[[1.6, 9.99], [2.3, 9.99], [2.3, 10.39], [1.6, 10.39]]],
device='cuda',
requires_grad=True)
def chamfer_distance_forward_gloden(xyz1, xyz2, dtype):
bs, ns, ss = xyz1.shape
dist1 = np.zeros((bs, ns)).astype(torch_type_trans(dtype))
dist2 = np.zeros((bs, ns)).astype(torch_type_trans(dtype))
idx1 = np.zeros((bs, ns)).astype('int32')
idx2 = np.zeros((bs, ns)).astype('int32')
for b1 in range(bs):
for n1 in range(ns):
x1, y1 = xyz1[b1][n1]
dist1[b1][n1] = 10000000
for n2 in range(ns):
x2, y2 = xyz2[b1][n2]
dst = (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)
if dist1[b1][n1] > dst:
dist1[b1][n1] = dst
idx1[b1][n1] = n2
for b1 in range(bs):
for n1 in range(ns):
x1, y1 = xyz2[b1][n1]
dist2[b1][n1] = 10000000
for n2 in range(ns):
x2, y2 = xyz1[b1][n2]
dst = (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)
if dist2[b1][n1] > dst:
dist2[b1][n1] = dst
idx2[b1][n1] = n2
return [dist1, dist2, idx1, idx2]

pointset2 = torch.tensor(
[[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]],
[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]],
[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]]],
device='cuda',
requires_grad=True)

expected_dist1 = torch.tensor(
[[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900],
[0.5200, 0.6500, 0.4900, 0.3600]],
device='cuda')
expected_dist2 = torch.tensor(
[[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900],
[0.7200, 0.8500, 0.4900, 0.3600]],
device='cuda')
def torch_type_trans(dtype):
if dtype == torch.half:
return np.float16
elif dtype == torch.float32:
return np.float32

expected_pointset1_grad = torch.tensor(
[[[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000],
[0.6000, 0.0000]],
[[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000],
[-0.6000, 0.0000]],
[[1.2000, -0.8000], [-1.4000, -0.8000], [-1.4000, 0.0000],
[1.2000, 0.0000]]],
device='cuda')

expected_pointset2_grad = torch.tensor(
[[[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000],
[-0.6000, 0.0000]],
[[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000],
[0.6000, 0.0000]],
[[0.0000, 0.0000], [0.0000, 0.0000], [2.8000, 0.8000],
[-2.4000, 0.8000]]],
device='cuda')

dist1, dist2, idx1, idx2 = chamfer_distance(pointset1, pointset2)
dist1.backward(torch.ones_like(dist1))
assert torch.allclose(dist1, expected_dist1, 1e-2)
assert torch.allclose(dist2, expected_dist2, 1e-2)
assert torch.allclose(pointset1.grad.data, expected_pointset1_grad, 1e-2)
assert torch.allclose(pointset2.grad.data, expected_pointset2_grad, 1e-2)
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
@pytest.mark.parametrize('dtype', [torch.half, torch.float32])
@pytest.mark.parametrize('shape', [(2, 600, 2), (2, 600, 2)])
def test_chamfer_distance_npu_dynamic_shape(dtype, device, shape):
bs = shape[0]
ns = shape[1]
xyz1 = np.random.uniform(-10.0, 10.0,
(bs, ns, 2)).astype(torch_type_trans(dtype))
xyz2 = np.random.uniform(-10.0, 10.0,
(bs, ns, 2)).astype(torch_type_trans(dtype))
xyz1_npu = torch.tensor(xyz1, dtype=dtype).to(device)
xyz2_npu = torch.tensor(xyz2, dtype=dtype).to(device)
expected_output = chamfer_distance_forward_gloden(xyz1, xyz2, dtype)
output = chamfer_distance(xyz1_npu, xyz2_npu)
assert np.allclose(output[0].cpu().numpy(), expected_output[0], 1e-3, 1e-4)
assert np.allclose(output[1].cpu().numpy(), expected_output[1], 1e-3, 1e-4)
assert np.allclose(output[2].cpu().numpy(), expected_output[2], 1e-3, 1e-4)
assert np.allclose(output[3].cpu().numpy(), expected_output[3], 1e-3, 1e-4)

0 comments on commit a42224c

Please sign in to comment.