From 47825b194e226d17cfb6ffa4a9f6c8409ac563ee Mon Sep 17 00:00:00 2001 From: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Date: Sat, 10 Apr 2021 21:20:53 +0800 Subject: [PATCH] [Refactoring] Revise init_weight in BaseModule (#905) * [Refactoring] Add deprecated API warning * revise test * fix lint * fix lint --- mmcv/runner/base_module.py | 6 +++--- tests/test_runner/test_basemodule.py | 8 +++++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mmcv/runner/base_module.py b/mmcv/runner/base_module.py index 418caa1a19..eea0a8b768 100644 --- a/mmcv/runner/base_module.py +++ b/mmcv/runner/base_module.py @@ -42,9 +42,9 @@ def init_weight(self): if not self._is_init: if hasattr(self, 'init_cfg'): initialize(self, self.init_cfg) - for module in self.children(): - if 'init_weight' in dir(module): - module.init_weight() + for m in self.children(): + if hasattr(m, 'init_weight'): + m.init_weight() self._is_init = True else: warnings.warn(f'init_weight of {self.__class__.__name__} has ' diff --git a/tests/test_runner/test_basemodule.py b/tests/test_runner/test_basemodule.py index b57752b888..807e8acc3c 100644 --- a/tests/test_runner/test_basemodule.py +++ b/tests/test_runner/test_basemodule.py @@ -336,10 +336,12 @@ def test_sequential_model_weight_init(): assert torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.)) # inner init_cfg has highter priority + layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] seq_model = Sequential( *layers, init_cfg=dict( type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + seq_model.init_weight() assert torch.equal(seq_model[0].conv1d.weight, torch.full(seq_model[0].conv1d.weight.shape, 0.)) assert torch.equal(seq_model[0].conv1d.bias, @@ -371,8 +373,12 @@ def test_modulelist_weight_init(): assert torch.equal(modellist[1].conv2d.bias, torch.full(modellist[1].conv2d.bias.shape, 3.)) # inner init_cfg has highter priority + layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] modellist = ModuleList( - layers, init_cfg=dict(type='Constant', val=4., bias=5.)) + layers, + init_cfg=dict( + type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + modellist.init_weight() assert torch.equal(modellist[0].conv1d.weight, torch.full(modellist[0].conv1d.weight.shape, 0.)) assert torch.equal(modellist[0].conv1d.bias,