Skip to content

Require input_ids for repetition penalty#45389

Draft
ruben-aghayan wants to merge 3 commits intohuggingface:mainfrom
ruben-aghayan:fix-repetition-penalty-inputs-embeds
Draft

Require input_ids for repetition penalty#45389
ruben-aghayan wants to merge 3 commits intohuggingface:mainfrom
ruben-aghayan:fix-repetition-penalty-inputs-embeds

Conversation

@ruben-aghayan
Copy link
Copy Markdown

@ruben-aghayan ruben-aghayan commented Apr 13, 2026

What does this PR do?

This PR warns when using repetition penalty or ngram repetition penalty in decoder models on input_embed without input_ids args.

Previously, users were able to call repetition penalty on generate calls with input_embeds args. Since they don't actually have tokens, the repetition penalty was not applied to the input args, but only to the generated tokens.
An equivalent call (ie tokens corresponding to those embeddings) would behave differently by applying the repetition penalty to the input tokens.
This change makes it so that the repetition penalty is not applied and a warning is shown.

Testing

pytest tests/generation -vv -k 'not test_text_streamer_decode_kwargs'
test_text_streamer_decode_kwargs was giving an unrelated failure

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

@afurm
Copy link
Copy Markdown

afurm commented Apr 13, 2026

Does prompt_input_ids.get() (or a follow-up check) need to handle the case where it's a list rather than a tensor? If input_ids is passed as a plain Python list, isinstance(..., torch.Tensor) would be False and this would raise the error even for valid input.

@ruben-aghayan ruben-aghayan force-pushed the fix-repetition-penalty-inputs-embeds branch from d1d35b7 to 3a4294c Compare April 13, 2026 05:46
@ruben-aghayan ruben-aghayan changed the title Guard repetition penalty for inputs_embeds Require input_ids for repetition penalty Apr 13, 2026
@ruben-aghayan
Copy link
Copy Markdown
Author

Does prompt_input_ids.get() (or a follow-up check) need to handle the case where it's a list rather than a tensor? If input_ids is passed as a plain Python list, isinstance(..., torch.Tensor) would be False and this would raise the error even for valid input.

Thank you for your comment!

Is List considered valid input? Generate args are all tensors. Admittedly, input_ids goes in kwargs so is not explicitly typed. But event today, such an input would fail e.g.

  from transformers import AutoTokenizer, AutoModelForCausalLM

  tok = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
  model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")

  ids = tok("Hello world", return_tensors="pt").input_ids[0].tolist()
  model.generate(input_ids=ids)

produces AttributeError: 'list' object has no attribute 'shape' since it's being treated as a tensor

@Rocketknight1
Copy link
Copy Markdown
Member

@remi-or @McPatate for generation/CB, but feel free to pass it on to someone else if you're not comfortable reviewing it!

@remi-or
Copy link
Copy Markdown
Collaborator

remi-or commented Apr 14, 2026

I am unfamiliar with encoder and the inputs embeds, so I would prefer if it can be passed on. If no one picks this up I will when I have some room!

@Rocketknight1
Copy link
Copy Markdown
Member

cc @Cyrilvallez for generation maybe, but if you're overloaded we might need to find someone to own generation code!

@Cyrilvallez
Copy link
Copy Markdown
Member

Hey @ruben-aghayan! We can indeed raise in such cases, but this code should live inside _get_logits_processor!

@ruben-aghayan ruben-aghayan marked this pull request as draft April 25, 2026 01:44
@ruben-aghayan ruben-aghayan force-pushed the fix-repetition-penalty-inputs-embeds branch 3 times, most recently from b84bcc7 to 1721159 Compare April 25, 2026 04:07
@ruben-aghayan ruben-aghayan force-pushed the fix-repetition-penalty-inputs-embeds branch from 1721159 to 08ac3d8 Compare April 25, 2026 04:11
@ruben-aghayan
Copy link
Copy Markdown
Author

Hey @ruben-aghayan! We can indeed raise in such cases, but this code should live inside _get_logits_processor!

thanks + done

I noticed EncoderRepetitionPenaltyLogitsProcessor above just warns so I switched to warning (this would have been enough for my use case). Also extended it to NoRepeatNGramLogitsProcessor

@ruben-aghayan ruben-aghayan marked this pull request as ready for review April 25, 2026 04:15
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45389&sha=036192

@ruben-aghayan
Copy link
Copy Markdown
Author

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45389&sha=036192

looks unrelated?

Comment on lines +1092 to +1100
inputs_embeds = model_kwargs.get("inputs_embeds") if model_kwargs is not None else None
if inputs_embeds is not None and (input_ids_seq_length is None or input_ids_seq_length == 0):
warnings.warn(
"Passing `repetition_penalty` requires some form of `input_ids` to be passed to "
"`generate`, ignoring the argument.",
UserWarning,
)
else:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to skip, only warn that it will apply the repetition only on new tokens, vs applying it to the full sequence inclucing prompt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants