Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open to contribution: adding torch.nn.functional.scaled_dot_product_attention support for more architectures #28005

Open
1 of 6 tasks
fxmarty opened this issue Dec 13, 2023 · 33 comments · Fixed by #30555
Open
1 of 6 tasks

Comments

@fxmarty
Copy link
Contributor

fxmarty commented Dec 13, 2023

Feature request

In Transformers 4.36, we started adding native support of torch.nn.functional.scaled_dot_product_attention (SDPA), enabled by default in Transformers: https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention

SDPA allows to dispatch to memory-efficient attention, flash attention on supported GPUs (currently NVIDIA-only), and even on Intel CPUs.

For the record, here's a benchmark on some currently supported models:

Training benchmark, run on A100-SXM4-80GB.

Model Batch size Sequence length Time per batch ("eager", s) Time per batch ("sdpa", s) Speedup Peak memory ("eager", MB) Peak memory ("sdpa", MB) Memory savings
llama2 7b 4 1024 1.065 0.90 19.4% 73878.28 45977.81 60.7%
llama2 7b 4 2048 OOM 1.87 / OOM 78394.58 SDPA does not OOM
llama2 7b 1 2048 0.64 0.48 32.0% 55557.01 29795.63 86.4%
llama2 7b 1 3072 OOM 0.75 / OOM 37916.08 SDPA does not OOM
llama2 7b 1 4096 OOM 1.03 / OOM 46028.14 SDPA does not OOM
llama2 7b 2 4096 OOM 2.05 / OOM 78428.14 SDPA does not OOM

Inference benchmark, run on A100-SXM4-80GB.

Model Batch size Prompt length Num new tokens Per token latency "eager" (ms) Per token latency "sdpa" (ms) Speedup
llama2 13b 1 1024 1 (prefill) 178.66 159.36 12.11%
llama2 13b 1 100 100 40.35 37.62 7.28%
llama2 13b 8 100 100 40.55 38.06 6.53%
Whisper v3 large 1 / 62 20.05 18.90 6.10%
Whisper v3 large 8 / 77 25.42 24.77 2.59%
Whisper v3 large 16 / 77 28.51 26.32 8.34%

Previously, we had a partial support of SDPA in Optimum BetterTransformer but we are now looking to slowly deprecate it in favor of upstream support of SDPA directly in Transformers.

Here are the architectures for which support has been requested:

The integration could take inspiration from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/decoder_models.py & https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py

Motivation

Faster training & inference, lower memory requirement

Your contribution

I may work on some at some point, but contributions are most welcome.

You should refer to #26572 to add the support of SDPA for a model, roughly following these steps:

  • Create a XxxSdpaAttention class inheriting from XxxAttention and implement the attention logic using SDPA
  • Use _prepare_4d_causal_attention_mask_for_sdpa instead of _prepare_4d_causal_attention_mask for SDPA
  • Use _prepare_4d_attention_mask_for_sdpa instead of _prepare_4d_attention_mask for SDPA
  • Add _supports_sdpa = True to XxxPreTrainedModel
  • Add "sdpa" key to XXX_ATTENTION_CLASSES in the model modeling file
@ENate
Copy link
Contributor

ENate commented Dec 14, 2023

Hi @fxmarty I can take a look at this issue. Of I can ask questions if necessary. Or has anyone taken it already?

@davidan5
Copy link

does someone know if longT5 and all T5 models are blocked by bias support in flash attention ?

Dao-AILab/flash-attention#617

@ENate
Copy link
Contributor

ENate commented Dec 19, 2023

Hi @davidan5 are you working on the implementation?

@davidan5
Copy link

@ENate I was trying to understand the status and have an estimation of the code change to see if I can contribute.

@ENate
Copy link
Contributor

ENate commented Dec 19, 2023

I see.

@hackyon
Copy link
Contributor

hackyon commented Jan 29, 2024

I'm interested in taking a look at this for the Mistral model if that's still needed. Otherwise, please let me know if there are any other models that still need some work. Thanks

@ENate
Copy link
Contributor

ENate commented Jan 29, 2024

Is LongT5 still open?

@ArthurZucker
Copy link
Collaborator

Mistral is already covered! LongT5 if it is like T5 and has attention bias that might not be supported

@hackyon
Copy link
Contributor

hackyon commented Jan 30, 2024

Oh yea, looks like you added support for Mistral/Mixtral last month.

It doesn't seem to be supported for BERT yet (I think someone else is working on FA2 but not SDPA), so I'll take a crack at it. It looks like there is a config for relative position embeddings for BERT, so I'll just have it fallback to the original attention for configs using relative position embeddings.

