Skip to content

Commit

Permalink
schedulefree optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed May 31, 2024
1 parent 96eb062 commit 9e654bb
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 0 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
"sacremoses",
"safetensors>=0.4.1",
"sagemaker>=2.31.0",
"schedulefree==1.2.1",
"scikit-learn",
"scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`)
"sentencepiece>=0.1.91,!=0.1.92",
Expand Down
35 changes: 35 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
is_schedulefree_available,
is_in_notebook,
is_ipex_available,
is_lomo_available,
Expand Down Expand Up @@ -1409,6 +1410,33 @@ def optimizer_hook(param):
optimizer_cls = Lomo

optimizer_kwargs.update({"model": model})
elif args.optim in [
OptimizerNames.SCHEDULE_FREE_ADAMW_32BIT,
OptimizerNames.SCHEDULE_FREE_SGD_32BIT,
]:
if not is_schedulefree_available():
raise ImportError(
"You need to install `schedulefree` in order to use schedulefree optimizers"
" install it with `pip install schedulefree`"
)
from schedulefree import AdamWScheduleFree, SGDScheduleFree
additional_optim_kwargs = {}
if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW_32BIT:
optimizer_cls = AdamWScheduleFree
additional_optim_kwargs = adam_kwargs
elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD_32BIT:
optimizer_cls = SGDScheduleFree
else:
raise ValueError("Invalid schedulefree optimizer")
additional_optim_kwargs["weight_decay"] = args.weight_decay
additional_optim_kwargs["warmup_steps"] = args.warmup_steps
additional_optim_kwargs.update(
{
"weight_lr_power": float(getattr(torch, optim_args.get("weight_lr_power", 2.0))),
"r": float(getattr(torch, optim_args.get("r", 0.0))),
}
)
optimizer_kwargs.update(additional_optim_kwargs)
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down Expand Up @@ -3266,6 +3294,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if "ScheduleFree" in self.optimizer.__class__.__name__:
self.optimizer.eval()

inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Expand Down Expand Up @@ -3767,6 +3798,8 @@ def evaluation_loop(
logger.info(f" Batch size = {batch_size}")

model.eval()
if "ScheduleFree" in self.optimizer.__class__.__name__:
self.optimizer.eval()

self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
Expand Down Expand Up @@ -4367,6 +4400,8 @@ def prediction_loop(
inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)

model.eval()
if "ScheduleFree" in self.optimizer.__class__.__name__:
self.optimizer.eval()

if args.past_index >= 0:
self._past = None
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ class OptimizerNames(ExplicitEnum):
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
LOMO = "lomo"
ADALOMO = "adalomo"
SCHEDULE_FREE_ADAMW_32BIT = "schedule_free_adamw_32bit"
SCHEDULE_FREE_SGD_32BIT = "schedule_free_adamw_32bit"


# Sometimes users will pass in a `str` repr of a dict in the CLI
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_schedulefree_available,
is_scipy_available,
is_sentencepiece_available,
is_seqio_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_eetq_available = _is_package_available("eetq")
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
_schedulefree_available = _is_package_available("schedulefree")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -345,6 +346,10 @@ def is_lomo_available():
return _lomo_available


def is_schedulefree_available():
return _schedulefree_available


def is_pyctcdecode_available():
return _pyctcdecode_available

Expand Down

0 comments on commit 9e654bb

Please sign in to comment.