diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py index 26446af5af..9c28c571c0 100644 --- a/mmengine/registry/build_functions.py +++ b/mmengine/registry/build_functions.py @@ -104,7 +104,8 @@ def build_from_cfg( 'can be found at ' 'https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501 ) - elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): + # this will include classes, functions, partial functions and more + elif callable(obj_type): obj_cls = obj_type else: raise TypeError( @@ -120,10 +121,15 @@ def build_from_cfg( else: obj = obj_cls(**args) # type: ignore + # For some rare cases (e.g. obj_cls is a partial function), obj_cls + # doesn't have the following attributes. Use default value to + # prevent error + cls_name = getattr(obj_cls, '__name__', str(obj_cls)) + cls_module = getattr(obj_cls, '__module__', 'unknown') print_log( - f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 + f'An `{cls_name}` instance is built from ' # type: ignore # noqa: E501 'registry, its implementation can be found in ' - f'{obj_cls.__module__}', # type: ignore + f'{cls_module}', # type: ignore logger='current', level=logging.DEBUG) return obj diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 35fd75de25..024809fad1 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -363,8 +363,11 @@ def get(self, key: str) -> Optional[Type]: obj_cls = root.get(key) if obj_cls is not None: + # For some rare cases (e.g. obj_cls is a partial function), obj_cls + # doesn't have `__name__`. Use default value to prevent error + cls_name = getattr(obj_cls, '__name__', str(obj_cls)) print_log( - f'Get class `{obj_cls.__name__}` from "{registry_name}"' + f'Get class `{cls_name}` from "{registry_name}"' f' registry in "{scope_name}"', logger='current', level=logging.DEBUG) @@ -441,16 +444,16 @@ def _register_module(self, """Register a module. Args: - module (type): Module class or function to be registered. + module (type): Module to be registered. Typically a class or a + function, but generally all ``Callable`` are acceptable. module_name (str or list of str, optional): The module name to be registered. If not specified, the class name will be used. Defaults to None. force (bool): Whether to override an existing class with the same name. Defaults to False. """ - if not inspect.isclass(module) and not inspect.isfunction(module): - raise TypeError('module must be a class or a function, ' - f'but got {type(module)}') + if not callable(module): + raise TypeError(f'module must be Callable, but got {type(module)}') if module_name is None: module_name = module.__name__ diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 1829920f96..9ea4e4e22f 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import functools import time import pytest @@ -59,23 +60,12 @@ def test_register_module(self): CATS = Registry('cat') @CATS.register_module() - def muchkin(): + def muchkin(size): pass assert CATS.get('muchkin') is muchkin assert 'muchkin' in CATS - # can only decorate a class or a function - with pytest.raises(TypeError): - - class Demo: - - def some_method(self): - pass - - method = Demo().some_method - CATS.register_module(name='some_method', module=method) - # test `name` parameter which must be either of None, a string or a # sequence of string # `name` is None @@ -146,7 +136,7 @@ class BritishShorthair: # decorator, which must be a class with pytest.raises( TypeError, - match='module must be a class or a function,' + match='module must be Callable,' " but got "): CATS.register_module(module='string') @@ -166,6 +156,14 @@ class SphynxCat: assert CATS.get('Sphynx3') is SphynxCat assert len(CATS) == 9 + # partial functions can be registered + muchkin0 = functools.partial(muchkin, size=0) + CATS.register_module('muchkin0', False, muchkin0) + + assert CATS.get('muchkin0') is muchkin0 + assert 'muchkin0' in CATS + assert len(CATS) == 10 + def _build_registry(self): """A helper function to build a Hierarchical Registry.""" # Hierarchical Registry @@ -227,12 +225,21 @@ def test_get(self): DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3] MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:] + @DOGS.register_module() + def bark(word, times): + return [word] * times + + dog_bark = functools.partial(bark, 'woof') + DOGS.register_module('dog_bark', False, dog_bark) + @DOGS.register_module() class GoldenRetriever: pass - assert len(DOGS) == 1 + assert len(DOGS) == 3 assert DOGS.get('GoldenRetriever') is GoldenRetriever + assert DOGS.get('bark') is bark + assert DOGS.get('dog_bark') is dog_bark @HOUNDS.register_module() class BloodHound: @@ -249,6 +256,8 @@ class BloodHound: # If the key is not found in the current registry, then look for its # parent assert HOUNDS.get('GoldenRetriever') is GoldenRetriever + assert HOUNDS.get('bark') is bark + assert HOUNDS.get('dog_bark') is dog_bark @LITTLE_HOUNDS.register_module() class Dachshund: @@ -340,11 +349,14 @@ def test_build(self, cfg_type): DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5] @DOGS.register_module() - def bark(times=1): - return ' '.join(['woof'] * times) + def bark(word, times): + return ' '.join([word] * times) + + dog_bark = functools.partial(bark, word='woof') + DOGS.register_module('dog_bark', False, dog_bark) - bark_cfg = cfg_type(dict(type='bark', times=3)) - assert DOGS.build(bark_cfg) == 'woof woof woof' + bark_cfg = cfg_type(dict(type='bark', word='meow', times=3)) + dog_bark_cfg = cfg_type(dict(type='dog_bark', times=3)) @DOGS.register_module() class GoldenRetriever: @@ -352,6 +364,8 @@ class GoldenRetriever: gr_cfg = cfg_type(dict(type='GoldenRetriever')) assert isinstance(DOGS.build(gr_cfg), GoldenRetriever) + assert DOGS.build(bark_cfg) == 'meow meow meow' + assert DOGS.build(dog_bark_cfg) == 'woof woof woof' @HOUNDS.register_module() class BloodHound: @@ -360,6 +374,8 @@ class BloodHound: bh_cfg = cfg_type(dict(type='BloodHound')) assert isinstance(HOUNDS.build(bh_cfg), BloodHound) assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever) + assert HOUNDS.build(bark_cfg) == 'meow meow meow' + assert HOUNDS.build(dog_bark_cfg) == 'woof woof woof' @LITTLE_HOUNDS.register_module() class Dachshund: