Require input_ids for repetition penalty#45389
Require input_ids for repetition penalty#45389ruben-aghayan wants to merge 3 commits intohuggingface:mainfrom
Conversation
|
Does |
d1d35b7 to
3a4294c
Compare
Thank you for your comment! Is List considered valid input? Generate args are all tensors. Admittedly, input_ids goes in kwargs so is not explicitly typed. But event today, such an input would fail e.g. produces |
|
I am unfamiliar with encoder and the inputs embeds, so I would prefer if it can be passed on. If no one picks this up I will when I have some room! |
|
cc @Cyrilvallez for generation maybe, but if you're overloaded we might need to find someone to own generation code! |
|
Hey @ruben-aghayan! We can indeed raise in such cases, but this code should live inside |
b84bcc7 to
1721159
Compare
1721159 to
08ac3d8
Compare
thanks + done I noticed |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45389&sha=036192 |
looks unrelated? |
| inputs_embeds = model_kwargs.get("inputs_embeds") if model_kwargs is not None else None | ||
| if inputs_embeds is not None and (input_ids_seq_length is None or input_ids_seq_length == 0): | ||
| warnings.warn( | ||
| "Passing `repetition_penalty` requires some form of `input_ids` to be passed to " | ||
| "`generate`, ignoring the argument.", | ||
| UserWarning, | ||
| ) | ||
| else: | ||
| processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) |
There was a problem hiding this comment.
We don't want to skip, only warn that it will apply the repetition only on new tokens, vs applying it to the full sequence inclucing prompt
What does this PR do?
This PR warns when using repetition penalty or ngram repetition penalty in decoder models on input_embed without input_ids args.
Previously, users were able to call repetition penalty on generate calls with input_embeds args. Since they don't actually have tokens, the repetition penalty was not applied to the input args, but only to the generated tokens.
An equivalent call (ie tokens corresponding to those embeddings) would behave differently by applying the repetition penalty to the input tokens.
This change makes it so that the repetition penalty is not applied and a warning is shown.
Testing
pytest tests/generation -vv -k 'not test_text_streamer_decode_kwargs'
test_text_streamer_decode_kwargs was giving an unrelated failure
Code Agent Policy
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante