Skip to content

Commit

Permalink
Merge pull request #32 from huhongsun/rc41.x
Browse files Browse the repository at this point in the history
repair nms_rotated npu bug
  • Loading branch information
momo609 committed Jun 20, 2024
2 parents 0852bb2 + 85d0ce4 commit 0356569
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions mmcv/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,23 +452,37 @@ def nms_rotated(dets: Tensor,
flip_mat[-1] = -1
dets_cw = dets * flip_mat
else:
dets_cw = dets.clone()
dets_cw = dets
multi_label = labels is not None
if labels is None:
input_labels = scores.new_empty(0, dtype=torch.int)
else:
input_labels = labels
if dets.device.type in ('npu', 'mlu'):

if dets.device.type == 'mlu':
order = scores.new_empty(0, dtype=torch.long)
keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw,
input_labels, iou_threshold,
multi_label)
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1)
return dets, keep_inds

if dets.device.type == 'npu':
order = scores.new_empty(0, dtype=torch.long)
if dets.device.type == 'npu':
coefficient = 57.29578 # 180 / PI
coefficient = 57.29578 # 180 / PI
if dets.dtype == torch.float16:
dets_cw = dets_cw.float()
scores = scores.float()
for i in range(dets.size()[0]):
dets_cw[i][4] *= coefficient # radians to angle
else:
dets_cw = dets_cw.clone()
for i in range(dets.size()[0]):
dets_cw[i][4] *= coefficient # radians to angle
scores = scores.float()
keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw,
input_labels, iou_threshold,
multi_label)
if dets.dtype == torch.float16:
scores = scores.half()
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1)
return dets, keep_inds
Expand Down

0 comments on commit 0356569

Please sign in to comment.