You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Run the code below to get 5 hypotheses of Beam Search on audio transcription
fromtransformersimportAutoProcessor, AutoModelForSpeechSeq2Seqimporttorchimportlibrosa# Load the processor and modelprocessor=AutoProcessor.from_pretrained("openai/whisper-tiny")
model=AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-tiny")
# Load and preprocess the audio fileaudio_path="audio.mp3"audio, sr=librosa.load(audio_path, sr=16000) # Ensure the sample rate is 16kHz# Preprocess the audio to get the input featuresinputs=processor(audio, sampling_rate=16000, return_tensors="pt")
# Generate the transcription using Beam Search with the modelbeam_outputs=model.generate(
inputs["input_features"],
num_beams=5, # Number of beamsnum_return_sequences=5, # Number of hypotheses to returnearly_stopping=True,
output_scores=True,
return_dict_in_generate=True,
)
# Decode the generated transcriptionshypotheses= [processor.decode(output_ids, skip_special_tokens=True) foroutput_idsinbeam_outputs.sequences]
# Print out the hypothesesfori, hypothesisinenumerate(hypotheses):
print(f"Hypothesis {i+1}: {hypothesis}. Score: {beam_outputs.sequences_scores[i]}")
Expected behavior
Together with @ylacombe we identified that after Pull Request #30984 Whisper Beam Search generation doesn't work as intended.
See more detailed discussion on Pull Request #32970
The code above must return 5 unique hypotheses due to the core principle of the Beam Search - to select num_beams best tokens in a top_k sampling fashion. Instead, we are getting the same results with the highest probability. See below for how Beam Search used to work in version v4.25.1 and how it works now.
transformers v4.25.1
Hypothesis 1: How is Mozilla going to handle and be with this? Thank you.. Score: -0.4627407491207123
Hypothesis 2: How is Mozilla going to handle and be with this? Thank you and Q.. Score: -0.4789799749851227
Hypothesis 3: How is Mozilla going to handle and be with this? Thank you, and cute.. Score: -0.48414239287376404
Hypothesis 4: How is Mozilla going to handle and be with this? Thank you and cute.. Score: -0.4972183108329773
Hypothesis 5: How is Mozilla going to handle and be with this? Thank you, and Q.. Score: -0.5054414868354797
Hypothesis 1: How is Mozilla going to handle and be with this? Thank you.. Score: -0.5495038032531738
Hypothesis 2: How is Mozilla going to handle and be with this? Thank you.. Score: -0.5495040416717529
Hypothesis 3: How is Mozilla going to handle and be with this? Thank you.. Score: -0.5495036840438843
Hypothesis 4: How is Mozilla going to handle and be with this? Thank you.. Score: -0.5495036244392395
Hypothesis 5: How is Mozilla going to handle and be with this? Thank you.. Score: -0.5495033264160156
The function artificially expands the batch size to num_return_sequences, which causes an issue when this expanded batch size is passed to GenerationMixin.generate. Specifically, if batch_size=5 and num_return_sequences > 1, the model generates batch_size * num_beams beams but retains only the most probable beam for each element of the original batch.
Impact
This bug results in the num_return_sequences parameter not being compatible with both short-form and long-form generation. Users expecting multiple return sequences will only receive the most probable sequence, which may not meet the intended use case.
System Info
Who can help?
@ylacombe @eustlb
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
sequences_scoresin the Whisper beam search output #32970 (it allows to output sequence_score)Expected behavior
Together with @ylacombe we identified that after Pull Request #30984 Whisper Beam Search generation doesn't work as intended.
See more detailed discussion on Pull Request #32970
The code above must return 5 unique hypotheses due to the core principle of the Beam Search - to select
num_beamsbest tokens in a top_k sampling fashion. Instead, we are getting the same results with the highest probability. See below for how Beam Search used to work in version v4.25.1 and how it works now.transformers v4.25.1
transformers v4.44.1 + My Fix from #32970
@ylacombe has found the bug in _expand_variables_for_generation function.
The function artificially expands the batch size to
num_return_sequences, which causes an issue when this expanded batch size is passed toGenerationMixin.generate. Specifically, ifbatch_size=5andnum_return_sequences > 1, the model generatesbatch_size * num_beamsbeams but retains only the most probable beam for each element of the original batch.Impact
This bug results in the
num_return_sequencesparameter not being compatible with both short-form and long-form generation. Users expecting multiple return sequences will only receive the most probable sequence, which may not meet the intended use case.cc @eustlb