diff --git a/mmengine/hooks/profiler_hook.py b/mmengine/hooks/profiler_hook.py index 6a242bd2b1..6339a5da92 100644 --- a/mmengine/hooks/profiler_hook.py +++ b/mmengine/hooks/profiler_hook.py @@ -48,12 +48,16 @@ class ProfilerHook(Hook): of generating handler. Defaults to None, which means profiling without an on_trace_ready.The Callable type needs to construct its own function that can handle 'torch.autograd.profiler.profile'. - Two officially recommended ways are provided, namely terminal - display or tensorboard display. The terminal display content can be - adjusted through 'EventList.table()' - from 'torch.autograd.profiler_util.py'. - If using tensorboard, save to '{work_dir}/tf_tracing_logs' - by default. + Two officially recommended ways are provided: + + - ``schedule=dict(type='log_trace')``: Print the profiling result + in the terminal. See more details in the `PyTorch official tutorial`_. + The configurable arguments are the same as + ``prof.key_averages().table`` + - ``scheduler=dict(type='tb_trace')``: Profile the performance + with tensorboard. See more details in the tutorial + `profile with tensorboard`_. + record_shapes (bool): Save information about operator's input shapes. Defaults to False. profile_memory (bool): Track tensor memory allocation/deallocation. @@ -67,11 +71,20 @@ class ProfilerHook(Hook): JSON format. Chrome use 'chrome://tracing' view json file. Defaults to None, which means profiling does not store json files. + Warnings: + The profiler will be closed after ``profile_times`` iterations + automatically. Please make sure the configuration of your scheduler + will not close the profiler before the iteration reach the value of + ``profile_times`` + Examples: >>> # tensorboard trace >>> trace_config = dict(type='tb_trace') >>> profiler_hook_cfg = dict(on_trace_ready=trace_config) - """ + + .. _PyTorch official tutorial: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-profiler-to-analyze-execution-time + .. _profile with tensorboard: https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html#pytorch-profiler-with-tensorboard + """ # noqa: E501 priority = 'VERY_LOW' def __init__(self, @@ -135,8 +148,8 @@ def __init__(self, self.with_flops = with_flops self.json_trace_path = json_trace_path + self._closed = False - @master_only def before_run(self, runner): """Initialize the profiler. @@ -212,23 +225,23 @@ def _log_handler(_profile): f'but got {self.on_trace_ready}') return _on_trace_ready - @master_only def after_train_epoch(self, runner): """Determine if the content is exported.""" - if self.by_epoch and runner.epoch == self.profile_times - 1: + # `after_train_epoch` will also be called in IterBasedTrainLoop. + # Here we check `self._closed` to avoid exiting twice. + if not self._closed: self._export_chrome_trace(runner) - @master_only def after_train_iter(self, runner, batch_idx, data_batch, outputs): - """Update the content according to the schedule, and determine if the - content is exported.""" - if self.schedule is None: + """profiler will call `step` method if it is not closed.""" + if not self._closed: self.profiler.step() - if not self.by_epoch and runner.iter == self.profile_times - 1: + if runner.iter == self.profile_times - 1 and not self.by_epoch: self._export_chrome_trace(runner) def _export_chrome_trace(self, runner): """Exporting content.""" + self._closed = True runner.logger.info('profiler may take a few minutes...') self.profiler.__exit__(None, None, None) if self.json_trace_path is not None: diff --git a/tests/test_hooks/test_profiler_hook.py b/tests/test_hooks/test_profiler_hook.py index 974d0f48c6..2db6df01b6 100644 --- a/tests/test_hooks/test_profiler_hook.py +++ b/tests/test_hooks/test_profiler_hook.py @@ -85,7 +85,7 @@ def test_parse_trace_config_tensorboard(self): dict( type='ProfilerHook', on_trace_ready=dict( - type='tb_trace', dir_name='/home/baymax/RunTime/tb')) + type='tb_trace', dir_name=self.temp_dir.name)) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() @@ -143,9 +143,6 @@ def test_after_train_iter(self): runner.iter = 9 hook = ProfilerHook(by_epoch=False, profile_times=10, schedule=None) - hook.before_run(runner) - hook.profiler.__exit__(None, None, None) - hook.profiler = MagicMock() hook.after_train_iter(runner, 1, 1, 1) hook.profiler.__exit__.assert_called_once() @@ -154,12 +151,9 @@ def test_after_train_iter(self): hook = ProfilerHook( by_epoch=False, schedule=dict(wait=1, warmup=1, active=3, repeat=1)) - hook.before_run(runner) - hook.profiler.__exit__(None, None, None) - hook.profiler = MagicMock() hook.after_train_iter(runner, 1, 1, 1) - hook.profiler.step.assert_not_called() + hook.profiler.step.assert_called_once() def test_with_runner(self): self.epoch_based_cfg['custom_hooks'] = [