From 331a87a711578d765091627cb47031e25c0b7ecc Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 1 Sep 2025 17:20:57 +0200 Subject: [PATCH 1/8] Squashed commit of the following: commit beb2b5f7a04ea9e12876696db66f3589fbae10c5 Author: Manuel de Prada Corral Date: Mon Sep 1 16:03:25 2025 +0200 also standardize _get_stopping_criteria commit 15c25663fa991e0a215a7f3cdcf13a9d3a989faa Author: Manuel de Prada Corral Date: Mon Sep 1 15:48:38 2025 +0200 watch super.generate() usages commit 67dd845be2202d191a54b2872f1cb3f71b74b7d6 Author: Manuel de Prada Corral Date: Mon Sep 1 14:44:32 2025 +0200 ops commit 4655dfa28fd59d5dc083a41d8396de042d99858c Author: Manuel de Prada Corral Date: Mon Sep 1 14:41:36 2025 +0200 wrong merge commit 46478143994e7b27d51c972a7881e0fea3cb6e3c Merge: a72c2c4b2f 8564e210ca Author: Manuel de Prada Corral Date: Mon Sep 1 14:36:15 2025 +0200 Merge branch 'main' of github.com:huggingface/transformers into fix-custom-gen-from-function2 commit a72c2c4b2f9c0e09fe6ec7992d4d02bfa279da2a Author: Manuel de Prada Corral Date: Mon Sep 1 14:04:59 2025 +0200 ops5 commit e72f91411b961979bb3d271810f57905cee5b577 Author: Manuel de Prada Corral Date: Mon Sep 1 12:06:19 2025 +0200 ops4 commit 12ca97b1078a42167143e0243036f6ef87d5fdac Author: Manuel de Prada Corral Date: Mon Sep 1 11:58:59 2025 +0200 ops3 commit 8cac6c60a318dd381793d4bf1ef3775823f3c95b Author: Manuel de Prada Corral Date: Mon Sep 1 11:43:03 2025 +0200 ops2 commit 4681a7d5dc6c8b96a515d9d79f06380c096b9a9f Author: Manuel de Prada Corral Date: Mon Sep 1 11:40:51 2025 +0200 ops commit 0d72aa6cbd99a5933c5a95a39bea9088ee21e50f Merge: e0d47e980e 5bb6186b8e Author: Manuel de Prada Corral Date: Mon Sep 1 11:37:28 2025 +0200 Merge branch 'remove-constrained-bs' into fix-custom-gen-from-function2 commit 5bb6186b8efbd5fdb8e3464a22f958343b9c450c Merge: 44973dac7d b0db5a02f3 Author: Manuel de Prada Corral Date: Mon Sep 1 11:36:30 2025 +0200 Merge branch 'main' into remove-constrained-bs commit 44973dac7df4b4e2111c71f5fac918be21f3de52 Merge: 1ddab4bee1 893d89e5e6 Author: Manuel de Prada Corral Date: Mon Sep 1 11:29:48 2025 +0200 Merge commit '893d89e5e6fac7279fe4292bfa3b027172287162' into remove-constrained-bs commit e0d47e980e26d32b028c2b402ccb71262637a7a7 Merge: 88128e4563 1ddab4bee1 Author: Manuel de Prada Corral Date: Mon Sep 1 10:52:50 2025 +0200 Merge branch 'remove-constrained-bs' into fix-custom-gen-from-function2 commit 88128e4563c0be583728e1d3c639bc93143c4029 Author: Manuel de Prada Corral Date: Mon Sep 1 10:44:38 2025 +0200 fix custom generate args, refactor gen mode args commit 1ddab4bee159f6c20722e7ff5cd41d5041fab0aa Author: Manuel de Prada Corral Date: Sun Aug 31 21:03:53 2025 +0200 fix commit 6095fdda677ef7fbeb06c05f4f914a11b45257b4 Merge: 4a8b6d2ce1 04addbc9ec Author: Manuel de Prada Corral Date: Thu Aug 28 17:49:16 2025 +0200 Merge branch 'remove-constrained-bs' of github.com:manueldeprada/transformers into remove-constrained-bs commit 4a8b6d2ce18b3a8b52c5261fea427e2416f65187 Author: Manuel de Prada Corral Date: Thu Aug 28 17:48:25 2025 +0200 restore and deprecate beam obkects commit 04addbc9ec62dd4f59d15128e8cd9499e2cda3bb Merge: e800c7841e becab2c601 Author: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Thu Aug 28 14:38:29 2025 +0200 Merge branch 'main' into remove-constrained-bs commit e800c7841e5c46ce5698fc9be309d0808f85d23c Author: Manuel de Prada Corral Date: Thu Aug 28 14:38:10 2025 +0200 tests gone after green commit 33971d21ac40aef76a7e1122f4a98ef28beadbe8 Author: Manuel de Prada Corral Date: Thu Aug 28 14:07:11 2025 +0200 tests green, changed handling of deprecated methods commit ab303835c184d0a87789da7aed7d8de5ba85d867 Author: Manuel de Prada Corral Date: Thu Aug 28 12:58:01 2025 +0200 tests fix commit ec74274ca52a6aa0b5f300374fda838609680506 Author: Manuel de Prada Corral Date: Thu Aug 28 12:32:05 2025 +0200 ops commit 0fb19004ccd285dcad485fce0865b355ce5493e0 Author: Manuel de Prada Corral Date: Thu Aug 28 11:45:16 2025 +0200 whoops commit c946bea5e45aea021c8878c57fcabc2a13f06fe5 Author: Manuel de Prada Corral Date: Thu Aug 28 11:35:36 2025 +0200 testing... commit 924c0dec6d9ea6b4890644fe7f711dc778f820bb Author: Manuel de Prada Corral Date: Thu Aug 28 11:22:46 2025 +0200 sweeep ready for tests commit b05aa771d3994b07cd460cda74b274c9e4f315e6 Author: Manuel de Prada Corral Date: Thu Aug 28 11:13:01 2025 +0200 restore and deprecate constraints commit 9c7962d10efa7178b69d3c99e69663756e1cd979 Merge: fceeb383f9 c17bf304d5 Author: Manuel de Prada Corral Date: Wed Aug 27 20:44:21 2025 +0200 Merge branch 'remove-group-bs' into remove-constrained-bs commit c17bf304d5cf33af7f34f9f6057915d5f5821dae Author: Manuel de Prada Corral Date: Wed Aug 27 17:00:50 2025 +0200 fix test commit d579aeec6706b77fcc24c1f6806cd7277d7db56e Merge: 822efd8c3c ed5dd2999c Author: Manuel de Prada Corral Date: Wed Aug 27 16:04:31 2025 +0200 Merge branch 'main' of github.com:huggingface/transformers into remove-group-bs commit 822efd8c3cf475d079e64293aa06e4ab59740fd7 Author: Manuel de Prada Corral Date: Wed Aug 27 15:59:51 2025 +0200 aaand remove tests after all green!! commit 62cb274a4acb9f24201902242f1b0dc4e46daac1 Author: Manuel de Prada Corral Date: Wed Aug 27 11:48:19 2025 +0200 fix commit c89c892e7b24a7d71831f2b35264456005030925 Author: Manuel de Prada Corral Date: Wed Aug 27 11:45:20 2025 +0200 testing that hub works the same commit fceeb383f99e4a836679d67b1d2a8520152eaf49 Author: Manuel de Prada Corral Date: Tue Aug 26 20:06:59 2025 +0200 draft commit 6a9b384078f3798587ba865ac7ddfefc9a79e41c Merge: 8af3af13ab 58cebc848b Author: Manuel de Prada Corral Date: Tue Aug 26 15:00:05 2025 +0200 Merge branch 'main' of github.com:huggingface/transformers into remove-group-bs commit 8af3af13abb85ca60e795d0390832f398a56c34f Author: Manuel de Prada Corral Date: Tue Aug 26 11:55:45 2025 +0200 Squashed commit remove-constrastive-search --- src/transformers/generation/utils.py | 107 +++++++++++------- src/transformers/models/dia/generation_dia.py | 15 ++- .../modeling_kyutai_speech_to_text.py | 2 +- .../modular_kyutai_speech_to_text.py | 2 +- .../models/musicgen/modeling_musicgen.py | 2 +- .../modeling_musicgen_melody.py | 2 +- src/transformers/models/rag/modeling_rag.py | 2 +- 7 files changed, 83 insertions(+), 49 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f118cf5276b0..1cd949e46239 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -999,11 +999,11 @@ def _get_candidate_generator( generation_config: GenerationConfig, input_ids: torch.LongTensor, inputs_tensor: torch.Tensor, - assistant_model: "PreTrainedModel", logits_processor: LogitsProcessorList, - target_tokenizer: "PreTrainedTokenizerBase", - assistant_tokenizer: "PreTrainedTokenizerBase", model_kwargs: dict, + assistant_model: Optional["PreTrainedModel"] = None, + target_tokenizer: Optional["PreTrainedTokenizerBase"] = None, + assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None, ) -> CandidateGenerator: """ Returns the candidate generator to be used in `assisted_generation` @@ -1300,7 +1300,6 @@ def _get_stopping_criteria( generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList], tokenizer: Optional["PreTrainedTokenizerBase"] = None, - **kwargs, ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if generation_config.max_length is not None: @@ -1869,7 +1868,7 @@ def _prepare_cache_for_generation( self, generation_config: GenerationConfig, model_kwargs: dict, - assistant_model: "PreTrainedModel", + generation_mode: GenerationMode, batch_size: int, max_cache_length: int, ) -> bool: @@ -1923,7 +1922,10 @@ def _prepare_cache_for_generation( # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, # which is only supported in dynamic caches atm - if assistant_model is not None and generation_config.cache_implementation is not None: + if ( + generation_mode == GenerationMode.ASSISTED_GENERATION + and generation_config.cache_implementation is not None + ): logger.warning_once( "An assistant model is provided, using a dynamic cache instead of a cache of type=" f"'{generation_config.cache_implementation}'." @@ -1933,7 +1935,6 @@ def _prepare_cache_for_generation( # Assisted decoding and contrastive search require cache rollback, which is incompatible with sliding layers. # To handle this, we skip passing the model config to DynamicCache (forcing a full-layer cache). # The "dynamic_full" option is a shortcut for generate() users to avoid sliding layers on their own. - generation_mode = generation_config.get_generation_mode(assistant_model) if ( generation_mode in (GenerationMode.ASSISTED_GENERATION, GenerationMode.CONTRASTIVE_SEARCH) or generation_config.cache_implementation == "dynamic_full" @@ -2125,15 +2126,13 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: Ge def _get_deprecated_gen_repo( self, - generation_config: GenerationConfig, + generation_mode: GenerationMode, trust_remote_code: bool, custom_generate: Optional[str] = None, - assistant_model: Optional["PreTrainedModel"] = None, ) -> Optional[str]: """ - Returns the Hub repo for a deprecated generation strategy, if any. + Returns the Hub repo for a deprecated generation mode, if any. """ - generation_mode = generation_config.get_generation_mode(assistant_model) moved_to_hub_modes = { GenerationMode.DOLA_GENERATION: "transformers-community/dola", GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search", @@ -2156,6 +2155,36 @@ def _get_deprecated_gen_repo( ) return repo + def _get_mode_processor_kwargs( + self, + custom_generate, + kwargs, + assistant_model, + negative_prompt_ids, + negative_prompt_attention_mask, + prefix_allowed_tokens_fn, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Extracts and returns the generation mode and logit processor related keyword arguments from the provided kwargs. + """ + gen_mode_kwargs = { + "tokenizer": kwargs.pop("tokenizer", None), + "assistant_model": assistant_model, + "assistant_tokenizer": kwargs.pop("assistant_tokenizer", None), + } + logits_processor_kwargs = { + "negative_prompt_ids": negative_prompt_ids, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn, + } + # If custom_generate is a callable, we need to extract its parameters + if isinstance(custom_generate, Callable): + usual_mode_kwargs = inspect.signature(GenerationMixin._sample).parameters.keys() + custom_generate_kwargs = inspect.signature(custom_generate).parameters.keys() + new_custom_keys = custom_generate_kwargs - usual_mode_kwargs + gen_mode_kwargs = {k: kwargs.pop(k) for k in new_custom_keys if k in kwargs} + return gen_mode_kwargs, logits_processor_kwargs + @torch.no_grad() def generate( self, @@ -2292,23 +2321,29 @@ def generate( ) return custom_generate_function(model=self, **generate_arguments) - # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria - assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation + # 1. Handle kwargs, `generation_config`, validate them and obtain generation mode + gen_mode_kwargs, logits_processor_kwargs = self._get_mode_processor_kwargs( + custom_generate, + kwargs, + assistant_model, + negative_prompt_ids, + negative_prompt_attention_mask, + prefix_allowed_tokens_fn, + ) generation_config, model_kwargs = self._prepare_generation_config( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) + self._validate_assistant(assistant_model=assistant_model) + + generation_mode = generation_config.get_generation_mode(assistant_model) # Deprecation-related step: set Hub repo for deprecated strategies. # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode. # It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps. # TODO joao, manuel: remove this in v4.62.0 - if deprecate_mode_repo := self._get_deprecated_gen_repo( - generation_config, trust_remote_code, custom_generate, assistant_model - ): + if deprecate_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate): return GenerationMixin.generate( self, inputs, @@ -2324,8 +2359,7 @@ def generate( use_model_defaults, custom_generate=deprecate_mode_repo, trust_remote_code=trust_remote_code, - tokenizer=tokenizer, - assistant_tokenizer=assistant_tokenizer, + **gen_mode_kwargs, **kwargs, ) @@ -2406,7 +2440,7 @@ def generate( ) if generation_config.token_healing: - input_ids = self.heal_tokens(input_ids, tokenizer) + input_ids = self.heal_tokens(input_ids, gen_mode_kwargs.get("tokenizer")) if streamer is not None: streamer.put(input_ids.cpu()) @@ -2444,13 +2478,10 @@ def generate( ): max_cache_length += inputs_tensor.shape[1] self._prepare_cache_for_generation( - generation_config, model_kwargs, assistant_model, batch_size, max_cache_length + generation_config, model_kwargs, generation_mode, batch_size, max_cache_length ) - # 8. determine generation mode - generation_mode = generation_config.get_generation_mode(assistant_model) - - if streamer is not None and (generation_config.num_beams > 1): + if streamer is not None and generation_mode == GenerationMode.BEAM_SEARCH: raise ValueError( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) @@ -2466,26 +2497,26 @@ def generate( UserWarning, ) - # 9. prepare logits processors and stopping criteria + # 8. prepare logits processors and stopping criteria prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, device=inputs_tensor.device, model_kwargs=model_kwargs, - negative_prompt_ids=negative_prompt_ids, - negative_prompt_attention_mask=negative_prompt_attention_mask, + **logits_processor_kwargs, ) prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs + generation_config=generation_config, + stopping_criteria=stopping_criteria, + tokenizer=gen_mode_kwargs.get("tokenizer"), ) # Set model_kwargs `use_cache` so we can use it later in forward runs model_kwargs["use_cache"] = generation_config.use_cache - # 10. go into different generation modes + # 9. go into different generation modes if isinstance(custom_generate, Callable): result = custom_generate( self, @@ -2516,19 +2547,19 @@ def generate( f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" ) - # 11. Get the candidate generator, given the parameterization + # 10. Get the candidate generator, given the parameterization candidate_generator = self._get_candidate_generator( generation_config=generation_config, input_ids=input_ids, inputs_tensor=inputs_tensor, assistant_model=assistant_model, logits_processor=logits_processor, - target_tokenizer=tokenizer, - assistant_tokenizer=assistant_tokenizer, + target_tokenizer=gen_mode_kwargs.get("tokenizer"), + assistant_tokenizer=gen_mode_kwargs.get("assistant_tokenizer"), model_kwargs=model_kwargs, ) - # 12. run assisted generate + # 11. run assisted generate result = self._assisted_decoding( input_ids, candidate_generator=candidate_generator, @@ -2541,7 +2572,7 @@ def generate( ) elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + # 10. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) result = self._sample( input_ids, logits_processor=prepared_logits_processor, @@ -2553,7 +2584,7 @@ def generate( ) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): - # 11. run beam sample + # 10. run beam sample result = self._beam_search( input_ids, logits_processor=prepared_logits_processor, diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py index 22b607ec2865..cce1d7a9a5a8 100644 --- a/src/transformers/models/dia/generation_dia.py +++ b/src/transformers/models/dia/generation_dia.py @@ -272,7 +272,9 @@ def _main_generate_loop( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) + self._validate_assistant( + assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer + ) # 2. Set generation parameters if not already defined if synced_gpus is None: @@ -346,12 +348,13 @@ def _main_generate_loop( and not self.config.is_encoder_decoder ): max_cache_length += inputs_tensor.shape[1] - self._prepare_cache_for_generation( - generation_config, model_kwargs, assistant_model, batch_size, max_cache_length - ) # 8. determine generation mode - generation_mode = generation_config.get_generation_mode(assistant_model) + generation_mode = generation_config.get_generation_mode(assistant_model=assistant_model) + + self._prepare_cache_for_generation( + generation_config, model_kwargs, generation_mode, batch_size, max_cache_length + ) if streamer is not None and (generation_config.num_beams > 1): raise ValueError( @@ -371,7 +374,7 @@ def _main_generate_loop( negative_prompt_attention_mask=negative_prompt_attention_mask, ) prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs + generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer ) # Set model_kwargs `use_cache` so we can use it later in forward runs diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 641eec0634d8..c10a0f80acf1 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -1222,7 +1222,7 @@ def _prepare_model_inputs( self.codec_model._prepare_cache_for_generation( generation_config=self.codec_model.generation_config, model_kwargs=temporary_model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=self.config.codec_config.sliding_window, ) diff --git a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py index 03b442b2edbd..8541a911e947 100644 --- a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py @@ -357,7 +357,7 @@ def _prepare_model_inputs( self.codec_model._prepare_cache_for_generation( generation_config=self.codec_model.generation_config, model_kwargs=temporary_model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=self.config.codec_config.sliding_window, ) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 92d2d049ae39..bfcf1191dad3 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1257,7 +1257,7 @@ def generate( self._prepare_cache_for_generation( generation_config, model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=max_cache_length, ) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 8133d33bac47..c634ffe6598d 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -2172,7 +2172,7 @@ def generate( self._prepare_cache_for_generation( generation_config, model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=max_cache_length, ) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 5f1c592d3230..f3932137a082 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1566,7 +1566,7 @@ def extend_enc_output(tensor, num_beams=None): self._prepare_cache_for_generation( generation_config, model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=input_ids.shape[0], max_cache_length=generation_config.max_length - 1, ) From 47580e49fc721cc738bf321f7581edaff7a11b08 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 1 Sep 2025 17:28:51 +0200 Subject: [PATCH 2/8] ops --- src/transformers/generation/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1cd949e46239..d99dcedca5b1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2335,7 +2335,9 @@ def generate( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model=assistant_model) + self._validate_assistant( + assistant_model, gen_mode_kwargs.get("tokenizer"), gen_mode_kwargs.get("prefix_allowed_tokens_fn") + ) generation_mode = generation_config.get_generation_mode(assistant_model) From 3b09223f84c88df2c8e47a477bf74b862e51bbdb Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 1 Sep 2025 17:45:45 +0200 Subject: [PATCH 3/8] fix --- src/transformers/generation/utils.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d99dcedca5b1..e864619cfbd8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2172,17 +2172,17 @@ def _get_mode_processor_kwargs( "assistant_model": assistant_model, "assistant_tokenizer": kwargs.pop("assistant_tokenizer", None), } - logits_processor_kwargs = { - "negative_prompt_ids": negative_prompt_ids, - "negative_prompt_attention_mask": negative_prompt_attention_mask, - "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn, - } - # If custom_generate is a callable, we need to extract its parameters + gen_mode_kwargs = {k: v for k, v in gen_mode_kwargs.items() if v is not None} if isinstance(custom_generate, Callable): usual_mode_kwargs = inspect.signature(GenerationMixin._sample).parameters.keys() custom_generate_kwargs = inspect.signature(custom_generate).parameters.keys() new_custom_keys = custom_generate_kwargs - usual_mode_kwargs gen_mode_kwargs = {k: kwargs.pop(k) for k in new_custom_keys if k in kwargs} + logits_processor_kwargs = { + "negative_prompt_ids": negative_prompt_ids, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn, + } return gen_mode_kwargs, logits_processor_kwargs @torch.no_grad() @@ -2335,9 +2335,7 @@ def generate( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant( - assistant_model, gen_mode_kwargs.get("tokenizer"), gen_mode_kwargs.get("prefix_allowed_tokens_fn") - ) + self._validate_assistant(assistant_model, gen_mode_kwargs.get("tokenizer"), prefix_allowed_tokens_fn) generation_mode = generation_config.get_generation_mode(assistant_model) @@ -2354,7 +2352,7 @@ def generate( stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, - assistant_model, + gen_mode_kwargs.pop("assistant_model"), streamer, negative_prompt_ids, negative_prompt_attention_mask, @@ -2529,6 +2527,7 @@ def generate( synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, + **gen_mode_kwargs, ) elif generation_mode == GenerationMode.ASSISTED_GENERATION: if generation_config.num_return_sequences > 1: From 2a74c4ba9f9c8b5d1c744ca375c2a0c5536dd418 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 1 Sep 2025 17:55:15 +0200 Subject: [PATCH 4/8] ops --- src/transformers/generation/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e864619cfbd8..9605ab3c44f3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2335,7 +2335,9 @@ def generate( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model, gen_mode_kwargs.get("tokenizer"), prefix_allowed_tokens_fn) + self._validate_assistant( + assistant_model, gen_mode_kwargs.get("tokenizer"), gen_mode_kwargs.get("assistant_tokenizer") + ) generation_mode = generation_config.get_generation_mode(assistant_model) From bfd41f1dd88885983b99d14fb9eaf656a19c6bd6 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 2 Sep 2025 15:47:50 +0200 Subject: [PATCH 5/8] review --- src/transformers/generation/utils.py | 164 +++++++++--------- src/transformers/models/dia/generation_dia.py | 2 +- 2 files changed, 79 insertions(+), 87 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9605ab3c44f3..1531f7acee17 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1492,35 +1492,44 @@ def compute_transition_scores( return transition_scores - def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): - if assistant_model is None: - return + def _validate_generation_mode(self, generation_mode, generation_mode_kwargs): + if "synced_gpus" not in generation_mode_kwargs: + generation_mode_kwargs["synced_gpus"] = ( + is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + ) and dist.get_world_size() > 1 - if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: - attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] - attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] - are_equal = all( - getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check + if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs: + raise ValueError( + "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) - if not are_equal: - raise ValueError( - "The main model and the assistant don't have compatible encoder-dependent input shapes. " - "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." - ) - doc_reference = ( - "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)" - ) - if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: - if assistant_tokenizer is not None: - raise ValueError( - f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}." - ) - else: - if tokenizer is None or assistant_tokenizer is None: - raise ValueError( - f"The main and assistant models have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}." + if generation_mode == GenerationMode.ASSISTED_GENERATION: + assistant_model = generation_mode_kwargs.get("assistant_model") + if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: + attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] + attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] + are_equal = all( + getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check ) + if not are_equal: + raise ValueError( + "The main model and the assistant don't have compatible encoder-dependent input shapes. " + "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." + ) + + doc_reference = ( + "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)" + ) + if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: + if "assistant_tokenizer" in generation_mode_kwargs: + raise ValueError( + f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}." + ) + else: + if "tokenizer" not in generation_mode_kwargs or "assistant_tokenizer" not in generation_mode_kwargs: + raise ValueError( + f"The main and assistant models have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}." + ) def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" @@ -2155,35 +2164,33 @@ def _get_deprecated_gen_repo( ) return repo - def _get_mode_processor_kwargs( + def _extract_generation_mode_kwargs( self, custom_generate, kwargs, + synced_gpus, assistant_model, - negative_prompt_ids, - negative_prompt_attention_mask, - prefix_allowed_tokens_fn, - ) -> tuple[dict[str, Any], dict[str, Any]]: + streamer, + ) -> dict[str, Any]: """ - Extracts and returns the generation mode and logit processor related keyword arguments from the provided kwargs. + Extracts and returns the generation mode related keyword arguments from the provided kwargs. """ - gen_mode_kwargs = { + generation_mode_kwargs = { "tokenizer": kwargs.pop("tokenizer", None), - "assistant_model": assistant_model, "assistant_tokenizer": kwargs.pop("assistant_tokenizer", None), + "synced_gpus": synced_gpus, + "assistant_model": assistant_model, + "streamer": streamer, } - gen_mode_kwargs = {k: v for k, v in gen_mode_kwargs.items() if v is not None} + generation_mode_kwargs = {k: v for k, v in generation_mode_kwargs.items() if v is not None} + # Custom_generate callables can have their own set of arguments + # To extract them, we compare the signature with the standard _sample method if isinstance(custom_generate, Callable): usual_mode_kwargs = inspect.signature(GenerationMixin._sample).parameters.keys() custom_generate_kwargs = inspect.signature(custom_generate).parameters.keys() new_custom_keys = custom_generate_kwargs - usual_mode_kwargs - gen_mode_kwargs = {k: kwargs.pop(k) for k in new_custom_keys if k in kwargs} - logits_processor_kwargs = { - "negative_prompt_ids": negative_prompt_ids, - "negative_prompt_attention_mask": negative_prompt_attention_mask, - "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn, - } - return gen_mode_kwargs, logits_processor_kwargs + generation_mode_kwargs = {k: kwargs.pop(k) for k in new_custom_keys if k in kwargs} + return generation_mode_kwargs @torch.no_grad() def generate( @@ -2322,25 +2329,22 @@ def generate( return custom_generate_function(model=self, **generate_arguments) # 1. Handle kwargs, `generation_config`, validate them and obtain generation mode - gen_mode_kwargs, logits_processor_kwargs = self._get_mode_processor_kwargs( + generation_mode_kwargs = self._extract_generation_mode_kwargs( custom_generate, kwargs, + synced_gpus, assistant_model, - negative_prompt_ids, - negative_prompt_attention_mask, - prefix_allowed_tokens_fn, + streamer, ) generation_config, model_kwargs = self._prepare_generation_config( generation_config, use_model_defaults, **kwargs ) - self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant( - assistant_model, gen_mode_kwargs.get("tokenizer"), gen_mode_kwargs.get("assistant_tokenizer") - ) - generation_mode = generation_config.get_generation_mode(assistant_model) + self._validate_model_kwargs(model_kwargs.copy()) + self._validate_generation_mode(generation_mode, generation_mode_kwargs) + # Deprecation-related step: set Hub repo for deprecated strategies. # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode. # It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps. @@ -2348,27 +2352,22 @@ def generate( if deprecate_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate): return GenerationMixin.generate( self, - inputs, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - gen_mode_kwargs.pop("assistant_model"), - streamer, - negative_prompt_ids, - negative_prompt_attention_mask, - use_model_defaults, + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + assistant_model=assistant_model, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + use_model_defaults=use_model_defaults, custom_generate=deprecate_mode_repo, trust_remote_code=trust_remote_code, - **gen_mode_kwargs, + **generation_mode_kwargs, **kwargs, ) # 2. Set generation parameters if not already defined - if synced_gpus is None: - synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() @@ -2442,7 +2441,7 @@ def generate( ) if generation_config.token_healing: - input_ids = self.heal_tokens(input_ids, gen_mode_kwargs.get("tokenizer")) + input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer")) if streamer is not None: streamer.put(input_ids.cpu()) @@ -2483,11 +2482,6 @@ def generate( generation_config, model_kwargs, generation_mode, batch_size, max_cache_length ) - if streamer is not None and generation_mode == GenerationMode.BEAM_SEARCH: - raise ValueError( - "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." - ) - if self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with the `input_ids` being on a device type different" @@ -2504,15 +2498,17 @@ def generate( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, device=inputs_tensor.device, model_kwargs=model_kwargs, - **logits_processor_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, ) prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria, - tokenizer=gen_mode_kwargs.get("tokenizer"), + tokenizer=generation_mode_kwargs.get("tokenizer"), ) # Set model_kwargs `use_cache` so we can use it later in forward runs @@ -2526,10 +2522,8 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, - **gen_mode_kwargs, ) elif generation_mode == GenerationMode.ASSISTED_GENERATION: if generation_config.num_return_sequences > 1: @@ -2555,10 +2549,10 @@ def generate( generation_config=generation_config, input_ids=input_ids, inputs_tensor=inputs_tensor, - assistant_model=assistant_model, + assistant_model=generation_mode_kwargs.pop("assistant_model", None), logits_processor=logits_processor, - target_tokenizer=gen_mode_kwargs.get("tokenizer"), - assistant_tokenizer=gen_mode_kwargs.get("assistant_tokenizer"), + target_tokenizer=generation_mode_kwargs.pop("tokenizer", None), + assistant_tokenizer=generation_mode_kwargs.pop("assistant_tokenizer", None), model_kwargs=model_kwargs, ) @@ -2569,8 +2563,7 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, ) @@ -2581,8 +2574,7 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, ) @@ -2593,7 +2585,7 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, + **generation_mode_kwargs, **model_kwargs, ) @@ -2716,7 +2708,7 @@ def _sample( stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, - streamer: Optional["BaseStreamer"], + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -3482,7 +3474,7 @@ def _assisted_decoding( stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, - streamer: Optional["BaseStreamer"], + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py index cce1d7a9a5a8..3f471c231f6e 100644 --- a/src/transformers/models/dia/generation_dia.py +++ b/src/transformers/models/dia/generation_dia.py @@ -272,7 +272,7 @@ def _main_generate_loop( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant( + self._validate_generation_mode( assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer ) From 494c9a8217125c9a2f4d2135935a989dc58a14f1 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 2 Sep 2025 16:12:38 +0200 Subject: [PATCH 6/8] fix --- src/transformers/generation/utils.py | 9 ++++----- src/transformers/models/csm/generation_csm.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1531f7acee17..b62c12a8df92 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1503,8 +1503,7 @@ def _validate_generation_mode(self, generation_mode, generation_mode_kwargs): "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) - if generation_mode == GenerationMode.ASSISTED_GENERATION: - assistant_model = generation_mode_kwargs.get("assistant_model") + if (assistant_model := generation_mode_kwargs.get("assistant_model")) is not None: if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] @@ -2707,7 +2706,7 @@ def _sample( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: @@ -3136,7 +3135,7 @@ def _beam_search( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, + synced_gpus: bool = False, **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" @@ -3473,7 +3472,7 @@ def _assisted_decoding( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: diff --git a/src/transformers/models/csm/generation_csm.py b/src/transformers/models/csm/generation_csm.py index 9c2f06e6562f..b14f353685c2 100644 --- a/src/transformers/models/csm/generation_csm.py +++ b/src/transformers/models/csm/generation_csm.py @@ -153,8 +153,8 @@ def _sample( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, - streamer: Optional["BaseStreamer"], + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: """ From da72717eea6255736a78ece8228e92103e74933f Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 2 Sep 2025 18:02:55 +0200 Subject: [PATCH 7/8] fix dia --- src/transformers/models/dia/generation_dia.py | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py index 3f471c231f6e..439b498b0988 100644 --- a/src/transformers/models/dia/generation_dia.py +++ b/src/transformers/models/dia/generation_dia.py @@ -265,16 +265,20 @@ def _main_generate_loop( ): # ********** mostly taken from main generate function up to calling the different methods (see NOTE) ********** # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria - assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation - + generation_mode_kwargs = self._extract_generation_mode_kwargs( + custom_generate, + kwargs, + synced_gpus, + assistant_model, + streamer, + ) generation_config, model_kwargs = self._prepare_generation_config( generation_config, use_model_defaults, **kwargs ) + generation_mode = generation_config.get_generation_mode(assistant_model) + self._validate_model_kwargs(model_kwargs.copy()) - self._validate_generation_mode( - assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer - ) + self._validate_generation_mode(generation_mode, generation_mode_kwargs) # 2. Set generation parameters if not already defined if synced_gpus is None: @@ -310,7 +314,7 @@ def _main_generate_loop( ) if generation_config.token_healing: - input_ids = self.heal_tokens(input_ids, tokenizer) + input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer")) if streamer is not None: streamer.put(input_ids.cpu()) @@ -348,20 +352,11 @@ def _main_generate_loop( and not self.config.is_encoder_decoder ): max_cache_length += inputs_tensor.shape[1] - - # 8. determine generation mode - generation_mode = generation_config.get_generation_mode(assistant_model=assistant_model) - self._prepare_cache_for_generation( generation_config, model_kwargs, generation_mode, batch_size, max_cache_length ) - if streamer is not None and (generation_config.num_beams > 1): - raise ValueError( - "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." - ) - - # 9. prepare logits processors and stopping criteria + # 8. prepare logits processors and stopping criteria prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, @@ -374,7 +369,9 @@ def _main_generate_loop( negative_prompt_attention_mask=negative_prompt_attention_mask, ) prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer + generation_config=generation_config, + stopping_criteria=stopping_criteria, + tokenizer=generation_mode_kwargs.get("tokenizer"), ) # Set model_kwargs `use_cache` so we can use it later in forward runs @@ -396,8 +393,7 @@ def _main_generate_loop( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, ) else: From cec61ab18170196c2f30d49c9be1137aadb658c6 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 3 Sep 2025 17:17:09 +0200 Subject: [PATCH 8/8] review --- src/transformers/generation/utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b62c12a8df92..2716c79c702f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1493,11 +1493,6 @@ def compute_transition_scores( return transition_scores def _validate_generation_mode(self, generation_mode, generation_mode_kwargs): - if "synced_gpus" not in generation_mode_kwargs: - generation_mode_kwargs["synced_gpus"] = ( - is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) - ) and dist.get_world_size() > 1 - if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs: raise ValueError( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." @@ -2177,10 +2172,13 @@ def _extract_generation_mode_kwargs( generation_mode_kwargs = { "tokenizer": kwargs.pop("tokenizer", None), "assistant_tokenizer": kwargs.pop("assistant_tokenizer", None), - "synced_gpus": synced_gpus, "assistant_model": assistant_model, "streamer": streamer, } + if synced_gpus is not None: + generation_mode_kwargs["synced_gpus"] = ( + is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + ) and dist.get_world_size() > 1 generation_mode_kwargs = {k: v for k, v in generation_mode_kwargs.items() if v is not None} # Custom_generate callables can have their own set of arguments # To extract them, we compare the signature with the standard _sample method