Skip to content

Commit

Permalink
[Fix] ZeroRedundancyOptimizer ambiguous error with param groups when …
Browse files Browse the repository at this point in the history
…pytorch < 1.12.0 (#818)

* fix zero_optimizer error with param groups when pytorch < 1.12.0

* add docstring

* fix docstring

* add unittest

* change ut to use a valid paramwise_cfg

* modify ut

* fix as comments
  • Loading branch information
C1rN09 committed Dec 19, 2022
1 parent 89477b5 commit 4147e97
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 21 deletions.
13 changes: 13 additions & 0 deletions mmengine/optim/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer):
Warnings:
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8.
Warnings:
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.12 to enable param
groups.
Args:
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
or :class:`dict` s giving all parameters, which will be sharded
Expand All @@ -53,6 +57,15 @@ def __init__(self, params, optimizer_type: str, **kwargs):
'`torch.distributed.optim.ZeroReundancyOptimizer` is only '
'available when pytorch version >= 1.8.')
assert is_available(), 'torch.distributed.rpc is not available.'
# Avoid the generator becoming empty after the following check
params = list(params)
assert (
all(isinstance(p, torch.Tensor) for p in params)
or digit_version(TORCH_VERSION) >= digit_version('1.12.0')), (
'PyTorch ZeroRedundancyOptimizer started to support param '
'groups since 1.12.0. Please update your pytorch version to '
'enable this feature, or disable param groups by deleting '
'`paramwise_cfg` filed in config file.')
optimizer_class = getattr(torch.optim, optimizer_type)
# TODO: Register a DDP communication hook for `overlap_with_ddp=True`.
# Currently only `overlap_with_ddp=False` is supported. For more
Expand Down
64 changes: 43 additions & 21 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
from torch.distributed.rpc import is_available

from mmengine.dist import get_rank
from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimWrapperConstructor, OptimWrapper,
build_optim_wrapper)
Expand Down Expand Up @@ -740,28 +741,23 @@ def _check_default_optimizer(self, optimizer, model):
self.assertEqual(optimizer.defaults['lr'], self.base_lr)
self.assertEqual(optimizer.defaults['momentum'], self.momentum)
self.assertEqual(optimizer.defaults['weight_decay'], self.base_wd)
param_groups = optimizer.param_groups[0]
if MMCV_FULL_AVAILABLE:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight',
'dcn.conv_offset.weight', 'dcn.conv_offset.bias'
]
param_groups = optimizer.param_groups
params_set = set(model.parameters())
self.assertEqual(
sum(len(param_group['params']) for param_group in param_groups),
len(params_set))
self.assertTrue(
all(param in params_set for param_group in param_groups
for param in param_group['params']))
state_dict = optimizer.state_dict()
if get_rank() == 0:
self.assertEqual(
sum(len(pg['params']) for pg in state_dict['param_groups']),
len(params_set))
else:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias'
]
param_dict = dict(model.named_parameters())
self.assertEqual(len(param_groups['params']), len(param_names))
for i in range(len(param_groups['params'])):
assert torch.equal(param_groups['params'][i],
param_dict[param_names[i]])
self.assertEqual(state_dict, {})

def test_build_zero_redundancy_optimizer(self):
from torch.distributed.optim import ZeroRedundancyOptimizer
def test_zero_redundancy_optimizer(self):
self._init_dist_env(self.rank, self.world_size)
model = ExampleModel()
self.base_lr = 0.01
Expand All @@ -777,7 +773,6 @@ def test_build_zero_redundancy_optimizer(self):
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper.optimizer, ZeroRedundancyOptimizer)
self._check_default_optimizer(optim_wrapper.optimizer, model)

# test build optimizer without ``optimizer_type``
Expand All @@ -790,6 +785,33 @@ def test_build_zero_redundancy_optimizer(self):
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)

@unittest.skipIf(
digit_version(TORCH_VERSION) < digit_version('1.12.0'),
reason='ZeRO started to support param groups since pytorch 1.12.0')
def test_zero_redundancy_optimizer_with_paramwise_cfg(self):
self._init_dist_env(self.rank, self.world_size)
model = ExampleModel()
self.base_lr = 0.01
self.momentum = 0.0001
self.base_wd = 0.9

# test build function
paramwise_cfg = dict(
custom_keys={
'conv1': dict(lr_mult=0.0, decay_mult=0.0),
'conv2': dict(lr_mult=1.0, decay_mult=2.0)
})
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
optimizer_type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum),
paramwise_cfg=paramwise_cfg)
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self._check_default_optimizer(optim_wrapper.optimizer, model)

def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
Expand Down

0 comments on commit 4147e97

Please sign in to comment.