From 331a87a711578d765091627cb47031e25c0b7ecc Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 1 Sep 2025 17:20:57 +0200 Subject: [PATCH 01/30] Squashed commit of the following: commit beb2b5f7a04ea9e12876696db66f3589fbae10c5 Author: Manuel de Prada Corral Date: Mon Sep 1 16:03:25 2025 +0200 also standardize _get_stopping_criteria commit 15c25663fa991e0a215a7f3cdcf13a9d3a989faa Author: Manuel de Prada Corral Date: Mon Sep 1 15:48:38 2025 +0200 watch super.generate() usages commit 67dd845be2202d191a54b2872f1cb3f71b74b7d6 Author: Manuel de Prada Corral Date: Mon Sep 1 14:44:32 2025 +0200 ops commit 4655dfa28fd59d5dc083a41d8396de042d99858c Author: Manuel de Prada Corral Date: Mon Sep 1 14:41:36 2025 +0200 wrong merge commit 46478143994e7b27d51c972a7881e0fea3cb6e3c Merge: a72c2c4b2f 8564e210ca Author: Manuel de Prada Corral Date: Mon Sep 1 14:36:15 2025 +0200 Merge branch 'main' of github.com:huggingface/transformers into fix-custom-gen-from-function2 commit a72c2c4b2f9c0e09fe6ec7992d4d02bfa279da2a Author: Manuel de Prada Corral Date: Mon Sep 1 14:04:59 2025 +0200 ops5 commit e72f91411b961979bb3d271810f57905cee5b577 Author: Manuel de Prada Corral Date: Mon Sep 1 12:06:19 2025 +0200 ops4 commit 12ca97b1078a42167143e0243036f6ef87d5fdac Author: Manuel de Prada Corral Date: Mon Sep 1 11:58:59 2025 +0200 ops3 commit 8cac6c60a318dd381793d4bf1ef3775823f3c95b Author: Manuel de Prada Corral Date: Mon Sep 1 11:43:03 2025 +0200 ops2 commit 4681a7d5dc6c8b96a515d9d79f06380c096b9a9f Author: Manuel de Prada Corral Date: Mon Sep 1 11:40:51 2025 +0200 ops commit 0d72aa6cbd99a5933c5a95a39bea9088ee21e50f Merge: e0d47e980e 5bb6186b8e Author: Manuel de Prada Corral Date: Mon Sep 1 11:37:28 2025 +0200 Merge branch 'remove-constrained-bs' into fix-custom-gen-from-function2 commit 5bb6186b8efbd5fdb8e3464a22f958343b9c450c Merge: 44973dac7d b0db5a02f3 Author: Manuel de Prada Corral Date: Mon Sep 1 11:36:30 2025 +0200 Merge branch 'main' into remove-constrained-bs commit 44973dac7df4b4e2111c71f5fac918be21f3de52 Merge: 1ddab4bee1 893d89e5e6 Author: Manuel de Prada Corral Date: Mon Sep 1 11:29:48 2025 +0200 Merge commit '893d89e5e6fac7279fe4292bfa3b027172287162' into remove-constrained-bs commit e0d47e980e26d32b028c2b402ccb71262637a7a7 Merge: 88128e4563 1ddab4bee1 Author: Manuel de Prada Corral Date: Mon Sep 1 10:52:50 2025 +0200 Merge branch 'remove-constrained-bs' into fix-custom-gen-from-function2 commit 88128e4563c0be583728e1d3c639bc93143c4029 Author: Manuel de Prada Corral Date: Mon Sep 1 10:44:38 2025 +0200 fix custom generate args, refactor gen mode args commit 1ddab4bee159f6c20722e7ff5cd41d5041fab0aa Author: Manuel de Prada Corral Date: Sun Aug 31 21:03:53 2025 +0200 fix commit 6095fdda677ef7fbeb06c05f4f914a11b45257b4 Merge: 4a8b6d2ce1 04addbc9ec Author: Manuel de Prada Corral Date: Thu Aug 28 17:49:16 2025 +0200 Merge branch 'remove-constrained-bs' of github.com:manueldeprada/transformers into remove-constrained-bs commit 4a8b6d2ce18b3a8b52c5261fea427e2416f65187 Author: Manuel de Prada Corral Date: Thu Aug 28 17:48:25 2025 +0200 restore and deprecate beam obkects commit 04addbc9ec62dd4f59d15128e8cd9499e2cda3bb Merge: e800c7841e becab2c601 Author: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Thu Aug 28 14:38:29 2025 +0200 Merge branch 'main' into remove-constrained-bs commit e800c7841e5c46ce5698fc9be309d0808f85d23c Author: Manuel de Prada Corral Date: Thu Aug 28 14:38:10 2025 +0200 tests gone after green commit 33971d21ac40aef76a7e1122f4a98ef28beadbe8 Author: Manuel de Prada Corral Date: Thu Aug 28 14:07:11 2025 +0200 tests green, changed handling of deprecated methods commit ab303835c184d0a87789da7aed7d8de5ba85d867 Author: Manuel de Prada Corral Date: Thu Aug 28 12:58:01 2025 +0200 tests fix commit ec74274ca52a6aa0b5f300374fda838609680506 Author: Manuel de Prada Corral Date: Thu Aug 28 12:32:05 2025 +0200 ops commit 0fb19004ccd285dcad485fce0865b355ce5493e0 Author: Manuel de Prada Corral Date: Thu Aug 28 11:45:16 2025 +0200 whoops commit c946bea5e45aea021c8878c57fcabc2a13f06fe5 Author: Manuel de Prada Corral Date: Thu Aug 28 11:35:36 2025 +0200 testing... commit 924c0dec6d9ea6b4890644fe7f711dc778f820bb Author: Manuel de Prada Corral Date: Thu Aug 28 11:22:46 2025 +0200 sweeep ready for tests commit b05aa771d3994b07cd460cda74b274c9e4f315e6 Author: Manuel de Prada Corral Date: Thu Aug 28 11:13:01 2025 +0200 restore and deprecate constraints commit 9c7962d10efa7178b69d3c99e69663756e1cd979 Merge: fceeb383f9 c17bf304d5 Author: Manuel de Prada Corral Date: Wed Aug 27 20:44:21 2025 +0200 Merge branch 'remove-group-bs' into remove-constrained-bs commit c17bf304d5cf33af7f34f9f6057915d5f5821dae Author: Manuel de Prada Corral Date: Wed Aug 27 17:00:50 2025 +0200 fix test commit d579aeec6706b77fcc24c1f6806cd7277d7db56e Merge: 822efd8c3c ed5dd2999c Author: Manuel de Prada Corral Date: Wed Aug 27 16:04:31 2025 +0200 Merge branch 'main' of github.com:huggingface/transformers into remove-group-bs commit 822efd8c3cf475d079e64293aa06e4ab59740fd7 Author: Manuel de Prada Corral Date: Wed Aug 27 15:59:51 2025 +0200 aaand remove tests after all green!! commit 62cb274a4acb9f24201902242f1b0dc4e46daac1 Author: Manuel de Prada Corral Date: Wed Aug 27 11:48:19 2025 +0200 fix commit c89c892e7b24a7d71831f2b35264456005030925 Author: Manuel de Prada Corral Date: Wed Aug 27 11:45:20 2025 +0200 testing that hub works the same commit fceeb383f99e4a836679d67b1d2a8520152eaf49 Author: Manuel de Prada Corral Date: Tue Aug 26 20:06:59 2025 +0200 draft commit 6a9b384078f3798587ba865ac7ddfefc9a79e41c Merge: 8af3af13ab 58cebc848b Author: Manuel de Prada Corral Date: Tue Aug 26 15:00:05 2025 +0200 Merge branch 'main' of github.com:huggingface/transformers into remove-group-bs commit 8af3af13abb85ca60e795d0390832f398a56c34f Author: Manuel de Prada Corral Date: Tue Aug 26 11:55:45 2025 +0200 Squashed commit remove-constrastive-search --- src/transformers/generation/utils.py | 107 +++++++++++------- src/transformers/models/dia/generation_dia.py | 15 ++- .../modeling_kyutai_speech_to_text.py | 2 +- .../modular_kyutai_speech_to_text.py | 2 +- .../models/musicgen/modeling_musicgen.py | 2 +- .../modeling_musicgen_melody.py | 2 +- src/transformers/models/rag/modeling_rag.py | 2 +- 7 files changed, 83 insertions(+), 49 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f118cf5276b0..1cd949e46239 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -999,11 +999,11 @@ def _get_candidate_generator( generation_config: GenerationConfig, input_ids: torch.LongTensor, inputs_tensor: torch.Tensor, - assistant_model: "PreTrainedModel", logits_processor: LogitsProcessorList, - target_tokenizer: "PreTrainedTokenizerBase", - assistant_tokenizer: "PreTrainedTokenizerBase", model_kwargs: dict, + assistant_model: Optional["PreTrainedModel"] = None, + target_tokenizer: Optional["PreTrainedTokenizerBase"] = None, + assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None, ) -> CandidateGenerator: """ Returns the candidate generator to be used in `assisted_generation` @@ -1300,7 +1300,6 @@ def _get_stopping_criteria( generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList], tokenizer: Optional["PreTrainedTokenizerBase"] = None, - **kwargs, ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if generation_config.max_length is not None: @@ -1869,7 +1868,7 @@ def _prepare_cache_for_generation( self, generation_config: GenerationConfig, model_kwargs: dict, - assistant_model: "PreTrainedModel", + generation_mode: GenerationMode, batch_size: int, max_cache_length: int, ) -> bool: @@ -1923,7 +1922,10 @@ def _prepare_cache_for_generation( # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, # which is only supported in dynamic caches atm - if assistant_model is not None and generation_config.cache_implementation is not None: + if ( + generation_mode == GenerationMode.ASSISTED_GENERATION + and generation_config.cache_implementation is not None + ): logger.warning_once( "An assistant model is provided, using a dynamic cache instead of a cache of type=" f"'{generation_config.cache_implementation}'." @@ -1933,7 +1935,6 @@ def _prepare_cache_for_generation( # Assisted decoding and contrastive search require cache rollback, which is incompatible with sliding layers. # To handle this, we skip passing the model config to DynamicCache (forcing a full-layer cache). # The "dynamic_full" option is a shortcut for generate() users to avoid sliding layers on their own. - generation_mode = generation_config.get_generation_mode(assistant_model) if ( generation_mode in (GenerationMode.ASSISTED_GENERATION, GenerationMode.CONTRASTIVE_SEARCH) or generation_config.cache_implementation == "dynamic_full" @@ -2125,15 +2126,13 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: Ge def _get_deprecated_gen_repo( self, - generation_config: GenerationConfig, + generation_mode: GenerationMode, trust_remote_code: bool, custom_generate: Optional[str] = None, - assistant_model: Optional["PreTrainedModel"] = None, ) -> Optional[str]: """ - Returns the Hub repo for a deprecated generation strategy, if any. + Returns the Hub repo for a deprecated generation mode, if any. """ - generation_mode = generation_config.get_generation_mode(assistant_model) moved_to_hub_modes = { GenerationMode.DOLA_GENERATION: "transformers-community/dola", GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search", @@ -2156,6 +2155,36 @@ def _get_deprecated_gen_repo( ) return repo + def _get_mode_processor_kwargs( + self, + custom_generate, + kwargs, + assistant_model, + negative_prompt_ids, + negative_prompt_attention_mask, + prefix_allowed_tokens_fn, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Extracts and returns the generation mode and logit processor related keyword arguments from the provided kwargs. + """ + gen_mode_kwargs = { + "tokenizer": kwargs.pop("tokenizer", None), + "assistant_model": assistant_model, + "assistant_tokenizer": kwargs.pop("assistant_tokenizer", None), + } + logits_processor_kwargs = { + "negative_prompt_ids": negative_prompt_ids, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn, + } + # If custom_generate is a callable, we need to extract its parameters + if isinstance(custom_generate, Callable): + usual_mode_kwargs = inspect.signature(GenerationMixin._sample).parameters.keys() + custom_generate_kwargs = inspect.signature(custom_generate).parameters.keys() + new_custom_keys = custom_generate_kwargs - usual_mode_kwargs + gen_mode_kwargs = {k: kwargs.pop(k) for k in new_custom_keys if k in kwargs} + return gen_mode_kwargs, logits_processor_kwargs + @torch.no_grad() def generate( self, @@ -2292,23 +2321,29 @@ def generate( ) return custom_generate_function(model=self, **generate_arguments) - # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria - assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation + # 1. Handle kwargs, `generation_config`, validate them and obtain generation mode + gen_mode_kwargs, logits_processor_kwargs = self._get_mode_processor_kwargs( + custom_generate, + kwargs, + assistant_model, + negative_prompt_ids, + negative_prompt_attention_mask, + prefix_allowed_tokens_fn, + ) generation_config, model_kwargs = self._prepare_generation_config( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) + self._validate_assistant(assistant_model=assistant_model) + + generation_mode = generation_config.get_generation_mode(assistant_model) # Deprecation-related step: set Hub repo for deprecated strategies. # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode. # It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps. # TODO joao, manuel: remove this in v4.62.0 - if deprecate_mode_repo := self._get_deprecated_gen_repo( - generation_config, trust_remote_code, custom_generate, assistant_model - ): + if deprecate_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate): return GenerationMixin.generate( self, inputs, @@ -2324,8 +2359,7 @@ def generate( use_model_defaults, custom_generate=deprecate_mode_repo, trust_remote_code=trust_remote_code, - tokenizer=tokenizer, - assistant_tokenizer=assistant_tokenizer, + **gen_mode_kwargs, **kwargs, ) @@ -2406,7 +2440,7 @@ def generate( ) if generation_config.token_healing: - input_ids = self.heal_tokens(input_ids, tokenizer) + input_ids = self.heal_tokens(input_ids, gen_mode_kwargs.get("tokenizer")) if streamer is not None: streamer.put(input_ids.cpu()) @@ -2444,13 +2478,10 @@ def generate( ): max_cache_length += inputs_tensor.shape[1] self._prepare_cache_for_generation( - generation_config, model_kwargs, assistant_model, batch_size, max_cache_length + generation_config, model_kwargs, generation_mode, batch_size, max_cache_length ) - # 8. determine generation mode - generation_mode = generation_config.get_generation_mode(assistant_model) - - if streamer is not None and (generation_config.num_beams > 1): + if streamer is not None and generation_mode == GenerationMode.BEAM_SEARCH: raise ValueError( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) @@ -2466,26 +2497,26 @@ def generate( UserWarning, ) - # 9. prepare logits processors and stopping criteria + # 8. prepare logits processors and stopping criteria prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, device=inputs_tensor.device, model_kwargs=model_kwargs, - negative_prompt_ids=negative_prompt_ids, - negative_prompt_attention_mask=negative_prompt_attention_mask, + **logits_processor_kwargs, ) prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs + generation_config=generation_config, + stopping_criteria=stopping_criteria, + tokenizer=gen_mode_kwargs.get("tokenizer"), ) # Set model_kwargs `use_cache` so we can use it later in forward runs model_kwargs["use_cache"] = generation_config.use_cache - # 10. go into different generation modes + # 9. go into different generation modes if isinstance(custom_generate, Callable): result = custom_generate( self, @@ -2516,19 +2547,19 @@ def generate( f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" ) - # 11. Get the candidate generator, given the parameterization + # 10. Get the candidate generator, given the parameterization candidate_generator = self._get_candidate_generator( generation_config=generation_config, input_ids=input_ids, inputs_tensor=inputs_tensor, assistant_model=assistant_model, logits_processor=logits_processor, - target_tokenizer=tokenizer, - assistant_tokenizer=assistant_tokenizer, + target_tokenizer=gen_mode_kwargs.get("tokenizer"), + assistant_tokenizer=gen_mode_kwargs.get("assistant_tokenizer"), model_kwargs=model_kwargs, ) - # 12. run assisted generate + # 11. run assisted generate result = self._assisted_decoding( input_ids, candidate_generator=candidate_generator, @@ -2541,7 +2572,7 @@ def generate( ) elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + # 10. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) result = self._sample( input_ids, logits_processor=prepared_logits_processor, @@ -2553,7 +2584,7 @@ def generate( ) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): - # 11. run beam sample + # 10. run beam sample result = self._beam_search( input_ids, logits_processor=prepared_logits_processor, diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py index 22b607ec2865..cce1d7a9a5a8 100644 --- a/src/transformers/models/dia/generation_dia.py +++ b/src/transformers/models/dia/generation_dia.py @@ -272,7 +272,9 @@ def _main_generate_loop( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) + self._validate_assistant( + assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer + ) # 2. Set generation parameters if not already defined if synced_gpus is None: @@ -346,12 +348,13 @@ def _main_generate_loop( and not self.config.is_encoder_decoder ): max_cache_length += inputs_tensor.shape[1] - self._prepare_cache_for_generation( - generation_config, model_kwargs, assistant_model, batch_size, max_cache_length - ) # 8. determine generation mode - generation_mode = generation_config.get_generation_mode(assistant_model) + generation_mode = generation_config.get_generation_mode(assistant_model=assistant_model) + + self._prepare_cache_for_generation( + generation_config, model_kwargs, generation_mode, batch_size, max_cache_length + ) if streamer is not None and (generation_config.num_beams > 1): raise ValueError( @@ -371,7 +374,7 @@ def _main_generate_loop( negative_prompt_attention_mask=negative_prompt_attention_mask, ) prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs + generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer ) # Set model_kwargs `use_cache` so we can use it later in forward runs diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 641eec0634d8..c10a0f80acf1 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -1222,7 +1222,7 @@ def _prepare_model_inputs( self.codec_model._prepare_cache_for_generation( generation_config=self.codec_model.generation_config, model_kwargs=temporary_model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=self.config.codec_config.sliding_window, ) diff --git a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py index 03b442b2edbd..8541a911e947 100644 --- a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py @@ -357,7 +357,7 @@ def _prepare_model_inputs( self.codec_model._prepare_cache_for_generation( generation_config=self.codec_model.generation_config, model_kwargs=temporary_model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=self.config.codec_config.sliding_window, ) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 92d2d049ae39..bfcf1191dad3 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1257,7 +1257,7 @@ def generate( self._prepare_cache_for_generation( generation_config, model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=max_cache_length, ) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 8133d33bac47..c634ffe6598d 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -2172,7 +2172,7 @@ def generate( self._prepare_cache_for_generation( generation_config, model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=max_cache_length, ) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 5f1c592d3230..f3932137a082 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1566,7 +1566,7 @@ def extend_enc_output(tensor, num_beams=None): self._prepare_cache_for_generation( generation_config, model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=input_ids.shape[0], max_cache_length=generation_config.max_length - 1, ) From 47580e49fc721cc738bf321f7581edaff7a11b08 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 1 Sep 2025 17:28:51 +0200 Subject: [PATCH 02/30] ops --- src/transformers/generation/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1cd949e46239..d99dcedca5b1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2335,7 +2335,9 @@ def generate( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model=assistant_model) + self._validate_assistant( + assistant_model, gen_mode_kwargs.get("tokenizer"), gen_mode_kwargs.get("prefix_allowed_tokens_fn") + ) generation_mode = generation_config.get_generation_mode(assistant_model) From 3b09223f84c88df2c8e47a477bf74b862e51bbdb Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 1 Sep 2025 17:45:45 +0200 Subject: [PATCH 03/30] fix --- src/transformers/generation/utils.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d99dcedca5b1..e864619cfbd8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2172,17 +2172,17 @@ def _get_mode_processor_kwargs( "assistant_model": assistant_model, "assistant_tokenizer": kwargs.pop("assistant_tokenizer", None), } - logits_processor_kwargs = { - "negative_prompt_ids": negative_prompt_ids, - "negative_prompt_attention_mask": negative_prompt_attention_mask, - "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn, - } - # If custom_generate is a callable, we need to extract its parameters + gen_mode_kwargs = {k: v for k, v in gen_mode_kwargs.items() if v is not None} if isinstance(custom_generate, Callable): usual_mode_kwargs = inspect.signature(GenerationMixin._sample).parameters.keys() custom_generate_kwargs = inspect.signature(custom_generate).parameters.keys() new_custom_keys = custom_generate_kwargs - usual_mode_kwargs gen_mode_kwargs = {k: kwargs.pop(k) for k in new_custom_keys if k in kwargs} + logits_processor_kwargs = { + "negative_prompt_ids": negative_prompt_ids, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn, + } return gen_mode_kwargs, logits_processor_kwargs @torch.no_grad() @@ -2335,9 +2335,7 @@ def generate( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant( - assistant_model, gen_mode_kwargs.get("tokenizer"), gen_mode_kwargs.get("prefix_allowed_tokens_fn") - ) + self._validate_assistant(assistant_model, gen_mode_kwargs.get("tokenizer"), prefix_allowed_tokens_fn) generation_mode = generation_config.get_generation_mode(assistant_model) @@ -2354,7 +2352,7 @@ def generate( stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, - assistant_model, + gen_mode_kwargs.pop("assistant_model"), streamer, negative_prompt_ids, negative_prompt_attention_mask, @@ -2529,6 +2527,7 @@ def generate( synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, + **gen_mode_kwargs, ) elif generation_mode == GenerationMode.ASSISTED_GENERATION: if generation_config.num_return_sequences > 1: From 2a74c4ba9f9c8b5d1c744ca375c2a0c5536dd418 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 1 Sep 2025 17:55:15 +0200 Subject: [PATCH 04/30] ops --- src/transformers/generation/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e864619cfbd8..9605ab3c44f3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2335,7 +2335,9 @@ def generate( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model, gen_mode_kwargs.get("tokenizer"), prefix_allowed_tokens_fn) + self._validate_assistant( + assistant_model, gen_mode_kwargs.get("tokenizer"), gen_mode_kwargs.get("assistant_tokenizer") + ) generation_mode = generation_config.get_generation_mode(assistant_model) From bfd41f1dd88885983b99d14fb9eaf656a19c6bd6 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 2 Sep 2025 15:47:50 +0200 Subject: [PATCH 05/30] review --- src/transformers/generation/utils.py | 164 +++++++++--------- src/transformers/models/dia/generation_dia.py | 2 +- 2 files changed, 79 insertions(+), 87 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9605ab3c44f3..1531f7acee17 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1492,35 +1492,44 @@ def compute_transition_scores( return transition_scores - def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): - if assistant_model is None: - return + def _validate_generation_mode(self, generation_mode, generation_mode_kwargs): + if "synced_gpus" not in generation_mode_kwargs: + generation_mode_kwargs["synced_gpus"] = ( + is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + ) and dist.get_world_size() > 1 - if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: - attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] - attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] - are_equal = all( - getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check + if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs: + raise ValueError( + "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) - if not are_equal: - raise ValueError( - "The main model and the assistant don't have compatible encoder-dependent input shapes. " - "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." - ) - doc_reference = ( - "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)" - ) - if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: - if assistant_tokenizer is not None: - raise ValueError( - f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}." - ) - else: - if tokenizer is None or assistant_tokenizer is None: - raise ValueError( - f"The main and assistant models have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}." + if generation_mode == GenerationMode.ASSISTED_GENERATION: + assistant_model = generation_mode_kwargs.get("assistant_model") + if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: + attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] + attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] + are_equal = all( + getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check ) + if not are_equal: + raise ValueError( + "The main model and the assistant don't have compatible encoder-dependent input shapes. " + "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." + ) + + doc_reference = ( + "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)" + ) + if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: + if "assistant_tokenizer" in generation_mode_kwargs: + raise ValueError( + f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}." + ) + else: + if "tokenizer" not in generation_mode_kwargs or "assistant_tokenizer" not in generation_mode_kwargs: + raise ValueError( + f"The main and assistant models have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}." + ) def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" @@ -2155,35 +2164,33 @@ def _get_deprecated_gen_repo( ) return repo - def _get_mode_processor_kwargs( + def _extract_generation_mode_kwargs( self, custom_generate, kwargs, + synced_gpus, assistant_model, - negative_prompt_ids, - negative_prompt_attention_mask, - prefix_allowed_tokens_fn, - ) -> tuple[dict[str, Any], dict[str, Any]]: + streamer, + ) -> dict[str, Any]: """ - Extracts and returns the generation mode and logit processor related keyword arguments from the provided kwargs. + Extracts and returns the generation mode related keyword arguments from the provided kwargs. """ - gen_mode_kwargs = { + generation_mode_kwargs = { "tokenizer": kwargs.pop("tokenizer", None), - "assistant_model": assistant_model, "assistant_tokenizer": kwargs.pop("assistant_tokenizer", None), + "synced_gpus": synced_gpus, + "assistant_model": assistant_model, + "streamer": streamer, } - gen_mode_kwargs = {k: v for k, v in gen_mode_kwargs.items() if v is not None} + generation_mode_kwargs = {k: v for k, v in generation_mode_kwargs.items() if v is not None} + # Custom_generate callables can have their own set of arguments + # To extract them, we compare the signature with the standard _sample method if isinstance(custom_generate, Callable): usual_mode_kwargs = inspect.signature(GenerationMixin._sample).parameters.keys() custom_generate_kwargs = inspect.signature(custom_generate).parameters.keys() new_custom_keys = custom_generate_kwargs - usual_mode_kwargs - gen_mode_kwargs = {k: kwargs.pop(k) for k in new_custom_keys if k in kwargs} - logits_processor_kwargs = { - "negative_prompt_ids": negative_prompt_ids, - "negative_prompt_attention_mask": negative_prompt_attention_mask, - "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn, - } - return gen_mode_kwargs, logits_processor_kwargs + generation_mode_kwargs = {k: kwargs.pop(k) for k in new_custom_keys if k in kwargs} + return generation_mode_kwargs @torch.no_grad() def generate( @@ -2322,25 +2329,22 @@ def generate( return custom_generate_function(model=self, **generate_arguments) # 1. Handle kwargs, `generation_config`, validate them and obtain generation mode - gen_mode_kwargs, logits_processor_kwargs = self._get_mode_processor_kwargs( + generation_mode_kwargs = self._extract_generation_mode_kwargs( custom_generate, kwargs, + synced_gpus, assistant_model, - negative_prompt_ids, - negative_prompt_attention_mask, - prefix_allowed_tokens_fn, + streamer, ) generation_config, model_kwargs = self._prepare_generation_config( generation_config, use_model_defaults, **kwargs ) - self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant( - assistant_model, gen_mode_kwargs.get("tokenizer"), gen_mode_kwargs.get("assistant_tokenizer") - ) - generation_mode = generation_config.get_generation_mode(assistant_model) + self._validate_model_kwargs(model_kwargs.copy()) + self._validate_generation_mode(generation_mode, generation_mode_kwargs) + # Deprecation-related step: set Hub repo for deprecated strategies. # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode. # It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps. @@ -2348,27 +2352,22 @@ def generate( if deprecate_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate): return GenerationMixin.generate( self, - inputs, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - gen_mode_kwargs.pop("assistant_model"), - streamer, - negative_prompt_ids, - negative_prompt_attention_mask, - use_model_defaults, + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + assistant_model=assistant_model, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + use_model_defaults=use_model_defaults, custom_generate=deprecate_mode_repo, trust_remote_code=trust_remote_code, - **gen_mode_kwargs, + **generation_mode_kwargs, **kwargs, ) # 2. Set generation parameters if not already defined - if synced_gpus is None: - synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() @@ -2442,7 +2441,7 @@ def generate( ) if generation_config.token_healing: - input_ids = self.heal_tokens(input_ids, gen_mode_kwargs.get("tokenizer")) + input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer")) if streamer is not None: streamer.put(input_ids.cpu()) @@ -2483,11 +2482,6 @@ def generate( generation_config, model_kwargs, generation_mode, batch_size, max_cache_length ) - if streamer is not None and generation_mode == GenerationMode.BEAM_SEARCH: - raise ValueError( - "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." - ) - if self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with the `input_ids` being on a device type different" @@ -2504,15 +2498,17 @@ def generate( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, device=inputs_tensor.device, model_kwargs=model_kwargs, - **logits_processor_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, ) prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria, - tokenizer=gen_mode_kwargs.get("tokenizer"), + tokenizer=generation_mode_kwargs.get("tokenizer"), ) # Set model_kwargs `use_cache` so we can use it later in forward runs @@ -2526,10 +2522,8 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, - **gen_mode_kwargs, ) elif generation_mode == GenerationMode.ASSISTED_GENERATION: if generation_config.num_return_sequences > 1: @@ -2555,10 +2549,10 @@ def generate( generation_config=generation_config, input_ids=input_ids, inputs_tensor=inputs_tensor, - assistant_model=assistant_model, + assistant_model=generation_mode_kwargs.pop("assistant_model", None), logits_processor=logits_processor, - target_tokenizer=gen_mode_kwargs.get("tokenizer"), - assistant_tokenizer=gen_mode_kwargs.get("assistant_tokenizer"), + target_tokenizer=generation_mode_kwargs.pop("tokenizer", None), + assistant_tokenizer=generation_mode_kwargs.pop("assistant_tokenizer", None), model_kwargs=model_kwargs, ) @@ -2569,8 +2563,7 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, ) @@ -2581,8 +2574,7 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, ) @@ -2593,7 +2585,7 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, + **generation_mode_kwargs, **model_kwargs, ) @@ -2716,7 +2708,7 @@ def _sample( stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, - streamer: Optional["BaseStreamer"], + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -3482,7 +3474,7 @@ def _assisted_decoding( stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, - streamer: Optional["BaseStreamer"], + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py index cce1d7a9a5a8..3f471c231f6e 100644 --- a/src/transformers/models/dia/generation_dia.py +++ b/src/transformers/models/dia/generation_dia.py @@ -272,7 +272,7 @@ def _main_generate_loop( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant( + self._validate_generation_mode( assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer ) From 494c9a8217125c9a2f4d2135935a989dc58a14f1 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 2 Sep 2025 16:12:38 +0200 Subject: [PATCH 06/30] fix --- src/transformers/generation/utils.py | 9 ++++----- src/transformers/models/csm/generation_csm.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1531f7acee17..b62c12a8df92 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1503,8 +1503,7 @@ def _validate_generation_mode(self, generation_mode, generation_mode_kwargs): "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) - if generation_mode == GenerationMode.ASSISTED_GENERATION: - assistant_model = generation_mode_kwargs.get("assistant_model") + if (assistant_model := generation_mode_kwargs.get("assistant_model")) is not None: if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] @@ -2707,7 +2706,7 @@ def _sample( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: @@ -3136,7 +3135,7 @@ def _beam_search( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, + synced_gpus: bool = False, **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" @@ -3473,7 +3472,7 @@ def _assisted_decoding( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: diff --git a/src/transformers/models/csm/generation_csm.py b/src/transformers/models/csm/generation_csm.py index 9c2f06e6562f..b14f353685c2 100644 --- a/src/transformers/models/csm/generation_csm.py +++ b/src/transformers/models/csm/generation_csm.py @@ -153,8 +153,8 @@ def _sample( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, - streamer: Optional["BaseStreamer"], + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: """ From da72717eea6255736a78ece8228e92103e74933f Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 2 Sep 2025 18:02:55 +0200 Subject: [PATCH 07/30] fix dia --- src/transformers/models/dia/generation_dia.py | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py index 3f471c231f6e..439b498b0988 100644 --- a/src/transformers/models/dia/generation_dia.py +++ b/src/transformers/models/dia/generation_dia.py @@ -265,16 +265,20 @@ def _main_generate_loop( ): # ********** mostly taken from main generate function up to calling the different methods (see NOTE) ********** # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria - assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation - + generation_mode_kwargs = self._extract_generation_mode_kwargs( + custom_generate, + kwargs, + synced_gpus, + assistant_model, + streamer, + ) generation_config, model_kwargs = self._prepare_generation_config( generation_config, use_model_defaults, **kwargs ) + generation_mode = generation_config.get_generation_mode(assistant_model) + self._validate_model_kwargs(model_kwargs.copy()) - self._validate_generation_mode( - assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer - ) + self._validate_generation_mode(generation_mode, generation_mode_kwargs) # 2. Set generation parameters if not already defined if synced_gpus is None: @@ -310,7 +314,7 @@ def _main_generate_loop( ) if generation_config.token_healing: - input_ids = self.heal_tokens(input_ids, tokenizer) + input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer")) if streamer is not None: streamer.put(input_ids.cpu()) @@ -348,20 +352,11 @@ def _main_generate_loop( and not self.config.is_encoder_decoder ): max_cache_length += inputs_tensor.shape[1] - - # 8. determine generation mode - generation_mode = generation_config.get_generation_mode(assistant_model=assistant_model) - self._prepare_cache_for_generation( generation_config, model_kwargs, generation_mode, batch_size, max_cache_length ) - if streamer is not None and (generation_config.num_beams > 1): - raise ValueError( - "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." - ) - - # 9. prepare logits processors and stopping criteria + # 8. prepare logits processors and stopping criteria prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, @@ -374,7 +369,9 @@ def _main_generate_loop( negative_prompt_attention_mask=negative_prompt_attention_mask, ) prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer + generation_config=generation_config, + stopping_criteria=stopping_criteria, + tokenizer=generation_mode_kwargs.get("tokenizer"), ) # Set model_kwargs `use_cache` so we can use it later in forward runs @@ -396,8 +393,7 @@ def _main_generate_loop( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, ) else: From 27582a1d21bb610a257432ffb729ed8250cfaba6 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 3 Sep 2025 15:54:23 +0200 Subject: [PATCH 08/30] unify assisted generate to common decoding method signature --- src/transformers/generation/utils.py | 158 ++++++++++++--------------- 1 file changed, 72 insertions(+), 86 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b62c12a8df92..04e8be9f80f7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2141,16 +2141,9 @@ def _get_deprecated_gen_repo( """ Returns the Hub repo for a deprecated generation mode, if any. """ - moved_to_hub_modes = { - GenerationMode.DOLA_GENERATION: "transformers-community/dola", - GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search", - GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search", - GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search", - } - if custom_generate is not None or generation_mode not in moved_to_hub_modes: + if custom_generate is not None or not isinstance(repo := GENERATION_MODES_MAPPING[generation_mode], str): return None - repo = moved_to_hub_modes[generation_mode] logger.warning_once( f"{generation_mode.name.replace('_', ' ').title()} was moved to a `custom_generate` repo: https://hf.co/{repo}. " f"To prevent loss of backward compatibility, add `custom_generate='{repo}'` " @@ -2340,6 +2333,9 @@ def generate( generation_config, use_model_defaults, **kwargs ) generation_mode = generation_config.get_generation_mode(assistant_model) + generation_call = ( + GENERATION_MODES_MAPPING[generation_mode] if not isinstance(custom_generate, Callable) else custom_generate + ) self._validate_model_kwargs(model_kwargs.copy()) self._validate_generation_mode(generation_mode, generation_mode_kwargs) @@ -2378,6 +2374,8 @@ def generate( inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) + if "inputs_tensor" in inspect.signature(generation_call).parameters.keys(): + generation_mode_kwargs["inputs_tensor"] = inputs_tensor batch_size = inputs_tensor.shape[0] device = inputs_tensor.device @@ -2513,80 +2511,16 @@ def generate( # Set model_kwargs `use_cache` so we can use it later in forward runs model_kwargs["use_cache"] = generation_config.use_cache - # 9. go into different generation modes - if isinstance(custom_generate, Callable): - result = custom_generate( - self, - input_ids, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - **generation_mode_kwargs, - **model_kwargs, - ) - elif generation_mode == GenerationMode.ASSISTED_GENERATION: - if generation_config.num_return_sequences > 1: - raise ValueError( - "num_return_sequences has to be 1 when doing assisted generate, " - f"but is {generation_config.num_return_sequences}." - ) - if batch_size > 1: - raise ValueError("assisted generate is only supported for batch_size = 1") - if not model_kwargs["use_cache"]: - raise ValueError("assisted generate requires `use_cache=True`") - if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]: - raise ValueError("assisted generate is not supported with Static cache classes`") - if self._is_stateful: - # In assisted generation we need the ability to confirm whether the model would pick certain tokens, - # which is not possible with stateful models (they can't reset to a previous subset of generated text) - raise ValueError( - f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" - ) - - # 10. Get the candidate generator, given the parameterization - candidate_generator = self._get_candidate_generator( - generation_config=generation_config, - input_ids=input_ids, - inputs_tensor=inputs_tensor, - assistant_model=generation_mode_kwargs.pop("assistant_model", None), - logits_processor=logits_processor, - target_tokenizer=generation_mode_kwargs.pop("tokenizer", None), - assistant_tokenizer=generation_mode_kwargs.pop("assistant_tokenizer", None), - model_kwargs=model_kwargs, - ) - - # 11. run assisted generate - result = self._assisted_decoding( - input_ids, - candidate_generator=candidate_generator, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - **generation_mode_kwargs, - **model_kwargs, - ) - - elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 10. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) - result = self._sample( - input_ids, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - **generation_mode_kwargs, - **model_kwargs, - ) - - elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): - # 10. run beam sample - result = self._beam_search( - input_ids, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - **generation_mode_kwargs, - **model_kwargs, - ) + # 9. go into generation mode callable + result = generation_call( + self, + input_ids, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + **generation_mode_kwargs, + **model_kwargs, + ) # Convert to legacy cache format if requested if ( @@ -3468,12 +3402,15 @@ def _beam_search( def _assisted_decoding( self, input_ids: torch.LongTensor, - candidate_generator: CandidateGenerator, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, + inputs_tensor: Optional[torch.FloatTensor] = None, + assistant_model: Optional["PreTrainedModel"] = None, + assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -3485,9 +3422,6 @@ def _assisted_decoding( Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. - candidate_generator (`CandidateGenerator`): - A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For - more information, the documentation of [`CandidateGenerator`] should be read. logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. @@ -3502,6 +3436,15 @@ def _assisted_decoding( streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + inputs_tensor (`torch.FloatTensor`, *optional*): + The input tensor for generation. For decoder models, usually `input_ids`. For encoder-decoder models, + the tensor that produced `model_kwargs["encoder_outputs"]`. + assistant_model (`PreTrainedModel`, *optional*): + The model used to assist the generation process. If not provided, the main model will be used. + assistant_tokenizer (`PreTrainedTokenizerBase`, *optional*): + The tokenizer used for the assistant model. If not provided, the token space is assumed to be the same. + tokenizer (`PreTrainedTokenizerBase`, *optional*): + The tokenizer used for the main model. If not provided, the token space is assumed to be the same. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -3513,6 +3456,35 @@ def _assisted_decoding( `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. """ + if generation_config.num_return_sequences > 1: + raise ValueError( + "num_return_sequences has to be 1 when doing assisted generate, " + f"but is {generation_config.num_return_sequences}." + ) + if input_ids.shape[0] > 1: + raise ValueError("assisted generate is only supported for batch_size = 1") + if not model_kwargs["use_cache"]: + raise ValueError("assisted generate requires `use_cache=True`") + if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]: + raise ValueError("assisted generate is not supported with Static cache classes`") + if self._is_stateful: + # In assisted generation we need the ability to confirm whether the model would pick certain tokens, + # which is not possible with stateful models (they can't reset to a previous subset of generated text) + raise ValueError( + f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" + ) + + # 10. Get the candidate generator, given the parameterization + candidate_generator = self._get_candidate_generator( + generation_config=generation_config, + input_ids=input_ids, + inputs_tensor=inputs_tensor, + assistant_model=assistant_model, + logits_processor=logits_processor, + target_tokenizer=tokenizer, + assistant_tokenizer=assistant_tokenizer, + model_kwargs=model_kwargs, + ) # init values do_sample = generation_config.do_sample output_attentions = generation_config.output_attentions @@ -3856,3 +3828,17 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at new_tuple += (layer[..., i : i + 1, :last_dim_size],) outputs += (new_tuple,) return outputs + + +GENERATION_MODES_MAPPING = { + GenerationMode.SAMPLE: GenerationMixin._sample, + GenerationMode.GREEDY_SEARCH: GenerationMixin._sample, + GenerationMode.BEAM_SEARCH: GenerationMixin._beam_search, + GenerationMode.BEAM_SAMPLE: GenerationMixin._beam_search, + GenerationMode.ASSISTED_GENERATION: GenerationMixin._assisted_decoding, + # Deprecated methods + GenerationMode.DOLA_GENERATION: "transformers-community/dola", + GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search", + GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search", + GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search", +} From fcfc23dfb6458f16e0e79c24341010945317cd20 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 3 Sep 2025 18:17:18 +0200 Subject: [PATCH 09/30] move checks to validate steps where possible --- src/transformers/generation/utils.py | 48 ++++++++++++++++------------ 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ab86a056fb31..eb09297c99df 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1492,12 +1492,27 @@ def compute_transition_scores( return transition_scores - def _validate_generation_mode(self, generation_mode, generation_mode_kwargs): + def _validate_generation_mode(self, batch_size, generation_mode, generation_config, generation_mode_kwargs): if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs: raise ValueError( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) + if generation_mode == GenerationMode.ASSISTED_GENERATION: + if generation_config.num_return_sequences > 1: + raise ValueError( + "num_return_sequences has to be 1 when doing assisted generate, " + f"but is {generation_config.num_return_sequences}." + ) + if batch_size > 1: + raise ValueError("assisted generate is only supported for batch_size = 1") + if self._is_stateful: + # In assisted generation we need the ability to confirm whether the model would pick certain tokens, + # which is not possible with stateful models (they can't reset to a previous subset of generated text) + raise ValueError( + f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" + ) + if (assistant_model := generation_mode_kwargs.get("assistant_model")) is not None: if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] @@ -2336,7 +2351,6 @@ def generate( ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_generation_mode(generation_mode, generation_mode_kwargs) # Deprecation-related step: set Hub repo for deprecated strategies. # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode. @@ -2372,10 +2386,13 @@ def generate( inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) + # Some generation modes (e.g. assisted) need `inputs_tensor` to rerun encoder.forward() if "inputs_tensor" in inspect.signature(generation_call).parameters.keys(): generation_mode_kwargs["inputs_tensor"] = inputs_tensor batch_size = inputs_tensor.shape[0] + self._validate_generation_mode(batch_size, generation_mode, generation_config, generation_mode_kwargs) + device = inputs_tensor.device self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) @@ -2509,7 +2526,7 @@ def generate( # Set model_kwargs `use_cache` so we can use it later in forward runs model_kwargs["use_cache"] = generation_config.use_cache - # 9. go into generation mode callable + # 9. Call generation mode result = generation_call( self, input_ids, @@ -3405,7 +3422,7 @@ def _assisted_decoding( generation_config: GenerationConfig, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, - inputs_tensor: Optional[torch.FloatTensor] = None, + inputs_tensor: torch.FloatTensor = None, assistant_model: Optional["PreTrainedModel"] = None, assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None, tokenizer: Optional["PreTrainedTokenizerBase"] = None, @@ -3454,25 +3471,16 @@ def _assisted_decoding( `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. """ - if generation_config.num_return_sequences > 1: - raise ValueError( - "num_return_sequences has to be 1 when doing assisted generate, " - f"but is {generation_config.num_return_sequences}." - ) - if input_ids.shape[0] > 1: - raise ValueError("assisted generate is only supported for batch_size = 1") + # The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") - if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]: + if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] or ( + "past_key_values" in model_kwargs + and hasattr(model_kwargs["past_key_values"], "layers") + and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers) + ): raise ValueError("assisted generate is not supported with Static cache classes`") - if self._is_stateful: - # In assisted generation we need the ability to confirm whether the model would pick certain tokens, - # which is not possible with stateful models (they can't reset to a previous subset of generated text) - raise ValueError( - f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" - ) - - # 10. Get the candidate generator, given the parameterization + # Get the candidate generator, given the parameterization candidate_generator = self._get_candidate_generator( generation_config=generation_config, input_ids=input_ids, From 35fc1169a3c9a0b2a8d17ab12b0858844e5f3eef Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 3 Sep 2025 18:44:27 +0200 Subject: [PATCH 10/30] fix csm and other models that override _sample --- src/transformers/generation/utils.py | 36 +++++++++++++++------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index eb09297c99df..75cc1e50d60c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2146,12 +2146,13 @@ def _get_deprecated_gen_repo( self, generation_mode: GenerationMode, trust_remote_code: bool, + generation_modes_mapping: dict[GenerationMode, Union[str, Callable]], custom_generate: Optional[str] = None, ) -> Optional[str]: """ Returns the Hub repo for a deprecated generation mode, if any. """ - if custom_generate is not None or not isinstance(repo := GENERATION_MODES_MAPPING[generation_mode], str): + if custom_generate is not None or not isinstance(repo := generation_modes_mapping[generation_mode], str): return None logger.warning_once( @@ -2346,8 +2347,21 @@ def generate( generation_config, use_model_defaults, **kwargs ) generation_mode = generation_config.get_generation_mode(assistant_model) + # Cannot be root level constant since subclasses might override the methods + generation_modes_mapping = { + GenerationMode.SAMPLE: type(self)._sample, + GenerationMode.GREEDY_SEARCH: type(self)._sample, + GenerationMode.BEAM_SEARCH: type(self)._beam_search, + GenerationMode.BEAM_SAMPLE: type(self)._beam_search, + GenerationMode.ASSISTED_GENERATION: type(self)._assisted_decoding, + # Deprecated methods + GenerationMode.DOLA_GENERATION: "transformers-community/dola", + GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search", + GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search", + GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search", + } generation_call = ( - GENERATION_MODES_MAPPING[generation_mode] if not isinstance(custom_generate, Callable) else custom_generate + generation_modes_mapping[generation_mode] if not isinstance(custom_generate, Callable) else custom_generate ) self._validate_model_kwargs(model_kwargs.copy()) @@ -2356,7 +2370,9 @@ def generate( # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode. # It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps. # TODO joao, manuel: remove this in v4.62.0 - if deprecate_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate): + if deprecate_mode_repo := self._get_deprecated_gen_repo( + generation_mode, trust_remote_code, generation_modes_mapping, custom_generate + ): return GenerationMixin.generate( self, inputs=inputs, @@ -3834,17 +3850,3 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at new_tuple += (layer[..., i : i + 1, :last_dim_size],) outputs += (new_tuple,) return outputs - - -GENERATION_MODES_MAPPING = { - GenerationMode.SAMPLE: GenerationMixin._sample, - GenerationMode.GREEDY_SEARCH: GenerationMixin._sample, - GenerationMode.BEAM_SEARCH: GenerationMixin._beam_search, - GenerationMode.BEAM_SAMPLE: GenerationMixin._beam_search, - GenerationMode.ASSISTED_GENERATION: GenerationMixin._assisted_decoding, - # Deprecated methods - GenerationMode.DOLA_GENERATION: "transformers-community/dola", - GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search", - GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search", - GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search", -} From 26919bd963fb725e71df68bf68827a0e876e6bfb Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 3 Sep 2025 18:56:15 +0200 Subject: [PATCH 11/30] ops dia you again --- src/transformers/models/dia/generation_dia.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py index 439b498b0988..f811aee179a3 100644 --- a/src/transformers/models/dia/generation_dia.py +++ b/src/transformers/models/dia/generation_dia.py @@ -278,7 +278,6 @@ def _main_generate_loop( generation_mode = generation_config.get_generation_mode(assistant_model) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_generation_mode(generation_mode, generation_mode_kwargs) # 2. Set generation parameters if not already defined if synced_gpus is None: @@ -293,6 +292,7 @@ def _main_generate_loop( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] + self._validate_generation_mode(batch_size, generation_mode, generation_config, generation_mode_kwargs) device = inputs_tensor.device self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) From a3f7be323d477b55e097764f1c0641f896f4fa9f Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Thu, 4 Sep 2025 10:59:27 +0200 Subject: [PATCH 12/30] opsie --- src/transformers/generation/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 75cc1e50d60c..d99183b111f2 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2184,10 +2184,11 @@ def _extract_generation_mode_kwargs( "assistant_model": assistant_model, "streamer": streamer, } - if synced_gpus is not None: - generation_mode_kwargs["synced_gpus"] = ( - is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) - ) and dist.get_world_size() > 1 + generation_mode_kwargs["synced_gpus"] = ( + (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 + if synced_gpus is None + else synced_gpus + ) generation_mode_kwargs = {k: v for k, v in generation_mode_kwargs.items() if v is not None} # Custom_generate callables can have their own set of arguments # To extract them, we compare the signature with the standard _sample method From a580344e9c162d3090b627058590f03e880871c0 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Thu, 4 Sep 2025 17:17:45 +0200 Subject: [PATCH 13/30] joao review --- src/transformers/generation/utils.py | 51 +++++++++++++--------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d99183b111f2..16de45809427 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -129,6 +129,19 @@ "past_buckets_states", # reformer ] +GENERATION_MODES_MAPPING = { + GenerationMode.SAMPLE: "_sample", + GenerationMode.GREEDY_SEARCH: "_sample", + GenerationMode.BEAM_SEARCH: "_beam_search", + GenerationMode.BEAM_SAMPLE: "_beam_search", + GenerationMode.ASSISTED_GENERATION: "_assisted_decoding", + # Deprecated methods + GenerationMode.DOLA_GENERATION: "transformers-community/dola", + GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search", + GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search", + GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search", +} + @dataclass class GenerateDecoderOnlyOutput(ModelOutput): @@ -1492,7 +1505,7 @@ def compute_transition_scores( return transition_scores - def _validate_generation_mode(self, batch_size, generation_mode, generation_config, generation_mode_kwargs): + def _validate_generation_mode(self, generation_mode, generation_config, generation_mode_kwargs): if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs: raise ValueError( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." @@ -1504,8 +1517,6 @@ def _validate_generation_mode(self, batch_size, generation_mode, generation_conf "num_return_sequences has to be 1 when doing assisted generate, " f"but is {generation_config.num_return_sequences}." ) - if batch_size > 1: - raise ValueError("assisted generate is only supported for batch_size = 1") if self._is_stateful: # In assisted generation we need the ability to confirm whether the model would pick certain tokens, # which is not possible with stateful models (they can't reset to a previous subset of generated text) @@ -2146,13 +2157,12 @@ def _get_deprecated_gen_repo( self, generation_mode: GenerationMode, trust_remote_code: bool, - generation_modes_mapping: dict[GenerationMode, Union[str, Callable]], custom_generate: Optional[str] = None, ) -> Optional[str]: """ Returns the Hub repo for a deprecated generation mode, if any. """ - if custom_generate is not None or not isinstance(repo := generation_modes_mapping[generation_mode], str): + if custom_generate is not None or "/" not in (repo := GENERATION_MODES_MAPPING[generation_mode]): return None logger.warning_once( @@ -2349,31 +2359,18 @@ def generate( ) generation_mode = generation_config.get_generation_mode(assistant_model) # Cannot be root level constant since subclasses might override the methods - generation_modes_mapping = { - GenerationMode.SAMPLE: type(self)._sample, - GenerationMode.GREEDY_SEARCH: type(self)._sample, - GenerationMode.BEAM_SEARCH: type(self)._beam_search, - GenerationMode.BEAM_SAMPLE: type(self)._beam_search, - GenerationMode.ASSISTED_GENERATION: type(self)._assisted_decoding, - # Deprecated methods - GenerationMode.DOLA_GENERATION: "transformers-community/dola", - GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search", - GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search", - GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search", - } - generation_call = ( - generation_modes_mapping[generation_mode] if not isinstance(custom_generate, Callable) else custom_generate + decoding_method = ( + GENERATION_MODES_MAPPING[generation_mode] if not isinstance(custom_generate, Callable) else custom_generate ) self._validate_model_kwargs(model_kwargs.copy()) + self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs) # Deprecation-related step: set Hub repo for deprecated strategies. # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode. # It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps. # TODO joao, manuel: remove this in v4.62.0 - if deprecate_mode_repo := self._get_deprecated_gen_repo( - generation_mode, trust_remote_code, generation_modes_mapping, custom_generate - ): + if deprecated_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate): return GenerationMixin.generate( self, inputs=inputs, @@ -2385,7 +2382,7 @@ def generate( negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, use_model_defaults=use_model_defaults, - custom_generate=deprecate_mode_repo, + custom_generate=deprecated_mode_repo, trust_remote_code=trust_remote_code, **generation_mode_kwargs, **kwargs, @@ -2404,12 +2401,10 @@ def generate( inputs, generation_config.bos_token_id, model_kwargs ) # Some generation modes (e.g. assisted) need `inputs_tensor` to rerun encoder.forward() - if "inputs_tensor" in inspect.signature(generation_call).parameters.keys(): + if "inputs_tensor" in inspect.signature(decoding_method).parameters.keys(): generation_mode_kwargs["inputs_tensor"] = inputs_tensor batch_size = inputs_tensor.shape[0] - self._validate_generation_mode(batch_size, generation_mode, generation_config, generation_mode_kwargs) - device = inputs_tensor.device self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) @@ -2544,7 +2539,7 @@ def generate( model_kwargs["use_cache"] = generation_config.use_cache # 9. Call generation mode - result = generation_call( + result = decoding_method( self, input_ids, logits_processor=prepared_logits_processor, @@ -3532,6 +3527,8 @@ def _assisted_decoding( # keep track of which sequences are already finished batch_size, cur_len = input_ids.shape[:2] + if batch_size > 1: + raise ValueError("assisted generate is only supported for batch_size = 1") unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) From 58ffe607f5cff429f49c5719d51eb3b9ccbed4f0 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Thu, 4 Sep 2025 18:56:11 +0200 Subject: [PATCH 14/30] ops --- src/transformers/generation/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 16de45809427..73bc4577f69a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2359,9 +2359,10 @@ def generate( ) generation_mode = generation_config.get_generation_mode(assistant_model) # Cannot be root level constant since subclasses might override the methods - decoding_method = ( - GENERATION_MODES_MAPPING[generation_mode] if not isinstance(custom_generate, Callable) else custom_generate - ) + if isinstance(custom_generate, Callable): + decoding_method = custom_generate + else: + decoding_method = getattr(self, GENERATION_MODES_MAPPING[generation_mode]) self._validate_model_kwargs(model_kwargs.copy()) self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs) From 29404687c3f15ef14c2adce9e2436adef2dbaf00 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Thu, 4 Sep 2025 22:15:18 +0200 Subject: [PATCH 15/30] ops2 --- src/transformers/generation/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 73bc4577f69a..3ac83b854978 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2358,11 +2358,11 @@ def generate( generation_config, use_model_defaults, **kwargs ) generation_mode = generation_config.get_generation_mode(assistant_model) - # Cannot be root level constant since subclasses might override the methods if isinstance(custom_generate, Callable): decoding_method = custom_generate else: - decoding_method = getattr(self, GENERATION_MODES_MAPPING[generation_mode]) + # type() required to access the unbound class-level method + decoding_method = getattr(type(self), GENERATION_MODES_MAPPING[generation_mode]) self._validate_model_kwargs(model_kwargs.copy()) self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs) From ceaaf68905358d3ea8b3ad469b07ac774cf9e3b6 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Thu, 4 Sep 2025 22:29:51 +0200 Subject: [PATCH 16/30] dia --- src/transformers/models/dia/generation_dia.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py index f811aee179a3..45ee66d39a97 100644 --- a/src/transformers/models/dia/generation_dia.py +++ b/src/transformers/models/dia/generation_dia.py @@ -278,6 +278,7 @@ def _main_generate_loop( generation_mode = generation_config.get_generation_mode(assistant_model) self._validate_model_kwargs(model_kwargs.copy()) + self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs) # 2. Set generation parameters if not already defined if synced_gpus is None: @@ -292,7 +293,6 @@ def _main_generate_loop( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] - self._validate_generation_mode(batch_size, generation_mode, generation_config, generation_mode_kwargs) device = inputs_tensor.device self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) From 273b6d9d2d45b55991e58ed5ec91f577ecd2709d Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Fri, 5 Sep 2025 13:41:36 +0200 Subject: [PATCH 17/30] Move variable output controls to `prepare_inputs_for_generation` --- src/transformers/generation/utils.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f118cf5276b0..ccaac36cfbc7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1768,6 +1768,11 @@ def _prepare_generation_config( # Finally, apply any passed kwargs model_kwargs = generation_config.update(**kwargs) + # And keep in model_kwargs variable output controls + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + model_kwargs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_kwargs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) return generation_config, model_kwargs @@ -2772,10 +2777,6 @@ def _sample( # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - if is_prefill: outputs = self(**model_inputs, return_dict=True) is_prefill = False @@ -3256,10 +3257,6 @@ def _beam_search( flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len]) model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs) - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - model_outputs = self(**model_inputs, return_dict=True) # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping @@ -3553,9 +3550,6 @@ def _assisted_decoding( model_inputs["logits_to_keep"] = candidate_length + 1 # 2.2. Run a forward pass on the candidate sequence - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) outputs = self(**model_inputs) From ac7efcb6d74273f54ab303b0c4f8a7f3fd9a7244 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Fri, 5 Sep 2025 13:42:46 +0200 Subject: [PATCH 18/30] fix xlstm --- src/transformers/models/xlstm/modeling_xlstm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index 14f189d2f1cc..6df4808f8c53 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -1556,6 +1556,12 @@ def prepare_inputs_for_generation( model_inputs = {"input_ids": input_ids} model_inputs.update({"cache_params": cache_params, "use_cache": use_cache}) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @can_return_tuple From 9a25f88d3d61b6a16df49e2110ea11bb6011e80a Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 8 Sep 2025 14:27:27 +0200 Subject: [PATCH 19/30] skip on args check --- src/transformers/generation/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ccaac36cfbc7..97ae541c821a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -18,7 +18,7 @@ import os import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_type_hints import torch import torch.distributed as dist @@ -46,6 +46,7 @@ from ..tokenization_utils import ExtensionsTrie from ..utils import ( ModelOutput, + TransformersKwargs, is_accelerate_available, is_hqq_available, is_optimum_quanto_available, @@ -1563,7 +1564,7 @@ def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): model_args |= {f"decoder_{x}" for x in decoder_model_args} for key, value in model_kwargs.items(): - if value is not None and key not in model_args: + if value is not None and key not in model_args and key not in get_type_hints(TransformersKwargs): unused_model_args.append(key) if unused_model_args: From 315247ad6a4394c80a7d1295a45d48dc35fdf56c Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 8 Sep 2025 14:54:15 +0200 Subject: [PATCH 20/30] fix xlm roberta, zamba --- .../models/xlm_roberta_xl/modeling_xlm_roberta_xl.py | 12 +++++++++++- src/transformers/models/zamba/modeling_zamba.py | 6 ++++++ src/transformers/models/zamba2/modeling_zamba2.py | 6 ++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index c625ce7b53ea..8d68f448f663 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -1013,13 +1013,23 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if position_ids is not None: position_ids = position_ids[:, remove_prefix_length:] - return { + model_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values, } + # They are calculated on the fly on XLMRobertaXLModel.forward() + model_kwargs.pop("token_type_ids", None) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in model_kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + @auto_docstring class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel): diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 9c0f86ea4489..2f9edb1e113c 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -1185,6 +1185,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index ddd4f6f69079..33e7e4b5a351 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1606,6 +1606,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs From 453ee1e28a7a77508c4c78650e2ac73c3c630140 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 8 Sep 2025 15:06:06 +0200 Subject: [PATCH 21/30] fix moshi, rwkv --- src/transformers/models/moshi/modeling_moshi.py | 7 ++++++- src/transformers/models/rwkv/modeling_rwkv.py | 6 ++++++ .../models/xlm_roberta_xl/modeling_xlm_roberta_xl.py | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index df254675450d..78a2e3e9d281 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -2259,7 +2259,7 @@ def prepare_inputs_for_generation( # we want to do it after a first token has been generated if model_inputs["input_ids"] is not None: - last_hidden_state = kwargs.get("last_hidden_state") + last_hidden_state = kwargs.pop("last_hidden_state") # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim) last_hidden_state = last_hidden_state.view(-1, 1, last_hidden_state.shape[-1]) @@ -2291,6 +2291,11 @@ def prepare_inputs_for_generation( model_inputs["input_ids"] = None model_inputs["inputs_embeds"] = inputs_embeds + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs def _update_model_kwargs_for_generation( diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 0b16af278946..d86d4d0f8707 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -719,6 +719,12 @@ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=Non model_inputs["state"] = state model_inputs["use_cache"] = use_cache + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @auto_docstring diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 8d68f448f663..ca85ea93814b 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -993,7 +993,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti # Create missing `position_ids` on the fly position_ids = None - if model_kwargs.get("position_ids") is None: + if model_kwargs.pop("position_ids", None) is None: position_ids = create_position_ids_from_input_ids( input_ids, padding_idx=self.config.pad_token_id ) # placed in kwargs for further processing (see below) From 9a3a82629af4be61c84af7ade292432c8b2d25a2 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 8 Sep 2025 15:13:42 +0200 Subject: [PATCH 22/30] fix mamba2 --- src/transformers/models/mamba2/modeling_mamba2.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 85cf026e49d0..738c5376c33e 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -989,6 +989,12 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @auto_docstring From 18e1ec19f074ee67ade0f95bf23e08a0ad7496ab Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 8 Sep 2025 16:27:32 +0200 Subject: [PATCH 23/30] fix a bunch of models --- src/transformers/models/bamba/modeling_bamba.py | 6 ++++++ src/transformers/models/bamba/modular_bamba.py | 6 ++++++ src/transformers/models/bloom/modeling_bloom.py | 6 ++++++ src/transformers/models/ctrl/modeling_ctrl.py | 9 ++++++++- .../models/falcon_h1/modeling_falcon_h1.py | 6 ++++++ .../models/falcon_h1/modular_falcon_h1.py | 6 ++++++ .../falcon_mamba/modeling_falcon_mamba.py | 6 ++++++ src/transformers/models/git/modeling_git.py | 9 ++++++++- .../modeling_granitemoehybrid.py | 6 ++++++ .../modular_granitemoehybrid.py | 6 ++++++ src/transformers/models/jamba/modeling_jamba.py | 6 ++++++ .../models/kosmos2_5/modeling_kosmos2_5.py | 9 ++++++++- src/transformers/models/mamba/modeling_mamba.py | 6 ++++++ .../models/openai/modeling_openai.py | 9 ++++++++- .../models/prophetnet/modeling_prophetnet.py | 12 +++++++++++- .../models/reformer/modeling_reformer.py | 9 +++++++-- src/transformers/models/xlm/modeling_xlm.py | 15 +++++++++++++-- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 2 +- src/transformers/models/xlnet/modeling_xlnet.py | 17 +++++++++++++---- 19 files changed, 137 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index f5e337e52ebd..09f00845524d 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1492,6 +1492,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index aec09861de81..52814930a172 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -1211,6 +1211,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 64576b683952..72585bb7dff4 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -822,6 +822,12 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @auto_docstring diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index ece001f9ce1f..6b64bebc33bb 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -570,7 +570,14 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cac input_ids = input_ids[:, remove_prefix_length:] - return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} + model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs @auto_docstring( diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index da16bbbd0327..674634b0ef32 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -1607,6 +1607,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 34193212a99c..d688939a0b5e 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -1372,6 +1372,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 7b61d2bdefd9..4f3577d8f50a 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -839,6 +839,12 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @auto_docstring diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 2b69cf07a046..642d4b9a36e7 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1452,7 +1452,7 @@ def prepare_inputs_for_generation( if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - return { + model_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": kwargs.get("pixel_values"), @@ -1460,5 +1460,12 @@ def prepare_inputs_for_generation( "use_cache": use_cache, } + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + __all__ = ["GitForCausalLM", "GitModel", "GitPreTrainedModel", "GitVisionModel"] diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index c0efccf4b5bb..2006ad6bf40a 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1829,6 +1829,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 25151b6936b6..f18c2a2842b2 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -383,6 +383,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index db7461af73d7..f80f235626dd 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1448,6 +1448,12 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs diff --git a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py index 27e692273c71..414841696cd7 100644 --- a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py +++ b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py @@ -1648,7 +1648,7 @@ def prepare_inputs_for_generation( dim=1, ) - return { + model_inputs = { "input_ids": input_ids, "image_embeds": image_embeds, "image_embeds_position_mask": image_embeds_position_mask, @@ -1658,6 +1658,13 @@ def prepare_inputs_for_generation( "use_cache": use_cache, } + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in model_kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + @add_start_docstrings( """ diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 3f39c2d8490b..677e62de57e8 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -777,6 +777,12 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, } ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs @auto_docstring diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 27c84910cb43..44fa05227ff8 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -602,7 +602,14 @@ def forward( def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict[str, Any]: # Overwritten -- old model with reduced inputs - return {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids} + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs @auto_docstring( diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index a79c803e77b9..5beefe79b314 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -2010,7 +2010,7 @@ def prepare_inputs_for_generation( if past_key_values is not None and past_key_values.get_seq_length() > 0: input_ids = input_ids[:, -1:] # first step, decoder_cached_states are empty - return { + model_inputs = { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "attention_mask": attention_mask, "head_mask": head_mask, @@ -2018,6 +2018,16 @@ def prepare_inputs_for_generation( "use_cache": use_cache, } + # Prophetnet does not support cache_position + kwargs.pop("cache_position", None) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel): """ diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 367af6692357..9125a5faefaf 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -2345,14 +2345,19 @@ def prepare_inputs_for_generation( if past_key_values is not None: input_ids = input_ids[:, -1:] - inputs_dict = { + model_inputs = { "input_ids": input_ids, "past_buckets_states": past_key_values, "use_cache": use_cache, "num_hashes": num_hashes, } - return inputs_dict + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs def _reorder_cache(self, past_key_values, beam_idx): reord_past_buckets_states = [] diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 4e7316fb781b..3a90a3057426 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -72,7 +72,7 @@ def get_masks(slen, lengths, causal, padding_mask=None): attn_mask = mask # sanity check - assert mask.size() == (bs, slen) + assert mask.size() == (bs, slen), f"mask.size(): {mask.size()}, should be: {(bs, slen)}" assert causal is False or attn_mask.size() == (bs, slen, slen) return mask, attn_mask @@ -994,7 +994,18 @@ def prepare_inputs_for_generation(self, input_ids, **kwargs): langs = torch.full_like(input_ids, lang_id) else: langs = None - return {"input_ids": input_ids, "langs": langs} + model_inputs = {"input_ids": input_ids, "langs": langs} + + # They are calculated on the fly on XLMModel.forward() + kwargs.pop("token_type_ids", None) + kwargs.pop("attention_mask", None) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs @auto_docstring def forward( diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index ca85ea93814b..8d68f448f663 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -993,7 +993,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti # Create missing `position_ids` on the fly position_ids = None - if model_kwargs.pop("position_ids", None) is None: + if model_kwargs.get("position_ids") is None: position_ids = create_position_ids_from_input_ids( input_ids, padding_idx=self.config.pad_token_id ) # placed in kwargs for further processing (see below) diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 736521ee9561..0c6b9f76eade 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -1472,7 +1472,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_mem ) target_mapping[:, 0, -1] = 1.0 - inputs = { + model_inputs = { "input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping, @@ -1481,9 +1481,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_mem # if past is defined in model kwargs then use it for faster decoding if past_key_values: - inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past_key_values) - - return inputs + model_inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past_key_values) + + # Attention mask is computed on the fly on XLNetModel.forward() + kwargs.pop("attention_mask", None) + # TODO: Ignoring use_cache should not happen, fixme. + kwargs.pop("use_cache", None) + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs @auto_docstring def forward( From 2e032485a432f7a72b56d33e973d1abce8d76dbe Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 8 Sep 2025 16:49:15 +0200 Subject: [PATCH 24/30] fix --- src/transformers/models/ctrl/modeling_ctrl.py | 3 +++ src/transformers/models/reformer/modeling_reformer.py | 3 +++ src/transformers/models/xlm/modeling_xlm.py | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 6b64bebc33bb..01f227c79185 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -572,9 +572,12 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cac model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} + # token_type_ids are computed on CTRLModel.forward() + kwargs.pop("token_type_ids", None) # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). for key, value in kwargs.items(): if key not in model_inputs: + print(f"Warning: {key} is not a recognized input.") model_inputs[key] = value return model_inputs diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 9125a5faefaf..990f21359bc0 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -2352,9 +2352,12 @@ def prepare_inputs_for_generation( "num_hashes": num_hashes, } + # Attention mask is computed on ReformerModel.forward() + kwargs.pop("attention_mask", None) # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). for key, value in kwargs.items(): if key not in model_inputs: + print(f"Warning: {key} is not a recognized input.") model_inputs[key] = value return model_inputs diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 3a90a3057426..a73b4a51cea4 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -72,7 +72,7 @@ def get_masks(slen, lengths, causal, padding_mask=None): attn_mask = mask # sanity check - assert mask.size() == (bs, slen), f"mask.size(): {mask.size()}, should be: {(bs, slen)}" + assert mask.size() == (bs, slen) assert causal is False or attn_mask.size() == (bs, slen, slen) return mask, attn_mask From c4f625780d274a9583f7c6e59407a50b177ec001 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 10 Sep 2025 20:44:21 +0200 Subject: [PATCH 25/30] review --- src/transformers/generation/utils.py | 13 +++-- tests/generation/test_utils.py | 60 +++++++++++++++++++++++ tests/models/dia/test_modeling_dia.py | 4 ++ tests/models/moshi/test_modeling_moshi.py | 8 +++ 4 files changed, 78 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 97ae541c821a..77a0bfbb2d47 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -18,7 +18,7 @@ import os import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_type_hints +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -547,11 +547,9 @@ def prepare_inputs_for_generation( **kwargs, ): """ - Prepare the model inputs for generation. It includes operations like computing the 4D attention mask or - slicing inputs given the existing cache. - - See the forward pass in the model documentation for expected arguments (different models might have different - requirements for e.g. `past_key_values`). This function should work as is for most LLMs. + Prepare the model inputs for generation. Notable steps include selecting the correct input key and cloning when appropriate, + creating position_ids from the attention_mask when missing, slicing inputs and converting 2D attention masks to 4D for + compilable caches, and finally forwarding all additional keyword arguments unchanged to the model's forward pass. """ # 1. Handle BC: @@ -1563,8 +1561,9 @@ def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): decoder_model_args = set(inspect.signature(decoder.forward).parameters) model_args |= {f"decoder_{x}" for x in decoder_model_args} + # TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions' for key, value in model_kwargs.items(): - if value is not None and key not in model_args and key not in get_type_hints(TransformersKwargs): + if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__: unused_model_args.append(key) if unused_model_args: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f285ce2818bf..6881c8eb59ed 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1835,6 +1835,66 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) + @pytest.mark.generate + def test_prepare_inputs_for_generation_kwargs_forwarding(self, **extra_kwargs): + """Tests that prepare_inputs_for_generation forwards arbitrary kwargs while manipulating specific args.""" + # TODO: fixme, these old models do not clone input ids like the reference `prepare_inputs_for_generation` + non_compilable_model_classes = [ + "BambaModel", + "CTRLModel", + "FalconH1Model", + "FalconMambaModel", + "GitModel", + "GraniteMoeHybridModel", + "JambaModel", + "Kosmos2_5Model", + "MambaModel", + "Mamba2Model", + "OpenAIGPTModel", + "ProphetNetStandaloneDecoderModel", + "ReformerLocalAttnModel", + "ReformerLSHAttnModel", + "RwkvModel", + "XLMRobertaXLModel", + "xLSTMModel", + "ZambaModel", + "Zamba2Model", + ] + + for model_class in self.all_generative_model_classes: + config, _ = self.prepare_config_and_inputs_for_generate() + + model = model_class(config).to(torch_device).eval() + + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) + + # Test that arbitrary kwargs are forwarded unchanged + input_args = { + "input_ids": input_ids, + "cache_position": [0], + "position_ids": torch.tensor([[0, 1, 2], [0, 1, 2]]).to(torch_device), + } + arbitrary_kwargs = { + "output_attentions": True, + "output_hidden_states": True, + "custom_arg": "test_value", + "numeric_arg": 42, + } + + model_inputs = model.prepare_inputs_for_generation(**input_args, **arbitrary_kwargs, **extra_kwargs) + + if model_class.__name__ in non_compilable_model_classes: + # Verify that input_ids is cloned + input_ids_key = "decoder_input_ids" if config.is_encoder_decoder else "input_ids" + self.assertTrue(model_inputs[input_ids_key] is not input_ids) + + # Verify that all arbitrary kwargs are forwarded unchanged + for key, value in arbitrary_kwargs.items(): + self.assertTrue(key in model_inputs) + self.assertTrue( + model_inputs[key] == value, f"Expected {key} to be {value}, but got {model_inputs[key]}" + ) + def _test_attention_implementation(self, attn_implementation): """ Compares the output of generate with the eager attention implementation against other implementations. diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py index b927b903085b..136e8df4b3dc 100644 --- a/tests/models/dia/test_modeling_dia.py +++ b/tests/models/dia/test_modeling_dia.py @@ -517,6 +517,10 @@ def test_generate_continue_from_past_key_values(self): ) ) + @pytest.mark.generate + def test_prepare_inputs_for_generation_kwargs_forwarding(self): + super().test_prepare_inputs_for_generation_kwargs_forwarding(encoder_outputs=torch.randn(2, 2, 32)) + @unittest.skip(reason="Indirectly checked in Dia through the generate methods.") def test_past_key_values_format(self, custom_all_cache_shapes=None): pass diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index cc33f1492dd3..5200e824fed3 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -880,6 +880,14 @@ def test_generate_continue_from_inputs_embeds(self): def test_save_load(self): super().test_save_load() + @pytest.mark.generate + @unittest.skip(reason="Moshi requires setting `model.generated_audio_codes` in generate() before preparing inputs") + def test_prepare_inputs_for_generation_kwargs_forwarding(self): + # If in the future `model.generated_audio_codes` is not required, this test can be re-enabled + super().test_prepare_inputs_for_generation_kwargs_forwarding( + last_hidden_state=torch.randn(2, 3, 32), kwargs_depth_decoder={} + ) + def place_dict_on_device(dict_to_place, device): for key in dict_to_place: From 190dfb5419b695bd454132c2304e4bff285c80d2 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 10 Sep 2025 20:56:45 +0200 Subject: [PATCH 26/30] ops --- src/transformers/generation/utils.py | 3 ++ tests/generation/test_utils.py | 45 +++++++++++++--------------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 77a0bfbb2d47..eef89db327ff 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -550,6 +550,9 @@ def prepare_inputs_for_generation( Prepare the model inputs for generation. Notable steps include selecting the correct input key and cloning when appropriate, creating position_ids from the attention_mask when missing, slicing inputs and converting 2D attention masks to 4D for compilable caches, and finally forwarding all additional keyword arguments unchanged to the model's forward pass. + + See the forward pass in the model documentation for expected arguments (different models might have different + requirements for e.g. `past_key_values`). This function should work as is for most LLMs. """ # 1. Handle BC: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6881c8eb59ed..be1b9c39f3c3 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1840,25 +1840,24 @@ def test_prepare_inputs_for_generation_kwargs_forwarding(self, **extra_kwargs): """Tests that prepare_inputs_for_generation forwards arbitrary kwargs while manipulating specific args.""" # TODO: fixme, these old models do not clone input ids like the reference `prepare_inputs_for_generation` non_compilable_model_classes = [ - "BambaModel", - "CTRLModel", - "FalconH1Model", - "FalconMambaModel", - "GitModel", - "GraniteMoeHybridModel", - "JambaModel", - "Kosmos2_5Model", - "MambaModel", - "Mamba2Model", - "OpenAIGPTModel", - "ProphetNetStandaloneDecoderModel", - "ReformerLocalAttnModel", - "ReformerLSHAttnModel", - "RwkvModel", - "XLMRobertaXLModel", - "xLSTMModel", - "ZambaModel", - "Zamba2Model", + "BambaForCausalLM", + "CTRLLMHeadModel", + "FalconH1ForCausalLM", + "FalconMambaForCausalLM", + "GitForCausalLM", + "GraniteMoeHybridForCausalLM", + "JambaForCausalLM", + "Kosmos2_5ForConditionalGeneration", + "MambaForCausalLM", + "Mamba2ForCausalLM", + "OpenAIGPTLMHeadModel", + "ProphetNetForCausalLM", + "ReformerModelWithLMHead", + "RwkvForCausalLM", + "XLMRobertaXLForCausalLM", + "xLSTMForCausalLM", + "ZambaForCausalLM", + "Zamba2ForCausalLM", ] for model_class in self.all_generative_model_classes: @@ -1868,7 +1867,6 @@ def test_prepare_inputs_for_generation_kwargs_forwarding(self, **extra_kwargs): input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) - # Test that arbitrary kwargs are forwarded unchanged input_args = { "input_ids": input_ids, "cache_position": [0], @@ -1883,17 +1881,14 @@ def test_prepare_inputs_for_generation_kwargs_forwarding(self, **extra_kwargs): model_inputs = model.prepare_inputs_for_generation(**input_args, **arbitrary_kwargs, **extra_kwargs) - if model_class.__name__ in non_compilable_model_classes: + if model_class.__name__ not in non_compilable_model_classes: # Verify that input_ids is cloned input_ids_key = "decoder_input_ids" if config.is_encoder_decoder else "input_ids" self.assertTrue(model_inputs[input_ids_key] is not input_ids) - # Verify that all arbitrary kwargs are forwarded unchanged for key, value in arbitrary_kwargs.items(): self.assertTrue(key in model_inputs) - self.assertTrue( - model_inputs[key] == value, f"Expected {key} to be {value}, but got {model_inputs[key]}" - ) + self.assertTrue(model_inputs[key] == value) def _test_attention_implementation(self, attn_implementation): """ From aa93797c2bd0c9b00d8e6ba4f3e3ccf44c02c907 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 10 Sep 2025 20:59:52 +0200 Subject: [PATCH 27/30] better comment --- tests/generation/test_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index be1b9c39f3c3..25731c43c728 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1838,7 +1838,8 @@ def test_inherits_generation_mixin(self): @pytest.mark.generate def test_prepare_inputs_for_generation_kwargs_forwarding(self, **extra_kwargs): """Tests that prepare_inputs_for_generation forwards arbitrary kwargs while manipulating specific args.""" - # TODO: fixme, these old models do not clone input ids like the reference `prepare_inputs_for_generation` + # TODO: fixme. These old models do not clone input ids like the reference `prepare_inputs_for_generation`. + # Thus, we skip the clone check on them. non_compilable_model_classes = [ "BambaForCausalLM", "CTRLLMHeadModel", From f668884bff79dc8d304db252d339c47524c119a6 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Sun, 14 Sep 2025 12:37:38 +0200 Subject: [PATCH 28/30] back to basics --- tests/generation/test_utils.py | 35 +++++++--------------------------- 1 file changed, 7 insertions(+), 28 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 25731c43c728..3b028e2d9b0f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1838,29 +1838,6 @@ def test_inherits_generation_mixin(self): @pytest.mark.generate def test_prepare_inputs_for_generation_kwargs_forwarding(self, **extra_kwargs): """Tests that prepare_inputs_for_generation forwards arbitrary kwargs while manipulating specific args.""" - # TODO: fixme. These old models do not clone input ids like the reference `prepare_inputs_for_generation`. - # Thus, we skip the clone check on them. - non_compilable_model_classes = [ - "BambaForCausalLM", - "CTRLLMHeadModel", - "FalconH1ForCausalLM", - "FalconMambaForCausalLM", - "GitForCausalLM", - "GraniteMoeHybridForCausalLM", - "JambaForCausalLM", - "Kosmos2_5ForConditionalGeneration", - "MambaForCausalLM", - "Mamba2ForCausalLM", - "OpenAIGPTLMHeadModel", - "ProphetNetForCausalLM", - "ReformerModelWithLMHead", - "RwkvForCausalLM", - "XLMRobertaXLForCausalLM", - "xLSTMForCausalLM", - "ZambaForCausalLM", - "Zamba2ForCausalLM", - ] - for model_class in self.all_generative_model_classes: config, _ = self.prepare_config_and_inputs_for_generate() @@ -1870,7 +1847,7 @@ def test_prepare_inputs_for_generation_kwargs_forwarding(self, **extra_kwargs): input_args = { "input_ids": input_ids, - "cache_position": [0], + "cache_position": torch.tensor([9]).to(torch_device), "position_ids": torch.tensor([[0, 1, 2], [0, 1, 2]]).to(torch_device), } arbitrary_kwargs = { @@ -1882,11 +1859,13 @@ def test_prepare_inputs_for_generation_kwargs_forwarding(self, **extra_kwargs): model_inputs = model.prepare_inputs_for_generation(**input_args, **arbitrary_kwargs, **extra_kwargs) - if model_class.__name__ not in non_compilable_model_classes: - # Verify that input_ids is cloned - input_ids_key = "decoder_input_ids" if config.is_encoder_decoder else "input_ids" - self.assertTrue(model_inputs[input_ids_key] is not input_ids) + # Verify that input_ids has proper name + if config.is_encoder_decoder: + self.assertTrue("decoder_input_ids" in model_inputs) + else: + self.assertTrue("input_ids" in model_inputs) + # Verify that arbitrary kwargs are forwarded for key, value in arbitrary_kwargs.items(): self.assertTrue(key in model_inputs) self.assertTrue(model_inputs[key] == value) From a4e750ab9d7fe29aad0d262edc67ccd878db5d7d Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 15 Sep 2025 12:44:19 +0200 Subject: [PATCH 29/30] final touches --- tests/generation/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index a7d05c2aafdb..3b828cd8313a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1804,8 +1804,8 @@ def test_inherits_generation_mixin(self): self.assertTrue("GenerationMixin" in str(model_class.__bases__)) @pytest.mark.generate - def test_prepare_inputs_for_generation_kwargs_forwarding(self, **extra_kwargs): - """Tests that prepare_inputs_for_generation forwards arbitrary kwargs while manipulating specific args.""" + def test_prepare_inputs_for_generation_kwargs_forwards(self, **extra_kwargs): + """Tests that prepare_inputs_for_generation forwards arbitrary kwargs.""" for model_class in self.all_generative_model_classes: config, _ = self.prepare_config_and_inputs_for_generate() From 712f39ac28b84924a0f093561a6a639292395e7f Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 15 Sep 2025 12:58:43 +0200 Subject: [PATCH 30/30] ops --- tests/models/dia/test_modeling_dia.py | 4 ++-- tests/models/moshi/test_modeling_moshi.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py index cd7b5fbb0aeb..5ac321c5a753 100644 --- a/tests/models/dia/test_modeling_dia.py +++ b/tests/models/dia/test_modeling_dia.py @@ -518,8 +518,8 @@ def test_generate_continue_from_past_key_values(self): ) @pytest.mark.generate - def test_prepare_inputs_for_generation_kwargs_forwarding(self): - super().test_prepare_inputs_for_generation_kwargs_forwarding(encoder_outputs=torch.randn(2, 2, 32)) + def test_prepare_inputs_for_generation_kwargs_forwards(self): + super().test_prepare_inputs_for_generation_kwargs_forwards(encoder_outputs=torch.randn(2, 2, 32)) @unittest.skip(reason="Indirectly checked in Dia through the generate methods.") def test_hidden_states_output(self): diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index a36ee6b9dbc7..21f56e1bc56d 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -870,9 +870,9 @@ def test_save_load(self): @pytest.mark.generate @unittest.skip(reason="Moshi requires setting `model.generated_audio_codes` in generate() before preparing inputs") - def test_prepare_inputs_for_generation_kwargs_forwarding(self): + def test_prepare_inputs_for_generation_kwargs_forwards(self): # If in the future `model.generated_audio_codes` is not required, this test can be re-enabled - super().test_prepare_inputs_for_generation_kwargs_forwarding( + super().test_prepare_inputs_for_generation_kwargs_forwards( last_hidden_state=torch.randn(2, 3, 32), kwargs_depth_decoder={} )