Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add test of get_hooks_info() #672

Merged
merged 14 commits into from
Nov 22, 2022
34 changes: 34 additions & 0 deletions mmengine/hooks/hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence, Union

from mmengine import is_method_overridden

DATA_BATCH = Optional[Union[dict, tuple, list]]


Expand All @@ -11,6 +13,13 @@ class Hook:
"""

priority = 'NORMAL'
stages = ('before_run', 'after_load_checkpoint', 'before_train',
'before_train_epoch', 'before_train_iter', 'after_train_iter',
'after_train_epoch', 'before_val', 'before_val_epoch',
'before_val_iter', 'after_val_iter', 'after_val_epoch',
'after_val', 'before_save_checkpoint', 'after_train',
'before_test', 'before_test_epoch', 'before_test_iter',
'after_test_iter', 'after_test_epoch', 'after_test', 'after_run')

def before_run(self, runner) -> None:
"""All subclasses should override this method, if they need any
Expand Down Expand Up @@ -416,3 +425,28 @@ def is_last_train_iter(self, runner) -> bool:
bool: Whether current iteration is the last train iteration.
"""
return runner.iter + 1 == runner.max_iters

def get_triggered_stages(self) -> list:
trigger_stages = set()
for stage in Hook.stages:
if is_method_overridden(stage, Hook, self):
trigger_stages.add(stage)

# some methods will be triggered in multi stages
# use this dict to map method to stages.
method_stages_map = {
'_before_epoch':
['before_train_epoch', 'before_val_epoch', 'before_test_epoch'],
'_after_epoch':
['after_train_epoch', 'after_val_epoch', 'after_test_epoch'],
'_before_iter':
['before_train_iter', 'before_val_iter', 'before_test_iter'],
'_after_iter':
['after_train_iter', 'after_val_iter', 'after_test_iter'],
}

for method, map_stages in method_stages_map.items():
if is_method_overridden(method, Hook, self):
trigger_stages.update(map_stages)

return list(trigger_stages)
26 changes: 26 additions & 0 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ def __init__(
self._hooks: List[Hook] = []
# register hooks to `self._hooks`
self.register_hooks(default_hooks, custom_hooks)
# log hooks information
self.logger.info(f'Hooks will be executed in the following '
f'order:\n{self.get_hooks_info()}')

# dump `cfg` to `work_dir`
self.dump_config()
Expand Down Expand Up @@ -1577,6 +1580,29 @@ def build_log_processor(

return log_processor # type: ignore

def get_hooks_info(self) -> str:
# Get hooks info in each stage
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
for hook in self.hooks:
try:
priority = Priority(hook.priority).name # type: ignore
except ValueError:
priority = hook.priority # type: ignore
classname = hook.__class__.__name__
hook_info = f'({priority:<12}) {classname:<35}'
for trigger_stage in hook.get_triggered_stages():
stage_hook_map[trigger_stage].append(hook_info)

stage_hook_infos = []
for stage in Hook.stages:
hook_infos = stage_hook_map[stage]
if len(hook_infos) > 0:
info = f'{stage}:\n'
info += '\n'.join(hook_infos)
info += '\n -------------------- '
stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos)

def load_or_resume(self) -> None:
"""load or resume checkpoint."""
if self._has_loaded:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,3 +2306,19 @@ def test_build_runner(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_runner2'
assert isinstance(RUNNERS.build(cfg), Runner)

def test_get_hooks_info(self):
# test get_hooks_info() function
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_get_hooks_info_from_test_runner_py'
cfg.runner_type = 'Runner'
runner = RUNNERS.build(cfg)
self.assertIsInstance(runner, Runner)
target_str = ('after_train_iter:\n'
'(VERY_HIGH ) RuntimeInfoHook \n'
'(NORMAL ) IterTimerHook \n'
'(BELOW_NORMAL) LoggerHook \n'
'(LOW ) ParamSchedulerHook \n'
'(VERY_LOW ) CheckpointHook \n')
self.assertIn(target_str, runner.get_hooks_info(),
'target string is not in logged hooks information.')