Llama: fix custom 4D masks#29930
Conversation
There was a problem hiding this comment.
reordered the logic: custom 4D masks are now a superset of the default mask, so we don't need to create the default mask first :)
There was a problem hiding this comment.
This else has no changes. Only the if attention_mask is not None and attention_mask.dim() == 4: is different.
There was a problem hiding this comment.
Added this test case (we can now pass full custom 4D attention masks)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks LGTM just want to always trigger the tests
|
Thanks @gante ! |
@poedator would you like to open a PR with that? As a user, you'll probably have cool examples in mind! |
will try, but not this week... |
There was a problem hiding this comment.
This set of slow tests was moved to the llama test file -> if we run the slow llama tests, which we often request, this will now be triggered
There was a problem hiding this comment.
This set of tests are now:
- part of the mixin, so they are run on all push commits
- a fast test, using
model = model_class(config)from the test config - triggered by
model_class._supports_cache_class == True-- recent LLMs [llama, cohere, gemma, mistral, mixtral, starcoder2, ...] have this attribute set toTrueand are 4D mask-compatible. Older models are often not compatible. Over time, as we spread the cache refactor, this test will be run on those classes as well 👀
|
@ArthurZucker ready for a re-review (test rework) -- we now have on push tests for all recent models + custom 4D mask :) |
|
@gante , I made the cache longer than the masks and padded the masks to the cache length. is this the correct way? |
|
Sorry for the delay, let's rebase on main as well |
ArthurZucker
left a comment
There was a problem hiding this comment.
Very good! Let's rebase on main, #30047 was merged, and run slow tests!
|
please, please merge this PR - I need it for my speculative decoding paper project! The 4D masks are essential for it. |
|
Sorry just got back to github 😓 could you rebase! |
I rebased this PR in new one #30348 and added few important changes. |
|
Closing in favor of #30348 |
What does this PR do?
Fixes the issue raised by @poedator in this comment.
Causal mask is now of shape
[..., seq_len, full_len], as opposed to[..., full_len, full_len]. This means custom 4D attention masks are now the whole causal mask, so we don't need a sliced copy -- we can copy the whole thing :)This PR also expands the support of custom 4D attention mask: we can pass both the full mask (
[..., full_len, full_len]) or the partial mask ([..., seq_len, full_len]).