Skip to content

Commit

Permalink
add model registry (#760)
Browse files Browse the repository at this point in the history
* add model registry

* fixed infer scoep

* fixed build func

* add docstring

* add md

* support multi level

* clean comments

* add docs

* fixed parent

* add more doc

* add value error, add docstring

* fixed docs

* change to local/global search

* resolve comments

* fixed test

* update some docstring

* update docs (minior)

* update docs

* update docs
  • Loading branch information
xvjiarui committed Apr 10, 2021
1 parent 47825b1 commit 375605f
Show file tree
Hide file tree
Showing 6 changed files with 429 additions and 72 deletions.
118 changes: 98 additions & 20 deletions docs/registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,29 @@ In MMCV, registry can be regarded as a mapping that maps a class to a string.
These classes contained by a single registry usually have similar APIs but implement different algorithms or support different datasets.
With the registry, users can find and instantiate the class through its corresponding string, and use the instantiated module as they want.
One typical example is the config systems in most OpenMMLab projects, which use the registry to create hooks, runners, models, and datasets, through configs.
The API reference could be find [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.Registry).

To manage your modules in the codebase by `Registry`, there are three steps as below.

1. Create an registry
2. Create a build method
3. Use this registry to manage the modules
1. Create a build method (optional, in most cases you can just use the default one).
2. Create a registry.
3. Use this registry to manage the modules.

`build_func` argument of `Registry` is to customize how to instantiate the class instance, the default one is `build_from_cfg` implemented [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.build_from_cfg).

### A Simple Example

Here we show a simple example of using registry to manage modules in a package.
You can find more practical examples in OpenMMLab projects.

Assuming we want to implement a series of Dataset Converter for converting different formats of data to the expected data format.
We create directory as a package named `converters`.
We create a directory as a package named `converters`.
In the package, we first create a file to implement builders, named `converters/builder.py`, as below

```python
from mmcv.utils import Registry

# create a registry for converters
CONVERTERS = Registry('converter')


# create a build function
def build_converter(cfg, *args, **kwargs):
cfg_ = cfg.copy()
converter_type = cfg_.pop('type')
if converter_type not in CONVERTERS:
raise KeyError(f'Unrecognized task type {converter_type}')
else:
converter_cls = CONVERTERS.get(converter_type)

converter = converter_cls(*args, **kwargs, **cfg_)
return converter
```

Then we can implement different converters in the package. For example, implement `Converter1` in `converters/converter1.py`
Expand All @@ -51,7 +40,6 @@ Then we can implement different converters in the package. For example, implemen

from .builder import CONVERTERS


# use the registry to manage the module
@CONVERTERS.register_module()
class Converter1(object):
Expand All @@ -71,5 +59,95 @@ If the module is successfully registered, you can use this converter through con

```python
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = build_converter(converter_cfg)
converter = CONVERTERS.build(converter_cfg)
```

## Customize Build Function

Suppose we would like to customize how `converters` are built, we could implement a customized `build_func` and pass it into the registry.

```python
from mmcv.utils import Registry

# create a build function
def build_converter(cfg, registry, *args, **kwargs):
cfg_ = cfg.copy()
converter_type = cfg_.pop('type')
if converter_type not in registry:
raise KeyError(f'Unrecognized converter type {converter_type}')
else:
converter_cls = registry.get(converter_type)

converter = converter_cls(*args, **kwargs, **cfg_)
return converter

# create a registry for converters and pass ``build_converter`` function
CONVERTERS = Registry('converter', build_func=build_converter)
```

Note: in this example, we demonstrate how to use the `build_func` argument to customize the way to build a class instance.
The functionality is similar to the default `build_from_cfg`. In most cases, default one would be sufficient.
`build_model_from_cfg` is also implemented to build PyTorch module in `nn.Sequentail`, you may directly use them instead of implementing by yourself.

## Hierarchy Registry

You could also build modules from more than one OpenMMLab frameworks, e.g. you could use all backbones in [MMClassification](https://github.com/open-mmlab/mmclassification) for object detectors in [MMDetection](https://github.com/open-mmlab/mmdetection), you may also combine an object detection model in [MMDetection](https://github.com/open-mmlab/mmdetection) and semantic segmentation model in [MMSegmentation](https://github.com/open-mmlab/mmsegmentation).

All `MODELS` registries of downstream codebases are children registries of MMCV's `MODELS` registry.
Basically, there are two ways to build a module from child or sibling registries.

1. Build from children registries.

For example:

In MMDetection we define:

```python
from mmcv.utils import Registry
from mmcv.cnn import MODELS as MMCV_MODELS
MODELS = Registry('model', parent=MMCV_MODELS)

@MODELS.register_module()
class NetA(nn.Module):
def forward(self, x):
return x
```

In MMClassification we define:

```python
from mmcv.utils import Registry
from mmcv.cnn import MODELS as MMCV_MODELS
MODELS = Registry('model', parent=MMCV_MODELS)

@MODELS.register_module()
class NetB(nn.Module):
def forward(self, x):
return x + 1
```

We could build two net in either MMDetection or MMClassification by:

```python
from mmdet.models import MODELS
net_a = MODELS.build(cfg=dict(type='NetA'))
net_b = MODELS.build(cfg=dict(type='mmcls.NetB'))
```

or

```python
from mmcls.models import MODELS
net_a = MODELS.build(cfg=dict(type='mmdet.NetA'))
net_b = MODELS.build(cfg=dict(type='NetB'))
```

2. Build from parent registry.

The shared `MODELS` registry in MMCV is the parent registry for all downstream codebases (root registry):

```python
from mmcv.cnn import MODELS as MMCV_MODELS
net_a = MMCV_MODELS.build(cfg=dict(type='mmdet.NetA'))
net_b = MMCV_MODELS.build(cfg=dict(type='mmcls.NetB'))
```
3 changes: 2 additions & 1 deletion mmcv/cnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm)
from .builder import MODELS, build_model_from_cfg
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
Expand All @@ -34,5 +35,5 @@
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit'
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
]
30 changes: 30 additions & 0 deletions mmcv/cnn/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch.nn as nn

from ..utils import Registry, build_from_cfg


def build_model_from_cfg(cfg, registry, default_args=None):
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
Args:
cfg (dict, list[dict]): The config of modules, is is either a config
dict or a list of config dicts. If cfg is a list, a
the built modules will be wrapped with ``nn.Sequential``.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)


MODELS = Registry('model', build_func=build_model_from_cfg)
Loading

0 comments on commit 375605f

Please sign in to comment.