From 751f88abdd74a6c16ce040c84ee122c73974488c Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 31 Oct 2024 12:30:36 +0100 Subject: [PATCH 01/17] handle single timestamp ending --- .../models/whisper/tokenization_whisper.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 0a6eb75c55f6..bd9a4a47b3a0 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -528,7 +528,7 @@ def basic_normalize(text, remove_diacritics=False): normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics) return normalizer(text) - def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str: + def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500) -> str: """ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". @@ -538,6 +538,8 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre cur_max_timestamp = 0.0 prev_segments_len = 0.0 + # track if last timestamp was single ending + penultimate_timestamp = 0.0 for token in token_ids: if token >= timestamp_begin: @@ -545,8 +547,13 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre if timestamp < cur_max_timestamp: # next segment has started - prev_segments_len += cur_max_timestamp + last_was_single_ending = not cur_max_timestamp == penultimate_timestamp + if last_was_single_ending: + prev_segments_len += time_precision * segment_size + else: + prev_segments_len += cur_max_timestamp + penultimate_timestamp = cur_max_timestamp cur_max_timestamp = timestamp outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>") From 000dccdb6476f364e7997578cb3e25b570b6e838 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 31 Oct 2024 12:30:58 +0100 Subject: [PATCH 02/17] include last timestamp token --- src/transformers/models/whisper/generation_whisper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 0ecdcb4dbdea..8217235c39af 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1799,6 +1799,9 @@ def _retrieve_segment( segments = [] if single_timestamp_ending: slices.append(len(seek_sequence)) + else: + # we want to include the last timestamp token in the last segment to know it was no single ending + slices[-1] += 1 last_slice = 0 # Add each segment to list of all segments From 70c8aacab51db60ac3f398c9a701032cdbbe35a5 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 31 Oct 2024 15:33:34 +0100 Subject: [PATCH 03/17] handle single timestamp ending --- .../models/whisper/tokenization_whisper.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index bd9a4a47b3a0..658503bc3f78 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -565,7 +565,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre ] return "".join(outputs) - def _compute_offsets(self, token_ids, time_precision=0.02): + def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): """ Compute offsets for a given tokenized input @@ -574,6 +574,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, *optional*, defaults to 0.02): The time ratio to convert from token to time. + segment_size (`int`, *optional*, defaults to 1500): + The number of features in the input mel spectrogram. """ offsets = [] # ensure torch tensor of token ids is placed on cpu @@ -604,7 +606,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02): if start_timestamp_position < cur_max_timestamp: # next segment has started - prev_segments_len += cur_max_timestamp + # here in the worst case we have [<|start_timestamp_position before|>, <|cur_max_timestamp|>, <|start_timestamp_position|>], so last_slice (idx of start_timestamp_position) - 2 is safe + is_single_ending = not token_ids[last_slice - 2] == token_ids[last_slice - 1] + if is_single_ending: + prev_segments_len += segment_size + else: + prev_segments_len += cur_max_timestamp cur_max_timestamp = end_timestamp_position From e7532066d48cfe37e09d7cb65c4e2f071b55b865 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 31 Oct 2024 19:54:38 +0100 Subject: [PATCH 04/17] avoid floating points arithm limitations --- src/transformers/models/whisper/tokenization_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 658503bc3f78..d4543450eba3 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -623,8 +623,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): { "text": text, "timestamp": ( - (start_timestamp_position + prev_segments_len) * time_precision, - (end_timestamp_position + prev_segments_len) * time_precision, + start_timestamp_position * time_precision + prev_segments_len * time_precision, + end_timestamp_position * time_precision + prev_segments_len * time_precision, ), } ) From c53fb2cec8f1bb89217fc08187dc44769636e53a Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 31 Oct 2024 19:55:02 +0100 Subject: [PATCH 05/17] ensure float64 operations --- src/transformers/models/whisper/generation_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 8217235c39af..e382af0bf16f 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -629,7 +629,7 @@ def generate( cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, ) - time_offset = seek * time_precision / input_stride + time_offset = seek.to(torch.float64) * time_precision / input_stride seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) # 6.2 cut out next 30s segment from input features From 185fb5551044e492d2f66f4f16a8d308bf5d2da6 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 31 Oct 2024 19:55:15 +0100 Subject: [PATCH 06/17] new test --- tests/models/whisper/test_modeling_whisper.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 12aedaca8cf9..99c4f32ea320 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2128,6 +2128,93 @@ def test_tiny_longform_timestamps_generation(self): transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT) + @slow + def test_small_longform_timestamps_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-small.en") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en") + model.to(torch_device) + + dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") + sample = dataset[0]["audio"]["array"] + sampling_rate = dataset[0]["audio"]["sampling_rate"] + + sample = [*sample[:15 * sampling_rate], *np.zeros(16 * sampling_rate).tolist(), *sample[15 * sampling_rate:]] + sample = np.array(sample) + + input_features = processor( + sample, + sampling_rate=16_000, + padding="longest", + truncation=False, + return_attention_mask=True, + return_tensors="pt", + ).input_features + + input_features = input_features.to(torch_device) + generated_ids = model.generate(input_features, return_timestamps=True, return_segments=True) + + EXPECTED_TRANSCRIPT = [ + { + "text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", + "timestamp": (0.0, 6.38), + }, + { + "text": " Nor is Mr. Quilter's manner less interesting than his matter.", + "timestamp": (6.38, 11.32), + }, + { + "text": " He tells us that at this festive season of the year,", + "timestamp": (11.32, 15.0), + }, + { + "text": " With Christmas and roast beef looming before us, similes drawn from eating and its results", + "timestamp": (30.0, 36.76), + }, + { + "text": " occur most readily to the mind.", + "timestamp": (36.76, 39.80), + }, + { + "text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and", + "timestamp": (39.80, 45.36), + }, + { + "text": " can discover in it but little of rocky Ithaca.", + "timestamp": (45.36, 49.0), + }, + { + "text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles", + "timestamp": (49.0, 56.28), + }, + { + "text": " are as national as a jingo poem. Mr. Burkett fosters landscape's smile at one much in", + "timestamp": (56.28, 64.12), + }, + { + "text": " the same way that Mr. Karker used to flash his teeth. And Mr. John Collier gives his", + "timestamp": (64.12, 70.76), + }, + { + "text": " sitter a cheerful slap on the back before he says, like a shampoo or in a Turkish bath,", + "timestamp": (70.76, 77.16), + }, + { + "text": " Next Man", + "timestamp": (77.16, 78.16), + }, + ] + + transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True) + self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT) + + transcript_segments = [ + { + "text": processor.decode(seg["tokens"], skip_special_tokens=True), + "timestamp": (seg["start"].item(), seg["end"].item()) + } for seg in generated_ids["segments"][0] + ] + self.assertEqual(transcript_segments, EXPECTED_TRANSCRIPT) + @slow def test_large_timestamp_generation(self): set_seed(0) From 429904c05bed2190f8baaa7b8a54a33ef710d3a6 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 31 Oct 2024 20:05:53 +0100 Subject: [PATCH 07/17] make fixup --- src/transformers/models/whisper/tokenization_whisper.py | 4 +++- tests/models/whisper/test_modeling_whisper.py | 9 +++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index d4543450eba3..824a9839b84b 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -528,7 +528,9 @@ def basic_normalize(text, remove_diacritics=False): normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics) return normalizer(text) - def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500) -> str: + def _decode_with_timestamps( + self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500 + ) -> str: """ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 99c4f32ea320..e96f5953f7e2 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2133,12 +2133,12 @@ def test_small_longform_timestamps_generation(self): processor = WhisperProcessor.from_pretrained("openai/whisper-small.en") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en") model.to(torch_device) - + dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") sample = dataset[0]["audio"]["array"] sampling_rate = dataset[0]["audio"]["sampling_rate"] - sample = [*sample[:15 * sampling_rate], *np.zeros(16 * sampling_rate).tolist(), *sample[15 * sampling_rate:]] + sample = [*sample[: 15 * sampling_rate], *np.zeros(16 * sampling_rate).tolist(), *sample[15 * sampling_rate :]] sample = np.array(sample) input_features = processor( @@ -2210,8 +2210,9 @@ def test_small_longform_timestamps_generation(self): transcript_segments = [ { "text": processor.decode(seg["tokens"], skip_special_tokens=True), - "timestamp": (seg["start"].item(), seg["end"].item()) - } for seg in generated_ids["segments"][0] + "timestamp": (seg["start"].item(), seg["end"].item()), + } + for seg in generated_ids["segments"][0] ] self.assertEqual(transcript_segments, EXPECTED_TRANSCRIPT) From 1c7224433b68cec2bb1a37ae6574aa012171b9f3 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 31 Oct 2024 20:13:05 +0100 Subject: [PATCH 08/17] make copies --- .../whisper/tokenization_whisper_fast.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 66cf412cc2a8..8a155e5bd64c 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -169,7 +169,9 @@ def _encode_plus(self, *args, **kwargs) -> BatchEncoding: return super()._encode_plus(*args, **kwargs) # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps - def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str: + def _decode_with_timestamps( + self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500 + ) -> str: """ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". @@ -179,6 +181,8 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre cur_max_timestamp = 0.0 prev_segments_len = 0.0 + # track if last timestamp was single ending + penultimate_timestamp = 0.0 for token in token_ids: if token >= timestamp_begin: @@ -186,8 +190,13 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre if timestamp < cur_max_timestamp: # next segment has started - prev_segments_len += cur_max_timestamp + last_was_single_ending = not cur_max_timestamp == penultimate_timestamp + if last_was_single_ending: + prev_segments_len += time_precision * segment_size + else: + prev_segments_len += cur_max_timestamp + penultimate_timestamp = cur_max_timestamp cur_max_timestamp = timestamp outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>") @@ -200,7 +209,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre return "".join(outputs) # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets - def _compute_offsets(self, token_ids, time_precision=0.02): + def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): """ Compute offsets for a given tokenized input @@ -209,6 +218,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, *optional*, defaults to 0.02): The time ratio to convert from token to time. + segment_size (`int`, *optional*, defaults to 1500): + The number of features in the input mel spectrogram. """ offsets = [] # ensure torch tensor of token ids is placed on cpu @@ -239,7 +250,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02): if start_timestamp_position < cur_max_timestamp: # next segment has started - prev_segments_len += cur_max_timestamp + # here in the worst case we have [<|start_timestamp_position before|>, <|cur_max_timestamp|>, <|start_timestamp_position|>], so last_slice (idx of start_timestamp_position) - 2 is safe + is_single_ending = not token_ids[last_slice - 2] == token_ids[last_slice - 1] + if is_single_ending: + prev_segments_len += segment_size + else: + prev_segments_len += cur_max_timestamp cur_max_timestamp = end_timestamp_position @@ -251,8 +267,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): { "text": text, "timestamp": ( - (start_timestamp_position + prev_segments_len) * time_precision, - (end_timestamp_position + prev_segments_len) * time_precision, + start_timestamp_position * time_precision + prev_segments_len * time_precision, + end_timestamp_position * time_precision + prev_segments_len * time_precision, ), } ) From 7d6f9b4711e6ee6924abf2e5c2bd070af0767e6e Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 4 Nov 2024 18:00:38 +0100 Subject: [PATCH 09/17] handle edge case double tokens ending with different tokens --- .../models/whisper/generation_whisper.py | 14 ++++++++++---- .../models/whisper/tokenization_whisper.py | 16 ++++++++++------ .../models/whisper/tokenization_whisper_fast.py | 16 ++++++++++------ 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index e382af0bf16f..845d895efdee 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -308,6 +308,7 @@ def generate( num_segment_frames: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, time_precision: float = 0.02, + time_precision_features: float = 0.01, return_token_timestamps: Optional[bool] = None, return_segments: bool = False, return_dict_in_generate: Optional[bool] = None, @@ -417,6 +418,8 @@ def generate( time_precision (`int`, *optional*, defaults to 0.02): The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts for 20 ms. + time_precision_features (`int`, *optional*, defaults to 0.01): + The duration represented by a feature frame in seconds. return_token_timestamps (`bool`, *optional*): Whether to return token-level timestamps with the text. This can be used with or without the `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into @@ -718,6 +721,7 @@ def generate( timestamp_begin=timestamp_begin, seek_num_frames=seek_num_frames, time_precision=time_precision, + time_precision_features=time_precision_features, input_stride=input_stride, prev_idx=prev_i, idx=i, @@ -1778,6 +1782,7 @@ def _retrieve_segment( timestamp_begin, seek_num_frames, time_precision, + time_precision_features, input_stride, prev_idx, idx, @@ -1805,10 +1810,11 @@ def _retrieve_segment( last_slice = 0 # Add each segment to list of all segments - for current_slice in slices: + for i, current_slice in enumerate(slices): + is_last_slice = i == len(slices) - 1 sliced_tokens = seek_sequence[last_slice:current_slice] start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin - end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin + end_timestamp_pos = sliced_tokens[-1 if not is_last_slice else -2].item() - timestamp_begin segments.append( { "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, @@ -1830,13 +1836,13 @@ def _retrieve_segment( # otherwise, ignore the unfinished segment and seek to the last timestamp # here we throw away all predictions after the last predicted "end of segment" # since we are cutting right in the middle of an audio - last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin + last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin segment_offset = last_timestamp_pos * input_stride else: # If whisper does not predict any "end of segment" token, then # the whole decoding is considered a segment and we add it to the list of segments timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()] - last_timestamp_pos = seek_num_frames[prev_idx] + last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision) if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: # no consecutive timestamps but it has a timestamp; use the last one. last_timestamp_pos = timestamps[-1].item() - timestamp_begin diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 824a9839b84b..e537ef95da67 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -540,20 +540,23 @@ def _decode_with_timestamps( cur_max_timestamp = 0.0 prev_segments_len = 0.0 - # track if last timestamp was single ending penultimate_timestamp = 0.0 - for token in token_ids: + for i, token in enumerate(token_ids): if token >= timestamp_begin: timestamp = float((token - timestamp_begin) * time_precision) if timestamp < cur_max_timestamp: # next segment has started - last_was_single_ending = not cur_max_timestamp == penultimate_timestamp + last_was_single_ending = i >= 2 and not ( + token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin + ) if last_was_single_ending: prev_segments_len += time_precision * segment_size else: - prev_segments_len += cur_max_timestamp + cur_max_timestamp = penultimate_timestamp + prev_segments_len += penultimate_timestamp + outputs = outputs[:-2] penultimate_timestamp = cur_max_timestamp cur_max_timestamp = timestamp @@ -608,8 +611,9 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): if start_timestamp_position < cur_max_timestamp: # next segment has started - # here in the worst case we have [<|start_timestamp_position before|>, <|cur_max_timestamp|>, <|start_timestamp_position|>], so last_slice (idx of start_timestamp_position) - 2 is safe - is_single_ending = not token_ids[last_slice - 2] == token_ids[last_slice - 1] + is_single_ending = last_slice >= 2 and not ( + token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin + ) if is_single_ending: prev_segments_len += segment_size else: diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 8a155e5bd64c..f0383cb0def7 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -181,20 +181,23 @@ def _decode_with_timestamps( cur_max_timestamp = 0.0 prev_segments_len = 0.0 - # track if last timestamp was single ending penultimate_timestamp = 0.0 - for token in token_ids: + for i, token in enumerate(token_ids): if token >= timestamp_begin: timestamp = float((token - timestamp_begin) * time_precision) if timestamp < cur_max_timestamp: # next segment has started - last_was_single_ending = not cur_max_timestamp == penultimate_timestamp + last_was_single_ending = i >= 2 and not ( + token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin + ) if last_was_single_ending: prev_segments_len += time_precision * segment_size else: - prev_segments_len += cur_max_timestamp + cur_max_timestamp = penultimate_timestamp + prev_segments_len += penultimate_timestamp + outputs = outputs[:-2] penultimate_timestamp = cur_max_timestamp cur_max_timestamp = timestamp @@ -250,8 +253,9 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): if start_timestamp_position < cur_max_timestamp: # next segment has started - # here in the worst case we have [<|start_timestamp_position before|>, <|cur_max_timestamp|>, <|start_timestamp_position|>], so last_slice (idx of start_timestamp_position) - 2 is safe - is_single_ending = not token_ids[last_slice - 2] == token_ids[last_slice - 1] + is_single_ending = last_slice >= 2 and not ( + token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin + ) if is_single_ending: prev_segments_len += segment_size else: From 937cd2a1ffefaeb5bafe14fd63d576bba0c69d95 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 4 Nov 2024 18:46:06 +0100 Subject: [PATCH 10/17] handle single timestamp ending --- src/transformers/models/whisper/generation_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 845d895efdee..dc0ee59e3e65 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1814,7 +1814,7 @@ def _retrieve_segment( is_last_slice = i == len(slices) - 1 sliced_tokens = seek_sequence[last_slice:current_slice] start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin - end_timestamp_pos = sliced_tokens[-1 if not is_last_slice else -2].item() - timestamp_begin + end_timestamp_pos = sliced_tokens[-1 if not is_last_slice or single_timestamp_ending else -2].item() - timestamp_begin segments.append( { "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, From 9b1d51edb82ad98d4dff9d1414b9fe8b5802753b Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 4 Nov 2024 19:06:05 +0100 Subject: [PATCH 11/17] make fixup --- src/transformers/models/whisper/generation_whisper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index dc0ee59e3e65..68af102875cf 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1814,7 +1814,9 @@ def _retrieve_segment( is_last_slice = i == len(slices) - 1 sliced_tokens = seek_sequence[last_slice:current_slice] start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin - end_timestamp_pos = sliced_tokens[-1 if not is_last_slice or single_timestamp_ending else -2].item() - timestamp_begin + end_timestamp_pos = ( + sliced_tokens[-1 if not is_last_slice or single_timestamp_ending else -2].item() - timestamp_begin + ) segments.append( { "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, From a3cbe9f4a80a9132e9c5680d16525cdc2587ae16 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Tue, 5 Nov 2024 19:04:37 +0100 Subject: [PATCH 12/17] handle conditioning on prev segments --- src/transformers/models/whisper/generation_whisper.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 68af102875cf..b0b715bd4c99 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1669,6 +1669,7 @@ def _prepare_decoder_input_ids( config, device, suppress_tokens, + timestamp_begin, kwargs, ): if "decoder_input_ids" in kwargs: @@ -1688,6 +1689,14 @@ def _prepare_decoder_input_ids( # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] + for segments in active_segments: + for seg in segments: + if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin: + # the segment finishes with two timestamp tokens + # we need to ignore the last timestamp token + # see https://github.com/huggingface/transformers/pull/34537 + seg["tokens"] = seg["tokens"][:-1] + if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments": prev_ids = prompt_ids else: From 7c0da36b884df4706478ffc98881fde98b4e9f8f Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Tue, 5 Nov 2024 19:45:35 +0100 Subject: [PATCH 13/17] fix --- src/transformers/models/whisper/generation_whisper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index b0b715bd4c99..5fe111604197 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -661,6 +661,7 @@ def generate( config=self.config, device=init_tokens.device, suppress_tokens=suppress_tokens, + timestamp_begin=timestamp_begin, kwargs=kwargs, ) From 4a21249a212d1850d86e839193a1522868d3c7dc Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 21 Nov 2024 18:15:55 +0100 Subject: [PATCH 14/17] Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> --- src/transformers/models/whisper/generation_whisper.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 5fe111604197..97b36e2fc695 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1824,9 +1824,8 @@ def _retrieve_segment( is_last_slice = i == len(slices) - 1 sliced_tokens = seek_sequence[last_slice:current_slice] start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin - end_timestamp_pos = ( - sliced_tokens[-1 if not is_last_slice or single_timestamp_ending else -2].item() - timestamp_begin - ) + idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2 + end_timestamp_pos = sliced_tokens[idx_sliced_tokens].item() - timestamp_begin segments.append( { "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, From e8f2f690b7305c572c23c2a6c419e0003b9f0126 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 27 Nov 2024 11:55:10 +0100 Subject: [PATCH 15/17] [run-slow] whisper From 5fba3e0f0d9b422ce7b1aae9a527211749b89fce Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 5 Dec 2024 12:22:36 +0100 Subject: [PATCH 16/17] don't call item() to avoid unnecessary sync --- .../models/whisper/generation_whisper.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 97b36e2fc695..be6412b132e7 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1823,13 +1823,13 @@ def _retrieve_segment( for i, current_slice in enumerate(slices): is_last_slice = i == len(slices) - 1 sliced_tokens = seek_sequence[last_slice:current_slice] - start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin + start_timestamp_pos = sliced_tokens[0] - timestamp_begin idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2 - end_timestamp_pos = sliced_tokens[idx_sliced_tokens].item() - timestamp_begin + end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin segments.append( { - "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, - "end": time_offset[prev_idx] + end_timestamp_pos * time_precision, + "start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision, + "end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision, "tokens": sliced_tokens, "result": seek_outputs[idx], } @@ -1854,13 +1854,13 @@ def _retrieve_segment( # the whole decoding is considered a segment and we add it to the list of segments timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()] last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision) - if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: + if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin: # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = timestamps[-1].item() - timestamp_begin + last_timestamp_pos = timestamps[-1] - timestamp_begin segments = [ { "start": time_offset[prev_idx], - "end": time_offset[prev_idx] + last_timestamp_pos * time_precision, + "end": time_offset[prev_idx] + last_timestamp_pos.to(torch.float64) * time_precision, "tokens": seek_sequence, "result": seek_outputs[idx], } From 88587bb510968ab924c8530e90263055ce7843a9 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 5 Dec 2024 13:28:11 +0100 Subject: [PATCH 17/17] fix --- src/transformers/models/whisper/generation_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index be6412b132e7..2f58375f3de7 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1856,11 +1856,11 @@ def _retrieve_segment( last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision) if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin: # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = timestamps[-1] - timestamp_begin + last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64) segments = [ { "start": time_offset[prev_idx], - "end": time_offset[prev_idx] + last_timestamp_pos.to(torch.float64) * time_precision, + "end": time_offset[prev_idx] + last_timestamp_pos * time_precision, "tokens": seek_sequence, "result": seek_outputs[idx], }