Support generating with fallback for short form audio in Whisper#30984
Support generating with fallback for short form audio in Whisper#30984kamilakesbi merged 72 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
956cfb4 to
07e7db3
Compare
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Looks like a good start @kamilakesbi. Two biggest suggestions are related to the designs of i) assisted generation, and ii) num return sequences. Think both can be simplified and assisted generation made more rigorous.
Two further design questions:
- Should we return the original
decoder_input_idsand EOS tokens in the sequences for long-form generation as well? IMO this is an inconsistency that we return them for short-form, but not long-form, and I would be in-favour of unifying the two in this PR - Is it correct to de-activate beam search when
temperature>0? We currently don't do this for long-form generation, but given the original Whisper repo does, it would be good to determine whether this is a 'bug' or an intended design decision
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
|
@ArthurZucker thanks for your review! I took your remarks into account :) Failing tests are unrelated to this PR. If this is ok for you we can perhaps merge or wait for the CI to be green... |
|
Let's wait for the full CI seems alright now! |
|
Also a question ont answered! |
|
The CI is green yes :) if it's ok for you I can merge! |
a00d2e8 to
6b7b3d6
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks! Last to nits and you can merge!
What does this PR do?
The aim of this PR is to refacto the Whisper
generatemethod to handle both short form and long form audio generation similarly. It will support short form audio generation with fallback (as requested in #29508).Here's what I've done:
Removed previous short-form scripts:
I've removed the part of the code used for short form generation. This involve lines 562 to 603 and lines 498 to 505 in main. Now when a short form audio (or a batched short form of audio) is passed to
generate, it's processed by the part of the code previously used for long form generation.Use is_shortform to still distinguish between short form and long form in some cases:
In the
_postprocess_outputsmethod we only returnpast_key_valuesif the audios are short form. For long form audios it is too expensive. (cf. this line).In
_retrieve_max_frames_and_seek: For long form audios, we necessarily need to pass an attention mask but not for short form audios. We can thus computemax_framesandseekwithout relying on the attention mask for short form audios.I've also updated the
split_by_batch_indexmethod: the previous method was broken when return_dict_in_generate was set to True for different short form audio cases. Now it handles both short form and long form audios.I've removed the
is_shortformparameter from the inputs to the_retrieve_logit_processorsmethod to allow the use ofgeneration_config.no_speech_thresholdfor short form audios.I've removed
is_shortfromparameter from the inputs to the_set_return_outputsmethod to allow the use oflogprob_thresholdfor short form audios.Make num_return_sequences>1 compatible with generate_with_fallback:
generate_with_fallbackcan't handle num_return_sequences>1 by design. I've added a new method, called_expand_variables_for_generation, which expands the different variables before passing intogenerate_with_fallbackwhengeneration_config.num_return_sequences>1. After expansion it will setgeneration_config.num_return_sequencesto 1 for compatibility withgenerate_with_fallback.Ensure that the output format for short form audio is compatible with the output format in main:
The output format for long-form audio is different from that for short-form audio. In order to ensure that the output is similar to that obtained in main when processing short form audio, we need to add a few post-processing steps: This is what is done in lines 721 to 765. In particular here:
EOStoken to the output sequence as it was removed during generation with fallback.return_token_timestampsis True in the correct format (see here).return_dict_in_generateis True, we use the new method_stack_split_outputsto get the output dict (containing all attributes (scores, encoder_attentions, etc.)) in the right format._stack_split_outputsbasically performs the opposite operations tosplit_by_batch_index.Make failing slow tests to pass:
Add new tests to make sure generation with fallback works for short form audios:
I've added two tests:
test_whisper_shortform_single_batch_prev_condandtest_whisper_shortform_multi_batch_hard_prev_cond.Who can review:
@sanchit-gandhi