Skip to content

Commit

Permalink
schedulefree optimizers (#30079)
Browse files Browse the repository at this point in the history
* schedulefree optimizers

* fix train instead of eval for optimizer

* fixes and update docs

* chore: lint

* add tests and drop overly-verbose _32bit suffix

* chore: lint

* fix for docs

* fix code review issues

* use duck-typing to avoid per-optimizer patches

* fixup style

* fixup style

* warn if incorrect accelerate version with schedule free

Co-authored-by: Aman Gupta Karmani <aman@tmm1.net>

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
  • Loading branch information
winglian and tmm1 committed Sep 9, 2024
1 parent 60226fd commit 62aecd8
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 0 deletions.
45 changes: 45 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,51 @@ trainer.train()

This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.

## Schedule Free Optimizer

The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule.
Supported optimizers for SFO are `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`.

Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:

```python
import torch
import datasets
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
import trl

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-schedulefree",
max_steps=1000,
per_device_train_batch_size=4,
optim="schedule_free_adamw",
gradient_checkpointing=True,
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-6,
save_strategy="no",
run_name="sfo-imdb",
)

model_id = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=1024,
)

trainer.train()
```

## Accelerate and Trainer

The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
"sacremoses",
"safetensors>=0.4.1",
"sagemaker>=2.31.0",
"schedulefree>=1.2.6",
"scikit-learn",
"scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`)
"sentencepiece>=0.1.91,!=0.1.92",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"sacremoses": "sacremoses",
"safetensors": "safetensors>=0.4.1",
"sagemaker": "sagemaker>=2.31.0",
"schedulefree": "schedulefree>=1.2.6",
"scikit-learn": "scikit-learn",
"scipy": "scipy<1.13.0",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
is_rjieba_available,
is_sacremoses_available,
is_safetensors_available,
is_schedulefree_available,
is_scipy_available,
is_sentencepiece_available,
is_seqio_available,
Expand Down Expand Up @@ -370,6 +371,14 @@ def require_grokadamw(test_case):
return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)


def require_schedulefree(test_case):
"""
Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed.
https://github.com/facebookresearch/schedule_free
"""
return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case)


def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
Expand Down
38 changes: 38 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_schedulefree_available,
is_torch_compile_available,
is_torch_mlu_available,
is_torch_mps_available,
Expand Down Expand Up @@ -1488,6 +1489,36 @@ def optimizer_hook(param):

optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
elif args.optim in [
OptimizerNames.SCHEDULE_FREE_ADAMW,
OptimizerNames.SCHEDULE_FREE_SGD,
]:
if not is_schedulefree_available():
raise ImportError(
"You need to install `schedulefree` in order to use schedulefree optimizers"
" install it with `pip install schedulefree`"
)
if not is_accelerate_available("0.30.0"):
raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers")
from schedulefree import AdamWScheduleFree, SGDScheduleFree

additional_optim_kwargs = {}
if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
optimizer_cls = AdamWScheduleFree
additional_optim_kwargs = adam_kwargs
elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
optimizer_cls = SGDScheduleFree
else:
raise ValueError("Invalid schedulefree optimizer")
additional_optim_kwargs["weight_decay"] = args.weight_decay
additional_optim_kwargs["warmup_steps"] = args.warmup_steps
additional_optim_kwargs.update(
{
"weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),
"r": float(optim_args.get("r", 0.0)),
}
)
optimizer_kwargs.update(additional_optim_kwargs)
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down Expand Up @@ -3410,6 +3441,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()

inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Expand Down Expand Up @@ -3960,6 +3994,8 @@ def evaluation_loop(
logger.info(f" Batch size = {batch_size}")

model.eval()
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
self.optimizer.eval()

self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
Expand Down Expand Up @@ -4573,6 +4609,8 @@ def prediction_loop(
inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)

model.eval()
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
self.optimizer.eval()

if args.past_index >= 0:
self._past = None
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ class OptimizerNames(ExplicitEnum):
LOMO = "lomo"
ADALOMO = "adalomo"
GROKADAMW = "grokadamw"
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
SCHEDULE_FREE_SGD = "schedule_free_sgd"


# Sometimes users will pass in a `str` repr of a dict in the CLI
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_schedulefree_available,
is_scipy_available,
is_sentencepiece_available,
is_seqio_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
_grokadamw_available = _is_package_available("grokadamw")
_schedulefree_available = _is_package_available("schedulefree")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -364,6 +365,10 @@ def is_grokadamw_available():
return _grokadamw_available


def is_schedulefree_available():
return _schedulefree_available


def is_pyctcdecode_available():
return _pyctcdecode_available

Expand Down
22 changes: 22 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
require_peft,
require_ray,
require_safetensors,
require_schedulefree,
require_sentencepiece,
require_sigopt,
require_tensorboard,
Expand Down Expand Up @@ -1442,6 +1443,27 @@ def test_grokadamw():
# Check this works
_ = trainer.train()

@require_schedulefree
@require_torch_gpu
def test_schedulefree_adam(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="schedule_free_adamw",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]

Expand Down

0 comments on commit 62aecd8

Please sign in to comment.