Skip to content

Proper performant flex attention implementation#36103

Closed
bursteratom wants to merge 22 commits intohuggingface:mainfrom
bursteratom:proper_flex
Closed

Proper performant flex attention implementation#36103
bursteratom wants to merge 22 commits intohuggingface:mainfrom
bursteratom:proper_flex

Conversation

@bursteratom
Copy link
Copy Markdown
Contributor

@bursteratom bursteratom commented Feb 8, 2025

What does this PR do?

Current flex attention implementation does not take advantage of the performance and memory efficiency promised in this official blog post from pytorch

This PR, inspired by https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py rectifies that by making flex attention always compiled and utilizing the sparse-optimised BlockMask data type for attention masking in lieu of regular torch tensor. Performance and memory utilization are now comparable to flash attention.

BlockMask creation has been implemented for the following models:

  • Llama

Let's add support for other models in a separate PR

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fan of flex attn :) hope you dont mind the comments but overall pro this. Torchtune is optimized for training iirc, is creating the block mask ok for inference? Like speed wise, I have no idea if there were any downsides/advantages for one or the other

Not sure if this is relevant to the PR tbh, but benchmarks might be a good thing to look out for in the future.


"""
Inspired by torchtune's flex attention implementation
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: would move this to top of the file

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! And we forgot to add a licence!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker can you point me to an example of how a proper licence string should be added?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something along these lines is meant

# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

(you can add the torchtune or pytorch team imo, not sure how finegrained it should be)

Comment thread src/transformers/integrations/flex_attention.py Outdated
Comment thread src/transformers/integrations/flex_attention.py Outdated
Comment thread src/transformers/integrations/flex_attention.py
Comment thread src/transformers/integrations/flex_attention.py
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 9, 2025

Example for inference with flex attn: meta-pytorch/gpt-fast#196

On first glance i can spot a few things:

  • offset wrapper function
  • create initial bigger sparse block mask and index as necessary
  • compile block mask creation function

I think avoiding recreating the block mask is especially important here to avoid the memory/speed overhead - but not sure as I haven't measured speeds/memory myself. Might be more appropriate for a different PR, no idea; I just think inference especially should be handled with more care.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very much needed! Thanks a lot 🤗
cc @molbap as you had issues with this recently !

Comment thread src/transformers/integrations/flex_attention.py Outdated
Comment thread src/transformers/integrations/flex_attention.py Outdated

"""
Inspired by torchtune's flex attention implementation
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! And we forgot to add a licence!

Comment thread src/transformers/models/llama/modeling_llama.py Outdated
Comment thread src/transformers/integrations/flex_attention.py
Copy link
Copy Markdown
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big fan of this work! Thanks a lot for tackling it. I'd be interested in benchmarks especially in a couple models like PaliGemma and models with bidirectional attention 👀

Comment thread src/transformers/integrations/flex_attention.py
@bursteratom
Copy link
Copy Markdown
Contributor Author

@vasqu @molbap @ArthurZucker I made some changes according to your inputs, wondering if you can give it another pass? Thank you!

@bursteratom bursteratom force-pushed the proper_flex branch 2 times, most recently from 7519b3c to 8c28c9d Compare February 12, 2025 16:25
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, think the core is fine - just a few nits and smaller things. Would leave inference for another PR :)

Comment thread src/transformers/integrations/flex_attention.py Outdated
Comment thread src/transformers/integrations/flex_attention.py Outdated
Comment on lines +97 to +127
return create_block_causal_mask_flex(
causal_mask_mod,
batch_size,
None,
Q_LEN=total_seq_len,
KV_LEN=total_seq_len,
device=device,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my marking last time made it a bit confusing - kwargs on all args would be beneficial imo, especially on the None arg (attention heads).

Comment thread src/transformers/integrations/flex_attention.py Outdated
Comment thread src/transformers/integrations/flex_attention.py
@bursteratom bursteratom force-pushed the proper_flex branch 2 times, most recently from 432bafa to 7483314 Compare February 13, 2025 05:05
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!
Just missing some doc/ small perf comparisons!

Comment thread src/transformers/models/llama/modeling_llama.py Outdated
Comment thread src/transformers/integrations/flex_attention.py
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! could you just add some documentation about perffs comparison ! 🤗

@ArthurZucker ArthurZucker mentioned this pull request Mar 1, 2025
2 tasks
@bursteratom
Copy link
Copy Markdown
Contributor Author

bursteratom commented Mar 4, 2025

@ArthurZucker thank you! I will add the doc and perf comparison shortly! I'm wondering where in the docs/ subdirectory should I add the doc pertaining to flex attention, and what should the doc entail?

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Let's merge for now IMO and you can open a new PR for doc!

@ArthurZucker ArthurZucker added flex attention Compilation Issues related to torchdynamo and torchinductor labels Mar 11, 2025
@ArthurZucker
Copy link
Copy Markdown
Collaborator

See #36643 needed to flix the conflicts

@ArthurZucker
Copy link
Copy Markdown
Collaborator

We can close PR is merged! 🤗

@bursteratom bursteratom reopened this Mar 11, 2025
@github-actions
Copy link
Copy Markdown
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@github-actions github-actions Bot marked this pull request as draft March 11, 2025 12:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Compilation Issues related to torchdynamo and torchinductor flex attention

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants