Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schedulefree optimizers #30079

Merged
merged 12 commits into from
Sep 9, 2024
Merged

Conversation

winglian
Copy link
Contributor

@winglian winglian commented Apr 6, 2024

What does this PR do?

integrates meta's https://github.com/facebookresearch/schedule_free for adamw & sgd

https://twitter.com/aaron_defazio/status/1776320004465582331

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr @younesbelkada @pacman100

@muellerzr
Copy link
Contributor

muellerzr commented Apr 6, 2024

FYI this will need huggingface/accelerate#2631 as we need to upstream accelerate's ability to call train/eval on a wrapped optimizer

@danielhanchen
Copy link
Contributor

Some thoughts:

  • I was trying to ask Aaron et al on Twitter if they did any transformer experiments, but to no avail. They said a paper will come in 1 or 2 months.
  • Aaron et al's past work on D-Adaptation won a best ICML paper, with their follow up work being Prodigy - but both on transformers did similar or worse than AdamW. https://twitter.com/danielhanchen/status/1775547139248341125
  • Superconvergence + LR range finder + Fast AI's Ranger21 optimizer was the goto optimizer for CNNs, and worked fabulously well, but on transformers, the learning rate range finder sadi 1e-3 was the best, whilst 1e-5 was better. However, the 1 cycle learning rate stuck. Learning rate finder for the trainer  #16013
  • A huge issue is this needs tuning??! But how about a well tuned AdamW? Eg see https://twitter.com/kellerjordan0/status/1776716388037529843 which outperformed it using a tuned SGD.

I'm just a little bit reserved for now since the author themselves aren't providing any transformer benchmarks, nor have they compared their CNN baselines to superconvergence, which is the goto standard for fast training for CNNs. Likewise https://parameterfree.com/2023/08/30/yet-another-icml-award-fiasco/ wasn't pleasant.

@PhilipMay
Copy link
Contributor

Should be very easy to test this on Phi-2 or TinyLlama when the implementation works?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @winglian ! 🤩 I left one minor comment, wdyt?

@@ -3117,6 +3145,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if "ScheduleFree" in self.optimizer.__class__.__name__:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe instead of checking the class name here we could inject an attribute _hf_schedule_free_optim to make sure we can support that in the future for other shcedule free optimizers, what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be on the Trainer class, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the place that makes the most sense to set that would be in get_optimizer_cls_and_kwargs but that is a @staticmethod so has no access to the trainer object. We could do something along the lines of

setattr(self.optimizer, "_hf_schedule_free_optim", True)

after we instantiate the optimizer_cls but we would still have to do some sort of class name detection.

Alternatively we could pass another value in the return tuple specific to schedule_free optimizers (but that feels worse)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh good point yeah, in that case this is probably already fine I would say, thanks for investigating @winglian !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than have it as a stateful attribute, could we instead move this logic out to a module-level function e.g.:

def _is_schedule_free_optimizer(optimizer):
    return "ScheduleFree" in optimizer__class__.__name__

?

This way:

  • The check is a bit more explicit within the code logic
  • we can easily adapt the checking in one place, rather than throughout the code, if we end up introducing e.g. a _is_schedule_free attribute or there's schedule free optimizers with slightly different names

@PhilipMay
Copy link
Contributor

This PR should maybe also add a few lines to the README about "how to use this".

@muellerzr
Copy link
Contributor

We've merged the accelerate portion in, so if anyone is trying this out in distributed fashions, you can do pip install git+https://github.com/huggingface/accelerate :)

src/transformers/trainer.py Outdated Show resolved Hide resolved
@bratao
Copy link

bratao commented Apr 14, 2024

There is any chance of this making into the main branch? I and other confirmed that the results are real. Thank you @winglian

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super useful addition of scheduler free optimizers @winglian! It would be great to document the usage along with a minimal example.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@CoffeeVampir3
Copy link

Is their any remaining work I could contribute towards getting this PR merged?

Cheers

@winglian
Copy link
Contributor Author

@pacman100 @muellerzr @younesbelkada Can we get a new review to get this merged? Since the last check, I rebased, added some fixes and docs.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall LG2M, let's pin schedulefree as a >= however.

Can you also run the quality checks? Afterwords at least from my end looks good to merge.

setup.py Outdated Show resolved Hide resolved
@winglian
Copy link
Contributor Author

winglian commented Jun 1, 2024

@muellerzr ran the make quality/lint and also added a smoke test to the test suite for schedule free adam

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a bunch! cc @LysandreJik for final review

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot !

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Main comment is about the getattr logic in get_optimizer_cls_and_kwargs

@@ -3117,6 +3145,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if "ScheduleFree" in self.optimizer.__class__.__name__:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than have it as a stateful attribute, could we instead move this logic out to a module-level function e.g.:

def _is_schedule_free_optimizer(optimizer):
    return "ScheduleFree" in optimizer__class__.__name__

?

This way:

  • The check is a bit more explicit within the code logic
  • we can easily adapt the checking in one place, rather than throughout the code, if we end up introducing e.g. a _is_schedule_free attribute or there's schedule free optimizers with slightly different names

additional_optim_kwargs["warmup_steps"] = args.warmup_steps
additional_optim_kwargs.update(
{
"weight_lr_power": float(getattr(torch, optim_args.get("weight_lr_power", 2.0))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem right:

  • If we get "weight_lr_power" from optim_args I'm presuming it's a float as string e.g. "2.0"? I don't think torch.2.0 exists?
  • If optim_args doesn't have "weight_lr_power", then the second argument to getattr is a float, which isn't compatible

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"weight_lr_power": float(getattr(torch, optim_args.get("weight_lr_power", 2.0))),
"weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),

additional_optim_kwargs.update(
{
"weight_lr_power": float(getattr(torch, optim_args.get("weight_lr_power", 2.0))),
"r": float(getattr(torch, optim_args.get("r", 0.0))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"r": float(getattr(torch, optim_args.get("r", 0.0))),
"r": float(optim_args.get("r", 0.0)),

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@winglian
Copy link
Contributor Author

Will get back to this soon. Not stale 😅

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jul 31, 2024
@bratao
Copy link

bratao commented Jul 31, 2024

@winglian please don´t let it die

@winglian
Copy link
Contributor Author

thanks for the fixes @tmm1 !

@tmm1
Copy link
Contributor

tmm1 commented Aug 20, 2024

@amyeroberts I addressed your comments. LMK what else is required to push this through!

Comment on lines 3406 to 3443
if _is_schedule_free_optimizer(self.optimizer):
self.optimizer.train()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need to have optimizer specific logic in the main logic loops: this makes our training logic hard to handle and will become quickly too cluttered if many optimizers have to have specific logic.

cc @muellerzr who's been working on refactoring a lot of similar logic to this. Ideally all optimizers would use the same API and would support calling .train or .eval on the class, if this is required, but this is a large piece of work. Is this customization as-is acceptable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be preferable to check if the optimizer has the eval/train methods and call them, instead of checking the class name?

cc @adefazio

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the way I do it in my code is by checking for the eval/train methods, i.e. duck typing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this approach acceptable?

Suggested change
if _is_schedule_free_optimizer(self.optimizer):
self.optimizer.train()
if hasattr(self.optimizer, 'train') and callable(getattr(self.optimizer, 'train')):
self.optimizer.train()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this would be preferable as it's more extensible to other optimizers, and then puts the responsibility for management on implementation / checking on the optimizer implementation, rather than us within trainer

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating - looks good to me!

cc @muellerzr to confirm the conditional check on the eval and train methods are OK

@amyeroberts
Copy link
Collaborator

@winglian It looks like we'll need a bit more iteration on the eval and train method checks as it's causing failing tests atm. For the code quality tests, could you run make fixup and push the changes?

@tmm1
Copy link
Contributor

tmm1 commented Aug 27, 2024

we'll need a bit more iteration on the eval and train method checks as it's causing failing tests atm

looks like we need to revisit huggingface/accelerate#2631 which i've done in huggingface/accelerate#3055

could you run make fixup and push the changes

👍

@fizzAI
Copy link

fizzAI commented Aug 30, 2024

lgtm!
for what it's worth, there was some asking when this PR was originally opened about whether schedulefree's optims were effective on transformers; i did some testing a while ago, and SF adamw had (albiet marginally) a better loss/acc landscape than normal adamw on at least sequence classification finetuning, even after tuning hyperparams.
Screenshot_20240830-084839
(purple line is ScheduleFreeAdamW, red line is just AdamW)

@tmm1
Copy link
Contributor

tmm1 commented Sep 6, 2024

@amyeroberts can we kick off the tests again? does the build pick up the latest accelerate release automatically?

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a request for one more additional check for users who don't have the right minimum accelerate version. cc @LysandreJik for final 🤗

src/transformers/trainer.py Show resolved Hide resolved
Co-authored-by: Aman Gupta Karmani <aman@tmm1.net>
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @winglian!

LGTM

@LysandreJik LysandreJik merged commit 62aecd8 into huggingface:main Sep 9, 2024
24 checks passed
@winglian
Copy link
Contributor Author

Thanks for the PR @winglian!

all the credit goes to @tmm1 for getting this over the line!

@winglian winglian deleted the schedule-free-optimizer branch September 11, 2024 14:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.