Skip to content

Commit

Permalink
[Fix] Fix chamfer_distance cuda bug len should div to 4 (#3113)
Browse files Browse the repository at this point in the history
  • Loading branch information
Annarine committed May 23, 2024
1 parent c0c81ce commit 2100900
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/common/cuda/chamfer_distance_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ __global__ void chamfer_distance_forward_cuda_kernel(int b, int n,
scalar_t y1 = xyz[(i * n + j) * 2 + 1];
int best_i = 0;
scalar_t best = 1e10;
int end_ka = end_k & (~2);
int end_ka = end_k & (~3);
if (end_ka == THREADS_PER_BLOCK) {
for (int k = 0; k < THREADS_PER_BLOCK; k += 4) {
#pragma unroll
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ops/test_chamfer_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def torch_to_np_type(dtype):
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
@pytest.mark.parametrize('dtype', [torch.half, torch.float32])
@pytest.mark.parametrize('shape', [(2, 600, 2), (2, 600, 2)])
@pytest.mark.parametrize('shape', [(2, 600, 2), (1, 1, 2), (7, 7, 2)])
def test_chamfer_distance_npu_dynamic_shape(dtype, device, shape):
bs = shape[0]
ns = shape[1]
Expand Down

0 comments on commit 2100900

Please sign in to comment.