Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix pickling the Python style config #1241

Merged
merged 10 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 60 additions & 29 deletions mmengine/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,24 @@ def _merge_a_into_b(a, b):
for key, value in merged.items():
self[key] = value

def __getstate__(self):
state = {}
for key, value in super().items():
state[key] = value
return state

def __setstate__(self, state):
for key, value in state.items():
self[key] = value

def __eq__(self, other):
if isinstance(other, ConfigDict):
return other.to_dict() == self.to_dict()
elif isinstance(other, dict):
return {k: v for k, v in self.items()} == other
else:
return False

def _to_lazy_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and keep
the ``LazyObject`` or ``LazyAttr`` object not built."""
Expand All @@ -281,8 +299,8 @@ def _to_dict(data):
return _to_dict(self)

def to_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and keep
the ``LazyObject`` or ``LazyAttr`` object not built."""
"""Convert the ConfigDict to a normal dictionary recursively, and
convert the ``LazyObject`` or ``LazyAttr`` to string."""
return _lazy2string(self, dict_type=dict)


Expand Down Expand Up @@ -363,12 +381,14 @@ class Config:
.. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html
""" # noqa: E501

def __init__(self,
cfg_dict: dict = None,
cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None,
format_python_code: bool = True):
def __init__(
self,
cfg_dict: dict = None,
cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None,
format_python_code: bool = True,
):
filename = str(filename) if isinstance(filename, Path) else filename
if cfg_dict is None:
cfg_dict = dict()
Expand All @@ -384,6 +404,9 @@ def __init__(self,
super().__setattr__('_cfg_dict', cfg_dict)
super().__setattr__('_filename', filename)
super().__setattr__('_format_python_code', format_python_code)
if not hasattr(self, '_imported_names'):
super().__setattr__('_imported_names', set())

if cfg_text:
text = cfg_text
elif filename:
Expand Down Expand Up @@ -445,7 +468,8 @@ def fromfile(filename: Union[str, Path],
cfg_dict,
cfg_text=cfg_text,
filename=filename,
env_variables=env_variables)
env_variables=env_variables,
)
else:
# Enable lazy import when parsing the config.
# Using try-except to make sure ``ConfigDict.lazy`` will be reset
Expand All @@ -457,15 +481,10 @@ def fromfile(filename: Union[str, Path],
except Exception as e:
raise e
finally:
# disable lazy import to get the real type. See more details
# about lazy in the docstring of ConfigDict
ConfigDict.lazy = False

# delete builtin imported objects
for key, value in list(cfg_dict._to_lazy_dict().items()):
if isinstance(value, (types.FunctionType, types.ModuleType)):
cfg_dict.pop(key)

# disable lazy import to get the real type. See more details about
# lazy in the docstring of ConfigDict
cfg = Config(
cfg_dict,
filename=filename,
Expand Down Expand Up @@ -996,7 +1015,7 @@ def _parse_lazy_import(filename: str) -> Tuple[ConfigDict, set]:
# accessed, but will not be dumped by default.

with open(filename, encoding='utf-8') as f:
global_dict = {'LazyObject': LazyObject}
global_dict = {'LazyObject': LazyObject, '__file__': filename}
base_dict = {}

parsed_codes = ast.parse(f.read())
Expand Down Expand Up @@ -1470,9 +1489,13 @@ def __setitem__(self, name, value):
def __iter__(self):
return iter(self._cfg_dict)

def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], dict]:
return (self._cfg_dict, self._filename, self._text,
self._env_variables)
def __getstate__(
self
) -> Tuple[dict, Optional[str], Optional[str], dict, bool, set]:
state = (self._cfg_dict, self._filename, self._text,
self._env_variables, self._format_python_code,
self._imported_names)
return state

def __deepcopy__(self, memo):
cls = self.__class__
Expand All @@ -1495,12 +1518,13 @@ def __copy__(self):
copy = __copy__

def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str],
dict]):
_cfg_dict, _filename, _text, _env_variables = state
super().__setattr__('_cfg_dict', _cfg_dict)
super().__setattr__('_filename', _filename)
super().__setattr__('_text', _text)
super().__setattr__('_text', _env_variables)
dict, bool, set]):
super().__setattr__('_cfg_dict', state[0])
super().__setattr__('_filename', state[1])
super().__setattr__('_text', state[2])
super().__setattr__('_env_variables', state[3])
super().__setattr__('_format_python_code', state[4])
super().__setattr__('_imported_names', state[5])

def dump(self, file: Optional[Union[str, Path]] = None):
"""Dump config to file or return config text.
Expand Down Expand Up @@ -1616,8 +1640,8 @@ def _is_lazy_import(filename: str) -> bool:
return False

