Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] Add logger for initialization of parameters #1150

Merged
merged 15 commits into from
Jul 23, 2021
52 changes: 52 additions & 0 deletions mmcv/cnn/utils/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn
from torch import Tensor

from mmcv.runner.base_module import update_init_infos
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log

INITIALIZERS = Registry('initializer')
Expand Down Expand Up @@ -122,6 +123,10 @@ def __init__(self, *, bias=0, bias_prob=None, layer=None):
self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer

def _get_init_info(self):
info = f', bias={self.bias}'
jshilong marked this conversation as resolved.
Show resolved Hide resolved
return info


@INITIALIZERS.register_module(name='Constant')
class ConstantInit(BaseInit):
Expand Down Expand Up @@ -152,6 +157,12 @@ def init(m):
constant_init(m, self.val, self.bias)

module.apply(init)
if hasattr(module, 'params_init_info'):
update_init_infos(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: val={self.val}'
return info + super()._get_init_info()


@INITIALIZERS.register_module(name='Xavier')
Expand Down Expand Up @@ -189,6 +200,13 @@ def init(m):
xavier_init(m, self.gain, self.bias, self.distribution)

module.apply(init)
if hasattr(module, 'params_init_info'):
update_init_infos(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
f'distribution={self.distribution}'
return info + super()._get_init_info()


@INITIALIZERS.register_module(name='Normal')
Expand Down Expand Up @@ -225,6 +243,12 @@ def init(m):
normal_init(m, self.mean, self.std, self.bias)

module.apply(init)
if hasattr(module, 'params_init_info'):
update_init_infos(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: mean={self.mean}, std={self.std}'
return info + super()._get_init_info()


@INITIALIZERS.register_module(name='TruncNormal')
Expand Down Expand Up @@ -273,6 +297,13 @@ def init(m):
self.bias)

module.apply(init)
if hasattr(module, 'params_init_info'):
update_init_infos(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
f' mean={self.mean}, std={self.std}'
return info + super()._get_init_info()


@INITIALIZERS.register_module(name='Uniform')
Expand Down Expand Up @@ -309,6 +340,12 @@ def init(m):
uniform_init(m, self.a, self.b, self.bias)

module.apply(init)
if hasattr(module, 'params_init_info'):
update_init_infos(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, b={self.b}'
return info + super()._get_init_info()


@INITIALIZERS.register_module(name='Kaiming')
Expand Down Expand Up @@ -364,6 +401,14 @@ def init(m):
self.bias, self.distribution)

module.apply(init)
if hasattr(module, 'params_init_info'):
update_init_infos(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
f'nonlinearity={self.nonlinearity}, ' \
f'distribution ={self.distribution}'
return info + super()._get_init_info()


@INITIALIZERS.register_module(name='Caffe2Xavier')
Expand Down Expand Up @@ -422,6 +467,13 @@ def __call__(self, module):
self.prefix, self.checkpoint, map_location=self.map_location)
load_state_dict(module, state_dict, strict=False, logger=logger)

if hasattr(module, 'params_init_info'):
update_init_infos(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
return info


def _initialize(module, cfg, wholemodule=False):
func = build_from_cfg(cfg, INITIALIZERS)
Expand Down
56 changes: 53 additions & 3 deletions mmcv/runner/base_module.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# Copyright (c) Open-MMLab. All rights reserved.
import warnings
from abc import ABCMeta
from collections import defaultdict

import torch.nn as nn

from mmcv import ConfigDict
from mmcv.utils.logging import logger_initialized, print_log


def update_init_infos(module, *, init_info):
jshilong marked this conversation as resolved.
Show resolved Hide resolved
for param in module.parameters():
if module.params_init_info[param]['tmp_sum_value'] != param.data.sum():
module.params_init_info[param]['init_info'] = init_info
module.params_init_info[param]['tmp_sum_value'] = param.data.sum()
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
jshilong marked this conversation as resolved.
Show resolved Hide resolved


class BaseModule(nn.Module, metaclass=ABCMeta):
Expand All @@ -26,6 +34,14 @@ def __init__(self, init_cfg=None):
self._is_init = False
self.init_cfg = init_cfg
jshilong marked this conversation as resolved.
Show resolved Hide resolved

# The `params_init_info` is used to record the initialization
# information of the parameters
# the key should be the obj:`nn.Parameter` of model and the value
# should be a dict contains
jshilong marked this conversation as resolved.
Show resolved Hide resolved
# `params_name`, `init_info` and `tmp_sum_value`.
jshilong marked this conversation as resolved.
Show resolved Hide resolved
# this attribute would be deleted after all parameters is initialized.
self.params_init_info = defaultdict(dict)
jshilong marked this conversation as resolved.
Show resolved Hide resolved

# Backward compatibility in derived classes
# if pretrained is not None:
# warnings.warn('DeprecationWarning: pretrained is a deprecated \
Expand All @@ -38,12 +54,33 @@ def is_init(self):

def init_weights(self):
"""Initialize the weights."""
from ..cnn import initialize

# judge if it is Topmost module
jshilong marked this conversation as resolved.
Show resolved Hide resolved
top_most_module = len(self.params_init_info) == 0
jshilong marked this conversation as resolved.
Show resolved Hide resolved
if top_most_module:
for name, param in self.named_parameters():
jshilong marked this conversation as resolved.
Show resolved Hide resolved
self.params_init_info[param]['params_name'] = name
self.params_init_info[param][
'init_info'] = 'The value is the same as ' \
'the default initialization of PyTorch'
self.params_init_info[param]['tmp_sum_value'] = param.data.sum(
)
# pass `params_init_info` to all submodules
jshilong marked this conversation as resolved.
Show resolved Hide resolved
for sub_moduls in self.modules():
jshilong marked this conversation as resolved.
Show resolved Hide resolved
sub_moduls.params_init_info = self.params_init_info

loggernames = list(logger_initialized.keys())
jshilong marked this conversation as resolved.
Show resolved Hide resolved
loggername = loggernames[0] if len(loggernames) > 0 else 'mmcv'
jshilong marked this conversation as resolved.
Show resolved Hide resolved

from ..cnn import initialize
modulename = self.__class__.__name__
jshilong marked this conversation as resolved.
Show resolved Hide resolved
if not self._is_init:
if self.init_cfg:
print_log(
f'initialize {modulename} with init_cfg {self.init_cfg}',
logger=loggername)
jshilong marked this conversation as resolved.
Show resolved Hide resolved
initialize(self, self.init_cfg)
if isinstance(self.init_cfg, (dict, ConfigDict)):
if isinstance(self.init_cfg, dict):
# Avoid the parameters of the pre-training model
# being overwritten by the init_weights
# of the children.
Expand All @@ -53,11 +90,24 @@ def init_weights(self):
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights()
# user may overload the `init_weights`
jshilong marked this conversation as resolved.
Show resolved Hide resolved
update_init_infos(
m,
init_info='Initialized by user-defined `init_weights`')

self._is_init = True
else:
warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.')

if top_most_module:
for item in list(self.params_init_info.values()):
print_log(
f"{item['params_name']} - {item['init_info']}",
logger=loggername)
for sub_moduls in self.modules():
del sub_moduls.params_init_info

def __repr__(self):
s = super().__repr__()
if self.init_cfg:
Expand Down