From c533081306c0b151b1b76038f4307b24c566756f Mon Sep 17 00:00:00 2001 From: Rist115 Date: Fri, 2 Sep 2022 15:18:02 +0900 Subject: [PATCH 1/5] fix ema load_state_dict --- mmengine/hooks/ema_hook.py | 12 +++++++----- tests/test_hooks/test_ema_hook.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index d0d5e3acb6..a940ce2203 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -7,6 +7,7 @@ from mmengine.logging import print_log from mmengine.model import is_model_wrapper from mmengine.registry import HOOKS, MODELS +from mmengine.runner import load_state_dict from .hook import DATA_BATCH, Hook @@ -175,8 +176,9 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: # The original model parameters are actually saved in ema # field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint) - self.ema_model.load_state_dict( - checkpoint['ema_state_dict'], strict=self.strict_load) + load_state_dict(self.ema_model.module, + checkpoint['ema_state_dict'], + strict=self.strict_load) # Support load checkpoint without ema state dict. else: @@ -184,9 +186,9 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: 'There is no `ema_state_dict` in checkpoint. ' '`EMAHook` will make a copy of `state_dict` as the ' 'initial `ema_state_dict`', 'current', logging.WARNING) - self.ema_model.module.load_state_dict( - copy.deepcopy(checkpoint['state_dict']), - strict=self.strict_load) + load_state_dict(self.ema_model.module, + copy.deepcopy(checkpoint['state_dict']), + strict=self.strict_load) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model.""" diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 4b7e7d7bca..3952033c60 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -56,6 +56,16 @@ def forward(self, *args, **kwargs): return super(BaseModel, self).forward(*args, **kwargs) +class ToyModel3(BaseModel, ToyModel): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(2, 2) + + def forward(self, *args, **kwargs): + return super(BaseModel, self).forward(*args, **kwargs) + + @DATASETS.register_module() class DummyDataset(Dataset): METAINFO = dict() # type: ignore @@ -203,6 +213,25 @@ def forward(self, *args, **kwargs): experiment_name='test5') runner.test() + # Test does not load ckpt strict_loadly. + # Test load checkpoint without ema_state_dict + # Test with different size head. + runner = Runner( + model=ToyModel3(), + test_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + test_evaluator=evaluator, + test_cfg=dict(), + work_dir=self.temp_dir.name, + load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'), + default_hooks=dict(logger=None), + custom_hooks=[dict(type='EMAHook', strict_load=False)], + experiment_name='test5') + runner.test() + # Test enable ema at 5 epochs. runner = Runner( model=model, From e5c4c9c43980ea369e3cc727390af49a26ecd629 Mon Sep 17 00:00:00 2001 From: Rist115 Date: Fri, 2 Sep 2022 15:18:39 +0900 Subject: [PATCH 2/5] fix ema load_state_dict --- mmengine/hooks/ema_hook.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index a940ce2203..4145b3ce1b 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -176,9 +176,10 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: # The original model parameters are actually saved in ema # field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint) - load_state_dict(self.ema_model.module, - checkpoint['ema_state_dict'], - strict=self.strict_load) + load_state_dict( + self.ema_model.module, + checkpoint['ema_state_dict'], + strict=self.strict_load) # Support load checkpoint without ema state dict. else: @@ -186,9 +187,10 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: 'There is no `ema_state_dict` in checkpoint. ' '`EMAHook` will make a copy of `state_dict` as the ' 'initial `ema_state_dict`', 'current', logging.WARNING) - load_state_dict(self.ema_model.module, - copy.deepcopy(checkpoint['state_dict']), - strict=self.strict_load) + load_state_dict( + self.ema_model.module, + copy.deepcopy(checkpoint['state_dict']), + strict=self.strict_load) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model.""" From 4ab2badae20a11b68ac209653224c4bb4a10093e Mon Sep 17 00:00:00 2001 From: Rist115 Date: Mon, 5 Sep 2022 09:52:34 +0900 Subject: [PATCH 3/5] fix for test --- mmengine/hooks/ema_hook.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index 4145b3ce1b..8acdbe5a5d 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -7,7 +7,7 @@ from mmengine.logging import print_log from mmengine.model import is_model_wrapper from mmengine.registry import HOOKS, MODELS -from mmengine.runner import load_state_dict +from mmengine.runner.checkpoint import _load_checkpoint_to_model from .hook import DATA_BATCH, Hook @@ -166,7 +166,10 @@ def before_save_checkpoint(self, runner, checkpoint: dict) -> None: # parameters. self._swap_ema_state_dict(checkpoint) - def after_load_checkpoint(self, runner, checkpoint: dict) -> None: + def after_load_checkpoint(self, + runner, + checkpoint: dict, + revise_keys: list = [(r'^module.', '')]) -> None: """Resume ema parameters from checkpoint. Args: @@ -176,10 +179,11 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: # The original model parameters are actually saved in ema # field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint) - load_state_dict( + _load_checkpoint_to_model( self.ema_model.module, checkpoint['ema_state_dict'], - strict=self.strict_load) + strict=self.strict_load, + revise_keys=revise_keys) # Support load checkpoint without ema state dict. else: @@ -187,10 +191,11 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: 'There is no `ema_state_dict` in checkpoint. ' '`EMAHook` will make a copy of `state_dict` as the ' 'initial `ema_state_dict`', 'current', logging.WARNING) - load_state_dict( + _load_checkpoint_to_model( self.ema_model.module, copy.deepcopy(checkpoint['state_dict']), - strict=self.strict_load) + strict=self.strict_load, + revise_keys=revise_keys) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model.""" From 54c37f9c62bfd460c0af29475912b09fda3291ad Mon Sep 17 00:00:00 2001 From: Rist115 Date: Tue, 6 Sep 2022 13:23:58 +0900 Subject: [PATCH 4/5] fix by review --- mmengine/hooks/ema_hook.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index 8acdbe5a5d..b9b70cc905 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -33,6 +33,13 @@ class EMAHook(Hook): Defaults to 0. begin_epoch (int): The number of epoch to enable ``EMAHook``. Defaults to 0. + revise_keys (list): A list of customized keywords to modify the + state_dict in checkpoint. Each item is a (pattern, replacement) + pair of the regular expression operations. Default: strip + the prefix 'module.' by [(r'^module\\.', '')]. + resume (bool): Whether to resume ema model. If ``resume`` is True and + ``ema_state_dict`` is not found in checkpoint, resuming does + nothing. Defaults to True. **kwargs: Keyword arguments passed to subclasses of :obj:`BaseAveragedModel` """ @@ -44,6 +51,8 @@ def __init__(self, strict_load: bool = True, begin_iter: int = 0, begin_epoch: int = 0, + revise_keys: list = [(r'^module.', '')], + resume: bool = True, **kwargs): self.strict_load = strict_load self.ema_cfg = dict(type=ema_type, **kwargs) @@ -59,6 +68,9 @@ def __init__(self, # enabled at 0 iteration. self.enabled_by_epoch = self.begin_epoch > 0 + self.revise_keys = revise_keys + self.resume = resume + def before_run(self, runner) -> None: """Create an ema copy of the model. @@ -166,36 +178,31 @@ def before_save_checkpoint(self, runner, checkpoint: dict) -> None: # parameters. self._swap_ema_state_dict(checkpoint) - def after_load_checkpoint(self, - runner, - checkpoint: dict, - revise_keys: list = [(r'^module.', '')]) -> None: + def after_load_checkpoint(self, runner, checkpoint: dict) -> None: """Resume ema parameters from checkpoint. Args: runner (Runner): The runner of the testing process. """ - if 'ema_state_dict' in checkpoint: + if 'ema_state_dict' in checkpoint and self.resume: # The original model parameters are actually saved in ema # field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint) - _load_checkpoint_to_model( - self.ema_model.module, - checkpoint['ema_state_dict'], - strict=self.strict_load, - revise_keys=revise_keys) + self.ema_model.load_state_dict( + checkpoint['ema_state_dict'], strict=self.strict_load) # Support load checkpoint without ema state dict. else: - print_log( - 'There is no `ema_state_dict` in checkpoint. ' - '`EMAHook` will make a copy of `state_dict` as the ' - 'initial `ema_state_dict`', 'current', logging.WARNING) + if self.resume: + print_log( + 'There is no `ema_state_dict` in checkpoint. ' + '`EMAHook` will make a copy of `state_dict` as the ' + 'initial `ema_state_dict`', 'current', logging.WARNING) _load_checkpoint_to_model( self.ema_model.module, copy.deepcopy(checkpoint['state_dict']), strict=self.strict_load, - revise_keys=revise_keys) + revise_keys=self.revise_keys) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model.""" From e109073d81aa6ab529f8dcfae432269fe6822b3c Mon Sep 17 00:00:00 2001 From: Rist115 Date: Thu, 8 Sep 2022 09:25:33 +0900 Subject: [PATCH 5/5] fix resume and keys --- mmengine/hooks/ema_hook.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index b9b70cc905..f2712e6f81 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -33,13 +33,6 @@ class EMAHook(Hook): Defaults to 0. begin_epoch (int): The number of epoch to enable ``EMAHook``. Defaults to 0. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Default: strip - the prefix 'module.' by [(r'^module\\.', '')]. - resume (bool): Whether to resume ema model. If ``resume`` is True and - ``ema_state_dict`` is not found in checkpoint, resuming does - nothing. Defaults to True. **kwargs: Keyword arguments passed to subclasses of :obj:`BaseAveragedModel` """ @@ -51,8 +44,6 @@ def __init__(self, strict_load: bool = True, begin_iter: int = 0, begin_epoch: int = 0, - revise_keys: list = [(r'^module.', '')], - resume: bool = True, **kwargs): self.strict_load = strict_load self.ema_cfg = dict(type=ema_type, **kwargs) @@ -68,9 +59,6 @@ def __init__(self, # enabled at 0 iteration. self.enabled_by_epoch = self.begin_epoch > 0 - self.revise_keys = revise_keys - self.resume = resume - def before_run(self, runner) -> None: """Create an ema copy of the model. @@ -184,7 +172,7 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: Args: runner (Runner): The runner of the testing process. """ - if 'ema_state_dict' in checkpoint and self.resume: + if 'ema_state_dict' in checkpoint and runner._resume: # The original model parameters are actually saved in ema # field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint) @@ -193,7 +181,7 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: # Support load checkpoint without ema state dict. else: - if self.resume: + if runner._resume: print_log( 'There is no `ema_state_dict` in checkpoint. ' '`EMAHook` will make a copy of `state_dict` as the ' @@ -201,8 +189,7 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: _load_checkpoint_to_model( self.ema_model.module, copy.deepcopy(checkpoint['state_dict']), - strict=self.strict_load, - revise_keys=self.revise_keys) + strict=self.strict_load) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model."""