Skip to content

[generate] Completely stop relying on cache_position to prepare inputs#44130

Merged
Cyrilvallez merged 29 commits intomainfrom
generate-input-prep
Feb 20, 2026
Merged

[generate] Completely stop relying on cache_position to prepare inputs#44130
Cyrilvallez merged 29 commits intomainfrom
generate-input-prep

Conversation

@Cyrilvallez
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez commented Feb 18, 2026

What does this PR do?

As per the title.

This PR is the first big step towards removing the cache_position everywhere, as they are not needed in general and everything can be inferred from the cache itself.
The major changes are the following:

  • prepare_inputs_for_generation now receives the already sliced input_ids or inputs_embeds according to the cache -> this sequence length is absolute and everything else is sliced based on it. The idea is that in every generation loop, we always know how many tokens we need to keep for next step with cache (usually a single one), so let's use this information instead of very convoluted logic. Currently, it takes as input the full sequence, then relies on cache_position to slice but this is extremely brittle versus taking the correct new input all the time as immediate input

  • align position_ids and cache_position to contain the values from all past tokens. Then they are only sliced to input length in prepare_inputs_for_generation. This was done because the attention_mask needs to contain the values of all tokens all the time (because we have padding), so it makes more sense to align the behavior of all important common inputs (position_ids, cache_position, attention_mask), so that we don't have to always remember what is already sliced and was is not) -- fixes some broken logic for that in _update_model_kwargs_for_generation as well

  • Revert most of Prefill-related logic in input preparation for generation #42088 as it's not needed to avoid running prefill for assistants - prefill can simply take a kwarg is_first_iteration to know if it's the assistant calling it several time. This is needed as when restarting from cache, we have to pass all the preceding input_ids as well and _prefill is now responsible to check the cache and slice the inputs to prefill with only the new tokens. We cannot easily slice before calling _prefill, as we potentially need to recreate the full attention_mask and positions_ids based on the FULL SEQUENCE (including preceding tokens)

Next steps

Now that input preparation is decorrelated from the cache_position, next PR will decorrelate both the cache update and the masking logic, and cache_position will not be useful anymore.

Note that currently, cache_position are not used anymore for general input preparation, but are still created and forwarded to the forwards for downstream use in the cache and masking.

cc @zucchini-nlp @vasqu @ArthurZucker

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Cyrilvallez Cyrilvallez changed the title Input preparation [generate] Stop relying on cache_position to slice inputs and simplify inputs preparation Feb 19, 2026
@Cyrilvallez Cyrilvallez changed the title [generate] Stop relying on cache_position to slice inputs and simplify inputs preparation [generate] Completely stop relying on cache_position to prepare inputs Feb 19, 2026
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love seeing hacky code gone and everything simplified! My only concern is BC with users' expectations. As I said below, there was a point when we got new issues in GH every day saying that "cache doesn't work" or users struggling to re-use system-prompt cache. So let's make sure that we communicate it well, add and update the docs with small snippets

Also if we are to deprecate cache_position in the next few releases, I'd prefer a tiny warning in _prepare_inputs. We can check if users passed cache_position explicitly and raise if there is anything. We need to explain how to craft their inputs correctly without cache position and link to docs

Finally let's run slow integration tests in test_utils.py. I know that some of those tests are kinda out-dated, gotta find a day and fix/clean them 🫠 On my list, and will start today since it's quite important to keep generation working

Comment thread src/transformers/models/csm/generation_csm.py
Comment thread src/transformers/generation/utils.py
Comment thread src/transformers/generation/utils.py Outdated
Comment thread src/transformers/generation/utils.py
Comment thread src/transformers/generation/utils.py Outdated
Comment on lines +3736 to +3738
past_length = cache.get_seq_length()
input_ids = input_ids[:, past_length:]
if "inputs_embeds" in model_kwargs:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will fail for cases when users pass already sliced input ids, and custom cache position. If we slice by past length, we end up with zero inputs. See test_generate_custom_cache_position in test_utils.py

I don't mind changing it, but we really need to communicate it well and change slowly. We already had a lot of issues and questions when cache_position appeared, about how and why to continue from existing cache. If smth "breaks", we might get another wave of issues

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, in general our way of doing here is very weird - first I assumed we were restarting from only new tokens as well, then thought that we only support restarting from full inputs. It can make sense since we have the information of the past potential padding tokens as well with full inputs.
And I wanted to change that later to also support just new tokens haha - did not know it could already be done
Only issue with restarting from only new tokens is that we may "loose" the information of previously padding tokens
However, it looks like we restart from only new tokens but fully incremented attention_mask in this test @zucchini-nlp, so it's not an issue at all and we can slice based on the mask as well. Can you confirm that?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(the test is not easy to understand, they do weird stuff 😅 For example, they seem to cat new tokens in the wrong order as it does cat(new_token, old_token) instead of cat(old_tokens, new_tokens) 😅

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so added support for already sliced inputs + FULL mask, and completely rewrote the test to make it understandable 🤗

@Cyrilvallez
Copy link
Copy Markdown
Member Author

Thanks for the review!!

As I said below, there was a point when we got new issues in GH every day saying that "cache doesn't work" or users struggling to re-use system-prompt cache. So let's make sure that we communicate it well, add and update the docs with small snippets

For next reviews, this was updated (it was only a small edge case), and there are no regression at all on what the inputs should look like from a user perspective!

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

@huggingface huggingface deleted a comment from github-actions Bot Feb 20, 2026
@huggingface huggingface deleted a comment from github-actions Bot Feb 20, 2026
@huggingface huggingface deleted a comment from github-actions Bot Feb 20, 2026
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: audioflamingo3, bart, csm, gemma, glmasr, gpt2, gptj, kosmos2, llama, llava, mamba2, mistral, openai, opt, rag, t5

@Cyrilvallez
Copy link
Copy Markdown
Member Author

run-slow: llama, mixtral, llava, whisper, mistral, qwen3_vl, gpt2

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/gpt2", "models/llama", "models/llava", "models/mistral", "models/mixtral", "models/qwen3_vl", "models/whisper"]
quantizations: []

@huggingface huggingface deleted a comment from github-actions Bot Feb 20, 2026
@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN b06a1b8b workflow commit (merge commit)
PR ab8ffcfa branch commit (from PR)
main 708d3e12 base commit (on main)

Model CI Report

4 new failed tests from this PR 😭

  • llava:
    tests/models/llava/test_modeling_llava.py::LlavaForConditionalGenerationIntegrationTest::test_pixtral (❌ ⟹ ❌)

  • whisper:
    tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_tiny_static_generation_long_form (✅ ⟹ ❌)
    tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_multi_batch_hard_prev_cond (✅ ⟹ ❌)
    tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_no_speech_detection (✅ ⟹ ❌)

@Cyrilvallez
Copy link
Copy Markdown
Member Author

pixtral test was actually failing and is now passing, not sure why it gets flagged as failing
whisper longform tests are known to be failing and will be checked by @eustlb! (thanks a lot haha)

Merging!

@Cyrilvallez Cyrilvallez merged commit ecf79eb into main Feb 20, 2026
25 of 27 checks passed
@Cyrilvallez Cyrilvallez deleted the generate-input-prep branch February 20, 2026 18:46
dacorvo added a commit that referenced this pull request Mar 18, 2026
…chmarks

Rewrite static_sample_investigation.md with:
- Context: goal is to determine neuron-only vs general static path
- Methodology: align on newest _sample algorithm, not neuron_sample fork
- Full comparison table: _static_sample vs neuron_sample (14 items)
- Benchmark results for Items A (output_ids CPU) and B (4D mask)
- Recent PRs affecting _sample (#44226, #44130, #44181, #44126)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
dacorvo added a commit that referenced this pull request Mar 18, 2026
…chmarks

Rewrite static_sample_investigation.md with:
- Context: goal is to determine neuron-only vs general static path
- Methodology: align on newest _sample algorithm, not neuron_sample fork
- Full comparison table: _static_sample vs neuron_sample (14 items)
- Benchmark results for Items A (output_ids CPU) and B (4D mask)
- Recent PRs affecting _sample (#44226, #44130, #44181, #44126)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants