Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
331a87a
Squashed commit of the following:
manueldeprada Sep 1, 2025
47580e4
ops
manueldeprada Sep 1, 2025
3b09223
fix
manueldeprada Sep 1, 2025
2a74c4b
ops
manueldeprada Sep 1, 2025
ddea5ed
Merge branch 'main' into fix-custom-gen-from-function2
manueldeprada Sep 1, 2025
18277b4
Merge branch 'main' into fix-custom-gen-from-function2
manueldeprada Sep 2, 2025
bfd41f1
review
manueldeprada Sep 2, 2025
494c9a8
fix
manueldeprada Sep 2, 2025
da72717
fix dia
manueldeprada Sep 2, 2025
27582a1
unify assisted generate to common decoding method signature
manueldeprada Sep 3, 2025
fcff2f3
Merge branch 'main' of github.com:huggingface/transformers into unify…
manueldeprada Sep 3, 2025
fcfc23d
move checks to validate steps where possible
manueldeprada Sep 3, 2025
814b2ec
Merge branch 'main' into unify-assisted-generate
manueldeprada Sep 3, 2025
35fc116
fix csm and other models that override _sample
manueldeprada Sep 3, 2025
a255758
Merge branch 'unify-assisted-generate' of github.com:manueldeprada/tr…
manueldeprada Sep 3, 2025
26919bd
ops dia you again
manueldeprada Sep 3, 2025
a3f7be3
opsie
manueldeprada Sep 4, 2025
cf84b1e
Merge branch 'main' into unify-assisted-generate
manueldeprada Sep 4, 2025
a580344
joao review
manueldeprada Sep 4, 2025
58ffe60
ops
manueldeprada Sep 4, 2025
2940468
ops2
manueldeprada Sep 4, 2025
ceaaf68
dia
manueldeprada Sep 4, 2025
273b6d9
Move variable output controls to `prepare_inputs_for_generation`
manueldeprada Sep 5, 2025
ac7efcb
fix xlstm
manueldeprada Sep 5, 2025
9a25f88
skip on args check
manueldeprada Sep 8, 2025
315247a
fix xlm roberta, zamba
manueldeprada Sep 8, 2025
453ee1e
fix moshi, rwkv
manueldeprada Sep 8, 2025
9a3a826
fix mamba2
manueldeprada Sep 8, 2025
18e1ec1
fix a bunch of models
manueldeprada Sep 8, 2025
2e03248
fix
manueldeprada Sep 8, 2025
c4f6257
review
manueldeprada Sep 10, 2025
190dfb5
ops
manueldeprada Sep 10, 2025
aa93797
better comment
manueldeprada Sep 10, 2025
f668884
back to basics
manueldeprada Sep 14, 2025
69b21a8
Merge commit '4cbca0d1af4a362a803abe05779837e327cda54b' into proper-p…
manueldeprada Sep 14, 2025
7983f3c
Merge commit '16b821c5423c9ac7567cc75ec7810bb5ab5b3772' into proper-p…
manueldeprada Sep 14, 2025
b7d4f19
Merge branch 'unify-assisted-generate' into proper-preparate-inputs
manueldeprada Sep 14, 2025
1a74ba4
Merge branch 'main' of github.com:huggingface/transformers into prope…
manueldeprada Sep 14, 2025
a4e750a
final touches
manueldeprada Sep 15, 2025
712f39a
ops
manueldeprada Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ..tokenization_utils import ExtensionsTrie
from ..utils import (
ModelOutput,
TransformersKwargs,
is_accelerate_available,
is_hqq_available,
is_optimum_quanto_available,
Expand Down Expand Up @@ -559,8 +560,9 @@ def prepare_inputs_for_generation(
**kwargs,
):
"""
Prepare the model inputs for generation. It includes operations like computing the 4D attention mask or
slicing inputs given the existing cache.
Prepare the model inputs for generation. Notable steps include selecting the correct input key and cloning when appropriate,
creating position_ids from the attention_mask when missing, slicing inputs and converting 2D attention masks to 4D for
compilable caches, and finally forwarding all additional keyword arguments unchanged to the model's forward pass.

See the forward pass in the model documentation for expected arguments (different models might have different
requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
Expand Down Expand Up @@ -1592,8 +1594,9 @@ def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}

# TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
unused_model_args.append(key)

if unused_model_args:
Expand Down Expand Up @@ -1798,6 +1801,11 @@ def _prepare_generation_config(

# Finally, apply any passed kwargs
model_kwargs = generation_config.update(**kwargs)
# And keep in model_kwargs variable output controls
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
model_kwargs.update({"output_attentions": output_attentions} if output_attentions else {})
model_kwargs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

return generation_config, model_kwargs

Expand Down Expand Up @@ -2761,10 +2769,6 @@ def _sample(
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

if is_prefill:
outputs = self(**model_inputs, return_dict=True)
is_prefill = False
Expand Down Expand Up @@ -3247,10 +3251,6 @@ def _beam_search(
flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)

# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

model_outputs = self(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
Expand Down Expand Up @@ -3575,9 +3575,6 @@ def _assisted_decoding(
model_inputs["logits_to_keep"] = candidate_length + 1

# 2.2. Run a forward pass on the candidate sequence
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

outputs = self(**model_inputs)

Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,12 @@ def prepare_inputs_for_generation(
"cache_position": cache_position,
}
)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs


Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,12 @@ def prepare_inputs_for_generation(
"cache_position": cache_position,
}
)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs


Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,12 @@ def prepare_inputs_for_generation(
"attention_mask": attention_mask,
}
)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs

@auto_docstring
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,17 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cac

input_ids = input_ids[:, remove_prefix_length:]

return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}

# token_type_ids are computed on CTRLModel.forward()
kwargs.pop("token_type_ids", None)
# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
print(f"Warning: {key} is not a recognized input.")
model_inputs[key] = value

return model_inputs


@auto_docstring(
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/falcon_h1/modeling_falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,12 @@ def prepare_inputs_for_generation(
"cache_position": cache_position,
}
)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs


Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/falcon_h1/modular_falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,12 @@ def prepare_inputs_for_generation(
"cache_position": cache_position,
}
)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs


Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,12 @@ def prepare_inputs_for_generation(
"attention_mask": attention_mask,
}
)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs

@auto_docstring
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/git/modeling_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,13 +1442,20 @@ def prepare_inputs_for_generation(
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

return {
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": kwargs.get("pixel_values"),
"past_key_values": past_key_values,
"use_cache": use_cache,
}

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs


__all__ = ["GitForCausalLM", "GitModel", "GitPreTrainedModel", "GitVisionModel"]
Original file line number Diff line number Diff line change
Expand Up @@ -1829,6 +1829,12 @@ def prepare_inputs_for_generation(
"cache_position": cache_position,
}
)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,12 @@ def prepare_inputs_for_generation(
"cache_position": cache_position,
}
)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs


Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,12 @@ def prepare_inputs_for_generation(
"cache_position": cache_position,
}
)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs


Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/kosmos2_5/modeling_kosmos2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,7 +1648,7 @@ def prepare_inputs_for_generation(
dim=1,
)

return {
model_inputs = {
"input_ids": input_ids,
"image_embeds": image_embeds,
"image_embeds_position_mask": image_embeds_position_mask,
Expand All @@ -1658,6 +1658,13 @@ def prepare_inputs_for_generation(
"use_cache": use_cache,
}

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in model_kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs


@add_start_docstrings(
"""
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,12 @@ def prepare_inputs_for_generation(
"attention_mask": attention_mask,
}
)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs

@auto_docstring
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,12 @@ def prepare_inputs_for_generation(
"attention_mask": attention_mask,
}
)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs

@auto_docstring
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2255,7 +2255,7 @@ def prepare_inputs_for_generation(

# we want to do it after a first token has been generated
if model_inputs["input_ids"] is not None:
last_hidden_state = kwargs.get("last_hidden_state")
last_hidden_state = kwargs.pop("last_hidden_state")
# (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim)
last_hidden_state = last_hidden_state.view(-1, 1, last_hidden_state.shape[-1])

Expand Down Expand Up @@ -2287,6 +2287,11 @@ def prepare_inputs_for_generation(
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = inputs_embeds

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs

def _update_model_kwargs_for_generation(
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/openai/modeling_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,14 @@ def forward(

def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict[str, Any]:
# Overwritten -- old model with reduced inputs
return {"input_ids": input_ids}
model_inputs = {"input_ids": input_ids}

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs


@auto_docstring(
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/prophetnet/modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,14 +2010,24 @@ def prepare_inputs_for_generation(
if past_key_values is not None and past_key_values.get_seq_length() > 0:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
model_inputs = {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"head_mask": head_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

# Prophetnet does not support cache_position
kwargs.pop("cache_position", None)

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs


class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
"""
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/models/reformer/modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,14 +2345,22 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
input_ids = input_ids[:, -1:]

inputs_dict = {
model_inputs = {
"input_ids": input_ids,
"past_buckets_states": past_key_values,
"use_cache": use_cache,
"num_hashes": num_hashes,
}

return inputs_dict
# Attention mask is computed on ReformerModel.forward()
kwargs.pop("attention_mask", None)
# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
print(f"Warning: {key} is not a recognized input.")
model_inputs[key] = value

return model_inputs

def _reorder_cache(self, past_key_values, beam_idx):
reord_past_buckets_states = []
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/rwkv/modeling_rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,12 @@ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=Non

model_inputs["state"] = state
model_inputs["use_cache"] = use_cache

# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

return model_inputs

@auto_docstring
Expand Down
Loading