Skip to content

[VLMs] support attention backends#37576

Merged
zucchini-nlp merged 25 commits intohuggingface:mainfrom
zucchini-nlp:new-attn-interface-vlms
May 8, 2025
Merged

[VLMs] support attention backends#37576
zucchini-nlp merged 25 commits intohuggingface:mainfrom
zucchini-nlp:new-attn-interface-vlms

Conversation

@zucchini-nlp
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp commented Apr 17, 2025

What does this PR do?

As per title, another step closer to vLLM + transformers

What was done:

  • Support attention API for VLM related models if not yet done
  • Pass kwargs so vLLM can forward its attention instances
  • Replace all loss computations to self.loss_fn (see Paligemma: fix generation with Gemma2 #36044 (comment))
  • Minor clean up so new models can copy prettified version, update return block with can_return_tuple

Fixes #36557, #35634, #36904 and fixes #33963

@github-actions github-actions Bot marked this pull request as draft April 17, 2025 09:32
@github-actions
Copy link
Copy Markdown
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@zucchini-nlp zucchini-nlp marked this pull request as ready for review April 17, 2025 09:40
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines -3354 to -3435
# Unset attn implementation so it can be set to another one when loading back
model_to_save.config._attn_implementation_autoset = False

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is moved to configuration_utils.py, where it's deleted from all sub configs. Otherwise we're unsetting it only in base config

Comment thread src/transformers/modeling_utils.py Outdated
fx_compatible = True
test_pruning = False
test_missing_keys = False
test_head_masking = False # new attn API doesn't support head mask
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll be removing head_mask in v5 and it was discussed in a this PR that we can deprecate it for now . Thus no need to fix the test and support with new interface

@zucchini-nlp
Copy link
Copy Markdown
Member Author

@qubvel if you have time to give it an initial review, while Arthur is off :)

Copy link
Copy Markdown
Contributor

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a few review questions

Comment thread src/transformers/models/aria/modular_aria.py Outdated
Comment thread src/transformers/models/blip_2/modeling_blip_2.py
Comment thread tests/models/kosmos2/test_modeling_kosmos2.py
Comment thread src/transformers/modeling_utils.py Outdated
Comment thread tests/test_modeling_common.py Outdated
@zucchini-nlp
Copy link
Copy Markdown
Member Author

zucchini-nlp commented Apr 22, 2025

@qubvel comments addressed. Skipped Kosmos test can be run now (discovered your additional_model_inputs for each tester), but it has larger diff than expected sometimes. Skipping it for now and I will investigate the source of flakiness before merging

EDIT: kosmos apparently doesn't support padding, this test is skipped in new kosmos as well

@zucchini-nlp
Copy link
Copy Markdown
Member Author

@qubvel @ArthurZucker updated the PR after the latest big refactor on VLMs. Can you review when you have time?

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work 🧼
Lot of these refactoring would be easier if we also applied modular to idefix and etc!

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

careful with 2-3 places where new attention softmax is in float32

Comment thread src/transformers/models/instructblip/modeling_instructblip.py
Comment thread src/transformers/models/instructblip/modeling_instructblip.py
@zucchini-nlp
Copy link
Copy Markdown
Member Author

Yeah, agreed. We can do modular in a separate PR to not bloat up this one

@ArthurZucker
Copy link
Copy Markdown
Collaborator

completely!

@zucchini-nlp zucchini-nlp merged commit d23aae2 into huggingface:main May 8, 2025
20 checks passed
# Since we use packing, if Flash-Attn 2 is selected we rely on position_ids
if self.config._attn_implementation == "flash_attention_2":
kwargs["position_ids"] = kwargs["position_ids"].to(hidden_states.device, non_blocking=True)
attention_mask = None
Copy link
Copy Markdown
Contributor

@uminaty uminaty May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @zucchini-nlp why did you delete this line ? Because now the attention mask is used by default and there is an error when using flash attention 2 in my tests. Maybe I missed something but I think this was important. At least I don't have the RuntimeError: cu_seqlens_q must have shape (batch_size + 1) error when adding a attention_mask = None if we use fa2.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a faulty rebase, I don't remember any tests failing because of this line. Probably we need a new test or expans an existing one to test FA2 packing correctly. Thanks for flagging

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Np! And yes I don't see how the tests could pass on this, I'm not familiar enough with the tests yet.

Copy link
Copy Markdown
Contributor

@uminaty uminaty May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I open a PR for adding this line back or do you take care of it?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel free to open a PR, I am a bit stuck on other tasks :)

And for the test, I realized that the one we have is for generative models only. Would be nice to add smth for all models if posssible

def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp I open a PR for a hotfix, I don't have the time rn to look at the tests sorry. But I'll try to get familiar with it, I just can't guarantee I'll have time soon for this.

zucchini-nlp added a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* update models

* why rename

* return attn weights when sdpa

* fixes

* fix attn implementation composite

* fix moshi

* add message

* add typings

* use explicitly all flags for each attn type

* fix some tests

* import what is needed

* kosmos on main has ew attention already, yay

* new models in main, run fixup

* won't fix kosmos yet

* fix-copies

* clean up after rebasing

* fix tests

* style

* dont cast attns to fp32

* did we update ruff? oke, let's just do what it asks

* fix pixtral after rebase
@ydshieh ydshieh mentioned this pull request Jun 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Error during processing: MllamaForCausalLM does not support Flash Attention 2.0 yet. Flash attention 2 support for PaliGemma model

5 participants