[FA-2] Add Flash Attention to Phi#27661
Conversation
|
@susnato can you try to run that test multiple times? sometimes it is flaky - apart from that the changes look great on my end ! |
|
Hi @younesbelkada, I ran that test 30 times and every time it failed! Shouldn't the inference test fail too, if the generation test fails? 😅 |
|
Hmm yes correct, what I did for llama was to overwrite the test as can be see here: https://github.com/huggingface/transformers/blob/main/tests/models/llama/test_modeling_llama.py#L392 using a real checkpoint. It would be great if you can do the same and test the next 10 tokens are the same (make sure to use |
|
Hi @younesbelkada, thanks a lot for the advice! All the flash attention tests are passing now. 🤗 |
younesbelkada
left a comment
There was a problem hiding this comment.
Truly amazing work @susnato ! Thanks a lot for this great contribution
There was a problem hiding this comment.
| # in fp32. (LlamaRMSNorm handles it correctly) | |
| # in fp32. (PhiRMSNorm handles it correctly) |
There was a problem hiding this comment.
Hey, we don't have a PhiRMSNorm, only nn.LayerNorm in the Attention Layer. So removing this part of the line (PhiRMSNorm handles it correctly).
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM other than the small comments
There was a problem hiding this comment.
that's not something we want to remove, we should also have flash attention support in persimmon should be the same as this
There was a problem hiding this comment.
Actually, there is another PR which is adding the FA support for Persimmon.
Should I add the self.causal=True in the PersimmonAttention so that we can keep this # Copied from statement?
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks ! We should be able to merge after accepting the suggestion below
There was a problem hiding this comment.
| ### Expected speedups | |
| Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `susnato/phi-1_dev` checkpoint and the Flash Attention 2 version of the model using a sequence length of 2048. | |
| <div style="text-align: center"> | |
| <img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/phi_1_speedup_plot.jpg"> | |
| </div> | |
|
Hi @younesbelkada, I have pushed the suggestion you asked. |
|
BTW when is the next release date for |
8b9edbc to
9e22498
Compare
|
Just force-pushed the branch along with the changes. @younesbelkada |
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks !
We'll probably be able to make it for the next release wdyt @ArthurZucker
| @pytest.mark.flash_attn_test | ||
| @slow | ||
| # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->susnato/phi-1_5_dev | ||
| def test_flash_attn_2_generate_padding_right(self): |
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM cc @fxmarty for your sdpa PR that will need rebasing I think
|
yep |


What does this PR do?
This PR adds Flash Attention to
Phi.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
cc: @younesbelkada, @ArthurZucker