Skip to content

[Whisper] 🚨 Fix pipeline word timestamp: timestamp token is end of token time !!!#36632

Merged
eustlb merged 24 commits intohuggingface:mainfrom
eustlb:whisper-pipeline-fix-word-timestamps
Jun 27, 2025
Merged

[Whisper] 🚨 Fix pipeline word timestamp: timestamp token is end of token time !!!#36632
eustlb merged 24 commits intohuggingface:mainfrom
eustlb:whisper-pipeline-fix-word-timestamps

Conversation

@eustlb
Copy link
Copy Markdown
Contributor

@eustlb eustlb commented Mar 10, 2025

Fixes #33552 #36228

What is happening? After a careful review of the original OpenAI Whisper codebase:

1️⃣ Context
OpenAI’s implementation follows a slightly different approach:
• First, they compute text tokens.
• Then, they redo a forward pass to retrieve cross-attention weights (which is inefficient—hence our different approach).
• Their forward pass takes as input:
[SOT sequence + no_timestamps token + all text tokens + EOS token]
• A hook retrieves cross-attention weights, meaning each token gets its cross-attention values (shape: [num_heads, 1500]).
• After scaling operations and dynamic time warping, they compute alignments between each token and its corresponding audio sequence length index (a value between 0 and 1499).
• Since each token represents 0.02 sec of audio, it can be mapped to a timestamp.

2️⃣ The important part
• These timestamp values are used as end-of-word times when merging tokens into words.
• Each timestamp represents the timing for the end of a token.
• But wait—how do they determine both start and end times when boundaries require N+1 timestamps for N tokens? 🤔
• Simple: they retrieve timestamps for the N text tokens and use the no_timestamps token as a boundary marker (which always ends up as 0.0s).

What is incorrect in our implem?

Our pipeline incorrectly treated timestamp tokens as start times instead of end times.
Moreover, token_timestamps are not correctly aligned in the current implementation:

timestamps[batch_idx, 1:] = torch.tensor(jump_times)

makes that last jump_times (corresponding to token_timestamps) is associated with eos token, while by design the eos token cannot have a token timestamp (remember that we do not have access to it's cross attention weights). Instead, it is better to replicate to last predicted token timestamp for the eos token. And this is not equivalent to what is done in the current implem where we take token timestamps as start of time and unalign by one token timestamps and tokens since this set up makes that we loose the last token timestamp in the process (it is associated with eos token and then cut out when concatenating sequences).

Also, we take into account the decoder_input_ids when computing the DTW which can negatively impact it's precision while OAI implem doesn't

The fix this PR brings

This PR fixes that by:

  1. Setting the first timestamp to 0.0s in the pipeline, similar to OpenAI’s implementation.
  2. Correctly using timestamp tokens as end times instead of start times and correctly aligning them with tokens.
  3. Skip decoder_input_ids when computing DTW and set them as 0.0s

Other changes

num_frames kwarg is broken (and was not documented anyway........)!
The kwarg return_token_timestamps for the processor, supposed to add num_frames as a kwarg for Whisper's generate is not useful! The attention_mask can be used to infer the number of frames, and IMO it is not a good practice to silently require return_token_timestamps=True for the processor (it is not mentioned in Whisper's generate doc) to after having precise token timestamps in generate (and this is even why our test was not correct 🫠). Instead, we want the user to pass the attention mask to use return_token_timestamp, and warn if he does not!

🚨 Changes for the user

This is kinda breaking to the extent that token timestamps are now aligned with tokens and represent the end time of the token, while before they were all shifted by one and represented start time of token.

snippet
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
from datasets import load_dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = ds.sort("id").select(range(1))[:1]["audio"]
input_speech = [x["array"] for x in speech_samples]

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

inputs = processor(input_speech, return_tensors="pt", sampling_rate=16_000, return_attention_mask=True)

generate_outputs = model.generate(**inputs, return_token_timestamps=True)
print(f"Token timestamps: {generate_outputs['token_timestamps']}")

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
)

