Skip to content

🚨 Generation config defaults are now None#42702

Merged
zucchini-nlp merged 24 commits intohuggingface:mainfrom
zucchini-nlp:generation-config-defaults
Dec 18, 2025
Merged

🚨 Generation config defaults are now None#42702
zucchini-nlp merged 24 commits intohuggingface:mainfrom
zucchini-nlp:generation-config-defaults

Conversation

@zucchini-nlp
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp commented Dec 8, 2025

What does this PR do?

As per title, must have been done long time ago but could't break BC. The current impl breaks BC only half-way, i.e. the generation loop is not affected and will keep using the old defaults. The biggest difference is for users to init, access, modify, etc. the model's generation config directly:

# `0` before this PR, `None` after the PR
print(model.generation_config.no_repeat_ngram_size)

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread src/transformers/generation/candidate_generator.py
Comment thread src/transformers/generation/configuration_utils.py
Comment thread src/transformers/generation/configuration_utils.py
Comment thread src/transformers/generation/watermarking.py
Comment on lines +1069 to +1074
generation_params = {}
default_config = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
for key in GenerationConfig._get_default_generation_params().keys():
if hasattr(self, key) and getattr(self, key) is not None and key not in default_config:
generation_params[key] = getattr(self, key)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could have been simplified because we no longer have any generation params in model.config

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding, generation config is top level either way. Hence, we no longer need to discern for submodels etc in composite models (which could have possibly different config values here)

Copy link
Copy Markdown
Member Author

@zucchini-nlp zucchini-nlp Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that and also because these lines are more of a workaround for old models (e.g. bart). New models don't have any generation params in model config anyway, we don't allow it for quite a long time

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, makes sense to me

Comment thread tests/models/reformer/test_modeling_reformer.py
@zucchini-nlp zucchini-nlp requested a review from vasqu December 10, 2025 12:39
@zucchini-nlp
Copy link
Copy Markdown
Member Author

Ready for review!

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's still add an 🚨 even if it's not completely breaking, I rather be safe than sorry here. We never know

First round of comments, my biggest issue would be the kwargs vs generation config passing. But you also left a note there.

Comment thread src/transformers/configuration_utils.py Outdated
Comment on lines +1069 to +1074
generation_params = {}
default_config = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
for key in GenerationConfig._get_default_generation_params().keys():
if hasattr(self, key) and getattr(self, key) is not None and key not in default_config:
generation_params[key] = getattr(self, key)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding, generation config is top level either way. Hence, we no longer need to discern for submodels etc in composite models (which could have possibly different config values here)

Comment thread src/transformers/generation/candidate_generator.py Outdated
Comment thread src/transformers/generation/configuration_utils.py
Comment thread src/transformers/generation/configuration_utils.py
Comment thread src/transformers/generation/utils.py Outdated
Comment thread src/transformers/generation/watermarking.py Outdated
Comment thread src/transformers/models/whisper/generation_whisper.py
Comment thread tests/utils/test_cache_utils.py
zucchini-nlp and others added 2 commits December 11, 2025 11:54
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
@zucchini-nlp zucchini-nlp changed the title Generation config defaults are now None 🚨 Generation config defaults are now None Dec 11, 2025
@zucchini-nlp
Copy link
Copy Markdown
Member Author

run-slow: bart, csm, dia, encoder_decoder, musicgen, rag, reformer, speech_encoder_decoder, vision_encoder_decoder, whisper

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/bart", "models/csm", "models/dia", "models/encoder_decoder", "models/musicgen", "models/rag", "models/reformer", "models/speech_encoder_decoder", "models/vision_encoder_decoder", "models/whisper"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@zucchini-nlp
Copy link
Copy Markdown
Member Author

@vasqu requesting another review :)

@zucchini-nlp zucchini-nlp requested a review from vasqu December 11, 2025 12:23
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a few last nits

Comment on lines +1069 to +1074
generation_params = {}
default_config = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
for key in GenerationConfig._get_default_generation_params().keys():
if hasattr(self, key) and getattr(self, key) is not None and key not in default_config:
generation_params[key] = getattr(self, key)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, makes sense to me

elif self.num_beams == 1:
if self.do_sample is False:
elif self.num_beams is None or self.num_beams == 1:
if self.do_sample is not True:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it was missed here?

