Cache: Static cache as a standalone object#30476
Conversation
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
| """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" | ||
| """Returns the sequence length of the cached states that were seen by the model.""" | ||
| # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | ||
| # limit the check to the first batch member and head dimension. | ||
| # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after | ||
| # https://github.com/pytorch/pytorch/issues/120248 is fixed | ||
| return (self.key_cache[0, 0].any(dim=-1)).sum() | ||
| return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
will remove this one
There was a problem hiding this comment.
it's slow and not reliable, generate should never use it
There was a problem hiding this comment.
(needs deprecation cycle and it's easer to do after we isolate the prefill stage, I'm going to leave it off this PR)
There was a problem hiding this comment.
fine by me to deprecate
| raise ValueError( | ||
| "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " | ||
| "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" | ||
| ) |
There was a problem hiding this comment.
Would be compatible if we slice the q k v efficiently, but that's too much trouble
|
Taking this on to finish! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
If you use the memory efficient kernel it's 20% slower. That's what we use by default |
|
https://gist.github.com/ArthurZucker/ae0a86ef8f841c0ef69aaa52ccbc0b03 for the benchmarks |
| # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail | ||
| # to infer the attention mask. | ||
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
| using_static_cache = isinstance(past_key_values, StaticCache) |
There was a problem hiding this comment.
As I understand it, once the StaticCache is initialized, there is no need to pass it in past_key_values argument. That's why additional condition is necessary. Suggestion:
using_static_cache = isinstance(past_key_values, StaticCache) or isinstance( getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache )
There was a problem hiding this comment.
@poedator This PR changes precisely the assumption you wrote: we will always need to pass the cache, after this PR it is an object that does NOT live inside the model.
This change will make the transformers' team work easier 🤗
There was a problem hiding this comment.
same comment as here: #30437 (comment) please make sure to validate these tests on the T4 and A10 runners 🙏
There was a problem hiding this comment.
There was indeed a mismatch on T4 🤗
ArthurZucker
left a comment
There was a problem hiding this comment.
Absolute great work
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
| """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" | ||
| """Returns the sequence length of the cached states that were seen by the model.""" | ||
| # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | ||
| # limit the check to the first batch member and head dimension. | ||
| # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after | ||
| # https://github.com/pytorch/pytorch/issues/120248 is fixed | ||
| return (self.key_cache[0, 0].any(dim=-1)).sum() | ||
| return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
fine by me to deprecate
| self.key_cache[layer_idx] *= 0.0 | ||
| self.value_cache[layer_idx] *= 0.0 |
There was a problem hiding this comment.
| self.key_cache[layer_idx] *= 0.0 | |
| self.value_cache[layer_idx] *= 0.0 | |
| self.key_cache[layer_idx] = 0.0 | |
| self.value_cache[layer_idx] = 0.0 |
might be faster?
There was a problem hiding this comment.
setting to a new tensor produces a graph break 💔 (I'm assuming you meant self.key_cache[layer_idx] = torch.zeros(...))
There was a problem hiding this comment.
No no, I think just filling them with zeros should work
There was a problem hiding this comment.
That would result in TypeError: 'float' object is not subscriptable when indexing the cache :D
But filling with zeros with tensor.zero_() works 👍
There was a problem hiding this comment.
ok 👍🏻 let's go with that then!
| if cache_position is None: | ||
| if isinstance(past_key_values, StaticCache): | ||
| raise ValueError("cache_position is a required argument when using StaticCache.") | ||
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
There was a problem hiding this comment.
Arf alright, let's add maybe a TODO? as we won't be initializing with get_seq_length later on!
There was a problem hiding this comment.
Added a todo on get_seq_length 👍
| # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in | ||
| # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail | ||
| # to infer the attention mask. | ||
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
| using_static_cache = isinstance(past_key_values, StaticCache) | ||
| if self.config._attn_implementation == "sdpa" and not using_static_cache: |
There was a problem hiding this comment.
this is new, and since we pass cahce position, let's use cache_position[0]
There was a problem hiding this comment.
Agreed in theory, can't do in practice: breaks torch.fx tests 💔
| if using_static_cache: | ||
| target_length = past_key_values.get_max_length() |
There was a problem hiding this comment.
can't we always use get_max_length()?
There was a problem hiding this comment.
get_max_length() is None in the dynamic caches
There was a problem hiding this comment.
It should be seq_length
| @@ -684,15 +683,25 @@ def test_model_13b_greedy_generation(self): | |||
| @require_torch_gpu | |||
| @require_read_token | |||
| def test_compile_static_cache(self): | |||
There was a problem hiding this comment.
should require torch > 2.2
| # Static Cache + compile | ||
| model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) | ||
| generated_ids = model.generate( | ||
| **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" | ||
| ) | ||
| static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
| self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) |
* 4d mask fixes * Update custom 4D mask logic * test moved to mixin * extra tests 4d mask * upd 4d mask and StaticCache handling * added Mask4DTestHard to mistral tests * post-rebase fixes * test fixes for StaticCache * make fix-copies * upd 1 after #30476 * fix common tests * rm elif attention_mask.dim() == 4: * tests combined, fixed, mixtral supported * bigbird style chg reverted * rm if attention_mask.dim() == 2 * modeling_llama formatting chg --------- Co-authored-by: Joao Gante <joao@huggingface.co>
| # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 | ||
| # work as intended. See https://github.com/pytorch/pytorch/issues/121943 |
There was a problem hiding this comment.
and 2.2.1 works as well

What does this PR do?
Replaces the current format of
StaticCache[an object living inside a model, containing the cache for one layer] with a standalone object matching the otherCacheobjects. The new format preserves the existingtorch.compilecapabilities while being easier to manipulate, especially outside a model.In the process, removes all traces of the previous format across all models, tests, and docs.
Fixes #30417 (In place of #30437)
Fixes #30351
Benchmarks
(RTX3090, tiny-llama model,
torch==2.4.0.dev20240424+cu121)Benchmark code
commit ==

14b19c4ef365f90797e07b2a20caaaaf3901b2d2v4.39.0
