Fix FA2 attention for models supporting sliding window#34093
Merged
Cyrilvallez merged 1 commit intomainfrom Oct 22, 2024
Merged
Fix FA2 attention for models supporting sliding window#34093Cyrilvallez merged 1 commit intomainfrom
Cyrilvallez merged 1 commit intomainfrom
Conversation
Collaborator
|
#30642 should have been using |
Collaborator
|
Could you add a test using your script? |
Member
Author
|
Actually it does not depend on the Cache class used. The lines slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()are just not doing anything. They are not modifying the Cache in-place, and then they are never reused later. I have no clue why they are still here.
from transformers import SlidingWindowCache, MistralConfig
past_key_value = SlidingWindowCache(MistralConfig(), batch_size=1, max_cache_len=100)
past_key = past_key_value[0][0]raises |
BernardZach
pushed a commit
to BernardZach/transformers
that referenced
this pull request
Dec 5, 2024
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR fixes FA2 attention for models using
sliding_window. Currently, the removed snippet was not doing anything except assigning unused variables and slicing the potentialattention_mask, which would result in wrong attention computation. Indeed, as soon as theattention_maskwas not only1s (thus notNone), and the sequence length was longer than the sliding window, we would incorrectly slice it.The fix is either to correctly slice the K-V, or to do nothing and rely on the
sliding_windowarg in_flash_attention_forward. I took the second option because it is easier and avoids unnecesary code.If you run the following:
we would previously get gibberish:
and now (which is the same as when using
sdpaattention:Slow tests for Mistral are all good (two failing, but on main as well)
@ArthurZucker