Comment thread src/transformers/generation/utils.py Outdated
Comment on lines +1768 to +1769
# user-defined kwargs or `generation_config` > `self.generation_config` > global default values
# NOTE: doesn't make sense to allow kwargs and `generation_config`. Might be strict and make them mutually exclusive?
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough but maybe we can break for v5? Not super important but it gives us a good opportunity to do so.

Either way, let's upgrade this to a TODO (as well).

Comment thread tests/generation/test_configuration_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py
@zucchini-nlp
Copy link
Copy Markdown
Member Author

Will merge when CI is fixed on main

@albertvillanova
Copy link
Copy Markdown
Member

albertvillanova commented Dec 15, 2025

Hi,

First of all, thanks for your reactivity and addressing the underlying issue:

On the other hand, I did a quick pass over the generation code, and I think there is a subtle semantic point worth double-checking: None is not always just "unset", it can also mean "disable this behavior".

Concretely, top_k=None disables top-k filtering entirely and allows all tokens. Because of that, we now have two different concepts that can both be represented as "None":

  • Use the global/default value for top_k (e.g. 50)
  • Explicitly disable top-k filtering (top_k=None or 0) even if the model's generation config or the global defaults say otherwise

Therefore, a situation can occur during training where:

  • the user-provided generation_config has top_k=None (intending to disable top-k filtering),
  • the model’s own generation_config has a non-None value for top_k,
  • merging logic currently preserves the model’s value instead of respecting the explicit None.

In that scenario, None is not just "unset"; it is a meaningful instruction ("don't apply top-k filtering"). If so, the merge semantics may need refinement to avoid unintentionally re-enabling filtering (e.g. by using a sentinel value instead).

I'll continue digging into this, but flagging it early so we can discuss it before merging.

@zucchini-nlp
Copy link
Copy Markdown
Member Author

@albertvillanova I think if top_k=None in the generation config it is the same as if users did not pass any top_k in kwargs. A value is not set to any value (None) does not specifically mean that the users is requesting to not use it, so the users would need to explicitly unset as top_k=0 if model has saved a different value.

Unfortunately we have no way to 100% know what users wants when they set values to None in current code. The only way would be for us to not update generation config with model's defaults if users pass my_generation_config. But that will be much more breaking and will require users to always create a custom config from model.generation_config

For ex, if everyone prepared custom configs as below, we can fix your issue. I'm afraid it's not the case for most users

my_generation_config = model.generation_config
my_generation_config.top_k = None
model.generate(inputs, generation_config=my_generation_config)

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: bart, csm, dia, encoder_decoder, musicgen, rag, reformer, speech_encoder_decoder, vision_encoder_decoder, whisper

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Dec 17, 2025

I think we can merge? @zucchini-nlp

@zucchini-nlp zucchini-nlp merged commit a81e04a into huggingface:main Dec 18, 2025
25 checks passed
modified_values = {}
global_default_generation_config = GenerationConfig()
model_generation_config = self.generation_config
# we iterate over the model's generation config: it may hold custom keys, which we'll want to copy
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu, @zucchini-nlp are custom keys no longer copied over?

Seems like only the ones here are defaulted at this line, which wouldn't copy custom keys in self.generation_config anymore

global_defaults = self.generation_config._get_default_generation_params()
generation_config.update(**self.generation_config.to_dict(), defaults_only=True)
generation_config.update(**global_defaults, defaults_only=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like this just after?

# add custom keys not in global defaults
for key, value in self.generation_config.to_dict().items():
    if not hasattr(generation_config, key):
        setattr(generation_config, key, value)

SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* this way betetr maybe?

* delete legacy from bart and mvp

* import not found

* fix some tests

* fix more tests

* revert smth to run tests again

* i though I fixed it already, but there were more models

* commit and check tests, clean-up later

* assisted deocding shoudl work now

* docs and whisper

* fix a few more tests

* no circular import errors pls

* wording

* add a test for defaults following TRL example

* nit

* Update src/transformers/configuration_utils.py

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>

* Update src/transformers/generation/utils.py

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>

* comments

* final fix tests

* more comments

---------

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
@ebezzam ebezzam mentioned this pull request Feb 4, 2026
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants