Skip to content

Commit

Permalink
fix chamfer_distance cuda bug len should div to 4
Browse files Browse the repository at this point in the history
  • Loading branch information
Annarine committed May 23, 2024
1 parent 768d6cb commit 14f628d
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions tests/test_ops/test_chamfer_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def torch_to_np_type(dtype):
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
@pytest.mark.parametrize('dtype', [torch.half, torch.float32])
@pytest.mark.parametrize('shape', [(2, 600, 2), (2, 600, 2)],
[(1, 1, 2), (1, 1, 2)],
[(1, 7, 2), (1, 7, 2)])
@pytest.mark.parametrize('shape', [(2, 600, 2), (1, 1, 2), (7, 7, 2)])
def test_chamfer_distance_npu_dynamic_shape(dtype, device, shape):
bs = shape[0]
ns = shape[1]
Expand Down

0 comments on commit 14f628d

Please sign in to comment.