Skip to content

[FA-2] Add Flash Attention to Phi#27661

Merged
ArthurZucker merged 6 commits intohuggingface:mainfrom
susnato:flash_attn_phi
Dec 7, 2023
Merged

[FA-2] Add Flash Attention to Phi#27661
ArthurZucker merged 6 commits intohuggingface:mainfrom
susnato:flash_attn_phi

Conversation

@susnato
Copy link
Copy Markdown
Contributor

@susnato susnato commented Nov 22, 2023

What does this PR do?

This PR adds Flash Attention to Phi.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc: @younesbelkada, @ArthurZucker

@susnato
Copy link
Copy Markdown
Contributor Author

susnato commented Nov 22, 2023

All the FA tests pass except the test_flash_attn_2_generate_padding_right.

Screenshot from 2023-11-23 00-31-04

This is odd given that the flash_attn_2_inference_padding_right test is passing as does the test_flash_attn_2_generate_left_padding test.

@younesbelkada
Copy link
Copy Markdown
Contributor

@susnato can you try to run that test multiple times? sometimes it is flaky - apart from that the changes look great on my end !

@susnato
Copy link
Copy Markdown
Contributor Author

susnato commented Nov 22, 2023

Hi @younesbelkada, I ran that test 30 times and every time it failed!

Shouldn't the inference test fail too, if the generation test fails? 😅

@younesbelkada
Copy link
Copy Markdown
Contributor

Hmm yes correct, what I did for llama was to overwrite the test as can be see here: https://github.com/huggingface/transformers/blob/main/tests/models/llama/test_modeling_llama.py#L392 using a real checkpoint. It would be great if you can do the same and test the next 10 tokens are the same (make sure to use do_sample=False)

@susnato
Copy link
Copy Markdown
Contributor Author

susnato commented Nov 23, 2023

Hi @younesbelkada, thanks a lot for the advice! All the flash attention tests are passing now. 🤗

Screenshot from 2023-11-23 17-57-49

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truly amazing work @susnato ! Thanks a lot for this great contribution

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# in fp32. (LlamaRMSNorm handles it correctly)
# in fp32. (PhiRMSNorm handles it correctly)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, we don't have a PhiRMSNorm, only nn.LayerNorm in the Attention Layer. So removing this part of the line (PhiRMSNorm handles it correctly).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok makes sense!

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM other than the small comments

Comment thread docs/source/en/model_doc/phi.md Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's not something we want to remove, we should also have flash attention support in persimmon should be the same as this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, there is another PR which is adding the FA support for Persimmon.

Should I add the self.causal=True in the PersimmonAttention so that we can keep this # Copied from statement?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah! 🤗

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! We should be able to merge after accepting the suggestion below

Comment thread docs/source/en/model_doc/phi.md Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
### Expected speedups
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `susnato/phi-1_dev` checkpoint and the Flash Attention 2 version of the model using a sequence length of 2048.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/phi_1_speedup_plot.jpg">
</div>

@susnato
Copy link
Copy Markdown
Contributor Author

susnato commented Dec 6, 2023

Hi @younesbelkada, I have pushed the suggestion you asked.

@susnato
Copy link
Copy Markdown
Contributor Author

susnato commented Dec 6, 2023

BTW when is the next release date for transformers?

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @susnato sorry one last thing before merging, can you apply changes similar to this commit: 93fe356 and merge with upstream main branch?

@susnato
Copy link
Copy Markdown
Contributor Author

susnato commented Dec 6, 2023

Just force-pushed the branch along with the changes. @younesbelkada

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !
We'll probably be able to make it for the next release wdyt @ArthurZucker

@pytest.mark.flash_attn_test
@slow
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->susnato/phi-1_5_dev
def test_flash_attn_2_generate_padding_right(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cc @fxmarty for your sdpa PR that will need rebasing I think

@ArthurZucker ArthurZucker merged commit f84d85b into huggingface:main Dec 7, 2023
@fxmarty
Copy link
Copy Markdown
Contributor

fxmarty commented Dec 7, 2023

yep

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants