[generate] Always pass full input_ids in prepare_inputs_for_generation#44226
[generate] Always pass full input_ids in prepare_inputs_for_generation#44226Cyrilvallez merged 14 commits intomainfrom
prepare_inputs_for_generation#44226Conversation
74727d9 to
766a86e
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| use_cache = model_kwargs.get("use_cache", True) | ||
| new_inputs_ids = input_ids[:, -1:] if use_cache else input_ids | ||
| model_inputs = self.prepare_inputs_for_generation(new_inputs_ids, **model_kwargs) | ||
| next_sequence_length = 1 if model_kwargs.get("use_cache", True) else None |
There was a problem hiding this comment.
nit: is it oke to default to use_cache=True? Prev we had a fallback to config attr
There was a problem hiding this comment.
It's actually guaranteed to exist in the model_kwargs. I updated to show it explicitly
|
run-slow: kosmos2_5 |
|
This comment contains models: ["models/kosmos2_5"] |
zucchini-nlp
left a comment
There was a problem hiding this comment.
The only question I have is about input embeddings. Not really clear why we need to slice them before prefill, otherwise lgtm
| # The cache is already taken into account in `_get_initial_cache_position`, so the length is only the new tokens if we slice | ||
| effective_input_length = next_sequence_length if next_sequence_length is not None else input_ids.shape[1] |
There was a problem hiding this comment.
Lost about this part, why inputs embeds cannot be sliced as is and does this work when both ids and embeds are passed?
There was a problem hiding this comment.
We never use them both simultaneously in prepare_inputs_for_generation. The issue is that _get_initial_cache_position will override the given sequence_length if inputs_embeds are present, so if we want the correct position we need to slice!
There was a problem hiding this comment.
ahh I see, it's surprising that _get_initial_cache_position treats embeds and ids differently
There was a problem hiding this comment.
Yes, way too surprising IMO. It should only take the sequence_length and/or cache and that's it imo, but for another time
| next_sequence_length: int | None = None, | ||
| past_key_values: Cache | None = None, | ||
| attention_mask: torch.LongTensor | None = None, | ||
| inputs_embeds: torch.FloatTensor | None = None, | ||
| cache_position: torch.LongTensor | None = None, |
There was a problem hiding this comment.
imo now we have next_sequence_length which has similar functionality as cache_position or even just past_key_values in simple cases. Let's clean up redundant arg and deprecate them properly in subsequent PRs
There was a problem hiding this comment.
Yeah the whole goal of those PRs was to remove cache_position - even though those 2 kwargs are not really the same as next_sequence_length is not intended to be used in modeling code at all, I've already started removing cache_position everywhere here #44181 🤗
8b4e2ae to
ac30867
Compare
| input_ids = input_ids[:, -next_sequence_length:] if next_sequence_length is not None else input_ids | ||
| model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) | ||
| batch_size, sequence_length = input_ids.shape[:2] # we slice here as some models may have them 3D | ||
|
|
There was a problem hiding this comment.
Not super relevant to this PR but this is probably a limitation to audio models, no? We can have 3D input ids with the codebook channel dim, so something along [bsz, channels, seq_len]
There was a problem hiding this comment.
It's actually not, as omitting the dim is equivalent to slicing it all with :!
| if model_input is not None and model_input.shape[-1] != sequence_length: | ||
| # Input can be 2D or 3D, and we always slice on `seq-length` (last dim) | ||
| model_input = model_input[..., -sequence_length:].clone(memory_format=torch.contiguous_format) | ||
| model_inputs[model_input_name] = model_input |
There was a problem hiding this comment.
| if model_input is not None and model_input.shape[-1] != sequence_length: | |
| # Input can be 2D or 3D, and we always slice on `seq-length` (last dim) | |
| model_input = model_input[..., -sequence_length:].clone(memory_format=torch.contiguous_format) | |
| model_inputs[model_input_name] = model_input | |
| if model_input is not None: | |
| # Input can be 2D or 3D, and we always slice on `seq-length` (last dim) | |
| model_input = model_input[..., -sequence_length:].clone(memory_format=torch.contiguous_format) | |
| model_inputs[model_input_name] = model_input |
Not a fan of a data dependent control flow here; should work in any case, no?
There was a problem hiding this comment.
I don't think shape checks are data-dependent flows in terms of dynamo - we can remove anyway but then we always clone which is unnecessary (though not really a big deal)
|
[For maintainers] Suggested jobs to run (before merge) run-slow: csm, higgs_audio_v2, kosmos2_5, paligemma, rag, xglm, xlm |
|
Well, technically #44130 was the pre-step for removing cache_position - this is a fix after said PR as audio models need to access full input_ids in |
…chmarks Rewrite static_sample_investigation.md with: - Context: goal is to determine neuron-only vs general static path - Methodology: align on newest _sample algorithm, not neuron_sample fork - Full comparison table: _static_sample vs neuron_sample (14 items) - Benchmark results for Items A (output_ids CPU) and B (4D mask) - Recent PRs affecting _sample (#44226, #44130, #44181, #44126) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…chmarks Rewrite static_sample_investigation.md with: - Context: goal is to determine neuron-only vs general static path - Methodology: align on newest _sample algorithm, not neuron_sample fork - Full comparison table: _static_sample vs neuron_sample (14 items) - Benchmark results for Items A (output_ids CPU) and B (4D mask) - Recent PRs affecting _sample (#44226, #44130, #44181, #44126) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
What does this PR do?
As per the title. It looks like some models (xlnet and kosmos2_5) and most audio models sometimes rely on the full previous input_ids to prepare inputs. Note that this cannot be compatible with restarting generation from a previously filled cache and new inputs, so those models are not well-behaved in general (if they use a cache, they should be able to restart from it with any input). However, it's a simple fix to forward the full inputs and slice inside
prepare_inputs_for_generationfor those models.This should fix all audio models cc @eustlb. Please note my comment about how that means audio models cannot restart from old cache, not sure if it's known/intended/fixable.
As for the other only 2 non-audio models requiring past input_ids (xlnet and kosmos2_5), it's in general the same issue. Kosmos could be fixed, the other not sure.
See https://huggingface.co/datasets/hf-internal-testing/transformers_daily_ci/raw/8785954cca2fdca181de0b9567059471bcadb959/2026-02-21/ci_results_run_models_gpu/new_failures_with_bad_commit_grouped_by_authors.json for details on the failing tests.
cc @zucchini-nlp @vasqu