Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] Add logger for initialization of parameters #1150

Merged
merged 15 commits into from
Jul 23, 2021
54 changes: 54 additions & 0 deletions mmcv/cnn/utils/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn
from torch import Tensor

from mmcv.runner.base_module import update_init_info
jshilong marked this conversation as resolved.
Show resolved Hide resolved
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log

INITIALIZERS = Registry('initializer')
Expand Down Expand Up @@ -122,6 +123,10 @@ def __init__(self, *, bias=0, bias_prob=None, layer=None):
self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer

def _get_init_info(self):
info = f'{self.__class__.__name__}, bias={self.bias}'
return info


@INITIALIZERS.register_module(name='Constant')
class ConstantInit(BaseInit):
Expand Down Expand Up @@ -152,6 +157,12 @@ def init(m):
constant_init(m, self.val, self.bias)

module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
return info


@INITIALIZERS.register_module(name='Xavier')
Expand Down Expand Up @@ -189,6 +200,13 @@ def init(m):
xavier_init(m, self.gain, self.bias, self.distribution)

module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
f'distribution={self.distribution}, bias={self.bias}'
return info


@INITIALIZERS.register_module(name='Normal')
Expand Down Expand Up @@ -225,6 +243,13 @@ def init(m):
normal_init(m, self.mean, self.std, self.bias)

module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: mean={self.mean},' \
f' std={self.std}, bias={self.bias}'
return info


@INITIALIZERS.register_module(name='TruncNormal')
Expand Down Expand Up @@ -273,6 +298,13 @@ def init(m):
self.bias)

module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
f' mean={self.mean}, std={self.std}, bias={self.bias}'
return info


@INITIALIZERS.register_module(name='Uniform')
Expand Down Expand Up @@ -309,6 +341,13 @@ def init(m):
uniform_init(m, self.a, self.b, self.bias)

module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a},' \
f' b={self.b}, bias={self.bias}'
return info


@INITIALIZERS.register_module(name='Kaiming')
Expand Down Expand Up @@ -364,6 +403,14 @@ def init(m):
self.bias, self.distribution)

module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
f'nonlinearity={self.nonlinearity}, ' \
f'distribution ={self.distribution}, bias={self.bias}'
return info


@INITIALIZERS.register_module(name='Caffe2Xavier')
Expand Down Expand Up @@ -422,6 +469,13 @@ def __call__(self, module):
self.prefix, self.checkpoint, map_location=self.map_location)
load_state_dict(module, state_dict, strict=False, logger=logger)

if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())

def _get_init_info(self):
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
return info


def _initialize(module, cfg, wholemodule=False):
func = build_from_cfg(cfg, INITIALIZERS)
Expand Down
146 changes: 132 additions & 14 deletions mmcv/runner/base_module.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,74 @@
# Copyright (c) Open-MMLab. All rights reserved.
import copy
import warnings
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler

import torch.nn as nn

from mmcv import ConfigDict
from mmcv.runner.dist_utils import master_only
from mmcv.utils.logging import get_logger, logger_initialized, print_log


def update_init_info(module, *, init_info):
"""Update the `_params_init_info` in the module if the value of parameters
are changed.

Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
for param in module.parameters():
mean_value = param.data.mean()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value


class BaseModule(nn.Module, metaclass=ABCMeta):
"""Base module for all modules in openmmlab."""
"""Base module for all modules in openmmlab.
jshilong marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, init_cfg=None):
"""Initialize BaseModule, inherited from `torch.nn.Module`
``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
functionality of parameter initialization. Compared with
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.

Args:
init_cfg (dict, optional): Initialization config dict.
"""
- ``init_cfg``: the config to control the initialization.
- ``_params_init_info``: Used to track the parameter
initialization information.
- ``init_weights``: The function of parameter
initialization and recording initialization
information.

Args:
init_cfg (dict, optional): Initialization config dict.
"""

def __init__(self, init_cfg=None):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""

# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.

super(BaseModule, self).__init__()
# define default value of init_cfg instead of hard code
# in init_weight() function
# in init_weights() function
self._is_init = False
self.init_cfg = init_cfg

self.init_cfg = copy.deepcopy(init_cfg)

# The `_params_init_info` is used to record the initialization
# information of the parameters
# the key should be the obj:`nn.Parameter` of model and the value
# should be a dict containing
# - param_name (str): The name of parameter.
# - init_info (str): The string that describes the initialization.
# - tmp_mean_value (FloatTensor): The mean of the parameter,
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters is initialized.
self._params_init_info = defaultdict(dict)

# Backward compatibility in derived classes
# if pretrained is not None:
Expand All @@ -38,26 +82,100 @@ def is_init(self):

def init_weights(self):
"""Initialize the weights."""
from ..cnn import initialize

# check if it is top-level module
is_top_level_module = len(self._params_init_info) == 0
if is_top_level_module:
# Initialize the `_params_init_info`,
# When detecting the `tmp_mean_value` of
# the corresponding parameter is changed, update related
# initialization information
for name, param in self.named_parameters():
jshilong marked this conversation as resolved.
Show resolved Hide resolved
self._params_init_info[param]['param_name'] = name
self._params_init_info[param][
'init_info'] = f'The value is the same before and ' \
f'after calling `init_weights` ' \
f'of {self.__class__.__name__} '
self._params_init_info[param][
'tmp_mean_value'] = param.data.mean()

# pass `params_init_info` to all submodules
jshilong marked this conversation as resolved.
Show resolved Hide resolved
# All submodules share the same `params_init_info`,
# so it will be updated when parameters are
# modified at any level of the model.
for sub_module in self.modules():
sub_module._params_init_info = self._params_init_info

# Get the initialized logger, if not exist,
# create a logger named `mmcv`
logger_names = list(logger_initialized.keys())
logger_name = logger_names[0] if logger_names else 'mmcv'

from ..cnn import initialize
module_name = self.__class__.__name__
if not self._is_init:
if self.init_cfg:
print_log(
f'initialize {module_name} with init_cfg {self.init_cfg}',
logger=logger_name)
initialize(self, self.init_cfg)
if isinstance(self.init_cfg, (dict, ConfigDict)):
# Avoid the parameters of the pre-training model
# being overwritten by the init_weights
# of the children.
if isinstance(self.init_cfg, dict):
# prevent the parameters of
# the pre-trained model
# from being overwritten by
# the `init_weights`
if self.init_cfg['type'] == 'Pretrained':
return

for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights()
# users may overload the `init_weights`
jshilong marked this conversation as resolved.
Show resolved Hide resolved
update_init_info(
m,
init_info=f'Initialized by '
f'user-defined `init_weights`'
f' in {m.__class__.__name__} ')

self._is_init = True
else:
warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.')

if is_top_level_module:
self._dump_init_info(logger_name)

for sub_module in self.modules():
del sub_module._params_init_info

@master_only
def _dump_init_info(self, logger_name):
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.

Args:
logger_name (str): The name of logger.
"""

logger = get_logger(logger_name)

with_file_handler = False
# dump the information to the logger file if there is a `FileHandler`
for handler in logger.handlers:
if isinstance(handler, FileHandler):
handler.stream.write(
'Name of parameter - Initialization information\n')
for item in list(self._params_init_info.values()):
handler.stream.write(
f"{item['param_name']} - {item['init_info']} \n")
handler.stream.flush()
with_file_handler = True
if not with_file_handler:
for item in list(self._params_init_info.values()):
print_log(
f"{item['param_name']} - {item['init_info']}",
logger=logger_name)

def __repr__(self):
s = super().__repr__()
if self.init_cfg:
Expand Down
Loading