ModernBERT FlexAttention#35423
Conversation
|
I think this is very interesting, but only if the performance rivals that of e.g. SDPA. I see your issue here: meta-pytorch/attention-gym#95, which also shows that with compilation, Flex Attention outperforms SDPA in a lot of common cases, so we would have to introduce the compilation option for Flex Attention in I'd love to see that implemented, though!
|
|
Yeah when compiled FlexAttention is generally faster than SDPA and has much lower memory from my tests. |
|
Any ideas on how transformers plans to support FlexAttention compilation? I see that no models have it implemented for now(I might be wrong here) |
ArthurZucker
left a comment
There was a problem hiding this comment.
Looks alright appart from the nightly requirement!
Regarding compilation, we do support compile, and by default when you use generate with cache_implementation="static".
|
I think the nightly requirement might be resolved by specifying the |
Torch 2.6 has released the RC version. The dependency issues should not be a big problem anymore. Looking forward to its merging. |
| if is_torch_flex_attn_available(): | ||
| from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention | ||
|
|
||
| flex_attention = torch.compile(flex_attention) |
There was a problem hiding this comment.
these are needed to actually use the flexattention kernels but the utils/modular_model_converter.py does not allow it in the converted file.
let me know if there is a better way to do this
There was a problem hiding this comment.
cc @ArthurZucker @tomaarsen (sorry for the ping)
There was a problem hiding this comment.
any ideas on how to support compiling the flex_attention function in transformers?
There was a problem hiding this comment.
The failing CI is related to these two lines, there is no clear way how to support compile in a clean way
There was a problem hiding this comment.
will have a look in a bit sorry about that!
|
With the release of PyTorch 2.6, it is now possible to use FlexAttention with ModernBERT without a nightly requirement. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Okay! Sorry for the delay, pieces of the puzzle are coming one at a time 🤣
#36103 should add something similar so we can take inspiration on it IMO.
If possible we want the least amount of changes in the modernbert modeling file, but otherwise sound great! let's follow a bit what is done in the linked PR, potentially using the same function to create the block mask!
|
Sorry @staghado for the delay! |
|
I skimmed through the linked PR, flex_attention does not get compiled automatically so the compilation call is needed. As you can see here. And there are hardcoded stuff for GQA, softcapping, etc which are not universal for all implementations. I would prefer keeping the ModernBERT implementation separate for now, at least until the integration is more mature. |
|
Hi @staghado my PR https://github.com/huggingface/transformers/pull/36103/files does add torch compile for flex attention, it is done in the I believe both score mod and block mask are needed - the score mod can modify attention scores in ways that can't be done using block mask. The intention is to perform attention masking solely using block mask, while score mod is only used for non-masking score modification. Would love to hear more about your thoughts! |
|
Hi @bursteratom, Another thing is, to get the performance promised by FlexAttention, we need to compile both |
|
Hey #36643 was merged, just needs to integrate the refactor points imo and I'll have another look! |
|
Looks like the integration is too focused on decoder-only models(gqa=true at api level), create_block_mask is renamed to create_block_causal_mask_flex, a score_mod is always passed even if not needed. This PR is out of scope. |
What does this PR do?
This PR adds FlexAttention support for ModernBERT:
Note:
The current version requires one of the latest torch nightlies (e.g 2.6.0.dev20241112)
Currently transformers does not allow compiling the flex_attention function IIUC