Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-472] Add parameter save_begin #1271

Merged
merged 16 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion docs/en/tutorials/hook.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ runner.train()
- Save the best checkpoints
- Specify the path to save the checkpoints
- Make checkpoints for publish
- Control the epoch number or iteration number at which checkpoint saving begins

For more features, please read the [CheckpointHook API documentation](mmengine.hooks.CheckpointHook).

The four features mentioned above are described below.
The six features mentioned above are described below.

- Save checkpoints by interval, and support saving them by epoch or iteration

Expand Down Expand Up @@ -129,6 +130,14 @@ The four features mentioned above are described below.
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
```

- Control the epoch number or iteration number at which checkpoint saving begins

If you want to set the number of epochs or iterations to control the start of saving weights, you can set the `save_begin` parameter, defaults to 0, which means saving checkpoints from the beginning of training. For example, if you train for a total of 10 epochs, and `save_begin` is set to 5, then the checkpoints for epochs 5, 6, 7, 8, 9, and 10 will be saved. If `interval=2`, only save checkpoints for epochs 5, 7 and 9.

```python
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5))
```

[LoggerHook](mmengine.hooks.LoggerHook) collects logs from different components of `Runner` and write them to terminal, JSON file, tensorboard and wandb .etc.

If we want to output (or save) the logs every 20 iterations, we can set the `interval` parameter and configure it as follows.
Expand Down
11 changes: 10 additions & 1 deletion docs/zh_cn/tutorials/hook.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ runner.train()
- 保存最优权重
- 指定保存权重的路径
- 制作发布用的权重
- 设置开始保存权重的 epoch 数或者 iteration 数

如需了解其他功能,请阅读 [CheckpointHook API 文档](mmengine.hooks.CheckpointHook)。

下面介绍上面提到的 4 个功能。
下面介绍上面提到的 6 个功能。

- 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重

Expand Down Expand Up @@ -130,6 +131,14 @@ runner.train()
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
```

- 设置开始保存权重的 epoch 数或者 iteration 数

如果想要设置控制开始保存权重的 epoch 数或者 iteration 数,可以设置 `save_begin` 参数,默认为 0,表示从训练开始就保存权重。例如,如果总共训练 10 个 epoch,并且 `save_begin` 设置为 5,则将保存第 5、6、7、8、9 和 10 个 epoch 的权重。如果 `interval=2`,则仅保存第 5、7 和 9 个 epoch 的权重。

```python
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5))
```

### LoggerHook

[LoggerHook](mmengine.hooks.LoggerHook) 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。
Expand Down
18 changes: 15 additions & 3 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class CheckpointHook(Hook):
publish model with keys in the list after training.
Defaults to None.
`New in version 0.7.1.`
save_begin (int): Control the epoch number or iteration number
at which checkpoint saving begins. Defaults to 0, which means
saving at the beginning.
`New in version 0.8.3.`

