diff --git a/mmcv/runner/base_runner.py b/mmcv/runner/base_runner.py index f78f1582f3..e28179d404 100644 --- a/mmcv/runner/base_runner.py +++ b/mmcv/runner/base_runner.py @@ -330,6 +330,20 @@ def resume(self, self._epoch = checkpoint['meta']['epoch'] self._iter = checkpoint['meta']['iter'] + + # Re-calculate the number of iterations when resuming + # models with different number of GPUs + if 'config' in checkpoint['meta']: + config = mmcv.Config.fromstring( + checkpoint['meta']['config'], file_format='.py') + previous_gpu_ids = config.get('gpu_ids', None) + if previous_gpu_ids and len(previous_gpu_ids) > 0 and len( + previous_gpu_ids) != self.world_size: + self._iter = int(self._iter * len(previous_gpu_ids) / + self.world_size) + self.logger.info('the iteration number is changed due to ' + 'change of GPU number') + if 'optimizer' in checkpoint and resume_optimizer: if isinstance(self.optimizer, Optimizer): self.optimizer.load_state_dict(checkpoint['optimizer']) diff --git a/mmcv/utils/config.py b/mmcv/utils/config.py index 17e44e7cff..ab4eb02aa6 100644 --- a/mmcv/utils/config.py +++ b/mmcv/utils/config.py @@ -5,6 +5,7 @@ import shutil import sys import tempfile +import warnings from argparse import Action, ArgumentParser from collections import abc from importlib import import_module @@ -253,6 +254,31 @@ def fromfile(filename, import_modules_from_strings(**cfg_dict['custom_imports']) return Config(cfg_dict, cfg_text=cfg_text, filename=filename) + @staticmethod + def fromstring(cfg_str, file_format): + """Generate config from config str. + + Args: + cfg_str (str): Config str. + file_format (str): Config file format corresponding to the + config str. Only py/yml/yaml/json type are supported now! + + Returns: + obj:`Config`: Config obj. + """ + if file_format not in ['.py', '.json', '.yaml', '.yml']: + raise IOError('Only py/yml/yaml/json type are supported now!') + if file_format != '.py' and 'dict(' in cfg_str: + # check if users specify a wrong suffix for python + warnings.warn( + 'Please check "file_format", the file format may be .py') + + with tempfile.NamedTemporaryFile('w', suffix=file_format) as temp_file: + temp_file.write(cfg_str) + temp_file.flush() + cfg = Config.fromfile(temp_file.name) + return cfg + @staticmethod def auto_argparser(description=None): """Generate argparser from config file automatically (experimental)""" diff --git a/tests/test_utils/test_config.py b/tests/test_utils/test_config.py index 013bb2f5c5..5abafe80b8 100644 --- a/tests/test_utils/test_config.py +++ b/tests/test_utils/test_config.py @@ -161,6 +161,31 @@ def test_fromfile(): Config.fromfile(osp.join(data_path, 'color.jpg')) +def test_fromstring(): + for filename in ['a.py', 'a.b.py', 'b.json', 'c.yaml']: + cfg_file = osp.join(data_path, 'config', filename) + file_format = osp.splitext(filename)[-1] + in_cfg = Config.fromfile(cfg_file) + + out_cfg = Config.fromstring(in_cfg.pretty_text, '.py') + assert in_cfg._cfg_dict == out_cfg._cfg_dict + + cfg_str = open(cfg_file, 'r').read() + out_cfg = Config.fromstring(cfg_str, file_format) + assert in_cfg._cfg_dict == out_cfg._cfg_dict + + # test pretty_text only supports py file format + cfg_file = osp.join(data_path, 'config', 'b.json') + in_cfg = Config.fromfile(cfg_file) + with pytest.raises(Exception): + Config.fromstring(in_cfg.pretty_text, '.json') + + # test file format error + cfg_str = open(cfg_file, 'r').read() + with pytest.raises(Exception): + Config.fromstring(cfg_str, '.py') + + def test_merge_from_base(): cfg_file = osp.join(data_path, 'config/d.py') cfg = Config.fromfile(cfg_file)