diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index d47fd3d748..60d71a735b 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -652,10 +652,11 @@ def weights_to_cpu(state_dict): Returns: OrderedDict: Model weights on GPU. """ + # stash metadata to put in state_dict later + metadata = getattr(state_dict, '_metadata', OrderedDict()) state_dict = apply_to(state_dict, lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()) - # Keep metadata in state_dict - state_dict._metadata = getattr(state_dict, '_metadata', OrderedDict()) + state_dict._metadata = metadata return state_dict