Examples:
>>> # Save best based on single metric
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
Expand Down Expand Up @@ -139,6 +144,7 @@ def __init__(self,
filename_tmpl: Optional[str] = None,
backend_args: Optional[dict] = None,
published_keys: Union[str, List[str], None] = None,
save_begin: int = 0,
**kwargs) -> None:
self.interval = interval
self.by_epoch = by_epoch
Expand Down Expand Up @@ -244,6 +250,10 @@ def __init__(self,
self.published_keys = published_keys

self.last_ckpt = None
if save_begin < 0:
raise ValueError(
'save_begin should not be less than 0, but got {save_begin}')
self.save_begin = save_begin
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved

def before_train(self, runner) -> None:
"""Finish all operations, related to checkpoint.
Expand Down Expand Up @@ -326,9 +336,9 @@ def after_train_epoch(self, runner) -> None:
return

# save checkpoint for following cases:
# 1. every ``self.interval`` epochs
# 1. every ``self.interval`` epochs which start at ``self.save_begin``
# 2. reach the last epoch of training
if self.every_n_epochs(runner, self.interval) or (
if self.every_n_epochs(runner, self.interval, self.save_begin) or (
self.save_last and self.is_last_train_epoch(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs')
Expand Down Expand Up @@ -644,8 +654,10 @@ def after_train_iter(self,

# save checkpoint for following cases:
# 1. every ``self.interval`` iterations
# which start at ``self.save_begin``
# 2. reach the last iteration of training
if self.every_n_train_iters(runner, self.interval) or \
if self.every_n_train_iters(runner, self.interval,
self.save_begin) or \
(self.save_last and
self.is_last_train_iter(runner)):
runner.logger.info(
Expand Down
14 changes: 10 additions & 4 deletions mmengine/hooks/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,21 @@ def _after_iter(self,
mode (str): Current mode of runner. Defaults to 'train'.
"""

def every_n_epochs(self, runner, n: int) -> bool:
def every_n_epochs(self, runner, n: int, start: int = 0) -> bool:
"""Test whether current epoch can be evenly divided by n.

Args:
runner (Runner): The runner of the training, validation or testing
process.
n (int): Whether current epoch can be evenly divided by n.
start (int): Starting from `start` to check the logic for
every n epochs. Defaults to 0.

Returns:
bool: Whether current epoch can be evenly divided by n.
"""
return (runner.epoch + 1) % n == 0 if n > 0 else False
dividend = runner.epoch + 1 - start
return dividend % n == 0 if dividend >= 0 and n > 0 else False

def every_n_inner_iters(self, batch_idx: int, n: int) -> bool:
"""Test whether current inner iteration can be evenly divided by n.
Expand All @@ -363,19 +366,22 @@ def every_n_inner_iters(self, batch_idx: int, n: int) -> bool:
"""
return (batch_idx + 1) % n == 0 if n > 0 else False

def every_n_train_iters(self, runner, n: int) -> bool:
def every_n_train_iters(self, runner, n: int, start: int = 0) -> bool:
"""Test whether current training iteration can be evenly divided by n.

Args:
runner (Runner): The runner of the training, validation or testing
process.
n (int): Whether current iteration can be evenly divided by n.
start (int): Starting from `start` to check the logic for
every n iterations. Defaults to 0.

Returns:
bool: Return True if the current iteration can be evenly divided
by n, otherwise False.
"""
return (runner.iter + 1) % n == 0 if n > 0 else False
dividend = runner.iter + 1 - start
return dividend % n == 0 if dividend >= 0 and n > 0 else False

def end_of_epoch(self, dataloader, batch_idx: int) -> bool:
"""Check whether the current iteration reaches the last iteration of
Expand Down
58 changes: 57 additions & 1 deletion tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def test_with_runner(self, training_type):
self.assertEqual(best_ckpt['meta']['epoch'], 0)
self.assertEqual(best_ckpt['meta']['iter'], 5)

# test save published keys
# Test save published keys
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']
runner = self.build_runner(cfg)
Expand All @@ -632,3 +632,59 @@ def test_with_runner(self, training_type):
any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files))

self.clear_work_dir()

# Test save_begin with interval=2, save_begin=5
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 2
cfg.default_hooks.checkpoint.save_begin = 5
runner = self.build_runner(cfg)
runner.train()

for i in range(5):
self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
for i in range(5, 11):
if (i - 5) % 2 == 1:
self.assertFalse(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
else:
self.assertTrue(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
self.clear_work_dir()

# Test save_begin with interval=2, save_begin=0
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 2
runner = self.build_runner(cfg)
runner.train()

for i in range(1, 11):
if i % 2 == 1:
self.assertFalse(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
else:
self.assertTrue(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
self.clear_work_dir()

# Test save_begin with interval=2, save_begin=1
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 2
cfg.default_hooks.checkpoint.save_begin = 1
runner = self.build_runner(cfg)
runner.train()

for i in range(1, 11):
if i % 2 == 1:
self.assertTrue(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
else:
self.assertFalse(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
self.clear_work_dir()
4 changes: 4 additions & 0 deletions tests/test_hooks/test_logger_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def test_init(self):

def test_after_train_iter(self):
# Test LoggerHook by iter.
# Avoid to compare `Runner.iter` (MagicMock) with other integers.
ori_every_n_train_iters = LoggerHook.every_n_train_iters
LoggerHook.every_n_train_iters = MagicMock(return_value=True)
runner = MagicMock()
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
Expand Down Expand Up @@ -112,6 +115,7 @@ def test_after_train_iter(self):
logger_hook = LoggerHook()
logger_hook.after_train_iter(runner, batch_idx=8)
runner.log_processor.get_log_after_iter.assert_called()
LoggerHook.every_n_train_iters = ori_every_n_train_iters

def test_after_val_epoch(self):
logger_hook = LoggerHook()
Expand Down
Loading