From 5289e49284f74859aea48b7c9b96eef07bcb753e Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 09:48:25 +0800 Subject: [PATCH 01/14] [CodeCamp2023-472] Add parameter save_begin --- docs/en/tutorials/hook.md | 11 ++++++++++- docs/zh_cn/tutorials/hook.md | 11 ++++++++++- mmengine/hooks/checkpoint_hook.py | 12 ++++++++++++ tests/test_hooks/test_checkpoint_hook.py | 22 +++++++++++++++++++++- 4 files changed, 53 insertions(+), 3 deletions(-) diff --git a/docs/en/tutorials/hook.md b/docs/en/tutorials/hook.md index 912f7502b8..c450b93133 100644 --- a/docs/en/tutorials/hook.md +++ b/docs/en/tutorials/hook.md @@ -71,10 +71,11 @@ runner.train() - Save the best checkpoints - Specify the path to save the checkpoints - Make checkpoints for publish +- Control the epoch number or iteration number at which checkpoint saving begins For more features, please read the [CheckpointHook API documentation](mmengine.hooks.CheckpointHook). -The four features mentioned above are described below. +The six features mentioned above are described below. - Save checkpoints by interval, and support saving them by epoch or iteration @@ -129,6 +130,14 @@ The four features mentioned above are described below. default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict'])) ``` +- Control the epoch number or iteration number at which checkpoint saving begins + + If you want to set the number of epochs or iterations to control the start of saving weights, you can set the `save_begin` parameter, defaults to 1, which means saving checkpoints from the beginning of training. For example, if you train for a total of 10 epochs, and `save_begin` is set to 6, then the checkpoints for epochs 6, 7, 8, 9, and 10 will be saved. If `interval=2`, only save checkpoints for epochs 6, 8 and 10. + + ```python + default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=6)) + ``` + [LoggerHook](mmengine.hooks.LoggerHook) collects logs from different components of `Runner` and write them to terminal, JSON file, tensorboard and wandb .etc. If we want to output (or save) the logs every 20 iterations, we can set the `interval` parameter and configure it as follows. diff --git a/docs/zh_cn/tutorials/hook.md b/docs/zh_cn/tutorials/hook.md index 63e6380477..a8e9ee636c 100644 --- a/docs/zh_cn/tutorials/hook.md +++ b/docs/zh_cn/tutorials/hook.md @@ -72,10 +72,11 @@ runner.train() - 保存最优权重 - 指定保存权重的路径 - 制作发布用的权重 +- 设置开始保存权重的 epoch 数或者 iteration 数 如需了解其他功能,请阅读 [CheckpointHook API 文档](mmengine.hooks.CheckpointHook)。 -下面介绍上面提到的 4 个功能。 +下面介绍上面提到的 6 个功能。 - 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重 @@ -130,6 +131,14 @@ runner.train() default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict'])) ``` +- 设置开始保存权重的 epoch 数或者 iteration 数 + + 如果想要设置控制开始保存权重的 epoch 数或者 iteration 数,可以设置 `save_begin` 参数,默认为1,表示从训练开始就保存权重。例如,如果总共训练 10 个 epoch,并且 `save_begin` 设置为 6,则将保存第 6、7、8、9 和 10 个 epoch 的权重。如果`interval=2`,则仅保存epoch 6、8和10的权重。 + + ```python + default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=6)) + ``` + ### LoggerHook [LoggerHook](mmengine.hooks.LoggerHook) 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。 diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 0b0f8d9a6e..1c4c0299a1 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -92,6 +92,10 @@ class CheckpointHook(Hook): publish model with keys in the list after training. Defaults to None. `New in version 0.7.1.` + save_begin (int) : Control the epoch number or iteration number + at which checkpoint saving begins. Defaults to 1, which means + saving at the beginning. + Examples: >>> # Save best based on single metric >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', @@ -139,6 +143,7 @@ def __init__(self, filename_tmpl: Optional[str] = None, backend_args: Optional[dict] = None, published_keys: Union[str, List[str], None] = None, + save_begin: int = 1, **kwargs) -> None: self.interval = interval self.by_epoch = by_epoch @@ -244,6 +249,7 @@ def __init__(self, self.published_keys = published_keys self.last_ckpt = None + self.save_begin = save_begin def before_train(self, runner) -> None: """Finish all operations, related to checkpoint. @@ -325,6 +331,9 @@ def after_train_epoch(self, runner) -> None: if not self.by_epoch: return + if runner.epoch + 1 < self.save_begin: + return + # save checkpoint for following cases: # 1. every ``self.interval`` epochs # 2. reach the last epoch of training @@ -642,6 +651,9 @@ def after_train_iter(self, if self.by_epoch: return + if runner.iter + 1 < self.save_begin: + return + # save checkpoint for following cases: # 1. every ``self.interval`` iterations # 2. reach the last iteration of training diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index e4e982721d..7d1813615b 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -622,7 +622,7 @@ def test_with_runner(self, training_type): self.assertEqual(best_ckpt['meta']['epoch'], 0) self.assertEqual(best_ckpt['meta']['iter'], 5) - # test save published keys + # Test save published keys cfg = copy.deepcopy(common_cfg) cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict'] runner = self.build_runner(cfg) @@ -632,3 +632,23 @@ def test_with_runner(self, training_type): any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files)) self.clear_work_dir() + + # Test save_begin with interval=2 + cfg = copy.deepcopy(common_cfg) + cfg.default_hooks.checkpoint.interval = 2 + cfg.default_hooks.checkpoint.save_begin = 6 + runner = self.build_runner(cfg) + runner.train() + for i in range(6): + self.assertFalse( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + for i in range(6, 11): + if i % 2: + self.assertFalse( + osp.isfile( + osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + else: + self.assertTrue( + osp.isfile( + osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.clear_work_dir() From 8030023cfc16d7cc360bfb039fc03c740e285dd0 Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 11:11:42 +0800 Subject: [PATCH 02/14] Update docs/zh_cn/tutorials/hook.md Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> --- docs/zh_cn/tutorials/hook.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/zh_cn/tutorials/hook.md b/docs/zh_cn/tutorials/hook.md index a8e9ee636c..14bf922a95 100644 --- a/docs/zh_cn/tutorials/hook.md +++ b/docs/zh_cn/tutorials/hook.md @@ -133,7 +133,7 @@ runner.train() - 设置开始保存权重的 epoch 数或者 iteration 数 - 如果想要设置控制开始保存权重的 epoch 数或者 iteration 数,可以设置 `save_begin` 参数,默认为1,表示从训练开始就保存权重。例如,如果总共训练 10 个 epoch,并且 `save_begin` 设置为 6,则将保存第 6、7、8、9 和 10 个 epoch 的权重。如果`interval=2`,则仅保存epoch 6、8和10的权重。 + 如果想要设置控制开始保存权重的 epoch 数或者 iteration 数,可以设置 `save_begin` 参数,默认为 1,表示从训练开始就保存权重。例如,如果总共训练 10 个 epoch,并且 `save_begin` 设置为 6,则将保存第 6、7、8、9 和 10 个 epoch 的权重。如果 `interval=2`,则仅保存第 6、8 和 10 个 epoch 的权重。 ```python default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=6)) From c54944b76ee9393db9dc96aff1aff8c84d39f109 Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 11:11:49 +0800 Subject: [PATCH 03/14] Update tests/test_hooks/test_checkpoint_hook.py Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> --- tests/test_hooks/test_checkpoint_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index 7d1813615b..b702bae875 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -643,7 +643,7 @@ def test_with_runner(self, training_type): self.assertFalse( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) for i in range(6, 11): - if i % 2: + if i % 2 == 1: self.assertFalse( osp.isfile( osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) From eeaa117cbd4f3fdb8f89c631d6778ee0d29418ed Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 15:30:11 +0800 Subject: [PATCH 04/14] update control range of save_begin --- docs/en/tutorials/hook.md | 4 +-- docs/zh_cn/tutorials/hook.md | 4 +-- mmengine/hooks/checkpoint_hook.py | 12 ++++--- mmengine/hooks/hook.py | 12 ++++--- tests/test_hooks/test_checkpoint_hook.py | 44 +++++++++++++++++++++--- 5 files changed, 59 insertions(+), 17 deletions(-) diff --git a/docs/en/tutorials/hook.md b/docs/en/tutorials/hook.md index c450b93133..ac969bf984 100644 --- a/docs/en/tutorials/hook.md +++ b/docs/en/tutorials/hook.md @@ -132,10 +132,10 @@ The six features mentioned above are described below. - Control the epoch number or iteration number at which checkpoint saving begins - If you want to set the number of epochs or iterations to control the start of saving weights, you can set the `save_begin` parameter, defaults to 1, which means saving checkpoints from the beginning of training. For example, if you train for a total of 10 epochs, and `save_begin` is set to 6, then the checkpoints for epochs 6, 7, 8, 9, and 10 will be saved. If `interval=2`, only save checkpoints for epochs 6, 8 and 10. + If you want to set the number of epochs or iterations to control the start of saving weights, you can set the `save_begin` parameter, defaults to 0, which means saving checkpoints from the beginning of training. For example, if you train for a total of 10 epochs, and `save_begin` is set to 5, then the checkpoints for epochs 5, 6, 7, 8, 9, and 10 will be saved. If `interval=2`, only save checkpoints for epochs 5, 7 and 9. ```python - default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=6)) + default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5)) ``` [LoggerHook](mmengine.hooks.LoggerHook) collects logs from different components of `Runner` and write them to terminal, JSON file, tensorboard and wandb .etc. diff --git a/docs/zh_cn/tutorials/hook.md b/docs/zh_cn/tutorials/hook.md index 14bf922a95..36f931446d 100644 --- a/docs/zh_cn/tutorials/hook.md +++ b/docs/zh_cn/tutorials/hook.md @@ -133,10 +133,10 @@ runner.train() - 设置开始保存权重的 epoch 数或者 iteration 数 - 如果想要设置控制开始保存权重的 epoch 数或者 iteration 数,可以设置 `save_begin` 参数,默认为 1,表示从训练开始就保存权重。例如,如果总共训练 10 个 epoch,并且 `save_begin` 设置为 6,则将保存第 6、7、8、9 和 10 个 epoch 的权重。如果 `interval=2`,则仅保存第 6、8 和 10 个 epoch 的权重。 + 如果想要设置控制开始保存权重的 epoch 数或者 iteration 数,可以设置 `save_begin` 参数,默认为 0,表示从训练开始就保存权重。例如,如果总共训练 10 个 epoch,并且 `save_begin` 设置为 5,则将保存第 5、6、7、8、9 和 10 个 epoch 的权重。如果 `interval=2`,则仅保存第 5、7 和 9 个 epoch 的权重。 ```python - default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=6)) + default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5)) ``` ### LoggerHook diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 1c4c0299a1..655bb73765 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -93,7 +93,7 @@ class CheckpointHook(Hook): Defaults to None. `New in version 0.7.1.` save_begin (int) : Control the epoch number or iteration number - at which checkpoint saving begins. Defaults to 1, which means + at which checkpoint saving begins. Defaults to 0, which means saving at the beginning. Examples: @@ -143,7 +143,7 @@ def __init__(self, filename_tmpl: Optional[str] = None, backend_args: Optional[dict] = None, published_keys: Union[str, List[str], None] = None, - save_begin: int = 1, + save_begin: int = 0, **kwargs) -> None: self.interval = interval self.by_epoch = by_epoch @@ -335,9 +335,9 @@ def after_train_epoch(self, runner) -> None: return # save checkpoint for following cases: - # 1. every ``self.interval`` epochs + # 1. every ``self.interval`` epochs which start at ``self.save_begin`` # 2. reach the last epoch of training - if self.every_n_epochs(runner, self.interval) or ( + if self.every_n_epochs(runner, self.interval, self.save_begin) or ( self.save_last and self.is_last_train_epoch(runner)): runner.logger.info( f'Saving checkpoint at {runner.epoch + 1} epochs') @@ -656,8 +656,10 @@ def after_train_iter(self, # save checkpoint for following cases: # 1. every ``self.interval`` iterations + # which start at ``self.save_begin`` # 2. reach the last iteration of training - if self.every_n_train_iters(runner, self.interval) or \ + if self.every_n_train_iters(runner, self.interval, + self.save_begin) or \ (self.save_last and self.is_last_train_iter(runner)): runner.logger.info( diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 941eb86537..5ab723d2fc 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -335,18 +335,20 @@ def _after_iter(self, mode (str): Current mode of runner. Defaults to 'train'. """ - def every_n_epochs(self, runner, n: int) -> bool: + def every_n_epochs(self, runner, n: int, st: int = 0) -> bool: """Test whether current epoch can be evenly divided by n. Args: runner (Runner): The runner of the training, validation or testing process. n (int): Whether current epoch can be evenly divided by n. + st(int): The parameter of save_begin, controlling the + epoch number at which checkpoint saving begins. Returns: bool: Whether current epoch can be evenly divided by n. """ - return (runner.epoch + 1) % n == 0 if n > 0 else False + return (runner.epoch + 1 - st) % n == 0 if n > 0 else False def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: """Test whether current inner iteration can be evenly divided by n. @@ -363,19 +365,21 @@ def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: """ return (batch_idx + 1) % n == 0 if n > 0 else False - def every_n_train_iters(self, runner, n: int) -> bool: + def every_n_train_iters(self, runner, n: int, st: int = 0) -> bool: """Test whether current training iteration can be evenly divided by n. Args: runner (Runner): The runner of the training, validation or testing process. n (int): Whether current iteration can be evenly divided by n. + st(int): The parameter of save_begin, controlling the + iteration number at which checkpoint saving begins. Returns: bool: Return True if the current iteration can be evenly divided by n, otherwise False. """ - return (runner.iter + 1) % n == 0 if n > 0 else False + return (runner.iter + 1 - st) % n == 0 if n > 0 else False def end_of_epoch(self, dataloader, batch_idx: int) -> bool: """Check whether the current iteration reaches the last iteration of diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index b702bae875..d731a42b76 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -633,16 +633,34 @@ def test_with_runner(self, training_type): self.clear_work_dir() - # Test save_begin with interval=2 + # Test save_begin with interval=2, save_begin=5 cfg = copy.deepcopy(common_cfg) cfg.default_hooks.checkpoint.interval = 2 - cfg.default_hooks.checkpoint.save_begin = 6 + cfg.default_hooks.checkpoint.save_begin = 5 runner = self.build_runner(cfg) runner.train() - for i in range(6): + + for i in range(5): self.assertFalse( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) - for i in range(6, 11): + for i in range(5, 11): + if (i - 5) % 2 == 1: + self.assertFalse( + osp.isfile( + osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + else: + self.assertTrue( + osp.isfile( + osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.clear_work_dir() + + # Test save_begin with interval=2, save_begin=0 + cfg = copy.deepcopy(common_cfg) + cfg.default_hooks.checkpoint.interval = 2 + runner = self.build_runner(cfg) + runner.train() + + for i in range(1, 11): if i % 2 == 1: self.assertFalse( osp.isfile( @@ -652,3 +670,21 @@ def test_with_runner(self, training_type): osp.isfile( osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) self.clear_work_dir() + + # Test save_begin with interval=2, save_begin=1 + cfg = copy.deepcopy(common_cfg) + cfg.default_hooks.checkpoint.interval = 2 + cfg.default_hooks.checkpoint.save_begin = 1 + runner = self.build_runner(cfg) + runner.train() + + for i in range(1, 11): + if i % 2 == 1: + self.assertTrue( + osp.isfile( + osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + else: + self.assertFalse( + osp.isfile( + osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.clear_work_dir() From 3e04557a4753eb4bb26990556aea5355bf328049 Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:16:45 +0800 Subject: [PATCH 05/14] Update mmengine/hooks/hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/hooks/hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 5ab723d2fc..38d4bf91cf 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -335,7 +335,7 @@ def _after_iter(self, mode (str): Current mode of runner. Defaults to 'train'. """ - def every_n_epochs(self, runner, n: int, st: int = 0) -> bool: + def every_n_epochs(self, runner, n: int, start: int = 0) -> bool: """Test whether current epoch can be evenly divided by n. Args: From dc0ba7ca741d6ee9f96a58877884a71ec0b0fd54 Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:16:52 +0800 Subject: [PATCH 06/14] Update mmengine/hooks/checkpoint_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/hooks/checkpoint_hook.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 655bb73765..5ccf472959 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -92,9 +92,10 @@ class CheckpointHook(Hook): publish model with keys in the list after training. Defaults to None. `New in version 0.7.1.` - save_begin (int) : Control the epoch number or iteration number + save_begin (int): Control the epoch number or iteration number at which checkpoint saving begins. Defaults to 0, which means saving at the beginning. + `New in version 0.8.3.` Examples: >>> # Save best based on single metric From 20cc16cdf611ac3b101f4b9d4c46797b15d1455c Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:17:07 +0800 Subject: [PATCH 07/14] Update mmengine/hooks/checkpoint_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/hooks/checkpoint_hook.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 5ccf472959..2c3f9387fe 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -250,6 +250,8 @@ def __init__(self, self.published_keys = published_keys self.last_ckpt = None + if save_begin < 0: + raise ValueError('save_begin should not less than 0, but got {save_begin}') self.save_begin = save_begin def before_train(self, runner) -> None: From 24e5a400cbaa357919759d5c229e68147ea4e126 Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:17:29 +0800 Subject: [PATCH 08/14] Update mmengine/hooks/hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/hooks/hook.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 38d4bf91cf..97657a0544 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -342,8 +342,8 @@ def every_n_epochs(self, runner, n: int, start: int = 0) -> bool: runner (Runner): The runner of the training, validation or testing process. n (int): Whether current epoch can be evenly divided by n. - st(int): The parameter of save_begin, controlling the - epoch number at which checkpoint saving begins. + start (int): Starting from `start` to check the logic for every n epochs. + Defaults to 0. Returns: bool: Whether current epoch can be evenly divided by n. From 7132692c5587b1d8162cadbabbaa30583446e6df Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:17:50 +0800 Subject: [PATCH 09/14] Update mmengine/hooks/hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/hooks/hook.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 97657a0544..ec33403531 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -348,7 +348,8 @@ def every_n_epochs(self, runner, n: int, start: int = 0) -> bool: Returns: bool: Whether current epoch can be evenly divided by n. """ - return (runner.epoch + 1 - st) % n == 0 if n > 0 else False + dividend = runner.epoch + 1 - start + return dividend % n == 0 if dividend >=0 and n > 0 else False def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: """Test whether current inner iteration can be evenly divided by n. From fa3aeac1fdb3133ced008fe92edc5763b9265b87 Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:25:40 +0800 Subject: [PATCH 10/14] fix code mention above and do pre-commit check again --- mmengine/hooks/checkpoint_hook.py | 6 ++---- mmengine/hooks/hook.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 2c3f9387fe..b7f6cc0c4a 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -251,7 +251,8 @@ def __init__(self, self.last_ckpt = None if save_begin < 0: - raise ValueError('save_begin should not less than 0, but got {save_begin}') + raise ValueError( + 'save_begin should not less than 0, but got {save_begin}') self.save_begin = save_begin def before_train(self, runner) -> None: @@ -654,9 +655,6 @@ def after_train_iter(self, if self.by_epoch: return - if runner.iter + 1 < self.save_begin: - return - # save checkpoint for following cases: # 1. every ``self.interval`` iterations # which start at ``self.save_begin`` diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index ec33403531..4e1c4ce8bd 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -342,14 +342,14 @@ def every_n_epochs(self, runner, n: int, start: int = 0) -> bool: runner (Runner): The runner of the training, validation or testing process. n (int): Whether current epoch can be evenly divided by n. - start (int): Starting from `start` to check the logic for every n epochs. - Defaults to 0. + start (int): Starting from `start` to check the logic for + every n epochs. Defaults to 0. Returns: bool: Whether current epoch can be evenly divided by n. """ dividend = runner.epoch + 1 - start - return dividend % n == 0 if dividend >=0 and n > 0 else False + return dividend % n == 0 if dividend >= 0 and n > 0 else False def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: """Test whether current inner iteration can be evenly divided by n. @@ -366,21 +366,22 @@ def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: """ return (batch_idx + 1) % n == 0 if n > 0 else False - def every_n_train_iters(self, runner, n: int, st: int = 0) -> bool: + def every_n_train_iters(self, runner, n: int, start: int = 0) -> bool: """Test whether current training iteration can be evenly divided by n. Args: runner (Runner): The runner of the training, validation or testing process. n (int): Whether current iteration can be evenly divided by n. - st(int): The parameter of save_begin, controlling the - iteration number at which checkpoint saving begins. + start (int): Starting from `start` to check the logic for + every n iterations. Defaults to 0. Returns: bool: Return True if the current iteration can be evenly divided by n, otherwise False. """ - return (runner.iter + 1 - st) % n == 0 if n > 0 else False + dividend = runner.iter + 1 - start + return dividend % n == 0 if dividend >= 0 and n > 0 else False def end_of_epoch(self, dataloader, batch_idx: int) -> bool: """Check whether the current iteration reaches the last iteration of From 28c339ba5f347f1d199c766a404041f273939d37 Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:32:54 +0800 Subject: [PATCH 11/14] Update mmengine/hooks/checkpoint_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/hooks/checkpoint_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index b7f6cc0c4a..94082182c3 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -252,7 +252,7 @@ def __init__(self, self.last_ckpt = None if save_begin < 0: raise ValueError( - 'save_begin should not less than 0, but got {save_begin}') + 'save_begin should not be less than 0, but got {save_begin}') self.save_begin = save_begin def before_train(self, runner) -> None: From 9a084d133d1b3b2a0856cfe796b55d20b7ab6dc9 Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:40:39 +0800 Subject: [PATCH 12/14] delete --- mmengine/hooks/checkpoint_hook.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index b7f6cc0c4a..03e00b0829 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -335,9 +335,6 @@ def after_train_epoch(self, runner) -> None: if not self.by_epoch: return - if runner.epoch + 1 < self.save_begin: - return - # save checkpoint for following cases: # 1. every ``self.interval`` epochs which start at ``self.save_begin`` # 2. reach the last epoch of training From dc8b1adf7f7e4431ebe4ae89a1923a6830997cce Mon Sep 17 00:00:00 2001 From: KerwinKai <101576779+KerwinKai@users.noreply.github.com> Date: Tue, 25 Jul 2023 18:05:23 +0800 Subject: [PATCH 13/14] try to fix unit test failed --- mmengine/hooks/hook.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 4e1c4ce8bd..3aa18ac239 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -348,7 +348,7 @@ def every_n_epochs(self, runner, n: int, start: int = 0) -> bool: Returns: bool: Whether current epoch can be evenly divided by n. """ - dividend = runner.epoch + 1 - start + dividend = int(runner.epoch + 1 - start) return dividend % n == 0 if dividend >= 0 and n > 0 else False def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: @@ -380,7 +380,7 @@ def every_n_train_iters(self, runner, n: int, start: int = 0) -> bool: bool: Return True if the current iteration can be evenly divided by n, otherwise False. """ - dividend = runner.iter + 1 - start + dividend = int(runner.iter + 1 - start) return dividend % n == 0 if dividend >= 0 and n > 0 else False def end_of_epoch(self, dataloader, batch_idx: int) -> bool: From d6256cf8c232ef6ded3a5258ea5068e96500c9ac Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 25 Jul 2023 19:05:14 +0800 Subject: [PATCH 14/14] Avoid to compare (MagicMock) with other integers. --- mmengine/hooks/hook.py | 4 ++-- tests/test_hooks/test_logger_hook.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 3aa18ac239..4e1c4ce8bd 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -348,7 +348,7 @@ def every_n_epochs(self, runner, n: int, start: int = 0) -> bool: Returns: bool: Whether current epoch can be evenly divided by n. """ - dividend = int(runner.epoch + 1 - start) + dividend = runner.epoch + 1 - start return dividend % n == 0 if dividend >= 0 and n > 0 else False def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: @@ -380,7 +380,7 @@ def every_n_train_iters(self, runner, n: int, start: int = 0) -> bool: bool: Return True if the current iteration can be evenly divided by n, otherwise False. """ - dividend = int(runner.iter + 1 - start) + dividend = runner.iter + 1 - start return dividend % n == 0 if dividend >= 0 and n > 0 else False def end_of_epoch(self, dataloader, batch_idx: int) -> bool: diff --git a/tests/test_hooks/test_logger_hook.py b/tests/test_hooks/test_logger_hook.py index d2670b5b4b..52b8bc1fa3 100644 --- a/tests/test_hooks/test_logger_hook.py +++ b/tests/test_hooks/test_logger_hook.py @@ -63,6 +63,9 @@ def test_init(self): def test_after_train_iter(self): # Test LoggerHook by iter. + # Avoid to compare `Runner.iter` (MagicMock) with other integers. + ori_every_n_train_iters = LoggerHook.every_n_train_iters + LoggerHook.every_n_train_iters = MagicMock(return_value=True) runner = MagicMock() runner.log_processor.get_log_after_iter = MagicMock( return_value=(dict(), 'log_str')) @@ -112,6 +115,7 @@ def test_after_train_iter(self): logger_hook = LoggerHook() logger_hook.after_train_iter(runner, batch_idx=8) runner.log_processor.get_log_after_iter.assert_called() + LoggerHook.every_n_train_iters = ori_every_n_train_iters def test_after_val_epoch(self): logger_hook = LoggerHook()