Proper performant flex attention implementation#36103
Proper performant flex attention implementation#36103bursteratom wants to merge 22 commits intohuggingface:mainfrom
Conversation
vasqu
left a comment
There was a problem hiding this comment.
A fan of flex attn :) hope you dont mind the comments but overall pro this. Torchtune is optimized for training iirc, is creating the block mask ok for inference? Like speed wise, I have no idea if there were any downsides/advantages for one or the other
Not sure if this is relevant to the PR tbh, but benchmarks might be a good thing to look out for in the future.
|
|
||
| """ | ||
| Inspired by torchtune's flex attention implementation | ||
| """ |
There was a problem hiding this comment.
Nit: would move this to top of the file
There was a problem hiding this comment.
Yep! And we forgot to add a licence!
There was a problem hiding this comment.
@ArthurZucker can you point me to an example of how a proper licence string should be added?
There was a problem hiding this comment.
I think something along these lines is meant
transformers/src/transformers/models/siglip2/modular_siglip2.py
Lines 1 to 14 in d18d9c3
(you can add the torchtune or pytorch team imo, not sure how finegrained it should be)
|
Example for inference with flex attn: meta-pytorch/gpt-fast#196 On first glance i can spot a few things:
I think avoiding recreating the block mask is especially important here to avoid the memory/speed overhead - but not sure as I haven't measured speeds/memory myself. Might be more appropriate for a different PR, no idea; I just think inference especially should be handled with more care. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Very much needed! Thanks a lot 🤗
cc @molbap as you had issues with this recently !
|
|
||
| """ | ||
| Inspired by torchtune's flex attention implementation | ||
| """ |
There was a problem hiding this comment.
Yep! And we forgot to add a licence!
molbap
left a comment
There was a problem hiding this comment.
Big fan of this work! Thanks a lot for tackling it. I'd be interested in benchmarks especially in a couple models like PaliGemma and models with bidirectional attention 👀
516de45 to
ae9a2b0
Compare
|
@vasqu @molbap @ArthurZucker I made some changes according to your inputs, wondering if you can give it another pass? Thank you! |
7519b3c to
8c28c9d
Compare
vasqu
left a comment
There was a problem hiding this comment.
Honestly, think the core is fine - just a few nits and smaller things. Would leave inference for another PR :)
| return create_block_causal_mask_flex( | ||
| causal_mask_mod, | ||
| batch_size, | ||
| None, | ||
| Q_LEN=total_seq_len, | ||
| KV_LEN=total_seq_len, | ||
| device=device, | ||
| ) |
There was a problem hiding this comment.
I think my marking last time made it a bit confusing - kwargs on all args would be beneficial imo, especially on the None arg (attention heads).
432bafa to
7483314
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Very nice!
Just missing some doc/ small perf comparisons!
fb9c4c6 to
3d9377f
Compare
99e62c0 to
c50468c
Compare
8aaeda8 to
864efb2
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM! could you just add some documentation about perffs comparison ! 🤗
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
|
@ArthurZucker thank you! I will add the doc and perf comparison shortly! I'm wondering where in the |
|
Let's merge for now IMO and you can open a new PR for doc! |
|
See #36643 needed to flix the conflicts |
|
We can close PR is merged! 🤗 |
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
What does this PR do?
Current flex attention implementation does not take advantage of the performance and memory efficiency promised in this official blog post from pytorch
This PR, inspired by https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py rectifies that by making flex attention always compiled and utilizing the sparse-optimised BlockMask data type for attention masking in lieu of regular torch tensor. Performance and memory utilization are now comparable to flash attention.
BlockMask creation has been implemented for the following models:
Let's add support for other models in a separate PR
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.