Skip to content

Commit

Permalink
Merge b8db0e2 into 2df5bc1
Browse files Browse the repository at this point in the history
  • Loading branch information
C1rN09 committed Oct 11, 2022
2 parents 2df5bc1 + b8db0e2 commit 6a3ee9c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 26 deletions.
12 changes: 9 additions & 3 deletions mmengine/registry/build_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def build_from_cfg(
'can be found at '
'https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
)
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
# this will include classes, functions, partial functions and more
elif callable(obj_type):
obj_cls = obj_type
else:
raise TypeError(
Expand All @@ -120,10 +121,15 @@ def build_from_cfg(
else:
obj = obj_cls(**args) # type: ignore

# For some rare cases (e.g. obj_cls is a partial function), obj_cls
# doesn't have the following attributes. Use default value to
# prevent error
cls_name = getattr(obj_cls, '__name__', str(obj_cls))
cls_module = getattr(obj_cls, '__module__', 'unknown')
print_log(
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
f'An `{cls_name}` instance is built from ' # type: ignore # noqa: E501
'registry, its implementation can be found in '
f'{obj_cls.__module__}', # type: ignore
f'{cls_module}', # type: ignore
logger='current',
level=logging.DEBUG)
return obj
Expand Down
13 changes: 8 additions & 5 deletions mmengine/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,11 @@ def get(self, key: str) -> Optional[Type]:
obj_cls = root.get(key)

if obj_cls is not None:
# For some rare cases (e.g. obj_cls is a partial function), obj_cls
# doesn't have `__name__`. Use default value to prevent error
cls_name = getattr(obj_cls, '__name__', str(obj_cls))
print_log(
f'Get class `{obj_cls.__name__}` from "{registry_name}"'
f'Get class `{cls_name}` from "{registry_name}"'
f' registry in "{scope_name}"',
logger='current',
level=logging.DEBUG)
Expand Down Expand Up @@ -441,16 +444,16 @@ def _register_module(self,
"""Register a module.
Args:
module (type): Module class or function to be registered.
module (type): Module to be registered. Typically a class or a
function, but generally all ``Callable`` are acceptable.
module_name (str or list of str, optional): The module name to be
registered. If not specified, the class name will be used.
Defaults to None.
force (bool): Whether to override an existing class with the same
name. Defaults to False.
"""
if not inspect.isclass(module) and not inspect.isfunction(module):
raise TypeError('module must be a class or a function, '
f'but got {type(module)}')
if not callable(module):
raise TypeError(f'module must be Callable, but got {type(module)}')

if module_name is None:
module_name = module.__name__
Expand Down
52 changes: 34 additions & 18 deletions tests/test_registry/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import time

import pytest
Expand Down Expand Up @@ -59,23 +60,12 @@ def test_register_module(self):
CATS = Registry('cat')

@CATS.register_module()
def muchkin():
def muchkin(size):
pass

assert CATS.get('muchkin') is muchkin
assert 'muchkin' in CATS

# can only decorate a class or a function
with pytest.raises(TypeError):

class Demo:

def some_method(self):
pass

method = Demo().some_method
CATS.register_module(name='some_method', module=method)

# test `name` parameter which must be either of None, a string or a
# sequence of string
# `name` is None
Expand Down Expand Up @@ -146,7 +136,7 @@ class BritishShorthair:
# decorator, which must be a class
with pytest.raises(
TypeError,
match='module must be a class or a function,'
match='module must be Callable,'
" but got <class 'str'>"):
CATS.register_module(module='string')

Expand All @@ -166,6 +156,14 @@ class SphynxCat:
assert CATS.get('Sphynx3') is SphynxCat
assert len(CATS) == 9

# partial functions can be registered
muchkin0 = functools.partial(muchkin, size=0)
CATS.register_module('muchkin0', False, muchkin0)

assert CATS.get('muchkin0') is muchkin0
assert 'muchkin0' in CATS
assert len(CATS) == 10

def _build_registry(self):
"""A helper function to build a Hierarchical Registry."""
# Hierarchical Registry
Expand Down Expand Up @@ -227,12 +225,21 @@ def test_get(self):
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:]

@DOGS.register_module()
def bark(word, times):
return [word] * times

dog_bark = functools.partial(bark, 'woof')
DOGS.register_module('dog_bark', False, dog_bark)

@DOGS.register_module()
class GoldenRetriever:
pass

assert len(DOGS) == 1
assert len(DOGS) == 3
assert DOGS.get('GoldenRetriever') is GoldenRetriever
assert DOGS.get('bark') is bark
assert DOGS.get('dog_bark') is dog_bark

@HOUNDS.register_module()
class BloodHound:
Expand All @@ -249,6 +256,8 @@ class BloodHound:
# If the key is not found in the current registry, then look for its
# parent
assert HOUNDS.get('GoldenRetriever') is GoldenRetriever
assert HOUNDS.get('bark') is bark
assert HOUNDS.get('dog_bark') is dog_bark

@LITTLE_HOUNDS.register_module()
class Dachshund:
Expand Down Expand Up @@ -340,18 +349,23 @@ def test_build(self, cfg_type):
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5]

@DOGS.register_module()
def bark(times=1):
return ' '.join(['woof'] * times)
def bark(word, times):
return ' '.join([word] * times)

dog_bark = functools.partial(bark, word='woof')
DOGS.register_module('dog_bark', False, dog_bark)

bark_cfg = cfg_type(dict(type='bark', times=3))
assert DOGS.build(bark_cfg) == 'woof woof woof'
bark_cfg = cfg_type(dict(type='bark', word='meow', times=3))
dog_bark_cfg = cfg_type(dict(type='dog_bark', times=3))

@DOGS.register_module()
class GoldenRetriever:
pass

gr_cfg = cfg_type(dict(type='GoldenRetriever'))
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
assert DOGS.build(bark_cfg) == 'meow meow meow'
assert DOGS.build(dog_bark_cfg) == 'woof woof woof'

@HOUNDS.register_module()
class BloodHound:
Expand All @@ -360,6 +374,8 @@ class BloodHound:
bh_cfg = cfg_type(dict(type='BloodHound'))
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
assert HOUNDS.build(bark_cfg) == 'meow meow meow'
assert HOUNDS.build(dog_bark_cfg) == 'woof woof woof'

@LITTLE_HOUNDS.register_module()
class Dachshund:
Expand Down

0 comments on commit 6a3ee9c

Please sign in to comment.