pipeline_result = pipe(input_speech[0], return_timestamps="word")
chunks = pipeline_result.get('chunks', [])

print("\nWord timestamps:")
for chunk in chunks:
    start = chunk['timestamp'][0]
    end = chunk['timestamp'][1] if chunk['timestamp'][1] is not None else "END"
    word = chunk['text'].strip()
    print(f"{word:10} {start:4.2f}{end}")

current output with main

Token timestamps: [0.44, 0.82, 0.96, 1.12, 1.12, 1.22, 1.5, 1.72, 2.0, 2.34, 2.5, 2.66, 3.18, 3.58, 3.68, 3.8, 4.1, 4.32, 4.58, 4.94, 5.4]
Word timestamps:
Mr.        0.44 → 0.96
Quilter    0.96 → 1.22
is         1.22 → 1.5
the        1.50 → 1.72
apostle    1.72 → 2.0
of         2.00 → 2.34
the        2.34 → 2.5
middle     2.50 → 2.66
classes    2.66 → 3.18
and        3.18 → 3.58
we         3.58 → 3.68
are        3.68 → 3.8
glad       3.80 → 4.1
to         4.10 → 4.32
welcome    4.32 → 4.58
his        4.58 → 4.94
gospel.    4.94 → None (bug)

output with this PR

Token timestamps: [0.0, 0.96, 1.12, 1.12, 1.22, 1.5, 1.72, 1.98, 2.34, 2.5, 2.66, 3.2, 3.58, 3.68, 3.8, 4.1, 4.32, 4.58, 4.94, 5.42, 5.84]
Mr.        0.00 → 0.98
Quilter    0.98 → 1.22
is         1.22 → 1.5
the        1.50 → 1.72
apostle    1.72 → 1.98
of         1.98 → 2.34
the        2.34 → 2.5
middle     2.50 → 2.66
classes    2.66 → 3.2
and        3.20 → 3.56
we         3.56 → 3.68
are        3.68 → 3.8
glad       3.80 → 4.1
to         4.10 → 4.3
welcome    4.30 → 4.58
his        4.58 → 4.94
gospel.    4.94 → 5.84

@github-actions
Copy link
Copy Markdown
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@github-actions github-actions Bot marked this pull request as draft March 10, 2025 18:49
@eustlb eustlb marked this pull request as ready for review March 10, 2025 18:49
@github-actions github-actions Bot requested a review from ArthurZucker March 10, 2025 18:50
@eustlb eustlb changed the title [Whisper ]timestamp token is end of token time !!! [Whisper] Fix pipeline word timestamp: timestamp token is end of token time !!! Mar 10, 2025
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@eustlb eustlb changed the title [Whisper] Fix pipeline word timestamp: timestamp token is end of token time !!! [Whisper] 🚨 Fix pipeline word timestamp: timestamp token is end of token time !!! Mar 11, 2025
@csetanmayjain
Copy link
Copy Markdown

I applied the mentioned fixes in this PR on transformers==4.49.0, but the issue persists:

[torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])]
Interestingly, this problem only occurs with a few chunks, despite their properties (duration, format, context, etc.) being almost identical to other chunks that process successfully.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would rather put this breaking change in the next release!
Otherwise would be nice to have a "visual" example of the fix (to just see string of timestemps) but LGTM otherwise

