Skip to content

Commit

Permalink
[Enhance] Support skipping initialization in BaseModule (#1263)
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE committed Jul 25, 2023
1 parent 6187595 commit 3871881
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
7 changes: 6 additions & 1 deletion mmengine/model/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def __init__(self, init_cfg: Union[dict, List[dict], None] = None):
def is_init(self):
return self._is_init

@is_init.setter
def is_init(self, value):
self._is_init = value

def init_weights(self):
"""Initialize the weights."""

Expand Down Expand Up @@ -127,7 +131,8 @@ def init_weights(self):
for m in self.children():
if is_model_wrapper(m) and not hasattr(m, 'init_weights'):
m = m.module
if hasattr(m, 'init_weights'):
if hasattr(m, 'init_weights') and not getattr(
m, 'is_init', False):
m.init_weights()
# users may overload the `init_weights`
update_init_info(
Expand Down
26 changes: 25 additions & 1 deletion tests/test_model/test_base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os.path as osp
import tempfile
from unittest import TestCase
from unittest.mock import patch
from unittest.mock import Mock, patch

import torch
from torch import nn
Expand Down Expand Up @@ -238,6 +238,30 @@ def __init__(self, module) -> None:
self.assertTrue((model.ddp.module.linear.weight == 1).all())
self.assertTrue((model.ddp.module.linear.bias == 2).all())

# Test submodule.init_weights will be skipped if `is_init` is set
# to True in root model
model: FooModel = FOOMODELS.build(copy.deepcopy(self.model_cfg))
for child in model.children():
child.init_weights = Mock()
model.is_init = True
model.init_weights()
for child in model.children():
child.init_weights.assert_not_called()

# Test submodule.init_weights will be skipped if submodule's `is_init`
# is set to True
model: FooModel = FOOMODELS.build(copy.deepcopy(self.model_cfg))
for child in model.children():
child.init_weights = Mock()
model.component1.is_init = True
model.reg.is_init = True
model.init_weights()
model.component1.init_weights.assert_not_called()
model.component2.init_weights.assert_called_once()
model.component3.init_weights.assert_called_once()
model.component4.init_weights.assert_called_once()
model.reg.init_weights.assert_not_called()

def test_dump_init_info(self):
import os
import shutil
Expand Down

0 comments on commit 3871881

Please sign in to comment.