diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index d0d5e3acb6..f2712e6f81 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -7,6 +7,7 @@ from mmengine.logging import print_log from mmengine.model import is_model_wrapper from mmengine.registry import HOOKS, MODELS +from mmengine.runner.checkpoint import _load_checkpoint_to_model from .hook import DATA_BATCH, Hook @@ -171,7 +172,7 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: Args: runner (Runner): The runner of the testing process. """ - if 'ema_state_dict' in checkpoint: + if 'ema_state_dict' in checkpoint and runner._resume: # The original model parameters are actually saved in ema # field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint) @@ -180,11 +181,13 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: # Support load checkpoint without ema state dict. else: - print_log( - 'There is no `ema_state_dict` in checkpoint. ' - '`EMAHook` will make a copy of `state_dict` as the ' - 'initial `ema_state_dict`', 'current', logging.WARNING) - self.ema_model.module.load_state_dict( + if runner._resume: + print_log( + 'There is no `ema_state_dict` in checkpoint. ' + '`EMAHook` will make a copy of `state_dict` as the ' + 'initial `ema_state_dict`', 'current', logging.WARNING) + _load_checkpoint_to_model( + self.ema_model.module, copy.deepcopy(checkpoint['state_dict']), strict=self.strict_load) diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 4b7e7d7bca..3952033c60 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -56,6 +56,16 @@ def forward(self, *args, **kwargs): return super(BaseModel, self).forward(*args, **kwargs) +class ToyModel3(BaseModel, ToyModel): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(2, 2) + + def forward(self, *args, **kwargs): + return super(BaseModel, self).forward(*args, **kwargs) + + @DATASETS.register_module() class DummyDataset(Dataset): METAINFO = dict() # type: ignore @@ -203,6 +213,25 @@ def forward(self, *args, **kwargs): experiment_name='test5') runner.test() + # Test does not load ckpt strict_loadly. + # Test load checkpoint without ema_state_dict + # Test with different size head. + runner = Runner( + model=ToyModel3(), + test_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + test_evaluator=evaluator, + test_cfg=dict(), + work_dir=self.temp_dir.name, + load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'), + default_hooks=dict(logger=None), + custom_hooks=[dict(type='EMAHook', strict_load=False)], + experiment_name='test5') + runner.test() + # Test enable ema at 5 epochs. runner = Runner( model=model,