[core] 🚨 Completely remove cache positions#44181
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
81f8086 to
c2fde0a
Compare
vasqu
left a comment
There was a problem hiding this comment.
Since I'm on the train, it's a bit hard to review but iiuc we gradually go to remove cache positions and focused on a subset of important models for now
Can you run slow for bert/bart: just me being a bit too anxious about these
| # Since StaticSlidingWindow have dynamic control flow that cannot be avoided, we have to replace them here by | ||
| # simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported | ||
| for i, layer in enumerate(self.static_cache.layers): | ||
| if isinstance(layer, StaticSlidingWindowLayer): | ||
| self.static_cache.layers[i] = StaticLayer(layer.max_cache_len) |
There was a problem hiding this comment.
Do we add something in the docs to clarify, this seems like something unseemingly hard to catch by outsiders - and I don't think we will remove this limitation any time soon 😓
There was a problem hiding this comment.
We could, but note that this is ALREADY the case on main, and it looks like nobody noticed/raised issue... It would generate garbage beyond the sliding window, it's just explicit in the code now!
To solve it, we could either have a mimimal version of sliding cache for export, which would work ONLY for decoding (i.e. 1 new tokens all the time, without being able to feed more than 1 token after prefill), or we could juste use full cache with proper sliding masking, but we loose some compute
There was a problem hiding this comment.
Yea, still think it would be nice to clarify somewhere. True, it's also broken in main, maybe out of scope for this PR
| @abstractmethod | ||
| def update( | ||
| self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: dict[str, Any] | None = None | ||
| self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs |
There was a problem hiding this comment.
Do we really need the args/kwargs signature here? Does kwargs suffice to be BC?
Just not a fan of the *args tbh
There was a problem hiding this comment.
Yes we need it - a lot of models pass cache_kwarg as an arg unfortunately, and other as a kwarg like cache_kwargs=cache_kwargs 🥲
| # It can either be an int for dynamic layers, or a tensor for static layers | ||
| if isinstance(self.cumulative_length, int): | ||
| self.cumulative_length = 0 | ||
| else: | ||
| self.cumulative_length.zero_() | ||
|
|
There was a problem hiding this comment.
Could we default to tensor instead or too bothersome?
There was a problem hiding this comment.
Or python ints always, not sure about the impacts.
More of a nit tbh
There was a problem hiding this comment.
It's easier with ints for dynamic caches, and static ones REALLY need tensors for compile with cudagraphs!
There was a problem hiding this comment.
what about dynamic + compile 🙃
There was a problem hiding this comment.
Dynamic layers will never be able to use cudagraphs anyway as the kv tensors will change shape dynamically! However, I have strong hopes that we will be able to make them compile-compatible with dynamic=True/None and mode="default"!
|
|
||
| def reset(self): | ||
| super().reset() | ||
| self.cumulative_length_int = 0 |
There was a problem hiding this comment.
The _int suffix is a bit weird, I know you mean to make it explicit but then it won't match with a lot of the other layer types.
There was a problem hiding this comment.
It's because this LayerCache has both - the base cumulative_length is a Tensor (to be able to use cudagraph in the regime when the cache is not yet full), and cumulative_length_int is the equivalent but as a python int, to avoid data-dependent branching
There was a problem hiding this comment.
Ah gotcha, didn't really came through the diff for me but makes sense
| attention_mask: torch.Tensor | None, | ||
| cache_position: torch.Tensor, | ||
| cache_position: torch.Tensor | None = None, # not used anymore but kept for BC | ||
| *, |
There was a problem hiding this comment.
| *, |
don't see a reason to add it?
There was a problem hiding this comment.
happens elsewhere, keeping it here only
There was a problem hiding this comment.
It's because now we want it to be optional, but the next argument past_key_values was not optional either, and we don't really want to make it optional (args with default value cannot be in front of other args without default 🥲) - this way it makes it clear that ? past_key_value` has to be passed over
| embeds = encoder_hidden_states if encoder_hidden_states is not None else inputs_embeds | ||
| batch_size, dtype, device = embeds.shape[0], embeds.dtype, embeds.device |
There was a problem hiding this comment.
I dont think we need this ternary at all anymore? Both embeds should have the same factory data, e.g. device,dtype,batch
There was a problem hiding this comment.
Probably not, wasn't entirely sure as sometimes with the device_map they can end up on different devices - let's keep it for now and see after the PR is merged if we can remove!
There was a problem hiding this comment.
I love it, feels good to see
|
run-slow: bert, bart |
|
This comment contains models: ["models/bart", "models/bert"] |
|
Can we also update/delete docs where cache position is mentioned, such as https://huggingface.co/docs/transformers/v5.2.0/en/cache_explanation#cache-position? |
|
Updated audioflamingo, thanks the heads-up @zucchini-nlp! For the doc, I think it's best to wait for more models to remove them before deleting! |
ArthurZucker
left a comment
There was a problem hiding this comment.
I like this, but lets update the PR so the community understands why we are doing this breaking change (it is kinda breaking for the mask API). let's add our strong motivations please!
Also we could deprecate without breaking for some of the changes
| # It can either be an int for dynamic layers, or a tensor for static layers | ||
| if isinstance(self.cumulative_length, int): | ||
| self.cumulative_length = 0 | ||
| else: | ||
| self.cumulative_length.zero_() | ||
|
|
There was a problem hiding this comment.
what about dynamic + compile 🙃
| else: | ||
| # Note: very important to use the tensor version of the cumulative length here, as otherwise cudagraphs | ||
| # (triggered by mode="reduced_overhead") will lead to random crashes, as the int would be overwritten | ||
| cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length |
There was a problem hiding this comment.
its annoying to me that we have to allocate new memory all the time here..... 😿
There was a problem hiding this comment.
It's basically free - it's a tensor of 1 element 99% of the time, and even when it's not it's always super small!
| """ | ||
| # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf | ||
| _ = kwargs.pop("allow_is_causal_skip", None) | ||
| _ = kwargs.pop("allow_torch_fix", None) |
There was a problem hiding this comment.
should we just deprecate it for 2 releases at lease (cache postition as arg?)
| def update_conv_state( | ||
| self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor | ||
| ) -> torch.Tensor: | ||
| conv_state = self.conv_states[layer_idx] | ||
| cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) | ||
|
|
||
| conv_state = conv_state.roll(shifts=-1, dims=-1) | ||
| conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) | ||
| self.conv_states[layer_idx].zero_() | ||
| self.conv_states[layer_idx] += conv_state | ||
| return self.conv_states[layer_idx] |
There was a problem hiding this comment.
this is a good catch but unrelated no?
There was a problem hiding this comment.
Yes, fully unrelated, but I stumbled upon it and it's never used anywhere...
|
[For maintainers] Suggested jobs to run (before merge) run-slow: afmoe, apertus, arcee, aria, audioflamingo3, bamba, bitnet, cohere, cohere2, csm, cwm, deepseek_v2, deepseek_v3, dia, diffllama |
…chmarks Rewrite static_sample_investigation.md with: - Context: goal is to determine neuron-only vs general static path - Methodology: align on newest _sample algorithm, not neuron_sample fork - Full comparison table: _static_sample vs neuron_sample (14 items) - Benchmark results for Items A (output_ids CPU) and B (4D mask) - Recent PRs affecting _sample (#44226, #44130, #44181, #44126) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…chmarks Rewrite static_sample_investigation.md with: - Context: goal is to determine neuron-only vs general static path - Methodology: align on newest _sample algorithm, not neuron_sample fork - Full comparison table: _static_sample vs neuron_sample (14 items) - Benchmark results for Items A (output_ids CPU) and B (4D mask) - Recent PRs affecting _sample (#44226, #44130, #44181, #44126) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
What does this PR do?
As per the title! Follow up of #44130 and #44226.
Finally remove the
cache_positioneverywhere (not ALL models, but all the recent/most important models that use modular). This means we still create them and pass them around in generate, but they are ignored by the models. I will gradually remove them from all models, and remove them fromgenerateonce this is done.This PR basically tweaks the logic of
cache_utils.pyandmasking_utils.pyto not rely on thecache_position, and then remove the argument from the models'forwards (not a regression as they are absorbed by the**kwargsfor all those models).Motivation
cache_positionare basically not needed, as our Cache classes already and masking primitives contains all the necessary informations. We can simply recreate them quickly in the StaticCache when needed. Removing them will make all modeling files much easier to read and understand. Indeed, people are often confused bycache_positionthat seem to only be a 1D version ofposition_ids(thus it seems fully redundant) when you read modeling files, even though it's not exactly the case andcache_positiondon't take padding into account. Also, will allow to makegeneratemuch easier for input preparation once they will be removed from all models.Compile compatibilty
I made EXTRA SURE that we do not have any regresions in terms of
compile-compatibility: we have exactly the same scope of compatibility as before, i.e. full compatibility (fullgraph=True, dynamic=False, cudagraphs) for StaticLayer without any recompiles, and (fullgraph=True, cudagraphs) for StaticSlidingWindowLayer (it can work with fullgraph=False, but will recompile every iteration - otherwise it recompiles only once to make the internal int a dynamicSymint, and once again if it changes cache-regime (i.e. cache becomes full etc). There is no way around that as StaticSlidingWindowLayer has data-dependent control flows that are unavoidable. It's exactly the same currently onmain.Breaking changes
This PR is not really breaking. The 🚨 marker is only for the following detail:
The only breaking change in this PR is in the
masking_utilsAPI, ascache_positionbecome optional everywhere in thecreate_xxx_maskfunctions (they are not used anymore), and thuspast_key_valuesneed to be passed as kwarg now due to the position of the args in the signature, as we cannot have an arg without default value following an arg with default value (this was already the case everywhere in Transformers).The internal functions
sdpa_mask,eager_masketc as well move from havingcache_positionin their signature to havingq_lengthandq_offsetinstead. Those should be private anyway, but you never know.Review pointers
The easiest way to review this PR is to only look at the very few files that are not modeling files, especially
cache_utils.pyandmasking_utils.py. Then, all modeling files are basically the same: just remove the argcache_positioneverywhere (they are absorbed in the**kwargsso no regression, they can still be externally passed, and they are still technically passed bygenerate, just not used)