Generate: fix logits processors doctests#29718
Conversation
|
cc @zucchini-nlp, to rebase your PRs after this gets merged :) |
| ... ''' | ||
| ... if input_ids[-1] == entity[0]: | ||
| ... return entity[1] | ||
| ... return [entity[1].item()] |
There was a problem hiding this comment.
prefix_allowed_tokens_fn should be a Callable[[int, torch.Tensor], List[int]], as explained in the docs
| >>> # distribution, summing to 1 | ||
| >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) | ||
| >>> print(torch.sum(torch.exp(outputs.scores[-1]))) | ||
| tensor(816.3250) |
There was a problem hiding this comment.
This value was sensible to numerical fluctuations across versions, and this exact value was not relevant for the test. The main point is that it is not approximately 1.0 :)
| >>> # it can't generate and EOS token in the first iteration, but it can in the others. | ||
| >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) | ||
| >>> print(outputs.scores[1][0, 50256]) # 1 (and not 0) is the first freely generated token | ||
| >>> print(outputs.scores[0][0, 50256]) |
There was a problem hiding this comment.
Whisper processor changes: @sanchit-gandhi let me know if they make sense, according to recent changes in Whisper
There was a problem hiding this comment.
Looks good to me - thanks for the updated @gante!
| sampled at their corresponding index. Originally created for | ||
| [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper). | ||
|
|
||
| Examples: |
There was a problem hiding this comment.
This processor is going to be removed in v4.40, so I didn't want to spend time fixing the test :D
| else: | ||
| generation_config = copy.deepcopy(generation_config) | ||
| # 1. prepare generation config | ||
| generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs) |
There was a problem hiding this comment.
This function from the main generate body (_prepare_generation_config) pulls generation parameterization from kwargs into generation_config.
Some Whisper-based doctests were incorrect without this functionality.
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for working on this fix!
I have a few questions about the changes, in particular why we need to change the seed
| else: | ||
| generation_config = copy.deepcopy(generation_config) | ||
| # 1. prepare generation config | ||
| generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs) |
There was a problem hiding this comment.
The lines above imply there's a self.generation_config which should be used if generation_config is None
There was a problem hiding this comment.
self._prepare_generation_config() does precisely that:
transformers/src/transformers/generation/utils.py
Line 1204 in 484e10f
It is a more complex version of this original if/else that preserves additional backward (and forward!) compatibility features of generate :)
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed | ||
|
|
||
| >>> set_seed(0) | ||
| >>> set_seed(1) |
There was a problem hiding this comment.
The seed is changed because the sample output is changed (more on that below), and a new seed was selected to illustrate the point of the example 🤗 I wanted a seed that produced a bad output in the unparameterized call and a good output in the parameterized call. Bear in mind that the model used in the examples is very small, and thus noisy with sampling.
We need to change the seed because the output of sampling has changed. There are many innocuous changes that can cause this: tiny numerical differences due to different versions, tiny numerical differences due to reordering of operations, corrections in the architecture code, different RNG behavior in torch (unlikely), and so on. As I've written in the PR header, I don't think it's worth our time finding the exact cause. The results in most other sampling tests are unchanged, there are many innocuous changes that can cause this, and it may be time-consuming to pin the cause.
|
|
||
| Examples: | ||
| ```python | ||
| >>> from transformers import AutoProcessor, WhisperForConditionalGeneration | ||
| >>> from datasets import load_dataset | ||
|
|
||
| >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") | ||
| >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") | ||
| >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
| >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") | ||
|
|
||
| >>> # This Whisper model forces the generation to start with `50362` at the first position by default, i.e. | ||
| >>> # `"forced_decoder_ids": [[1, 50362]]`. This means all other tokens are masked out. | ||
| >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) | ||
| >>> print( | ||
| ... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362) | ||
| ... ) | ||
| True | ||
| >>> print(outputs.scores[0][0, 50362]) | ||
| tensor(0.) | ||
|
|
||
| >>> # If we disable `forced_decoder_ids`, we stop seeing that effect | ||
| >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, forced_decoder_ids=None) | ||
| >>> print( | ||
| ... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362) | ||
| ... ) | ||
| False | ||
| >>> print(outputs.scores[0][0, 50362]) | ||
| tensor(19.3140) | ||
| ``` |
There was a problem hiding this comment.
Why remove the example here?
There was a problem hiding this comment.
This processor is going to be removed in v4.40, so I didn't want to spend time fixing the test :D
:)
| >>> print(torch.allclose(torch.sum(torch.exp(outputs.scores[-1])), torch.Tensor((1.000,)), rtol=1e-4)) | ||
| False |
There was a problem hiding this comment.
The previous output was more informative imo - there's infinitely many ways to not be close to 1
There was a problem hiding this comment.
True, but it is beyond the scope of the example -- the key point here is adding the flag normalizes the probability distribution.
Testing against the exact number caused the test to fail. In fact, if we run this test on different hardware (local compute vs DGX), we get a slightly different number. We could work around it with torch.allclose, but I don't think it adds value to the test :)
| >>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means | ||
| >>> # it can't generate and EOS token in the first iteration, but it can in the others. | ||
| >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) | ||
| >>> print(outputs.scores[1][0, 50256]) # 1 (and not 0) is the first freely generated token |
There was a problem hiding this comment.
out of interest - what changed here?
There was a problem hiding this comment.
I believe the indexing of first freely decoded token changed recently in Whisper, but I'd like to have @sanchit-gandhi confirming the correctness of these changes :)
There was a problem hiding this comment.
This might be a possible BC issue :/
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for working on this!
Happy with the changes - only concern is the difference in the whisper processor @sanchit-gandhi can you confirm this?
| >>> print(torch.allclose(torch.sum(torch.exp(outputs.scores[-1])), torch.Tensor((1.000,)), rtol=1e-4)) | ||
| False |
| sampled at their corresponding index. Originally created for | ||
| [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper). | ||
|
|
||
| Examples: |
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Thanks for the fixes @gante!
What does this PR do?
The doctests got stale 👀 (related PR to prevent this from happening again: #29716)
There are 2 main categories of fixes:
v4.35), but I don't think it's worth diving through to find the root cause, as many harmless things can change the output of sampling;All tests are passing after these changes (
pytest --doctest-modules src/transformers/generation/logits_process.py -vv)