Skip to content

Commit

Permalink
test DefaultOptimWrapperConstructor when the params in shared modules…
Browse files Browse the repository at this point in the history
… do not require grad
  • Loading branch information
HIT-cwh committed Feb 3, 2023
1 parent bad1c33 commit 18f417d
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self):

class ExampleDuplicateModel(nn.Module):

def __init__(self):
def __init__(self, duplicate_model_require_grad: bool = True):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=False))
Expand All @@ -51,6 +51,9 @@ def __init__(self):
self.sub = SubModel()
self.conv3 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=False))
self.conv3[0] = self.conv1[0]
if not duplicate_model_require_grad:
self.conv1[0].requires_grad_(False)
self.conv3[0].requires_grad_(False)
if MMCV_FULL_AVAILABLE:
from mmcv.ops import DeformConv2dPack
self.dcn = DeformConv2dPack(
Expand Down Expand Up @@ -576,6 +579,30 @@ def test_default_optimizer_constructor_bypass_duplicate(self):
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
**paramwise_cfg)

model = ExampleDuplicateModel(duplicate_model_require_grad=False)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1,
flat_decay_mult=0.3,
bypass_duplicate=True)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)

self.assertWarnsRegex(
Warning,
'conv3.0 is duplicate. It is skipped since bypass_duplicate=True',
lambda: optim_constructor(model))
optim_wrapper = optim_constructor(model)
model_parameters = list(model.parameters())
num_params = 14 if MMCV_FULL_AVAILABLE else 11
assert len(optim_wrapper.optimizer.param_groups) == len(
model_parameters) == num_params
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
**paramwise_cfg)

def test_default_optimizer_constructor_custom_key(self):
# test DefaultOptimWrapperConstructor with custom_keys and
# ExampleModel
Expand Down

0 comments on commit 18f417d

Please sign in to comment.