Skip to content

Commit

Permalink
[Enhance] Ensure metrics is not empty when saving best ckpts
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Dec 27, 2022
1 parent eb803f8 commit 508e561
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 80 deletions.
12 changes: 6 additions & 6 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,19 +294,19 @@ def after_val_epoch(self, runner, metrics):
runner (Runner): The runner of the training process.
metrics (dict): Evaluation results of all metrics
"""
if len(metrics) == 0:
runner.logger.warn(
'Since `metrics` is an empty dict, the behavior to save '
'the best checkpoint will be skipped in this evaluation.')
return

self._save_best_checkpoint(runner, metrics)

def _get_metric_score(self, metrics, key_indicator):
eval_res = OrderedDict()
if metrics is not None:
eval_res.update(metrics)

if len(eval_res) == 0:
warnings.warn(
'Since `eval_res` is an empty dict, the behavior to save '
'the best checkpoint will be skipped in this evaluation.')
return None

return eval_res[key_indicator]

def _save_checkpoint(self, runner) -> None:
Expand Down
142 changes: 68 additions & 74 deletions tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_after_val_epoch(self, tmp_path):
runner.work_dir = tmp_path
runner.epoch = 9
runner.model = Mock()
runner.logger.warn = Mock()
runner.message_hub = MessageHub.get_instance('test_after_val_epoch')

with pytest.raises(ValueError):
Expand All @@ -159,22 +160,11 @@ def test_after_val_epoch(self, tmp_path):
CheckpointHook(
interval=2, by_epoch=True, save_best='auto', rule='unsupport')

# if eval_res is an empty dict, print a warning information
with pytest.warns(UserWarning) as record_warnings:
eval_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='auto')
eval_hook._get_metric_score(None, None)
# Since there will be many warnings thrown, we just need to check
# if the expected exceptions are thrown
expected_message = (
'Since `eval_res` is an empty dict, the behavior to '
'save the best checkpoint will be skipped in this '
'evaluation.')
for warning in record_warnings:
if str(warning.message) == expected_message:
break
else:
assert False
# if metrics is an empty dict, print a warning information
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='auto')
checkpoint_hook.after_val_epoch(runner, {})
runner.logger.warn.assert_called_once()

# test error when number of rules and metrics are not same
with pytest.raises(AssertionError) as assert_error:
Expand All @@ -187,93 +177,97 @@ def test_after_val_epoch(self, tmp_path):
'"save_best", but got 3.')
assert error_message in str(assert_error.value)

# if save_best is None,no best_ckpt meta should be stored
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None)
eval_hook.before_train(runner)
eval_hook.after_val_epoch(runner, None)
# if save_best is None, no best_ckpt meta should be stored
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best=None)
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, {})
assert 'best_score' not in runner.message_hub.runtime_info
assert 'best_ckpt' not in runner.message_hub.runtime_info

# when `save_best` is set to `auto`, first metric will be used.
metrics = {'acc': 0.5, 'map': 0.3}
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto')
eval_hook.before_train(runner)
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='auto')
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
best_ckpt_name = 'best_acc_epoch_9.pth'
best_ckpt_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name)
assert eval_hook.key_indicators == ['acc']
assert eval_hook.rules == ['greater']
best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name)
assert checkpoint_hook.key_indicators == ['acc']
assert checkpoint_hook.rules == ['greater']
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path

# # when `save_best` is set to `acc`, it should update greater value
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc')
eval_hook.before_train(runner)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='acc')
checkpoint_hook.before_train(runner)
metrics['acc'] = 0.8
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.8

# # when `save_best` is set to `loss`, it should update less value
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss')
eval_hook.before_train(runner)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='loss')
checkpoint_hook.before_train(runner)
metrics['loss'] = 0.8
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
metrics['loss'] = 0.5
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5

# when `rule` is set to `less`,then it should update less value
# no matter what `save_best` is
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='acc', rule='less')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics['acc'] = 0.3
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.3

# # when `rule` is set to `greater`,then it should update greater value
# # no matter what `save_best` is
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='loss', rule='greater')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics['loss'] = 1.0
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 1.0

# test multi `save_best` with one rule
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, save_best=['acc', 'mIoU'], rule='greater')
assert eval_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater']
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert checkpoint_hook.rules == ['greater', 'greater']

# test multi `save_best` with multi rules
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, save_best=['FID', 'IS'], rule=['less', 'greater'])
assert eval_hook.key_indicators == ['FID', 'IS']
assert eval_hook.rules == ['less', 'greater']
assert checkpoint_hook.key_indicators == ['FID', 'IS']
assert checkpoint_hook.rules == ['less', 'greater']

# test multi `save_best` with default rule
eval_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
assert eval_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater']
checkpoint_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert checkpoint_hook.rules == ['greater', 'greater']
runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_save_multi_best')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics = dict(acc=0.5, mIoU=0.6)
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
best_acc_name = 'best_acc_epoch_9.pth'
best_acc_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_acc_name)
best_acc_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_acc_name)
best_mIoU_name = 'best_mIoU_epoch_9.pth'
best_mIoU_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_mIoU_name)
best_mIoU_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_mIoU_name)
assert 'best_score_acc' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score_acc') == 0.5
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
Expand All @@ -293,26 +287,26 @@ def test_after_val_epoch(self, tmp_path):

# check best ckpt name and best score
metrics = {'acc': 0.5, 'map': 0.3}
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, save_best='acc', rule='greater')
eval_hook.before_train(runner)
eval_hook.after_val_epoch(runner, metrics)
assert eval_hook.key_indicators == ['acc']
assert eval_hook.rules == ['greater']
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
assert checkpoint_hook.key_indicators == ['acc']
assert checkpoint_hook.rules == ['greater']
best_ckpt_name = 'best_acc_iter_9.pth'
best_ckpt_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name)
best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5

# check best score updating
metrics['acc'] = 0.666
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
best_ckpt_name = 'best_acc_iter_9.pth'
best_ckpt_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name)
best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
assert 'best_score' in runner.message_hub.runtime_info and \
Expand All @@ -326,21 +320,21 @@ def test_after_val_epoch(self, tmp_path):
interval=2, save_best='acc', rule=['greater', 'less'])

# check best checkpoint name with `by_epoch` is False
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
assert eval_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater']
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert checkpoint_hook.rules == ['greater', 'greater']
runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_save_multi_best_by_epoch_is_false')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics = dict(acc=0.5, mIoU=0.6)
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
best_acc_name = 'best_acc_iter_9.pth'
best_acc_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_acc_name)
best_acc_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_acc_name)
best_mIoU_name = 'best_mIoU_iter_9.pth'
best_mIoU_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_mIoU_name)
best_mIoU_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_mIoU_name)
assert 'best_score_acc' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score_acc') == 0.5
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
Expand Down

0 comments on commit 508e561

Please sign in to comment.