diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index ba2a2c9e94..8649d9aded 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -4,7 +4,8 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning, import_modules_from_strings, is_list_of, is_seq_of, is_str, is_tuple_of, iter_cast, list_cast, requires_executable, - requires_package, slice_list, tuple_cast) + requires_package, slice_list, to_1tuple, to_2tuple, + to_3tuple, to_4tuple, to_ntuple, tuple_cast) from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, scandir, symlink) from .progressbar import (ProgressBar, track_iter_progress, @@ -29,17 +30,18 @@ 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning', 'digit_version', 'get_git_hash', 'import_modules_from_strings', 'assert_dict_contains_subset', 'assert_attrs_equal', - 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script' + 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script', + 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple' ] else: from .env import collect_env from .logging import get_logger, print_log + from .parrots_jit import jit, skip_no_elena from .parrots_wrapper import ( CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config) - from .parrots_jit import jit, skip_no_elena from .registry import Registry, build_from_cfg __all__ = [ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', diff --git a/mmcv/utils/misc.py b/mmcv/utils/misc.py index 5e4645e37d..1d2517f02d 100644 --- a/mmcv/utils/misc.py +++ b/mmcv/utils/misc.py @@ -1,4 +1,5 @@ # Copyright (c) Open-MMLab. All rights reserved. +import collections.abc import functools import itertools import subprocess @@ -6,6 +7,25 @@ from collections import abc from importlib import import_module from inspect import getfullargspec +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple def is_str(x): diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index adcd26ea0d..29819b2faa 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -4,6 +4,31 @@ import mmcv +def test_to_ntuple(): + single_number = 2 + assert mmcv.utils.to_1tuple(single_number) == (single_number, ) + assert mmcv.utils.to_2tuple(single_number) == (single_number, + single_number) + assert mmcv.utils.to_3tuple(single_number) == (single_number, + single_number, + single_number) + assert mmcv.utils.to_4tuple(single_number) == (single_number, + single_number, + single_number, + single_number) + assert mmcv.utils.to_ntuple(5)(single_number) == (single_number, + single_number, + single_number, + single_number, + single_number) + assert mmcv.utils.to_ntuple(6)(single_number) == (single_number, + single_number, + single_number, + single_number, + single_number, + single_number) + + def test_iter_cast(): assert mmcv.list_cast([1, 2, 3], int) == [1, 2, 3] assert mmcv.list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0] @@ -105,6 +130,7 @@ def func_c(): def test_import_modules_from_strings(): # multiple imports import os.path as osp_ + import sys as sys_ osp, sys = mmcv.import_modules_from_strings(['os.path', 'sys']) assert osp == osp_