Skip to content

Commit

Permalink
[Fix] Fix Config cannot parse base config when there is . in tmp …
Browse files Browse the repository at this point in the history
…path (#856)

* [Fix] Fix config cannot parse tmp path like

* Add comments

* Add comments

* Apply suggestions from code review

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
HAOCHENYE and zhouzaida committed Dec 30, 2022
1 parent 6af8878 commit ad1b43f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
6 changes: 3 additions & 3 deletions mmengine/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,8 @@ def _get_base_files(filename: str) -> list:
Returns:
list: A list of base config.
"""
file_format = filename.partition('.')[-1]
if file_format == 'py':
file_format = osp.splitext(filename)[1]
if file_format == '.py':
Config._validate_py_syntax(filename)
with open(filename, encoding='utf-8') as f:
codes = ast.parse(f.read()).body
Expand All @@ -568,7 +568,7 @@ def is_base_line(c):
base_files = eval(compile(base_code, '', mode='eval'))
else:
base_files = []
elif file_format in ('yml', 'yaml', 'json'):
elif file_format in ('.yml', '.yaml', '.json'):
import mmengine
cfg_dict = mmengine.load(filename)
base_files = cfg_dict.get(BASE_KEY, [])
Expand Down
17 changes: 17 additions & 0 deletions tests/test_config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import os.path as osp
import platform
import sys
import tempfile
from importlib import import_module
from pathlib import Path
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -715,6 +717,21 @@ def _base_variables(self):
cfg = Config._file2dict(cfg_file)[0]
assert cfg == dict(item1=dict(a=1))

# Simulate the case that the temporary directory includes `.`, etc.
# /tmp/test.axsgr12/. This patch is to check the issue
# https://github.com/open-mmlab/mmengine/issues/788 has been solved.
class PatchedTempDirectory(tempfile.TemporaryDirectory):

def __init__(self, *args, prefix='test.', **kwargs):
super().__init__(*args, prefix=prefix, **kwargs)

with patch('mmengine.config.config.tempfile.TemporaryDirectory',
PatchedTempDirectory):
cfg_file = osp.join(self.data_path,
'config/py_config/test_py_modify_key.py')
cfg = Config._file2dict(cfg_file)[0]
assert cfg == dict(item1=dict(a=1))

def _merge_recursive_bases(self):
cfg_file = osp.join(self.data_path,
'config/py_config/test_merge_recursive_bases.py')
Expand Down

0 comments on commit ad1b43f

Please sign in to comment.