Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schedulefree optimizers #30079

Merged
merged 12 commits into from
Sep 9, 2024
45 changes: 45 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,51 @@ trainer.train()

This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.

## Schedule Free Optimizer

The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule.
Supported optimizers for SFO are `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`.

Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:

```python
import torch
import datasets
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
import trl

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-schedulefree",
max_steps=1000,
per_device_train_batch_size=4,
optim="schedule_free_adamw",
gradient_checkpointing=True,
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-6,
save_strategy="no",
run_name="sfo-imdb",
)

model_id = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=1024,
)

trainer.train()
```

## Accelerate and Trainer

The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
"sacremoses",
"safetensors>=0.4.1",
"sagemaker>=2.31.0",
"schedulefree>=1.2.6",
"scikit-learn",
"scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`)
"sentencepiece>=0.1.91,!=0.1.92",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"sacremoses": "sacremoses",
"safetensors": "safetensors>=0.4.1",
"sagemaker": "sagemaker>=2.31.0",
"schedulefree": "schedulefree>=1.2.6",
"scikit-learn": "scikit-learn",
"scipy": "scipy<1.13.0",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
is_rjieba_available,
is_sacremoses_available,
is_safetensors_available,
is_schedulefree_available,
is_scipy_available,
is_sentencepiece_available,
is_seqio_available,
Expand Down Expand Up @@ -370,6 +371,14 @@ def require_grokadamw(test_case):
return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)


def require_schedulefree(test_case):
"""
Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed.
https://github.com/facebookresearch/schedule_free
"""
return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case)


def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
Expand Down
36 changes: 36 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_schedulefree_available,
is_torch_compile_available,
is_torch_mlu_available,
is_torch_mps_available,
Expand Down Expand Up @@ -1488,6 +1489,34 @@ def optimizer_hook(param):

optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
elif args.optim in [
OptimizerNames.SCHEDULE_FREE_ADAMW,
OptimizerNames.SCHEDULE_FREE_SGD,
]:
if not is_schedulefree_available():
raise ImportError(
"You need to install `schedulefree` in order to use schedulefree optimizers"
" install it with `pip install schedulefree`"
)
winglian marked this conversation as resolved.
Show resolved Hide resolved
from schedulefree import AdamWScheduleFree, SGDScheduleFree

additional_optim_kwargs = {}
if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
optimizer_cls = AdamWScheduleFree
additional_optim_kwargs = adam_kwargs
elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
optimizer_cls = SGDScheduleFree
else:
raise ValueError("Invalid schedulefree optimizer")
additional_optim_kwargs["weight_decay"] = args.weight_decay
additional_optim_kwargs["warmup_steps"] = args.warmup_steps
additional_optim_kwargs.update(
{
"weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),
"r": float(optim_args.get("r", 0.0)),
}
)
optimizer_kwargs.update(additional_optim_kwargs)
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down Expand Up @@ -3410,6 +3439,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()

inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Expand Down Expand Up @@ -3960,6 +3992,8 @@ def evaluation_loop(
logger.info(f" Batch size = {batch_size}")

model.eval()
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
self.optimizer.eval()

self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
Expand Down Expand Up @@ -4573,6 +4607,8 @@ def prediction_loop(
inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)

model.eval()
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
self.optimizer.eval()

if args.past_index >= 0:
self._past = None
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ class OptimizerNames(ExplicitEnum):
LOMO = "lomo"
ADALOMO = "adalomo"
GROKADAMW = "grokadamw"
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
SCHEDULE_FREE_SGD = "schedule_free_sgd"


# Sometimes users will pass in a `str` repr of a dict in the CLI
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_schedulefree_available,
is_scipy_available,
is_sentencepiece_available,
is_seqio_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
_grokadamw_available = _is_package_available("grokadamw")
_schedulefree_available = _is_package_available("schedulefree")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -364,6 +365,10 @@ def is_grokadamw_available():
return _grokadamw_available


def is_schedulefree_available():
return _schedulefree_available


def is_pyctcdecode_available():
return _pyctcdecode_available

Expand Down
22 changes: 22 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
require_peft,
require_ray,
require_safetensors,
require_schedulefree,
require_sentencepiece,
require_sigopt,
require_tensorboard,
Expand Down Expand Up @@ -1442,6 +1443,27 @@ def test_grokadamw():
# Check this works
_ = trainer.train()

@require_schedulefree
@require_torch_gpu
def test_schedulefree_adam(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="schedule_free_adamw",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]

Expand Down
Loading