diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0163e74e77cb..66e684635e64 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -46,6 +46,7 @@ from ..tokenization_utils import ExtensionsTrie from ..utils import ( ModelOutput, + TransformersKwargs, is_accelerate_available, is_hqq_available, is_optimum_quanto_available, @@ -559,8 +560,9 @@ def prepare_inputs_for_generation( **kwargs, ): """ - Prepare the model inputs for generation. It includes operations like computing the 4D attention mask or - slicing inputs given the existing cache. + Prepare the model inputs for generation. Notable steps include selecting the correct input key and cloning when appropriate, + creating position_ids from the attention_mask when missing, slicing inputs and converting 2D attention masks to 4D for + compilable caches, and finally forwarding all additional keyword arguments unchanged to the model's forward pass. See the forward pass in the model documentation for expected arguments (different models might have different requirements for e.g. `past_key_values`). This function should work as is for most LLMs. @@ -1592,8 +1594,9 @@ def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): decoder_model_args = set(inspect.signature(decoder.forward).parameters) model_args |= {f"decoder_{x}" for x in decoder_model_args} + # TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions' for key, value in model_kwargs.items(): - if value is not None and key not in model_args: + if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__: unused_model_args.append(key) if unused_model_args: @@ -1798,6 +1801,11 @@ def _prepare_generation_config( # Finally, apply any passed kwargs model_kwargs = generation_config.update(**kwargs) + # And keep in model_kwargs variable output controls + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + model_kwargs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_kwargs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) return generation_config, model_kwargs @@ -2761,10 +2769,6 @@ def _sample( # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - if is_prefill: outputs = self(**model_inputs, return_dict=True) is_prefill = False @@ -3247,10 +3251,6 @@ def _beam_search( flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len]) model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs) - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - model_outputs = self(**model_inputs, return_dict=True) # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping @@ -3575,9 +3575,6 @@ def _assisted_decoding( model_inputs["logits_to_keep"] = candidate_length + 1 # 2.2. Run a forward pass on the candidate sequence - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) outputs = self(**model_inputs) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index f5e337e52ebd..09f00845524d 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1492,6 +1492,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index aec09861de81..52814930a172 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -1211,6 +1211,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 699a177fc6c1..605ae4f59b63 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -818,6 +818,12 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @auto_docstring diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index ece001f9ce1f..01f227c79185 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -570,7 +570,17 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cac input_ids = input_ids[:, remove_prefix_length:] - return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} + model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} + + # token_type_ids are computed on CTRLModel.forward() + kwargs.pop("token_type_ids", None) + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + print(f"Warning: {key} is not a recognized input.") + model_inputs[key] = value + + return model_inputs @auto_docstring( diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 865daf384b49..5f08309b2085 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -1607,6 +1607,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 8b00de3ab97f..c81e8967bcf2 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -1372,6 +1372,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index dc593c979dc7..3cdf6da7bda3 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -862,6 +862,12 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @auto_docstring diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index cdd0f622bd86..0125132718a3 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1442,7 +1442,7 @@ def prepare_inputs_for_generation( if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - return { + model_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": kwargs.get("pixel_values"), @@ -1450,5 +1450,12 @@ def prepare_inputs_for_generation( "use_cache": use_cache, } + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + __all__ = ["GitForCausalLM", "GitModel", "GitPreTrainedModel", "GitVisionModel"] diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index e3a1e69fc861..edfbe6f7ee35 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1829,6 +1829,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 4de1ff253914..6715bd939d9f 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -383,6 +383,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index f604ffd3b72e..17246d6f1b2e 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1448,6 +1448,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py index c51d9109b48b..f4859f245428 100644 --- a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py +++ b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py @@ -1648,7 +1648,7 @@ def prepare_inputs_for_generation( dim=1, ) - return { + model_inputs = { "input_ids": input_ids, "image_embeds": image_embeds, "image_embeds_position_mask": image_embeds_position_mask, @@ -1658,6 +1658,13 @@ def prepare_inputs_for_generation( "use_cache": use_cache, } + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in model_kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + @add_start_docstrings( """ diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 10616323e13f..4a53c47c8b4a 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -803,6 +803,12 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @auto_docstring diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 85cf026e49d0..738c5376c33e 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -989,6 +989,12 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @auto_docstring diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 27c08626115d..660630726645 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -2255,7 +2255,7 @@ def prepare_inputs_for_generation( # we want to do it after a first token has been generated if model_inputs["input_ids"] is not None: - last_hidden_state = kwargs.get("last_hidden_state") + last_hidden_state = kwargs.pop("last_hidden_state") # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim) last_hidden_state = last_hidden_state.view(-1, 1, last_hidden_state.shape[-1]) @@ -2287,6 +2287,11 @@ def prepare_inputs_for_generation( model_inputs["input_ids"] = None model_inputs["inputs_embeds"] = inputs_embeds + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs def _update_model_kwargs_for_generation( diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 27c84910cb43..44fa05227ff8 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -602,7 +602,14 @@ def forward( def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict[str, Any]: # Overwritten -- old model with reduced inputs - return {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids} + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs @auto_docstring( diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index a79c803e77b9..5beefe79b314 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -2010,7 +2010,7 @@ def prepare_inputs_for_generation( if past_key_values is not None and past_key_values.get_seq_length() > 0: input_ids = input_ids[:, -1:] # first step, decoder_cached_states are empty - return { + model_inputs = { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "attention_mask": attention_mask, "head_mask": head_mask, @@ -2018,6 +2018,16 @@ def prepare_inputs_for_generation( "use_cache": use_cache, } + # Prophetnet does not support cache_position + kwargs.pop("cache_position", None) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel): """ diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 367af6692357..990f21359bc0 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -2345,14 +2345,22 @@ def prepare_inputs_for_generation( if past_key_values is not None: input_ids = input_ids[:, -1:] - inputs_dict = { + model_inputs = { "input_ids": input_ids, "past_buckets_states": past_key_values, "use_cache": use_cache, "num_hashes": num_hashes, } - return inputs_dict + # Attention mask is computed on ReformerModel.forward() + kwargs.pop("attention_mask", None) + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + print(f"Warning: {key} is not a recognized input.") + model_inputs[key] = value + + return model_inputs def _reorder_cache(self, past_key_values, beam_idx): reord_past_buckets_states = [] diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 0b16af278946..d86d4d0f8707 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -719,6 +719,12 @@ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=Non model_inputs["state"] = state model_inputs["use_cache"] = use_cache + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @auto_docstring diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 4e7316fb781b..a73b4a51cea4 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -994,7 +994,18 @@ def prepare_inputs_for_generation(self, input_ids, **kwargs): langs = torch.full_like(input_ids, lang_id) else: langs = None - return {"input_ids": input_ids, "langs": langs} + model_inputs = {"input_ids": input_ids, "langs": langs} + + # They are calculated on the fly on XLMModel.forward() + kwargs.pop("token_type_ids", None) + kwargs.pop("attention_mask", None) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs @auto_docstring def forward( diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index c625ce7b53ea..8d68f448f663 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -1013,13 +1013,23 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if position_ids is not None: position_ids = position_ids[:, remove_prefix_length:] - return { + model_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values, } + # They are calculated on the fly on XLMRobertaXLModel.forward() + model_kwargs.pop("token_type_ids", None) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in model_kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + @auto_docstring class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel): diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 736521ee9561..0c6b9f76eade 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -1472,7 +1472,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_mem ) target_mapping[:, 0, -1] = 1.0 - inputs = { + model_inputs = { "input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping, @@ -1481,9 +1481,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_mem # if past is defined in model kwargs then use it for faster decoding if past_key_values: - inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past_key_values) - - return inputs + model_inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past_key_values) + + # Attention mask is computed on the fly on XLNetModel.forward() + kwargs.pop("attention_mask", None) + # TODO: Ignoring use_cache should not happen, fixme. + kwargs.pop("use_cache", None) + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs @auto_docstring def forward( diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index b77ec26d2b31..7e2fce997683 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -1556,6 +1556,12 @@ def prepare_inputs_for_generation( model_inputs = {"input_ids": input_ids} model_inputs.update({"cache_params": cache_params, "use_cache": use_cache}) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @can_return_tuple diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 9c0f86ea4489..2f9edb1e113c 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -1185,6 +1185,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index ddd4f6f69079..33e7e4b5a351 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1606,6 +1606,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index dbeade214410..3b828cd8313a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1803,6 +1803,41 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) + @pytest.mark.generate + def test_prepare_inputs_for_generation_kwargs_forwards(self, **extra_kwargs): + """Tests that prepare_inputs_for_generation forwards arbitrary kwargs.""" + for model_class in self.all_generative_model_classes: + config, _ = self.prepare_config_and_inputs_for_generate() + + model = model_class(config).to(torch_device).eval() + + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) + + input_args = { + "input_ids": input_ids, + "cache_position": torch.tensor([9]).to(torch_device), + "position_ids": torch.tensor([[0, 1, 2], [0, 1, 2]]).to(torch_device), + } + arbitrary_kwargs = { + "output_attentions": True, + "output_hidden_states": True, + "custom_arg": "test_value", + "numeric_arg": 42, + } + + model_inputs = model.prepare_inputs_for_generation(**input_args, **arbitrary_kwargs, **extra_kwargs) + + # Verify that input_ids has proper name + if config.is_encoder_decoder: + self.assertTrue("decoder_input_ids" in model_inputs) + else: + self.assertTrue("input_ids" in model_inputs) + + # Verify that arbitrary kwargs are forwarded + for key, value in arbitrary_kwargs.items(): + self.assertTrue(key in model_inputs) + self.assertTrue(model_inputs[key] == value) + def _test_attention_implementation(self, attn_implementation): """ Compares the output of generate with the eager attention implementation against other implementations. diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py index 989608d686ea..5ac321c5a753 100644 --- a/tests/models/dia/test_modeling_dia.py +++ b/tests/models/dia/test_modeling_dia.py @@ -517,6 +517,10 @@ def test_generate_continue_from_past_key_values(self): ) ) + @pytest.mark.generate + def test_prepare_inputs_for_generation_kwargs_forwards(self): + super().test_prepare_inputs_for_generation_kwargs_forwards(encoder_outputs=torch.randn(2, 2, 32)) + @unittest.skip(reason="Indirectly checked in Dia through the generate methods.") def test_hidden_states_output(self): pass diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index 6df9393f8041..21f56e1bc56d 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -868,6 +868,14 @@ def test_generate_continue_from_inputs_embeds(self): def test_save_load(self): super().test_save_load() + @pytest.mark.generate + @unittest.skip(reason="Moshi requires setting `model.generated_audio_codes` in generate() before preparing inputs") + def test_prepare_inputs_for_generation_kwargs_forwards(self): + # If in the future `model.generated_audio_codes` is not required, this test can be re-enabled + super().test_prepare_inputs_for_generation_kwargs_forwards( + last_hidden_state=torch.randn(2, 3, 32), kwargs_depth_decoder={} + ) + def place_dict_on_device(dict_to_place, device): for key in dict_to_place: