Skip to content

[generate] Always pass full input_ids in prepare_inputs_for_generation#44226

Merged
Cyrilvallez merged 14 commits intomainfrom
send-full-inputs
Feb 24, 2026
Merged

[generate] Always pass full input_ids in prepare_inputs_for_generation#44226
Cyrilvallez merged 14 commits intomainfrom
send-full-inputs

Conversation

@Cyrilvallez
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez commented Feb 23, 2026

What does this PR do?

As per the title. It looks like some models (xlnet and kosmos2_5) and most audio models sometimes rely on the full previous input_ids to prepare inputs. Note that this cannot be compatible with restarting generation from a previously filled cache and new inputs, so those models are not well-behaved in general (if they use a cache, they should be able to restart from it with any input). However, it's a simple fix to forward the full inputs and slice inside prepare_inputs_for_generation for those models.

This should fix all audio models cc @eustlb. Please note my comment about how that means audio models cannot restart from old cache, not sure if it's known/intended/fixable.
As for the other only 2 non-audio models requiring past input_ids (xlnet and kosmos2_5), it's in general the same issue. Kosmos could be fixed, the other not sure.

See https://huggingface.co/datasets/hf-internal-testing/transformers_daily_ci/raw/8785954cca2fdca181de0b9567059471bcadb959/2026-02-21/ci_results_run_models_gpu/new_failures_with_bad_commit_grouped_by_authors.json for details on the failing tests.

cc @zucchini-nlp @vasqu

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread src/transformers/generation/utils.py Outdated
use_cache = model_kwargs.get("use_cache", True)
new_inputs_ids = input_ids[:, -1:] if use_cache else input_ids
model_inputs = self.prepare_inputs_for_generation(new_inputs_ids, **model_kwargs)
next_sequence_length = 1 if model_kwargs.get("use_cache", True) else None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is it oke to default to use_cache=True? Prev we had a fallback to config attr

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually guaranteed to exist in the model_kwargs. I updated to show it explicitly

@Cyrilvallez
Copy link
Copy Markdown
Member Author

run-slow: kosmos2_5

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/kosmos2_5"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 2e262fed workflow commit (merge commit)
PR 3bd0a68b branch commit (from PR)
main 6ed9ee36 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only question I have is about input embeddings. Not really clear why we need to slice them before prefill, otherwise lgtm

Comment on lines +3762 to +3763
# The cache is already taken into account in `_get_initial_cache_position`, so the length is only the new tokens if we slice
effective_input_length = next_sequence_length if next_sequence_length is not None else input_ids.shape[1]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lost about this part, why inputs embeds cannot be sliced as is and does this work when both ids and embeds are passed?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never use them both simultaneously in prepare_inputs_for_generation. The issue is that _get_initial_cache_position will override the given sequence_length if inputs_embeds are present, so if we want the correct position we need to slice!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh I see, it's surprising that _get_initial_cache_position treats embeds and ids differently

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, way too surprising IMO. It should only take the sequence_length and/or cache and that's it imo, but for another time

Comment thread src/transformers/models/kosmos2_5/modeling_kosmos2_5.py
Comment on lines +496 to 500
next_sequence_length: int | None = None,
past_key_values: Cache | None = None,
attention_mask: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
cache_position: torch.LongTensor | None = None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo now we have next_sequence_length which has similar functionality as cache_position or even just past_key_values in simple cases. Let's clean up redundant arg and deprecate them properly in subsequent PRs

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the whole goal of those PRs was to remove cache_position - even though those 2 kwargs are not really the same as next_sequence_length is not intended to be used in modeling code at all, I've already started removing cache_position everywhere here #44181 🤗

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just some smaller comments. I guess this is a pre-step to removing cache positions in #44181?

Comment on lines +527 to 530
input_ids = input_ids[:, -next_sequence_length:] if next_sequence_length is not None else input_ids
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
batch_size, sequence_length = input_ids.shape[:2] # we slice here as some models may have them 3D

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super relevant to this PR but this is probably a limitation to audio models, no? We can have 3D input ids with the codebook channel dim, so something along [bsz, channels, seq_len]

Copy link
Copy Markdown
Member Author

@Cyrilvallez Cyrilvallez Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually not, as omitting the dim is equivalent to slicing it all with :!

Comment on lines +544 to +547
if model_input is not None and model_input.shape[-1] != sequence_length:
# Input can be 2D or 3D, and we always slice on `seq-length` (last dim)
model_input = model_input[..., -sequence_length:].clone(memory_format=torch.contiguous_format)
model_inputs[model_input_name] = model_input
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if model_input is not None and model_input.shape[-1] != sequence_length:
# Input can be 2D or 3D, and we always slice on `seq-length` (last dim)
model_input = model_input[..., -sequence_length:].clone(memory_format=torch.contiguous_format)
model_inputs[model_input_name] = model_input
if model_input is not None:
# Input can be 2D or 3D, and we always slice on `seq-length` (last dim)
model_input = model_input[..., -sequence_length:].clone(memory_format=torch.contiguous_format)
model_inputs[model_input_name] = model_input

Not a fan of a data dependent control flow here; should work in any case, no?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think shape checks are data-dependent flows in terms of dynamo - we can remove anyway but then we always clone which is unnecessary (though not really a big deal)

Comment thread src/transformers/generation/utils.py
Comment thread src/transformers/models/paligemma/modeling_paligemma.py
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: csm, higgs_audio_v2, kosmos2_5, paligemma, rag, xglm, xlm

@Cyrilvallez
Copy link
Copy Markdown
Member Author

Well, technically #44130 was the pre-step for removing cache_position - this is a fix after said PR as audio models need to access full input_ids in prepare_inputs_for_generation unfortunately, so we have to delay the slicing a little bit!

@Cyrilvallez Cyrilvallez merged commit 3c52b78 into main Feb 24, 2026
26 checks passed
@Cyrilvallez Cyrilvallez deleted the send-full-inputs branch February 24, 2026 10:45
dacorvo added a commit that referenced this pull request Mar 18, 2026
…chmarks

Rewrite static_sample_investigation.md with:
- Context: goal is to determine neuron-only vs general static path
- Methodology: align on newest _sample algorithm, not neuron_sample fork
- Full comparison table: _static_sample vs neuron_sample (14 items)
- Benchmark results for Items A (output_ids CPU) and B (4D mask)
- Recent PRs affecting _sample (#44226, #44130, #44181, #44126)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
dacorvo added a commit that referenced this pull request Mar 18, 2026
…chmarks

Rewrite static_sample_investigation.md with:
- Context: goal is to determine neuron-only vs general static path
- Methodology: align on newest _sample algorithm, not neuron_sample fork
- Full comparison table: _static_sample vs neuron_sample (14 items)
- Benchmark results for Items A (output_ids CPU) and B (4D mask)
- Recent PRs affecting _sample (#44226, #44130, #44181, #44126)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants