[VLMs] support attention backends#37576
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| # Unset attn implementation so it can be set to another one when loading back | ||
| model_to_save.config._attn_implementation_autoset = False | ||
|
|
There was a problem hiding this comment.
this is moved to configuration_utils.py, where it's deleted from all sub configs. Otherwise we're unsetting it only in base config
| fx_compatible = True | ||
| test_pruning = False | ||
| test_missing_keys = False | ||
| test_head_masking = False # new attn API doesn't support head mask |
There was a problem hiding this comment.
we'll be removing head_mask in v5 and it was discussed in a this PR that we can deprecate it for now . Thus no need to fix the test and support with new interface
|
@qubvel if you have time to give it an initial review, while Arthur is off :) |
qubvel
left a comment
There was a problem hiding this comment.
Thanks! Just a few review questions
|
@qubvel comments addressed. Skipped Kosmos test can be run now (discovered your EDIT: kosmos apparently doesn't support padding, this test is skipped in new kosmos as well |
|
@qubvel @ArthurZucker updated the PR after the latest big refactor on VLMs. Can you review when you have time? |
ArthurZucker
left a comment
There was a problem hiding this comment.
Nice work 🧼
Lot of these refactoring would be easier if we also applied modular to idefix and etc!
ArthurZucker
left a comment
There was a problem hiding this comment.
careful with 2-3 places where new attention softmax is in float32
|
Yeah, agreed. We can do modular in a separate PR to not bloat up this one |
|
completely! |
| # Since we use packing, if Flash-Attn 2 is selected we rely on position_ids | ||
| if self.config._attn_implementation == "flash_attention_2": | ||
| kwargs["position_ids"] = kwargs["position_ids"].to(hidden_states.device, non_blocking=True) | ||
| attention_mask = None |
There was a problem hiding this comment.
Hi @zucchini-nlp why did you delete this line ? Because now the attention mask is used by default and there is an error when using flash attention 2 in my tests. Maybe I missed something but I think this was important. At least I don't have the RuntimeError: cu_seqlens_q must have shape (batch_size + 1) error when adding a attention_mask = None if we use fa2.
There was a problem hiding this comment.
Looks like a faulty rebase, I don't remember any tests failing because of this line. Probably we need a new test or expans an existing one to test FA2 packing correctly. Thanks for flagging
There was a problem hiding this comment.
Np! And yes I don't see how the tests could pass on this, I'm not familiar enough with the tests yet.
There was a problem hiding this comment.
Should I open a PR for adding this line back or do you take care of it?
There was a problem hiding this comment.
feel free to open a PR, I am a bit stuck on other tasks :)
And for the test, I realized that the one we have is for generative models only. Would be nice to add smth for all models if posssible
transformers/tests/test_modeling_common.py
Line 4106 in 1b00966
There was a problem hiding this comment.
@zucchini-nlp I open a PR for a hotfix, I don't have the time rn to look at the tests sorry. But I'll try to get familiar with it, I just can't guarantee I'll have time soon for this.
* update models * why rename * return attn weights when sdpa * fixes * fix attn implementation composite * fix moshi * add message * add typings * use explicitly all flags for each attn type * fix some tests * import what is needed * kosmos on main has ew attention already, yay * new models in main, run fixup * won't fix kosmos yet * fix-copies * clean up after rebasing * fix tests * style * dont cast attns to fp32 * did we update ruff? oke, let's just do what it asks * fix pixtral after rebase
What does this PR do?
As per title, another step closer to vLLM + transformers
What was done:
kwargsso vLLM can forward its attention instancesself.loss_fn(see Paligemma: fix generation with Gemma2 #36044 (comment))can_return_tupleFixes #36557, #35634, #36904 and fixes #33963