Skip to content

feat: add flash_attn 2 to bert#27478

Closed
chiennv2000 wants to merge 3 commits intohuggingface:mainfrom
chiennv2000:main
Closed

feat: add flash_attn 2 to bert#27478
chiennv2000 wants to merge 3 commits intohuggingface:mainfrom
chiennv2000:main

Conversation

@chiennv2000
Copy link
Copy Markdown

@chiennv2000 chiennv2000 commented Nov 14, 2023

Feat: Add flash attention option for BERT
Usage:
model = BertModel.from_pretrained('bert-base-uncased', torch_dtype=torch.bfloat16, use_flash_attention_2=True)

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@ArthurZucker and @younesbelkada

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your PR! In principle this looks great!
Many architecture uses BertAttention with # Copied from, therefore all these architectures could benefit from FA-2 for free, however you will need to apply _supports_flash_attn_2 = True on all these architectures. You need to
1- run make fix-copies
2- on modified architectures add the flag above + copy paste the BertFlashAttention class on all on them (with modified names).
Would you be happy to address these changes? Otherwise happy to help you!

Comment thread src/transformers/models/bert/modeling_bert.py
@chiennv2000
Copy link
Copy Markdown
Author

Thanks a lot for your review and your suggestions @younesbelkada.
But I don't really familiar with make fix-copies command. Can you guide me on how to do that?

@chiennv2000
Copy link
Copy Markdown
Author

  1. I appreciate your feedback. I'm happy to receive your assistance in implementing these changes.
    If you could help me with other architectures, that would be fantastic. Additionally, I'm open to collaborating on extending this to the Roberta and XLMR model. @younesbelkada

@younesbelkada
Copy link
Copy Markdown
Contributor

Perfect thanks!
As a first step, can you simply run make fix-copies and push the changes here? Then we'll take it over from there !

@chiennv2000
Copy link
Copy Markdown
Author

Thanks @younesbelkada , I did it

@huggingface huggingface deleted a comment from github-actions Bot Dec 14, 2023
@ArthurZucker
Copy link
Copy Markdown
Collaborator

cc @younesbelkada

@huggingface huggingface deleted a comment from github-actions Bot Jan 8, 2024
@younesbelkada
Copy link
Copy Markdown
Contributor

Didn't had time to properly look into it, will do it asap!

@kevinhu
Copy link
Copy Markdown
Contributor

kevinhu commented Jan 31, 2024

Any updates on getting this PR merged?

@hackyon
Copy link
Copy Markdown
Contributor

hackyon commented Feb 7, 2024

Hello there!

I'm working on integrating scaled_dot_product_attention to BERT #28802, and there might be some merge conflicts with this change. Mostly, if my changes go through, then we can get rid of most of the downstream dependencies from fix-copies.

Let me know if you have any questions. Happy to discuss and/or chat on the best way forward if necessary.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 3, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@loswald
Copy link
Copy Markdown

loswald commented Apr 21, 2024

This would be a lifesaver for me, I hope merging this is prioritized! cc @younesbelkada
@chiennv2000 what kind of speedups are you observing with this?

@hackyon
Copy link
Copy Markdown
Contributor

hackyon commented Apr 22, 2024

This would be a lifesaver for me, I hope merging this is prioritized! cc @younesbelkada @chiennv2000 what kind of speedups are you observing with this?

@loswald - you can see a quick estimate of the speedups in #28802. The pytorch SDPA implementation uses FA2 under the hood (if your hardware supports it). The PR is ready but we're just waiting on the HG team to merge it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants