Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix weight_init.py #825

Merged
merged 2 commits into from
Feb 7, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 26 additions & 28 deletions mmcv/cnn/utils/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def bias_init_with_prob(prior_prob):

class BaseInit(object):

def __init__(self, bias, bias_prob, layer):
def __init__(self, *, bias=0, bias_prob=None, layer=None):
if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a numbel, but got a {type(bias)}')

Expand All @@ -88,7 +88,7 @@ def __init__(self, bias, bias_prob, layer):

if layer is not None:
if not isinstance(layer, (str, list)):
raise TypeError(f'layer must be str or list[str], \
raise TypeError(f'layer must be a str or a list of str, \
but got a {type(layer)}')

if bias_prob is not None:
Expand All @@ -112,8 +112,8 @@ class ConstantInit(BaseInit):
Defaults to None.
"""

def __init__(self, val, bias=0, bias_prob=None, layer=None):
super().__init__(bias, bias_prob, layer)
def __init__(self, val, **kwargs):
super().__init__(**kwargs)
self.val = val

def __call__(self, module):
Expand Down Expand Up @@ -149,13 +149,8 @@ class XavierInit(BaseInit):
Defaults to None.
"""

def __init__(self,
gain=1,
bias=0,
bias_prob=None,
distribution='normal',
layer=None):
super().__init__(bias, bias_prob, layer)
def __init__(self, gain=1, distribution='normal', **kwargs):
super().__init__(**kwargs)
self.gain = gain
self.distribution = distribution

Expand Down Expand Up @@ -191,8 +186,8 @@ class NormalInit(BaseInit):

"""

def __init__(self, mean=0, std=1, bias=0, bias_prob=None, layer=None):
super().__init__(bias, bias_prob, layer)
def __init__(self, mean=0, std=1, **kwargs):
super().__init__(**kwargs)
self.mean = mean
self.std = std

Expand Down Expand Up @@ -228,8 +223,8 @@ class UniformInit(BaseInit):
Defaults to None.
"""

def __init__(self, a=0, b=1, bias=0, bias_prob=None, layer=None):
super().__init__(bias, bias_prob, layer)
def __init__(self, a=0, b=1, **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b

Expand Down Expand Up @@ -279,11 +274,9 @@ def __init__(self,
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
bias_prob=None,
distribution='normal',
layer=None):
super().__init__(bias, bias_prob, layer)
**kwargs):
super().__init__(**kwargs)
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
Expand All @@ -307,10 +300,15 @@ def init(m):

@INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object):
"""Initialize module by loading a pretrained model
"""Initialize module by loading a pretrained model.

Args:
checkpoint (str): the file should be load
prefix (str, optional): the prefix to indicate the sub-module.
checkpoint (str): the checkpoint file of the pretrained model should
be load.
prefix (str, optional): the prefix of a sub-module in the pretrained
model. it is for loading a part of the pretrained model to
initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None.
"""

Expand Down Expand Up @@ -347,8 +345,8 @@ def _initialize(module, cfg):

def _initialize_override(module, override):
if not isinstance(override, (dict, list)):
raise TypeError(
f'override must be a dict or list, but got {type(override)}')
raise TypeError(f'override must be a dict or a list of dict, \
but got {type(override)}')

override = [override] if isinstance(override, dict) else override

Expand All @@ -366,10 +364,9 @@ def initialize(module, init_cfg):
Args:
module (``torch.nn.Module``): the module will be initialized.
init_cfg (dict | list[dict]): initialization configuration dict to
define initializer. OpenMMLab has implemented 7 initializers
define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, ``Pretrained`` and ``BiasProb`` for bias
initialization.
``Kaiming``, and ``Pretrained``.

Example:
>>> module = nn.Linear(2, 3, bias=True)
Expand Down Expand Up @@ -415,7 +412,8 @@ def initialize(module, init_cfg):
checkpoint=url, prefix='backbone.')
"""
if not isinstance(init_cfg, (dict, list)):
raise TypeError(f'init_cfg must be a dict, but got {type(init_cfg)}')
raise TypeError(f'init_cfg must be a dict or a list of dict, \
but got {type(init_cfg)}')

if isinstance(init_cfg, dict):
init_cfg = [init_cfg]
Expand Down