System Info
On main
Who can help?
@zucchini-nlp @gante
Information
Tasks
Reproduction
We currently cannot use cache_implementation='sliding_window' with FA2. The following snippet
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache, StaticCache
import torch
device = 3
model_name = 'mistralai/Mistral-7B-v0.1'
dtype = torch.bfloat16
attn = 'flash_attention_2'
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation=attn,
torch_dtype=dtype, low_cpu_mem_usage=True).cuda(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
generation_kwargs = {
"max_new_tokens": 50,
"eos_token_id": None,
"do_sample": False,
"return_dict_in_generate": True,
}
inputs = tokenizer('Hello who are', padding=True, return_tensors='pt')
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
attention_mask
outputs = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, cache_implementation='sliding_window')
fails with
ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to call `tokenizer.padding_side = 'left'` before tokenizing the input.
Expected behavior
This comes from the fact that prepare_input_for_generation creates a 4d mask if if isinstance(past_key_values, StaticCache), but FA2 does not support 4d masks. I believe we need a check and correct expansion of the 2d mask as well with SlidingWindowCache. Found about it when opening #34093
System Info
On main
Who can help?
@zucchini-nlp @gante
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
We currently cannot use
cache_implementation='sliding_window'with FA2. The following snippetfails with
Expected behavior
This comes from the fact that
prepare_input_for_generationcreates a 4d mask ifif isinstance(past_key_values, StaticCache), but FA2 does not support 4d masks. I believe we need a check and correct expansion of the 2d mask as well withSlidingWindowCache. Found about it when opening #34093