Fix static generation when compiling! #28937
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
I'm not sure adding a new argument The following works on import torch
from transformers import AutoModelForCausalLM, LlamaTokenizer
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM", attn_implementation="eager")
tokenizer = LlamaTokenizer.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM")
# random input id
inputs = tokenizer("Hey there", return_tensors="pt", return_attention_mask=True)
position_ids = inputs.attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(inputs.attention_mask == 0, 1)
with torch.no_grad():
logits = model.forward(**inputs, position_ids=position_ids).logitsIf we run the same code on this PR, we get the following error: Full traceback:This is because transformers/src/transformers/models/llama/modeling_llama.py Lines 352 to 353 in 56768a0 instead of reshaping to [ :, :, cache_position, : key_states.shape[-2]], we reshape to [ :, :, None, : key_states.shape[-2]]. So instead of slicing, we insert an extra dimension! This gives the size mismatch when we add the attention mask to the weights. The user needs to specify cache_position as an argument to the forward call in order for this to work.
Overall, I think we should avoid adding extra arguments that require code changes from the user, especially to the top-level modules which are already highly-used. What about a design more like Flax where we keep track of the |
|
We can make it BC! this PR is not ready yet, but generate should check the past key value class and if signature can take cache_position, give them. Something like that. I'll work on making it BC! :) |
| past_seen_tokens = 0 | ||
| if use_cache and not isinstance(past_key_values, Cache): | ||
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
| past_seen_tokens = past_key_values.get_usable_length(inputs_embeds.shape[1]) # kept for BC (cache positions) | ||
|
|
||
| if cache_position is None: | ||
| cache_position = torch.arange(past_seen_tokens, past_seen_tokens+inputs_embeds.shape[1]) |
There was a problem hiding this comment.
Has to be kept for BC
| if attention_mask is None: | ||
| return None | ||
| is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) | ||
| if not is_tracing and (torch.all(attention_mask == 1)): | ||
| return None | ||
| if is_tracing and seq_length == 1: | ||
| return None |
There was a problem hiding this comment.
all of this failed generations, deal with it later
There was a problem hiding this comment.
cc @fxmarty I am warning you in advance 🥶 you might have to do something similar to the prepared_4d_sdpa but this is a lot simpler so for the better
| # TODO @gante we should only keep a `cache_position` in generate, and do +=1. | ||
| # same goes for position ids. Could also help with continued generation. | ||
| cache_position = kwargs.get("cache_position", None) | ||
| if cache_position is None: | ||
| cache_position = torch.arange(past_length, past_length+input_ids.shape[1]) |
There was a problem hiding this comment.
kept for BC as well, generate should handle cache positions IMO
gante
left a comment
There was a problem hiding this comment.
Pre-approving, as the overall PR shape looks good to me 👍
(btw, this PR is blocking further work on generate, as llama + generate + dynamic cache is not correct at the moment and I want to standardize the interface of the different cache classes to match the static cache)
|
Thanks, merging asap |
| bool_keys = [k for k in keys if isinstance(model_input[k], bool)] | ||
| non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"] | ||
| bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] | ||
| keys_to_ignore = ["cache_position", "encoder_outputs"] | ||
| non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] |
There was a problem hiding this comment.
beam search will split the cache positions otherwise
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks for the huge work ! I left some minor comments that should be addressed before merging IMO, otherwise we might introduce some breaking change for users that use our public classes without explicit positional arguments
|
Example of a breaking behaviour that I introduced while working on FA2: #25598 (comment) so we should be careful when adding new args in our modules |
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
…hub.com:huggingface/transformers into fix-static-kv-cache
younesbelkada
left a comment
There was a problem hiding this comment.
Thank you very much !
|
Hey @ArthurZucker, I discovered that this change actually breaks TPU... Now, TPU training with FSDPv2 will produce loss with NaN. I haven't looked into your PR so I'm not sure why. Just bisecting til this change. |
|
Mmm this might be a ROPE issue? #29109 might also play |
|
Hi @ArthurZucker I run your benchmark script with both transformers 4.38.0 and 4.38.2 but got error: |
|
It is probably out of date! I'll update it |
|
We'll actually push a full benchmark in |

What does this PR do?
Fixes the static cache generation. Comes with #27931
thanks @OlivierDehaene for the insight
https://gist.github.com/ArthurZucker/ae0a86ef8f841c0ef69aaa52ccbc0b03 benchmark
generatebecause the first forward will be fully causal.FA2 potential fix if compiled worked:
but I have slowdowns:


Slicing
vs no Slicing