Skip to content

[core] 🚨 Completely remove cache positions#44181

Merged
Cyrilvallez merged 51 commits intomainfrom
remove-cache-pos
Mar 4, 2026
Merged

[core] 🚨 Completely remove cache positions#44181
Cyrilvallez merged 51 commits intomainfrom
remove-cache-pos

Conversation

@Cyrilvallez
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez commented Feb 20, 2026

What does this PR do?

As per the title! Follow up of #44130 and #44226.
Finally remove the cache_position everywhere (not ALL models, but all the recent/most important models that use modular). This means we still create them and pass them around in generate, but they are ignored by the models. I will gradually remove them from all models, and remove them from generate once this is done.

This PR basically tweaks the logic of cache_utils.py and masking_utils.py to not rely on the cache_position, and then remove the argument from the models' forwards (not a regression as they are absorbed by the **kwargs for all those models).

Motivation

cache_position are basically not needed, as our Cache classes already and masking primitives contains all the necessary informations. We can simply recreate them quickly in the StaticCache when needed. Removing them will make all modeling files much easier to read and understand. Indeed, people are often confused by cache_position that seem to only be a 1D version of position_ids (thus it seems fully redundant) when you read modeling files, even though it's not exactly the case and cache_position don't take padding into account. Also, will allow to make generate much easier for input preparation once they will be removed from all models.

Compile compatibilty

I made EXTRA SURE that we do not have any regresions in terms of compile-compatibility: we have exactly the same scope of compatibility as before, i.e. full compatibility (fullgraph=True, dynamic=False, cudagraphs) for StaticLayer without any recompiles, and (fullgraph=True, cudagraphs) for StaticSlidingWindowLayer (it can work with fullgraph=False, but will recompile every iteration - otherwise it recompiles only once to make the internal int a dynamic Symint, and once again if it changes cache-regime (i.e. cache becomes full etc). There is no way around that as StaticSlidingWindowLayer has data-dependent control flows that are unavoidable. It's exactly the same currently on main.

Breaking changes

This PR is not really breaking. The 🚨 marker is only for the following detail:

The only breaking change in this PR is in the masking_utils API, as cache_position become optional everywhere in the create_xxx_mask functions (they are not used anymore), and thus past_key_values need to be passed as kwarg now due to the position of the args in the signature, as we cannot have an arg without default value following an arg with default value (this was already the case everywhere in Transformers).

The internal functions sdpa_mask, eager_mask etc as well move from having cache_position in their signature to having q_length and q_offset instead. Those should be private anyway, but you never know.

Review pointers

The easiest way to review this PR is to only look at the very few files that are not modeling files, especially cache_utils.py and masking_utils.py. Then, all modeling files are basically the same: just remove the arg cache_position everywhere (they are absorbed in the **kwargs so no regression, they can still be externally passed, and they are still technically passed by generate, just not used)

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I'm on the train, it's a bit hard to review but iiuc we gradually go to remove cache positions and focused on a subset of important models for now

Can you run slow for bert/bart: just me being a bit too anxious about these

Comment on lines +526 to +530
# Since StaticSlidingWindow have dynamic control flow that cannot be avoided, we have to replace them here by
# simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
for i, layer in enumerate(self.static_cache.layers):
if isinstance(layer, StaticSlidingWindowLayer):
self.static_cache.layers[i] = StaticLayer(layer.max_cache_len)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we add something in the docs to clarify, this seems like something unseemingly hard to catch by outsiders - and I don't think we will remove this limitation any time soon 😓

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but note that this is ALREADY the case on main, and it looks like nobody noticed/raised issue... It would generate garbage beyond the sliding window, it's just explicit in the code now!
To solve it, we could either have a mimimal version of sliding cache for export, which would work ONLY for decoding (i.e. 1 new tokens all the time, without being able to feed more than 1 token after prefill), or we could juste use full cache with proper sliding masking, but we loose some compute

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, still think it would be nice to clarify somewhere. True, it's also broken in main, maybe out of scope for this PR

@abstractmethod
def update(
self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: dict[str, Any] | None = None
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need the args/kwargs signature here? Does kwargs suffice to be BC?

Just not a fan of the *args tbh

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we need it - a lot of models pass cache_kwarg as an arg unfortunately, and other as a kwarg like cache_kwargs=cache_kwargs 🥲

Comment on lines +75 to 80
# It can either be an int for dynamic layers, or a tensor for static layers
if isinstance(self.cumulative_length, int):
self.cumulative_length = 0
else:
self.cumulative_length.zero_()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we default to tensor instead or too bothersome?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or python ints always, not sure about the impacts.

More of a nit tbh

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's easier with ints for dynamic caches, and static ones REALLY need tensors for compile with cudagraphs!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about dynamic + compile 🙃

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dynamic layers will never be able to use cudagraphs anyway as the kv tensors will change shape dynamically! However, I have strong hopes that we will be able to make them compile-compatible with dynamic=True/None and mode="default"!


def reset(self):
super().reset()
self.cumulative_length_int = 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _int suffix is a bit weird, I know you mean to make it explicit but then it won't match with a lot of the other layer types.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because this LayerCache has both - the base cumulative_length is a Tensor (to be able to use cudagraph in the regime when the cache is not yet full), and cumulative_length_int is the equivalent but as a python int, to avoid data-dependent branching

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha, didn't really came through the diff for me but makes sense

Comment thread src/transformers/masking_utils.py
Comment thread src/transformers/masking_utils.py
Comment thread src/transformers/masking_utils.py
attention_mask: torch.Tensor | None,
cache_position: torch.Tensor,
cache_position: torch.Tensor | None = None, # not used anymore but kept for BC
*,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*,

don't see a reason to add it?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happens elsewhere, keeping it here only

Copy link
Copy Markdown
Member Author

@Cyrilvallez Cyrilvallez Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because now we want it to be optional, but the next argument past_key_values was not optional either, and we don't really want to make it optional (args with default value cannot be in front of other args without default 🥲) - this way it makes it clear that ? past_key_value` has to be passed over

Comment on lines +1018 to +1019
embeds = encoder_hidden_states if encoder_hidden_states is not None else inputs_embeds
batch_size, dtype, device = embeds.shape[0], embeds.dtype, embeds.device
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we need this ternary at all anymore? Both embeds should have the same factory data, e.g. device,dtype,batch

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not, wasn't entirely sure as sometimes with the device_map they can end up on different devices - let's keep it for now and see after the PR is merged if we can remove!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love it, feels good to see

@Cyrilvallez
Copy link
Copy Markdown
Member Author

run-slow: bert, bart

@Cyrilvallez Cyrilvallez changed the title [core] Completely remove cache positions [core] 🚨 Completely remove cache positions Feb 27, 2026
@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/bart", "models/bert"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN ab1b03f4 workflow commit (merge commit)
PR c2d8180f branch commit (from PR)
main fe3cb66e base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

Comment thread src/transformers/models/audioflamingo3/modeling_audioflamingo3.py Outdated
@zucchini-nlp
Copy link
Copy Markdown
Member

Can we also update/delete docs where cache position is mentioned, such as https://huggingface.co/docs/transformers/v5.2.0/en/cache_explanation#cache-position?

@Cyrilvallez
Copy link
Copy Markdown
Member Author

Updated audioflamingo, thanks the heads-up @zucchini-nlp! For the doc, I think it's best to wait for more models to remove them before deleting!

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this, but lets update the PR so the community understands why we are doing this breaking change (it is kinda breaking for the mask API). let's add our strong motivations please!

Also we could deprecate without breaking for some of the changes

Comment thread src/transformers/models/t5gemma/modular_t5gemma.py
Comment on lines +75 to 80
# It can either be an int for dynamic layers, or a tensor for static layers
if isinstance(self.cumulative_length, int):
self.cumulative_length = 0
else:
self.cumulative_length.zero_()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about dynamic + compile 🙃

Comment thread src/transformers/cache_utils.py
Comment thread src/transformers/cache_utils.py
else:
# Note: very important to use the tensor version of the cumulative length here, as otherwise cudagraphs
# (triggered by mode="reduced_overhead") will lead to random crashes, as the int would be overwritten
cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its annoying to me that we have to allocate new memory all the time here..... 😿

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's basically free - it's a tensor of 1 element 99% of the time, and even when it's not it's always super small!

"""
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
_ = kwargs.pop("allow_is_causal_skip", None)
_ = kwargs.pop("allow_torch_fix", None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just deprecate it for 2 releases at lease (cache postition as arg?)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Until v5.6!

Comment on lines -125 to -135
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)

conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good catch but unrelated no?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, fully unrelated, but I stumbled upon it and it's never used anywhere...

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🤗

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 4, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: afmoe, apertus, arcee, aria, audioflamingo3, bamba, bitnet, cohere, cohere2, csm, cwm, deepseek_v2, deepseek_v3, dia, diffllama

@Cyrilvallez Cyrilvallez merged commit 421c7f6 into main Mar 4, 2026
21 of 27 checks passed
@Cyrilvallez Cyrilvallez deleted the remove-cache-pos branch March 4, 2026 18:08
dacorvo added a commit that referenced this pull request Mar 18, 2026
…chmarks

Rewrite static_sample_investigation.md with:
- Context: goal is to determine neuron-only vs general static path
- Methodology: align on newest _sample algorithm, not neuron_sample fork
- Full comparison table: _static_sample vs neuron_sample (14 items)
- Benchmark results for Items A (output_ids CPU) and B (4D mask)
- Recent PRs affecting _sample (#44226, #44130, #44181, #44126)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
dacorvo added a commit that referenced this pull request Mar 18, 2026
…chmarks

Rewrite static_sample_investigation.md with:
- Context: goal is to determine neuron-only vs general static path
- Methodology: align on newest _sample algorithm, not neuron_sample fork
- Full comparison table: _static_sample vs neuron_sample (14 items)
- Benchmark results for Items A (output_ids CPU) and B (4D mask)
- Recent PRs affecting _sample (#44226, #44130, #44181, #44126)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants