-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add model registry * fixed infer scoep * fixed build func * add docstring * add md * support multi level * clean comments * add docs * fixed parent * add more doc * add value error, add docstring * fixed docs * change to local/global search * resolve comments * fixed test * update some docstring * update docs (minior) * update docs * update docs
- Loading branch information
Showing
6 changed files
with
429 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch.nn as nn | ||
|
||
from ..utils import Registry, build_from_cfg | ||
|
||
|
||
def build_model_from_cfg(cfg, registry, default_args=None): | ||
"""Build a PyTorch model from config dict(s). Different from | ||
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built. | ||
Args: | ||
cfg (dict, list[dict]): The config of modules, is is either a config | ||
dict or a list of config dicts. If cfg is a list, a | ||
the built modules will be wrapped with ``nn.Sequential``. | ||
registry (:obj:`Registry`): A registry the module belongs to. | ||
default_args (dict, optional): Default arguments to build the module. | ||
Defaults to None. | ||
Returns: | ||
nn.Module: A built nn module. | ||
""" | ||
if isinstance(cfg, list): | ||
modules = [ | ||
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg | ||
] | ||
return nn.Sequential(*modules) | ||
else: | ||
return build_from_cfg(cfg, registry, default_args) | ||
|
||
|
||
MODELS = Registry('model', build_func=build_model_from_cfg) |
Oops, something went wrong.