Skip to content

Commit

Permalink
Fix the iter error when the number of GPUs is different during resume (
Browse files Browse the repository at this point in the history
…#844)

* Fix the iter error when the number of GPUs is different during resume

* Add fromstring and unit test

* Remove is_pretty_text

* Fix comment

* Add log info

* Add py format check

* Remove SyntaxError check
  • Loading branch information
hhaAndroid committed Feb 25, 2021
1 parent ba30d98 commit 58a8483
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 0 deletions.
14 changes: 14 additions & 0 deletions mmcv/runner/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,20 @@ def resume(self,

self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']

# Re-calculate the number of iterations when resuming
# models with different number of GPUs
if 'config' in checkpoint['meta']:
config = mmcv.Config.fromstring(
checkpoint['meta']['config'], file_format='.py')
previous_gpu_ids = config.get('gpu_ids', None)
if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
previous_gpu_ids) != self.world_size:
self._iter = int(self._iter * len(previous_gpu_ids) /
self.world_size)
self.logger.info('the iteration number is changed due to '
'change of GPU number')

if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer'])
Expand Down
26 changes: 26 additions & 0 deletions mmcv/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import sys
import tempfile
import warnings
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module
Expand Down Expand Up @@ -253,6 +254,31 @@ def fromfile(filename,
import_modules_from_strings(**cfg_dict['custom_imports'])
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)

@staticmethod
def fromstring(cfg_str, file_format):
"""Generate config from config str.
Args:
cfg_str (str): Config str.
file_format (str): Config file format corresponding to the
config str. Only py/yml/yaml/json type are supported now!
Returns:
obj:`Config`: Config obj.
"""
if file_format not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
if file_format != '.py' and 'dict(' in cfg_str:
# check if users specify a wrong suffix for python
warnings.warn(
'Please check "file_format", the file format may be .py')

with tempfile.NamedTemporaryFile('w', suffix=file_format) as temp_file:
temp_file.write(cfg_str)
temp_file.flush()
cfg = Config.fromfile(temp_file.name)
return cfg

@staticmethod
def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)"""
Expand Down
25 changes: 25 additions & 0 deletions tests/test_utils/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,31 @@ def test_fromfile():
Config.fromfile(osp.join(data_path, 'color.jpg'))


def test_fromstring():
for filename in ['a.py', 'a.b.py', 'b.json', 'c.yaml']:
cfg_file = osp.join(data_path, 'config', filename)
file_format = osp.splitext(filename)[-1]
in_cfg = Config.fromfile(cfg_file)

out_cfg = Config.fromstring(in_cfg.pretty_text, '.py')
assert in_cfg._cfg_dict == out_cfg._cfg_dict

cfg_str = open(cfg_file, 'r').read()
out_cfg = Config.fromstring(cfg_str, file_format)
assert in_cfg._cfg_dict == out_cfg._cfg_dict

# test pretty_text only supports py file format
cfg_file = osp.join(data_path, 'config', 'b.json')
in_cfg = Config.fromfile(cfg_file)
with pytest.raises(Exception):
Config.fromstring(in_cfg.pretty_text, '.json')

# test file format error
cfg_str = open(cfg_file, 'r').read()
with pytest.raises(Exception):
Config.fromstring(cfg_str, '.py')


def test_merge_from_base():
cfg_file = osp.join(data_path, 'config/d.py')
cfg = Config.fromfile(cfg_file)
Expand Down

0 comments on commit 58a8483

Please sign in to comment.