Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Save optimizer.state in cpu by default. #966

Merged
merged 5 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,4 @@ Miscellaneous
requires_executable
requires_package
check_time
apply_to
1 change: 1 addition & 0 deletions docs/zh_cn/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,4 @@ Miscellaneous
requires_executable
requires_package
check_time
apply_to
12 changes: 6 additions & 6 deletions mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from mmengine.fileio import load as load_file
from mmengine.logging import print_log
from mmengine.model import BaseTTAModel, is_model_wrapper
from mmengine.utils import deprecated_function, digit_version, mkdir_or_exist
from mmengine.utils import (apply_to, deprecated_function, digit_version,
mkdir_or_exist)
from mmengine.utils.dl_utils import load_url

# `MMENGINE_HOME` is the highest priority directory to save checkpoints
Expand Down Expand Up @@ -622,12 +623,11 @@ def weights_to_cpu(state_dict):
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
state_dict = apply_to(state_dict, lambda x: hasattr(x, 'cpu'),
lambda x: x.cpu())
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
return state_dict_cpu
state_dict._metadata = getattr(state_dict, '_metadata', OrderedDict())
return state_dict


@deprecated_function(
Expand Down
20 changes: 13 additions & 7 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS,
RUNNERS, VISUALIZERS, DefaultScope)
from mmengine.utils import digit_version, get_git_hash, is_seq_of
from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing)
from mmengine.visualization import Visualizer
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
find_latest_checkpoint, get_state_dict,
save_checkpoint, weights_to_cpu)
find_latest_checkpoint, save_checkpoint,
weights_to_cpu)
from .log_processor import LogProcessor
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
Expand Down Expand Up @@ -2139,14 +2139,20 @@ def save_checkpoint(
model = self.model

checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model)),
'message_hub': self.message_hub.state_dict()
'meta':
meta,
'state_dict':
weights_to_cpu(model.state_dict()),
'message_hub':
apply_to(self.message_hub.state_dict(),
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()),
}
# save optimizer state dict to checkpoint
if save_optimizer:
if isinstance(self.optim_wrapper, OptimWrapper):
checkpoint['optimizer'] = self.optim_wrapper.state_dict()
checkpoint['optimizer'] = apply_to(
self.optim_wrapper.state_dict(),
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu())
else:
raise TypeError(
'self.optim_wrapper should be an `OptimWrapper` '
Expand Down
7 changes: 4 additions & 3 deletions mmengine/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .manager import ManagerMeta, ManagerMixin
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
deprecated_function, has_method,
from .misc import (apply_to, check_prerequisites, concat_list,
deprecated_api_warning, deprecated_function, has_method,
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
Expand All @@ -27,5 +27,6 @@
'is_abs', 'is_method_overridden', 'has_method', 'digit_version',
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
'TimerError', 'ProgressBar', 'track_iter_progress',
'track_parallel_progress', 'track_progress', 'deprecated_function'
'track_parallel_progress', 'track_progress', 'deprecated_function',
'apply_to'
]
41 changes: 41 additions & 0 deletions mmengine/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,47 @@ def concat_list(in_list):
return list(itertools.chain(*in_list))


def apply_to(data: Any, expr: Callable, apply_func: Callable):
"""Apply function to each element in dict, list or tuple that matches with
the expression.
For examples, if you want to convert each element in a list of dict from
`np.ndarray` to `Tensor`. You can use the following code:
Examples:
>>> from mmengine.utils import apply_to
>>> import numpy as np
>>> import torch
>>> data = dict(array=[np.array(1)]) # {'array': [array(1)]}
>>> result = apply_to(data, lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x))
>>> print(result) # {'array': [tensor(1)]}
Args:
data (Any): Data to be applied.
expr (Callable): Expression to tell which data should be applied with
the function. It should return a boolean.
apply_func (Callable): Function applied to data.
Returns:
Any: The data after applying.
""" # noqa: E501
if isinstance(data, dict):
# Keep the original dict type
res = type(data)()
for key, value in data.items():
res[key] = apply_to(value, expr, apply_func)
return res
elif isinstance(data, tuple) and hasattr(data, '_fields'):
# namedtuple
return type(data)(*(apply_to(sample, expr, apply_func) for sample in data)) # type: ignore # noqa: E501 # yapf:disable
elif isinstance(data, (tuple, list)):
return type(data)(apply_to(sample, expr, apply_func) for sample in data) # type: ignore # noqa: E501 # yapf:disable
elif expr(data):
return apply_func(data)
else:
return data


def check_prerequisites(
prerequisites,
checker,
Expand Down
46 changes: 45 additions & 1 deletion tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import namedtuple

import numpy as np
import pytest
import torch

from mmengine import MMLogger
# yapf: disable
from mmengine.utils.misc import (concat_list, deprecated_api_warning,
from mmengine.utils.misc import (apply_to, concat_list, deprecated_api_warning,
deprecated_function, has_method,
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_tuple_of,
Expand Down Expand Up @@ -283,3 +287,43 @@ def deprecated_demo1():
Short summary.""" # noqa: E122
assert expected_docstring.strip(' ') == deprecated_demo1.__doc__


def test_apply_to():
# Test only apply `+1` to int object.
data = dict(a=1, b=2.0)
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
assert result == dict(a=2, b=2.0)

# Test with nested data
data = dict(a=[dict(c=1)], b=2.0)
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
assert result == dict(a=[dict(c=2)], b=2.0)

# Tensor to numpy
data = dict(a=[dict(c=torch.tensor(1))], b=torch.tensor(2))
result = apply_to(data, lambda x: isinstance(x, torch.Tensor),
lambda x: x.numpy())
assert isinstance(result['b'], np.ndarray)
assert isinstance(result['a'][0]['c'], np.ndarray)

# Tuple and convert string
data = (1, dict(a=[dict(b=2.0)]), 'test')
result = apply_to(
data, lambda x: isinstance(x, int) or x == 'test',
lambda x: torch.Tensor(x) if isinstance(x, int) else 'train')
assert isinstance(result, tuple)
assert isinstance(result[0], torch.Tensor)
assert isinstance(result[1]['a'][0]['b'], float)
assert result[2] == 'train'

# Named Tuple
dataclass = namedtuple('Data', ['a', 'b'])
data = dataclass('test', dict(a=[dict(c=1)], b=2.0))
result = apply_to(
data, lambda x: isinstance(x, int) or x == 'test',
lambda x: torch.Tensor(x) if isinstance(x, int) else 'train')
assert isinstance(result, dataclass)
assert result[0] == 'train'
assert isinstance(result.b['a'][0]['c'], torch.Tensor)
assert isinstance(result.b['b'], float)