Skip to content

Fix FA2 attention for models supporting sliding window#34093

Merged
Cyrilvallez merged 1 commit intomainfrom
fix-sliding-fa2
Oct 22, 2024
Merged

Fix FA2 attention for models supporting sliding window#34093
Cyrilvallez merged 1 commit intomainfrom
fix-sliding-fa2

Conversation

@Cyrilvallez
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez commented Oct 11, 2024

What does this PR do?

This PR fixes FA2 attention for models using sliding_window. Currently, the removed snippet was not doing anything except assigning unused variables and slicing the potential attention_mask, which would result in wrong attention computation. Indeed, as soon as the attention_mask was not only 1s (thus not None), and the sequence length was longer than the sliding window, we would incorrectly slice it.

The fix is either to correctly slice the K-V, or to do nothing and rely on the sliding_window arg in _flash_attention_forward. I took the second option because it is easier and avoids unnecesary code.

If you run the following:

from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache, StaticCache
import torch

device = 3

model_name = 'mistralai/Mistral-7B-v0.1'
dtype = torch.bfloat16
attn = 'flash_attention_2'

model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation=attn,
                                            torch_dtype=dtype, low_cpu_mem_usage=True).cuda(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token


with open('text.txt') as file:
    long_text = file.read()
inputs = tokenizer.encode(long_text, return_tensors='pt').to(device=device)


generation_kwargs = {
    "max_new_tokens": 50,
    "eos_token_id": None,
    "do_sample": False,
}


inputs0 = tokenizer.encode(long_text, return_tensors='pt')[:, :4200]
# Get text slighly longer than sliding window
text = tokenizer.batch_decode(inputs0, skip_special_tokens=True)[0]
inputs = tokenizer([text, 'My favorite condiment is undeniably ketchup because it is very nice and'], padding=True, return_tensors='pt')

input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)

outputs = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
tokenizer.batch_decode(outputs[:, 4200:], skip_special_tokens=True)

we would previously get gibberish:

>>> ['and behavioralongraphyeterms\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n', 't2222222222222222222222222222222222222222222222222']

and now (which is the same as when using sdpa attention:

>>> ['and physical differences. Dogs have undergone a process of domestication and artificial selection, which has led to a range of morphological and behavioral changes.\n\nPhysical Differences\nDogs exhibit a wide range of physical',
 'tasty. I love to eat it with my fries, burgers, and other foods. I also love to use it as a dip for my fries.\n\nI have been using ketchup for a long time now,']

Slow tests for Mistral are all good (two failing, but on main as well)

@ArthurZucker

@Cyrilvallez Cyrilvallez marked this pull request as ready for review October 11, 2024 15:21
@Cyrilvallez Cyrilvallez changed the title Fix FA2 Fix FA2 attention for models supporting sliding window Oct 11, 2024
@ArthurZucker
Copy link
Copy Markdown
Collaborator

#30642 should have been using SlidingWindow + test this no?

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Could you add a test using your script?

@Cyrilvallez Cyrilvallez mentioned this pull request Oct 11, 2024
4 tasks
@Cyrilvallez
Copy link
Copy Markdown
Member Author

Actually it does not depend on the Cache class used. The lines

slicing_tokens = 1 - self.config.sliding_window

past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]

past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()

are just not doing anything. They are not modifying the Cache in-place, and then they are never reused later. I have no clue why they are still here.

SlidingWindowCache does not even have __getitem__, so doing

from transformers import SlidingWindowCache, MistralConfig

past_key_value = SlidingWindowCache(MistralConfig(), batch_size=1, max_cache_len=100)
past_key = past_key_value[0][0]

raises TypeError: 'SlidingWindowCache' object is not subscriptable...

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice thanks

@Cyrilvallez Cyrilvallez merged commit 51e395d into main Oct 22, 2024
@Cyrilvallez Cyrilvallez deleted the fix-sliding-fa2 branch October 22, 2024 13:37
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants