Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor]: Inherits mmcv registry #252

Merged
merged 1 commit into from
May 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions mmcls/models/builder.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,34 @@
import torch.nn as nn
from mmcv.utils import Registry, build_from_cfg
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry

BACKBONES = Registry('backbone')
CLASSIFIERS = Registry('classifier')
HEADS = Registry('head')
NECKS = Registry('neck')
LOSSES = Registry('loss')
MODELS = Registry('models', parent=MMCV_MODELS)


def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
LOSSES = MODELS
CLASSIFIERS = MODELS


def build_backbone(cfg):
return build(cfg, BACKBONES)
"""Build backbone."""
return BACKBONES.build(cfg)


def build_head(cfg):
return build(cfg, HEADS)
def build_neck(cfg):
"""Build neck."""
return NECKS.build(cfg)


def build_neck(cfg):
return build(cfg, NECKS)
def build_head(cfg):
"""Build head."""
return HEADS.build(cfg)


def build_loss(cfg):
return build(cfg, LOSSES)
"""Build loss."""
return LOSSES.build(cfg)


def build_classifier(cfg):
return build(cfg, CLASSIFIERS)
return CLASSIFIERS.build(cfg)