Skip to content

Commit

Permalink
[Enhancement] Add an option to control whether to use progress bar in…
Browse files Browse the repository at this point in the history
… BaseInference (#1135)

* show_track

* Update mmengine/infer/infer.py

* Update mmengine/infer/infer.py

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
W-ZN and zhouzaida committed May 9, 2023
1 parent d4bb561 commit 8a0fae0
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions mmengine/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class BaseInferencer(metaclass=InferencerMeta):
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to None.
show_progress (bool): Control whether to display the progress bar during
the inference process. Defaults to True.
`New in version 0.7.4.`
Note:
Since ``Inferencer`` could be used to infer batch data,
Expand All @@ -139,7 +142,8 @@ def __init__(self,
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = None) -> None:
scope: Optional[str] = None,
show_progress: bool = True) -> None:
if scope is None:
default_scope = DefaultScope.get_current_instance()
if default_scope is not None:
Expand Down Expand Up @@ -178,6 +182,7 @@ def __init__(self,
self.collate_fn = self._init_collate(cfg)
self.visualizer = self._init_visualizer(cfg)
self.cfg = cfg
self.show_progress = show_progress

def __call__(
self,
Expand Down Expand Up @@ -213,7 +218,8 @@ def __call__(
inputs = self.preprocess(
ori_inputs, batch_size=batch_size, **preprocess_kwargs)
preds = []
for data in track(inputs, description='Inference'):
for data in (track(inputs, description='Inference')
if self.show_progress else inputs):
preds.extend(self.forward(data, **forward_kwargs))
visualization = self.visualize(
ori_inputs, preds,
Expand Down

0 comments on commit 8a0fae0

Please sign in to comment.