feat: add flash_attn 2 to bert#27478
Conversation
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks a lot for your PR! In principle this looks great!
Many architecture uses BertAttention with # Copied from, therefore all these architectures could benefit from FA-2 for free, however you will need to apply _supports_flash_attn_2 = True on all these architectures. You need to
1- run make fix-copies
2- on modified architectures add the flag above + copy paste the BertFlashAttention class on all on them (with modified names).
Would you be happy to address these changes? Otherwise happy to help you!
|
Thanks a lot for your review and your suggestions @younesbelkada. |
|
|
Perfect thanks! |
|
Thanks @younesbelkada , I did it |
|
Didn't had time to properly look into it, will do it asap! |
|
Any updates on getting this PR merged? |
|
Hello there! I'm working on integrating scaled_dot_product_attention to BERT #28802, and there might be some merge conflicts with this change. Mostly, if my changes go through, then we can get rid of most of the downstream dependencies from fix-copies. Let me know if you have any questions. Happy to discuss and/or chat on the best way forward if necessary. |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
This would be a lifesaver for me, I hope merging this is prioritized! cc @younesbelkada |
@loswald - you can see a quick estimate of the speedups in #28802. The pytorch SDPA implementation uses FA2 under the hood (if your hardware supports it). The PR is ready but we're just waiting on the HG team to merge it. |
Feat: Add flash attention option for BERT
Usage:
model = BertModel.from_pretrained('bert-base-uncased', torch_dtype=torch.bfloat16, use_flash_attention_2=True)
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
@ArthurZucker and @younesbelkada