Comment on lines +333 to +335
warnings.warn(
f"`return_token_timestamps` is deprecated for {self.__class__.__name__} and will be removed in Transformers v5. Use `return_attention_mask` instead, as the number of frames can be inferred from it."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logger.warning_once please

else:
generation_config.num_frames = torch.tensor(generation_config.num_frames)

warnings.warn(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use warning and logger hahha let'salso only warn once and warn as little as possible

@Noah-Jaffe
Copy link
Copy Markdown

please merge, i cant run whisper on any of my large files

@Noah-Jaffe
Copy link
Copy Markdown

tested on my machine, this does fix the issue

@eustlb
Copy link
Copy Markdown
Contributor Author

eustlb commented Apr 11, 2025

Addressed the logging comments + added a visual example of the fix to the PR's comment @ArthurZucker 😊
Ready to merge!

@eustlb
Copy link
Copy Markdown
Contributor Author

eustlb commented Apr 14, 2025

cc @Cyrilvallez, I've addressed Athur's comments, can you approve the PR please?

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM on principle, but looks like whisper tokenization tests are not happy. You probably need to fix the tests there to reflect the new behavior! 🤗

@eustlb eustlb enabled auto-merge (squash) June 27, 2025 12:15
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, LGTM! Time to ship it! 🤗🚀

@eustlb eustlb merged commit 2b85b6c into huggingface:main Jun 27, 2025
20 checks passed
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…ken time !!! (huggingface#36632)

* timestamp token is end of token time !!!

* ensure correct alignment between tokens and timestamp tokens

* ignore input tokens for DTW computation

* use num_frames to avoid token timestamp hallucinations

* token timestamps test updates !

* num_frames: deprecate and use attention_mask instead

* avoid breaking change

* fix the pipeline usage for chunk approach

* make style

* better logging

* better logging

* make style

* update tests with correct values
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…ken time !!! (huggingface#36632)

* timestamp token is end of token time !!!

* ensure correct alignment between tokens and timestamp tokens

* ignore input tokens for DTW computation

* use num_frames to avoid token timestamp hallucinations

* token timestamps test updates !

* num_frames: deprecate and use attention_mask instead

* avoid breaking change

* fix the pipeline usage for chunk approach

* make style

* better logging

* better logging

* make style

* update tests with correct values
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…ken time !!! (huggingface#36632)

* timestamp token is end of token time !!!

* ensure correct alignment between tokens and timestamp tokens

* ignore input tokens for DTW computation

* use num_frames to avoid token timestamp hallucinations

* token timestamps test updates !

* num_frames: deprecate and use attention_mask instead

* avoid breaking change

* fix the pipeline usage for chunk approach

* make style

* better logging

* better logging

* make style

* update tests with correct values
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…ken time !!! (huggingface#36632)

* timestamp token is end of token time !!!

* ensure correct alignment between tokens and timestamp tokens

* ignore input tokens for DTW computation

* use num_frames to avoid token timestamp hallucinations

* token timestamps test updates !

* num_frames: deprecate and use attention_mask instead

* avoid breaking change

* fix the pipeline usage for chunk approach

* make style

* better logging

* better logging

* make style

* update tests with correct values
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…ken time !!! (huggingface#36632)

* timestamp token is end of token time !!!

* ensure correct alignment between tokens and timestamp tokens

* ignore input tokens for DTW computation

* use num_frames to avoid token timestamp hallucinations

* token timestamps test updates !

* num_frames: deprecate and use attention_mask instead

* avoid breaking change

* fix the pipeline usage for chunk approach

* make style

* better logging

* better logging

* make style

* update tests with correct values
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…ken time !!! (huggingface#36632)

* timestamp token is end of token time !!!

* ensure correct alignment between tokens and timestamp tokens

* ignore input tokens for DTW computation

* use num_frames to avoid token timestamp hallucinations

* token timestamps test updates !

* num_frames: deprecate and use attention_mask instead

* avoid breaking change

* fix the pipeline usage for chunk approach

* make style

* better logging

* better logging

* make style

* update tests with correct values
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…ken time !!! (huggingface#36632)

* timestamp token is end of token time !!!

* ensure correct alignment between tokens and timestamp tokens

* ignore input tokens for DTW computation

* use num_frames to avoid token timestamp hallucinations

* token timestamps test updates !

* num_frames: deprecate and use attention_mask instead

* avoid breaking change

* fix the pipeline usage for chunk approach

* make style

* better logging

* better logging

* make style

* update tests with correct values
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Whisper] TypeError: '<=' not supported between instances of 'NoneType' and 'float'

6 participants