diff --git a/docs/registry.md b/docs/registry.md index 684bcedf2f..3793224b6d 100644 --- a/docs/registry.md +++ b/docs/registry.md @@ -9,12 +9,15 @@ In MMCV, registry can be regarded as a mapping that maps a class to a string. These classes contained by a single registry usually have similar APIs but implement different algorithms or support different datasets. With the registry, users can find and instantiate the class through its corresponding string, and use the instantiated module as they want. One typical example is the config systems in most OpenMMLab projects, which use the registry to create hooks, runners, models, and datasets, through configs. +The API reference could be find [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.Registry). To manage your modules in the codebase by `Registry`, there are three steps as below. -1. Create an registry -2. Create a build method -3. Use this registry to manage the modules +1. Create a build method (optional, in most cases you can just use the default one). +2. Create a registry. +3. Use this registry to manage the modules. + +`build_func` argument of `Registry` is to customize how to instantiate the class instance, the default one is `build_from_cfg` implemented [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.build_from_cfg). ### A Simple Example @@ -22,27 +25,13 @@ Here we show a simple example of using registry to manage modules in a package. You can find more practical examples in OpenMMLab projects. Assuming we want to implement a series of Dataset Converter for converting different formats of data to the expected data format. -We create directory as a package named `converters`. +We create a directory as a package named `converters`. In the package, we first create a file to implement builders, named `converters/builder.py`, as below ```python from mmcv.utils import Registry - # create a registry for converters CONVERTERS = Registry('converter') - - -# create a build function -def build_converter(cfg, *args, **kwargs): - cfg_ = cfg.copy() - converter_type = cfg_.pop('type') - if converter_type not in CONVERTERS: - raise KeyError(f'Unrecognized task type {converter_type}') - else: - converter_cls = CONVERTERS.get(converter_type) - - converter = converter_cls(*args, **kwargs, **cfg_) - return converter ``` Then we can implement different converters in the package. For example, implement `Converter1` in `converters/converter1.py` @@ -51,7 +40,6 @@ Then we can implement different converters in the package. For example, implemen from .builder import CONVERTERS - # use the registry to manage the module @CONVERTERS.register_module() class Converter1(object): @@ -71,5 +59,95 @@ If the module is successfully registered, you can use this converter through con ```python converter_cfg = dict(type='Converter1', a=a_value, b=b_value) -converter = build_converter(converter_cfg) +converter = CONVERTERS.build(converter_cfg) +``` + +## Customize Build Function + +Suppose we would like to customize how `converters` are built, we could implement a customized `build_func` and pass it into the registry. + +```python +from mmcv.utils import Registry + +# create a build function +def build_converter(cfg, registry, *args, **kwargs): + cfg_ = cfg.copy() + converter_type = cfg_.pop('type') + if converter_type not in registry: + raise KeyError(f'Unrecognized converter type {converter_type}') + else: + converter_cls = registry.get(converter_type) + + converter = converter_cls(*args, **kwargs, **cfg_) + return converter + +# create a registry for converters and pass ``build_converter`` function +CONVERTERS = Registry('converter', build_func=build_converter) ``` + +Note: in this example, we demonstrate how to use the `build_func` argument to customize the way to build a class instance. +The functionality is similar to the default `build_from_cfg`. In most cases, default one would be sufficient. +`build_model_from_cfg` is also implemented to build PyTorch module in `nn.Sequentail`, you may directly use them instead of implementing by yourself. + +## Hierarchy Registry + +You could also build modules from more than one OpenMMLab frameworks, e.g. you could use all backbones in [MMClassification](https://github.com/open-mmlab/mmclassification) for object detectors in [MMDetection](https://github.com/open-mmlab/mmdetection), you may also combine an object detection model in [MMDetection](https://github.com/open-mmlab/mmdetection) and semantic segmentation model in [MMSegmentation](https://github.com/open-mmlab/mmsegmentation). + +All `MODELS` registries of downstream codebases are children registries of MMCV's `MODELS` registry. +Basically, there are two ways to build a module from child or sibling registries. + +1. Build from children registries. + + For example: + + In MMDetection we define: + + ```python + from mmcv.utils import Registry + from mmcv.cnn import MODELS as MMCV_MODELS + MODELS = Registry('model', parent=MMCV_MODELS) + + @MODELS.register_module() + class NetA(nn.Module): + def forward(self, x): + return x + ``` + + In MMClassification we define: + + ```python + from mmcv.utils import Registry + from mmcv.cnn import MODELS as MMCV_MODELS + MODELS = Registry('model', parent=MMCV_MODELS) + + @MODELS.register_module() + class NetB(nn.Module): + def forward(self, x): + return x + 1 + ``` + + We could build two net in either MMDetection or MMClassification by: + + ```python + from mmdet.models import MODELS + net_a = MODELS.build(cfg=dict(type='NetA')) + net_b = MODELS.build(cfg=dict(type='mmcls.NetB')) + ``` + + or + + ```python + from mmcls.models import MODELS + net_a = MODELS.build(cfg=dict(type='mmdet.NetA')) + net_b = MODELS.build(cfg=dict(type='NetB')) + ``` + +2. Build from parent registry. + + The shared `MODELS` registry in MMCV is the parent registry for all downstream codebases (root registry): + + ```python + from mmcv.cnn import MODELS as MMCV_MODELS + net_a = MMCV_MODELS.build(cfg=dict(type='mmdet.NetA')) + net_b = MMCV_MODELS.build(cfg=dict(type='mmcls.NetB')) + ``` diff --git a/mmcv/cnn/__init__.py b/mmcv/cnn/__init__.py index 41cf85d4ca..71d2b69357 100644 --- a/mmcv/cnn/__init__.py +++ b/mmcv/cnn/__init__.py @@ -11,6 +11,7 @@ build_activation_layer, build_conv_layer, build_norm_layer, build_padding_layer, build_plugin_layer, build_upsample_layer, conv_ws_2d, is_norm) +from .builder import MODELS, build_model_from_cfg # yapf: enable from .resnet import ResNet, make_res_layer from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit, @@ -34,5 +35,5 @@ 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', - 'Caffe2XavierInit' + 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg' ] diff --git a/mmcv/cnn/builder.py b/mmcv/cnn/builder.py new file mode 100644 index 0000000000..89e8b9e7ac --- /dev/null +++ b/mmcv/cnn/builder.py @@ -0,0 +1,30 @@ +import torch.nn as nn + +from ..utils import Registry, build_from_cfg + + +def build_model_from_cfg(cfg, registry, default_args=None): + """Build a PyTorch model from config dict(s). Different from + ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built. + + Args: + cfg (dict, list[dict]): The config of modules, is is either a config + dict or a list of config dicts. If cfg is a list, a + the built modules will be wrapped with ``nn.Sequential``. + registry (:obj:`Registry`): A registry the module belongs to. + default_args (dict, optional): Default arguments to build the module. + Defaults to None. + + Returns: + nn.Module: A built nn module. + """ + if isinstance(cfg, list): + modules = [ + build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg + ] + return nn.Sequential(*modules) + else: + return build_from_cfg(cfg, registry, default_args) + + +MODELS = Registry('model', build_func=build_model_from_cfg) diff --git a/mmcv/utils/registry.py b/mmcv/utils/registry.py index 5894ad25f2..f4e129e035 100644 --- a/mmcv/utils/registry.py +++ b/mmcv/utils/registry.py @@ -5,16 +5,107 @@ from .misc import is_seq_of +def build_from_cfg(cfg, registry, default_args=None): + """Build a module from config dict. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + if default_args is None or 'type' not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "type", ' + f'but got {cfg}\n{default_args}') + if not isinstance(registry, Registry): + raise TypeError('registry must be an mmcv.Registry object, ' + f'but got {type(registry)}') + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError('default_args must be a dict or None, ' + f'but got {type(default_args)}') + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop('type') + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError( + f'{obj_type} is not in the {registry.name} registry') + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f'{obj_cls.__name__}: {e}') + + class Registry: """A registry to map strings to classes. + Registered object could be built from registry. + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = MODELS.build(dict(type='ResNet')) + + Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for + advanced useage. + Args: name (str): Registry name. + build_func(func, optional): Build function to construct instance from + Registry, func:`build_from_cfg` is used if neither ``parent`` or + ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Default: None. + parent (Registry, optional): Parent registry. The class registered in + children registry could be built from parent. Default: None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. mmdet, mmcls, mmseg. + Default: None. """ - def __init__(self, name): + def __init__(self, name, build_func=None, parent=None, scope=None): self._name = name self._module_dict = dict() + self._children = dict() + self._scope = self.infer_scope() if scope is None else scope + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None def __len__(self): return len(self._module_dict) @@ -28,14 +119,68 @@ def __repr__(self): f'items={self._module_dict})' return format_str + @staticmethod + def infer_scope(): + """Infer the scope of registry. + + The name of the package where registry is defined will be returned. + + Example: + # in mmdet/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + The scope of ``ResNet`` will be ``mmdet``. + + + Returns: + scope (str): The inferred scope name. + """ + # inspect.stack() trace where this function is called, the index-2 + # indicates the frame where `infer_scope()` is called + filename = inspect.getmodule(inspect.stack()[2][0]).__name__ + split_filename = filename.split('.') + return split_filename[0] + + @staticmethod + def split_scope_key(key): + """Split scope and key. + + The first scope will be split from key. + + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + + Return: + scope (str, None): The first scope. + key (str): The remaining key. + """ + split_index = key.find('.') + if split_index != -1: + return key[:split_index], key[split_index + 1:] + else: + return None, key + @property def name(self): return self._name + @property + def scope(self): + return self._scope + @property def module_dict(self): return self._module_dict + @property + def children(self): + return self._children + def get(self, key): """Get the registry record. @@ -45,7 +190,45 @@ def get(self, key): Returns: class: The corresponding class. """ - return self._module_dict.get(key, None) + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + + The ``registry`` will be added as children based on its scope. + The parent registry could build objects from children registry. + + Example: + >>> models = Registry('models') + >>> mmdet_models = Registry('models', parent=models) + >>> @mmdet_models.register_module() + >>> class ResNet: + >>> pass + >>> resnet = models.build(dict(type='mmdet.ResNet')) + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert registry.scope not in self.children, \ + f'scope {registry.scope} exists in {self.name} registry' + self.children[registry.scope] = registry def _register_module(self, module_class, module_name=None, force=False): if not inspect.isclass(module_class): @@ -131,52 +314,3 @@ def _register(cls): return cls return _register - - -def build_from_cfg(cfg, registry, default_args=None): - """Build a module from config dict. - - Args: - cfg (dict): Config dict. It should at least contain the key "type". - registry (:obj:`Registry`): The registry to search the type from. - default_args (dict, optional): Default initialization arguments. - - Returns: - object: The constructed object. - """ - if not isinstance(cfg, dict): - raise TypeError(f'cfg must be a dict, but got {type(cfg)}') - if 'type' not in cfg: - if default_args is None or 'type' not in default_args: - raise KeyError( - '`cfg` or `default_args` must contain the key "type", ' - f'but got {cfg}\n{default_args}') - if not isinstance(registry, Registry): - raise TypeError('registry must be an mmcv.Registry object, ' - f'but got {type(registry)}') - if not (isinstance(default_args, dict) or default_args is None): - raise TypeError('default_args must be a dict or None, ' - f'but got {type(default_args)}') - - args = cfg.copy() - - if default_args is not None: - for name, value in default_args.items(): - args.setdefault(name, value) - - obj_type = args.pop('type') - if isinstance(obj_type, str): - obj_cls = registry.get(obj_type) - if obj_cls is None: - raise KeyError( - f'{obj_type} is not in the {registry.name} registry') - elif inspect.isclass(obj_type): - obj_cls = obj_type - else: - raise TypeError( - f'type must be a str or valid type, but got {type(obj_type)}') - try: - return obj_cls(**args) - except Exception as e: - # Normal TypeError does not print class name. - raise type(e)(f'{obj_cls.__name__}: {e}') diff --git a/tests/test_cnn/test_model_registry.py b/tests/test_cnn/test_model_registry.py new file mode 100644 index 0000000000..86fb15b685 --- /dev/null +++ b/tests/test_cnn/test_model_registry.py @@ -0,0 +1,63 @@ +import torch.nn as nn + +import mmcv +from mmcv.cnn import MODELS, build_model_from_cfg + + +def test_build_model_from_cfg(): + BACKBONES = mmcv.Registry('backbone', build_func=build_model_from_cfg) + + @BACKBONES.register_module() + class ResNet(nn.Module): + + def __init__(self, depth, stages=4): + super().__init__() + self.depth = depth + self.stages = stages + + def forward(self, x): + return x + + @BACKBONES.register_module() + class ResNeXt(nn.Module): + + def __init__(self, depth, stages=4): + super().__init__() + self.depth = depth + self.stages = stages + + def forward(self, x): + return x + + cfg = dict(type='ResNet', depth=50) + model = BACKBONES.build(cfg) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + cfg = dict(type='ResNeXt', depth=50, stages=3) + model = BACKBONES.build(cfg) + assert isinstance(model, ResNeXt) + assert model.depth == 50 and model.stages == 3 + + cfg = [ + dict(type='ResNet', depth=50), + dict(type='ResNeXt', depth=50, stages=3) + ] + model = BACKBONES.build(cfg) + assert isinstance(model, nn.Sequential) + assert isinstance(model[0], ResNet) + assert model[0].depth == 50 and model[0].stages == 4 + assert isinstance(model[1], ResNeXt) + assert model[1].depth == 50 and model[1].stages == 3 + + # test inherit `build_func` from parent + NEW_MODELS = mmcv.Registry('models', parent=MODELS, scope='new') + assert NEW_MODELS.build_func is build_model_from_cfg + + # test specify `build_func` + def pseudo_build(cfg): + return cfg + + NEW_MODELS = mmcv.Registry( + 'models', parent=MODELS, build_func=pseudo_build) + assert NEW_MODELS.build_func is pseudo_build diff --git a/tests/test_utils/test_registry.py b/tests/test_utils/test_registry.py index 33d49c5f48..4a1cc90f6e 100644 --- a/tests/test_utils/test_registry.py +++ b/tests/test_utils/test_registry.py @@ -132,6 +132,57 @@ class NewCat2: # end: test old APIs +def test_multi_scope_registry(): + DOGS = mmcv.Registry('dogs') + assert DOGS.name == 'dogs' + assert DOGS.scope == 'test_registry' + assert DOGS.module_dict == {} + assert len(DOGS) == 0 + + @DOGS.register_module() + class GoldenRetriever: + pass + + assert len(DOGS) == 1 + assert DOGS.get('GoldenRetriever') is GoldenRetriever + + HOUNDS = mmcv.Registry('dogs', parent=DOGS, scope='hound') + + @HOUNDS.register_module() + class BloodHound: + pass + + assert len(HOUNDS) == 1 + assert HOUNDS.get('BloodHound') is BloodHound + assert DOGS.get('hound.BloodHound') is BloodHound + assert HOUNDS.get('hound.BloodHound') is BloodHound + + LITTLE_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='little_hound') + + @LITTLE_HOUNDS.register_module() + class Dachshund: + pass + + assert len(LITTLE_HOUNDS) == 1 + assert LITTLE_HOUNDS.get('Dachshund') is Dachshund + assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound + assert HOUNDS.get('little_hound.Dachshund') is Dachshund + assert DOGS.get('hound.little_hound.Dachshund') is Dachshund + + MID_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='mid_hound') + + @MID_HOUNDS.register_module() + class Beagle: + pass + + assert MID_HOUNDS.get('Beagle') is Beagle + assert HOUNDS.get('mid_hound.Beagle') is Beagle + assert DOGS.get('hound.mid_hound.Beagle') is Beagle + assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle + assert MID_HOUNDS.get('hound.BloodHound') is BloodHound + assert MID_HOUNDS.get('hound.Dachshund') is None + + def test_build_from_cfg(): BACKBONES = mmcv.Registry('backbone')