[Docs / BetterTransformer ] Added more details about flash attention + SDPA#25265
[Docs / BetterTransformer ] Added more details about flash attention + SDPA#25265younesbelkada merged 14 commits intohuggingface:mainfrom
Docs / BetterTransformer ] Added more details about flash attention + SDPA#25265Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
stevhliu
left a comment
There was a problem hiding this comment.
Thanks for adding these additional details! 😄
|
|
||
| As of PyTorch 2.0, the attention fastpath is supported for both encoders and decoders. The list of supported architectures can be found [here](https://huggingface.co/docs/optimum/bettertransformer/overview#supported-models). | ||
|
|
||
| For decoder-based models (e.g. GPT, T5, Llama, etc.), the `BetterTransformer` API will convert all attention operations to use the [`torch.nn.functional.scaled_dot_product_attention` method](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) (SDPA), that is available only from PyTorch 2.0 and onwards. |
There was a problem hiding this comment.
Same comments for the rest of this section as in perf_infer_gpu_many.md (you can probably copy the changes over) :)
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
|
Thanks a lot for the extensive review @stevhliu ! 🎉 |
There was a problem hiding this comment.
Thanks a lot, that is much better.
I'll release on Optimum side to include huggingface/optimum#1225 that allows training with encoder models + SDPA as well.
It could be worth noting that a few models (Falcon, M4) start to have native SDPA support in transformers (but they may not dispatch to flash), see these discussions:
| For encoder models, the [`~PreTrainedModel.reverse_bettertransformer`] method reverts to the original model, which should be used before saving the model to use the canonical transformers modeling: | ||
|
|
||
| ```python | ||
| model = model.reverse_bettertransformer() | ||
| model.save_pretrained("saved_model") | ||
| ``` |
There was a problem hiding this comment.
I think we should not make the distinction between encoder / decoder models when it come to using reverse_bettertransformer.
For example, for encoder-decoder models (e.g. t5), both SDPA (in the decoder) and nestedtensor (in the encoder) are used. So in case one wants to save the model, he'll need to use reverse_bettertransformer.
To me the distinction is more in that you can get speedups for inference with encoder models (since nestedtensor is used), but for decoder models the speedup / dispatch to flash will only come (in pytorch 2.0) for training & batch size = 1 for inference.
There was a problem hiding this comment.
Thanks for the suggestion! I refactored a bit that section and removed the reverse_bettertransformer part as it is relevant only for training (that section is for inference only)
| # Use it for training or inference | ||
| ``` | ||
|
|
||
| SDPA can also call [Flash-Attention](https://arxiv.org/abs/2205.14135) kernels under the hood. If you want to force the usage of Flash Attention, use [`torch.backends.cuda.sdp_kernel(enable_flash=True)`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel): |
There was a problem hiding this comment.
torch.backends.cuda.sdp_kernel(enable_flash=True) is not enough. You need torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False as below
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
stevhliu
left a comment
There was a problem hiding this comment.
Looks awesome! I added some minor comments to make it a bit easier to read, and if you could also copy the changes from perf_infer_gpu_many to their corresponding sections in perf_infer_gpu_one that'd be great 🤗
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for working on this! 🚀
What does this PR do?
as discussed offline with @LysandreJik
This PR clarifies to users how it is possible to use Flash Attention as a backend for most used models in transformers. As we have a seen some questions from users asking whether it is possible to integrate flash attention into HF models, whereas you can already benefit from it when using
model.to_bettertransformer(), leveraging theBetterTransformerAPI from 🤗 optimum.The informations are based from the official documentation of
torch.nn.functional.scaled_dot_productIn the near future, we could also have a small blogpost explaining this as well
To do list / To clarify list:
Let me know if I missed anything else
cc @fxmarty @MKhalusova @stevhliu