Skip to content

SlidingWindowCache issue #34097

@Cyrilvallez

Description

@Cyrilvallez

System Info

On main

Who can help?

@zucchini-nlp @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

We currently cannot use cache_implementation='sliding_window' with FA2. The following snippet

from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache, StaticCache
import torch

device = 3


model_name = 'mistralai/Mistral-7B-v0.1'
dtype = torch.bfloat16
attn = 'flash_attention_2'

model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation=attn,
                                            torch_dtype=dtype, low_cpu_mem_usage=True).cuda(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token



generation_kwargs = {
    "max_new_tokens": 50,
    "eos_token_id": None,
    "do_sample": False,
    "return_dict_in_generate": True,
}


inputs = tokenizer('Hello who are', padding=True, return_tensors='pt')

input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
attention_mask

outputs = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, cache_implementation='sliding_window')

fails with

ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to  call `tokenizer.padding_side  = 'left'` before tokenizing the input.

Expected behavior

This comes from the fact that prepare_input_for_generation creates a 4d mask if if isinstance(past_key_values, StaticCache), but FA2 does not support 4d masks. I believe we need a check and correct expansion of the 2d mask as well with SlidingWindowCache. Found about it when opening #34093

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions