Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Ensure metrics is not empty when saving best ckpts #849

Merged
merged 3 commits into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from collections import OrderedDict
from math import inf
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union
Expand Down Expand Up @@ -294,20 +293,13 @@ def after_val_epoch(self, runner, metrics):
runner (Runner): The runner of the training process.
metrics (dict): Evaluation results of all metrics
"""
self._save_best_checkpoint(runner, metrics)

def _get_metric_score(self, metrics, key_indicator):
eval_res = OrderedDict()
if metrics is not None:
eval_res.update(metrics)

if len(eval_res) == 0:
warnings.warn(
'Since `eval_res` is an empty dict, the behavior to save '
if len(metrics) == 0:
runner.logger.warning(
'Since `metrics` is an empty dict, the behavior to save '
'the best checkpoint will be skipped in this evaluation.')
return None
return

return eval_res[key_indicator]
self._save_best_checkpoint(runner, metrics)

def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
Expand Down Expand Up @@ -385,7 +377,7 @@ def _save_best_checkpoint(self, runner, metrics) -> None:
# save best logic
# get score from messagehub
for key_indicator, rule in zip(self.key_indicators, self.rules):
key_score = self._get_metric_score(metrics, key_indicator)
key_score = metrics[key_indicator]

if len(self.key_indicators) == 1:
best_score_key = 'best_score'
Expand Down
142 changes: 68 additions & 74 deletions tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_after_val_epoch(self, tmp_path):
runner.work_dir = tmp_path
runner.epoch = 9
runner.model = Mock()
runner.logger.warning = Mock()
runner.message_hub = MessageHub.get_instance('test_after_val_epoch')

with pytest.raises(ValueError):
Expand All @@ -159,22 +160,11 @@ def test_after_val_epoch(self, tmp_path):
CheckpointHook(
interval=2, by_epoch=True, save_best='auto', rule='unsupport')

# if eval_res is an empty dict, print a warning information
with pytest.warns(UserWarning) as record_warnings:
eval_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='auto')
eval_hook._get_metric_score(None, None)
# Since there will be many warnings thrown, we just need to check
# if the expected exceptions are thrown
expected_message = (
'Since `eval_res` is an empty dict, the behavior to '
'save the best checkpoint will be skipped in this '
'evaluation.')
for warning in record_warnings:
if str(warning.message) == expected_message:
break
else:
assert False
# if metrics is an empty dict, print a warning information
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='auto')
checkpoint_hook.after_val_epoch(runner, {})
runner.logger.warning.assert_called_once()

# test error when number of rules and metrics are not same
with pytest.raises(AssertionError) as assert_error:
Expand All @@ -187,93 +177,97 @@ def test_after_val_epoch(self, tmp_path):
'"save_best", but got 3.')
assert error_message in str(assert_error.value)

# if save_best is None,no best_ckpt meta should be stored
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None)
eval_hook.before_train(runner)
eval_hook.after_val_epoch(runner, None)
# if save_best is None, no best_ckpt meta should be stored
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best=None)
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, {})
assert 'best_score' not in runner.message_hub.runtime_info
assert 'best_ckpt' not in runner.message_hub.runtime_info

# when `save_best` is set to `auto`, first metric will be used.
metrics = {'acc': 0.5, 'map': 0.3}
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto')
eval_hook.before_train(runner)
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='auto')
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
best_ckpt_name = 'best_acc_epoch_9.pth'
best_ckpt_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name)
assert eval_hook.key_indicators == ['acc']
assert eval_hook.rules == ['greater']
best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name)
assert checkpoint_hook.key_indicators == ['acc']
assert checkpoint_hook.rules == ['greater']
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path

# # when `save_best` is set to `acc`, it should update greater value
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc')
eval_hook.before_train(runner)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='acc')
checkpoint_hook.before_train(runner)
metrics['acc'] = 0.8
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.8

# # when `save_best` is set to `loss`, it should update less value
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss')
eval_hook.before_train(runner)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='loss')
checkpoint_hook.before_train(runner)
metrics['loss'] = 0.8
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
metrics['loss'] = 0.5
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5

# when `rule` is set to `less`,then it should update less value
# no matter what `save_best` is
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='acc', rule='less')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics['acc'] = 0.3
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.3

# # when `rule` is set to `greater`,then it should update greater value
# # no matter what `save_best` is
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='loss', rule='greater')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics['loss'] = 1.0
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 1.0

# test multi `save_best` with one rule
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, save_best=['acc', 'mIoU'], rule='greater')
assert eval_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater']
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert checkpoint_hook.rules == ['greater', 'greater']

# test multi `save_best` with multi rules
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, save_best=['FID', 'IS'], rule=['less', 'greater'])
assert eval_hook.key_indicators == ['FID', 'IS']
assert eval_hook.rules == ['less', 'greater']
assert checkpoint_hook.key_indicators == ['FID', 'IS']
assert checkpoint_hook.rules == ['less', 'greater']

# test multi `save_best` with default rule
eval_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
assert eval_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater']
checkpoint_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert checkpoint_hook.rules == ['greater', 'greater']
runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_save_multi_best')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics = dict(acc=0.5, mIoU=0.6)
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
best_acc_name = 'best_acc_epoch_9.pth'
best_acc_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_acc_name)
best_acc_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_acc_name)
best_mIoU_name = 'best_mIoU_epoch_9.pth'
best_mIoU_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_mIoU_name)
best_mIoU_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_mIoU_name)
assert 'best_score_acc' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score_acc') == 0.5
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
Expand All @@ -293,26 +287,26 @@ def test_after_val_epoch(self, tmp_path):

# check best ckpt name and best score
metrics = {'acc': 0.5, 'map': 0.3}
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, save_best='acc', rule='greater')
eval_hook.before_train(runner)
eval_hook.after_val_epoch(runner, metrics)
assert eval_hook.key_indicators == ['acc']
assert eval_hook.rules == ['greater']
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
assert checkpoint_hook.key_indicators == ['acc']
assert checkpoint_hook.rules == ['greater']
best_ckpt_name = 'best_acc_iter_9.pth'
best_ckpt_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name)
best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5

# check best score updating
metrics['acc'] = 0.666
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
best_ckpt_name = 'best_acc_iter_9.pth'
best_ckpt_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name)
best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
assert 'best_score' in runner.message_hub.runtime_info and \
Expand All @@ -326,21 +320,21 @@ def test_after_val_epoch(self, tmp_path):
interval=2, save_best='acc', rule=['greater', 'less'])

# check best checkpoint name with `by_epoch` is False
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
assert eval_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater']
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert checkpoint_hook.rules == ['greater', 'greater']
runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_save_multi_best_by_epoch_is_false')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics = dict(acc=0.5, mIoU=0.6)
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
best_acc_name = 'best_acc_iter_9.pth'
best_acc_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_acc_name)
best_acc_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_acc_name)
best_mIoU_name = 'best_mIoU_iter_9.pth'
best_mIoU_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_mIoU_name)
best_mIoU_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_mIoU_name)
assert 'best_score_acc' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score_acc') == 0.5
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
Expand Down