Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Update digit_version #778

Merged
merged 3 commits into from
Aug 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions mmseg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,52 @@
import warnings

import mmcv
from packaging.version import parse

from .version import __version__, version_info

MMCV_MIN = '1.3.7'
MMCV_MAX = '1.4.0'


def digit_version(version_str):
digit_version = []
for x in version_str.split('.'):
if x.isdigit():
digit_version.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
digit_version.append(int(patch_version[0]) - 1)
digit_version.append(int(patch_version[1]))
return digit_version
def digit_version(version_str: str, length: int = 4):
"""Convert a version string into a tuple of integers.

This method is usually used for comparing two versions. For pre-release
versions: alpha < beta < rc.

Args:
version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.

Returns:
tuple[int]: The version info in digits (integers).
"""
version = parse(version_str)
assert version.release, f'failed to parse version {version_str}'
release = list(version.release)
release = release[:length]
if len(release) < length:
release = release + [0] * (length - len(release))
if version.is_prerelease:
mapping = {'a': -3, 'b': -2, 'rc': -1}
val = -4
# version.pre can be None
if version.pre:
if version.pre[0] not in mapping:
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
'version checking may go wrong')
else:
val = mapping[version.pre[0]]
release.extend([val, version.pre[-1]])
else:
release.extend([val, 0])

elif version.is_postrelease:
release.extend([1, version.post])
else:
release.extend([0, 0])
return tuple(release)


mmcv_min_version = digit_version(MMCV_MIN)
Expand All @@ -27,4 +58,4 @@ def digit_version(version_str):
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'

__all__ = ['__version__', 'version_info']
__all__ = ['__version__', 'version_info', 'digit_version']
4 changes: 2 additions & 2 deletions mmseg/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg
from mmcv.utils import Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader, DistributedSampler

if platform.system() != 'Windows':
Expand Down Expand Up @@ -133,7 +133,7 @@ def build_dataloader(dataset,
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None

if torch.__version__ >= '1.8.0':
if digit_version(torch.__version__) >= digit_version('1.8.0'):
data_loader = DataLoader(
dataset,
batch_size=batch_size,
Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
matplotlib
numpy
packaging
prettytable
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmseg
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch,ts
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,packaging,prettytable,pytest,scipy,seaborn,torch,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
20 changes: 20 additions & 0 deletions tests/test_digit_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from mmseg import digit_version


def test_digit_version():
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
assert digit_version('1.0') == digit_version('1.0.0')
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
assert digit_version('1.0.0a') < digit_version('1.0.0b')
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
assert digit_version('1.0.0') < digit_version('1.0.0post')
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)