From fc53018a9c3ea94b622b856147b86b3e68fbd767 Mon Sep 17 00:00:00 2001 From: Junjun2016 Date: Wed, 11 Aug 2021 23:09:52 +0800 Subject: [PATCH 1/3] update digit_version --- mmseg/__init__.py | 53 +++++++++++++++++++++++++++++++-------- mmseg/datasets/builder.py | 4 ++- requirements/runtime.txt | 1 + setup.cfg | 2 +- 4 files changed, 47 insertions(+), 13 deletions(-) diff --git a/mmseg/__init__.py b/mmseg/__init__.py index dbdebf9943..317622c924 100644 --- a/mmseg/__init__.py +++ b/mmseg/__init__.py @@ -1,4 +1,7 @@ +import warnings + import mmcv +from packaging.version import parse from .version import __version__, version_info @@ -6,16 +9,44 @@ MMCV_MAX = '1.4.0' -def digit_version(version_str): - digit_version = [] - for x in version_str.split('.'): - if x.isdigit(): - digit_version.append(int(x)) - elif x.find('rc') != -1: - patch_version = x.split('rc') - digit_version.append(int(patch_version[0]) - 1) - digit_version.append(int(patch_version[1])) - return digit_version +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) mmcv_min_version = digit_version(MMCV_MIN) @@ -27,4 +58,4 @@ def digit_version(version_str): f'MMCV=={mmcv.__version__} is used but incompatible. ' \ f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' -__all__ = ['__version__', 'version_info'] +__all__ = ['__version__', 'version_info', 'digit_version'] diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py index 5994ab233b..e9093916f1 100644 --- a/mmseg/datasets/builder.py +++ b/mmseg/datasets/builder.py @@ -10,6 +10,8 @@ from mmcv.utils import Registry, build_from_cfg from torch.utils.data import DataLoader, DistributedSampler +from mmseg import digit_version + if platform.system() != 'Windows': # https://github.com/pytorch/pytorch/issues/973 import resource @@ -133,7 +135,7 @@ def build_dataloader(dataset, worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None - if torch.__version__ >= '1.8.0': + if digit_version(torch.__version__) >= digit_version('1.8.0'): data_loader = DataLoader( dataset, batch_size=batch_size, diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 47048d029a..2712f504c7 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,3 +1,4 @@ matplotlib numpy +packaging prettytable diff --git a/setup.cfg b/setup.cfg index 0dbe479fa7..0c80b37ce7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmseg -known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch,ts +known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,packaging,prettytable,pytest,scipy,seaborn,torch,ts no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY From 45a8d99a6da5ad5f16e8e2784e0e6b9f1984611f Mon Sep 17 00:00:00 2001 From: Junjun2016 Date: Wed, 11 Aug 2021 23:13:29 +0800 Subject: [PATCH 2/3] add unittest --- tests/test_digital_version.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/test_digital_version.py diff --git a/tests/test_digital_version.py b/tests/test_digital_version.py new file mode 100644 index 0000000000..4d6649005c --- /dev/null +++ b/tests/test_digital_version.py @@ -0,0 +1,20 @@ +from mmseg import digit_version + + +def test_digit_version(): + assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0) + assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0) + assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0) + assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1) + assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0) + assert digit_version('1.0') == digit_version('1.0.0') + assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5') + assert digit_version('1.0.0dev') < digit_version('1.0.0a') + assert digit_version('1.0.0a') < digit_version('1.0.0a1') + assert digit_version('1.0.0a') < digit_version('1.0.0b') + assert digit_version('1.0.0b') < digit_version('1.0.0rc') + assert digit_version('1.0.0rc1') < digit_version('1.0.0') + assert digit_version('1.0.0') < digit_version('1.0.0post') + assert digit_version('1.0.0post') < digit_version('1.0.0post1') + assert digit_version('v1') == (1, 0, 0, 0, 0, 0) + assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0) From 82f79a845607aded6a179e0ffdf88d6e70829818 Mon Sep 17 00:00:00 2001 From: Junjun2016 Date: Thu, 12 Aug 2021 13:56:32 +0800 Subject: [PATCH 3/3] fix import --- mmseg/datasets/builder.py | 4 +--- tests/{test_digital_version.py => test_digit_version.py} | 0 2 files changed, 1 insertion(+), 3 deletions(-) rename tests/{test_digital_version.py => test_digit_version.py} (100%) diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py index e9093916f1..82f6f460fb 100644 --- a/mmseg/datasets/builder.py +++ b/mmseg/datasets/builder.py @@ -7,11 +7,9 @@ import torch from mmcv.parallel import collate from mmcv.runner import get_dist_info -from mmcv.utils import Registry, build_from_cfg +from mmcv.utils import Registry, build_from_cfg, digit_version from torch.utils.data import DataLoader, DistributedSampler -from mmseg import digit_version - if platform.system() != 'Windows': # https://github.com/pytorch/pytorch/issues/973 import resource diff --git a/tests/test_digital_version.py b/tests/test_digit_version.py similarity index 100% rename from tests/test_digital_version.py rename to tests/test_digit_version.py