Skip to content

ModernBERT FlexAttention#35423

Closed
staghado wants to merge 13 commits intohuggingface:mainfrom
staghado:flexattn-modernbert
Closed

ModernBERT FlexAttention#35423
staghado wants to merge 13 commits intohuggingface:mainfrom
staghado:flexattn-modernbert

Conversation

@staghado
Copy link
Copy Markdown
Contributor

What does this PR do?

This PR adds FlexAttention support for ModernBERT:

  • Combines sliding window and document masking to implement the alternating local/global attention pattern in ModernBERT
  • Mask creation is expensive so the two masks are cached at the model level and then re-used across layers.
  • Similar to the FA2 path, it works directly on the unpadded sequences
  • Re-uses the existing ModernBertRotaryEmbedding to avoid requiring FA2.

Note:
The current version requires one of the latest torch nightlies (e.g 2.6.0.dev20241112)
Currently transformers does not allow compiling the flex_attention function IIUC

@tomaarsen
Copy link
Copy Markdown
Member

I think this is very interesting, but only if the performance rivals that of e.g. SDPA. I see your issue here: meta-pytorch/attention-gym#95, which also shows that with compilation, Flex Attention outperforms SDPA in a lot of common cases, so we would have to introduce the compilation option for Flex Attention in transformers before this is viable I think.

I'd love to see that implemented, though!

  • Tom Aarsen

@staghado
Copy link
Copy Markdown
Contributor Author

Yeah when compiled FlexAttention is generally faster than SDPA and has much lower memory from my tests.
How would we go about adding the compilation option? would a flag suffice in this case?

@staghado
Copy link
Copy Markdown
Contributor Author

staghado commented Jan 7, 2025

Any ideas on how transformers plans to support FlexAttention compilation? I see that no models have it implemented for now(I might be wrong here)
cc @tomaarsen @ArthurZucker

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks alright appart from the nightly requirement!
Regarding compilation, we do support compile, and by default when you use generate with cache_implementation="static".

Comment thread src/transformers/models/modernbert/modeling_modernbert.py Outdated
Comment thread src/transformers/models/modernbert/modeling_modernbert.py Outdated
@staghado
Copy link
Copy Markdown
Contributor Author

staghado commented Jan 9, 2025

I think the nightly requirement might be resolved by specifying the BLOCK_SIZE argument in create_block_mask, will verify that.
for compilation, I meant supporting the flex_attention = torch.compile(flex_attention, dynamic=False) which is very important for performance as mentioned here.
It also looks like the flexattention in Gemma was removed?

@neavo
Copy link
Copy Markdown

neavo commented Jan 11, 2025

What does this PR do?

This PR adds FlexAttention support for ModernBERT:

  • Combines sliding window and document masking to implement the alternating local/global attention pattern in ModernBERT
  • Mask creation is expensive so the two masks are cached at the model level and then re-used across layers.
  • Similar to the FA2 path, it works directly on the unpadded sequences
  • Re-uses the existing ModernBertRotaryEmbedding to avoid requiring FA2.

Note: The current version requires one of the latest torch nightlies (e.g 2.6.0.dev20241112) Currently transformers does not allow compiling the flex_attention function IIUC

Torch 2.6 has released the RC version. The dependency issues should not be a big problem anymore. Looking forward to its merging.
Before that, could you please update the branch of this PR to make its code compatible with the current mainline version?
I would like to give it a try.

if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention

flex_attention = torch.compile(flex_attention)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are needed to actually use the flexattention kernels but the utils/modular_model_converter.py does not allow it in the converted file.
let me know if there is a better way to do this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ArthurZucker @tomaarsen (sorry for the ping)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any ideas on how to support compiling the flex_attention function in transformers?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failing CI is related to these two lines, there is no clear way how to support compile in a clean way

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will have a look in a bit sorry about that!

@staghado staghado requested a review from ArthurZucker January 15, 2025 15:50
@staghado
Copy link
Copy Markdown
Contributor Author

staghado commented Jan 30, 2025

With the release of PyTorch 2.6, it is now possible to use FlexAttention with ModernBERT without a nightly requirement.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! Sorry for the delay, pieces of the puzzle are coming one at a time 🤣
#36103 should add something similar so we can take inspiration on it IMO.
If possible we want the least amount of changes in the modernbert modeling file, but otherwise sound great! let's follow a bit what is done in the linked PR, potentially using the same function to create the block mask!

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Sorry @staghado for the delay!

@staghado
Copy link
Copy Markdown
Contributor Author

staghado commented Feb 10, 2025

I skimmed through the linked PR, flex_attention does not get compiled automatically so the compilation call is needed. As you can see here. And there are hardcoded stuff for GQA, softcapping, etc which are not universal for all implementations.
Another thing is the integration uses score_mod which is less performant than block mask as mentioned here

I would prefer keeping the ModernBERT implementation separate for now, at least until the integration is more mature.

@bursteratom
Copy link
Copy Markdown
Contributor

bursteratom commented Feb 10, 2025

Hi @staghado my PR https://github.com/huggingface/transformers/pull/36103/files does add torch compile for flex attention, it is done in the integrations/flex_attention.py such that minimal modifications are required in the various modeling_{model type}.py for them to take advantage of flex attention.

I believe both score mod and block mask are needed - the score mod can modify attention scores in ways that can't be done using block mask. The intention is to perform attention masking solely using block mask, while score mod is only used for non-masking score modification.

Would love to hear more about your thoughts!

@staghado
Copy link
Copy Markdown
Contributor Author

Hi @bursteratom,
One big distinction between score mod and mask mod is that the latter is much more efficient in practice, for e.g with ModernBERT it allows implementing unpadded attention similar to flash attention.
The integration seems focused on causal decoder-only models with causal attention and GQA always set to true?, I think if we want it to be re-usable across architectures it should be as general as the PyTorch API with all the necessary knobs.

Another thing is, to get the performance promised by FlexAttention, we need to compile both flex_attention and create_block_mask.
@ArthurZucker for the reasons above I think it's better to keep this separate for now.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Hey #36643 was merged, just needs to integrate the refactor points imo and I'll have another look!

@staghado staghado closed this Mar 23, 2025
@staghado
Copy link
Copy Markdown
Contributor Author

Looks like the integration is too focused on decoder-only models(gqa=true at api level), create_block_mask is renamed to create_block_causal_mask_flex, a score_mod is always passed even if not needed. This PR is out of scope.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants