From 54bd43aa14a20d3a2d805653087b86632ce3824d Mon Sep 17 00:00:00 2001 From: nbei Date: Tue, 5 Jan 2021 13:35:31 +0800 Subject: [PATCH 1/5] allow register multi-name for a module simultaneously --- mmcv/utils/registry.py | 11 +++++++---- tests/test_utils/test_registry.py | 3 +++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mmcv/utils/registry.py b/mmcv/utils/registry.py index 64b83f1d75..fb17508b7a 100644 --- a/mmcv/utils/registry.py +++ b/mmcv/utils/registry.py @@ -54,10 +54,13 @@ def _register_module(self, module_class, module_name=None, force=False): if module_name is None: module_name = module_class.__name__ - if not force and module_name in self._module_dict: - raise KeyError(f'{module_name} is already registered ' - f'in {self.name}') - self._module_dict[module_name] = module_class + if is_str(module_name): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f'{name} is already registered ' + f'in {self.name}') + self._module_dict[name] = module_class def deprecated_register_module(self, cls=None, force=False): warnings.warn( diff --git a/tests/test_utils/test_registry.py b/tests/test_utils/test_registry.py index 104cc1964c..196cae41d3 100644 --- a/tests/test_utils/test_registry.py +++ b/tests/test_utils/test_registry.py @@ -58,6 +58,9 @@ class SphynxCat: CATS.register_module(name='Sphynx', module=SphynxCat) assert CATS.get('Sphynx') is SphynxCat + CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat) + assert CATS.get('Sphynx2') is SphynxCat + repr_str = 'Registry(name=cat, items={' repr_str += ("'BritishShorthair': .BritishShorthair'>, ") From 3b6f7e82e5bd46d14bfbe715b06247bd46924b8d Mon Sep 17 00:00:00 2001 From: nbei Date: Tue, 5 Jan 2021 13:46:23 +0800 Subject: [PATCH 2/5] add assertion for name type --- mmcv/utils/registry.py | 7 ++++++- tests/test_utils/test_registry.py | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mmcv/utils/registry.py b/mmcv/utils/registry.py index fb17508b7a..3371b79bd4 100644 --- a/mmcv/utils/registry.py +++ b/mmcv/utils/registry.py @@ -2,7 +2,7 @@ import warnings from functools import partial -from .misc import is_str +from .misc import is_seq_of, is_str class Registry: @@ -56,6 +56,11 @@ def _register_module(self, module_class, module_name=None, force=False): module_name = module_class.__name__ if is_str(module_name): module_name = [module_name] + else: + assert is_seq_of( + module_name, + str), ('module_name should be either of None, an ' + f'instance of str or list, but got {type(module_name)}') for name in module_name: if not force and name in self._module_dict: raise KeyError(f'{name} is already registered ' diff --git a/tests/test_utils/test_registry.py b/tests/test_utils/test_registry.py index 196cae41d3..1b80a4da79 100644 --- a/tests/test_utils/test_registry.py +++ b/tests/test_utils/test_registry.py @@ -73,6 +73,10 @@ class SphynxCat: repr_str += '})' assert repr(CATS) == repr_str + # name type + with pytest.raises(AssertionError): + CATS.register_module(name=7474741, module=SphynxCat) + # the registered module should be a class with pytest.raises(TypeError): CATS.register_module(0) From b60992fda97066a45b035a4c3dc0072c557e1125 Mon Sep 17 00:00:00 2001 From: nbei Date: Tue, 5 Jan 2021 15:09:46 +0800 Subject: [PATCH 3/5] use isintance intead of is_str --- mmcv/utils/registry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mmcv/utils/registry.py b/mmcv/utils/registry.py index 3371b79bd4..7f96b6cc67 100644 --- a/mmcv/utils/registry.py +++ b/mmcv/utils/registry.py @@ -2,7 +2,7 @@ import warnings from functools import partial -from .misc import is_seq_of, is_str +from .misc import is_seq_of class Registry: @@ -54,7 +54,7 @@ def _register_module(self, module_class, module_name=None, force=False): if module_name is None: module_name = module_class.__name__ - if is_str(module_name): + if isinstance(module_name, str): module_name = [module_name] else: assert is_seq_of( @@ -165,7 +165,7 @@ def build_from_cfg(cfg, registry, default_args=None): args.setdefault(name, value) obj_type = args.pop('type') - if is_str(obj_type): + if isinstance(obj_type, str): obj_cls = registry.get(obj_type) if obj_cls is None: raise KeyError( From 39f0b67fde6548607bd0f9d786fc49ef59983035 Mon Sep 17 00:00:00 2001 From: nbei Date: Wed, 6 Jan 2021 10:25:16 +0800 Subject: [PATCH 4/5] fix bug in unit test --- tests/test_utils/test_registry.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_utils/test_registry.py b/tests/test_utils/test_registry.py index 1b80a4da79..d155b05722 100644 --- a/tests/test_utils/test_registry.py +++ b/tests/test_utils/test_registry.py @@ -70,6 +70,10 @@ class SphynxCat: ".SiameseCat'>, ") repr_str += ("'Sphynx': .SphynxCat'>") + repr_str += ("'Sphynx1': .SphynxCat'>") + repr_str += ("'Sphynx2': .SphynxCat'>") repr_str += '})' assert repr(CATS) == repr_str From dc211c0151472cb124928809dd36d2189dcb3d08 Mon Sep 17 00:00:00 2001 From: nbei Date: Thu, 7 Jan 2021 13:27:11 +0800 Subject: [PATCH 5/5] fix unit test --- tests/test_utils/test_registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utils/test_registry.py b/tests/test_utils/test_registry.py index d155b05722..3106c39e56 100644 --- a/tests/test_utils/test_registry.py +++ b/tests/test_utils/test_registry.py @@ -69,9 +69,9 @@ class SphynxCat: repr_str += ("'Siamese': .SiameseCat'>, ") repr_str += ("'Sphynx': .SphynxCat'>") + ".SphynxCat'>, ") repr_str += ("'Sphynx1': .SphynxCat'>") + ".SphynxCat'>, ") repr_str += ("'Sphynx2': .SphynxCat'>") repr_str += '})'