def _to_lazy_dict(self, keep_imported: bool = False) -> dict:
"""Convert config object to dictionary and filter the imported
object."""
"""Convert config object to dictionary with lazy object, and filter the
imported object."""
res = self._cfg_dict._to_lazy_dict()
if hasattr(self, '_imported_names') and not keep_imported:
res = {
Expand All @@ -1637,7 +1661,14 @@ def to_dict(self, keep_imported: bool = False):
If you import third-party objects in the config file, all imported
objects will be converted to a string like ``torch.optim.SGD``
"""
return self._cfg_dict.to_dict()
cfg_dict = self._cfg_dict.to_dict()
if hasattr(self, '_imported_names') and not keep_imported:
cfg_dict = {
key: value
for key, value in cfg_dict.items()
if key not in self._imported_names
}
return cfg_dict


class DictAction(Action):
Expand Down
20 changes: 20 additions & 0 deletions mmengine/config/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ def __str__(self) -> str:

__repr__ = __str__

# `pickle.dump` will try to get the `__getstate__` and `__setstate__`
# methods of the dumped object. If these two methods are not defined,
# LazyObject will return a `__getstate__` LazyObject` or `__setstate__`
# LazyObject.
def __getstate__(self):
return self.__dict__

def __setstate__(self, state):
self.__dict__ = state


class LazyAttr:
"""The attribute of the LazyObject.
Expand Down Expand Up @@ -219,3 +229,13 @@ def __str__(self) -> str:
return self.name

__repr__ = __str__

# `pickle.dump` will try to get the `__getstate__` and `__setstate__`
# methods of the dumped object. If these two methods are not defined,
# LazyAttr will return a `__getstate__` LazyAttr` or `__setstate__`
# LazyAttr.
def __getstate__(self):
return self.__dict__

def __setstate__(self, state):
self.__dict__ = state
13 changes: 13 additions & 0 deletions mmengine/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def _is_builtin_module(module_name: str) -> bool:
return False
if module_name.startswith('mmengine.config'):
return True
if module_name in sys.builtin_module_names:
return True
spec = find_spec(module_name.split('.')[0])
# Module not found
if spec is None:
Expand Down Expand Up @@ -314,6 +316,15 @@ class docstring says.
# Built-in modules will not be parsed as LazyObject
module = f'{node.level*"."}{node.module}'
if _is_builtin_module(module):
# Make sure builtin module will be added into `self.imported_obj`
for alias in node.names:
if alias.asname is not None:
self.imported_obj.add(alias.asname)
elif alias.name == '*':
raise ConfigParsingError(
'Cannot import * from non-base config')
else:
self.imported_obj.add(alias.name)
return node

if module in self.base_dict:
Expand Down Expand Up @@ -409,6 +420,8 @@ def visit_Import(self, node) -> Union[ast.Assign, ast.Import]:
alias = alias_list[0]
if alias.asname is not None:
self.imported_obj.add(alias.asname)
if _is_builtin_module(alias.name.split('.')[0]):
return node
return ast.parse( # type: ignore
f'{alias.asname} = LazyObject('
f'"{alias.name}",'
Expand Down
16 changes: 16 additions & 0 deletions tests/data/config/lazy_module_config/test_mix_builtin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from functools import partial
from itertools import chain
from os.path import basename
from os.path import exists as ex
from os.path import splitext

import numpy as np

path = osp.join('a', 'b')
name, suffix = splitext('a/b.py')
chained = list(chain([1, 2], [3, 4]))
existed = ex(__file__)
cfgname = partial(basename, __file__)()

34 changes: 34 additions & 0 deletions tests/test_config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import os
import os.path as osp
import pickle
import platform
import sys
import tempfile
Expand Down Expand Up @@ -951,6 +952,19 @@ class ToyModel:
assert isinstance(model.roi_head.bbox_head.loss_cls, ToyLoss)
DefaultScope._instance_dict.pop('test1')

def test_pickle(self):
# Text style config
cfg_path = osp.join(self.data_path, 'config/py_config/test_py_base.py')
cfg = Config.fromfile(cfg_path)
pickled = pickle.loads(pickle.dumps(cfg))
assert pickled.__dict__ == cfg.__dict__

cfg_path = osp.join(self.data_path,
'config/lazy_module_config/toy_model.py')
cfg = Config.fromfile(cfg_path)
pickled = pickle.loads(pickle.dumps(cfg))
assert pickled.__dict__ == cfg.__dict__

def test_lazy_import(self, tmp_path):
lazy_import_cfg_path = osp.join(
self.data_path, 'config/lazy_module_config/toy_model.py')
Expand Down Expand Up @@ -1036,6 +1050,26 @@ def _compare_dict(a, b):
osp.join(self.data_path,
'config/lazy_module_config/error_mix_using2.py'))

cfg = Config.fromfile(
osp.join(self.data_path,
'config/lazy_module_config/test_mix_builtin.py'))
assert cfg.path == osp.join('a', 'b')
assert cfg.name == 'a/b'
assert cfg.suffix == '.py'
assert cfg.chained == [1, 2, 3, 4]
assert cfg.existed
assert cfg.cfgname == 'test_mix_builtin.py'

cfg_dict = cfg.to_dict()
dumped_cfg_path = tmp_path / 'test_dump_lazy.py'
cfg.dump(dumped_cfg_path)
dumped_cfg = Config.fromfile(dumped_cfg_path)

assert set(dumped_cfg.keys()) == {
'path', 'name', 'suffix', 'chained', 'existed', 'cfgname'
}
assert dumped_cfg.to_dict() == cfg.to_dict()


class TestConfigDict(TestCase):

Expand Down