Skip to content

Commit

Permalink
ultralytics 8.2.43 enable classes filter for end2end models (#13971)
Browse files Browse the repository at this point in the history
Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
Laughing-q and glenn-jocher committed Jun 25, 2024
1 parent b11f043 commit 87dba19
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO 馃殌, AGPL-3.0 license

__version__ = "8.2.42"
__version__ = "8.2.43"

import os

Expand Down
14 changes: 10 additions & 4 deletions ultralytics/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def non_max_suppression(
max_nms (int): The maximum number of boxes into torchvision.ops.nms().
max_wh (int): The maximum box width and height in pixels.
in_place (bool): If True, the input prediction tensor will be modified in place.
rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
Returns:
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
Expand All @@ -212,11 +213,16 @@ def non_max_suppression(
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
if classes is not None:
classes = torch.tensor(classes, device=prediction.device)

if prediction.shape[-1] == 6: # end-to-end model
return [pred[pred[:, 4] > conf_thres] for pred in prediction]
if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6)
output = [pred[pred[:, 4] > conf_thres] for pred in prediction]
if classes is not None:
output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
return output

bs = prediction.shape[0] # batch size
bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)
nc = nc or (prediction.shape[1] - 4) # number of classes
nm = prediction.shape[1] - nc - 4 # number of masks
mi = 4 + nc # mask start index
Expand Down Expand Up @@ -265,7 +271,7 @@ def non_max_suppression(

# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
x = x[(x[:, 5:6] == classes).any(1)]

# Check shape
n = x.shape[0] # number of boxes
Expand Down

0 comments on commit 87dba19

Please sign in to comment.