Skip to content

Commit

Permalink
rename DSOptimWrapper to DeepSpeedOptimWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Jun 20, 2023
1 parent 9088abc commit 8436d00
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/distributed_training_with_flexible_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def main():
contiguous_gradients=True,
cpu_offload=False))
optim_wrapper = dict(
type='DSOptimWrapper',
type='DeepSpeedOptimWrapper',
optimizer=dict(type=SGD, lr=0.001, momentum=0.9))
else:
strategy = None
Expand Down
2 changes: 2 additions & 0 deletions mmengine/_strategy/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class DeepSpeedStrategy(BaseStrategy):
https://www.deepspeed.ai/docs/config-json/.
Args:
config (str or dict, optional): If it is a string, it is a path to load
config for deepspeed. Defaults to None.
zero_optimization (dict, optional): Enabling and configuring ZeRO
memory optimizations. Defaults to None.
gradient_clipping (float): Enable gradient clipping with value.
Expand Down
4 changes: 2 additions & 2 deletions mmengine/model/wrappers/_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from deepspeed.runtime.engine import DeepSpeedEngine

from mmengine.optim.optimizer._deepspeed import DSOptimWrapper
from mmengine.optim.optimizer._deepspeed import DeepSpeedOptimWrapper
from mmengine.registry import MODEL_WRAPPERS


Expand All @@ -29,7 +29,7 @@ def __getattr__(self, name):
def train_step(
self,
data: Union[dict, tuple, list],
optim_wrapper: DSOptimWrapper,
optim_wrapper: DeepSpeedOptimWrapper,
) -> Dict[str, torch.Tensor]:
data = self.model.module.data_preprocessor(data, training=True)
data = self._cast_inputs_half(data)
Expand Down
4 changes: 2 additions & 2 deletions mmengine/optim/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
]

if is_installed('deepspeed'):
from ._deepspeed import DSOptimWrapper # noqa:F401
__all__.append('DSOptimWrapper')
from ._deepspeed import DeepSpeedOptimWrapper # noqa:F401
__all__.append('DeepSpeedOptimWrapper')
2 changes: 1 addition & 1 deletion mmengine/optim/optimizer/_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@OPTIM_WRAPPERS.register_module()
class DSOptimWrapper(BaseOptimWrapper):
class DeepSpeedOptimWrapper(BaseOptimWrapper):

def __init__(self, optimizer):
self.optimizer = optimizer
Expand Down

0 comments on commit 8436d00

Please sign in to comment.