diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 7983799ad8a7..a931358c0798 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -1072,14 +1072,14 @@ def new_chunk(): # 4/ Regular token # We just append to the list of all tokens so we can handle # merges later and decode into text. - current_tokens.append(token) if return_timestamps == "word": start_time = round(token_timestamps[i] + time_offset, 2) if i + 1 < len(token_timestamps): + current_tokens.append(token) end_time = round(token_timestamps[i + 1] + time_offset, 2) + current_token_timestamps.append((start_time, end_time)) else: end_time = None # should never happen - current_token_timestamps.append((start_time, end_time)) if "stride" in output: time_offset += chunk_len - stride_right diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 27b24448d5a2..dac69b9d3092 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -374,6 +374,48 @@ def test_decode_asr_with_word_level_timestamps(self): ) self.assertEqual(result, EXPECTED_OUTPUT) + # fmt: off + model_outputs = [ + { + 'stride': (30.0, 0.0, 5.0), + 'tokens': np.array([[286, 478, 2633, 760, 420, 2633, 264, 558, 2372, 13, 286, 841, 264, 596, 346, 13, 583, 406, 1547, 281]]), + 'token_timestamps': np.array( + [[23.88, 24.06, 24.06, 24.3, 24.54, 24.72, 24.98, 25.2, 25.36, 25.62, + 25.66, 25.8, 26.06, 26.26, 26.34, 26.48, 26.52, 26.72, 26.86, 27.08]]) + }, + { + 'stride': (10.0075, 5.0, 0.0), + 'tokens': np.array([[2633, 6385, 286, 478, 2633, 264, 558, 2372, 286, 841, 264, 4588, 457, 406, 1547, 281, 652, 385, 605, 493]]), + 'token_timestamps': np.array( + [[4.12, 4.32, 4.58, 4.76, 4.84, 4.9, 5.2, 5.36, 5.62, 5.82, + 6.02, 6.26, 6.48, 6.74, 6.86, 7.08, 7.32, 7.42, 7.66, 7.8]]) + } + ] + # fmt: on + + result = tokenizer._decode_asr( + model_outputs, return_timestamps="word", return_language=False, time_precision=0.02 + ) + + EXPECTED_OUTPUT = ( + " ofectjoy knowocjoy sace threat.ublic s influencept Lians anoryusical", + { + "chunks": [ + {"text": " ofectjoy", "timestamp": (23.88, 24.3)}, + {"text": " knowocjoy", "timestamp": (24.3, 24.98)}, + {"text": " sace", "timestamp": (24.98, 25.36)}, + {"text": " threat", "timestamp": (25.36, 25.62)}, + {"text": ".ublic", "timestamp": (25.62, 26.02)}, + {"text": " s", "timestamp": (26.02, 26.26)}, + {"text": " influencept", "timestamp": (26.26, 26.74)}, + {"text": " Lians", "timestamp": (26.74, 27.08)}, + {"text": " anoryusical", "timestamp": (27.08, 27.8)}, + ] + }, + ) + + self.assertEqual(result, EXPECTED_OUTPUT) + class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): checkpoint_name = "openai/whisper-small.en"