Skip to content

Add ALL_ATTENTION_FUNCTIONS compatibility for Pixtral model#37960

Merged
ArthurZucker merged 7 commits intohuggingface:mainfrom
uminaty:pixtral-all-attn
May 8, 2025
Merged

Add ALL_ATTENTION_FUNCTIONS compatibility for Pixtral model#37960
ArthurZucker merged 7 commits intohuggingface:mainfrom
uminaty:pixtral-all-attn

Conversation

@uminaty
Copy link
Copy Markdown
Contributor

@uminaty uminaty commented May 5, 2025

What does this PR do?

This PR adds support for ALL_ATTENTION_FUNCTIONS to the Pixtral model’s attention mechanism. I added and verified compatibility with sdpa, flash_attention_2, and flex_attention. Since Pixtral also serves as the vision tower in Mistral 3.1, users can now set the entire model to use flash_attention_2.

I tried to follow the implementation pattern of other models using this interface. For flash_attention_2, I reused position_ids because the existing attention mask shape isn’t supported. Since Pixtral uses sequence packing and already generates position_ids, we leverage prepare_fa2_from_position_ids instead of a mask.

I tested these changes in training and inference: losses match very closely and we observe a 10–25 % throughput improvement depending on the setup.

Who can review?

@github-actions github-actions Bot marked this pull request as draft May 5, 2025 13:43
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 5, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@uminaty uminaty marked this pull request as ready for review May 5, 2025 13:43
@uminaty uminaty force-pushed the pixtral-all-attn branch from 7c43f75 to 54b71c2 Compare May 5, 2025 13:57
@uminaty uminaty force-pushed the pixtral-all-attn branch from 20f777d to 71827ac Compare May 5, 2025 16:52
Copy link
Copy Markdown
Contributor

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uminaty thanks for the PR! please use library-defined functions as much as possible 🤗 Thank you!

Comment thread src/transformers/models/pixtral/modeling_pixtral.py Outdated
Comment thread src/transformers/models/pixtral/modeling_pixtral.py Outdated
Comment thread src/transformers/models/pixtral/modeling_pixtral.py
Comment thread src/transformers/models/pixtral/modeling_pixtral.py
@uminaty
Copy link
Copy Markdown
Contributor Author

uminaty commented May 5, 2025

Thanks @qubvel for the review 🙏! I made the changes you suggested, let me know if anything else is needed.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great addition thanks! We don't really need the position ids, (should be kwargs imo!)

Comment thread src/transformers/models/pixtral/modeling_pixtral.py Outdated
@uminaty uminaty force-pushed the pixtral-all-attn branch from 9e78cee to 50cc674 Compare May 6, 2025 09:30
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool, I also had a PR for attention in VLMs in #37576 😄

Comment thread src/transformers/models/pixtral/modeling_pixtral.py
Comment thread src/transformers/models/pixtral/modeling_pixtral.py Outdated
@uminaty uminaty force-pushed the pixtral-all-attn branch from 50cc674 to 9503c77 Compare May 6, 2025 10:32
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@uminaty
Copy link
Copy Markdown
Contributor Author

uminaty commented May 6, 2025

Thanks everyone for your reviews! Let me know if anything else is needed before merging 😊

@ArthurZucker ArthurZucker merged commit f6664ee into huggingface:main May 8, 2025
14 checks passed
@ArthurZucker
Copy link
Copy Markdown
Collaborator

Thanks for the contrib!

@uminaty uminaty deleted the pixtral-all-attn branch May 8, 2025 20:47
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
…ace#37960)

* Add ALL_ATTENTION_FUNCTIONS compatibility for Pixtral model

* Fix invalid operand type

* Allow image_sizes to be optional in forward pass to fit tests

Disallow using sdpa and output_attentions

* Disallow using sdpa with output_attentions

* Delete useless comments, use eager attention from smolvlm, use pattern from mistral

* add _supports_attention_backend

* use kwargs instead of position_ids

---------

Co-authored-by: aurelien.lac <aurelien.lac@lighton.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants