[generate] Completely stop relying on cache_position to prepare inputs#44130
[generate] Completely stop relying on cache_position to prepare inputs#44130Cyrilvallez merged 29 commits intomainfrom
cache_position to prepare inputs#44130Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cache_position to slice inputs and simplify inputs preparation
cache_position to slice inputs and simplify inputs preparationcache_position to prepare inputs
There was a problem hiding this comment.
I love seeing hacky code gone and everything simplified! My only concern is BC with users' expectations. As I said below, there was a point when we got new issues in GH every day saying that "cache doesn't work" or users struggling to re-use system-prompt cache. So let's make sure that we communicate it well, add and update the docs with small snippets
Also if we are to deprecate cache_position in the next few releases, I'd prefer a tiny warning in _prepare_inputs. We can check if users passed cache_position explicitly and raise if there is anything. We need to explain how to craft their inputs correctly without cache position and link to docs
Finally let's run slow integration tests in test_utils.py. I know that some of those tests are kinda out-dated, gotta find a day and fix/clean them 🫠 On my list, and will start today since it's quite important to keep generation working
| past_length = cache.get_seq_length() | ||
| input_ids = input_ids[:, past_length:] | ||
| if "inputs_embeds" in model_kwargs: |
There was a problem hiding this comment.
this will fail for cases when users pass already sliced input ids, and custom cache position. If we slice by past length, we end up with zero inputs. See test_generate_custom_cache_position in test_utils.py
I don't mind changing it, but we really need to communicate it well and change slowly. We already had a lot of issues and questions when cache_position appeared, about how and why to continue from existing cache. If smth "breaks", we might get another wave of issues
There was a problem hiding this comment.
Yeah I agree, in general our way of doing here is very weird - first I assumed we were restarting from only new tokens as well, then thought that we only support restarting from full inputs. It can make sense since we have the information of the past potential padding tokens as well with full inputs.
And I wanted to change that later to also support just new tokens haha - did not know it could already be done
Only issue with restarting from only new tokens is that we may "loose" the information of previously padding tokens
However, it looks like we restart from only new tokens but fully incremented attention_mask in this test @zucchini-nlp, so it's not an issue at all and we can slice based on the mask as well. Can you confirm that?
There was a problem hiding this comment.
(the test is not easy to understand, they do weird stuff 😅 For example, they seem to cat new tokens in the wrong order as it does cat(new_token, old_token) instead of cat(old_tokens, new_tokens) 😅
There was a problem hiding this comment.
Ok, so added support for already sliced inputs + FULL mask, and completely rewrote the test to make it understandable 🤗
|
Thanks for the review!!
For next reviews, this was updated (it was only a small edge case), and there are no regression at all on what the inputs should look like from a user perspective! |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: audioflamingo3, bart, csm, gemma, glmasr, gpt2, gptj, kosmos2, llama, llava, mamba2, mistral, openai, opt, rag, t5 |
|
run-slow: llama, mixtral, llava, whisper, mistral, qwen3_vl, gpt2 |
|
This comment contains models: ["models/gpt2", "models/llama", "models/llava", "models/mistral", "models/mixtral", "models/qwen3_vl", "models/whisper"] |
CI ResultsCommit Info
Model CI Report❌ 4 new failed tests from this PR 😭
|
|
pixtral test was actually failing and is now passing, not sure why it gets flagged as failing Merging! |
…chmarks Rewrite static_sample_investigation.md with: - Context: goal is to determine neuron-only vs general static path - Methodology: align on newest _sample algorithm, not neuron_sample fork - Full comparison table: _static_sample vs neuron_sample (14 items) - Benchmark results for Items A (output_ids CPU) and B (4D mask) - Recent PRs affecting _sample (#44226, #44130, #44181, #44126) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…chmarks Rewrite static_sample_investigation.md with: - Context: goal is to determine neuron-only vs general static path - Methodology: align on newest _sample algorithm, not neuron_sample fork - Full comparison table: _static_sample vs neuron_sample (14 items) - Benchmark results for Items A (output_ids CPU) and B (4D mask) - Recent PRs affecting _sample (#44226, #44130, #44181, #44126) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
What does this PR do?
As per the title.
This PR is the first big step towards removing the
cache_positioneverywhere, as they are not needed in general and everything can be inferred from the cache itself.The major changes are the following:
prepare_inputs_for_generationnow receives the already slicedinput_idsorinputs_embedsaccording to the cache -> this sequence length is absolute and everything else is sliced based on it. The idea is that in every generation loop, we always know how many tokens we need to keep for next step with cache (usually a single one), so let's use this information instead of very convoluted logic. Currently, it takes as input the full sequence, then relies oncache_positionto slice but this is extremely brittle versus taking the correct new input all the time as immediate inputalign
position_idsandcache_positionto contain the values from all past tokens. Then they are only sliced to input length inprepare_inputs_for_generation. This was done because theattention_maskneeds to contain the values of all tokens all the time (because we have padding), so it makes more sense to align the behavior of all important common inputs (position_ids,cache_position,attention_mask), so that we don't have to always remember what is already sliced and was is not) -- fixes some broken logic for that in_update_model_kwargs_for_generationas wellRevert most of Prefill-related logic in input preparation for generation #42088 as it's not needed to avoid running prefill for assistants - prefill can simply take a kwarg
is_first_iterationto know if it's the assistant calling it several time. This is needed as when restarting from cache, we have to pass all the precedinginput_idsas well and_prefillis now responsible to check the cache and slice the inputs to prefill with only the new tokens. We cannot easily slice before calling_prefill, as we potentially need to recreate the fullattention_maskandpositions_idsbased on the FULL SEQUENCE (including preceding tokens)Next steps
Now that input preparation is decorrelated from the
cache_position, next PR will decorrelate both the cache update and the masking logic, andcache_positionwill not be useful anymore.Note that currently,
cache_positionare not used anymore for general input preparation, but are still created and forwarded to theforwards for downstream use in the cache and masking.cc @zucchini-nlp @vasqu @ArthurZucker