Fix whisper return_language with return_timestamp=word#39938
Fix whisper return_language with return_timestamp=word#39938Metric-Void wants to merge 5 commits intohuggingface:mainfrom
return_language with return_timestamp=word#39938Conversation
|
@Metric-Void thanks for the PR! Does the same example in #39404 (below), now return the expected timestamps and language? Could you share the output? import torch
from transformers import pipeline
from transformers.configuration_utils import PretrainedConfig
pipeline = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
torch_dtype=torch.float16,
config=PretrainedConfig(
attn_implementation="flash_attention_2"
)
)
result = pipeline("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac", return_language=True, return_timestamps='word')
result["chunks"]Regarding errors, the ones you are getting are related to missing libraries. When I run the tests, I get the following: # pytest tests/models/whisper
==== short test summary info ====
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelTest::test_multi_gpu_data_parallel_forward - TypeError: EncoderDecoderCache.__init__() missing 1 required positional argument: 'cross_attention_cache'
==== 1 failed, 467 passed, 285 skipped, 37 warnings in 418.89s (0:06:58) ========It is consistent before and after your changes so you haven't introduced any failing tests 👍 I would still wait for @eustlb's input on how to adjust Whisper's generate code. |
|
@ebezzam Yes, here's the output. #39404 was mine, so it only makes sense if it fixes that. More tests in https://gist.github.com/Metric-Void/79f7fcecc432d0e648af0fd896b5016a. Though it seemed like whisper (at least the tiny model) does not predict additional language tokens when the language changes. For the long I'm not sure if I should add tests to test this use case. There was such a test but was removed afterwards. transformers/tests/pipelines/test_pipelines_automatic_speech_recognition.py Lines 369 to 421 in b31d595 |
ebezzam
left a comment
There was a problem hiding this comment.
thanks @Metric-Void for sharing the outputs and tests!
Could you add some of your tests to test_modeling_whisper.py so that we don't get this problem again? Thanks 👍
|
[For maintainers] Suggested jobs to run (before merge) run-slow: whisper |
|
I've added tests to test_pipelines_automatic_speech_recognition.py, since this feature depends on calling from the pipeline. That's also where the test originally was. Also added comments to explain why two tokens. |
eustlb
left a comment
There was a problem hiding this comment.
Hey @Metric-Void, thanks for the work! 🤗
Actually, adding such a parameter isn’t necessary since the decoder input ids can be retrieved from tokens['segments'][0][0]['result']['sequences']. I’m strongly against adding it, as a lot of effort and thorough testing already went into fixing the Whisper generation logic and ensuring a 1-to-1 correspondence with the OAI implementation.
As you noticed, language changes aren’t detected because only the first 30 seconds of the input are used for language detection. Would you mind reworking the logic to remove changes to generation_whisper.py and instead handle the decoder input IDs directly as mentioned above?
If you prefer, I can also quickly open a PR to supersede this one and add you as a co-author.
|
@Metric-Void any updates on this? |
|
Thank you. I couldn't find a way to make the modification without changing the pipeline, risking compatibility with pipelines that don't have these two switches enabled. |
What does this PR do?
Fixes #39404.
Add a switch to Whisper.generate() that allows preserving some special tokens, then stripped in retrieve_segments to ensure timestamp alignment.
Tested on short and long audios. Tested on English, French, and Cantonese. Prediction and timestamp results align, and language is detected correctly.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@eustlb @ebezzam
Local failed tests (WSL2, RUN_SLOW)
I don't think any of these failures are related to this PR.