@ArthurZucker - Please let me know if you know if someone else is already working on SDPA for BERT and I can look for something else to do. Thanks!

@ArthurZucker
Copy link
Collaborator

Not sure anyone is working on that but bert is already so small that I doubt it will have a lot of impact on perf!

@huggingface huggingface deleted a comment from github-actions bot Feb 26, 2024
@lyaronskaya lyaronskaya mentioned this issue Feb 27, 2024
5 tasks
@huggingface huggingface deleted a comment from github-actions bot Mar 22, 2024
@abdulfatir
Copy link

@ArthurZucker for the T5 family of models, attention bias is required, so flash-attention won't work for now but torch SDPA can still use the memory efficient kernel from xformers, right? I did some benchmarking with Chronos models (based on T5 architecture) here (amazon-science/chronos-forecasting#33) and there's a clear speedup when using torch SDPA.

@fxmarty
Copy link
Contributor Author

fxmarty commented Apr 2, 2024

@abdulfatir That's correct

@abdulfatir
Copy link

I can open a PR for T5 with SDPA then. Are there specific things that I should know of or a reference that can look at?

@fxmarty
Copy link
Contributor Author

fxmarty commented Apr 2, 2024

@abdulfatir For sure, some specific things that are good to know:

pytorch/pytorch#108108 (is_causal=True may not do what you expect)
pytorch/pytorch#110213 (You need

)

example of a PR: #29108

@ArthurZucker
Copy link
Collaborator

FYI going forward we should rather use

def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
as it is more self contained, easier to debug and maintain than the many paths in the atnn_mask utils

@sayakpaul
Copy link
Member

Hey @abdulfatir just wanted to check in if you are still working on dropping a PR to add SDPA support for T5? It would tremendously help accelerating diffusion models that use T5.

@hackyon
Copy link
Contributor

hackyon commented Apr 26, 2024

@His-Wardship - Thanks for your input. We have finally managed to merge #28802 (SDPA for BERT), so hopefully this will unlock other SDPA implementations for other models. I plan to open up a PR for RoBERTa and do some testing to see how if there are any performance gains there.

Regarding DeBERTa, I'm somewhat skeptical whether or not SDPA has support for its disentangled attention mechanism. I think the disentangled attention mechanism requires some extra operations that may not be supported in SDPA.

@abdulfatir
Copy link

@amyeroberts Why was this issue closed? I believe not everything here has been addressed. Can we reopen please?

@fxmarty fxmarty reopened this Jun 10, 2024
@amyeroberts
Copy link
Collaborator

@abdulfatir The issue was closed automatically when #30555 was merged in because it had Fixes #28005 in the PR description.

For any future PRs, writing something like Address #PR_NUMBER or Part of #PR_NUMBER will avoid this

@OmarManzoor
Copy link
Contributor

OmarManzoor commented Jul 18, 2024

@fxmarty @amyeroberts Could you recommend a model to work on?

@amyeroberts
Copy link
Collaborator

@OmarManzoor
Copy link
Contributor

Okay, working on Albert.

@Bocchi-Chan2023
Copy link

Is it possible to support moondream2?
Not a large model, but it is one of the best vision model.

@amyeroberts
Copy link
Collaborator

Hi @Bocchi-Chan2023, there is an implementation of moondream2 available on the hub: https://huggingface.co/vikhyatk/moondream2

@avishaiElmakies
Copy link

I will start working on OPT

@AaronZLT
Copy link
Contributor

AaronZLT commented Sep 4, 2024

Just a stupid question: we are adding native sdpa support in modeling code, since pytorch has provide the fused sdpa?

@EFord36
Copy link

EFord36 commented Sep 9, 2024

Hi,,

Not sure if this is the best place to raise this, but I think it should be fairly straightforward to add support for DinoV2, since DinoV2 is using code copied from ViT, and ViT has since been updated (see here for comments about copied code, and #30555 for adding SDPA to ViT).

I'm using DinoV2 so would be keen for any speedups!

Should I just flag that here, or worth opening a separate issue?

I'm also not sure whether I should be raising support for BetterTransformer for DinoV2 in Optimum instead/as well, but my impression is that native support here is preferable if straightforward so it's used by default?

@amyeroberts
Copy link
Collaborator

@EFord36 Regarding adding SDPA, you can create a separate issue to track but we can regard this comment as a request for it to be added :) If you or anyone else in the community would like to open a PR to add this to the model we'd be very happy to review a PR!

Yes, it'd be preferable to be added in transformers as this would mean SDPA is available outside of optimum usage

@avishaiElmakies
Copy link

I can handle dinoV2

@avishaiElmakies avishaiElmakies mentioned this issue Sep 10, 2024
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet