Skip to content

Commit

Permalink
[fix] EMAHook load state dict (#507)
Browse files Browse the repository at this point in the history
* fix ema load_state_dict

* fix ema load_state_dict

* fix for test

* fix by review

* fix resume and keys
  • Loading branch information
okotaku committed Sep 9, 2022
1 parent cfb884c commit a6f5297
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
15 changes: 9 additions & 6 deletions mmengine/hooks/ema_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS, MODELS
from mmengine.runner.checkpoint import _load_checkpoint_to_model
from .hook import DATA_BATCH, Hook


Expand Down Expand Up @@ -171,7 +172,7 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
Args:
runner (Runner): The runner of the testing process.
"""
if 'ema_state_dict' in checkpoint:
if 'ema_state_dict' in checkpoint and runner._resume:
# The original model parameters are actually saved in ema
# field swap the weights back to resume ema state.
self._swap_ema_state_dict(checkpoint)
Expand All @@ -180,11 +181,13 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None:

# Support load checkpoint without ema state dict.
else:
print_log(
'There is no `ema_state_dict` in checkpoint. '
'`EMAHook` will make a copy of `state_dict` as the '
'initial `ema_state_dict`', 'current', logging.WARNING)
self.ema_model.module.load_state_dict(
if runner._resume:
print_log(
'There is no `ema_state_dict` in checkpoint. '
'`EMAHook` will make a copy of `state_dict` as the '
'initial `ema_state_dict`', 'current', logging.WARNING)
_load_checkpoint_to_model(
self.ema_model.module,
copy.deepcopy(checkpoint['state_dict']),
strict=self.strict_load)

Expand Down
29 changes: 29 additions & 0 deletions tests/test_hooks/test_ema_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ def forward(self, *args, **kwargs):
return super(BaseModel, self).forward(*args, **kwargs)


class ToyModel3(BaseModel, ToyModel):

def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 2)

def forward(self, *args, **kwargs):
return super(BaseModel, self).forward(*args, **kwargs)


@DATASETS.register_module()
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
Expand Down Expand Up @@ -203,6 +213,25 @@ def forward(self, *args, **kwargs):
experiment_name='test5')
runner.test()

# Test does not load ckpt strict_loadly.
# Test load checkpoint without ema_state_dict
# Test with different size head.
runner = Runner(
model=ToyModel3(),
test_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
test_evaluator=evaluator,
test_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook', strict_load=False)],
experiment_name='test5')
runner.test()

# Test enable ema at 5 epochs.
runner = Runner(
model=model,
Expand Down

0 comments on commit a6f5297

Please sign in to comment.