Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Fix the bug when the params in shared modules do not require grad #903

Merged
merged 2 commits into from
Feb 15, 2023

Conversation

HIT-cwh
Copy link
Contributor

@HIT-cwh HIT-cwh commented Feb 3, 2023

How to reproduce this bug

import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.optim import (OptimWrapper, OptimWrapperDict, build_optim_wrapper)
from mmengine import Config, ConfigDict
from torch.optim import Optimizer
from collections import OrderedDict


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        toy_modules = nn.ModuleList()
        toy_modules.append(ConvModule(1, 1, 1))
        toy_modules.append(ConvModule(1, 1, 1))
        toy_modules[-1].conv = toy_modules[0].conv
        self.toy_modules = toy_modules

    def forward(self, x):
        for module in self.toy_modules:
            x = module(x)
        return x

# modified from https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/runner.py#L950-L1097
def my_build_optim_wrapper(optim_wrapper, model):
    if isinstance(optim_wrapper, OptimWrapper):
        return optim_wrapper
    if isinstance(optim_wrapper, (dict, ConfigDict, Config)):
        optimizer = optim_wrapper.get('optimizer', None)

        if isinstance(optimizer, Optimizer):
            optim_wrapper.setdefault('type', 'OptimWrapper')
            return OPTIM_WRAPPERS.build(optim_wrapper)  # type: ignore

        if optimizer is not None or 'constructor' in optim_wrapper:
            return build_optim_wrapper(model, optim_wrapper)
        else:
            optim_wrappers = OrderedDict()
            for name, optim in optim_wrapper.items():
                if not isinstance(optim, OptimWrapper):
                    raise ValueError(
                        'each item mush be an optimizer object when '
                        '"type" and "constructor" are not in '
                        f'optimizer, but got {name}={optim}')
                optim_wrappers[name] = optim
            return OptimWrapperDict(**optim_wrappers)
    else:
        raise TypeError('optimizer wrapper should be an OptimWrapper '
                        f'object or dict, but got {optim_wrapper}')

model = Net()
for name, param in model.named_parameters():
    param.requires_grad_(False)

optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=1., weight_decay=0.05),
    paramwise_cfg=dict(
        norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
my_build_optim_wrapper(optim_wrapper, model)

Error message: ValueError: some parameters appear in more than one parameter group

Modification

The only difference in mmengine/optim/optimizer/default_constructor.py

if not param.requires_grad:
    params.append(param_group)
    continue
if bypass_duplicate and self._is_in(param_group, params):
    warnings.warn(f'{prefix} is duplicate. It is skipped since '
                  f'bypass_duplicate={bypass_duplicate}')
    continue

to

if bypass_duplicate and self._is_in(param_group, params):
    warnings.warn(f'{prefix} is duplicate. It is skipped since '
                  f'bypass_duplicate={bypass_duplicate}')
    continue
if not param.requires_grad:
    params.append(param_group)
    continue

@CLAassistant
Copy link

CLAassistant commented Feb 3, 2023

CLA assistant check
All committers have signed the CLA.

@codecov
Copy link

codecov bot commented Feb 3, 2023

Codecov Report

❗ No coverage uploaded for pull request base (main@6dc1d70). Click here to learn what that means.
Patch has no changes to coverable lines.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #903   +/-   ##
=======================================
  Coverage        ?   78.12%           
=======================================
  Files           ?      132           
  Lines           ?    10031           
  Branches        ?     2004           
=======================================
  Hits            ?     7837           
  Misses          ?     1853           
  Partials        ?      341           
Flag Coverage Δ
unittests 78.12% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@@ -576,6 +579,30 @@ def test_default_optimizer_constructor_bypass_duplicate(self):
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
**paramwise_cfg)

model = ExampleDuplicateModel(duplicate_model_require_grad=False)
Copy link
Collaborator

@HAOCHENYE HAOCHENYE Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model = ExampleDuplicateModel(duplicate_model_require_grad=False)
# `DefaultOptimWrapperConcstructor` can build an optimizer when the model has duplicated and non-grad parameters.
model = ExampleDuplicateModel(duplicate_model_require_grad=False)

@HIT-cwh HIT-cwh force-pushed the fix_DefaultOptimWrapperConstructor branch from 18f417d to e807e8b Compare February 3, 2023 10:22
@zhouzaida zhouzaida merged commit a5f48f7 into open-mmlab:main Feb 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants