Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RoBERTa-based] Add support for sdpa #30510

Merged
merged 13 commits into from
Aug 28, 2024
Merged

Conversation

hackyon
Copy link
Contributor

@hackyon hackyon commented Apr 26, 2024

What does this PR do?

Adding support for SDPA (scaled dot product attention) for RoBERTa-based models. More context in #28005 and #28802.

Models: camembert, roberta, xlm_roberta, xlm_roberta_xl.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@fxmarty @ArthurZucker @amyeroberts

@hackyon
Copy link
Contributor Author

hackyon commented Apr 26, 2024

I ran slow tests for the affected models, and verified that they all pass except XLMRobertaXLModelTest::test_eager_matches_sdpa_generate(). I suspect it's just some numerical computation error, but I'll take a quick look to see if I can find anything.

I'll also try to run some the perf benchmarks on RoBERTa over the weekend to see how they behave.

@hackyon
Copy link
Contributor Author

hackyon commented Apr 27, 2024

Preliminary perf numbers for Roberta (using "roberta-base" with AutoModel/Tokenizer).

Training

num_training_steps batch_size seq_len is cuda Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
1000 1 256 True 0.018 0.015 24.411 731.752 736.471 -0.641
1000 1 512 True 0.019 0.016 17.819 823.792 757.096 8.809
1000 2 256 True 0.020 0.016 29.890 760.504 757.096 0.450
1000 2 512 True 0.020 0.016 25.317 1283.793 907.688 41.435
1000 4 256 True 0.020 0.016 28.907 1094.001 907.289 20.579
1000 4 512 True 0.025 0.021 19.153 2205.299 1446.666 52.440

Inference

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
50 2 64 True True True 5.357 5.067 5.716 333.956 333.956 0
50 2 128 True True True 5.534 5.181 6.812 360.089 360.089 0
50 2 256 True True True 5.823 5.516 5.577 412.355 412.355 0
50 4 64 True True True 5.632 5.344 5.381 385.611 385.611 0
50 4 128 True True True 6.101 5.849 4.304 437.895 437.877 0.004
50 4 256 True True True 6.91 6.529 5.824 542.598 542.598 0

@hackyon
Copy link
Contributor Author

hackyon commented Apr 27, 2024

It seems like XLMRobertaXLModelTest::test_eager_matches_sdpa_generate() doesn't always fail, but it's flaky and depends on the random number generator. I think it is due to computation/numerical stability, which can result in slightly different results.

EDIT: I added a set_seed(0) to XLMRobertaXLModelTest::test_eager_matches_sdpa_generate(), and the flake seems to have gone away.

@hackyon hackyon force-pushed the sdpa-roberta branch 2 times, most recently from c39f457 to 41537e3 Compare April 29, 2024 17:42
@hackyon
Copy link
Contributor Author

hackyon commented Apr 29, 2024

@fxmarty @ArthurZucker @amyeroberts

This is ready for review! With the exception of the changes to the test and check_support_list.py, all the changes are coming from "Copied From". Please let me know if you have any questions!

@hackyon hackyon marked this pull request as ready for review April 29, 2024 17:50
@hackyon hackyon mentioned this pull request May 8, 2024
5 tasks
@michaelshekasta
Copy link

@hackyon, I'm curious about whether implementing flash_atten is essential when writing an SDPA. I came across claims that flash_atten can offer up to a x4 efficiency boost (roughly) compared to native PyTorch. However, your remarks in #30510 suggest that the actual improvement is less than 50%. Could you help shed some light on this apparent difference?

@hackyon
Copy link
Contributor Author

hackyon commented May 19, 2024

@michaelshekasta I believe the 4x improvement only applies to certain models, usually larger models with more computationally expensive attention computations.

@ArthurZucker
Copy link
Collaborator

@fxmarty can you have a look and ping me for the final review? 🤗

@nbroad1881
Copy link
Contributor

@fxmarty , gentle bump

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a fix that I think needs to be made for cross attention for the is_causal param

The test here https://github.com/huggingface/transformers/pull/30138/files#diff-681c988a50a31869d1756f2db71904939c639617569a5168d7b3167fe8da0b48 could also be extended for extra safety, but up to you.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@michaelshekasta
Copy link

@fxmarty what's left? How can I help?

@michaelshekasta
Copy link

@fxmarty you are amazing! If I can help, please write to me

@kiszk
Copy link
Contributor

kiszk commented Jul 10, 2024

@fxmarty Thank you very much. I would appreciate it if you could re-add gpt_neox for consistency. Or can I do it?
I am not sure why it was dropped.

https://app.circleci.com/pipelines/github/huggingface/transformers/97500/workflows/4facc164-8c3b-4ad0-9387-be9de636e686/jobs/1291191?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-checks-link&utm_content=summary

Traceback (most recent call last):
  File "/root/transformers/utils/check_support_list.py", line 97, in <module>
    check_sdpa_support_list()
  File "/root/transformers/utils/check_support_list.py", line 90, in check_sdpa_support_list
    raise ValueError(
ValueError: gpt_neox should be in listed in the SDPA documentation but is not. Please update the documentation.

Exited with code exit status 1

@fxmarty
Copy link
Contributor

fxmarty commented Jul 11, 2024

Thanks @kiszk, missed it when reordering the lists.

@fxmarty fxmarty requested a review from amyeroberts July 11, 2024 12:03
@fxmarty
Copy link
Contributor

fxmarty commented Jul 12, 2024

gentle ping @ArthurZucker @amyeroberts

@fxmarty
Copy link
Contributor

fxmarty commented Jul 16, 2024

@ArthurZucker @amyeroberts

@kiszk
Copy link
Contributor

kiszk commented Jul 22, 2024

@fxmarty You may want to resolve conflicts.

@fxmarty fxmarty requested review from amyeroberts and ArthurZucker and removed request for amyeroberts and ArthurZucker July 22, 2024 09:44
@ArthurZucker
Copy link
Collaborator

Sorry did not have time before, will try to do today or next week. It's a big PR with lots of changes, need to be extra careful!

@kiszk
Copy link
Contributor

kiszk commented Aug 13, 2024

@ArthurZucker would you have a time for this review?

@hotchpotch
Copy link

I've also experienced approximately 20% faster training with XLMRoberta using this PR on an RTX4090. I've been testing it for over a week now, and it's been working without any issues. I sincerely hope this can be merged.

@kiszk
Copy link
Contributor

kiszk commented Aug 27, 2024

@ArthurZucker Can we help with anything reviewing this PR?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept pushing this back, it's on me! I'll solve whatever comes up with this merge.
Thanks @hackyon for your hard work LGTM!

@ArthurZucker ArthurZucker merged commit f1a385b into huggingface:main Aug 28, 2024
22 checks passed
@michaelshekasta
Copy link

michaelshekasta commented Aug 28, 2024

@ArthurZucker when do you think that this change will appear in transformers package? next version?

P.S. You are so amazing guys!

@ArthurZucker
Copy link
Collaborator

It should be there in at most 2 weeks! 🤗

@hotchpotch
Copy link

I would like to thank everyone involved in this Pull Request from the bottom of my heart! 🎉

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
* Adding SDPA support for RoBERTa-based models

* add not is_cross_attention

* fix copies

* fix test

* add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin

* address some review comments

* use copied from

* style

* consistency

* fix lists

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
* Adding SDPA support for RoBERTa-based models

* add not is_cross_attention

* fix copies

* fix test

* add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin

* address some review comments

* use copied from

* style

* consistency

* fix lists

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@michaelshekasta
Copy link

@ArthurZucker A gentle reminder ;-)

It should be there in at most 2 weeks! 🤗

@ArthurZucker A gentle remider ;-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants