From 8530d3d9062f67bd449f10d76fc57717785b53cc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Nov 2023 17:56:17 +0000 Subject: [PATCH 01/75] finalize --- src/transformers/generation/logits_process.py | 3 ++ src/transformers/generation/utils.py | 2 + .../models/whisper/modeling_whisper.py | 49 +++++++++++++++---- 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 6c7b84f6ae67..c85b53b73956 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1537,6 +1537,9 @@ def __init__( ) self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) + def set_begin_index(self, begin_index): + self.begin_index = begin_index + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # suppress <|notimestamps|> which is handled by without_timestamps diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 077bc16aff8b..054328cf74b2 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -773,6 +773,8 @@ def _prepare_decoder_input_ids_for_generation( # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): pass + elif self.config.model_type in ['whisper']: + pass # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index e88fe3a6aacd..430d60b0d9b2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -15,6 +15,7 @@ """ PyTorch Whisper model.""" import math +from pickle import decode_long import warnings from typing import Optional, Tuple, Union @@ -1741,6 +1742,7 @@ def generate( task=None, language=None, is_multilingual=None, + condition_on_previous_tokens: Optional[bool] = None, prompt_ids: Optional[torch.Tensor] = None, num_segment_frames: Optional[int] = None, return_token_timestamps: Optional[bool] = None, @@ -2143,6 +2145,9 @@ def generate( return outputs + condition_on_previous_tokens = condition_on_previous_tokens or getattr(self.generation_config, "condition_on_previous_tokens", False) + self.generation_config.condition_on_previous_tokens = condition_on_previous_tokens + # 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated # timestamp tokens # 6.1 Set running parameters for while loop @@ -2214,6 +2219,17 @@ def generate( segment_input = torch.cat(segment_input, dim=0) + decoder_input_ids = None + if condition_on_previous_tokens and len(current_segments[0]) > 0: + # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 + cut_off_length = self.config.max_target_positions // 2 - 1 + active_segments = [current_segments[i] for i in new_cur_to_prev_index_map] + decoder_input_ids = self._pad_to_max_length(active_segments, self.generation_config.pad_token_id, padding="left") + decoder_input_ids = torch.cat([decoder_input_ids[:, -cut_off_length:], torch.ones((cur_bsz, 1), device=decoder_input_ids.device, dtype=torch.long) * self.config.decoder_start_token_id], dim=-1) + + timestamp_processor = [proc for proc in logits_processor if isinstance(proc, WhisperTimeStampLogitsProcessor)][0] + timestamp_processor.set_begin_index(decoder_input_ids.shape[-1]) + # 6.6 Batch generate current chunk seek_outputs = super().generate( segment_input, @@ -2223,6 +2239,7 @@ def generate( prefix_allowed_tokens_fn, synced_gpus, return_dict_in_generate=return_dict_in_generate, + decoder_input_ids=decoder_input_ids, **kwargs, ) @@ -2241,6 +2258,9 @@ def generate( else: seek_sequences = seek_outputs + if decoder_input_ids is not None: + seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1]:] + # 6.7 Loop over each decoded audio individually as each decoding can be of a different length for i, seek_sequence in enumerate(seek_sequences): prev_i = cur_to_prev_index_map[i] @@ -2273,25 +2293,34 @@ def generate( # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output - sequences = [] + sequences = self._pad_to_max_length(current_segments, self.generation_config.pad_token_id, padding="right") + + # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. + if return_segments: + return {"sequences": sequences, "segments": current_segments} + + return sequences + + @staticmethod + def _pad_to_max_length(current_segments, pad_token_id, padding="right"): max_total_length = 0 + sequences = [] + if padding not in ["right", "left"]: + raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}") + for current_segment_list in current_segments: sequences.append(torch.cat([d["tokens"] for d in current_segment_list], dim=-1)) max_total_length = max(max_total_length, len(sequences[-1])) - for i in range(batch_size): - sequences[i] = F.pad( - sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id - ) + for i in range(len(current_segments)): + pad_length = max_total_length - len(sequences[i]) + pad = (0, pad_length) if padding == "right" else (pad_length, 0) + sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id) sequences = torch.stack(sequences, dim=0) - - # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. - if return_segments: - return {"sequences": sequences, "segments": current_segments} - return sequences + @staticmethod def _retrieve_segment( seek_sequence, From c4826fdcb77093dc0d8a4340e875e119b923f2e0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Nov 2023 19:17:46 +0000 Subject: [PATCH 02/75] make fix copies whisper --- .../models/whisper/modeling_whisper.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 430d60b0d9b2..6652ba1aac40 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2145,8 +2145,7 @@ def generate( return outputs - condition_on_previous_tokens = condition_on_previous_tokens or getattr(self.generation_config, "condition_on_previous_tokens", False) - self.generation_config.condition_on_previous_tokens = condition_on_previous_tokens + condition_on_previous_tokens = condition_on_previous_tokens if condition_on_previous_tokens is not None else getattr(self.generation_config, "condition_on_previous_tokens", False) # 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated # timestamp tokens @@ -2225,11 +2224,32 @@ def generate( cut_off_length = self.config.max_target_positions // 2 - 1 active_segments = [current_segments[i] for i in new_cur_to_prev_index_map] decoder_input_ids = self._pad_to_max_length(active_segments, self.generation_config.pad_token_id, padding="left") - decoder_input_ids = torch.cat([decoder_input_ids[:, -cut_off_length:], torch.ones((cur_bsz, 1), device=decoder_input_ids.device, dtype=torch.long) * self.config.decoder_start_token_id], dim=-1) + + prev_start_of_text = 50360 # TODO(Patrick): Need to put in generation_config + one_tensor = torch.ones((cur_bsz, 1), device=decoder_input_ids.device, dtype=torch.long) + + decoder_input_ids = torch.cat([prev_start_of_text * one_tensor, decoder_input_ids[:, -cut_off_length:], self.config.decoder_start_token_id * one_tensor], dim=-1) timestamp_processor = [proc for proc in logits_processor if isinstance(proc, WhisperTimeStampLogitsProcessor)][0] timestamp_processor.set_begin_index(decoder_input_ids.shape[-1]) + passed_max_length = kwargs.get("max_length", None) + passed_max_new_tokens = kwargs.get("max_new_tokens", None) + max_length_config = getattr(self.generation_config, "max_length", None) + max_new_tokens_config = getattr(self.generation_config, "max_new_tokens", None) + + # Make sure we don't get larger than `max_length` + if passed_max_length is not None and passed_max_new_tokens is None: + kwargs["max_length"] = max(kwargs["max_length"] + cut_off_length + 1, self.config.max_target_positions) + logger.info(f"Increase max_length from {passed_max_length} to {kwargs['max_length']} since input is conditioned on previous segment.") + elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: + kwargs["max_length"] = max(self.generation_config.max_length + cut_off_length + 1, self.config.max_target_positions) + logger.info(f"Increase max_length from {max_length_config} to {kwargs['max_length']} since input is conditioned on previous segment.") + elif passed_max_new_tokens is not None and passed_max_new_tokens + cut_off_length + 2 > self.config.max_target_positions: + kwargs["max_new_tokens"] = self.config.max_target_positions - cut_off_length - 2 + elif passed_max_new_tokens is None and max_new_tokens_config is not None and max_new_tokens_config + cut_off_length + 2 > self.config.max_target_positions: + kwargs["max_new_tokens"] = self.config.max_target_positions - cut_off_length - 2 + # 6.6 Batch generate current chunk seek_outputs = super().generate( segment_input, @@ -2281,7 +2301,6 @@ def generate( time_offset=time_offset, timestamp_begin=timestamp_begin, seek_num_frames=seek_num_frames, - cur_bsz=cur_bsz, time_precision=time_precision, input_stride=input_stride, prev_idx=prev_i, @@ -2328,7 +2347,6 @@ def _retrieve_segment( time_offset, timestamp_begin, seek_num_frames, - cur_bsz, time_precision, input_stride, prev_idx, @@ -2337,7 +2355,7 @@ def _retrieve_segment( # find the predicted "end of segment" predictions of Whisper # "end of segment" predictions occur whenever Whisper predicts a timestamp token timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin) - single_timestamp_ending = timestamp_tokens[-2:].tolist() == cur_bsz * [[False, True]] + single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] # If whisper predicted a "end of segment" via a timestep token, let's go ever each From 2de5fe0b12d8f0456bf52df59f90d7569e1b7fef Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Nov 2023 19:18:39 +0000 Subject: [PATCH 03/75] [Tests] Make sure that we don't run tests mulitple times --- src/transformers/generation/utils.py | 2 +- .../models/whisper/modeling_whisper.py | 54 ++++++++++++++----- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 054328cf74b2..aafd46c64cfa 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -773,7 +773,7 @@ def _prepare_decoder_input_ids_for_generation( # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): pass - elif self.config.model_type in ['whisper']: + elif self.config.model_type in ["whisper"]: pass # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 6652ba1aac40..580b9b3e8655 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -15,7 +15,6 @@ """ PyTorch Whisper model.""" import math -from pickle import decode_long import warnings from typing import Optional, Tuple, Union @@ -2145,7 +2144,11 @@ def generate( return outputs - condition_on_previous_tokens = condition_on_previous_tokens if condition_on_previous_tokens is not None else getattr(self.generation_config, "condition_on_previous_tokens", False) + condition_on_previous_tokens = ( + condition_on_previous_tokens + if condition_on_previous_tokens is not None + else getattr(self.generation_config, "condition_on_previous_tokens", False) + ) # 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated # timestamp tokens @@ -2223,14 +2226,25 @@ def generate( # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 cut_off_length = self.config.max_target_positions // 2 - 1 active_segments = [current_segments[i] for i in new_cur_to_prev_index_map] - decoder_input_ids = self._pad_to_max_length(active_segments, self.generation_config.pad_token_id, padding="left") + decoder_input_ids = self._pad_to_max_length( + active_segments, self.generation_config.pad_token_id, padding="left" + ) prev_start_of_text = 50360 # TODO(Patrick): Need to put in generation_config one_tensor = torch.ones((cur_bsz, 1), device=decoder_input_ids.device, dtype=torch.long) - decoder_input_ids = torch.cat([prev_start_of_text * one_tensor, decoder_input_ids[:, -cut_off_length:], self.config.decoder_start_token_id * one_tensor], dim=-1) + decoder_input_ids = torch.cat( + [ + prev_start_of_text * one_tensor, + decoder_input_ids[:, -cut_off_length:], + self.config.decoder_start_token_id * one_tensor, + ], + dim=-1, + ) - timestamp_processor = [proc for proc in logits_processor if isinstance(proc, WhisperTimeStampLogitsProcessor)][0] + timestamp_processor = [ + proc for proc in logits_processor if isinstance(proc, WhisperTimeStampLogitsProcessor) + ][0] timestamp_processor.set_begin_index(decoder_input_ids.shape[-1]) passed_max_length = kwargs.get("max_length", None) @@ -2240,14 +2254,29 @@ def generate( # Make sure we don't get larger than `max_length` if passed_max_length is not None and passed_max_new_tokens is None: - kwargs["max_length"] = max(kwargs["max_length"] + cut_off_length + 1, self.config.max_target_positions) - logger.info(f"Increase max_length from {passed_max_length} to {kwargs['max_length']} since input is conditioned on previous segment.") + kwargs["max_length"] = max( + kwargs["max_length"] + cut_off_length + 1, self.config.max_target_positions + ) + logger.info( + f"Increase max_length from {passed_max_length} to {kwargs['max_length']} since input is conditioned on previous segment." + ) elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: - kwargs["max_length"] = max(self.generation_config.max_length + cut_off_length + 1, self.config.max_target_positions) - logger.info(f"Increase max_length from {max_length_config} to {kwargs['max_length']} since input is conditioned on previous segment.") - elif passed_max_new_tokens is not None and passed_max_new_tokens + cut_off_length + 2 > self.config.max_target_positions: + kwargs["max_length"] = max( + self.generation_config.max_length + cut_off_length + 1, self.config.max_target_positions + ) + logger.info( + f"Increase max_length from {max_length_config} to {kwargs['max_length']} since input is conditioned on previous segment." + ) + elif ( + passed_max_new_tokens is not None + and passed_max_new_tokens + cut_off_length + 2 > self.config.max_target_positions + ): kwargs["max_new_tokens"] = self.config.max_target_positions - cut_off_length - 2 - elif passed_max_new_tokens is None and max_new_tokens_config is not None and max_new_tokens_config + cut_off_length + 2 > self.config.max_target_positions: + elif ( + passed_max_new_tokens is None + and max_new_tokens_config is not None + and max_new_tokens_config + cut_off_length + 2 > self.config.max_target_positions + ): kwargs["max_new_tokens"] = self.config.max_target_positions - cut_off_length - 2 # 6.6 Batch generate current chunk @@ -2279,7 +2308,7 @@ def generate( seek_sequences = seek_outputs if decoder_input_ids is not None: - seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1]:] + seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1] :] # 6.7 Loop over each decoded audio individually as each decoding can be of a different length for i, seek_sequence in enumerate(seek_sequences): @@ -2339,7 +2368,6 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right"): sequences = torch.stack(sequences, dim=0) return sequences - @staticmethod def _retrieve_segment( seek_sequence, From 3106c518acf6aa73d3f51a13b1dd8487b051eeed Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Nov 2023 20:19:46 +0100 Subject: [PATCH 04/75] Update src/transformers/models/whisper/modeling_whisper.py --- src/transformers/models/whisper/modeling_whisper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 580b9b3e8655..73d690cb027e 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -15,7 +15,6 @@ """ PyTorch Whisper model.""" import math -import warnings from typing import Optional, Tuple, Union import numpy as np From cfc19981fd797ff4f3c1cdf14b19e63c1a5b6975 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Nov 2023 00:54:30 +0000 Subject: [PATCH 05/75] [Tests] Make sure that we don't run tests mulitple times --- src/transformers/models/whisper/modeling_whisper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 580b9b3e8655..2930403228dc 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2230,7 +2230,7 @@ def generate( active_segments, self.generation_config.pad_token_id, padding="left" ) - prev_start_of_text = 50360 # TODO(Patrick): Need to put in generation_config + prev_start_of_text = self.generation_config.suppress_tokens[-2] # TODO(Patrick): Need to put in generation_config one_tensor = torch.ones((cur_bsz, 1), device=decoder_input_ids.device, dtype=torch.long) decoder_input_ids = torch.cat( @@ -2241,6 +2241,8 @@ def generate( ], dim=-1, ) + # print(decoder_input_ids) + # import ipdb; ipdb.set_trace() timestamp_processor = [ proc for proc in logits_processor if isinstance(proc, WhisperTimeStampLogitsProcessor) From 00df894e8ef57be50048e970a72650659f97eb65 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 1 Dec 2023 16:30:47 +0000 Subject: [PATCH 06/75] fix more --- .../models/whisper/modeling_whisper.py | 49 ++++++++++++------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 6d0861197269..1c34a3541dfd 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2105,17 +2105,18 @@ def generate( generation_config.num_frames = kwargs.pop("num_frames") if generation_config.return_timestamps is True: + forced_decoder_ids = generation_config.forced_decoder_ids last_forced_decoder_ids = ( - generation_config.forced_decoder_ids[-1][-1] - if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids + forced_decoder_ids[-1][-1] + if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None else None ) if last_forced_decoder_ids == self.generation_config.no_timestamps_token_id: # remove no_timestamp to be forcefully generated if we want to return timestamps # this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly - forced_decoder_ids = generation_config.forced_decoder_ids[:-1] + forced_decoder_ids = forced_decoder_ids[:-1] if len(forced_decoder_ids) > 1 else None # Make sure that if list is empty we set it to None - generation_config.forced_decoder_ids = None if len(forced_decoder_ids) == 0 else forced_decoder_ids + generation_config.forced_decoder_ids = forced_decoder_ids timestamp_processor = [WhisperTimeStampLogitsProcessor(generation_config)] logits_processor = ( @@ -2179,6 +2180,17 @@ def generate( # batch size can decrease during the run cur_bsz = prev_bsz = batch_size + init_tokens = [self.generation_config.decoder_start_token_id] + if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: + i = 1 + while len(forced_decoder_ids) > i and forced_decoder_ids[0][0] == i: + init_tokens += [forced_decoder_ids[i - 1][1]] + forced_decoder_ids = forced_decoder_ids[1:] + i += 1 + + forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None + generation_config.forced_decoder_ids = forced_decoder_ids + # 6.2 Transcribe audio until we reach the end of all input audios while (seek < max_frames).any(): prev_bsz = cur_bsz @@ -2220,29 +2232,20 @@ def generate( segment_input = torch.cat(segment_input, dim=0) - decoder_input_ids = None + one_tensor = torch.ones((cur_bsz, 1), device=segment_input.device, dtype=torch.long) + decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) + if condition_on_previous_tokens and len(current_segments[0]) > 0: # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 cut_off_length = self.config.max_target_positions // 2 - 1 active_segments = [current_segments[i] for i in new_cur_to_prev_index_map] - decoder_input_ids = self._pad_to_max_length( + prev_tokens = self._pad_to_max_length( active_segments, self.generation_config.pad_token_id, padding="left" ) prev_start_of_text = self.generation_config.suppress_tokens[-2] # TODO(Patrick): Need to put in generation_config - one_tensor = torch.ones((cur_bsz, 1), device=decoder_input_ids.device, dtype=torch.long) - - decoder_input_ids = torch.cat( - [ - prev_start_of_text * one_tensor, - decoder_input_ids[:, -cut_off_length:], - self.config.decoder_start_token_id * one_tensor, - ], - dim=-1, - ) - # print(decoder_input_ids) - # import ipdb; ipdb.set_trace() + decoder_input_ids = torch.cat([prev_start_of_text * one_tensor, prev_tokens[:, -cut_off_length:], decoder_input_ids], dim=-1) timestamp_processor = [ proc for proc in logits_processor if isinstance(proc, WhisperTimeStampLogitsProcessor) ][0] @@ -2280,6 +2283,8 @@ def generate( ): kwargs["max_new_tokens"] = self.config.max_target_positions - cut_off_length - 2 + # print(decoder_input_ids) + # 6.6 Batch generate current chunk seek_outputs = super().generate( segment_input, @@ -2311,6 +2316,9 @@ def generate( if decoder_input_ids is not None: seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1] :] + # print("hf tokens", seek_sequences) + # import ipdb; ipdb.set_trace() + # 6.7 Loop over each decoded audio individually as each decoding can be of a different length for i, seek_sequence in enumerate(seek_sequences): prev_i = cur_to_prev_index_map[i] @@ -2340,6 +2348,8 @@ def generate( current_segments[prev_i] += segments seek[prev_i] += segment_offset + print(seek) + # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output sequences = self._pad_to_max_length(current_segments, self.generation_config.pad_token_id, padding="right") @@ -2415,12 +2425,14 @@ def _retrieve_segment( if single_timestamp_ending: # single timestamp at the end means no speech after the last timestamp. segment_offset = seek_num_frames[prev_idx] + print("single timestamp") else: # otherwise, ignore the unfinished segment and seek to the last timestamp # here we throw away all predictions after the last predicted "end of segment" # since we are cutting right in the middle of an audio last_timestamp_pos = seek_sequence[last_slice].item() - timestamp_begin segment_offset = last_timestamp_pos * input_stride + print("cut") else: # If whisper does not predict any "end of segment" token, then # the whole decoding is considered a segment and we add it to the list of segments @@ -2439,6 +2451,7 @@ def _retrieve_segment( } ] segment_offset = seek_num_frames[prev_idx] + print("all") return segments, segment_offset From 25bcd697167fc97f6297b2345e71a0d1b576a28b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 Dec 2023 13:03:48 +0000 Subject: [PATCH 07/75] improve --- .../models/whisper/modeling_whisper.py | 20 ++++++++-------- tests/models/whisper/test_modeling_whisper.py | 23 ++++++++++++------- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 1c34a3541dfd..182bd79fb327 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1740,7 +1740,7 @@ def generate( task=None, language=None, is_multilingual=None, - condition_on_previous_tokens: Optional[bool] = None, + condition_on_prev_tokens: Optional[bool] = None, prompt_ids: Optional[torch.Tensor] = None, num_segment_frames: Optional[int] = None, return_token_timestamps: Optional[bool] = None, @@ -2144,10 +2144,10 @@ def generate( return outputs - condition_on_previous_tokens = ( - condition_on_previous_tokens - if condition_on_previous_tokens is not None - else getattr(self.generation_config, "condition_on_previous_tokens", False) + condition_on_prev_tokens = ( + condition_on_prev_tokens + if condition_on_prev_tokens is not None + else getattr(self.generation_config, "condition_on_prev_tokens", False) ) # 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated @@ -2235,7 +2235,7 @@ def generate( one_tensor = torch.ones((cur_bsz, 1), device=segment_input.device, dtype=torch.long) decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) - if condition_on_previous_tokens and len(current_segments[0]) > 0: + if condition_on_prev_tokens and len(current_segments[0]) > 0: # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 cut_off_length = self.config.max_target_positions // 2 - 1 active_segments = [current_segments[i] for i in new_cur_to_prev_index_map] @@ -2348,7 +2348,7 @@ def generate( current_segments[prev_i] += segments seek[prev_i] += segment_offset - print(seek) + # print(seek) # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output @@ -2425,14 +2425,14 @@ def _retrieve_segment( if single_timestamp_ending: # single timestamp at the end means no speech after the last timestamp. segment_offset = seek_num_frames[prev_idx] - print("single timestamp") + # print("single timestamp") else: # otherwise, ignore the unfinished segment and seek to the last timestamp # here we throw away all predictions after the last predicted "end of segment" # since we are cutting right in the middle of an audio last_timestamp_pos = seek_sequence[last_slice].item() - timestamp_begin segment_offset = last_timestamp_pos * input_stride - print("cut") + # print("cut") else: # If whisper does not predict any "end of segment" token, then # the whole decoding is considered a segment and we add it to the list of segments @@ -2451,7 +2451,7 @@ def _retrieve_segment( } ] segment_offset = seek_num_frames[prev_idx] - print("all") + # print("all") return segments, segment_offset diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index f77d81d76e52..063f1606075a 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1311,7 +1311,7 @@ def test_generate_with_prompt_ids_max_length(self): model.generate(input_features, max_new_tokens=1, prompt_ids=prompt_ids) - def test_longform_generate_single_batch(self): + def _check_longform_generate_single_batch(self, condition_on_prev_tokens): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() model = WhisperForConditionalGeneration(config).eval().to(torch_device) @@ -1355,16 +1355,12 @@ def test_longform_generate_single_batch(self): # make sure that we only have the same begin token model.generation_config.max_initial_timestamp_index = 0 - outputs = model.generate(long_input_features, logits_processor=logits_processor, return_segments=True) + outputs = model.generate(long_input_features, logits_processor=logits_processor, condition_on_prev_tokens=condition_on_prev_tokens, return_segments=True) segments = outputs["segments"][0] for i, segment in enumerate(segments): assert segment["start"] <= segment["end"], "start has to be smaller equal end" - assert ( - segment["tokens"][0] == model.generation_config.decoder_start_token_id - or segment["tokens"][0] >= timestamp_begin - ), "First segment token should be a timestamp token" assert any( s > timestamp_begin for s in segment["tokens"][1:] ), f"At least one segment token should be a timestamp token, but not first., {segment['tokens']}" @@ -1372,7 +1368,13 @@ def test_longform_generate_single_batch(self): segment["tokens"].shape[-1] <= max_length ), "make sure that no segment is larger than max generation length" - def test_longform_generate_multi_batch(self): + def test_longform_generate_single_batch(self): + self._check_longform_generate_single_batch(condition_on_prev_tokens=False) + + def test_longform_generate_single_batch_cond_prev(self): + self._check_longform_generate_single_batch(condition_on_prev_tokens=True) + + def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() model = WhisperForConditionalGeneration(config).eval().to(torch_device) @@ -1426,7 +1428,7 @@ def test_longform_generate_multi_batch(self): ) ] outputs = model.generate( - long_input_features, attention_mask=attention_mask, logits_processor=logits_processor, return_segments=True + long_input_features, attention_mask=attention_mask, logits_processor=logits_processor, return_segments=True, condition_on_prev_tokens=condition_on_prev_tokens, ) tokens = outputs["sequences"][1] segments = outputs["segments"][1] @@ -1438,6 +1440,11 @@ def test_longform_generate_multi_batch(self): assert seg1["end"] == seg2["end"] assert seg1["tokens"].tolist() == seg2["tokens"].tolist() + def test_longform_generate_multi_batch(self): + self._check_longform_generate_multi_batch(condition_on_prev_tokens=False) + + def test_longform_generate_multi_batch_cond_prev(self): + self._check_longform_generate_multi_batch(condition_on_prev_tokens=True) @require_torch @require_torchaudio From 30a78a32d9acfc5d7e41662dbf0c6c8656900890 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 Dec 2023 15:28:16 +0000 Subject: [PATCH 08/75] improve --- .../models/whisper/modeling_whisper.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 182bd79fb327..23dec2884f50 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2283,7 +2283,7 @@ def generate( ): kwargs["max_new_tokens"] = self.config.max_target_positions - cut_off_length - 2 - # print(decoder_input_ids) + print("Init", decoder_input_ids[0, :6].tolist()) # 6.6 Batch generate current chunk seek_outputs = super().generate( @@ -2333,7 +2333,8 @@ def generate( num_paddings = (seek_sequence == self.generation_config.pad_token_id).sum() seek_sequence = seek_sequence[:-num_paddings] - segments, segment_offset = self._retrieve_segment( + # TODO(Patrick: delete cut type) + segments, segment_offset, cut_type = self._retrieve_segment( seek_sequence=seek_sequence, seek_outputs=seek_outputs, time_offset=time_offset, @@ -2348,7 +2349,7 @@ def generate( current_segments[prev_i] += segments seek[prev_i] += segment_offset - # print(seek) + print(f"{cut_type} seek {seek[0]}") # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output @@ -2425,14 +2426,14 @@ def _retrieve_segment( if single_timestamp_ending: # single timestamp at the end means no speech after the last timestamp. segment_offset = seek_num_frames[prev_idx] - # print("single timestamp") + cut_type = "single ending" else: # otherwise, ignore the unfinished segment and seek to the last timestamp # here we throw away all predictions after the last predicted "end of segment" # since we are cutting right in the middle of an audio last_timestamp_pos = seek_sequence[last_slice].item() - timestamp_begin segment_offset = last_timestamp_pos * input_stride - # print("cut") + cut_type = "cut" else: # If whisper does not predict any "end of segment" token, then # the whole decoding is considered a segment and we add it to the list of segments @@ -2451,9 +2452,9 @@ def _retrieve_segment( } ] segment_offset = seek_num_frames[prev_idx] - # print("all") + cut_type = "all" - return segments, segment_offset + return segments, segment_offset, cut_type def prepare_inputs_for_generation( self, From e3bff24f37defbb605b8885910d0e6931187e73c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 Dec 2023 16:52:35 +0100 Subject: [PATCH 09/75] improve further --- src/transformers/models/whisper/modeling_whisper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 23dec2884f50..79135c2468c7 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2313,8 +2313,9 @@ def generate( else: seek_sequences = seek_outputs - if decoder_input_ids is not None: - seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1] :] + if condition_on_prev_tokens is not None: + # remove all previously passed decoder input ids except start token + seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1] - 1:] # print("hf tokens", seek_sequences) # import ipdb; ipdb.set_trace() From 25c934524cf2e13b7b6b0235345e445cf23ad7f3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 Dec 2023 20:04:40 +0100 Subject: [PATCH 10/75] improve more --- src/transformers/generation/logits_process.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index c85b53b73956..aab7d505ded3 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1536,6 +1536,8 @@ def __init__( len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1 ) self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) + # TODO(Patrick, remove hardcoded setting) + self.max_initial_timestamp_index = 50 def set_begin_index(self, begin_index): self.begin_index = begin_index From cd7734bbe3aa4f38246e1cb14f201f1f29fa5724 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 Dec 2023 15:54:30 +0100 Subject: [PATCH 11/75] improve --- .../models/whisper/modeling_whisper.py | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 79135c2468c7..c544a8ed5929 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1741,6 +1741,9 @@ def generate( language=None, is_multilingual=None, condition_on_prev_tokens: Optional[bool] = None, + temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1.0, prompt_ids: Optional[torch.Tensor] = None, num_segment_frames: Optional[int] = None, return_token_timestamps: Optional[bool] = None, @@ -1902,6 +1905,9 @@ def generate( ``` """ + temperature = 0.0 + compression_ratio_threshold = None + logprob_threshold = None if "inputs" in kwargs: input_features = kwargs.pop("inputs") @@ -2180,6 +2186,10 @@ def generate( # batch size can decrease during the run cur_bsz = prev_bsz = batch_size + temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature + return_scores = compression_ratio_threshold is not None or logprob_threshold is not None + return_dict_in_generate = return_dict_in_generate or return_scores + init_tokens = [self.generation_config.decoder_start_token_id] if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: i = 1 @@ -2286,17 +2296,23 @@ def generate( print("Init", decoder_input_ids[0, :6].tolist()) # 6.6 Batch generate current chunk - seek_outputs = super().generate( - segment_input, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - return_dict_in_generate=return_dict_in_generate, - decoder_input_ids=decoder_input_ids, - **kwargs, - ) + for temperature in temperatures: + do_sample = temperature > 0.0 + + seek_outputs = super().generate( + segment_input, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + temperature=temperature, + do_sample=do_sample, + return_scores=return_scores, + return_dict_in_generate=return_dict_in_generate, + decoder_input_ids=decoder_input_ids, + **kwargs, + ) if return_token_timestamps and hasattr(generation_config, "alignment_heads"): num_frames = getattr(generation_config, "num_frames", None) @@ -2313,10 +2329,17 @@ def generate( else: seek_sequences = seek_outputs + if compression_ratio_threshold is not None: + pass + if condition_on_prev_tokens is not None: # remove all previously passed decoder input ids except start token seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1] - 1:] + if return_scores: + scores = seek_outputs["scores"] if return_scores else None + logprops = self._retrieve_logprobs(scores, seek_sequences) + # print("hf tokens", seek_sequences) # import ipdb; ipdb.set_trace() From 3cf1752053badc1f3d7cf7eb92805afa9fde2161 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 Dec 2023 14:54:47 +0000 Subject: [PATCH 12/75] fix more --- src/transformers/models/whisper/modeling_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 79135c2468c7..13a1ccfd01d2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2183,8 +2183,8 @@ def generate( init_tokens = [self.generation_config.decoder_start_token_id] if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: i = 1 - while len(forced_decoder_ids) > i and forced_decoder_ids[0][0] == i: - init_tokens += [forced_decoder_ids[i - 1][1]] + while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: + init_tokens += [forced_decoder_ids[0][1]] forced_decoder_ids = forced_decoder_ids[1:] i += 1 From 133d17e0aa6e39ee92ab0c70083051e155199989 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 Dec 2023 18:00:20 +0000 Subject: [PATCH 13/75] git commit and git push --- .../models/whisper/modeling_whisper.py | 105 ++++++++++++------ 1 file changed, 71 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8d0c1ac61da9..cbbaff480b8e 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1742,7 +1742,7 @@ def generate( is_multilingual=None, condition_on_prev_tokens: Optional[bool] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), - compression_ratio_threshold: Optional[float] = 2.4, + compression_ratio_threshold: Optional[float] = 2.0, logprob_threshold: Optional[float] = -1.0, prompt_ids: Optional[torch.Tensor] = None, num_segment_frames: Optional[int] = None, @@ -1905,10 +1905,6 @@ def generate( ``` """ - temperature = 0.0 - compression_ratio_threshold = None - logprob_threshold = None - if "inputs" in kwargs: input_features = kwargs.pop("inputs") warnings.warn( @@ -2187,8 +2183,10 @@ def generate( cur_bsz = prev_bsz = batch_size temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature - return_scores = compression_ratio_threshold is not None or logprob_threshold is not None - return_dict_in_generate = return_dict_in_generate or return_scores + temperature = temperatures[0] + + output_scores = logprob_threshold is not None + return_dict_in_generate = return_dict_in_generate or output_scores init_tokens = [self.generation_config.decoder_start_token_id] if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: @@ -2245,7 +2243,7 @@ def generate( one_tensor = torch.ones((cur_bsz, 1), device=segment_input.device, dtype=torch.long) decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) - if condition_on_prev_tokens and len(current_segments[0]) > 0: + if condition_on_prev_tokens and len(current_segments[0]) > 0 and temperature < 0.5: # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 cut_off_length = self.config.max_target_positions // 2 - 1 active_segments = [current_segments[i] for i in new_cur_to_prev_index_map] @@ -2268,14 +2266,14 @@ def generate( # Make sure we don't get larger than `max_length` if passed_max_length is not None and passed_max_new_tokens is None: - kwargs["max_length"] = max( + kwargs["max_length"] = min( kwargs["max_length"] + cut_off_length + 1, self.config.max_target_positions ) logger.info( f"Increase max_length from {passed_max_length} to {kwargs['max_length']} since input is conditioned on previous segment." ) elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: - kwargs["max_length"] = max( + kwargs["max_length"] = min( self.generation_config.max_length + cut_off_length + 1, self.config.max_target_positions ) logger.info( @@ -2308,40 +2306,59 @@ def generate( synced_gpus, temperature=temperature, do_sample=do_sample, - return_scores=return_scores, + output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, decoder_input_ids=decoder_input_ids, **kwargs, ) - if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - num_frames = getattr(generation_config, "num_frames", None) - seek_outputs["token_timestamps"] = self._extract_token_timestamps( - seek_outputs, generation_config.alignment_heads, num_frames=num_frames - ) + if return_token_timestamps and hasattr(generation_config, "alignment_heads"): + num_frames = getattr(generation_config, "num_frames", None) + seek_outputs["token_timestamps"] = self._extract_token_timestamps( + seek_outputs, generation_config.alignment_heads, num_frames=num_frames + ) - if return_dict_in_generate: - seek_sequences = seek_outputs["sequences"] - seek_outputs = [ - {k: v[i] for k, v in seek_outputs.items()} - for i in range(next(iter(seek_outputs.values())).size(0)) - ] - else: - seek_sequences = seek_outputs + if return_dict_in_generate: + def split_by_batch_index(values, key, batch_idx): + if key == "scores": + return list(v[batch_idx] for v in values) + return values[batch_idx] + + seek_sequences = seek_outputs["sequences"] + seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(cur_bsz)] + else: + seek_sequences = seek_outputs + + if condition_on_prev_tokens is not None: + # remove all previously passed decoder input ids except start token + seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1] - 1:] - if compression_ratio_threshold is not None: - pass + needs_fallback = False + if compression_ratio_threshold is not None: + compression_ratio = [seek_sequence.shape[0] / torch.unique(seek_sequence).shape[0] for seek_sequence in seek_sequences] - if condition_on_prev_tokens is not None: - # remove all previously passed decoder input ids except start token - seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1] - 1:] + #if seek.item() > 420000: + # import ipdb; ipdb.set_trace() - if return_scores: - scores = seek_outputs["scores"] if return_scores else None - logprops = self._retrieve_logprobs(scores, seek_sequences) + # TODO(PVP) only works for batch size = 1 currently + if compression_ratio[0] > compression_ratio_threshold: + print("fallback compression") + print("current temp", temperature) + + needs_fallback = True - # print("hf tokens", seek_sequences) - # import ipdb; ipdb.set_trace() + if logprob_threshold is not None: + scores = [s["scores"] for s in seek_outputs] if output_scores else None + logprobs = self._retrieve_avg_logprobs(scores, seek_sequences, self.config.eos_token_id) + + # TODO(PVP) only works for batch size = 1 currently + if logprobs[0] < logprob_threshold: + print("fallback logprob") + print("current temp", temperature) + needs_fallback = True + + if not needs_fallback: + break # 6.7 Loop over each decoded audio individually as each decoding can be of a different length for i, seek_sequence in enumerate(seek_sequences): @@ -2404,6 +2421,26 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right"): sequences = torch.stack(sequences, dim=0) return sequences + @staticmethod + def _retrieve_avg_logprobs(scores, tokens, eos_token_id): + scores = torch.stack([torch.stack(score) for score in scores]) + logprobs = F.log_softmax(scores.float(), dim=-1).to(scores.dtype) + tokens = tokens[:, -scores.shape[1]:] + + def get_log_prob(logprob, token): + token_logprob = logprob.gather(-1, token) * (token[:, -1] != eos_token_id) + return token_logprob + + sum_logprobs = sum(get_log_prob(logprobs[:, i], tokens[:, i: i+1]) for i in range(logprobs.shape[1])) + + lengths = (tokens != eos_token_id).sum(-1) + + avg_logprobs = torch.div(sum_logprobs, lengths) + return avg_logprobs + + + # print("hf tokens", seek_sequences) + @staticmethod def _retrieve_segment( seek_sequence, From e86745a7520a89a06dde4e6fcd57fcec84b099f9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 Dec 2023 18:00:44 +0000 Subject: [PATCH 14/75] fix more --- src/transformers/generation/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index aafd46c64cfa..24568f20ba1e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2624,7 +2624,10 @@ def greedy_search( if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits[:, -1, :] + try: + next_token_logits = outputs.logits[:, -1, :] + except: + import ipdb; ipdb.set_trace() # pre-process distribution next_tokens_scores = logits_processor(input_ids, next_token_logits) From cbae58c28e8ef06633ae1947856ed23e640b1c54 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 Dec 2023 18:01:12 +0000 Subject: [PATCH 15/75] fix more --- src/transformers/generation/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 24568f20ba1e..aafd46c64cfa 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2624,10 +2624,7 @@ def greedy_search( if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need - try: - next_token_logits = outputs.logits[:, -1, :] - except: - import ipdb; ipdb.set_trace() + next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_tokens_scores = logits_processor(input_ids, next_token_logits) From 3933896d1ac747e9c3d405e49bb71678b0bc8859 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 Dec 2023 18:56:01 +0000 Subject: [PATCH 16/75] fix more --- src/transformers/models/whisper/modeling_whisper.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index cbbaff480b8e..36ae28d381c4 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2321,8 +2321,11 @@ def generate( if return_dict_in_generate: def split_by_batch_index(values, key, batch_idx): if key == "scores": - return list(v[batch_idx] for v in values) - return values[batch_idx] + return list(v[batch_idx].cpu() for v in values) + if key == "past_key_values": + # we don't save `past_key_values` as this is too costly + return None + return values[batch_idx].cpu() seek_sequences = seek_outputs["sequences"] seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(cur_bsz)] @@ -2423,7 +2426,7 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right"): @staticmethod def _retrieve_avg_logprobs(scores, tokens, eos_token_id): - scores = torch.stack([torch.stack(score) for score in scores]) + scores = torch.stack([torch.stack(score) for score in scores]).to(tokens.device) logprobs = F.log_softmax(scores.float(), dim=-1).to(scores.dtype) tokens = tokens[:, -scores.shape[1]:] From 0e413e7963bb7de09611c59bd58d926b78a72f20 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 5 Dec 2023 13:24:21 +0000 Subject: [PATCH 17/75] New try --- src/transformers/generation/logits_process.py | 12 ++- .../models/whisper/modeling_whisper.py | 81 ++++++++++++------- 2 files changed, 59 insertions(+), 34 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index aab7d505ded3..bcb41fe0905c 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1424,6 +1424,9 @@ def __init__(self, begin_suppress_tokens, begin_index): self.begin_suppress_tokens = list(begin_suppress_tokens) self.begin_index = begin_index + def set_begin_index(self, begin_index): + self.begin_index = begin_index + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if input_ids.shape[1] == self.begin_index: @@ -1519,7 +1522,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): """ def __init__( - self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None + self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None ): # support for the kwargs self.eos_token_id = generate_config.eos_token_id self.no_timestamps_token_id = generate_config.no_timestamps_token_id @@ -1532,9 +1535,9 @@ def __init__( else getattr(generate_config, "_detect_timestamp_from_logprob", True) ) - self.begin_index = ( - len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1 - ) + num_forced_ids = len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0 + self.begin_index = begin_index or (num_forced_ids + 1) + self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) # TODO(Patrick, remove hardcoded setting) self.max_initial_timestamp_index = 50 @@ -1575,6 +1578,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to # apply the `max_initial_timestamp` option if input_ids.shape[1] == self.begin_index: + print("HF Sample begin", self.begin_index) scores[:, : self.timestamp_begin] = -float("inf") if self.max_initial_timestamp_index is not None: diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 36ae28d381c4..38d7fc649019 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union import numpy as np +import copy import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -25,7 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation.logits_process import WhisperTimeStampLogitsProcessor +from ...generation.logits_process import WhisperTimeStampLogitsProcessor, SuppressTokensAtBeginLogitsProcessor from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1742,8 +1743,11 @@ def generate( is_multilingual=None, condition_on_prev_tokens: Optional[bool] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), - compression_ratio_threshold: Optional[float] = 2.0, + compression_ratio_threshold: Optional[float] = 2.4, logprob_threshold: Optional[float] = -1.0, + # temperature: Union[float, Tuple[float, ...]] = 0.0, + # compression_ratio_threshold: Optional[float] = None, + # logprob_threshold: Optional[float] = None, prompt_ids: Optional[torch.Tensor] = None, num_segment_frames: Optional[int] = None, return_token_timestamps: Optional[bool] = None, @@ -1912,15 +1916,15 @@ def generate( FutureWarning, ) + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate + else generation_config.return_dict_in_generate ) - if generation_config is None: - generation_config = self.generation_config - input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] if num_segment_frames is None: num_segment_frames = input_stride * self.config.max_source_positions @@ -2011,10 +2015,10 @@ def generate( if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: forced_decoder_ids = self.config.forced_decoder_ids elif ( - hasattr(self.generation_config, "forced_decoder_ids") - and self.generation_config.forced_decoder_ids is not None + hasattr(generation_config, "forced_decoder_ids") + and generation_config.forced_decoder_ids is not None ): - forced_decoder_ids = self.generation_config.forced_decoder_ids + forced_decoder_ids = generation_config.forced_decoder_ids else: forced_decoder_ids = kwargs.get("forced_decoder_ids", None) @@ -2113,7 +2117,7 @@ def generate( if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None else None ) - if last_forced_decoder_ids == self.generation_config.no_timestamps_token_id: + if last_forced_decoder_ids == generation_config.no_timestamps_token_id: # remove no_timestamp to be forcefully generated if we want to return timestamps # this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly forced_decoder_ids = forced_decoder_ids[:-1] if len(forced_decoder_ids) > 1 else None @@ -2125,6 +2129,13 @@ def generate( timestamp_processor if logits_processor is None else timestamp_processor + logits_processor ) + if hasattr(generation_config, "begin_suppress_tokens") and generation_config.begin_suppress_tokens is not None: + begin_suppress_processor = [SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, 1)] + logits_processor = ( + timestamp_processor if logits_processor is None else begin_suppress_processor + logits_processor + ) + generation_config.begin_suppress_tokens = None + # 5. If we're in shortform mode, simple generate the whole input at once and return the output if is_shortform: outputs = super().generate( @@ -2149,7 +2160,7 @@ def generate( condition_on_prev_tokens = ( condition_on_prev_tokens if condition_on_prev_tokens is not None - else getattr(self.generation_config, "condition_on_prev_tokens", False) + else getattr(generation_config, "condition_on_prev_tokens", False) ) # 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated @@ -2161,7 +2172,7 @@ def generate( ) # if input is longer than 30 seconds we default to long-form generation - timestamp_begin = self.generation_config.no_timestamps_token_id + 1 + timestamp_begin = generation_config.no_timestamps_token_id + 1 # input stride is mel frames per encoder output vector which is the product of all conv strides batch_size = input_features.shape[0] @@ -2188,7 +2199,7 @@ def generate( output_scores = logprob_threshold is not None return_dict_in_generate = return_dict_in_generate or output_scores - init_tokens = [self.generation_config.decoder_start_token_id] + init_tokens = [generation_config.decoder_start_token_id] if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: i = 1 while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: @@ -2248,21 +2259,16 @@ def generate( cut_off_length = self.config.max_target_positions // 2 - 1 active_segments = [current_segments[i] for i in new_cur_to_prev_index_map] prev_tokens = self._pad_to_max_length( - active_segments, self.generation_config.pad_token_id, padding="left" + active_segments, generation_config.pad_token_id, padding="left" ) - prev_start_of_text = self.generation_config.suppress_tokens[-2] # TODO(Patrick): Need to put in generation_config + prev_start_of_text = generation_config.suppress_tokens[-2] # TODO(Patrick): Need to put in generation_config decoder_input_ids = torch.cat([prev_start_of_text * one_tensor, prev_tokens[:, -cut_off_length:], decoder_input_ids], dim=-1) - timestamp_processor = [ - proc for proc in logits_processor if isinstance(proc, WhisperTimeStampLogitsProcessor) - ][0] - timestamp_processor.set_begin_index(decoder_input_ids.shape[-1]) - passed_max_length = kwargs.get("max_length", None) passed_max_new_tokens = kwargs.get("max_new_tokens", None) - max_length_config = getattr(self.generation_config, "max_length", None) - max_new_tokens_config = getattr(self.generation_config, "max_new_tokens", None) + max_length_config = getattr(generation_config, "max_length", None) + max_new_tokens_config = getattr(generation_config, "max_new_tokens", None) # Make sure we don't get larger than `max_length` if passed_max_length is not None and passed_max_new_tokens is None: @@ -2274,7 +2280,7 @@ def generate( ) elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: kwargs["max_length"] = min( - self.generation_config.max_length + cut_off_length + 1, self.config.max_target_positions + generation_config.max_length + cut_off_length + 1, self.config.max_target_positions ) logger.info( f"Increase max_length from {max_length_config} to {kwargs['max_length']} since input is conditioned on previous segment." @@ -2291,7 +2297,16 @@ def generate( ): kwargs["max_new_tokens"] = self.config.max_target_positions - cut_off_length - 2 - print("Init", decoder_input_ids[0, :6].tolist()) + timestamp_processor = [ + proc for proc in logits_processor if isinstance(proc, WhisperTimeStampLogitsProcessor) + ][0] + timestamp_processor.set_begin_index(decoder_input_ids.shape[-1]) + begin_suppress_processor = [ + proc for proc in logits_processor if isinstance(proc, SuppressTokensAtBeginLogitsProcessor) + ][0] + begin_suppress_processor.set_begin_index(decoder_input_ids.shape[-1]) + + print("hf in tokens", decoder_input_ids[0].tolist()) # 6.6 Batch generate current chunk for temperature in temperatures: @@ -2334,7 +2349,10 @@ def split_by_batch_index(values, key, batch_idx): if condition_on_prev_tokens is not None: # remove all previously passed decoder input ids except start token - seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1] - 1:] + seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1]:] + else: + # cut BOS token + seek_sequences = seek_sequences[:, 1:] needs_fallback = False if compression_ratio_threshold is not None: @@ -2369,12 +2387,14 @@ def split_by_batch_index(values, key, batch_idx): # make sure we cut a predicted EOS token if we are not finished with the generation yet is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] - if is_not_final and seek_sequence[-1] == self.generation_config.eos_token_id: + if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: seek_sequence = seek_sequence[:-1] + print("hf out tokens", seek_sequence.tolist()) + # remove all padding tokens - if seek_sequence[-1] == self.generation_config.pad_token_id: - num_paddings = (seek_sequence == self.generation_config.pad_token_id).sum() + if seek_sequence[-1] == generation_config.pad_token_id: + num_paddings = (seek_sequence == generation_config.pad_token_id).sum() seek_sequence = seek_sequence[:-num_paddings] # TODO(Patrick: delete cut type) @@ -2397,7 +2417,7 @@ def split_by_batch_index(values, key, batch_idx): # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output - sequences = self._pad_to_max_length(current_segments, self.generation_config.pad_token_id, padding="right") + sequences = self._pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right") # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. if return_segments: @@ -2461,6 +2481,7 @@ def _retrieve_segment( timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin) single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + timestamp_segment_indices.add_(1) # If whisper predicted a "end of segment" via a timestep token, let's go ever each # "end of segment" prediction and slice the decoding into segments accordingly @@ -2474,7 +2495,7 @@ def _retrieve_segment( last_slice = 0 # Add each segment to list of all segments for current_slice in slices: - sliced_tokens = seek_sequence[last_slice + 1 : current_slice + 1] + sliced_tokens = seek_sequence[last_slice : current_slice] start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin segments.append( From 8411a9e4570531d743414e905eb996fc7c166205 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 5 Dec 2023 17:23:22 +0000 Subject: [PATCH 18/75] Fix more whisper stuff --- src/transformers/generation/logits_process.py | 24 +++++++ src/transformers/generation/utils.py | 1 + .../models/whisper/modeling_whisper.py | 71 +++++++++++++------ 3 files changed, 76 insertions(+), 20 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index bcb41fe0905c..ca33f27a40ba 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1596,6 +1596,30 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores +class WhisperNoSpeechDetection(LogitsProcessor): + r"""This processor can be used to detect silence when using Whisper.""" + + def __init__(self, no_speech_token: int, begin_index: int): + self.no_speech_token = no_speech_token + self.begin_index = begin_index + self._no_speech_prob = [0.0] + + @property + def no_speech_prob(self): + return self._no_speech_prob + + def set_begin_index(self, begin_index): + self.begin_index = begin_index + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if input_ids.shape[1] == self.begin_index: + probs = scores.float().softmax(dim=-1) + self._no_speech_prob = probs[:, self.no_speech_token] + + return scores + + class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): r"""Logits processor for classifier free guidance (CFG). The scores are split over the batch dimension, where the first half correspond to the conditional logits (predicted from the input prompt) and the second half diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index aafd46c64cfa..8e3c09c54ff9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1106,6 +1106,7 @@ def _get_logits_processor( # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: processors.append(LogitNormalization()) + return processors def _get_stopping_criteria( diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 38d7fc649019..0fb94e1af4af 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation.logits_process import WhisperTimeStampLogitsProcessor, SuppressTokensAtBeginLogitsProcessor +from ...generation.logits_process import WhisperTimeStampLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, WhisperNoSpeechDetection, SuppressTokensLogitsProcessor from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1745,6 +1745,7 @@ def generate( temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, logprob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = None, # temperature: Union[float, Tuple[float, ...]] = 0.0, # compression_ratio_threshold: Optional[float] = None, # logprob_threshold: Optional[float] = None, @@ -2090,7 +2091,7 @@ def generate( forced_decoder_ids = [ *text_prompt_ids, generation_config.decoder_start_token_id, - *[token for _rank, token in non_prompt_forced_decoder_ids], + *[token for _, token in non_prompt_forced_decoder_ids], ] forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] generation_config.forced_decoder_ids = forced_decoder_ids @@ -2110,6 +2111,7 @@ def generate( if kwargs.get("num_frames") is not None: generation_config.num_frames = kwargs.pop("num_frames") + begin_index = 1 if generation_config.return_timestamps is True: forced_decoder_ids = generation_config.forced_decoder_ids last_forced_decoder_ids = ( @@ -2124,18 +2126,33 @@ def generate( # Make sure that if list is empty we set it to None generation_config.forced_decoder_ids = forced_decoder_ids - timestamp_processor = [WhisperTimeStampLogitsProcessor(generation_config)] + begin_index = begin_index + len(forced_decoder_ids) if forced_decoder_ids is not None else begin_index + + timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index) + logits_processor = ( + [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor + ) + + if generation_config.suppress_tokens is not None: + suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens) logits_processor = ( - timestamp_processor if logits_processor is None else timestamp_processor + logits_processor + [suppress_tokens_processor] if logits_processor is None else [suppress_tokens_processor] + logits_processor ) + generation_config.suppress_tokens = None - if hasattr(generation_config, "begin_suppress_tokens") and generation_config.begin_suppress_tokens is not None: - begin_suppress_processor = [SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, 1)] + if generation_config.begin_suppress_tokens is not None: + begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index=begin_index) logits_processor = ( - timestamp_processor if logits_processor is None else begin_suppress_processor + logits_processor + [begin_suppress_processor] if logits_processor is None else [begin_suppress_processor] + logits_processor ) generation_config.begin_suppress_tokens = None + if no_speech_threshold is not None: + no_speech_detector = WhisperNoSpeechDetection(no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index) + logits_processor = ( + [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor + ) + # 5. If we're in shortform mode, simple generate the whole input at once and return the output if is_shortform: outputs = super().generate( @@ -2262,7 +2279,7 @@ def generate( active_segments, generation_config.pad_token_id, padding="left" ) - prev_start_of_text = generation_config.suppress_tokens[-2] # TODO(Patrick): Need to put in generation_config + prev_start_of_text = suppress_tokens_processor.suppress_tokens[-2] # TODO(Patrick): Need to put in generation_config decoder_input_ids = torch.cat([prev_start_of_text * one_tensor, prev_tokens[:, -cut_off_length:], decoder_input_ids], dim=-1) passed_max_length = kwargs.get("max_length", None) @@ -2297,21 +2314,22 @@ def generate( ): kwargs["max_new_tokens"] = self.config.max_target_positions - cut_off_length - 2 - timestamp_processor = [ - proc for proc in logits_processor if isinstance(proc, WhisperTimeStampLogitsProcessor) - ][0] timestamp_processor.set_begin_index(decoder_input_ids.shape[-1]) - begin_suppress_processor = [ - proc for proc in logits_processor if isinstance(proc, SuppressTokensAtBeginLogitsProcessor) - ][0] begin_suppress_processor.set_begin_index(decoder_input_ids.shape[-1]) + if no_speech_threshold is not None: + no_speech_detector.set_begin_index(decoder_input_ids.shape[-1]) + print("hf in tokens", decoder_input_ids[0].tolist()) # 6.6 Batch generate current chunk + should_skip = False for temperature in temperatures: do_sample = temperature > 0.0 + num_beams = kwargs.pop("num_beams", 1) + generation_config.num_beams = num_beams if not do_sample else 1 + seek_outputs = super().generate( segment_input, generation_config, @@ -2355,6 +2373,7 @@ def split_by_batch_index(values, key, batch_idx): seek_sequences = seek_sequences[:, 1:] needs_fallback = False + if compression_ratio_threshold is not None: compression_ratio = [seek_sequence.shape[0] / torch.unique(seek_sequence).shape[0] for seek_sequence in seek_sequences] @@ -2369,14 +2388,24 @@ def split_by_batch_index(values, key, batch_idx): needs_fallback = True if logprob_threshold is not None: - scores = [s["scores"] for s in seek_outputs] if output_scores else None - logprobs = self._retrieve_avg_logprobs(scores, seek_sequences, self.config.eos_token_id) + if "sequences_scores" in seek_outputs[0]: + logprobs = [s["sequences_scores"] for s in seek_outputs] + else: + scores = [s["scores"] for s in seek_outputs] + logprobs = self._retrieve_avg_logprobs(scores, seek_sequences, self.config.eos_token_id) # TODO(PVP) only works for batch size = 1 currently if logprobs[0] < logprob_threshold: print("fallback logprob") print("current temp", temperature) needs_fallback = True + + if no_speech_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + # Need to do before all other logit processors + if no_speech_detector.no_speech_prob[0].cpu() > no_speech_threshold and logprobs[0] < logprob_threshold: + needs_fallback = False + should_skip = True if not needs_fallback: break @@ -2385,6 +2414,11 @@ def split_by_batch_index(values, key, batch_idx): for i, seek_sequence in enumerate(seek_sequences): prev_i = cur_to_prev_index_map[i] + if should_skip: + seek[prev_i] += seek_num_frames[prev_i] + print("Skipped!") + continue + # make sure we cut a predicted EOS token if we are not finished with the generation yet is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: @@ -2458,12 +2492,9 @@ def get_log_prob(logprob, token): lengths = (tokens != eos_token_id).sum(-1) - avg_logprobs = torch.div(sum_logprobs, lengths) + avg_logprobs = torch.div(sum_logprobs, lengths + 1) return avg_logprobs - - # print("hf tokens", seek_sequences) - @staticmethod def _retrieve_segment( seek_sequence, From 10cfdc6f6b46d79c071fb0d6bbb1176274854836 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 5 Dec 2023 23:38:05 +0000 Subject: [PATCH 19/75] Improve --- src/transformers/models/whisper/modeling_whisper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 255cfaeffd89..bf64761e1a82 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2320,15 +2320,15 @@ def generate( ) elif ( passed_max_new_tokens is not None - and passed_max_new_tokens + cut_off_length + 2 > self.config.max_target_positions + and passed_max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions ): - kwargs["max_new_tokens"] = self.config.max_target_positions - cut_off_length - 2 + kwargs["max_new_tokens"] = self.config.max_target_positions - decoder_input_ids.shape[-1] elif ( passed_max_new_tokens is None and max_new_tokens_config is not None - and max_new_tokens_config + cut_off_length + 2 > self.config.max_target_positions + and max_new_tokens_config + decoder_input_ids.shape[-1] > self.config.max_target_positions ): - kwargs["max_new_tokens"] = self.config.max_target_positions - cut_off_length - 2 + kwargs["max_new_tokens"] = self.config.max_target_positions - decoder_input_ids.shape[-1] timestamp_processor.set_begin_index(decoder_input_ids.shape[-1]) begin_suppress_processor.set_begin_index(decoder_input_ids.shape[-1]) From b0897c7575a88804ae469adf56fec30de44ce65c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 7 Dec 2023 11:40:08 +0000 Subject: [PATCH 20/75] correct more --- src/transformers/models/whisper/modeling_whisper.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index bf64761e1a82..3b8c707d6f65 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2381,21 +2381,14 @@ def split_by_batch_index(values, key, batch_idx): else: seek_sequences = seek_outputs - if condition_on_prev_tokens is not None: - # remove all previously passed decoder input ids except start token - seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1]:] - else: - # cut BOS token - seek_sequences = seek_sequences[:, 1:] + # remove all previously passed decoder input ids + seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1]:] needs_fallback = False if compression_ratio_threshold is not None: compression_ratio = [seek_sequence.shape[0] / torch.unique(seek_sequence).shape[0] for seek_sequence in seek_sequences] - #if seek.item() > 420000: - # import ipdb; ipdb.set_trace() - # TODO(PVP) only works for batch size = 1 currently if compression_ratio[0] > compression_ratio_threshold: print("fallback compression") @@ -2563,7 +2556,7 @@ def _retrieve_segment( # otherwise, ignore the unfinished segment and seek to the last timestamp # here we throw away all predictions after the last predicted "end of segment" # since we are cutting right in the middle of an audio - last_timestamp_pos = seek_sequence[last_slice].item() - timestamp_begin + last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin segment_offset = last_timestamp_pos * input_stride cut_type = "cut" else: From 24fa4632f0696a2d7cc871672dc14c7834b02a22 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 7 Dec 2023 11:48:36 +0000 Subject: [PATCH 21/75] correct more --- src/transformers/models/whisper/modeling_whisper.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 3b8c707d6f65..6a720850a43f 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1758,13 +1758,10 @@ def generate( language=None, is_multilingual=None, condition_on_prev_tokens: Optional[bool] = None, - temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), - compression_ratio_threshold: Optional[float] = 2.4, - logprob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = None, - # temperature: Union[float, Tuple[float, ...]] = 0.0, - # compression_ratio_threshold: Optional[float] = None, - # logprob_threshold: Optional[float] = None, + temperature: Union[float, Tuple[float, ...]] = 0.0, + compression_ratio_threshold: Optional[float] = None, + logprob_threshold: Optional[float] = None, prompt_ids: Optional[torch.Tensor] = None, num_segment_frames: Optional[int] = None, return_token_timestamps: Optional[bool] = None, From e0b7af3981f1d60d2dd49d5fe62ba8ea2ac92c59 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 7 Dec 2023 11:59:58 +0000 Subject: [PATCH 22/75] correct more --- src/transformers/models/whisper/modeling_whisper.py | 11 ++++------- tests/models/whisper/test_modeling_whisper.py | 2 ++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 6a720850a43f..f3d8d36e573c 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2291,8 +2291,7 @@ def generate( prev_tokens = self._pad_to_max_length( active_segments, generation_config.pad_token_id, padding="left" ) - - prev_start_of_text = suppress_tokens_processor.suppress_tokens[-2] # TODO(Patrick): Need to put in generation_config + prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or suppress_tokens_processor.suppress_tokens[-2] # TODO(Patrick): Need to put in generation_config decoder_input_ids = torch.cat([prev_start_of_text * one_tensor, prev_tokens[:, -cut_off_length:], decoder_input_ids], dim=-1) passed_max_length = kwargs.get("max_length", None) @@ -2327,11 +2326,9 @@ def generate( ): kwargs["max_new_tokens"] = self.config.max_target_positions - decoder_input_ids.shape[-1] - timestamp_processor.set_begin_index(decoder_input_ids.shape[-1]) - begin_suppress_processor.set_begin_index(decoder_input_ids.shape[-1]) - - if no_speech_threshold is not None: - no_speech_detector.set_begin_index(decoder_input_ids.shape[-1]) + for proc in logits_processor: + if hasattr(proc, "set_begin_index"): + proc.set_begin_index(decoder_input_ids.shape[-1]) print("hf in tokens", decoder_input_ids[0].tolist()) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index a53731ccdb57..3ac922028de6 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1355,6 +1355,7 @@ def _check_longform_generate_single_batch(self, condition_on_prev_tokens): model.generation_config._detect_timestamp_from_logprob = False # make sure that we only have the same begin token model.generation_config.max_initial_timestamp_index = 0 + model.generation_config.prev_bos_token_id = timestamp_begin - 3 outputs = model.generate(long_input_features, logits_processor=logits_processor, condition_on_prev_tokens=condition_on_prev_tokens, return_segments=True) @@ -1402,6 +1403,7 @@ def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): model.generation_config._detect_timestamp_from_logprob = False # make sure that we only have the same begin token model.generation_config.max_initial_timestamp_index = 0 + model.generation_config.prev_bos_token_id = timestamp_begin - 3 logits_processor = [ DummyTimestampLogitProcessor( From 5a67f752d8a334359a83af37d1e8fde40f9919b9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 7 Dec 2023 15:37:43 +0000 Subject: [PATCH 23/75] Fix some tests --- src/transformers/generation/logits_process.py | 12 +++++++++--- tests/models/whisper/test_modeling_whisper.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index cb2080e76beb..0aa6ac026b5e 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1524,9 +1524,9 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): def __init__( self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None ): # support for the kwargs - self.eos_token_id = generate_config.eos_token_id self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.timestamp_begin = generate_config.no_timestamps_token_id + 1 + self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id # this variable is mostly just used for testing self._detect_timestamp_from_logprob = ( @@ -1539,8 +1539,8 @@ def __init__( self.begin_index = begin_index or (num_forced_ids + 1) self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) - # TODO(Patrick, remove hardcoded setting) - self.max_initial_timestamp_index = 50 + # TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50 + # self.max_initial_timestamp_index = 50 def set_begin_index(self, begin_index): self.begin_index = begin_index @@ -1551,6 +1551,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to scores[:, self.no_timestamps_token_id] = -float("inf") # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly + no_timestamps = False for k in range(input_ids.shape[0]): sampled_tokens = input_ids[k, self.begin_index :] seq = list(sampled_tokens.tolist()) @@ -1561,6 +1562,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to if last_was_timestamp: if penultimate_was_timestamp: # has to be non-timestamp scores[k, self.timestamp_begin :] = -float("inf") + no_timestamps = True else: # cannot be normal text tokens scores[k, : self.eos_token_id] = -float("inf") @@ -1591,8 +1593,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1) max_text_token_logprob = logprobs[k, : self.timestamp_begin].max() if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob: + import ipdb; ipdb.set_trace() scores[k, : self.timestamp_begin] = -float("inf") + if torch.isinf(scores).all(): + import ipdb; ipdb.set_trace() + return scores diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 3ac922028de6..8c0c856f52d9 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1361,7 +1361,7 @@ def _check_longform_generate_single_batch(self, condition_on_prev_tokens): segments = outputs["segments"][0] - for i, segment in enumerate(segments): + for _, segment in enumerate(segments): assert segment["start"] <= segment["end"], "start has to be smaller equal end" assert any( s > timestamp_begin for s in segment["tokens"][1:] From 404d5422c96b219a1cd66df04f1ebe987e9301b4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 7 Dec 2023 19:56:15 +0000 Subject: [PATCH 24/75] Add more tests --- src/transformers/generation/logits_process.py | 1 - .../models/whisper/modeling_whisper.py | 3 +- tests/models/whisper/test_modeling_whisper.py | 186 ++++++++++++++++-- 3 files changed, 174 insertions(+), 16 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 0aa6ac026b5e..b361e95e6f4c 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1593,7 +1593,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1) max_text_token_logprob = logprobs[k, : self.timestamp_begin].max() if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob: - import ipdb; ipdb.set_trace() scores[k, : self.timestamp_begin] = -float("inf") if torch.isinf(scores).all(): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f3d8d36e573c..c0877ed96856 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2379,7 +2379,6 @@ def split_by_batch_index(values, key, batch_idx): seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1]:] needs_fallback = False - if compression_ratio_threshold is not None: compression_ratio = [seek_sequence.shape[0] / torch.unique(seek_sequence).shape[0] for seek_sequence in seek_sequences] @@ -2488,7 +2487,7 @@ def _retrieve_avg_logprobs(scores, tokens, eos_token_id): tokens = tokens[:, -scores.shape[1]:] def get_log_prob(logprob, token): - token_logprob = logprob.gather(-1, token) * (token[:, -1] != eos_token_id) + token_logprob = logprob.gather(-1, token)[:, 0] * (token[:, -1] != eos_token_id) return token_logprob sum_logprobs = sum(get_log_prob(logprobs[:, i], tokens[:, i: i+1]) for i in range(logprobs.shape[1])) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 8c0c856f52d9..e2cd45205982 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1357,7 +1357,19 @@ def _check_longform_generate_single_batch(self, condition_on_prev_tokens): model.generation_config.max_initial_timestamp_index = 0 model.generation_config.prev_bos_token_id = timestamp_begin - 3 - outputs = model.generate(long_input_features, logits_processor=logits_processor, condition_on_prev_tokens=condition_on_prev_tokens, return_segments=True) + gen_kwargs = { + "logits_processor": logits_processor, + "return_segments": True, + "condition_on_prev_tokens": condition_on_prev_tokens + } + + if condition_on_prev_tokens: + gen_kwargs["no_speech_threshold"] = 0.6 + gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) + gen_kwargs["compression_ratio_threshold"] = 2.4 + gen_kwargs["logprob_threshold"] = -1.0 + + outputs = model.generate(long_input_features, **gen_kwargs) segments = outputs["segments"][0] @@ -1430,9 +1442,20 @@ def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): seed=0, ) ] - outputs = model.generate( - long_input_features, attention_mask=attention_mask, logits_processor=logits_processor, return_segments=True, condition_on_prev_tokens=condition_on_prev_tokens, - ) + gen_kwargs = { + "logits_processor": logits_processor, + "return_segments": True, + "condition_on_prev_tokens": condition_on_prev_tokens, + "attention_mask": attention_mask, + } + + if condition_on_prev_tokens: + gen_kwargs["no_speech_threshold"] = 0.6 + gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) + gen_kwargs["compression_ratio_threshold"] = 2.4 + gen_kwargs["logprob_threshold"] = -1.0 + + outputs = model.generate(long_input_features, **gen_kwargs) tokens = outputs["sequences"][1] segments = outputs["segments"][1] @@ -2065,12 +2088,46 @@ def test_whisper_longform_single_batch(self): assert decoded == EXPECTED_TEXT + @slow + def test_whisper_longform_single_batch_prev_cond(self): + # fmt: off + EXPECTED_TEXT = [""" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grieved doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite itals are as national as a jingo poem. Mr. Birk at Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. When Mr. John Collier gives his sitter a cheerful slap in the back, before he says like a shampooer and a Turkish bath, next man it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. He tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man, and remarks was pleasing courtesy in felicitous grace that many faces are feeling. Unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M. A. A man said to the universe, Sir, I exist. Sweat covered Breon's body trickling into the tight-lowing cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retroveilities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the twenties needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenties, he must have drawn his gun because the intruder said quickly, but that away you're being a fool. But there was silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Your man who entered the twenties had his own training tricks. They were appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Breon's death was in some ways easier than defeat. Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that's rested aside, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggido long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now, inquired Shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked Betsy thoughtfully. I don't believe Anne knew any magic, or she'd have worked it before. I do not know, confessed Shaggy. True, agreed Calico. Calico went to the big gong, and pounded on it, just as we're good to be used to do. But no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong, and then sat in the throne, wearing Regidos discarded Ruby Crown, and holding in his hand to scepter, which Regidos had so often thrown at his head."""] + # fmt: on + + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + model = model.to("cuda") + + ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean") + one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32) + + input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")[ + "input_features" + ] + input_features = input_features.to(device="cuda") + + gen_kwargs = {"return_timestamps": True} + gen_kwargs["no_speech_threshold"] = 0.6 + gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) + gen_kwargs["compression_ratio_threshold"] = 2.4 + gen_kwargs["condition_on_prev_tokens"] = True + gen_kwargs["logprob_threshold"] = -1.0 + + torch.manual_seed(0) + result = model.generate(input_features, **gen_kwargs) + decoded = processor.batch_decode(result, skip_special_tokens=True) + + result = f'"""{decoded[0]}"""' + + assert result == EXPECTED_TEXT + + @slow def test_whisper_longform_multi_batch(self): # fmt: off EXPECTED_TEXT_1 = [" Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Birkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Mix a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing a poster or near the fire, and the ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the Tupper of painting. a Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the tight-wing cloth that was the only germany war. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were, triggered his muscles into complete relaxation. Oily his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenty's he must have drawn his gun, because the intruder said quickly, but that away you're being a fool. Out there was silence then, and still wondering, Breon was once more asleep. Ten seconds he asked the handler who was needing his aching muscles. a red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the twenties had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were andextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the twenties and death during the last round was, in some ways, easier than defeat. Breeding deeply, Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled the mazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue, pre-inscented and new to fifth point was his. Then the powerful twist that's rest of the side, in and under the guard, because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, a cooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, return Calico. Where is my brother now? choir-dshaggy, in the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh, no, I'm quite sure he didn't. That's funny, remarked Betsy thoughtfully. I don't believe and knew any magic, or she'd have worked it before. I do not know, confess shaggy. True, a great calico. Calico went to the big gong and pounded on it, just as Virgado used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the throne, wearing Virgados discarded Ruby Crown, and holding in his hand to scepter, which Virgado had so often thrown at his head. head."] EXPECTED_TEXT_2 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Burkett Foster's landscapes smile at one much in the same way that Mr. Carker."] - EXPECTED_TEXT_3 = [" possible. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grieved doubts whether Sir Frederick Layton's work is really greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-guards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Birk at Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampooer and a Turkish bath, next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. Under general principles of art, Mr. Quilter writes with equal lucidity. Painting, he tells us, is of a different quality to mathematics and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Mix a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire. any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man, and remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the titling cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes. Even to soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered as muscles into complete relaxation. Oily his heart and lungs worked on at a strong measured rate. He was in In reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenty's he must have drawn his gun, because the intruder said quickly, but that away you're being a fool. Out there was silence then, and still wondering, Breon was once more asleep. Ten seconds he asked the handler who was needing his aching muscles. a red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the twenties had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were andextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the twenties and death during the last round was, in some ways, easier than defeat. Breeding deeply, Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue, re-insunced it and knew the fifth point was his. Then the powerful twist that's rest of the side, in and under the guard, because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, a cooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled and disgraced, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now? quared shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. And that's funny, remarked Betsy thoughtfully. I don't believe Anne knew any magic, or she'd have worked it before. I do not know, confess Shaggy. True, a great calico. Calico went to the big gong and pounded on it, just as we're good to have used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the thrown wearing ruggedos discarded ruby crown and holding in his hand to septor which Ruggato had so often thrown at his head."] + EXPECTED_TEXT_3 = [" possible. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grieved doubts whether Sir Frederick Layton's work is really greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-guards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Birk at Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampooer and a Turkish bath, next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. Under general principles of art, Mr. Quilter writes with equal lucidity. Painting, he tells us, is of a different quality to mathematics and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Mix a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire. any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man, and remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the titling cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes. Even to soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered as muscles into complete relaxation. Oily his heart and lungs worked on at a strong measured rate. He was in In reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenty's he must have drawn his gun, because the intruder said quickly, but that away you're being a fool. Out there was silence then, and still wondering, Breon was once more asleep. Ten seconds he asked the handler who was needing his aching muscles. a red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the twenties had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were andextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the twenties and death during the last round was, in some ways, easier than defeat. Breeding deeply, Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue, re-insunced it and knew the fifth point was his. Then the powerful twist that's rest of the side, in and under the guard, because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, a cooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled and disgraced, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now? quared shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. And that's funny, remarked Betsy thoughtfully. I don't believe Anne knew any magic, or she'd have worked it before. I do not know, confess Shaggy. True, a great calico. Calico went to the big gong and pounded on it, just as we're good to have used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the thrown wearing ruggedos discarded ruby crown and holding in his hand to septor which ruggedo had so often thrown at his head."] EXPECTED_TEXT_4 = [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter\'s manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton\'s work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell\'s pictures are a sort of up-gards and atom paintings, and Mason\'s exquisite idles are as national as a jingo poem. Mr. Birk at Foster\'s landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampoo or a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Makes the customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire, any ornaments Fred brought home from India on the mantelboard. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the Tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon\'s body trickling into the tight-lowing cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered his muscles into complete relaxation. Oli\'s heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the twenties needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, The thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I\'m here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenties, he must have drawn his gun because the intruder said quickly, but that away you\'re being a fool. out, through his silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry, and victory to the stronger. man who entered the twenties had his own training tricks. They were appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. had died before during the 20s and death during the last round was in some ways easier than defeat. Breathing deeply, Breon\'s softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent\'s face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that\'s rested aside, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accooing dove. He has gone, and gone for good," answered Polychrom, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with says he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn\'t work too hard, said Shaggy. He doesn\'t work at all. In fact, there\'s nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we\'ve turned Calico. Where is my brother now, inquired Shaggy. In the metal forest. Where is that? The middle forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I\'m quite sure he didn\'t. That\'s funny, remarked Betsy thoughtfully. I don\'t believe Anne knew any magic, or she\'d have worked it before. I do not know, confess Shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as Virgato used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the throne, wearing Virgato\'s discarded ruby crown and holding in his hand to scepter which reggative head so often thrown at his head.'] # fmt: on @@ -2114,18 +2171,76 @@ def test_whisper_longform_multi_batch(self): assert decoded_all[2:3] == EXPECTED_TEXT_3 assert decoded_all[3:4] == EXPECTED_TEXT_4 + @slow + def test_whisper_longform_multi_batch_prev_cond(self): + # fmt: off + EXPECTED_TEXT_1 = [" Mr. Quilters manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca. The Nils, pictures are sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. On the general principles of art, Mr. Quilters writes with equal lucidity. Painting he tells us is of a different quality to mathematics and finish in art is adding more effect. As for etchings, there are of two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostorer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many phases of feeling only, unfortunately, his own work never does get good. Mr. Quilters has missed his chance, for he has failed even to make himself the tougher of painting. My hair equal to M.A. A man said to the universe, Sir, I exist. Sweat covered Breon's body, trickling into the tight-wing cloth that was the only garment he wore. The cut on his chest still dripping blood. The ache of his overstrain dyes. Even the soaring arena around him with thousands of spectators, retrievalidies not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. To 20s, he must have drawn his gun because the intruder said quickly, but that away, you're being a fool. Out, the resoundance then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. Our red-haired mountain of a man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were inexplicably linked into one. This strengthened enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the other hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our role. Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to the side, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchanges as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they're asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help you run into escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico, whereas my brother now inquired shaggy in the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked to Bedsey thoughtfully. I don't believe Anne knew any magic or she'd have worked before. I do not know, confessed shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as Ruggano used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing Ruggano's discarded ruby crown. And holding in his hand the scepter which Ruggano had so often thrown at his head."] + EXPECTED_TEXT_2 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and can discover in it but little of rocky Ithaca. Lennials, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker"] + EXPECTED_TEXT_3 = [" gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating in its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of rocky ithaka. Lennils, pictures, are a sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. Under general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics and finish in art is adding more effect. As for etchings, thereof two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostoror. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and falseness graced that many phases of feeling, only unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tougher of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the tight-wing cloth that was the only garment you wore. The cut on his chest still dripping blood. The ache of his overstrained eyes. Even the soaring arena around him with thousands of spectators were trivealed, not worth thinking about. His instant panic was followed by a small sharp, blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie sliding out on the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The 20s, he must have drawn his gun because the intruder said quickly, but that away, he'll be in the fool. Out, there is silence then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the autohydrotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled up the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our ol' Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to decide, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they're asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Whereas my brother now, in Quaragejjegi, in the metal forest. Where is that? The metal forest is in the great Dome to Cavern, the largest and all our dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny remarked by the bad sea thoughtfully. I don't believe Anne knew any magic or she'd have worked it before. I do not know, confessed shaggy. True, a great Calico. Calico went to the big gong and pounded on it, just as we're good or used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing reggos, discarded ruby crown, and holding in his hand to scepter which reggos had so often thrown at his head."] + EXPECTED_TEXT_4 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and can discover in it but little of rocky Ithaca. Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says, like a shampooer in a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. On the general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, thereof two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostorer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many phases of feeling only, unfortunately, his own work never does, get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tougher of painting. By Harry Quilter, M.A. A man said to the universe, Sir, I exist. Sweat covered Breon's body, trickling into the tight-wing cloth that was the only garment you wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators were trivialities not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. To 20s, he must have drawn his gun because the intruder said quickly, but that away, you're being a fool. Out, there is silence then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. I've read here at Mountain of a Man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma, as if the two were inexplicably linked into one. Just strengthed and enabled someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the autohydrotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled up the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our ol' Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to the side, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. She has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchanges as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace and your friends are asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now, in Quaragejji, in the metal forest? Where is that? The metal forest is in the great Dome to Cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked a bit, see you thoughtfully. I don't believe Anne knew any magic or she'd have worked it before. I do not know, confessed shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as we're good we used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing reggos, discarded ruby crown and holding it his hand to scepter which reggo had so often thrown at his head."] + # fmt: on + + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model = model.to("cuda") + + ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean") + one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32) + audios = [] + audios.append(one_audio[110000:]) + audios.append(one_audio[:800000]) + audios.append(one_audio[80000:]) + audios.append(one_audio[:]) + + gen_kwargs = {"return_timestamps": True} + gen_kwargs["no_speech_threshold"] = 0.6 + gen_kwargs["temperature"] = 0.0 + gen_kwargs["compression_ratio_threshold"] = 2.4 + gen_kwargs["condition_on_prev_tokens"] = True + gen_kwargs["logprob_threshold"] = -1.0 + + with open("/home/patrick/expected.txt", "w") as f: + decoded_single = [] + for audio in audios: + inputs = processor(audio, return_tensors="pt", truncation=False) + inputs = inputs.to(device="cuda") + + result = model.generate(**inputs, **gen_kwargs) + decoded_single.append(processor.batch_decode(result, skip_special_tokens=True)) + f.write(decoded_single[-1][0] + "\n") + + # inputs = processor( + # audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True + # ) + # inputs = inputs.to(device="cuda") + + # result = model.generate(**inputs, **gen_kwargs) + # decoded_all = processor.batch_decode(result, skip_special_tokens=True) + + # # make sure single & batch is exactly the same + # assert decoded_all[0:1] == decoded_single[0] + # assert decoded_all[1:2] == decoded_single[1] + # assert decoded_all[2:3] == decoded_single[2] + # assert decoded_all[3:4] == decoded_single[3] + + # exact match + assert decoded_single[0] == EXPECTED_TEXT_1 + assert decoded_single[1] == EXPECTED_TEXT_2 + assert decoded_single[2] == EXPECTED_TEXT_3 + assert decoded_single[3] == EXPECTED_TEXT_4 + @slow def test_whisper_longform_multi_batch_hard(self): # fmt: off EXPECTED_TEXT = [ - " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!", - " Folks, I spend a lot of time right over there, night after night after night, actually. Carefully selecting for you the day's noosiest, most aerodynamic headlines, stress testing, and those topical anti-lock breaks and power steering, painstakingly stitching, leather seating so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school and slap myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen moon, render a gas tank out of an empty big gulp, fill with white claw and denatured alcohol, then light a match and let her rip and the demented one man soapbox derby of news that is my segment. Me, Guadalupe! No!", - " Ladies and gentlemen, you know, I spent a lot of time right over there Raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their jokes swollen teats Churning the daily stories into the decadent proven-style style triple cream breed that is my nightly monologue But sometimes sometimes folks I stagger home hungry after being released by the police and Root around in the neighbor's trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rod I won from a rat in a pre-donned street fight. Put it in a discarded paint can to leave it to ferment next to a trash fire then hunker down and hallucinate while eating the listeria laden demon custard of news that is my segment. You mean one of them.", - " Folks, if you watch this show, you know I spend most of my time right over there carefully sorting through the day's biggest stories and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ichol Gregoire Ferrandi, who carefully dye them in a palette of bright zesty shades and adorn them in the finest and most topical inlay work using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddles stitching. In line it with bees, wax, coated linen, finely attached a mallet, hammered strap, pearled hardware, and close-shit to create for you the one-of-a-kind hoke couture, Erme's Birkin bag that is my monologue. But sometimes, sometimes folks, sometimes. Sometimes I wake up in the last car of an abandoned roller coaster at Coney Island where I'm I'm hiding from the triads. I have some engine lubricants out of a safe way bag and stagger down the shore to tear the sail off a beach schooner. Then I rip the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel lovely folks. And use it to stitch the sail into a loose pouch like a rock sack. And I stow away in the back of a garbage truck to the junkyard where I pick through to the debris for only the broken toys that make me the saddest until I have loaded for you. The Hobo Fugitives bug out, bindle of news that is my segment. Me one!", - " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's biggest stories right over there. Meticulously selecting the most topical chakra affirming scented candles, and using Feng Shui to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue. But sometimes just sometimes I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself, and used fry oil, wrap my hands with some double-duct tape by stole from the broken car window. Pound a six-pack of blueberry hard-seltzer and a sack of pills I stole from a parked ambulance. Then arm wrestle a raccoon in the back alley vision quest of news that is my segment. Meanwhile!", - " You know, folks, I spend most of my time right over there. Mining the day's biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels. Then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press-black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards in a faceplate and, finally, using fluted strips of white alloyed molding, I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating, Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes folks. Sometimes, just sometimes, I come into my sense as fully naked on the deck of a pirate besieged melee container ship that picked me up floating on the detached door of a portapotty in the Indian Ocean. Then after a sunstroke-induced realization of the crew of this ship plans to sell me an exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe at a pool chain that accepting my new role as Captain and declaring myself king of the windarc seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create the sopping wet pirate crown of news that is my segment. Meanwhile!", - " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's Newsiest most topical flower eggs milk and butter and Stranding into a fine batter to make delicate and informative comedy pancakes Then I glaze them in the juice and zest of the most relevant midnight Valencia oranges and douse it all and a fine Dela main de voyage cognac Before prom baying and basting them tables. I deserve for you the James Beard award worthy crepe suzzette That is my nightly monologue, but sometimes just sometimes folks. I wake up in the baggage hold of Greyhound bus. It's being hoisted by the scrap yard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps and busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strain, pair of sweatpants and as oven mitts to extract and serve the demented transience poundcake of news that is my segment. Me, Guadalupe!", - " Folks, if you watched the show and I hope you do, I spent a lot of time right over there. Tiredlessly studying the lineage of the days most important thoroughbred stories and whole-stiner headlines, working with the best trainers, money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen. That is my nightly monologue, but sometimes, sometimes, folks, I break into an unincorporated veterinary genetics lab and grab whatever test tubes I can find and then under a grow light I got from a discarded chia pet. I mixed the pilfered DNA of a horse and whatever was in a tube labeled Keith Colan extra. Slurrying the concoction with caffeine pills and a microwave red bull, I screamed, sang a prayer to Janice, initiator of human life and God of transformation as a half horse, half man, freak. Seizes to life before me and the hideous collection of loose animal parts and corrupted man tissue that is my segment. Meanwhile!", + " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile." + " Folks, I spend a lot of time right over there, night after night after night, actually. Carefully selecting for you the day's noosiest, most aerodynamic headlines, stress testing, and those topical anti-lock breaks and power steering, painstakingly stitching, leather seating so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school and slap myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen moon, render a gas tank out of an empty big gulp, fill with white claw and denatured alcohol, then light a match and let her rip and the demented one man soapbox derby of news that is my segment. Me, Guadalupe! No!" + " Ladies and gentlemen, you know, I spent a lot of time right over there Raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their jokes swollen teats Churning the daily stories into the decadent proven-style style triple cream breed that is my nightly monologue But sometimes sometimes folks I stagger home hungry after being released by the police and Root around in the neighbor's trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rod I won from a rat in a pre-donned street fight. Put it in a discarded paint can to leave it to ferment next to a trash fire then hunker down and hallucinate while eating the listeria laden demon custard of news that is my segment. You mean one of them." + " Folks, if you watch this show, you know I spend most of my time right over there carefully sorting through the day's biggest stories and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ichol Gregoire Ferrandi, who carefully dye them in a palette of bright zesty shades and adorn them in the finest and most topical inlay work using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddles stitching. In line it with bees, wax, coated linen, finely attached a mallet, hammered strap, pearled hardware, and close-shit to create for you the one-of-a-kind hoke couture, Erme's Birkin bag that is my monologue. But sometimes, sometimes folks, sometimes. Sometimes I wake up in the last car of an abandoned roller coaster at Coney Island where I'm I'm hiding from the triads. I have some engine lubricants out of a safe way bag and stagger down the shore to tear the sail off a beach schooner. Then I rip the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel lovely folks. And use it to stitch the sail into a loose pouch like a rock sack. And I stow away in the back of a garbage truck to the junkyard where I pick through to the debris for only the broken toys that make me the saddest until I have loaded for you. The Hobo Fugitives bug out, bindle of news that is my segment. Me one!" + " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's biggest stories right over there. Meticulously selecting the most topical chakra affirming scented candles, and using Feng Shui to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue. But sometimes just sometimes I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself, and used fry oil, wrap my hands with some double-duct tape by stole from the broken car window. Pound a six-pack of blueberry hard-seltzer and a sack of pills I stole from a parked ambulance. Then arm wrestle a raccoon in the back alley vision quest of news that is my segment. Meanwhile!" + " You know, folks, I spend most of my time right over there. Mining the day's biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels. Then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press-black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards in a faceplate and, finally, using fluted strips of white alloyed molding, I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating, Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes folks. Sometimes, just sometimes, I come into my sense as fully naked on the deck of a pirate besieged melee container ship that picked me up floating on the detached door of a portapotty in the Indian Ocean. Then after a sunstroke-induced realization of the crew of this ship plans to sell me an exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe at a pool chain that accepting my new role as Captain and declaring myself king of the windarc seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create the sopping wet pirate crown of news that is my segment. Meanwhile!" + " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's Newsiest most topical flower eggs milk and butter and Stranding into a fine batter to make delicate and informative comedy pancakes Then I glaze them in the juice and zest of the most relevant midnight Valencia oranges and douse it all and a fine Dela main de voyage cognac Before prom baying and basting them tables. I deserve for you the James Beard award worthy crepe suzzette That is my nightly monologue, but sometimes just sometimes folks. I wake up in the baggage hold of Greyhound bus. It's being hoisted by the scrap yard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps and busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strain, pair of sweatpants and as oven mitts to extract and serve the demented transience poundcake of news that is my segment. Me, Guadalupe!" + " Folks, if you watched the show and I hope you do, I spent a lot of time right over there. Tiredlessly studying the lineage of the days most important thoroughbred stories and whole-stiner headlines, working with the best trainers, money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen. That is my nightly monologue, but sometimes, sometimes, folks, I break into an unincorporated veterinary genetics lab and grab whatever test tubes I can find and then under a grow light I got from a discarded chia pet. I mixed the pilfered DNA of a horse and whatever was in a tube labeled Keith Colan extra. Slurrying the concoction with caffeine pills and a microwave red bull, I screamed, sang a prayer to Janice, initiator of human life and God of transformation as a half horse, half man, freak. Seizes to life before me and the hideous collection of loose animal parts and corrupted man tissue that is my segment. Meanwhile!" ] # fmt: on @@ -2161,6 +2276,51 @@ def test_whisper_longform_multi_batch_hard(self): assert decoded_all[i] == decoded_single[i] assert decoded_all[i] == EXPECTED_TEXT[i] + @slow + def test_whisper_longform_multi_batch_hard_prev_cond(self): + # fmt: off + EXPECTED_TEXT = [ + " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile." + " Folks, I spend a lot of time right over there, night after night after night, actually. Carefully selecting for you the day's noosiest, most aerodynamic headlines, stress testing, and those topical anti-lock breaks and power steering, painstakingly stitching, leather seating so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school and slap myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen moon, render a gas tank out of an empty big gulp, fill with white claw and denatured alcohol, then light a match and let her rip and the demented one man soapbox derby of news that is my segment. Me, Guadalupe! No!" + " Ladies and gentlemen, you know, I spent a lot of time right over there Raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their jokes swollen teats Churning the daily stories into the decadent proven-style style triple cream breed that is my nightly monologue But sometimes sometimes folks I stagger home hungry after being released by the police and Root around in the neighbor's trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rod I won from a rat in a pre-donned street fight. Put it in a discarded paint can to leave it to ferment next to a trash fire then hunker down and hallucinate while eating the listeria laden demon custard of news that is my segment. You mean one of them." + " Folks, if you watch this show, you know I spend most of my time right over there carefully sorting through the day's biggest stories and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ichol Gregoire Ferrandi, who carefully dye them in a palette of bright zesty shades and adorn them in the finest and most topical inlay work using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddles stitching. In line it with bees, wax, coated linen, finely attached a mallet, hammered strap, pearled hardware, and close-shit to create for you the one-of-a-kind hoke couture, Erme's Birkin bag that is my monologue. But sometimes, sometimes folks, sometimes. Sometimes I wake up in the last car of an abandoned roller coaster at Coney Island where I'm I'm hiding from the triads. I have some engine lubricants out of a safe way bag and stagger down the shore to tear the sail off a beach schooner. Then I rip the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel lovely folks. And use it to stitch the sail into a loose pouch like a rock sack. And I stow away in the back of a garbage truck to the junkyard where I pick through to the debris for only the broken toys that make me the saddest until I have loaded for you. The Hobo Fugitives bug out, bindle of news that is my segment. Me one!" + " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's biggest stories right over there. Meticulously selecting the most topical chakra affirming scented candles, and using Feng Shui to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue. But sometimes just sometimes I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself, and used fry oil, wrap my hands with some double-duct tape by stole from the broken car window. Pound a six-pack of blueberry hard-seltzer and a sack of pills I stole from a parked ambulance. Then arm wrestle a raccoon in the back alley vision quest of news that is my segment. Meanwhile!" + " You know, folks, I spend most of my time right over there. Mining the day's biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels. Then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press-black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards in a faceplate and, finally, using fluted strips of white alloyed molding, I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating, Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes folks. Sometimes, just sometimes, I come into my sense as fully naked on the deck of a pirate besieged melee container ship that picked me up floating on the detached door of a portapotty in the Indian Ocean. Then after a sunstroke-induced realization of the crew of this ship plans to sell me an exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe at a pool chain that accepting my new role as Captain and declaring myself king of the windarc seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create the sopping wet pirate crown of news that is my segment. Meanwhile!" + " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's Newsiest most topical flower eggs milk and butter and Stranding into a fine batter to make delicate and informative comedy pancakes Then I glaze them in the juice and zest of the most relevant midnight Valencia oranges and douse it all and a fine Dela main de voyage cognac Before prom baying and basting them tables. I deserve for you the James Beard award worthy crepe suzzette That is my nightly monologue, but sometimes just sometimes folks. I wake up in the baggage hold of Greyhound bus. It's being hoisted by the scrap yard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps and busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strain, pair of sweatpants and as oven mitts to extract and serve the demented transience poundcake of news that is my segment. Me, Guadalupe!" + " Folks, if you watched the show and I hope you do, I spent a lot of time right over there. Tiredlessly studying the lineage of the days most important thoroughbred stories and whole-stiner headlines, working with the best trainers, money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen. That is my nightly monologue, but sometimes, sometimes, folks, I break into an unincorporated veterinary genetics lab and grab whatever test tubes I can find and then under a grow light I got from a discarded chia pet. I mixed the pilfered DNA of a horse and whatever was in a tube labeled Keith Colan extra. Slurrying the concoction with caffeine pills and a microwave red bull, I screamed, sang a prayer to Janice, initiator of human life and God of transformation as a half horse, half man, freak. Seizes to life before me and the hideous collection of loose animal parts and corrupted man tissue that is my segment. Meanwhile!" + ] + # fmt: on + + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model = model.to("cuda") + + ds = load_dataset("distil-whisper/meanwhile", "default")["test"] + ds = ds.cast_column("audio", Audio(sampling_rate=16000)) + + num_samples = 8 + + audio = ds[:num_samples]["audio"] + audios = [x["array"] for x in audio] + + inputs = processor( + audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True + ) + inputs = inputs.to(device="cuda") + + gen_kwargs = {"return_timestamps": True} + gen_kwargs["no_speech_threshold"] = 0.6 + gen_kwargs["compression_ratio_threshold"] = 2.4 + gen_kwargs["condition_on_prev_tokens"] = True + gen_kwargs["logprob_threshold"] = -1.0 + gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) + gen_kwargs["num_beams"] = 5 + + result = model.generate(**inputs, **gen_kwargs) + decoded_all = processor.batch_decode(result, skip_special_tokens=True) + + for i in range(num_samples): + assert decoded_all[i] == EXPECTED_TEXT[i] def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): if head_mask is None: From 46cdb431bb0f8b7183404a94f22b4f3a29a5e497 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 8 Dec 2023 15:47:04 +0000 Subject: [PATCH 25/75] correct more --- .../models/whisper/modeling_whisper.py | 53 +++++++++++++------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index c0877ed96856..9dbf338cbb97 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -20,6 +20,7 @@ import numpy as np import copy import torch +import zlib import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -49,6 +50,11 @@ from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE +from transformers import AutoTokenizer + +# tok = AutoTokenizer.from_pretrained("openai/whisper-tiny") + + if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -2127,11 +2133,7 @@ def generate( begin_index = 1 if generation_config.return_timestamps is True: forced_decoder_ids = generation_config.forced_decoder_ids - last_forced_decoder_ids = ( - forced_decoder_ids[-1][-1] - if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None - else None - ) + last_forced_decoder_ids = forced_decoder_ids[-1][-1] if forced_decoder_ids is not None else None if last_forced_decoder_ids == generation_config.no_timestamps_token_id: # remove no_timestamp to be forcefully generated if we want to return timestamps # this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly @@ -2370,23 +2372,22 @@ def split_by_batch_index(values, key, batch_idx): return None return values[batch_idx].cpu() - seek_sequences = seek_outputs["sequences"] + sequence_tokens = seek_outputs["sequences"] seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(cur_bsz)] else: - seek_sequences = seek_outputs + sequence_tokens = seek_outputs # remove all previously passed decoder input ids - seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1]:] + seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] needs_fallback = False if compression_ratio_threshold is not None: - compression_ratio = [seek_sequence.shape[0] / torch.unique(seek_sequence).shape[0] for seek_sequence in seek_sequences] - # TODO(PVP) only works for batch size = 1 currently - if compression_ratio[0] > compression_ratio_threshold: + compression_ratio = self._retrieve_compression_ratio(sequence_tokens[0], self.config.vocab_size) + + if compression_ratio > compression_ratio_threshold: print("fallback compression") print("current temp", temperature) - needs_fallback = True if logprob_threshold is not None: @@ -2394,7 +2395,7 @@ def split_by_batch_index(values, key, batch_idx): logprobs = [s["sequences_scores"] for s in seek_outputs] else: scores = [s["scores"] for s in seek_outputs] - logprobs = self._retrieve_avg_logprobs(scores, seek_sequences, self.config.eos_token_id) + logprobs = self._retrieve_avg_logprobs(scores, seek_sequences, self.config.eos_token_id, temperature) # TODO(PVP) only works for batch size = 1 currently if logprobs[0] < logprob_threshold: @@ -2405,7 +2406,8 @@ def split_by_batch_index(values, key, batch_idx): if no_speech_threshold is not None: # TODO(PVP) only works for batch size = 1 currently # Need to do before all other logit processors - if no_speech_detector.no_speech_prob[0].cpu() > no_speech_threshold and logprobs[0] < logprob_threshold: + if no_speech_detector.no_speech_prob[0] > no_speech_threshold and logprobs[0] < logprob_threshold: + print("Skip because of VAD") needs_fallback = False should_skip = True @@ -2481,9 +2483,28 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right"): return sequences @staticmethod - def _retrieve_avg_logprobs(scores, tokens, eos_token_id): + def _retrieve_compression_ratio(tokens, vocab_size): + length = int(math.log2(vocab_size) / 8) + 1 + token_bytes = b''.join([t.to_bytes(length, 'little') for t in tokens.tolist()]) + + # string = tok.decode(tokens, skip_special_tokens=True) + # string_bytes = string.encode("utf-8") + # string_compression_ratio = len(string_bytes) / len(zlib.compress(string_bytes)) + + compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes)) + + # print(f"HERE: string: {string}") + # print(f"HERE: string ratio: {string_compression_ratio}") + # print(f"HERE: token ratio: {compression_ratio}") + # print('HERE:' + 20 * '-') + + return compression_ratio + + @staticmethod + def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature): + rescale_temperature = temperature if temperature > 0.0 else 1 scores = torch.stack([torch.stack(score) for score in scores]).to(tokens.device) - logprobs = F.log_softmax(scores.float(), dim=-1).to(scores.dtype) + logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype) tokens = tokens[:, -scores.shape[1]:] def get_log_prob(logprob, token): From 6818ebf480f3798f93164a3a17ba04a26b6e0a95 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 8 Dec 2023 16:47:23 +0000 Subject: [PATCH 26/75] correct more --- .../models/whisper/modeling_whisper.py | 79 +++++++++++-------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 9dbf338cbb97..5c56fa577b07 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2335,7 +2335,8 @@ def generate( print("hf in tokens", decoder_input_ids[0].tolist()) # 6.6 Batch generate current chunk - should_skip = False + token_sequences = [None for _ in range(cur_bsz)] + needs_fallback = [False for _ in range(cur_bsz)] for temperature in temperatures: do_sample = temperature > 0.0 @@ -2380,40 +2381,56 @@ def split_by_batch_index(values, key, batch_idx): # remove all previously passed decoder input ids seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] - needs_fallback = False - if compression_ratio_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - compression_ratio = self._retrieve_compression_ratio(sequence_tokens[0], self.config.vocab_size) - - if compression_ratio > compression_ratio_threshold: - print("fallback compression") - print("current temp", temperature) - needs_fallback = True - - if logprob_threshold is not None: - if "sequences_scores" in seek_outputs[0]: - logprobs = [s["sequences_scores"] for s in seek_outputs] - else: - scores = [s["scores"] for s in seek_outputs] - logprobs = self._retrieve_avg_logprobs(scores, seek_sequences, self.config.eos_token_id, temperature) - - # TODO(PVP) only works for batch size = 1 currently - if logprobs[0] < logprob_threshold: - print("fallback logprob") - print("current temp", temperature) - needs_fallback = True + if False: + # 6.7 Loop over each decoded audio individually as each decoding can be of a different length + for i, seek_sequence in enumerate(seek_sequences): + # make sure we cut a predicted EOS token if we are not finished with the generation yet + is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] + if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: + seek_sequence = seek_sequence[:-1] + + print("hf out tokens", seek_sequence.tolist()) + + # remove all padding tokens + if seek_sequence[-1] == generation_config.pad_token_id: + num_paddings = (seek_sequence == generation_config.pad_token_id).sum() + seek_sequence = seek_sequence[:-num_paddings] + + if compression_ratio_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + compression_ratio = self._retrieve_compression_ratio(sequence_tokens[i], self.config.vocab_size) + + if compression_ratio > compression_ratio_threshold: + print("fallback compression") + print("current temp", temperature) + needs_fallback[i] = True + + if logprob_threshold is not None: + if "sequences_scores" in seek_outputs[0]: + logprobs = [s["sequences_scores"] for s in seek_outputs] + else: + scores = [s["scores"] for s in seek_outputs] + logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) + + # TODO(PVP) only works for batch size = 1 currently + if logprobs < logprob_threshold: + print("current temp", temperature) + needs_fallback[i] = True - if no_speech_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - # Need to do before all other logit processors - if no_speech_detector.no_speech_prob[0] > no_speech_threshold and logprobs[0] < logprob_threshold: - print("Skip because of VAD") - needs_fallback = False - should_skip = True + if no_speech_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + # Need to do before all other logit processors + if no_speech_detector.no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: + print("Skip because of VAD") + needs_fallback[i] = False + should_skip = True if not needs_fallback: break + # if not needs_fallback.any(): + # break + # 6.7 Loop over each decoded audio individually as each decoding can be of a different length for i, seek_sequence in enumerate(seek_sequences): prev_i = cur_to_prev_index_map[i] @@ -2516,7 +2533,7 @@ def get_log_prob(logprob, token): lengths = (tokens != eos_token_id).sum(-1) avg_logprobs = torch.div(sum_logprobs, lengths + 1) - return avg_logprobs + return avg_logprobs[0] @staticmethod def _retrieve_segment( From b46b63d6ffb5adccb17dc570059da650800c078f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 8 Dec 2023 16:58:27 +0000 Subject: [PATCH 27/75] correct more --- .../models/whisper/modeling_whisper.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 5c56fa577b07..2ce5e2a1f8d1 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2337,6 +2337,8 @@ def generate( # 6.6 Batch generate current chunk token_sequences = [None for _ in range(cur_bsz)] needs_fallback = [False for _ in range(cur_bsz)] + needs_fallback = False + should_skip = False for temperature in temperatures: do_sample = temperature > 0.0 @@ -2425,6 +2427,35 @@ def split_by_batch_index(values, key, batch_idx): needs_fallback[i] = False should_skip = True + if compression_ratio_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + compression_ratio = self._retrieve_compression_ratio(sequence_tokens[0], self.config.vocab_size) + + if compression_ratio > compression_ratio_threshold: + print("fallback compression") + print("current temp", temperature) + needs_fallback = True + + if logprob_threshold is not None: + if "sequences_scores" in seek_outputs[0]: + logprobs = [s["sequences_scores"] for s in seek_outputs] + else: + scores = [s["scores"] for s in seek_outputs] + logprobs = self._retrieve_avg_logprobs(scores, seek_sequences, self.config.eos_token_id, temperature) + + # TODO(PVP) only works for batch size = 1 currently + if logprobs < logprob_threshold: + print("current temp", temperature) + needs_fallback = True + + if no_speech_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + # Need to do before all other logit processors + if no_speech_detector.no_speech_prob[0] > no_speech_threshold and logprobs < logprob_threshold: + print("Skip because of VAD") + needs_fallback = False + should_skip = True + if not needs_fallback: break From 184c888fbe8dc4e063426044bf16b7780c4686fb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 9 Dec 2023 12:40:18 +0100 Subject: [PATCH 28/75] push --- .../models/whisper/modeling_whisper.py | 209 +++++++++++------- 1 file changed, 128 insertions(+), 81 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 2ce5e2a1f8d1..d55784870e12 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2337,8 +2337,10 @@ def generate( # 6.6 Batch generate current chunk token_sequences = [None for _ in range(cur_bsz)] needs_fallback = [False for _ in range(cur_bsz)] - needs_fallback = False - should_skip = False + should_skip = [False for _ in range(cur_bsz)] + fallback_index_map = list(range(cur_bsz)) + # needs_fallback = False + # should_skip = False for temperature in temperatures: do_sample = temperature > 0.0 @@ -2383,7 +2385,7 @@ def split_by_batch_index(values, key, batch_idx): # remove all previously passed decoder input ids seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] - if False: + if True: # 6.7 Loop over each decoded audio individually as each decoding can be of a different length for i, seek_sequence in enumerate(seek_sequences): # make sure we cut a predicted EOS token if we are not finished with the generation yet @@ -2398,108 +2400,153 @@ def split_by_batch_index(values, key, batch_idx): num_paddings = (seek_sequence == generation_config.pad_token_id).sum() seek_sequence = seek_sequence[:-num_paddings] + if compression_ratio_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + compression_ratio = self._retrieve_compression_ratio(sequence_tokens[i], self.config.vocab_size) + + if compression_ratio > compression_ratio_threshold: + print("fallback compression") + print("current temp", temperature) + needs_fallback[i] = True + + if logprob_threshold is not None: + if "sequences_scores" in seek_outputs[0]: + logprobs = [s["sequences_scores"] for s in seek_outputs] + else: + scores = [s["scores"] for s in seek_outputs] + logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) + + # TODO(PVP) only works for batch size = 1 currently + if logprobs < logprob_threshold: + print("current temp", temperature) + needs_fallback[i] = True + + if no_speech_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + # Need to do before all other logit processors + if no_speech_detector.no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: + print("Skip because of VAD") + needs_fallback[i] = False + should_skip[i] = True + + new_fallback_index_map = [] + new_segment_input = [] + new_decoder_input_ids = [] + for i, seek_sequence in enumerate(seek_sequences): + if needs_fallback[i]: + new_fallback_index_map.append(fallback_index_map[i]) + new_segment_input.append(segment_input[i]) + new_decoder_input_ids.append(decoder_input_ids[i]) + + token_sequences[fallback_index_map[i]] = seek_sequence + + fallback_index_map = new_fallback_index_map + seek_sequences = torch.stack(token_sequences) + + if len(fallback_index_map) == 0: + break + + decoder_input_ids = torch.stack(new_decoder_input_ids) + segment_input = torch.stack(new_segment_input) + + + if False: if compression_ratio_threshold is not None: # TODO(PVP) only works for batch size = 1 currently - compression_ratio = self._retrieve_compression_ratio(sequence_tokens[i], self.config.vocab_size) + compression_ratio = self._retrieve_compression_ratio(sequence_tokens[0], self.config.vocab_size) if compression_ratio > compression_ratio_threshold: print("fallback compression") print("current temp", temperature) - needs_fallback[i] = True + needs_fallback = True if logprob_threshold is not None: if "sequences_scores" in seek_outputs[0]: logprobs = [s["sequences_scores"] for s in seek_outputs] else: scores = [s["scores"] for s in seek_outputs] - logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) + logprobs = self._retrieve_avg_logprobs(scores, seek_sequences, self.config.eos_token_id, temperature) # TODO(PVP) only works for batch size = 1 currently if logprobs < logprob_threshold: print("current temp", temperature) - needs_fallback[i] = True - + needs_fallback = True + if no_speech_threshold is not None: # TODO(PVP) only works for batch size = 1 currently # Need to do before all other logit processors - if no_speech_detector.no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: + if no_speech_detector.no_speech_prob[0] > no_speech_threshold and logprobs < logprob_threshold: print("Skip because of VAD") - needs_fallback[i] = False + needs_fallback = False should_skip = True + + if not needs_fallback: + break + + if True: + for i, seek_sequence in enumerate(seek_sequences): + prev_i = cur_to_prev_index_map[i] + + if should_skip[i]: + seek[prev_i] += seek_num_frames[prev_i] + print("Skipped!") + continue + + # TODO(Patrick: delete cut type) + segments, segment_offset, cut_type = self._retrieve_segment( + seek_sequence=seek_sequence, + seek_outputs=seek_outputs, + time_offset=time_offset, + timestamp_begin=timestamp_begin, + seek_num_frames=seek_num_frames, + time_precision=time_precision, + input_stride=input_stride, + prev_idx=prev_i, + idx=i, + ) - if compression_ratio_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - compression_ratio = self._retrieve_compression_ratio(sequence_tokens[0], self.config.vocab_size) - - if compression_ratio > compression_ratio_threshold: - print("fallback compression") - print("current temp", temperature) - needs_fallback = True - - if logprob_threshold is not None: - if "sequences_scores" in seek_outputs[0]: - logprobs = [s["sequences_scores"] for s in seek_outputs] - else: - scores = [s["scores"] for s in seek_outputs] - logprobs = self._retrieve_avg_logprobs(scores, seek_sequences, self.config.eos_token_id, temperature) - - # TODO(PVP) only works for batch size = 1 currently - if logprobs < logprob_threshold: - print("current temp", temperature) - needs_fallback = True - - if no_speech_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - # Need to do before all other logit processors - if no_speech_detector.no_speech_prob[0] > no_speech_threshold and logprobs < logprob_threshold: - print("Skip because of VAD") - needs_fallback = False - should_skip = True - - if not needs_fallback: - break - - # if not needs_fallback.any(): - # break - - # 6.7 Loop over each decoded audio individually as each decoding can be of a different length - for i, seek_sequence in enumerate(seek_sequences): - prev_i = cur_to_prev_index_map[i] - - if should_skip: - seek[prev_i] += seek_num_frames[prev_i] - print("Skipped!") - continue - - # make sure we cut a predicted EOS token if we are not finished with the generation yet - is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] - if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: - seek_sequence = seek_sequence[:-1] - - print("hf out tokens", seek_sequence.tolist()) - - # remove all padding tokens - if seek_sequence[-1] == generation_config.pad_token_id: - num_paddings = (seek_sequence == generation_config.pad_token_id).sum() - seek_sequence = seek_sequence[:-num_paddings] - - # TODO(Patrick: delete cut type) - segments, segment_offset, cut_type = self._retrieve_segment( - seek_sequence=seek_sequence, - seek_outputs=seek_outputs, - time_offset=time_offset, - timestamp_begin=timestamp_begin, - seek_num_frames=seek_num_frames, - time_precision=time_precision, - input_stride=input_stride, - prev_idx=prev_i, - idx=i, - ) + current_segments[prev_i] += segments + seek[prev_i] += segment_offset + + if False: + # 6.7 Loop over each decoded audio individually as each decoding can be of a different length + for i, seek_sequence in enumerate(seek_sequences): + prev_i = cur_to_prev_index_map[i] + + if should_skip: + seek[prev_i] += seek_num_frames[prev_i] + print("Skipped!") + continue + + # make sure we cut a predicted EOS token if we are not finished with the generation yet + is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] + if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: + seek_sequence = seek_sequence[:-1] + + print("hf out tokens", seek_sequence.tolist()) + + # remove all padding tokens + if seek_sequence[-1] == generation_config.pad_token_id: + num_paddings = (seek_sequence == generation_config.pad_token_id).sum() + seek_sequence = seek_sequence[:-num_paddings] + + # TODO(Patrick: delete cut type) + segments, segment_offset, cut_type = self._retrieve_segment( + seek_sequence=seek_sequence, + seek_outputs=seek_outputs, + time_offset=time_offset, + timestamp_begin=timestamp_begin, + seek_num_frames=seek_num_frames, + time_precision=time_precision, + input_stride=input_stride, + prev_idx=prev_i, + idx=i, + ) - current_segments[prev_i] += segments - seek[prev_i] += segment_offset + current_segments[prev_i] += segments + seek[prev_i] += segment_offset - print(f"{cut_type} seek {seek[0]}") + print(f"{cut_type} seek {seek[0]}") # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output From e4b7827e49d1e41c1b5362d402246eb9e7b91af2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 9 Dec 2023 12:13:27 +0000 Subject: [PATCH 29/75] correct more --- .../models/whisper/modeling_whisper.py | 44 ++++++++++--------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d55784870e12..e31245aa4df5 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Whisper model.""" +RUN_NEW_WAY = False import math from typing import Optional, Tuple, Union @@ -2335,12 +2336,16 @@ def generate( print("hf in tokens", decoder_input_ids[0].tolist()) # 6.6 Batch generate current chunk - token_sequences = [None for _ in range(cur_bsz)] - needs_fallback = [False for _ in range(cur_bsz)] - should_skip = [False for _ in range(cur_bsz)] - fallback_index_map = list(range(cur_bsz)) - # needs_fallback = False - # should_skip = False + if RUN_NEW_WAY: + token_sequences = [None for _ in range(cur_bsz)] + needs_fallback = [False for _ in range(cur_bsz)] + should_skip = [False for _ in range(cur_bsz)] + fallback_index_map = list(range(cur_bsz)) + + if not RUN_NEW_WAY: + needs_fallback = False + should_skip = False + for temperature in temperatures: do_sample = temperature > 0.0 @@ -2385,7 +2390,7 @@ def split_by_batch_index(values, key, batch_idx): # remove all previously passed decoder input ids seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] - if True: + if RUN_NEW_WAY: # 6.7 Loop over each decoded audio individually as each decoding can be of a different length for i, seek_sequence in enumerate(seek_sequences): # make sure we cut a predicted EOS token if we are not finished with the generation yet @@ -2413,7 +2418,7 @@ def split_by_batch_index(values, key, batch_idx): if "sequences_scores" in seek_outputs[0]: logprobs = [s["sequences_scores"] for s in seek_outputs] else: - scores = [s["scores"] for s in seek_outputs] + scores = seek_outputs[i]["scores"] logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) # TODO(PVP) only works for batch size = 1 currently @@ -2450,7 +2455,7 @@ def split_by_batch_index(values, key, batch_idx): segment_input = torch.stack(new_segment_input) - if False: + if not RUN_NEW_WAY: if compression_ratio_threshold is not None: # TODO(PVP) only works for batch size = 1 currently compression_ratio = self._retrieve_compression_ratio(sequence_tokens[0], self.config.vocab_size) @@ -2483,7 +2488,7 @@ def split_by_batch_index(values, key, batch_idx): if not needs_fallback: break - if True: + if RUN_NEW_WAY: for i, seek_sequence in enumerate(seek_sequences): prev_i = cur_to_prev_index_map[i] @@ -2508,7 +2513,7 @@ def split_by_batch_index(values, key, batch_idx): current_segments[prev_i] += segments seek[prev_i] += segment_offset - if False: + if not RUN_NEW_WAY: # 6.7 Loop over each decoded audio individually as each decoding can be of a different length for i, seek_sequence in enumerate(seek_sequences): prev_i = cur_to_prev_index_map[i] @@ -2598,20 +2603,17 @@ def _retrieve_compression_ratio(tokens, vocab_size): @staticmethod def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature): rescale_temperature = temperature if temperature > 0.0 else 1 - scores = torch.stack([torch.stack(score) for score in scores]).to(tokens.device) + scores = torch.stack(scores).to(tokens.device) logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype) - tokens = tokens[:, -scores.shape[1]:] - - def get_log_prob(logprob, token): - token_logprob = logprob.gather(-1, token)[:, 0] * (token[:, -1] != eos_token_id) - return token_logprob + tokens = tokens[-scores.shape[0]:] - sum_logprobs = sum(get_log_prob(logprobs[:, i], tokens[:, i: i+1]) for i in range(logprobs.shape[1])) + # retrieve logprob of selected tokens and sum + sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0])) - lengths = (tokens != eos_token_id).sum(-1) + length = (tokens != eos_token_id).sum(-1) - avg_logprobs = torch.div(sum_logprobs, lengths + 1) - return avg_logprobs[0] + avg_logprobs = sum_logprobs / (length + 1) + return avg_logprobs @staticmethod def _retrieve_segment( From 380bd547d537a7ac01adc92638d6cf13f9836925 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 9 Dec 2023 13:42:26 +0000 Subject: [PATCH 30/75] Fix more --- .../models/whisper/modeling_whisper.py | 34 +++++++++++-------- tests/models/whisper/test_modeling_whisper.py | 29 ++++++++-------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index e31245aa4df5..2e50fd04295b 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Whisper model.""" -RUN_NEW_WAY = False +RUN_NEW_WAY = True import math from typing import Optional, Tuple, Union @@ -2337,7 +2337,7 @@ def generate( # 6.6 Batch generate current chunk if RUN_NEW_WAY: - token_sequences = [None for _ in range(cur_bsz)] + seek_sequence_list = [None for _ in range(cur_bsz)] needs_fallback = [False for _ in range(cur_bsz)] should_skip = [False for _ in range(cur_bsz)] fallback_index_map = list(range(cur_bsz)) @@ -2346,7 +2346,7 @@ def generate( needs_fallback = False should_skip = False - for temperature in temperatures: + for fallback_idx, temperature in enumerate(temperatures): do_sample = temperature > 0.0 num_beams = kwargs.pop("num_beams", 1) @@ -2391,6 +2391,9 @@ def split_by_batch_index(values, key, batch_idx): seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] if RUN_NEW_WAY: + new_fallback_index_map = [] + new_segment_input = [] + new_decoder_input_ids = [] # 6.7 Loop over each decoded audio individually as each decoding can be of a different length for i, seek_sequence in enumerate(seek_sequences): # make sure we cut a predicted EOS token if we are not finished with the generation yet @@ -2416,13 +2419,14 @@ def split_by_batch_index(values, key, batch_idx): if logprob_threshold is not None: if "sequences_scores" in seek_outputs[0]: - logprobs = [s["sequences_scores"] for s in seek_outputs] + logprobs = [s["sequences_scores"] for s in seek_outputs][i] else: scores = seek_outputs[i]["scores"] logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) # TODO(PVP) only works for batch size = 1 currently if logprobs < logprob_threshold: + print("fallback logprobs", logprobs) print("current temp", temperature) needs_fallback[i] = True @@ -2434,27 +2438,22 @@ def split_by_batch_index(values, key, batch_idx): needs_fallback[i] = False should_skip[i] = True - new_fallback_index_map = [] - new_segment_input = [] - new_decoder_input_ids = [] - for i, seek_sequence in enumerate(seek_sequences): if needs_fallback[i]: new_fallback_index_map.append(fallback_index_map[i]) new_segment_input.append(segment_input[i]) new_decoder_input_ids.append(decoder_input_ids[i]) - token_sequences[fallback_index_map[i]] = seek_sequence + seek_sequence_list[fallback_index_map[i]] = seek_sequence fallback_index_map = new_fallback_index_map - seek_sequences = torch.stack(token_sequences) - if len(fallback_index_map) == 0: + if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1: + seek_sequences = seek_sequence_list break decoder_input_ids = torch.stack(new_decoder_input_ids) segment_input = torch.stack(new_segment_input) - if not RUN_NEW_WAY: if compression_ratio_threshold is not None: # TODO(PVP) only works for batch size = 1 currently @@ -2470,7 +2469,7 @@ def split_by_batch_index(values, key, batch_idx): logprobs = [s["sequences_scores"] for s in seek_outputs] else: scores = [s["scores"] for s in seek_outputs] - logprobs = self._retrieve_avg_logprobs(scores, seek_sequences, self.config.eos_token_id, temperature) + logprobs = self._retrieve_avg_logprobs(scores[0], seek_sequences[0], self.config.eos_token_id, temperature) # TODO(PVP) only works for batch size = 1 currently if logprobs < logprob_threshold: @@ -2604,12 +2603,17 @@ def _retrieve_compression_ratio(tokens, vocab_size): def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature): rescale_temperature = temperature if temperature > 0.0 else 1 scores = torch.stack(scores).to(tokens.device) + + # TODO(Patrick) - only leave scores = scores[:tokens.shape[0]] part + if scores.shape[0] > tokens.shape[0]: + scores = scores[:tokens.shape[0]] + else: + tokens = tokens[-scores.shape[0]:] + logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype) - tokens = tokens[-scores.shape[0]:] # retrieve logprob of selected tokens and sum sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0])) - length = (tokens != eos_token_id).sum(-1) avg_logprobs = sum_logprobs / (length + 1) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index e2cd45205982..b4d8f609e259 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2109,7 +2109,7 @@ def test_whisper_longform_single_batch_prev_cond(self): gen_kwargs = {"return_timestamps": True} gen_kwargs["no_speech_threshold"] = 0.6 gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) - gen_kwargs["compression_ratio_threshold"] = 2.4 + gen_kwargs["compression_ratio_threshold"] = 1.35 gen_kwargs["condition_on_prev_tokens"] = True gen_kwargs["logprob_threshold"] = -1.0 @@ -2117,9 +2117,7 @@ def test_whisper_longform_single_batch_prev_cond(self): result = model.generate(input_features, **gen_kwargs) decoded = processor.batch_decode(result, skip_special_tokens=True) - result = f'"""{decoded[0]}"""' - - assert result == EXPECTED_TEXT + assert decoded == EXPECTED_TEXT @slow @@ -2195,7 +2193,7 @@ def test_whisper_longform_multi_batch_prev_cond(self): gen_kwargs = {"return_timestamps": True} gen_kwargs["no_speech_threshold"] = 0.6 gen_kwargs["temperature"] = 0.0 - gen_kwargs["compression_ratio_threshold"] = 2.4 + gen_kwargs["compression_ratio_threshold"] = 1.35 gen_kwargs["condition_on_prev_tokens"] = True gen_kwargs["logprob_threshold"] = -1.0 @@ -2233,13 +2231,13 @@ def test_whisper_longform_multi_batch_prev_cond(self): def test_whisper_longform_multi_batch_hard(self): # fmt: off EXPECTED_TEXT = [ - " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile." - " Folks, I spend a lot of time right over there, night after night after night, actually. Carefully selecting for you the day's noosiest, most aerodynamic headlines, stress testing, and those topical anti-lock breaks and power steering, painstakingly stitching, leather seating so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school and slap myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen moon, render a gas tank out of an empty big gulp, fill with white claw and denatured alcohol, then light a match and let her rip and the demented one man soapbox derby of news that is my segment. Me, Guadalupe! No!" - " Ladies and gentlemen, you know, I spent a lot of time right over there Raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their jokes swollen teats Churning the daily stories into the decadent proven-style style triple cream breed that is my nightly monologue But sometimes sometimes folks I stagger home hungry after being released by the police and Root around in the neighbor's trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rod I won from a rat in a pre-donned street fight. Put it in a discarded paint can to leave it to ferment next to a trash fire then hunker down and hallucinate while eating the listeria laden demon custard of news that is my segment. You mean one of them." - " Folks, if you watch this show, you know I spend most of my time right over there carefully sorting through the day's biggest stories and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ichol Gregoire Ferrandi, who carefully dye them in a palette of bright zesty shades and adorn them in the finest and most topical inlay work using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddles stitching. In line it with bees, wax, coated linen, finely attached a mallet, hammered strap, pearled hardware, and close-shit to create for you the one-of-a-kind hoke couture, Erme's Birkin bag that is my monologue. But sometimes, sometimes folks, sometimes. Sometimes I wake up in the last car of an abandoned roller coaster at Coney Island where I'm I'm hiding from the triads. I have some engine lubricants out of a safe way bag and stagger down the shore to tear the sail off a beach schooner. Then I rip the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel lovely folks. And use it to stitch the sail into a loose pouch like a rock sack. And I stow away in the back of a garbage truck to the junkyard where I pick through to the debris for only the broken toys that make me the saddest until I have loaded for you. The Hobo Fugitives bug out, bindle of news that is my segment. Me one!" - " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's biggest stories right over there. Meticulously selecting the most topical chakra affirming scented candles, and using Feng Shui to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue. But sometimes just sometimes I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself, and used fry oil, wrap my hands with some double-duct tape by stole from the broken car window. Pound a six-pack of blueberry hard-seltzer and a sack of pills I stole from a parked ambulance. Then arm wrestle a raccoon in the back alley vision quest of news that is my segment. Meanwhile!" - " You know, folks, I spend most of my time right over there. Mining the day's biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels. Then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press-black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards in a faceplate and, finally, using fluted strips of white alloyed molding, I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating, Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes folks. Sometimes, just sometimes, I come into my sense as fully naked on the deck of a pirate besieged melee container ship that picked me up floating on the detached door of a portapotty in the Indian Ocean. Then after a sunstroke-induced realization of the crew of this ship plans to sell me an exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe at a pool chain that accepting my new role as Captain and declaring myself king of the windarc seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create the sopping wet pirate crown of news that is my segment. Meanwhile!" - " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's Newsiest most topical flower eggs milk and butter and Stranding into a fine batter to make delicate and informative comedy pancakes Then I glaze them in the juice and zest of the most relevant midnight Valencia oranges and douse it all and a fine Dela main de voyage cognac Before prom baying and basting them tables. I deserve for you the James Beard award worthy crepe suzzette That is my nightly monologue, but sometimes just sometimes folks. I wake up in the baggage hold of Greyhound bus. It's being hoisted by the scrap yard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps and busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strain, pair of sweatpants and as oven mitts to extract and serve the demented transience poundcake of news that is my segment. Me, Guadalupe!" + " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile.", + " Folks, I spend a lot of time right over there, night after night after night, actually. Carefully selecting for you the day's noosiest, most aerodynamic headlines, stress testing, and those topical anti-lock breaks and power steering, painstakingly stitching, leather seating so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school and slap myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen moon, render a gas tank out of an empty big gulp, fill with white claw and denatured alcohol, then light a match and let her rip and the demented one man soapbox derby of news that is my segment. Me, Guadalupe! No!", + " Ladies and gentlemen, you know, I spent a lot of time right over there Raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their jokes swollen teats Churning the daily stories into the decadent proven-style style triple cream breed that is my nightly monologue But sometimes sometimes folks I stagger home hungry after being released by the police and Root around in the neighbor's trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rod I won from a rat in a pre-donned street fight. Put it in a discarded paint can to leave it to ferment next to a trash fire then hunker down and hallucinate while eating the listeria laden demon custard of news that is my segment. You mean one of them.", + " Folks, if you watch this show, you know I spend most of my time right over there carefully sorting through the day's biggest stories and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ichol Gregoire Ferrandi, who carefully dye them in a palette of bright zesty shades and adorn them in the finest and most topical inlay work using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddles stitching. In line it with bees, wax, coated linen, finely attached a mallet, hammered strap, pearled hardware, and close-shit to create for you the one-of-a-kind hoke couture, Erme's Birkin bag that is my monologue. But sometimes, sometimes folks, sometimes. Sometimes I wake up in the last car of an abandoned roller coaster at Coney Island where I'm I'm hiding from the triads. I have some engine lubricants out of a safe way bag and stagger down the shore to tear the sail off a beach schooner. Then I rip the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel lovely folks. And use it to stitch the sail into a loose pouch like a rock sack. And I stow away in the back of a garbage truck to the junkyard where I pick through to the debris for only the broken toys that make me the saddest until I have loaded for you. The Hobo Fugitives bug out, bindle of news that is my segment. Me one!", + " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's biggest stories right over there. Meticulously selecting the most topical chakra affirming scented candles, and using Feng Shui to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue. But sometimes just sometimes I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself, and used fry oil, wrap my hands with some double-duct tape by stole from the broken car window. Pound a six-pack of blueberry hard-seltzer and a sack of pills I stole from a parked ambulance. Then arm wrestle a raccoon in the back alley vision quest of news that is my segment. Meanwhile!", + " You know, folks, I spend most of my time right over there. Mining the day's biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels. Then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press-black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards in a faceplate and, finally, using fluted strips of white alloyed molding, I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating, Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes folks. Sometimes, just sometimes, I come into my sense as fully naked on the deck of a pirate besieged melee container ship that picked me up floating on the detached door of a portapotty in the Indian Ocean. Then after a sunstroke-induced realization of the crew of this ship plans to sell me an exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe at a pool chain that accepting my new role as Captain and declaring myself king of the windarc seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create the sopping wet pirate crown of news that is my segment. Meanwhile!", + " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's Newsiest most topical flower eggs milk and butter and Stranding into a fine batter to make delicate and informative comedy pancakes Then I glaze them in the juice and zest of the most relevant midnight Valencia oranges and douse it all and a fine Dela main de voyage cognac Before prom baying and basting them tables. I deserve for you the James Beard award worthy crepe suzzette That is my nightly monologue, but sometimes just sometimes folks. I wake up in the baggage hold of Greyhound bus. It's being hoisted by the scrap yard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps and busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strain, pair of sweatpants and as oven mitts to extract and serve the demented transience poundcake of news that is my segment. Me, Guadalupe!", " Folks, if you watched the show and I hope you do, I spent a lot of time right over there. Tiredlessly studying the lineage of the days most important thoroughbred stories and whole-stiner headlines, working with the best trainers, money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen. That is my nightly monologue, but sometimes, sometimes, folks, I break into an unincorporated veterinary genetics lab and grab whatever test tubes I can find and then under a grow light I got from a discarded chia pet. I mixed the pilfered DNA of a horse and whatever was in a tube labeled Keith Colan extra. Slurrying the concoction with caffeine pills and a microwave red bull, I screamed, sang a prayer to Janice, initiator of human life and God of transformation as a half horse, half man, freak. Seizes to life before me and the hideous collection of loose animal parts and corrupted man tissue that is my segment. Meanwhile!" ] # fmt: on @@ -2274,7 +2272,10 @@ def test_whisper_longform_multi_batch_hard(self): for i in range(num_samples): assert decoded_all[i] == decoded_single[i] - assert decoded_all[i] == EXPECTED_TEXT[i] + try: + assert decoded_all[i] == EXPECTED_TEXT[i] + except: + import ipdb; ipdb.set_trace() @slow def test_whisper_longform_multi_batch_hard_prev_cond(self): @@ -2310,7 +2311,7 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): gen_kwargs = {"return_timestamps": True} gen_kwargs["no_speech_threshold"] = 0.6 - gen_kwargs["compression_ratio_threshold"] = 2.4 + gen_kwargs["compression_ratio_threshold"] = 1.35 gen_kwargs["condition_on_prev_tokens"] = True gen_kwargs["logprob_threshold"] = -1.0 gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) From 3e49df8a494c707134614269f846d446c1db9039 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 9 Dec 2023 19:25:41 +0000 Subject: [PATCH 31/75] Better --- src/transformers/generation/logits_process.py | 4 +- .../models/whisper/modeling_whisper.py | 68 ++++++++++++++----- tests/models/whisper/test_modeling_whisper.py | 2 + 3 files changed, 56 insertions(+), 18 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index b361e95e6f4c..1a7ac744f715 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1595,8 +1595,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob: scores[k, : self.timestamp_begin] = -float("inf") - if torch.isinf(scores).all(): - import ipdb; ipdb.set_trace() + # if torch.isinf(scores).all(): + # import ipdb; ipdb.set_trace() return scores diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 2e50fd04295b..861c61c91a3b 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -314,8 +314,11 @@ class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): super().__init__(num_positions, embedding_dim) - def forward(self, input_ids, past_key_values_length=0): - return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] + def forward(self, input_ids, past_key_values_length=0, position_ids=None): + if position_ids is None: + return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] + else: + return self.weight[position_ids] class WhisperAttention(nn.Module): @@ -1255,6 +1258,7 @@ def forward( cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, + position_ids=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -1353,9 +1357,9 @@ def forward( # embed positions if input_ids is not None: - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids) else: - positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1537,6 +1541,7 @@ def forward( encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1595,6 +1600,7 @@ def forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, + position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1668,6 +1674,7 @@ def forward( encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1722,6 +1729,7 @@ def forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, + decoder_position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -2228,6 +2236,7 @@ def generate( temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature temperature = temperatures[0] + do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)] output_scores = logprob_threshold is not None return_dict_in_generate = return_dict_in_generate or output_scores @@ -2287,16 +2296,22 @@ def generate( one_tensor = torch.ones((cur_bsz, 1), device=segment_input.device, dtype=torch.long) decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) - if condition_on_prev_tokens and len(current_segments[0]) > 0 and temperature < 0.5: + # if condition_on_prev_tokens and len(current_segments[0]) > 0 and temperature < 0.5: + if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0: # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 cut_off_length = self.config.max_target_positions // 2 - 1 - active_segments = [current_segments[i] for i in new_cur_to_prev_index_map] + active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in new_cur_to_prev_index_map] + prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or suppress_tokens_processor.suppress_tokens[-2] # TODO(Patrick): Need to put in generation_config + + bos_token_tensor = prev_start_of_text * one_tensor[0] prev_tokens = self._pad_to_max_length( - active_segments, generation_config.pad_token_id, padding="left" + active_segments, generation_config.pad_token_id, padding="left", bos_token_tensor=bos_token_tensor, cut_off_length=cut_off_length ) - prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or suppress_tokens_processor.suppress_tokens[-2] # TODO(Patrick): Need to put in generation_config + decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) + + kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) + kwargs["decoder_position_ids"] = (kwargs["decoder_attention_mask"].cumsum(-1) - 1).clamp(min=0) - decoder_input_ids = torch.cat([prev_start_of_text * one_tensor, prev_tokens[:, -cut_off_length:], decoder_input_ids], dim=-1) passed_max_length = kwargs.get("max_length", None) passed_max_new_tokens = kwargs.get("max_new_tokens", None) max_length_config = getattr(generation_config, "max_length", None) @@ -2338,6 +2353,7 @@ def generate( # 6.6 Batch generate current chunk if RUN_NEW_WAY: seek_sequence_list = [None for _ in range(cur_bsz)] + seek_outputs_list = [None for _ in range(cur_bsz)] needs_fallback = [False for _ in range(cur_bsz)] should_skip = [False for _ in range(cur_bsz)] fallback_index_map = list(range(cur_bsz)) @@ -2383,7 +2399,7 @@ def split_by_batch_index(values, key, batch_idx): return values[batch_idx].cpu() sequence_tokens = seek_outputs["sequences"] - seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(cur_bsz)] + seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0])] else: sequence_tokens = seek_outputs @@ -2438,17 +2454,20 @@ def split_by_batch_index(values, key, batch_idx): needs_fallback[i] = False should_skip[i] = True + seek_sequence_list[fallback_index_map[i]] = seek_sequence + seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] + do_condition_on_prev_tokens[fallback_index_map[i]] = temperature < 0.5 + if needs_fallback[i]: new_fallback_index_map.append(fallback_index_map[i]) new_segment_input.append(segment_input[i]) new_decoder_input_ids.append(decoder_input_ids[i]) - seek_sequence_list[fallback_index_map[i]] = seek_sequence - fallback_index_map = new_fallback_index_map if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1: seek_sequences = seek_sequence_list + seek_outputs = seek_outputs_list break decoder_input_ids = torch.stack(new_decoder_input_ids) @@ -2563,15 +2582,24 @@ def split_by_batch_index(values, key, batch_idx): return sequences @staticmethod - def _pad_to_max_length(current_segments, pad_token_id, padding="right"): + def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None): max_total_length = 0 sequences = [] if padding not in ["right", "left"]: raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}") for current_segment_list in current_segments: - sequences.append(torch.cat([d["tokens"] for d in current_segment_list], dim=-1)) - max_total_length = max(max_total_length, len(sequences[-1])) + if current_segment_list is not None: + sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1) + + if cut_off_length is not None: + sequence = sequence[-cut_off_length:] + + sequence = torch.cat([bos_token_tensor, sequence]) + sequences.append(sequence) + max_total_length = max(max_total_length, len(sequences[-1])) + else: + sequences.append(bos_token_tensor) for i in range(len(current_segments)): pad_length = max_total_length - len(sequences[i]) @@ -2703,6 +2731,8 @@ def prepare_inputs_for_generation( use_cache=None, encoder_outputs=None, attention_mask=None, + decoder_position_ids=None, + decoder_attention_mask=None, **kwargs, ): if past_key_values is not None: @@ -2717,12 +2747,18 @@ def prepare_inputs_for_generation( decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]: + decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] + + decoder_attention_mask = kwargs.pop("decoder_attention_mask", None) + return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "use_cache": use_cache, - "decoder_attention_mask": None, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, } @staticmethod diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index b4d8f609e259..21f0044bae77 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2320,7 +2320,9 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): result = model.generate(**inputs, **gen_kwargs) decoded_all = processor.batch_decode(result, skip_special_tokens=True) + torch.manual_seed(0) for i in range(num_samples): + import ipdb; ipdb.set_trace() assert decoded_all[i] == EXPECTED_TEXT[i] def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): From c2387edf9c5072d2ca47fe95bda8cc368051855f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 9 Dec 2023 21:25:09 +0000 Subject: [PATCH 32/75] without dec mask --- src/transformers/generation/logits_process.py | 4 ++-- .../models/whisper/modeling_whisper.py | 18 +++++++++++------- tests/models/whisper/test_modeling_whisper.py | 15 +++++++++------ 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 1a7ac744f715..92531c011dae 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1595,8 +1595,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob: scores[k, : self.timestamp_begin] = -float("inf") - # if torch.isinf(scores).all(): - # import ipdb; ipdb.set_trace() + if torch.isinf(scores).all(): + import ipdb; ipdb.set_trace() return scores diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 861c61c91a3b..f1de60947c37 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2309,8 +2309,7 @@ def generate( ) decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) - kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) - kwargs["decoder_position_ids"] = (kwargs["decoder_attention_mask"].cumsum(-1) - 1).clamp(min=0) + # kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) passed_max_length = kwargs.get("max_length", None) passed_max_new_tokens = kwargs.get("max_new_tokens", None) @@ -2426,7 +2425,8 @@ def split_by_batch_index(values, key, batch_idx): if compression_ratio_threshold is not None: # TODO(PVP) only works for batch size = 1 currently - compression_ratio = self._retrieve_compression_ratio(sequence_tokens[i], self.config.vocab_size) + # compression_ratio = self._retrieve_compression_ratio(sequence_tokens[i], self.config.vocab_size) + compression_ratio = self._retrieve_compression_ratio(seek_sequence, self.config.vocab_size) if compression_ratio > compression_ratio_threshold: print("fallback compression") @@ -2456,7 +2456,7 @@ def split_by_batch_index(values, key, batch_idx): seek_sequence_list[fallback_index_map[i]] = seek_sequence seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] - do_condition_on_prev_tokens[fallback_index_map[i]] = temperature < 0.5 + do_condition_on_prev_tokens[fallback_index_map[i]] = condition_on_prev_tokens and temperature < 0.5 if needs_fallback[i]: new_fallback_index_map.append(fallback_index_map[i]) @@ -2595,7 +2595,9 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_toke if cut_off_length is not None: sequence = sequence[-cut_off_length:] - sequence = torch.cat([bos_token_tensor, sequence]) + if bos_token_tensor is not None: + sequence = torch.cat([bos_token_tensor, sequence]) + sequences.append(sequence) max_total_length = max(max_total_length, len(sequences[-1])) else: @@ -2731,10 +2733,13 @@ def prepare_inputs_for_generation( use_cache=None, encoder_outputs=None, attention_mask=None, - decoder_position_ids=None, decoder_attention_mask=None, **kwargs, ): + decoder_position_ids = None + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) + if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -2750,7 +2755,6 @@ def prepare_inputs_for_generation( if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]: decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] - decoder_attention_mask = kwargs.pop("decoder_attention_mask", None) return { "encoder_outputs": encoder_outputs, diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 21f0044bae77..f432e8ea1990 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -84,6 +84,7 @@ def __init__( self.batch_size = batch_size self.max_length = max_length self.count = 0 + self.begin_index = 0 self.let_pass = [[] for _ in range(batch_size)] for k in range(batch_size): @@ -91,9 +92,14 @@ def __init__( for _ in range(10000): self.let_pass[k].append(random.randint(1, 10) <= 3) + def set_begin_index(self, begin_index: int): + self.begin_index = begin_index + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # we don't want to randomely sample timestamp tokens - if input_ids.shape[-1] > 1: + orig_scores = scores.clone() + + if input_ids.shape[-1] != self.begin_index: scores[:, self.timestamp_begin :] = -float("inf") self.no_time_stamp_counter = [x + 1 for x in self.no_time_stamp_counter] @@ -132,6 +138,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to self.count += 1 if torch.isinf(scores).all(): + import ipdb; ipdb.set_trace() raise ValueError("Dummy logit processor is incorrectly set up. Scores should not be all inf.") return scores @@ -2272,10 +2279,7 @@ def test_whisper_longform_multi_batch_hard(self): for i in range(num_samples): assert decoded_all[i] == decoded_single[i] - try: - assert decoded_all[i] == EXPECTED_TEXT[i] - except: - import ipdb; ipdb.set_trace() + assert decoded_all[i] == EXPECTED_TEXT[i] @slow def test_whisper_longform_multi_batch_hard_prev_cond(self): @@ -2322,7 +2326,6 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): torch.manual_seed(0) for i in range(num_samples): - import ipdb; ipdb.set_trace() assert decoded_all[i] == EXPECTED_TEXT[i] def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): From 1cca405d5bc83582c52c5fdd39abb5e1639b22ef Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 10 Dec 2023 15:28:12 +0000 Subject: [PATCH 33/75] correct more --- src/transformers/generation/logits_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 92531c011dae..2d7758243160 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1596,7 +1596,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to scores[k, : self.timestamp_begin] = -float("inf") if torch.isinf(scores).all(): - import ipdb; ipdb.set_trace() + print("RED FLAG") return scores From 29a983095071e7a3022eaa748fb3a4b0ac9a63bb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 10 Dec 2023 18:15:40 +0100 Subject: [PATCH 34/75] clean --- .../models/whisper/modeling_whisper.py | 825 +++++++++--------- 1 file changed, 404 insertions(+), 421 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f1de60947c37..d52d84972559 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -16,6 +16,7 @@ RUN_NEW_WAY = True import math +import warnings from typing import Optional, Tuple, Union import numpy as np @@ -1938,42 +1939,284 @@ def generate( ``` """ + # 0. deprecate old inputs if "inputs" in kwargs: input_features = kwargs.pop("inputs") warnings.warn( "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", FutureWarning, ) - + # 1. copy generation config if generation_config is None: generation_config = copy.deepcopy(self.generation_config) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else generation_config.return_dict_in_generate - ) - + # 2. set global generate variables input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] - if num_segment_frames is None: - num_segment_frames = input_stride * self.config.max_source_positions + num_segment_frames = num_frames or (input_stride * self.config.max_source_positions) + total_input_frames = self._retrieve_total_input_frames(input_features, kwargs) + is_shortform = total_input_frames <= num_segment_frames + + # 3. Make sure generation config is correctly set + # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not + self._set_return_outputs(return_dict_in_generate, return_token_timestamps, is_shortform, logprob_threshold, generation_config) + self._set_return_timestamps(return_timestamps, is_shortform, generation_config) + self._set_language_and_task(language, task, is_multilingual, generation_config) + # pass self.config for backward compatibility + self._set_forced_decoder_ids(task, language, prompt_ids, generation_config, kwargs, self.config) + self._set_num_frames(return_token_timestamps, generation_config, kwargs) + + # 4. Retrieve logits processors + logits_processor = self._retrieve_logit_processors(generation_config) + + # 5. If we're in shortform mode, simple generate the whole input at once and return the output + if is_shortform: + outputs = super().generate( + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) + + if generation_config.return_token_timestamps and hasattr(generation_config, "alignment_heads"): + outputs["token_timestamps"] = self._extract_token_timestamps( + outputs, generation_config.alignment_heads, num_frames=generation_config.num_frames + ) + + return outputs + + + # 6. Else we're in longform mode which is more complex. + # We need to chunk the audio input depending on when the model generates timestamp tokens + + # 6.1 Set and retrieve global longform generation variables + self._set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config) + + timestamp_begin = generation_config.no_timestamps_token_id + 1 + temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature + temperature = temperatures[0] + batch_size = input_features.shape[0] + + max_frames, seek = self._retrieve_max_frames_and_seek(batch_size, attention_mask) + init_tokens = self._retrieve_init_tokens_from_forced_decoder_ids(generation_config) + + # 6.2 Preppare running variables, list for generation + cur_bsz = batch_size + current_segments = [[] for _ in range(batch_size)] + batch_idx_map = list(range(batch_size)) + do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)] + + # 6.2 Transcribe audio until we reach the end of all input audios + while (seek < max_frames).any(): + # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop + # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order + # to know which original audio is being decoded + # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk + input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map) + time_offset = seek * time_precision / input_stride + seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) + + # 6.4 cut out next 30s segment from input features + segment_input = self._get_input_segment(input_features, cur_bsz, batch_idx_map) + + # 6.5 prepare decoder input ids + # TODO(Patrick) - clean up prev_start_of_text + prev_start_of_text = [l for l in logits_processor if isinstance(l, SuppressTokensLogitsProcessor)][0].suppress_tokens[-2] + decoder_input_ids, decoder_attention_mask = self._prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx_map, do_condition_on_prev_tokens, generation_config, self.config, device=segment_input.device, prev_start_of_text=prev_start_of_text) + + # 6.6 set max new tokens or max length + max_new_tokens, max_length = self._get_max_new_tokens_and_length(self.config, decoder_input_ids, generation_config, kwargs): + + # 6.7 Set current `begin_index` for all logit processors + for proc in logits_processor: + if hasattr(proc, "set_begin_index"): + proc.set_begin_index(decoder_input_ids.shape[-1]) + + print("hf in tokens", decoder_input_ids[0].tolist()) + + # 6.6 Batch generate current chunk + seek_sequence_list = [None for _ in range(cur_bsz)] + seek_outputs_list = [None for _ in range(cur_bsz)] + needs_fallback = [False for _ in range(cur_bsz)] + should_skip = [False for _ in range(cur_bsz)] + fallback_index_map = list(range(cur_bsz)) + + for fallback_idx, temperature in enumerate(temperatures): + generation_config.do_sample = temperature > 0.0 + generation_config.temperature = temperature + generation_config.num_beams = kwargs.pop("num_beams", 1) if not do_sample else 1 + + seek_outputs = super().generate( + segment_input, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + max_new_tokens=max_new_tokens, + max_length=max_length, + **kwargs, + ) + + if return_token_timestamps and hasattr(generation_config, "alignment_heads"): + num_frames = getattr(generation_config, "num_frames", None) + seek_outputs["token_timestamps"] = self._extract_token_timestamps( + seek_outputs, generation_config.alignment_heads, num_frames=num_frames + ) + + if generation_config.return_dict_in_generate: + def split_by_batch_index(values, key, batch_idx): + if key == "scores": + return list(v[batch_idx].cpu() for v in values) + if key == "past_key_values": + # we don't save `past_key_values` as this is too costly + return None + return values[batch_idx].cpu() + + sequence_tokens = seek_outputs["sequences"] + seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0])] + else: + sequence_tokens = seek_outputs + + # remove all previously passed decoder input ids + seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] + + new_fallback_index_map = [] + new_segment_input = [] + new_decoder_input_ids = [] + # 6.7 Loop over each decoded audio individually as each decoding can be of a different length + for i, seek_sequence in enumerate(seek_sequences): + # make sure we cut a predicted EOS token if we are not finished with the generation yet + is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] + if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: + seek_sequence = seek_sequence[:-1] + + print("hf out tokens", seek_sequence.tolist()) + + # remove all padding tokens + if seek_sequence[-1] == generation_config.pad_token_id: + num_paddings = (seek_sequence == generation_config.pad_token_id).sum() + seek_sequence = seek_sequence[:-num_paddings] + + if compression_ratio_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + # compression_ratio = self._retrieve_compression_ratio(sequence_tokens[i], self.config.vocab_size) + compression_ratio = self._retrieve_compression_ratio(seek_sequence, self.config.vocab_size) + + if compression_ratio > compression_ratio_threshold: + print("fallback compression") + print("current temp", temperature) + needs_fallback[i] = True + + if logprob_threshold is not None: + if "sequences_scores" in seek_outputs[0]: + logprobs = [s["sequences_scores"] for s in seek_outputs][i] + else: + scores = seek_outputs[i]["scores"] + logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) + + # TODO(PVP) only works for batch size = 1 currently + if logprobs < logprob_threshold: + print("fallback logprobs", logprobs) + print("current temp", temperature) + needs_fallback[i] = True + + if no_speech_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + # Need to do before all other logit processors + if no_speech_detector.no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: + print("Skip because of VAD") + needs_fallback[i] = False + should_skip[i] = True + + seek_sequence_list[fallback_index_map[i]] = seek_sequence + seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] + do_condition_on_prev_tokens[fallback_index_map[i]] = condition_on_prev_tokens and temperature < 0.5 + + if needs_fallback[i]: + new_fallback_index_map.append(fallback_index_map[i]) + new_segment_input.append(segment_input[i]) + new_decoder_input_ids.append(decoder_input_ids[i]) + + fallback_index_map = new_fallback_index_map + + if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1: + seek_sequences = seek_sequence_list + seek_outputs = seek_outputs_list + break + + decoder_input_ids = torch.stack(new_decoder_input_ids) + segment_input = torch.stack(new_segment_input) + + for i, seek_sequence in enumerate(seek_sequences): + prev_i = batch_idx_map[i] + + if should_skip[i]: + seek[prev_i] += seek_num_frames[prev_i] + print("Skipped!") + continue + + # TODO(Patrick: delete cut type) + segments, segment_offset, cut_type = self._retrieve_segment( + seek_sequence=seek_sequence, + seek_outputs=seek_outputs, + time_offset=time_offset, + timestamp_begin=timestamp_begin, + seek_num_frames=seek_num_frames, + time_precision=time_precision, + input_stride=input_stride, + prev_idx=prev_i, + idx=i, + ) + + current_segments[prev_i] += segments + seek[prev_i] += segment_offset + + # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted + # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output + sequences = self._pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right") + + # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. + if return_segments: + return {"sequences": sequences, "segments": current_segments} - # 1. Check whether we're in shortform or longform mode + return sequences + + @staticmethod + def _retrieve_total_input_frames(input_features, kwargs): if input_features is not None: - total_input_frames = input_features.shape[-1] - elif "encoder_outputs" in kwargs: + return input_features.shape[-1] + + if "encoder_outputs" in kwargs: encoder_outputs_shape = ( kwargs["encoder_outputs"][0].shape if isinstance(kwargs["encoder_outputs"], BaseModelOutput) else kwargs["encoder_outputs"].shape ) - total_input_frames = encoder_outputs_shape[1] * input_stride - else: - raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.") + return encoder_outputs_shape[1] * input_stride + + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.") - is_shortform = total_input_frames <= num_segment_frames + @staticmethod + def _set_return_outputs(return_dict_in_generate, return_token_timestamps, is_shortform, logprob_threshold, generation_config): + if return_dict_in_generate is None: + return_dict_in_generate = generation_config.return_dict_in_generate - # 2. Make sure the generation config is correctly set depending on whether timestamps are to be returned or not + if return_token_timestamps: + return_dict_in_generate = True + generation_config.output_attentions = True + + if not is_shortform and logprob_threshold is not None: + generation_config.output_scores = True + generation_config.output_attentions = True + + @staticmethod + def _set_return_timestamps(return_timestamps, is_shortform, generation_config): if return_timestamps is True: if not hasattr(generation_config, "no_timestamps_token_id"): raise ValueError( @@ -1981,7 +2224,7 @@ def generate( "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" ) - generation_config.return_timestamps = return_timestamps + generation_config.return_timestamps = True elif not is_shortform: if return_timestamps is False: raise ValueError( @@ -2003,7 +2246,8 @@ def generate( else: generation_config.return_timestamps = False - # 3. Make sure to correctly set language-related parameters + @staticmethod + def _set_language_and_task(language, task, is_multilingual, generation_config): if is_multilingual is not None: if not hasattr(generation_config, "is_multilingual"): raise ValueError( @@ -2029,6 +2273,7 @@ def generate( ) language = language.lower() generation_config.language = language + if task is not None: if not hasattr(generation_config, "task_to_id"): raise ValueError( @@ -2038,18 +2283,19 @@ def generate( ) generation_config.task = task - # 4. Add forced decoder ids depending on passed `language`, `task`,`prompt_ids`, `return_token_timestamps` and `return_timestamps` + @staticmethod + def _set_forced_decoder_ids(task, language, prompt_ids, generation_config, config, kwargs): forced_decoder_ids = None # Legacy code for backward compatibility - if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: - forced_decoder_ids = self.config.forced_decoder_ids + if hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None: + forced_decoder_ids = config.forced_decoder_ids elif ( hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None ): forced_decoder_ids = generation_config.forced_decoder_ids else: - forced_decoder_ids = kwargs.get("forced_decoder_ids", None) + forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): forced_decoder_ids = [] @@ -2095,21 +2341,21 @@ def generate( decoder_start_token_id, *text_prompt_ids = prompt_ids # Slicing the text prompt ids in a manner consistent with the OpenAI implementation # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) - text_prompt_ids = text_prompt_ids[-self.config.max_target_positions // 2 - 1 :] + text_prompt_ids = text_prompt_ids[-config.max_target_positions // 2 - 1 :] # Set the decoder_start_token_id to <|startofprev|> kwargs.update({"decoder_start_token_id": decoder_start_token_id}) # If the user passes `max_new_tokens`, increase its number to account for the prompt if kwargs.get("max_new_tokens", None) is not None: kwargs["max_new_tokens"] += len(text_prompt_ids) - if kwargs["max_new_tokens"] >= self.config.max_target_positions: + if kwargs["max_new_tokens"] >= config.max_target_positions: raise ValueError( f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` " f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced " f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the " - f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. " + f"`max_target_positions` of the Whisper model: {config.max_target_positions}. " "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " - f"so that their combined length is less that {self.config.max_target_positions}." + f"so that their combined length is less that {config.max_target_positions}." ) # Reformat the forced_decoder_ids to incorporate the prompt @@ -2124,10 +2370,9 @@ def generate( forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] generation_config.forced_decoder_ids = forced_decoder_ids + @staticmethod + def _set_num_frames(return_token_timestamps, generation_config, kwargs): if return_token_timestamps: - kwargs["output_attentions"] = True - return_dict_in_generate = True - if getattr(generation_config, "task", None) == "translate": logger.warning("Token-level timestamps may not be reliable for task 'translate'.") if not hasattr(generation_config, "alignment_heads"): @@ -2136,9 +2381,48 @@ def generate( "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." ) - if kwargs.get("num_frames") is not None: - generation_config.num_frames = kwargs.pop("num_frames") + generation_config.num_frames = kwargs.pop("num_frames", None) + + @staticmethod + def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config): + condition_on_prev_tokens = ( + condition_on_prev_tokens + if condition_on_prev_tokens is not None + else getattr(generation_config, "condition_on_prev_tokens", False) + ) + generation_config.condition_on_prev_tokens = condition_on_prev_tokens + + @staticmethod + def _retrieve_max_frames_and_seek(batch_size, attention_mask): + if batch_size > 1 and attention_mask is None: + raise ValueError( + "When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " + ) + elif batch_size > 1: + max_frames = attention_mask.sum(-1).cpu().to(torch.long) + seek = torch.zeros((batch_size,), dtype=torch.long) + else: + max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames + seek = torch.zeros((1,), dtype=torch.long) + + return max_frames, seek + + @staticmethod + def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): + init_tokens = [generation_config.decoder_start_token_id] + forced_decoder_ids = generation_config.forced_decoder_ids + if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: + i = 1 + while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: + init_tokens += [forced_decoder_ids[0][1]] + forced_decoder_ids = forced_decoder_ids[1:] + i += 1 + + forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None + generation_config.forced_decoder_ids = forced_decoder_ids + @staticmethod + def _retrieve_logit_processors(generation_config): begin_index = 1 if generation_config.return_timestamps is True: forced_decoder_ids = generation_config.forced_decoder_ids @@ -2177,409 +2461,108 @@ def generate( [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor ) - # 5. If we're in shortform mode, simple generate the whole input at once and return the output - if is_shortform: - outputs = super().generate( - input_features, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - ) - - if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - num_frames = getattr(generation_config, "num_frames", None) - outputs["token_timestamps"] = self._extract_token_timestamps( - outputs, generation_config.alignment_heads, num_frames=num_frames - ) - - return outputs - - condition_on_prev_tokens = ( - condition_on_prev_tokens - if condition_on_prev_tokens is not None - else getattr(generation_config, "condition_on_prev_tokens", False) - ) - - # 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated - # timestamp tokens - # 6.1 Set running parameters for while loop - if not return_segments and return_dict_in_generate: - raise ValueError( - "Make sure to set `return_segments=True` to return generation outputs as part of the `'segments' key.`" - ) - - # if input is longer than 30 seconds we default to long-form generation - timestamp_begin = generation_config.no_timestamps_token_id + 1 - # input stride is mel frames per encoder output vector which is the product of all conv strides - batch_size = input_features.shape[0] - - if batch_size > 1 and attention_mask is None: - raise ValueError( - "When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " - ) - elif batch_size > 1: - max_frames = attention_mask.sum(-1).cpu().to(torch.long) - seek = torch.zeros((batch_size,), dtype=torch.long) - else: - max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames - seek = torch.zeros((1,), dtype=torch.long) - - current_segments = [[] for _ in range(batch_size)] - cur_to_prev_index_map = list(range(batch_size)) - - # batch size can decrease during the run - cur_bsz = prev_bsz = batch_size - - temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature - temperature = temperatures[0] - do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)] - - output_scores = logprob_threshold is not None - return_dict_in_generate = return_dict_in_generate or output_scores - - init_tokens = [generation_config.decoder_start_token_id] - if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: - i = 1 - while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: - init_tokens += [forced_decoder_ids[0][1]] - forced_decoder_ids = forced_decoder_ids[1:] - i += 1 - - forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None - generation_config.forced_decoder_ids = forced_decoder_ids - - # 6.2 Transcribe audio until we reach the end of all input audios - while (seek < max_frames).any(): - prev_bsz = cur_bsz - - # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop - # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order - # to know which original audio is being decoded - new_cur_to_prev_index_map = [] - for i in range(prev_bsz): - prev_i = cur_to_prev_index_map[i] - if seek[prev_i] >= max_frames[prev_i]: - cut_index = i + (cur_bsz - prev_bsz) - cur_bsz -= 1 - input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0) - else: - # cut out index that goes away - new_cur_to_prev_index_map.append(prev_i) - - # 6.4 Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk - cur_to_prev_index_map = new_cur_to_prev_index_map - time_offset = seek * time_precision / input_stride - seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) - - # 6.5 Make sure that all inputs are padded to the same input length - segment_input = [] - for i in range(cur_bsz): - prev_i = cur_to_prev_index_map[i] - segment_input_slice = input_features[ - i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i] - ] - - if segment_input_slice.shape[-1] < num_segment_frames: - # pad to 3000 if necessary - segment_input_slice = F.pad( - segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1]) - ) - - segment_input.append(segment_input_slice) - - segment_input = torch.cat(segment_input, dim=0) - - one_tensor = torch.ones((cur_bsz, 1), device=segment_input.device, dtype=torch.long) - decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) - - # if condition_on_prev_tokens and len(current_segments[0]) > 0 and temperature < 0.5: - if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0: - # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 - cut_off_length = self.config.max_target_positions // 2 - 1 - active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in new_cur_to_prev_index_map] - prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or suppress_tokens_processor.suppress_tokens[-2] # TODO(Patrick): Need to put in generation_config - - bos_token_tensor = prev_start_of_text * one_tensor[0] - prev_tokens = self._pad_to_max_length( - active_segments, generation_config.pad_token_id, padding="left", bos_token_tensor=bos_token_tensor, cut_off_length=cut_off_length - ) - decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) - - # kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) - - passed_max_length = kwargs.get("max_length", None) - passed_max_new_tokens = kwargs.get("max_new_tokens", None) - max_length_config = getattr(generation_config, "max_length", None) - max_new_tokens_config = getattr(generation_config, "max_new_tokens", None) - - # Make sure we don't get larger than `max_length` - if passed_max_length is not None and passed_max_new_tokens is None: - kwargs["max_length"] = min( - kwargs["max_length"] + cut_off_length + 1, self.config.max_target_positions - ) - logger.info( - f"Increase max_length from {passed_max_length} to {kwargs['max_length']} since input is conditioned on previous segment." - ) - elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: - kwargs["max_length"] = min( - generation_config.max_length + cut_off_length + 1, self.config.max_target_positions - ) - logger.info( - f"Increase max_length from {max_length_config} to {kwargs['max_length']} since input is conditioned on previous segment." - ) - elif ( - passed_max_new_tokens is not None - and passed_max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions - ): - kwargs["max_new_tokens"] = self.config.max_target_positions - decoder_input_ids.shape[-1] - elif ( - passed_max_new_tokens is None - and max_new_tokens_config is not None - and max_new_tokens_config + decoder_input_ids.shape[-1] > self.config.max_target_positions - ): - kwargs["max_new_tokens"] = self.config.max_target_positions - decoder_input_ids.shape[-1] - - for proc in logits_processor: - if hasattr(proc, "set_begin_index"): - proc.set_begin_index(decoder_input_ids.shape[-1]) - - print("hf in tokens", decoder_input_ids[0].tolist()) - - # 6.6 Batch generate current chunk - if RUN_NEW_WAY: - seek_sequence_list = [None for _ in range(cur_bsz)] - seek_outputs_list = [None for _ in range(cur_bsz)] - needs_fallback = [False for _ in range(cur_bsz)] - should_skip = [False for _ in range(cur_bsz)] - fallback_index_map = list(range(cur_bsz)) - - if not RUN_NEW_WAY: - needs_fallback = False - should_skip = False + @staticmethod + def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map): + prev_bsz = cur_bsz + new_batch_idx_map = [] + for i in range(prev_bsz): + prev_i = batch_idx_map[i] + if seek[prev_i] >= max_frames[prev_i]: + cut_index = i + (cur_bsz - prev_bsz) + cur_bsz -= 1 + input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0) + else: + # cut out index that goes away + new_batch_idx_map.append(prev_i) - for fallback_idx, temperature in enumerate(temperatures): - do_sample = temperature > 0.0 + return input_features, cur_bsz, new_batch_idx_map - num_beams = kwargs.pop("num_beams", 1) - generation_config.num_beams = num_beams if not do_sample else 1 + @staticmethod + def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map): + segment_input = [] + for i in range(cur_bsz): + prev_i = batch_idx_map[i] + segment_input_slice = input_features[ + i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i] + ] - seek_outputs = super().generate( - segment_input, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - temperature=temperature, - do_sample=do_sample, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - decoder_input_ids=decoder_input_ids, - **kwargs, + if segment_input_slice.shape[-1] < num_segment_frames: + # pad to 3000 if necessary + segment_input_slice = F.pad( + segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1]) ) - if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - num_frames = getattr(generation_config, "num_frames", None) - seek_outputs["token_timestamps"] = self._extract_token_timestamps( - seek_outputs, generation_config.alignment_heads, num_frames=num_frames - ) - - if return_dict_in_generate: - def split_by_batch_index(values, key, batch_idx): - if key == "scores": - return list(v[batch_idx].cpu() for v in values) - if key == "past_key_values": - # we don't save `past_key_values` as this is too costly - return None - return values[batch_idx].cpu() - - sequence_tokens = seek_outputs["sequences"] - seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0])] - else: - sequence_tokens = seek_outputs - - # remove all previously passed decoder input ids - seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] - - if RUN_NEW_WAY: - new_fallback_index_map = [] - new_segment_input = [] - new_decoder_input_ids = [] - # 6.7 Loop over each decoded audio individually as each decoding can be of a different length - for i, seek_sequence in enumerate(seek_sequences): - # make sure we cut a predicted EOS token if we are not finished with the generation yet - is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] - if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: - seek_sequence = seek_sequence[:-1] - - print("hf out tokens", seek_sequence.tolist()) - - # remove all padding tokens - if seek_sequence[-1] == generation_config.pad_token_id: - num_paddings = (seek_sequence == generation_config.pad_token_id).sum() - seek_sequence = seek_sequence[:-num_paddings] - - if compression_ratio_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - # compression_ratio = self._retrieve_compression_ratio(sequence_tokens[i], self.config.vocab_size) - compression_ratio = self._retrieve_compression_ratio(seek_sequence, self.config.vocab_size) - - if compression_ratio > compression_ratio_threshold: - print("fallback compression") - print("current temp", temperature) - needs_fallback[i] = True - - if logprob_threshold is not None: - if "sequences_scores" in seek_outputs[0]: - logprobs = [s["sequences_scores"] for s in seek_outputs][i] - else: - scores = seek_outputs[i]["scores"] - logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) - - # TODO(PVP) only works for batch size = 1 currently - if logprobs < logprob_threshold: - print("fallback logprobs", logprobs) - print("current temp", temperature) - needs_fallback[i] = True - - if no_speech_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - # Need to do before all other logit processors - if no_speech_detector.no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: - print("Skip because of VAD") - needs_fallback[i] = False - should_skip[i] = True - - seek_sequence_list[fallback_index_map[i]] = seek_sequence - seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] - do_condition_on_prev_tokens[fallback_index_map[i]] = condition_on_prev_tokens and temperature < 0.5 - - if needs_fallback[i]: - new_fallback_index_map.append(fallback_index_map[i]) - new_segment_input.append(segment_input[i]) - new_decoder_input_ids.append(decoder_input_ids[i]) - - fallback_index_map = new_fallback_index_map - - if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1: - seek_sequences = seek_sequence_list - seek_outputs = seek_outputs_list - break - - decoder_input_ids = torch.stack(new_decoder_input_ids) - segment_input = torch.stack(new_segment_input) - - if not RUN_NEW_WAY: - if compression_ratio_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - compression_ratio = self._retrieve_compression_ratio(sequence_tokens[0], self.config.vocab_size) - - if compression_ratio > compression_ratio_threshold: - print("fallback compression") - print("current temp", temperature) - needs_fallback = True - - if logprob_threshold is not None: - if "sequences_scores" in seek_outputs[0]: - logprobs = [s["sequences_scores"] for s in seek_outputs] - else: - scores = [s["scores"] for s in seek_outputs] - logprobs = self._retrieve_avg_logprobs(scores[0], seek_sequences[0], self.config.eos_token_id, temperature) - - # TODO(PVP) only works for batch size = 1 currently - if logprobs < logprob_threshold: - print("current temp", temperature) - needs_fallback = True - - if no_speech_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - # Need to do before all other logit processors - if no_speech_detector.no_speech_prob[0] > no_speech_threshold and logprobs < logprob_threshold: - print("Skip because of VAD") - needs_fallback = False - should_skip = True - - if not needs_fallback: - break - - if RUN_NEW_WAY: - for i, seek_sequence in enumerate(seek_sequences): - prev_i = cur_to_prev_index_map[i] - - if should_skip[i]: - seek[prev_i] += seek_num_frames[prev_i] - print("Skipped!") - continue - - # TODO(Patrick: delete cut type) - segments, segment_offset, cut_type = self._retrieve_segment( - seek_sequence=seek_sequence, - seek_outputs=seek_outputs, - time_offset=time_offset, - timestamp_begin=timestamp_begin, - seek_num_frames=seek_num_frames, - time_precision=time_precision, - input_stride=input_stride, - prev_idx=prev_i, - idx=i, - ) - - current_segments[prev_i] += segments - seek[prev_i] += segment_offset - - if not RUN_NEW_WAY: - # 6.7 Loop over each decoded audio individually as each decoding can be of a different length - for i, seek_sequence in enumerate(seek_sequences): - prev_i = cur_to_prev_index_map[i] + segment_input.append(segment_input_slice) - if should_skip: - seek[prev_i] += seek_num_frames[prev_i] - print("Skipped!") - continue + segment_input = torch.cat(segment_input, dim=0) - # make sure we cut a predicted EOS token if we are not finished with the generation yet - is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] - if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: - seek_sequence = seek_sequence[:-1] - - print("hf out tokens", seek_sequence.tolist()) + return segment_input - # remove all padding tokens - if seek_sequence[-1] == generation_config.pad_token_id: - num_paddings = (seek_sequence == generation_config.pad_token_id).sum() - seek_sequence = seek_sequence[:-num_paddings] + @staticmethod + # TODO(Patrick) - remove prev_start_of_text + def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx_map, do_condition_on_prev_tokens, generation_config, config, device, prev_start_of_text): + cut_off_length = config.max_target_positions // 2 - 1 + + one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long) + decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) + decoder_attention_mask = None + + # if condition_on_prev_tokens and len(current_segments[0]) > 0 and temperature < 0.5: + if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0: + # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 + active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] + prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or prev_start_of_text + + bos_token_tensor = prev_start_of_text * one_tensor[0] + prev_tokens = self._pad_to_max_length( + active_segments, generation_config.pad_token_id, padding="left", bos_token_tensor=bos_token_tensor, cut_off_length=cut_off_length + ) + decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) - # TODO(Patrick: delete cut type) - segments, segment_offset, cut_type = self._retrieve_segment( - seek_sequence=seek_sequence, - seek_outputs=seek_outputs, - time_offset=time_offset, - timestamp_begin=timestamp_begin, - seek_num_frames=seek_num_frames, - time_precision=time_precision, - input_stride=input_stride, - prev_idx=prev_i, - idx=i, - ) + decoder_attention_mask = (decoder_input_ids != generation_config.pad_token_id) + + return decoder_input_ids, decoder_attention_mask - current_segments[prev_i] += segments - seek[prev_i] += segment_offset + @staticmethod + def _get_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs): + cut_off_length = config.max_target_positions // 2 - 1 - print(f"{cut_type} seek {seek[0]}") + passed_max_length = kwargs.get("max_length", None) + passed_max_new_tokens = kwargs.get("max_new_tokens", None) + max_length_config = getattr(generation_config, "max_length", None) + max_new_tokens_config = getattr(generation_config, "max_new_tokens", None) - # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted - # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output - sequences = self._pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right") + max_new_tokens = None + max_length = None - # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. - if return_segments: - return {"sequences": sequences, "segments": current_segments} + # Make sure we don't get larger than `max_length` + if passed_max_length is not None and passed_max_new_tokens is None: + max_length = min( + kwargs["max_length"] + cut_off_length + 1, config.max_target_positions + ) + logger.info( + f"Increase max_length from {passed_max_length} to {kwargs['max_length']} since input is conditioned on previous segment." + ) + elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: + max_length = min( + generation_config.max_length + cut_off_length + 1, config.max_target_positions + ) + logger.info( + f"Increase max_length from {max_length_config} to {kwargs['max_length']} since input is conditioned on previous segment." + ) + elif ( + passed_max_new_tokens is not None + and passed_max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions + ): + max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] + elif ( + passed_max_new_tokens is None + and max_new_tokens_config is not None + and max_new_tokens_config + decoder_input_ids.shape[-1] > config.max_target_positions + ): + max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] - return sequences + return max_new_tokens, max_length @staticmethod def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None): From 032d45abccda19c26fb64c0ac7019cbac0e2874b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 10 Dec 2023 18:46:51 +0000 Subject: [PATCH 35/75] save intermediate --- src/transformers/generation/utils.py | 1 + src/transformers/models/whisper/modeling_whisper.py | 3 +-- tests/models/whisper/test_modeling_whisper.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e773ecbe256b..21eb5327bd3b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2575,6 +2575,7 @@ def greedy_search( break # prepare model inputs + import ipdb; ipdb.set_trace() model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f1de60947c37..c1cc24910d18 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2309,7 +2309,7 @@ def generate( ) decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) - # kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) + kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) passed_max_length = kwargs.get("max_length", None) passed_max_new_tokens = kwargs.get("max_new_tokens", None) @@ -2755,7 +2755,6 @@ def prepare_inputs_for_generation( if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]: decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] - return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index f432e8ea1990..d9b8228075f5 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2326,6 +2326,7 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): torch.manual_seed(0) for i in range(num_samples): + import ipdb; ipdb.set_trace() assert decoded_all[i] == EXPECTED_TEXT[i] def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): From 06b598a71c5370f39fadc907aa9bb59a5f4d2a76 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 10 Dec 2023 20:11:15 +0000 Subject: [PATCH 36/75] Fix more --- .../models/whisper/modeling_whisper.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index e225fadc8f57..2cd8c0ff53eb 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2024,8 +2024,8 @@ def generate( # 6.5 prepare decoder input ids # TODO(Patrick) - clean up prev_start_of_text - suppress_tokens_processor = [l for l in logits_processor if isinstance(l, SuppressTokensLogitsProcessor)] - prev_start_of_text = suppress_tokens_processor[0].suppress_tokens[-2] if len(suppress_tokens_processor) > 0 else None + suppress_tokens = self._get_attr_from_logit_processors(logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens") + prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None decoder_input_ids, kwargs = self._prepare_decoder_input_ids(cur_bsz=cur_bsz, init_tokens=init_tokens, current_segments=current_segments, batch_idx_map=batch_idx_map, do_condition_on_prev_tokens=do_condition_on_prev_tokens, generation_config=generation_config, config=self.config, device=segment_input.device, kwargs=kwargs, prev_start_of_text=prev_start_of_text) # 6.6 set max new tokens or max length @@ -2130,7 +2130,8 @@ def split_by_batch_index(values, key, batch_idx): if no_speech_threshold is not None: # TODO(PVP) only works for batch size = 1 currently # Need to do before all other logit processors - if no_speech_detector.no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: + no_speech_prob = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") + if no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: print("Skip because of VAD") needs_fallback[i] = False should_skip[i] = True @@ -2188,6 +2189,13 @@ def split_by_batch_index(values, key, batch_idx): return sequences + @staticmethod + def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): + logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None) + if logit_processor: + return getattr(logit_processor, attribute_name, None) + return None + @staticmethod def _retrieve_total_input_frames(input_features, kwargs): if input_features is not None: @@ -2213,8 +2221,10 @@ def _set_return_outputs(return_dict_in_generate, return_token_timestamps, is_sho generation_config.output_attentions = True if not is_shortform and logprob_threshold is not None: + return_dict_in_generate = True generation_config.output_scores = True - generation_config.output_attentions = True + + generation_config.return_dict_in_generate = return_dict_in_generate @staticmethod def _set_return_timestamps(return_timestamps, is_shortform, generation_config): @@ -2519,7 +2529,7 @@ def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or prev_start_of_text bos_token_tensor = prev_start_of_text * one_tensor[0] - prev_tokens = self._pad_to_max_length( + prev_tokens = WhisperForConditionalGeneration._pad_to_max_length( active_segments, generation_config.pad_token_id, padding="left", bos_token_tensor=bos_token_tensor, cut_off_length=cut_off_length ) decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) From d1021ec3ac991e90737c58cf6bd0b642ad28ff64 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 12 Dec 2023 19:49:10 +0000 Subject: [PATCH 37/75] Fix VAD for large-v2 --- src/transformers/generation/logits_process.py | 28 +++++++++++++++++-- src/transformers/generation/utils.py | 6 ++-- .../models/whisper/modeling_whisper.py | 25 +++++++++++------ 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 2d7758243160..1658e805cf10 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -50,6 +50,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." ) + @property + def pass_all_logits(self): + return getattr(self, "_pass_all_logits", False) class LogitsWarper: """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" @@ -60,6 +63,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." ) + @property + def pass_all_logits(self): + return getattr(self, "_pass_all_logits", False) + class LogitsProcessorList(list): """ @@ -84,6 +91,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa The processed prediction scores. """ + if not any(processor.pass_all_logits for processor in self) and len(scores.shape) > 2: + scores = scores[:, -1, :] + for processor in self: function_args = inspect.signature(processor.__call__).parameters if len(function_args) > 2: @@ -95,6 +105,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa scores = processor(input_ids, scores, **kwargs) else: scores = processor(input_ids, scores) + return scores @@ -1604,10 +1615,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class WhisperNoSpeechDetection(LogitsProcessor): r"""This processor can be used to detect silence when using Whisper.""" - def __init__(self, no_speech_token: int, begin_index: int): + def __init__(self, no_speech_token: int, begin_index: int, begin_index_offset: int): self.no_speech_token = no_speech_token self.begin_index = begin_index + self.begin_index_offset = begin_index_offset self._no_speech_prob = [0.0] + self._has_run = False + + # make sure we pass all logits + self._pass_all_logits = True @property def no_speech_prob(self): @@ -1615,12 +1631,18 @@ def no_speech_prob(self): def set_begin_index(self, begin_index): self.begin_index = begin_index + self._has_run = False @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if input_ids.shape[1] == self.begin_index: - probs = scores.float().softmax(dim=-1) + no_speech_index = (self.begin_index - self.begin_index_offset) + if (input_ids.shape[1] >= no_speech_index) and not self._has_run: + no_speech_scores = scores[:, no_speech_index - 1] + probs = no_speech_scores.float().softmax(dim=-1) self._no_speech_prob = probs[:, self.no_speech_token] + self._has_run = True + + scores = scores[:, -1, :] return scores diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e773ecbe256b..1de7b026cea3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2588,10 +2588,8 @@ def greedy_search( if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits[:, -1, :] - # pre-process distribution - next_tokens_scores = logits_processor(input_ids, next_token_logits) + next_tokens_scores = logits_processor(input_ids, outputs.logits) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -2870,7 +2868,7 @@ def sample( if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits[:, -1, :] + next_token_logits = outputs.logits # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 2cd8c0ff53eb..86f200f5437c 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1965,11 +1965,11 @@ def generate( self._set_forced_decoder_ids(task=task, language=language, prompt_ids=prompt_ids, generation_config=generation_config, config=self.config, kwargs=kwargs) self._set_num_frames(return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs) - # 4. Retrieve logits processors - logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold) - # 5. If we're in shortform mode, simple generate the whole input at once and return the output if is_shortform: + # 4. Retrieve logits processors + logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold) + outputs = super().generate( input_features, generation_config, @@ -2003,6 +2003,9 @@ def generate( max_frames, seek = self._retrieve_max_frames_and_seek(batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames) init_tokens = self._retrieve_init_tokens_from_forced_decoder_ids(generation_config=generation_config) + # 4. Retrieve logits processors + logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold, init_tokens=init_tokens) + # 6.2 Preppare running variables, list for generation cur_bsz = batch_size current_segments = [[] for _ in range(batch_size)] @@ -2030,6 +2033,8 @@ def generate( # 6.6 set max new tokens or max length max_new_tokens, max_length = self._get_max_new_tokens_and_length(config=self.config, decoder_input_ids=decoder_input_ids, generation_config=generation_config, kwargs=kwargs) + kwargs.pop("max_length", None) + kwargs.pop("max_new_tokens", None) # 6.7 Set current `begin_index` for all logit processors for proc in logits_processor: @@ -2131,6 +2136,9 @@ def split_by_batch_index(values, key, batch_idx): # TODO(PVP) only works for batch size = 1 currently # Need to do before all other logit processors no_speech_prob = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") + print("WATCH") + print(no_speech_prob) + print(logprobs) if no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: print("Skip because of VAD") needs_fallback[i] = False @@ -2216,6 +2224,7 @@ def _set_return_outputs(return_dict_in_generate, return_token_timestamps, is_sho if return_dict_in_generate is None: return_dict_in_generate = generation_config.return_dict_in_generate + generation_config.return_token_timestamps = return_token_timestamps if return_token_timestamps: return_dict_in_generate = True generation_config.output_attentions = True @@ -2435,7 +2444,7 @@ def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): return init_tokens @staticmethod - def _retrieve_logit_processors(generation_config, logits_processor, no_speech_threshold): + def _retrieve_logit_processors(generation_config, logits_processor, no_speech_threshold, init_tokens=None): begin_index = 1 if generation_config.return_timestamps is True: forced_decoder_ids = generation_config.forced_decoder_ids @@ -2468,8 +2477,9 @@ def _retrieve_logit_processors(generation_config, logits_processor, no_speech_th ) generation_config.begin_suppress_tokens = None - if no_speech_threshold is not None: - no_speech_detector = WhisperNoSpeechDetection(no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index) + if no_speech_threshold is not None and init_tokens is not None: + begin_index_offset = (len(init_tokens) - 1) + no_speech_detector = WhisperNoSpeechDetection(no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, begin_index_offset=begin_index_offset) logits_processor = ( [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor ) @@ -2520,7 +2530,6 @@ def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long) decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) - decoder_attention_mask = None # if condition_on_prev_tokens and len(current_segments[0]) > 0 and temperature < 0.5: if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0: @@ -2534,7 +2543,7 @@ def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx ) decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) - kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) + # kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) return decoder_input_ids, kwargs From 23d214928d3b786189787b0e3a5d4072d8e6d459 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 13 Dec 2023 11:20:46 +0000 Subject: [PATCH 38/75] Save new --- .../models/whisper/modeling_whisper.py | 257 +++++++++--------- 1 file changed, 129 insertions(+), 128 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 86f200f5437c..b9b68e242c52 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2032,9 +2032,7 @@ def generate( decoder_input_ids, kwargs = self._prepare_decoder_input_ids(cur_bsz=cur_bsz, init_tokens=init_tokens, current_segments=current_segments, batch_idx_map=batch_idx_map, do_condition_on_prev_tokens=do_condition_on_prev_tokens, generation_config=generation_config, config=self.config, device=segment_input.device, kwargs=kwargs, prev_start_of_text=prev_start_of_text) # 6.6 set max new tokens or max length - max_new_tokens, max_length = self._get_max_new_tokens_and_length(config=self.config, decoder_input_ids=decoder_input_ids, generation_config=generation_config, kwargs=kwargs) - kwargs.pop("max_length", None) - kwargs.pop("max_new_tokens", None) + kwargs = self._set_max_new_tokens_and_length(config=self.config, decoder_input_ids=decoder_input_ids, generation_config=generation_config, kwargs=kwargs) # 6.7 Set current `begin_index` for all logit processors for proc in logits_processor: @@ -2042,126 +2040,7 @@ def generate( proc.set_begin_index(decoder_input_ids.shape[-1]) print("hf in tokens", decoder_input_ids[0].tolist()) - - # 6.6 Batch generate current chunk - seek_sequence_list = [None for _ in range(cur_bsz)] - seek_outputs_list = [None for _ in range(cur_bsz)] - needs_fallback = [False for _ in range(cur_bsz)] - should_skip = [False for _ in range(cur_bsz)] - fallback_index_map = list(range(cur_bsz)) - - for fallback_idx, temperature in enumerate(temperatures): - generation_config.do_sample = temperature > 0.0 - generation_config.temperature = temperature - generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1 - - seek_outputs = super().generate( - segment_input, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - decoder_input_ids=decoder_input_ids, - max_new_tokens=max_new_tokens, - max_length=max_length, - **kwargs, - ) - - if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - num_frames = getattr(generation_config, "num_frames", None) - seek_outputs["token_timestamps"] = self._extract_token_timestamps( - seek_outputs, generation_config.alignment_heads, num_frames=num_frames - ) - - if generation_config.return_dict_in_generate: - def split_by_batch_index(values, key, batch_idx): - if key == "scores": - return list(v[batch_idx].cpu() for v in values) - if key == "past_key_values": - # we don't save `past_key_values` as this is too costly - return None - return values[batch_idx].cpu() - - sequence_tokens = seek_outputs["sequences"] - seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0])] - else: - sequence_tokens = seek_outputs - - # remove all previously passed decoder input ids - seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] - - new_fallback_index_map = [] - new_segment_input = [] - new_decoder_input_ids = [] - # 6.7 Loop over each decoded audio individually as each decoding can be of a different length - for i, seek_sequence in enumerate(seek_sequences): - # make sure we cut a predicted EOS token if we are not finished with the generation yet - prev_i = batch_idx_map[fallback_index_map[i]] - is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] - if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: - seek_sequence = seek_sequence[:-1] - - print("hf out tokens", seek_sequence.tolist()) - - # remove all padding tokens - if seek_sequence[-1] == generation_config.pad_token_id: - num_paddings = (seek_sequence == generation_config.pad_token_id).sum() - seek_sequence = seek_sequence[:-num_paddings] - - if compression_ratio_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - # compression_ratio = self._retrieve_compression_ratio(sequence_tokens[i], self.config.vocab_size) - compression_ratio = self._retrieve_compression_ratio(seek_sequence, self.config.vocab_size) - - if compression_ratio > compression_ratio_threshold: - print("fallback compression") - print("current temp", temperature) - needs_fallback[i] = True - - if logprob_threshold is not None: - if "sequences_scores" in seek_outputs[0]: - logprobs = [s["sequences_scores"] for s in seek_outputs][i] - else: - scores = seek_outputs[i]["scores"] - logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) - - # TODO(PVP) only works for batch size = 1 currently - if logprobs < logprob_threshold: - print("fallback logprobs", logprobs) - print("current temp", temperature) - needs_fallback[i] = True - - if no_speech_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - # Need to do before all other logit processors - no_speech_prob = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") - print("WATCH") - print(no_speech_prob) - print(logprobs) - if no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: - print("Skip because of VAD") - needs_fallback[i] = False - should_skip[i] = True - - seek_sequence_list[fallback_index_map[i]] = seek_sequence - seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] - do_condition_on_prev_tokens[fallback_index_map[i]] = condition_on_prev_tokens and temperature < 0.5 - - if needs_fallback[i]: - new_fallback_index_map.append(fallback_index_map[i]) - new_segment_input.append(segment_input[i]) - new_decoder_input_ids.append(decoder_input_ids[i]) - - fallback_index_map = new_fallback_index_map - - if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1: - seek_sequences = seek_sequence_list - seek_outputs = seek_outputs_list - break - - decoder_input_ids = torch.stack(new_decoder_input_ids) - segment_input = torch.stack(new_segment_input) + seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(segment_input=segment_input, decoder_input_ids=decoder_input_ids, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, seek=seek, num_segment_frames=num_segment_frames, max_frames=max_frames, temperatures=temperatures, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, return_token_timestamps=return_token_timestamps, compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold, do_condition_on_prev_tokens=do_condition_on_prev_tokens, condition_on_prev_tokens=condition_on_prev_tokens, kwargs=kwargs) for i, seek_sequence in enumerate(seek_sequences): prev_i = batch_idx_map[i] @@ -2197,6 +2076,125 @@ def split_by_batch_index(values, key, batch_idx): return sequences + def generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, batch_idx_map, seek, num_segment_frames, max_frames, temperatures, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_token_timestamps, compression_ratio_threshold, logprob_threshold, no_speech_threshold, do_condition_on_prev_tokens, condition_on_prev_tokens, kwargs): + # 6.6 Batch generate current chunk + seek_sequence_list = [None for _ in range(cur_bsz)] + seek_outputs_list = [None for _ in range(cur_bsz)] + needs_fallback = [False for _ in range(cur_bsz)] + should_skip = [False for _ in range(cur_bsz)] + fallback_index_map = list(range(cur_bsz)) + + for fallback_idx, temperature in enumerate(temperatures): + generation_config.do_sample = temperature > 0.0 + generation_config.temperature = temperature + generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1 + + seek_outputs = super().generate( + segment_input, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=decoder_input_ids, + **kwargs, + ) + + if return_token_timestamps and hasattr(generation_config, "alignment_heads"): + num_frames = getattr(generation_config, "num_frames", None) + seek_outputs["token_timestamps"] = self._extract_token_timestamps( + seek_outputs, generation_config.alignment_heads, num_frames=num_frames + ) + + if generation_config.return_dict_in_generate: + def split_by_batch_index(values, key, batch_idx): + if key == "scores": + return list(v[batch_idx].cpu() for v in values) + if key == "past_key_values": + # we don't save `past_key_values` as this is too costly + return None + return values[batch_idx].cpu() + + sequence_tokens = seek_outputs["sequences"] + seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0])] + else: + sequence_tokens = seek_outputs + + # remove all previously passed decoder input ids + seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] + + new_fallback_index_map = [] + new_segment_input = [] + new_decoder_input_ids = [] + # 6.7 Loop over each decoded audio individually as each decoding can be of a different length + for i, seek_sequence in enumerate(seek_sequences): + # make sure we cut a predicted EOS token if we are not finished with the generation yet + prev_i = batch_idx_map[fallback_index_map[i]] + is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] + if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: + seek_sequence = seek_sequence[:-1] + + print("hf out tokens", seek_sequence.tolist()) + + # remove all padding tokens + if seek_sequence[-1] == generation_config.pad_token_id: + num_paddings = (seek_sequence == generation_config.pad_token_id).sum() + seek_sequence = seek_sequence[:-num_paddings] + + if compression_ratio_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + # compression_ratio = self._retrieve_compression_ratio(sequence_tokens[i], self.config.vocab_size) + compression_ratio = self._retrieve_compression_ratio(seek_sequence, self.config.vocab_size) + + if compression_ratio > compression_ratio_threshold: + print("fallback compression") + print("current temp", temperature) + needs_fallback[i] = True + + if logprob_threshold is not None: + if "sequences_scores" in seek_outputs[0]: + logprobs = [s["sequences_scores"] for s in seek_outputs][i] + else: + scores = seek_outputs[i]["scores"] + logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) + + # TODO(PVP) only works for batch size = 1 currently + if logprobs < logprob_threshold: + print("fallback logprobs", logprobs) + print("current temp", temperature) + needs_fallback[i] = True + + if no_speech_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + # Need to do before all other logit processors + no_speech_prob = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") + if no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: + print("Skip because of VAD") + needs_fallback[i] = False + should_skip[i] = True + + seek_sequence_list[fallback_index_map[i]] = seek_sequence + seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] + do_condition_on_prev_tokens[fallback_index_map[i]] = condition_on_prev_tokens and temperature < 0.5 + + if needs_fallback[i]: + new_fallback_index_map.append(fallback_index_map[i]) + new_segment_input.append(segment_input[i]) + new_decoder_input_ids.append(decoder_input_ids[i]) + + fallback_index_map = new_fallback_index_map + + if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1: + seek_sequences = seek_sequence_list + seek_outputs = seek_outputs_list + break + + decoder_input_ids = torch.stack(new_decoder_input_ids) + segment_input = torch.stack(new_segment_input) + + return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens + + @staticmethod def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None) @@ -2549,11 +2547,11 @@ def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx return decoder_input_ids, kwargs @staticmethod - def _get_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs): + def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs): cut_off_length = config.max_target_positions // 2 - 1 - passed_max_length = kwargs.get("max_length", None) - passed_max_new_tokens = kwargs.get("max_new_tokens", None) + passed_max_length = kwargs.pop("max_length", None) + passed_max_new_tokens = kwargs.pop("max_new_tokens", None) max_length_config = getattr(generation_config, "max_length", None) max_new_tokens_config = getattr(generation_config, "max_new_tokens", None) @@ -2563,7 +2561,7 @@ def _get_max_new_tokens_and_length(config, decoder_input_ids, generation_config, # Make sure we don't get larger than `max_length` if passed_max_length is not None and passed_max_new_tokens is None: max_length = min( - kwargs["max_length"] + cut_off_length + 1, config.max_target_positions + passed_max_length + cut_off_length + 1, config.max_target_positions ) logger.info( f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment." @@ -2587,7 +2585,10 @@ def _get_max_new_tokens_and_length(config, decoder_input_ids, generation_config, ): max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] - return max_new_tokens, max_length + kwargs["max_new_tokens"] = max_new_tokens + kwargs["max_length"] = max_length + + return kwargs @staticmethod def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None): From 1caa2cbe1cc801c7dd24880128a32db8f5deca46 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 13 Dec 2023 12:30:46 +0000 Subject: [PATCH 39/75] Correct more --- .../models/whisper/modeling_whisper.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index b9b68e242c52..ee9f068bd2e1 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1965,11 +1965,12 @@ def generate( self._set_forced_decoder_ids(task=task, language=language, prompt_ids=prompt_ids, generation_config=generation_config, config=self.config, kwargs=kwargs) self._set_num_frames(return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs) + # 4. Retrieve logits processors + num_start_tokens = len(generation_config.forced_decoder_ids) if generation_config.forced_decoder_ids is not None else 1 + logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold, num_start_tokens=num_start_tokens, is_shortform=is_shortform) + # 5. If we're in shortform mode, simple generate the whole input at once and return the output if is_shortform: - # 4. Retrieve logits processors - logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold) - outputs = super().generate( input_features, generation_config, @@ -2003,9 +2004,6 @@ def generate( max_frames, seek = self._retrieve_max_frames_and_seek(batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames) init_tokens = self._retrieve_init_tokens_from_forced_decoder_ids(generation_config=generation_config) - # 4. Retrieve logits processors - logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold, init_tokens=init_tokens) - # 6.2 Preppare running variables, list for generation cur_bsz = batch_size current_segments = [[] for _ in range(batch_size)] @@ -2442,7 +2440,7 @@ def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): return init_tokens @staticmethod - def _retrieve_logit_processors(generation_config, logits_processor, no_speech_threshold, init_tokens=None): + def _retrieve_logit_processors(generation_config, logits_processor, no_speech_threshold, num_start_tokens, is_shortform): begin_index = 1 if generation_config.return_timestamps is True: forced_decoder_ids = generation_config.forced_decoder_ids @@ -2475,8 +2473,8 @@ def _retrieve_logit_processors(generation_config, logits_processor, no_speech_th ) generation_config.begin_suppress_tokens = None - if no_speech_threshold is not None and init_tokens is not None: - begin_index_offset = (len(init_tokens) - 1) + if no_speech_threshold is not None and not is_shortform: + begin_index_offset = num_start_tokens - 1 no_speech_detector = WhisperNoSpeechDetection(no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, begin_index_offset=begin_index_offset) logits_processor = ( [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor @@ -2548,7 +2546,7 @@ def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx @staticmethod def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs): - cut_off_length = config.max_target_positions // 2 - 1 + num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1) passed_max_length = kwargs.pop("max_length", None) passed_max_new_tokens = kwargs.pop("max_new_tokens", None) @@ -2561,14 +2559,14 @@ def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, # Make sure we don't get larger than `max_length` if passed_max_length is not None and passed_max_new_tokens is None: max_length = min( - passed_max_length + cut_off_length + 1, config.max_target_positions + passed_max_length + num_initial_tokens, config.max_target_positions ) logger.info( f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment." ) elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: max_length = min( - generation_config.max_length + cut_off_length + 1, config.max_target_positions + generation_config.max_length + num_initial_tokens, config.max_target_positions ) logger.info( f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment." From 947e5424a9b649e4e62be12f0a5247a02f72af04 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 13 Dec 2023 18:01:28 +0000 Subject: [PATCH 40/75] make cleaner --- src/transformers/generation/logits_process.py | 9 +++------ .../models/whisper/modeling_whisper.py | 15 ++++++++++----- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 1658e805cf10..7704829a8f45 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1620,7 +1620,6 @@ def __init__(self, no_speech_token: int, begin_index: int, begin_index_offset: i self.begin_index = begin_index self.begin_index_offset = begin_index_offset self._no_speech_prob = [0.0] - self._has_run = False # make sure we pass all logits self._pass_all_logits = True @@ -1631,16 +1630,14 @@ def no_speech_prob(self): def set_begin_index(self, begin_index): self.begin_index = begin_index - self._has_run = False @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - no_speech_index = (self.begin_index - self.begin_index_offset) - if (input_ids.shape[1] >= no_speech_index) and not self._has_run: - no_speech_scores = scores[:, no_speech_index - 1] + if input_ids.shape[1] == self.begin_index: + no_speech_index = self.begin_index - self.begin_index_offset + no_speech_scores = scores[:, no_speech_index] probs = no_speech_scores.float().softmax(dim=-1) self._no_speech_prob = probs[:, self.no_speech_token] - self._has_run = True scores = scores[:, -1, :] diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index ee9f068bd2e1..df990a3bbb98 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1966,6 +1966,7 @@ def generate( self._set_num_frames(return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs) # 4. Retrieve logits processors + # => TODO(Patrick): The way `num_start_tokens` is retrieved here is too brittle. Need a better approach num_start_tokens = len(generation_config.forced_decoder_ids) if generation_config.forced_decoder_ids is not None else 1 logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold, num_start_tokens=num_start_tokens, is_shortform=is_shortform) @@ -1989,7 +1990,6 @@ def generate( return outputs - # 6. Else we're in longform mode which is more complex. # We need to chunk the audio input depending on when the model generates timestamp tokens @@ -2038,8 +2038,10 @@ def generate( proc.set_begin_index(decoder_input_ids.shape[-1]) print("hf in tokens", decoder_input_ids[0].tolist()) + # 6.8 Run generate with fallback seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(segment_input=segment_input, decoder_input_ids=decoder_input_ids, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, seek=seek, num_segment_frames=num_segment_frames, max_frames=max_frames, temperatures=temperatures, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, return_token_timestamps=return_token_timestamps, compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold, do_condition_on_prev_tokens=do_condition_on_prev_tokens, condition_on_prev_tokens=condition_on_prev_tokens, kwargs=kwargs) + # 6.9 In every generated sequence, split by timestamp tokens and extract segments for i, seek_sequence in enumerate(seek_sequences): prev_i = batch_idx_map[i] @@ -2129,6 +2131,8 @@ def split_by_batch_index(values, key, batch_idx): # make sure we cut a predicted EOS token if we are not finished with the generation yet prev_i = batch_idx_map[fallback_index_map[i]] is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] + + # remove eos token id if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: seek_sequence = seek_sequence[:-1] @@ -2140,8 +2144,6 @@ def split_by_batch_index(values, key, batch_idx): seek_sequence = seek_sequence[:-num_paddings] if compression_ratio_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - # compression_ratio = self._retrieve_compression_ratio(sequence_tokens[i], self.config.vocab_size) compression_ratio = self._retrieve_compression_ratio(seek_sequence, self.config.vocab_size) if compression_ratio > compression_ratio_threshold: @@ -2166,6 +2168,8 @@ def split_by_batch_index(values, key, batch_idx): # TODO(PVP) only works for batch size = 1 currently # Need to do before all other logit processors no_speech_prob = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") + print("WATCH") + print(no_speech_prob) if no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: print("Skip because of VAD") needs_fallback[i] = False @@ -2182,11 +2186,13 @@ def split_by_batch_index(values, key, batch_idx): fallback_index_map = new_fallback_index_map + # if no sequence needs to be run with temperature fallback, we're finished if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1: seek_sequences = seek_sequence_list seek_outputs = seek_outputs_list break + # if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors decoder_input_ids = torch.stack(new_decoder_input_ids) segment_input = torch.stack(new_segment_input) @@ -2474,8 +2480,7 @@ def _retrieve_logit_processors(generation_config, logits_processor, no_speech_th generation_config.begin_suppress_tokens = None if no_speech_threshold is not None and not is_shortform: - begin_index_offset = num_start_tokens - 1 - no_speech_detector = WhisperNoSpeechDetection(no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, begin_index_offset=begin_index_offset) + no_speech_detector = WhisperNoSpeechDetection(no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, begin_index_offset=num_start_tokens) logits_processor = ( [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor ) From c0d03afe213ed77400eecc26ff5592901d50b754 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 13 Dec 2023 18:33:55 +0000 Subject: [PATCH 41/75] correct tests --- tests/models/whisper/test_modeling_whisper.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 3eaf60287983..6a3a033e21ab 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -138,7 +138,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to self.count += 1 if torch.isinf(scores).all(): - import ipdb; ipdb.set_trace() + import ipdb + + ipdb.set_trace() raise ValueError("Dummy logit processor is incorrectly set up. Scores should not be all inf.") return scores @@ -1369,7 +1371,7 @@ def _check_longform_generate_single_batch(self, condition_on_prev_tokens): gen_kwargs = { "logits_processor": logits_processor, "return_segments": True, - "condition_on_prev_tokens": condition_on_prev_tokens + "condition_on_prev_tokens": condition_on_prev_tokens, } if condition_on_prev_tokens: @@ -1481,6 +1483,7 @@ def test_longform_generate_multi_batch(self): def test_longform_generate_multi_batch_cond_prev(self): self._check_longform_generate_multi_batch(condition_on_prev_tokens=True) + @require_torch @require_torchaudio class WhisperModelIntegrationTests(unittest.TestCase): @@ -2128,7 +2131,6 @@ def test_whisper_longform_single_batch_prev_cond(self): assert decoded == EXPECTED_TEXT - @slow def test_whisper_longform_multi_batch(self): # fmt: off @@ -2328,9 +2330,12 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): torch.manual_seed(0) for i in range(num_samples): - import ipdb; ipdb.set_trace() + import ipdb + + ipdb.set_trace() assert decoded_all[i] == EXPECTED_TEXT[i] + def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): if head_mask is None: head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) From abb3d56e5a4615892c0e0b7dc73567e028005c04 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 13 Dec 2023 18:34:04 +0000 Subject: [PATCH 42/75] correct src --- src/transformers/generation/logits_process.py | 13 +- src/transformers/generation/utils.py | 2 +- .../models/whisper/generation_whisper.py | 1022 +++++++++++++++++ .../models/whisper/modeling_whisper.py | 1002 +--------------- 4 files changed, 1039 insertions(+), 1000 deletions(-) create mode 100644 src/transformers/models/whisper/generation_whisper.py diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 4357b88d1fca..a65db4f1c158 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -54,6 +54,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to def pass_all_logits(self): return getattr(self, "_pass_all_logits", False) + class LogitsWarper: """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" @@ -93,7 +94,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa """ if not any(processor.pass_all_logits for processor in self) and len(scores.shape) > 2: scores = scores[:, -1, :] - + for processor in self: function_args = inspect.signature(processor.__call__).parameters if len(function_args) > 2: @@ -1824,7 +1825,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): """ def __init__( - self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None + self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None ): # support for the kwargs self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.timestamp_begin = generate_config.no_timestamps_token_id + 1 @@ -1837,9 +1838,11 @@ def __init__( else getattr(generate_config, "_detect_timestamp_from_logprob", True) ) - num_forced_ids = len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0 + num_forced_ids = ( + len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0 + ) self.begin_index = begin_index or (num_forced_ids + 1) - + self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) # TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50 # self.max_initial_timestamp_index = 50 @@ -1929,7 +1932,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to no_speech_scores = scores[:, no_speech_index] probs = no_speech_scores.float().softmax(dim=-1) self._no_speech_prob = probs[:, self.no_speech_token] - + scores = scores[:, -1, :] return scores diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b56555c019a1..ae9e022c7ad8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1129,7 +1129,7 @@ def _get_logits_processor( # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: processors.append(LogitNormalization()) - + return processors def _get_stopping_criteria( diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py new file mode 100644 index 000000000000..af524312ea09 --- /dev/null +++ b/src/transformers/models/whisper/generation_whisper.py @@ -0,0 +1,1022 @@ +# coding=utf-8 +# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +import warnings +import zlib +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F + +from ...generation.logits_process import ( + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, + WhisperNoSpeechDetection, + WhisperTimeStampLogitsProcessor, +) +from ...modeling_outputs import BaseModelOutput +from ...utils import logging +from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE + + +logger = logging.get_logger(__name__) + + +class WhisperGenerationMixin: + + def generate( + self, + input_features: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_timestamps=None, + task=None, + language=None, + is_multilingual=None, + condition_on_prev_tokens: Optional[bool] = None, + no_speech_threshold: Optional[float] = None, + temperature: Union[float, Tuple[float, ...]] = 0.0, + compression_ratio_threshold: Optional[float] = None, + logprob_threshold: Optional[float] = None, + prompt_ids: Optional[torch.Tensor] = None, + num_segment_frames: Optional[int] = None, + return_token_timestamps: Optional[bool] = None, + return_segments: bool = False, + attention_mask: Optional[torch.Tensor] = None, + time_precision: int = 0.02, + return_dict_in_generate: Optional[bool] = None, + **kwargs, + ): + """ + Transcribes or translates passed mel input features to a sequence of token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + return_timestamps (`bool`, *optional*): + Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. + task (`str`, *optional*): + Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids` + will be updated accordingly. + language (`str`, *optional*): + Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can + find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary. + is_multilingual (`bool`, *optional*): + Whether or not the model is multilingual. + prompt_ids (`torch.Tensor`, *optional*): + Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is + provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for + transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words + correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. + return_token_timestamps (`bool`, *optional*): + Whether to return token-level timestamps with the text. This can be used with or without the + `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into + words. + return_segments (`bool`, *optional*, defaults to `False`): + Whether to additionally return a list of all segments. Note that this option can only be enabled + when doing long-form transcription. + attention_mask (`torch.Tensor`, *optional*): + `attention_mask` needs to be passed when doing long-form transcription using a batch size > 1. + time_precision (`int`, *optional*, defaults to 0.02): + The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts + for 20 ms. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens. + Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when + `return_segments` is set True. In this case the generation outputs of each segment is added to each + segment. + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`. + + If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned. + + else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + + else only the generated output sequence ids are returned. + + Example: + + - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate. + + ```python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset, Audio + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + >>> model.cuda() + + >>> # load audios > 30 seconds + >>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"] + >>> # resample to 16kHz + >>> ds = ds.cast_column("audio", Audio(sampling_rate=16000)) + >>> # take first 8 audios and retrieve array + >>> audio = ds[:8]["audio"] + >>> audio = [x["array"] for x in audio] + + >>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio + >>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000) + >>> inputs = inputs.to("cuda", torch.float32) + + >>> # transcribe audio to ids + >>> generated_ids = model.generate(**inputs) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True) + >>> transcription[0] + ' Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!' + ``` + + - *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate. + + ```python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + + >>> generated_ids = model.generate(inputs=input_features) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ``` + + """ + # 0. deprecate old inputs + if "inputs" in kwargs: + input_features = kwargs.pop("inputs") + warnings.warn( + "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", + FutureWarning, + ) + # 1. copy generation config + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + + # 2. set global generate variables + input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] + num_segment_frames = input_stride * self.config.max_source_positions + total_input_frames = self._retrieve_total_input_frames(input_features=input_features, input_stride=input_stride, kwargs=kwargs) + is_shortform = total_input_frames <= num_segment_frames + + # 3. Make sure generation config is correctly set + # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not + self._set_return_outputs(return_dict_in_generate=return_dict_in_generate, return_token_timestamps=return_token_timestamps, is_shortform=is_shortform, logprob_threshold=logprob_threshold, generation_config=generation_config) + self._set_return_timestamps(return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config) + self._set_language_and_task(language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config) + # pass self.config for backward compatibility + self._set_forced_decoder_ids(task=task, language=language, prompt_ids=prompt_ids, generation_config=generation_config, config=self.config, kwargs=kwargs) + self._set_num_frames(return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs) + + # 4. Retrieve logits processors + # => TODO(Patrick): The way `num_start_tokens` is retrieved here is too brittle. Need a better approach + num_start_tokens = len(generation_config.forced_decoder_ids) if generation_config.forced_decoder_ids is not None else 1 + logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold, num_start_tokens=num_start_tokens, is_shortform=is_shortform) + + # 5. If we're in shortform mode, simple generate the whole input at once and return the output + if is_shortform: + outputs = super().generate( + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) + + if generation_config.return_token_timestamps and hasattr(generation_config, "alignment_heads"): + outputs["token_timestamps"] = self._extract_token_timestamps( + outputs, generation_config.alignment_heads, num_frames=generation_config.num_frames + ) + + return outputs + + # 6. Else we're in longform mode which is more complex. + # We need to chunk the audio input depending on when the model generates timestamp tokens + + # 6.1 Set and retrieve global longform generation variables + self._set_condition_on_prev_tokens(condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config) + + timestamp_begin = generation_config.no_timestamps_token_id + 1 + temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature + temperature = temperatures[0] + batch_size = input_features.shape[0] + + max_frames, seek = self._retrieve_max_frames_and_seek(batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames) + init_tokens = self._retrieve_init_tokens_from_forced_decoder_ids(generation_config=generation_config) + + # 6.2 Preppare running variables, list for generation + cur_bsz = batch_size + current_segments = [[] for _ in range(batch_size)] + batch_idx_map = list(range(batch_size)) + do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)] + + # 6.2 Transcribe audio until we reach the end of all input audios + while (seek < max_frames).any(): + # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop + # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order + # to know which original audio is being decoded + # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk + input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(input_features=input_features, seek=seek, max_frames=max_frames, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map) + time_offset = seek * time_precision / input_stride + seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) + + # 6.4 cut out next 30s segment from input features + segment_input = self._get_input_segment(input_features=input_features, seek=seek, seek_num_frames=seek_num_frames, num_segment_frames=num_segment_frames, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map) + + # 6.5 prepare decoder input ids + # TODO(Patrick) - clean up prev_start_of_text + suppress_tokens = self._get_attr_from_logit_processors(logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens") + prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None + decoder_input_ids, kwargs = self._prepare_decoder_input_ids(cur_bsz=cur_bsz, init_tokens=init_tokens, current_segments=current_segments, batch_idx_map=batch_idx_map, do_condition_on_prev_tokens=do_condition_on_prev_tokens, generation_config=generation_config, config=self.config, device=segment_input.device, kwargs=kwargs, prev_start_of_text=prev_start_of_text) + + # 6.6 set max new tokens or max length + kwargs = self._set_max_new_tokens_and_length(config=self.config, decoder_input_ids=decoder_input_ids, generation_config=generation_config, kwargs=kwargs) + + # 6.7 Set current `begin_index` for all logit processors + for proc in logits_processor: + if hasattr(proc, "set_begin_index"): + proc.set_begin_index(decoder_input_ids.shape[-1]) + + print("hf in tokens", decoder_input_ids[0].tolist()) + # 6.8 Run generate with fallback + seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(segment_input=segment_input, decoder_input_ids=decoder_input_ids, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, seek=seek, num_segment_frames=num_segment_frames, max_frames=max_frames, temperatures=temperatures, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, return_token_timestamps=return_token_timestamps, compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold, do_condition_on_prev_tokens=do_condition_on_prev_tokens, condition_on_prev_tokens=condition_on_prev_tokens, kwargs=kwargs) + + # 6.9 In every generated sequence, split by timestamp tokens and extract segments + for i, seek_sequence in enumerate(seek_sequences): + prev_i = batch_idx_map[i] + + if should_skip[i]: + seek[prev_i] += seek_num_frames[prev_i] + print("Skipped!") + continue + + # TODO(Patrick: delete cut type) + segments, segment_offset, cut_type = self._retrieve_segment( + seek_sequence=seek_sequence, + seek_outputs=seek_outputs, + time_offset=time_offset, + timestamp_begin=timestamp_begin, + seek_num_frames=seek_num_frames, + time_precision=time_precision, + input_stride=input_stride, + prev_idx=prev_i, + idx=i, + ) + + current_segments[prev_i] += segments + seek[prev_i] += segment_offset + + # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted + # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output + sequences = self._pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right") + + # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. + if return_segments: + return {"sequences": sequences, "segments": current_segments} + + return sequences + + def generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, batch_idx_map, seek, num_segment_frames, max_frames, temperatures, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_token_timestamps, compression_ratio_threshold, logprob_threshold, no_speech_threshold, do_condition_on_prev_tokens, condition_on_prev_tokens, kwargs): + # 6.6 Batch generate current chunk + seek_sequence_list = [None for _ in range(cur_bsz)] + seek_outputs_list = [None for _ in range(cur_bsz)] + needs_fallback = [False for _ in range(cur_bsz)] + should_skip = [False for _ in range(cur_bsz)] + fallback_index_map = list(range(cur_bsz)) + + for fallback_idx, temperature in enumerate(temperatures): + generation_config.do_sample = temperature > 0.0 + generation_config.temperature = temperature + generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1 + + seek_outputs = super().generate( + segment_input, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=decoder_input_ids, + **kwargs, + ) + + if return_token_timestamps and hasattr(generation_config, "alignment_heads"): + num_frames = getattr(generation_config, "num_frames", None) + seek_outputs["token_timestamps"] = self._extract_token_timestamps( + seek_outputs, generation_config.alignment_heads, num_frames=num_frames + ) + + if generation_config.return_dict_in_generate: + def split_by_batch_index(values, key, batch_idx): + if key == "scores": + return [v[batch_idx].cpu() for v in values] + if key == "past_key_values": + # we don't save `past_key_values` as this is too costly + return None + return values[batch_idx].cpu() + + sequence_tokens = seek_outputs["sequences"] + seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0])] + else: + sequence_tokens = seek_outputs + + # remove all previously passed decoder input ids + seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] + + new_fallback_index_map = [] + new_segment_input = [] + new_decoder_input_ids = [] + # 6.7 Loop over each decoded audio individually as each decoding can be of a different length + for i, seek_sequence in enumerate(seek_sequences): + # make sure we cut a predicted EOS token if we are not finished with the generation yet + prev_i = batch_idx_map[fallback_index_map[i]] + is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] + + # remove eos token id + if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: + seek_sequence = seek_sequence[:-1] + + print("hf out tokens", seek_sequence.tolist()) + + # remove all padding tokens + if seek_sequence[-1] == generation_config.pad_token_id: + num_paddings = (seek_sequence == generation_config.pad_token_id).sum() + seek_sequence = seek_sequence[:-num_paddings] + + if compression_ratio_threshold is not None: + compression_ratio = self._retrieve_compression_ratio(seek_sequence, self.config.vocab_size) + + if compression_ratio > compression_ratio_threshold: + print("fallback compression") + print("current temp", temperature) + needs_fallback[i] = True + + if logprob_threshold is not None: + if "sequences_scores" in seek_outputs[0]: + logprobs = [s["sequences_scores"] for s in seek_outputs][i] + else: + scores = seek_outputs[i]["scores"] + logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) + + # TODO(PVP) only works for batch size = 1 currently + if logprobs < logprob_threshold: + print("fallback logprobs", logprobs) + print("current temp", temperature) + needs_fallback[i] = True + + if no_speech_threshold is not None: + # TODO(PVP) only works for batch size = 1 currently + # Need to do before all other logit processors + no_speech_prob = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") + print("WATCH") + print(no_speech_prob) + if no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: + print("Skip because of VAD") + needs_fallback[i] = False + should_skip[i] = True + + seek_sequence_list[fallback_index_map[i]] = seek_sequence + seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] + do_condition_on_prev_tokens[fallback_index_map[i]] = condition_on_prev_tokens and temperature < 0.5 + + if needs_fallback[i]: + new_fallback_index_map.append(fallback_index_map[i]) + new_segment_input.append(segment_input[i]) + new_decoder_input_ids.append(decoder_input_ids[i]) + + fallback_index_map = new_fallback_index_map + + # if no sequence needs to be run with temperature fallback, we're finished + if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1: + seek_sequences = seek_sequence_list + seek_outputs = seek_outputs_list + break + + # if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors + decoder_input_ids = torch.stack(new_decoder_input_ids) + segment_input = torch.stack(new_segment_input) + + return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens + + + @staticmethod + def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): + logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None) + if logit_processor: + return getattr(logit_processor, attribute_name, None) + return None + + @staticmethod + def _retrieve_total_input_frames(input_features, input_stride, kwargs): + if input_features is not None: + return input_features.shape[-1] + + if "encoder_outputs" in kwargs: + encoder_outputs_shape = ( + kwargs["encoder_outputs"][0].shape + if isinstance(kwargs["encoder_outputs"], BaseModelOutput) + else kwargs["encoder_outputs"].shape + ) + return encoder_outputs_shape[1] * input_stride + + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.") + + @staticmethod + def _set_return_outputs(return_dict_in_generate, return_token_timestamps, is_shortform, logprob_threshold, generation_config): + if return_dict_in_generate is None: + return_dict_in_generate = generation_config.return_dict_in_generate + + generation_config.return_token_timestamps = return_token_timestamps + if return_token_timestamps: + return_dict_in_generate = True + generation_config.output_attentions = True + + if not is_shortform and logprob_threshold is not None: + return_dict_in_generate = True + generation_config.output_scores = True + + generation_config.return_dict_in_generate = return_dict_in_generate + + @staticmethod + def _set_return_timestamps(return_timestamps, is_shortform, generation_config): + if return_timestamps is True: + if not hasattr(generation_config, "no_timestamps_token_id"): + raise ValueError( + "You are trying to return timestamps, but the generation config is not properly set. " + "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " + "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" + ) + generation_config.return_timestamps = True + elif not is_shortform: + if return_timestamps is False: + raise ValueError( + "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " + "requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features." + ) + + if not hasattr(generation_config, "no_timestamps_token_id"): + raise ValueError( + "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " + "requires the generation config to have `no_timestamps_token_id` correctly. " + "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " + "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" + "or make sure to pass no more than 3000 mel input features." + ) + + logger.info("Setting `return_timestamps=True` for long-form generation.") + generation_config.return_timestamps = True + else: + generation_config.return_timestamps = False + + @staticmethod + def _set_language_and_task(language, task, is_multilingual, generation_config): + if is_multilingual is not None: + if not hasattr(generation_config, "is_multilingual"): + raise ValueError( + "The generation config is outdated and is thus not compatible with the `is_multilingual` argument " + "to `generate`. Please update the generation config as per the instructions " + "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + ) + generation_config.is_multilingual = is_multilingual + + if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual: + if task is not None or language is not None: + raise ValueError( + "Cannot specify `task` or `language` for an English-only model. If the model is intended to be " + "multilingual, pass `is_multilingual=True` to generate, or update the generation config." + ) + + if language is not None: + if not hasattr(generation_config, "lang_to_id"): + raise ValueError( + "The generation config is outdated and is thus not compatible with the `language` argument " + "to `generate`. Either set the language using the `forced_decoder_ids` in the model config, " + "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + ) + language = language.lower() + generation_config.language = language + + if task is not None: + if not hasattr(generation_config, "task_to_id"): + raise ValueError( + "The generation config is outdated and is thus not compatible with the `task` argument " + "to `generate`. Either set the task using the `forced_decoder_ids` in the model config, " + "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + ) + generation_config.task = task + + @staticmethod + def _set_forced_decoder_ids(task, language, prompt_ids, generation_config, config, kwargs): + forced_decoder_ids = None + # Legacy code for backward compatibility + if hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None: + forced_decoder_ids = config.forced_decoder_ids + elif ( + hasattr(generation_config, "forced_decoder_ids") + and generation_config.forced_decoder_ids is not None + ): + forced_decoder_ids = generation_config.forced_decoder_ids + else: + forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) + + if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): + forced_decoder_ids = [] + if hasattr(generation_config, "language"): + if generation_config.language in generation_config.lang_to_id.keys(): + language_token = generation_config.language + elif generation_config.language in TO_LANGUAGE_CODE.keys(): + language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" + elif generation_config.language in TO_LANGUAGE_CODE.values(): + language_token = f"<|{generation_config.language}|>" + else: + is_language_code = len(generation_config.language) == 2 + raise ValueError( + f"Unsupported language: {generation_config.language}. Language should be one of:" + f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." + ) + if language_token not in generation_config.lang_to_id: + raise ValueError( + f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`." + "(You should just add it to the generation config)" + ) + forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) + else: + forced_decoder_ids.append((1, None)) # automatically detect the language + + if hasattr(generation_config, "task"): + if generation_config.task in TASK_IDS: + forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) + else: + raise ValueError( + f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" + ) + elif hasattr(generation_config, "task_to_id"): + forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe + if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: + idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 + forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) + + if forced_decoder_ids is not None: + generation_config.forced_decoder_ids = forced_decoder_ids + + if prompt_ids is not None: + if kwargs.get("decoder_start_token_id") is not None: + raise ValueError( + "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten." + ) + prompt_ids = prompt_ids.tolist() + decoder_start_token_id, *text_prompt_ids = prompt_ids + # Slicing the text prompt ids in a manner consistent with the OpenAI implementation + # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) + text_prompt_ids = text_prompt_ids[-config.max_target_positions // 2 - 1 :] + # Set the decoder_start_token_id to <|startofprev|> + kwargs.update({"decoder_start_token_id": decoder_start_token_id}) + + # If the user passes `max_new_tokens`, increase its number to account for the prompt + if kwargs.get("max_new_tokens", None) is not None: + kwargs["max_new_tokens"] += len(text_prompt_ids) + if kwargs["max_new_tokens"] >= config.max_target_positions: + raise ValueError( + f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` " + f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced " + f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the " + f"`max_target_positions` of the Whisper model: {config.max_target_positions}. " + "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " + f"so that their combined length is less that {config.max_target_positions}." + ) + + # Reformat the forced_decoder_ids to incorporate the prompt + non_prompt_forced_decoder_ids = ( + kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids + ) + forced_decoder_ids = [ + *text_prompt_ids, + generation_config.decoder_start_token_id, + *[token for _, token in non_prompt_forced_decoder_ids], + ] + forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] + generation_config.forced_decoder_ids = forced_decoder_ids + + @staticmethod + def _set_num_frames(return_token_timestamps, generation_config, kwargs): + if return_token_timestamps: + if getattr(generation_config, "task", None) == "translate": + logger.warning("Token-level timestamps may not be reliable for task 'translate'.") + if not hasattr(generation_config, "alignment_heads"): + raise ValueError( + "Model generation config has no `alignment_heads`, token-level timestamps not available. " + "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." + ) + + generation_config.num_frames = kwargs.pop("num_frames", None) + + @staticmethod + def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config): + condition_on_prev_tokens = ( + condition_on_prev_tokens + if condition_on_prev_tokens is not None + else getattr(generation_config, "condition_on_prev_tokens", False) + ) + generation_config.condition_on_prev_tokens = condition_on_prev_tokens + + @staticmethod + def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames): + if batch_size > 1 and attention_mask is None: + raise ValueError( + "When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " + ) + elif batch_size > 1: + max_frames = attention_mask.sum(-1).cpu().to(torch.long) + seek = torch.zeros((batch_size,), dtype=torch.long) + else: + max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames + seek = torch.zeros((1,), dtype=torch.long) + + return max_frames, seek + + @staticmethod + def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): + init_tokens = [generation_config.decoder_start_token_id] + forced_decoder_ids = generation_config.forced_decoder_ids + if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: + i = 1 + while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: + init_tokens += [forced_decoder_ids[0][1]] + forced_decoder_ids = forced_decoder_ids[1:] + i += 1 + + forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None + generation_config.forced_decoder_ids = forced_decoder_ids + + return init_tokens + + @staticmethod + def _retrieve_logit_processors(generation_config, logits_processor, no_speech_threshold, num_start_tokens, is_shortform): + begin_index = 1 + if generation_config.return_timestamps is True: + forced_decoder_ids = generation_config.forced_decoder_ids + last_forced_decoder_ids = forced_decoder_ids[-1][-1] if forced_decoder_ids is not None else None + if last_forced_decoder_ids == generation_config.no_timestamps_token_id: + # remove no_timestamp to be forcefully generated if we want to return timestamps + # this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly + forced_decoder_ids = forced_decoder_ids[:-1] if len(forced_decoder_ids) > 1 else None + # Make sure that if list is empty we set it to None + generation_config.forced_decoder_ids = forced_decoder_ids + + begin_index = begin_index + len(forced_decoder_ids) if forced_decoder_ids is not None else begin_index + + timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index) + logits_processor = ( + [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor + ) + + if generation_config.suppress_tokens is not None: + suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens) + logits_processor = ( + [suppress_tokens_processor] if logits_processor is None else [suppress_tokens_processor] + logits_processor + ) + generation_config.suppress_tokens = None + + if generation_config.begin_suppress_tokens is not None: + begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index=begin_index) + logits_processor = ( + [begin_suppress_processor] if logits_processor is None else [begin_suppress_processor] + logits_processor + ) + generation_config.begin_suppress_tokens = None + + if no_speech_threshold is not None and not is_shortform: + no_speech_detector = WhisperNoSpeechDetection(no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, begin_index_offset=num_start_tokens) + logits_processor = ( + [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor + ) + + return logits_processor + + @staticmethod + def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map): + prev_bsz = cur_bsz + new_batch_idx_map = [] + for i in range(prev_bsz): + prev_i = batch_idx_map[i] + if seek[prev_i] >= max_frames[prev_i]: + cut_index = i + (cur_bsz - prev_bsz) + cur_bsz -= 1 + input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0) + else: + # cut out index that goes away + new_batch_idx_map.append(prev_i) + + return input_features, cur_bsz, new_batch_idx_map + + @staticmethod + def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map): + segment_input = [] + for i in range(cur_bsz): + prev_i = batch_idx_map[i] + segment_input_slice = input_features[ + i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i] + ] + + if segment_input_slice.shape[-1] < num_segment_frames: + # pad to 3000 if necessary + segment_input_slice = F.pad( + segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1]) + ) + + segment_input.append(segment_input_slice) + + segment_input = torch.cat(segment_input, dim=0) + + return segment_input + + # TODO(Patrick) - remove prev_start_of_text + @staticmethod + def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx_map, do_condition_on_prev_tokens, generation_config, config, device, kwargs, prev_start_of_text): + cut_off_length = config.max_target_positions // 2 - 1 + + one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long) + decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) + + # if condition_on_prev_tokens and len(current_segments[0]) > 0 and temperature < 0.5: + if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0: + # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 + active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] + prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or prev_start_of_text + + bos_token_tensor = prev_start_of_text * one_tensor[0] + prev_tokens = WhisperGenerationMixin._pad_to_max_length( + active_segments, generation_config.pad_token_id, padding="left", bos_token_tensor=bos_token_tensor, cut_off_length=cut_off_length + ) + decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) + + # kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) + + + return decoder_input_ids, kwargs + + @staticmethod + def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs): + num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1) + + passed_max_length = kwargs.pop("max_length", None) + passed_max_new_tokens = kwargs.pop("max_new_tokens", None) + max_length_config = getattr(generation_config, "max_length", None) + max_new_tokens_config = getattr(generation_config, "max_new_tokens", None) + + max_new_tokens = None + max_length = None + + # Make sure we don't get larger than `max_length` + if passed_max_length is not None and passed_max_new_tokens is None: + max_length = min( + passed_max_length + num_initial_tokens, config.max_target_positions + ) + logger.info( + f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment." + ) + elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: + max_length = min( + generation_config.max_length + num_initial_tokens, config.max_target_positions + ) + logger.info( + f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment." + ) + elif ( + passed_max_new_tokens is not None + and passed_max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions + ): + max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] + elif ( + passed_max_new_tokens is None + and max_new_tokens_config is not None + and max_new_tokens_config + decoder_input_ids.shape[-1] > config.max_target_positions + ): + max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] + + kwargs["max_new_tokens"] = max_new_tokens + kwargs["max_length"] = max_length + + return kwargs + + @staticmethod + def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None): + max_total_length = 0 + sequences = [] + if padding not in ["right", "left"]: + raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}") + + for current_segment_list in current_segments: + if current_segment_list is not None: + sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1) + + if cut_off_length is not None: + sequence = sequence[-cut_off_length:] + + if bos_token_tensor is not None: + sequence = torch.cat([bos_token_tensor, sequence]) + + sequences.append(sequence) + max_total_length = max(max_total_length, len(sequences[-1])) + else: + sequences.append(bos_token_tensor) + + for i in range(len(current_segments)): + pad_length = max_total_length - len(sequences[i]) + pad = (0, pad_length) if padding == "right" else (pad_length, 0) + sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id) + + sequences = torch.stack(sequences, dim=0) + return sequences + + @staticmethod + def _retrieve_compression_ratio(tokens, vocab_size): + length = int(math.log2(vocab_size) / 8) + 1 + token_bytes = b''.join([t.to_bytes(length, 'little') for t in tokens.tolist()]) + + # string = tok.decode(tokens, skip_special_tokens=True) + # string_bytes = string.encode("utf-8") + # string_compression_ratio = len(string_bytes) / len(zlib.compress(string_bytes)) + + compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes)) + + # print(f"HERE: string: {string}") + # print(f"HERE: string ratio: {string_compression_ratio}") + # print(f"HERE: token ratio: {compression_ratio}") + # print('HERE:' + 20 * '-') + + return compression_ratio + + @staticmethod + def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature): + rescale_temperature = temperature if temperature > 0.0 else 1 + scores = torch.stack(scores).to(tokens.device) + + # TODO(Patrick) - only leave scores = scores[:tokens.shape[0]] part + if scores.shape[0] > tokens.shape[0]: + scores = scores[:tokens.shape[0]] + else: + tokens = tokens[-scores.shape[0]:] + + logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype) + + # retrieve logprob of selected tokens and sum + sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0])) + length = (tokens != eos_token_id).sum(-1) + + avg_logprobs = sum_logprobs / (length + 1) + return avg_logprobs + + @staticmethod + def _retrieve_segment( + seek_sequence, + seek_outputs, + time_offset, + timestamp_begin, + seek_num_frames, + time_precision, + input_stride, + prev_idx, + idx, + ): + # find the predicted "end of segment" predictions of Whisper + # "end of segment" predictions occur whenever Whisper predicts a timestamp token + timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin) + single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] + timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + timestamp_segment_indices.add_(1) + + # If whisper predicted a "end of segment" via a timestep token, let's go ever each + # "end of segment" prediction and slice the decoding into segments accordingly + if len(timestamp_segment_indices) > 0: + # if the output contains two consecutive timestamp tokens + slices = timestamp_segment_indices.tolist() + segments = [] + if single_timestamp_ending: + slices.append(len(seek_sequence)) + + last_slice = 0 + # Add each segment to list of all segments + for current_slice in slices: + sliced_tokens = seek_sequence[last_slice : current_slice] + start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin + end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin + segments.append( + { + "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, + "end": time_offset[prev_idx] + end_timestamp_pos * time_precision, + "tokens": sliced_tokens, + "result": seek_outputs[idx], + } + ) + last_slice = current_slice + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last timestamp. + segment_offset = seek_num_frames[prev_idx] + cut_type = "single ending" + else: + # otherwise, ignore the unfinished segment and seek to the last timestamp + # here we throw away all predictions after the last predicted "end of segment" + # since we are cutting right in the middle of an audio + last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin + segment_offset = last_timestamp_pos * input_stride + cut_type = "cut" + else: + # If whisper does not predict any "end of segment" token, then + # the whole decoding is considered a segment and we add it to the list of segments + timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()] + last_timestamp_pos = seek_num_frames[prev_idx] + if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: + # no consecutive timestamps but it has a timestamp; use the last one. + last_timestamp_pos = timestamps[-1].item() - timestamp_begin + + segments = [ + { + "start": time_offset[prev_idx], + "end": time_offset[prev_idx] + last_timestamp_pos * time_precision, + "tokens": seek_sequence, + "result": seek_outputs[idx], + } + ] + segment_offset = seek_num_frames[prev_idx] + cut_type = "all" + + return segments, segment_offset, cut_type + diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index bcec47151ff3..7e3a75cbb7c5 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -15,23 +15,17 @@ """ PyTorch Whisper model.""" RUN_NEW_WAY = True -import copy import math -import warnings from typing import Optional, Tuple, Union import numpy as np -import copy import torch -import zlib import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation.logits_process import WhisperTimeStampLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, WhisperNoSpeechDetection, SuppressTokensLogitsProcessor -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutput, @@ -51,11 +45,9 @@ replace_return_docstrings, ) from .configuration_whisper import WhisperConfig -from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE +from .generation_whisper import WhisperGenerationMixin -from transformers import AutoTokenizer - # tok = AutoTokenizer.from_pretrained("openai/whisper-tiny") @@ -1470,9 +1462,13 @@ def forward( # embed positions if input_ids is not None: - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids) + positions = self.embed_positions( + input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids + ) else: - positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids) + positions = self.embed_positions( + inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids + ) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1739,7 +1735,7 @@ def forward( "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.", WHISPER_START_DOCSTRING, ) -class WhisperForConditionalGeneration(WhisperPreTrainedModel): +class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedModel): base_model_prefix = "model" _tied_weights_keys = ["proj_out.weight"] @@ -1873,988 +1869,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def generate( - self, - input_features: Optional[torch.Tensor] = None, - generation_config=None, - logits_processor=None, - stopping_criteria=None, - prefix_allowed_tokens_fn=None, - synced_gpus=False, - return_timestamps=None, - task=None, - language=None, - is_multilingual=None, - condition_on_prev_tokens: Optional[bool] = None, - no_speech_threshold: Optional[float] = None, - temperature: Union[float, Tuple[float, ...]] = 0.0, - compression_ratio_threshold: Optional[float] = None, - logprob_threshold: Optional[float] = None, - prompt_ids: Optional[torch.Tensor] = None, - num_segment_frames: Optional[int] = None, - return_token_timestamps: Optional[bool] = None, - return_segments: bool = False, - attention_mask: Optional[torch.Tensor] = None, - time_precision: int = 0.02, - return_dict_in_generate: Optional[bool] = None, - **kwargs, - ): - """ - Transcribes or translates passed mel input features to a sequence of token ids. - - - - Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the - model's default generation configuration. You can override any `generation_config` by passing the corresponding - parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - - For an overview of generation strategies and code examples, check out the [following - guide](./generation_strategies). - - - - Parameters: - inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): - The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the - method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` - should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of - `input_ids`, `input_values`, `input_features`, or `pixel_values`. - generation_config (`~generation.GenerationConfig`, *optional*): - The generation configuration to be used as base parametrization for the generation call. `**kwargs` - passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which had the following loading - priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model - configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s - default values, whose documentation should be checked to parameterize generation. - logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and - generation config. If a logit processor is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complement the default stopping criteria built from arguments and a - generation config. If a stopping criteria is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): - If provided, this function constraints the beam search to allowed tokens only at each step. If not - provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and - `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned - on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful - for constrained generation conditioned on the prefix, as described in [Autoregressive Entity - Retrieval](https://arxiv.org/abs/2010.00904). - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - return_timestamps (`bool`, *optional*): - Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. - task (`str`, *optional*): - Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids` - will be updated accordingly. - language (`str`, *optional*): - Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can - find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary. - is_multilingual (`bool`, *optional*): - Whether or not the model is multilingual. - prompt_ids (`torch.Tensor`, *optional*): - Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is - provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for - transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words - correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. - return_token_timestamps (`bool`, *optional*): - Whether to return token-level timestamps with the text. This can be used with or without the - `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into - words. - return_segments (`bool`, *optional*, defaults to `False`): - Whether to additionally return a list of all segments. Note that this option can only be enabled - when doing long-form transcription. - attention_mask (`torch.Tensor`, *optional*): - `attention_mask` needs to be passed when doing long-form transcription using a batch size > 1. - time_precision (`int`, *optional*, defaults to 0.02): - The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts - for 20 ms. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens. - Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when - `return_segments` is set True. In this case the generation outputs of each segment is added to each - segment. - kwargs (`Dict[str, Any]`, *optional*): - Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be - forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder - specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. - - Return: - [`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` - or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`. - - If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned. - - else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are: - - - [`~generation.GreedySearchEncoderDecoderOutput`], - - [`~generation.SampleEncoderDecoderOutput`], - - [`~generation.BeamSearchEncoderDecoderOutput`], - - [`~generation.BeamSampleEncoderDecoderOutput`] - - else only the generated output sequence ids are returned. - - Example: - - - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate. - - ```python - >>> import torch - >>> from transformers import AutoProcessor, WhisperForConditionalGeneration - >>> from datasets import load_dataset, Audio - - >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - >>> model.cuda() - - >>> # load audios > 30 seconds - >>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"] - >>> # resample to 16kHz - >>> ds = ds.cast_column("audio", Audio(sampling_rate=16000)) - >>> # take first 8 audios and retrieve array - >>> audio = ds[:8]["audio"] - >>> audio = [x["array"] for x in audio] - - >>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio - >>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000) - >>> inputs = inputs.to("cuda", torch.float32) - - >>> # transcribe audio to ids - >>> generated_ids = model.generate(**inputs) - - >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True) - >>> transcription[0] - ' Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!' - ``` - - - *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate. - - ```python - >>> import torch - >>> from transformers import AutoProcessor, WhisperForConditionalGeneration - >>> from datasets import load_dataset - - >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - - >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") - >>> input_features = inputs.input_features - - >>> generated_ids = model.generate(inputs=input_features) - - >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - >>> transcription - ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' - ``` - - """ - # 0. deprecate old inputs - if "inputs" in kwargs: - input_features = kwargs.pop("inputs") - warnings.warn( - "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", - FutureWarning, - ) - # 1. copy generation config - if generation_config is None: - generation_config = copy.deepcopy(self.generation_config) - - # 2. set global generate variables - input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] - num_segment_frames = input_stride * self.config.max_source_positions - total_input_frames = self._retrieve_total_input_frames(input_features=input_features, kwargs=kwargs) - is_shortform = total_input_frames <= num_segment_frames - - # 3. Make sure generation config is correctly set - # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not - self._set_return_outputs(return_dict_in_generate=return_dict_in_generate, return_token_timestamps=return_token_timestamps, is_shortform=is_shortform, logprob_threshold=logprob_threshold, generation_config=generation_config) - self._set_return_timestamps(return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config) - self._set_language_and_task(language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config) - # pass self.config for backward compatibility - self._set_forced_decoder_ids(task=task, language=language, prompt_ids=prompt_ids, generation_config=generation_config, config=self.config, kwargs=kwargs) - self._set_num_frames(return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs) - - # 4. Retrieve logits processors - # => TODO(Patrick): The way `num_start_tokens` is retrieved here is too brittle. Need a better approach - num_start_tokens = len(generation_config.forced_decoder_ids) if generation_config.forced_decoder_ids is not None else 1 - logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold, num_start_tokens=num_start_tokens, is_shortform=is_shortform) - - # 5. If we're in shortform mode, simple generate the whole input at once and return the output - if is_shortform: - outputs = super().generate( - input_features, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - ) - - if generation_config.return_token_timestamps and hasattr(generation_config, "alignment_heads"): - outputs["token_timestamps"] = self._extract_token_timestamps( - outputs, generation_config.alignment_heads, num_frames=generation_config.num_frames - ) - - return outputs - - # 6. Else we're in longform mode which is more complex. - # We need to chunk the audio input depending on when the model generates timestamp tokens - - # 6.1 Set and retrieve global longform generation variables - self._set_condition_on_prev_tokens(condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config) - - timestamp_begin = generation_config.no_timestamps_token_id + 1 - temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature - temperature = temperatures[0] - batch_size = input_features.shape[0] - - max_frames, seek = self._retrieve_max_frames_and_seek(batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames) - init_tokens = self._retrieve_init_tokens_from_forced_decoder_ids(generation_config=generation_config) - - # 6.2 Preppare running variables, list for generation - cur_bsz = batch_size - current_segments = [[] for _ in range(batch_size)] - batch_idx_map = list(range(batch_size)) - do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)] - - # 6.2 Transcribe audio until we reach the end of all input audios - while (seek < max_frames).any(): - # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop - # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order - # to know which original audio is being decoded - # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk - input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(input_features=input_features, seek=seek, max_frames=max_frames, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map) - time_offset = seek * time_precision / input_stride - seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) - - # 6.4 cut out next 30s segment from input features - segment_input = self._get_input_segment(input_features=input_features, seek=seek, seek_num_frames=seek_num_frames, num_segment_frames=num_segment_frames, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map) - - # 6.5 prepare decoder input ids - # TODO(Patrick) - clean up prev_start_of_text - suppress_tokens = self._get_attr_from_logit_processors(logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens") - prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None - decoder_input_ids, kwargs = self._prepare_decoder_input_ids(cur_bsz=cur_bsz, init_tokens=init_tokens, current_segments=current_segments, batch_idx_map=batch_idx_map, do_condition_on_prev_tokens=do_condition_on_prev_tokens, generation_config=generation_config, config=self.config, device=segment_input.device, kwargs=kwargs, prev_start_of_text=prev_start_of_text) - - # 6.6 set max new tokens or max length - kwargs = self._set_max_new_tokens_and_length(config=self.config, decoder_input_ids=decoder_input_ids, generation_config=generation_config, kwargs=kwargs) - - # 6.7 Set current `begin_index` for all logit processors - for proc in logits_processor: - if hasattr(proc, "set_begin_index"): - proc.set_begin_index(decoder_input_ids.shape[-1]) - - print("hf in tokens", decoder_input_ids[0].tolist()) - # 6.8 Run generate with fallback - seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(segment_input=segment_input, decoder_input_ids=decoder_input_ids, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, seek=seek, num_segment_frames=num_segment_frames, max_frames=max_frames, temperatures=temperatures, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, return_token_timestamps=return_token_timestamps, compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold, do_condition_on_prev_tokens=do_condition_on_prev_tokens, condition_on_prev_tokens=condition_on_prev_tokens, kwargs=kwargs) - - # 6.9 In every generated sequence, split by timestamp tokens and extract segments - for i, seek_sequence in enumerate(seek_sequences): - prev_i = batch_idx_map[i] - - if should_skip[i]: - seek[prev_i] += seek_num_frames[prev_i] - print("Skipped!") - continue - - # TODO(Patrick: delete cut type) - segments, segment_offset, cut_type = self._retrieve_segment( - seek_sequence=seek_sequence, - seek_outputs=seek_outputs, - time_offset=time_offset, - timestamp_begin=timestamp_begin, - seek_num_frames=seek_num_frames, - time_precision=time_precision, - input_stride=input_stride, - prev_idx=prev_i, - idx=i, - ) - - current_segments[prev_i] += segments - seek[prev_i] += segment_offset - - # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted - # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output - sequences = self._pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right") - - # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. - if return_segments: - return {"sequences": sequences, "segments": current_segments} - - return sequences - - def generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, batch_idx_map, seek, num_segment_frames, max_frames, temperatures, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_token_timestamps, compression_ratio_threshold, logprob_threshold, no_speech_threshold, do_condition_on_prev_tokens, condition_on_prev_tokens, kwargs): - # 6.6 Batch generate current chunk - seek_sequence_list = [None for _ in range(cur_bsz)] - seek_outputs_list = [None for _ in range(cur_bsz)] - needs_fallback = [False for _ in range(cur_bsz)] - should_skip = [False for _ in range(cur_bsz)] - fallback_index_map = list(range(cur_bsz)) - - for fallback_idx, temperature in enumerate(temperatures): - generation_config.do_sample = temperature > 0.0 - generation_config.temperature = temperature - generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1 - - seek_outputs = super().generate( - segment_input, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - decoder_input_ids=decoder_input_ids, - **kwargs, - ) - - if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - num_frames = getattr(generation_config, "num_frames", None) - seek_outputs["token_timestamps"] = self._extract_token_timestamps( - seek_outputs, generation_config.alignment_heads, num_frames=num_frames - ) - - if generation_config.return_dict_in_generate: - def split_by_batch_index(values, key, batch_idx): - if key == "scores": - return list(v[batch_idx].cpu() for v in values) - if key == "past_key_values": - # we don't save `past_key_values` as this is too costly - return None - return values[batch_idx].cpu() - - sequence_tokens = seek_outputs["sequences"] - seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0])] - else: - sequence_tokens = seek_outputs - - # remove all previously passed decoder input ids - seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] - - new_fallback_index_map = [] - new_segment_input = [] - new_decoder_input_ids = [] - # 6.7 Loop over each decoded audio individually as each decoding can be of a different length - for i, seek_sequence in enumerate(seek_sequences): - # make sure we cut a predicted EOS token if we are not finished with the generation yet - prev_i = batch_idx_map[fallback_index_map[i]] - is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] - - # remove eos token id - if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: - seek_sequence = seek_sequence[:-1] - - print("hf out tokens", seek_sequence.tolist()) - - # remove all padding tokens - if seek_sequence[-1] == generation_config.pad_token_id: - num_paddings = (seek_sequence == generation_config.pad_token_id).sum() - seek_sequence = seek_sequence[:-num_paddings] - - if compression_ratio_threshold is not None: - compression_ratio = self._retrieve_compression_ratio(seek_sequence, self.config.vocab_size) - - if compression_ratio > compression_ratio_threshold: - print("fallback compression") - print("current temp", temperature) - needs_fallback[i] = True - - if logprob_threshold is not None: - if "sequences_scores" in seek_outputs[0]: - logprobs = [s["sequences_scores"] for s in seek_outputs][i] - else: - scores = seek_outputs[i]["scores"] - logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) - - # TODO(PVP) only works for batch size = 1 currently - if logprobs < logprob_threshold: - print("fallback logprobs", logprobs) - print("current temp", temperature) - needs_fallback[i] = True - - if no_speech_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - # Need to do before all other logit processors - no_speech_prob = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") - print("WATCH") - print(no_speech_prob) - if no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: - print("Skip because of VAD") - needs_fallback[i] = False - should_skip[i] = True - - seek_sequence_list[fallback_index_map[i]] = seek_sequence - seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] - do_condition_on_prev_tokens[fallback_index_map[i]] = condition_on_prev_tokens and temperature < 0.5 - - if needs_fallback[i]: - new_fallback_index_map.append(fallback_index_map[i]) - new_segment_input.append(segment_input[i]) - new_decoder_input_ids.append(decoder_input_ids[i]) - - fallback_index_map = new_fallback_index_map - - # if no sequence needs to be run with temperature fallback, we're finished - if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1: - seek_sequences = seek_sequence_list - seek_outputs = seek_outputs_list - break - - # if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors - decoder_input_ids = torch.stack(new_decoder_input_ids) - segment_input = torch.stack(new_segment_input) - - return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens - - - @staticmethod - def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): - logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None) - if logit_processor: - return getattr(logit_processor, attribute_name, None) - return None - - @staticmethod - def _retrieve_total_input_frames(input_features, kwargs): - if input_features is not None: - return input_features.shape[-1] - - if "encoder_outputs" in kwargs: - encoder_outputs_shape = ( - kwargs["encoder_outputs"][0].shape - if isinstance(kwargs["encoder_outputs"], BaseModelOutput) - else kwargs["encoder_outputs"].shape - ) - return encoder_outputs_shape[1] * input_stride - - raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.") - - @staticmethod - def _set_return_outputs(return_dict_in_generate, return_token_timestamps, is_shortform, logprob_threshold, generation_config): - if return_dict_in_generate is None: - return_dict_in_generate = generation_config.return_dict_in_generate - - generation_config.return_token_timestamps = return_token_timestamps - if return_token_timestamps: - return_dict_in_generate = True - generation_config.output_attentions = True - - if not is_shortform and logprob_threshold is not None: - return_dict_in_generate = True - generation_config.output_scores = True - - generation_config.return_dict_in_generate = return_dict_in_generate - - @staticmethod - def _set_return_timestamps(return_timestamps, is_shortform, generation_config): - if return_timestamps is True: - if not hasattr(generation_config, "no_timestamps_token_id"): - raise ValueError( - "You are trying to return timestamps, but the generation config is not properly set. " - "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " - "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" - ) - generation_config.return_timestamps = True - elif not is_shortform: - if return_timestamps is False: - raise ValueError( - "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " - "requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features." - ) - - if not hasattr(generation_config, "no_timestamps_token_id"): - raise ValueError( - "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " - "requires the generation config to have `no_timestamps_token_id` correctly. " - "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " - "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" - "or make sure to pass no more than 3000 mel input features." - ) - - logger.info("Setting `return_timestamps=True` for long-form generation.") - generation_config.return_timestamps = True - else: - generation_config.return_timestamps = False - - @staticmethod - def _set_language_and_task(language, task, is_multilingual, generation_config): - if is_multilingual is not None: - if not hasattr(generation_config, "is_multilingual"): - raise ValueError( - "The generation config is outdated and is thus not compatible with the `is_multilingual` argument " - "to `generate`. Please update the generation config as per the instructions " - "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" - ) - generation_config.is_multilingual = is_multilingual - - if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual: - if task is not None or language is not None: - raise ValueError( - "Cannot specify `task` or `language` for an English-only model. If the model is intended to be " - "multilingual, pass `is_multilingual=True` to generate, or update the generation config." - ) - - if language is not None: - if not hasattr(generation_config, "lang_to_id"): - raise ValueError( - "The generation config is outdated and is thus not compatible with the `language` argument " - "to `generate`. Either set the language using the `forced_decoder_ids` in the model config, " - "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" - ) - language = language.lower() - generation_config.language = language - - if task is not None: - if not hasattr(generation_config, "task_to_id"): - raise ValueError( - "The generation config is outdated and is thus not compatible with the `task` argument " - "to `generate`. Either set the task using the `forced_decoder_ids` in the model config, " - "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" - ) - generation_config.task = task - - @staticmethod - def _set_forced_decoder_ids(task, language, prompt_ids, generation_config, config, kwargs): - forced_decoder_ids = None - # Legacy code for backward compatibility - if hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None: - forced_decoder_ids = config.forced_decoder_ids - elif ( - hasattr(generation_config, "forced_decoder_ids") - and generation_config.forced_decoder_ids is not None - ): - forced_decoder_ids = generation_config.forced_decoder_ids - else: - forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) - - if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): - forced_decoder_ids = [] - if hasattr(generation_config, "language"): - if generation_config.language in generation_config.lang_to_id.keys(): - language_token = generation_config.language - elif generation_config.language in TO_LANGUAGE_CODE.keys(): - language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" - elif generation_config.language in TO_LANGUAGE_CODE.values(): - language_token = f"<|{generation_config.language}|>" - else: - is_language_code = len(generation_config.language) == 2 - raise ValueError( - f"Unsupported language: {generation_config.language}. Language should be one of:" - f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." - ) - if language_token not in generation_config.lang_to_id: - raise ValueError( - f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`." - "(You should just add it to the generation config)" - ) - forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) - else: - forced_decoder_ids.append((1, None)) # automatically detect the language - - if hasattr(generation_config, "task"): - if generation_config.task in TASK_IDS: - forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) - else: - raise ValueError( - f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" - ) - elif hasattr(generation_config, "task_to_id"): - forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe - if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: - idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 - forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) - - if forced_decoder_ids is not None: - generation_config.forced_decoder_ids = forced_decoder_ids - - if prompt_ids is not None: - if kwargs.get("decoder_start_token_id") is not None: - raise ValueError( - "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten." - ) - prompt_ids = prompt_ids.tolist() - decoder_start_token_id, *text_prompt_ids = prompt_ids - # Slicing the text prompt ids in a manner consistent with the OpenAI implementation - # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) - text_prompt_ids = text_prompt_ids[-config.max_target_positions // 2 - 1 :] - # Set the decoder_start_token_id to <|startofprev|> - kwargs.update({"decoder_start_token_id": decoder_start_token_id}) - - # If the user passes `max_new_tokens`, increase its number to account for the prompt - if kwargs.get("max_new_tokens", None) is not None: - kwargs["max_new_tokens"] += len(text_prompt_ids) - if kwargs["max_new_tokens"] >= config.max_target_positions: - raise ValueError( - f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` " - f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced " - f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the " - f"`max_target_positions` of the Whisper model: {config.max_target_positions}. " - "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " - f"so that their combined length is less that {config.max_target_positions}." - ) - - # Reformat the forced_decoder_ids to incorporate the prompt - non_prompt_forced_decoder_ids = ( - kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids - ) - forced_decoder_ids = [ - *text_prompt_ids, - generation_config.decoder_start_token_id, - *[token for _, token in non_prompt_forced_decoder_ids], - ] - forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] - generation_config.forced_decoder_ids = forced_decoder_ids - - @staticmethod - def _set_num_frames(return_token_timestamps, generation_config, kwargs): - if return_token_timestamps: - if getattr(generation_config, "task", None) == "translate": - logger.warning("Token-level timestamps may not be reliable for task 'translate'.") - if not hasattr(generation_config, "alignment_heads"): - raise ValueError( - "Model generation config has no `alignment_heads`, token-level timestamps not available. " - "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." - ) - - generation_config.num_frames = kwargs.pop("num_frames", None) - - @staticmethod - def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config): - condition_on_prev_tokens = ( - condition_on_prev_tokens - if condition_on_prev_tokens is not None - else getattr(generation_config, "condition_on_prev_tokens", False) - ) - generation_config.condition_on_prev_tokens = condition_on_prev_tokens - - @staticmethod - def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames): - if batch_size > 1 and attention_mask is None: - raise ValueError( - "When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " - ) - elif batch_size > 1: - max_frames = attention_mask.sum(-1).cpu().to(torch.long) - seek = torch.zeros((batch_size,), dtype=torch.long) - else: - max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames - seek = torch.zeros((1,), dtype=torch.long) - - return max_frames, seek - - @staticmethod - def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): - init_tokens = [generation_config.decoder_start_token_id] - forced_decoder_ids = generation_config.forced_decoder_ids - if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: - i = 1 - while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: - init_tokens += [forced_decoder_ids[0][1]] - forced_decoder_ids = forced_decoder_ids[1:] - i += 1 - - forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None - generation_config.forced_decoder_ids = forced_decoder_ids - - return init_tokens - - @staticmethod - def _retrieve_logit_processors(generation_config, logits_processor, no_speech_threshold, num_start_tokens, is_shortform): - begin_index = 1 - if generation_config.return_timestamps is True: - forced_decoder_ids = generation_config.forced_decoder_ids - last_forced_decoder_ids = forced_decoder_ids[-1][-1] if forced_decoder_ids is not None else None - if last_forced_decoder_ids == generation_config.no_timestamps_token_id: - # remove no_timestamp to be forcefully generated if we want to return timestamps - # this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly - forced_decoder_ids = forced_decoder_ids[:-1] if len(forced_decoder_ids) > 1 else None - # Make sure that if list is empty we set it to None - generation_config.forced_decoder_ids = forced_decoder_ids - - begin_index = begin_index + len(forced_decoder_ids) if forced_decoder_ids is not None else begin_index - - timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index) - logits_processor = ( - [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor - ) - - if generation_config.suppress_tokens is not None: - suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens) - logits_processor = ( - [suppress_tokens_processor] if logits_processor is None else [suppress_tokens_processor] + logits_processor - ) - generation_config.suppress_tokens = None - - if generation_config.begin_suppress_tokens is not None: - begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index=begin_index) - logits_processor = ( - [begin_suppress_processor] if logits_processor is None else [begin_suppress_processor] + logits_processor - ) - generation_config.begin_suppress_tokens = None - - if no_speech_threshold is not None and not is_shortform: - no_speech_detector = WhisperNoSpeechDetection(no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, begin_index_offset=num_start_tokens) - logits_processor = ( - [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor - ) - - return logits_processor - - @staticmethod - def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map): - prev_bsz = cur_bsz - new_batch_idx_map = [] - for i in range(prev_bsz): - prev_i = batch_idx_map[i] - if seek[prev_i] >= max_frames[prev_i]: - cut_index = i + (cur_bsz - prev_bsz) - cur_bsz -= 1 - input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0) - else: - # cut out index that goes away - new_batch_idx_map.append(prev_i) - - return input_features, cur_bsz, new_batch_idx_map - - @staticmethod - def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map): - segment_input = [] - for i in range(cur_bsz): - prev_i = batch_idx_map[i] - segment_input_slice = input_features[ - i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i] - ] - - if segment_input_slice.shape[-1] < num_segment_frames: - # pad to 3000 if necessary - segment_input_slice = F.pad( - segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1]) - ) - - segment_input.append(segment_input_slice) - - segment_input = torch.cat(segment_input, dim=0) - - return segment_input - - # TODO(Patrick) - remove prev_start_of_text - @staticmethod - def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx_map, do_condition_on_prev_tokens, generation_config, config, device, kwargs, prev_start_of_text): - cut_off_length = config.max_target_positions // 2 - 1 - - one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long) - decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) - - # if condition_on_prev_tokens and len(current_segments[0]) > 0 and temperature < 0.5: - if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0: - # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 - active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] - prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or prev_start_of_text - - bos_token_tensor = prev_start_of_text * one_tensor[0] - prev_tokens = WhisperForConditionalGeneration._pad_to_max_length( - active_segments, generation_config.pad_token_id, padding="left", bos_token_tensor=bos_token_tensor, cut_off_length=cut_off_length - ) - decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) - - # kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) - - - return decoder_input_ids, kwargs - - @staticmethod - def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs): - num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1) - - passed_max_length = kwargs.pop("max_length", None) - passed_max_new_tokens = kwargs.pop("max_new_tokens", None) - max_length_config = getattr(generation_config, "max_length", None) - max_new_tokens_config = getattr(generation_config, "max_new_tokens", None) - - max_new_tokens = None - max_length = None - - # Make sure we don't get larger than `max_length` - if passed_max_length is not None and passed_max_new_tokens is None: - max_length = min( - passed_max_length + num_initial_tokens, config.max_target_positions - ) - logger.info( - f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment." - ) - elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: - max_length = min( - generation_config.max_length + num_initial_tokens, config.max_target_positions - ) - logger.info( - f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment." - ) - elif ( - passed_max_new_tokens is not None - and passed_max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions - ): - max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] - elif ( - passed_max_new_tokens is None - and max_new_tokens_config is not None - and max_new_tokens_config + decoder_input_ids.shape[-1] > config.max_target_positions - ): - max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] - - kwargs["max_new_tokens"] = max_new_tokens - kwargs["max_length"] = max_length - - return kwargs - - @staticmethod - def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None): - max_total_length = 0 - sequences = [] - if padding not in ["right", "left"]: - raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}") - - for current_segment_list in current_segments: - if current_segment_list is not None: - sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1) - - if cut_off_length is not None: - sequence = sequence[-cut_off_length:] - - if bos_token_tensor is not None: - sequence = torch.cat([bos_token_tensor, sequence]) - - sequences.append(sequence) - max_total_length = max(max_total_length, len(sequences[-1])) - else: - sequences.append(bos_token_tensor) - - for i in range(len(current_segments)): - pad_length = max_total_length - len(sequences[i]) - pad = (0, pad_length) if padding == "right" else (pad_length, 0) - sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id) - - sequences = torch.stack(sequences, dim=0) - return sequences - - @staticmethod - def _retrieve_compression_ratio(tokens, vocab_size): - length = int(math.log2(vocab_size) / 8) + 1 - token_bytes = b''.join([t.to_bytes(length, 'little') for t in tokens.tolist()]) - - # string = tok.decode(tokens, skip_special_tokens=True) - # string_bytes = string.encode("utf-8") - # string_compression_ratio = len(string_bytes) / len(zlib.compress(string_bytes)) - - compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes)) - - # print(f"HERE: string: {string}") - # print(f"HERE: string ratio: {string_compression_ratio}") - # print(f"HERE: token ratio: {compression_ratio}") - # print('HERE:' + 20 * '-') - - return compression_ratio - - @staticmethod - def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature): - rescale_temperature = temperature if temperature > 0.0 else 1 - scores = torch.stack(scores).to(tokens.device) - - # TODO(Patrick) - only leave scores = scores[:tokens.shape[0]] part - if scores.shape[0] > tokens.shape[0]: - scores = scores[:tokens.shape[0]] - else: - tokens = tokens[-scores.shape[0]:] - - logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype) - - # retrieve logprob of selected tokens and sum - sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0])) - length = (tokens != eos_token_id).sum(-1) - - avg_logprobs = sum_logprobs / (length + 1) - return avg_logprobs - - @staticmethod - def _retrieve_segment( - seek_sequence, - seek_outputs, - time_offset, - timestamp_begin, - seek_num_frames, - time_precision, - input_stride, - prev_idx, - idx, - ): - # find the predicted "end of segment" predictions of Whisper - # "end of segment" predictions occur whenever Whisper predicts a timestamp token - timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin) - single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] - timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] - timestamp_segment_indices.add_(1) - - # If whisper predicted a "end of segment" via a timestep token, let's go ever each - # "end of segment" prediction and slice the decoding into segments accordingly - if len(timestamp_segment_indices) > 0: - # if the output contains two consecutive timestamp tokens - slices = timestamp_segment_indices.tolist() - segments = [] - if single_timestamp_ending: - slices.append(len(seek_sequence)) - - last_slice = 0 - # Add each segment to list of all segments - for current_slice in slices: - sliced_tokens = seek_sequence[last_slice : current_slice] - start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin - end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin - segments.append( - { - "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, - "end": time_offset[prev_idx] + end_timestamp_pos * time_precision, - "tokens": sliced_tokens, - "result": seek_outputs[idx], - } - ) - last_slice = current_slice - - if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. - segment_offset = seek_num_frames[prev_idx] - cut_type = "single ending" - else: - # otherwise, ignore the unfinished segment and seek to the last timestamp - # here we throw away all predictions after the last predicted "end of segment" - # since we are cutting right in the middle of an audio - last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin - segment_offset = last_timestamp_pos * input_stride - cut_type = "cut" - else: - # If whisper does not predict any "end of segment" token, then - # the whole decoding is considered a segment and we add it to the list of segments - timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()] - last_timestamp_pos = seek_num_frames[prev_idx] - if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: - # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = timestamps[-1].item() - timestamp_begin - - segments = [ - { - "start": time_offset[prev_idx], - "end": time_offset[prev_idx] + last_timestamp_pos * time_precision, - "tokens": seek_sequence, - "result": seek_outputs[idx], - } - ] - segment_offset = seek_num_frames[prev_idx] - cut_type = "all" - - return segments, segment_offset, cut_type - def prepare_inputs_for_generation( self, decoder_input_ids, From a8b8446a8408e027faaf6dd13aa9e2a9bb6102a2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 19 Dec 2023 12:38:03 +0000 Subject: [PATCH 43/75] Finish --- src/transformers/generation/logits_process.py | 9 +++- src/transformers/generation/utils.py | 13 +++++- .../models/whisper/generation_whisper.py | 18 +++++--- tests/models/whisper/test_modeling_whisper.py | 44 +++++++------------ 4 files changed, 48 insertions(+), 36 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index a65db4f1c158..7bbae30d766a 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1909,11 +1909,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class WhisperNoSpeechDetection(LogitsProcessor): r"""This processor can be used to detect silence when using Whisper.""" - def __init__(self, no_speech_token: int, begin_index: int, begin_index_offset: int): + def __init__(self, no_speech_token: int, begin_index: int, begin_index_offset: int, scores_is_logprobs: bool = False): self.no_speech_token = no_speech_token self.begin_index = begin_index self.begin_index_offset = begin_index_offset self._no_speech_prob = [0.0] + self.is_scores_logprobs = scores_is_logprobs # make sure we pass all logits self._pass_all_logits = True @@ -1930,7 +1931,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to if input_ids.shape[1] == self.begin_index: no_speech_index = self.begin_index - self.begin_index_offset no_speech_scores = scores[:, no_speech_index] - probs = no_speech_scores.float().softmax(dim=-1) + if self.is_scores_logprobs: + probs = no_speech_scores.exp() + else: + probs = no_speech_scores.float().softmax(dim=-1) + self._no_speech_prob = probs[:, self.no_speech_token] scores = scores[:, -1, :] diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ae9e022c7ad8..05738c0000bd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1390,6 +1390,7 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de "generation.", UserWarning, ) + if input_ids_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" warnings.warn( @@ -2605,6 +2606,9 @@ def greedy_search( # pre-process distribution next_tokens_scores = logits_processor(input_ids, outputs.logits) + if len(next_tokens_scores.shape) > 2: + next_tokens_scores = next_tokens_scores[:, -1, :] + # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: @@ -2888,6 +2892,9 @@ def sample( next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) + if len(next_token_scores.shape) > 2: + next_token_scores = next_token_scores[:, -1, :] + # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: @@ -3203,12 +3210,16 @@ def beam_search( cur_len = cur_len + 1 continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits[:, -1, :] + next_token_logits = outputs.logits next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) + + if len(next_token_scores_processed.shape) > 2: + next_token_scores_processed = next_token_scores_processed[:, -1, :] + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index af524312ea09..29b4893ba269 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -403,6 +403,8 @@ def split_by_batch_index(values, key, batch_idx): new_fallback_index_map = [] new_segment_input = [] new_decoder_input_ids = [] + new_decoder_attention_mask = [] + # 6.7 Loop over each decoded audio individually as each decoding can be of a different length for i, seek_sequence in enumerate(seek_sequences): # make sure we cut a predicted EOS token if we are not finished with the generation yet @@ -445,8 +447,6 @@ def split_by_batch_index(values, key, batch_idx): # TODO(PVP) only works for batch size = 1 currently # Need to do before all other logit processors no_speech_prob = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") - print("WATCH") - print(no_speech_prob) if no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: print("Skip because of VAD") needs_fallback[i] = False @@ -460,6 +460,8 @@ def split_by_batch_index(values, key, batch_idx): new_fallback_index_map.append(fallback_index_map[i]) new_segment_input.append(segment_input[i]) new_decoder_input_ids.append(decoder_input_ids[i]) + if "decoder_attention_mask" in kwargs: + new_decoder_attention_mask.append(kwargs['decoder_attention_mask'][i]) fallback_index_map = new_fallback_index_map @@ -472,6 +474,8 @@ def split_by_batch_index(values, key, batch_idx): # if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors decoder_input_ids = torch.stack(new_decoder_input_ids) segment_input = torch.stack(new_segment_input) + if "decoder_attention_mask" in kwargs: + kwargs["decoder_attention_mask"] = torch.stack(new_decoder_attention_mask) return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens @@ -826,8 +830,7 @@ def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx ) decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) - # kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) - + kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) return decoder_input_ids, kwargs @@ -870,8 +873,11 @@ def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, ): max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] - kwargs["max_new_tokens"] = max_new_tokens - kwargs["max_length"] = max_length + if max_new_tokens is not None: + kwargs["max_new_tokens"] = max_new_tokens + + if max_length is not None: + kwargs["max_length"] = max_length return kwargs diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 6a3a033e21ab..db6a4ecde07a 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -138,9 +138,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to self.count += 1 if torch.isinf(scores).all(): - import ipdb - - ipdb.set_trace() raise ValueError("Dummy logit processor is incorrectly set up. Scores should not be all inf.") return scores @@ -1407,7 +1404,6 @@ def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): # len = 250 with num_input_frames = 60 long_input_features = torch.cat([input_features.repeat(1, 1, 4), input_features[:, :, :10]], dim=-1) - long_input_features[:1, :, :200] input_features_2 = long_input_features[1:] attention_mask = torch.ones( (2, long_input_features.shape[-1]), dtype=input_features.dtype, device=input_features.device @@ -1419,13 +1415,14 @@ def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): batch_size = 1 num_timestamp_tokens = 20 - max_length = 16 + max_new_tokens = 16 timestamp_begin = vocab_size - num_timestamp_tokens model.generation_config.no_timestamps_token_id = timestamp_begin - 1 model.generation_config.eos_token_id = None model.generation_config._detect_timestamp_from_logprob = False # make sure that we only have the same begin token model.generation_config.max_initial_timestamp_index = 0 + model.generation_config.max_new_tokens = max_new_tokens model.generation_config.prev_bos_token_id = timestamp_begin - 3 logits_processor = [ @@ -1433,12 +1430,12 @@ def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): vocab_size - num_timestamp_tokens, vocab_size, batch_size=batch_size, - max_length=max_length, + max_length=max_new_tokens, min_space=4, seed=1, ) ] - outputs_2 = model.generate(input_features_2, logits_processor=logits_processor, return_segments=True) + outputs_2 = model.generate(input_features_2, max_new_tokens=max_new_tokens, logits_processor=logits_processor, condition_on_prev_tokens=condition_on_prev_tokens, return_segments=True) tokens_2 = outputs_2["sequences"][0] segments_2 = outputs_2["segments"][0] @@ -1448,7 +1445,7 @@ def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): vocab_size - num_timestamp_tokens, vocab_size, batch_size=batch_size, - max_length=max_length, + max_length=max_new_tokens, min_space=4, seed=0, ) @@ -1458,19 +1455,15 @@ def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): "return_segments": True, "condition_on_prev_tokens": condition_on_prev_tokens, "attention_mask": attention_mask, + "max_new_tokens": max_new_tokens, } - if condition_on_prev_tokens: - gen_kwargs["no_speech_threshold"] = 0.6 - gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) - gen_kwargs["compression_ratio_threshold"] = 2.4 - gen_kwargs["logprob_threshold"] = -1.0 - outputs = model.generate(long_input_features, **gen_kwargs) tokens = outputs["sequences"][1] segments = outputs["segments"][1] - assert tokens_2.tolist() == tokens.tolist() + # make sure batched and non-batched is the same + assert tokens_2.tolist() == tokens[:tokens_2.shape[-1]].tolist() for seg1, seg2 in zip(segments_2, segments): assert seg1["start"] == seg2["start"] @@ -2289,14 +2282,14 @@ def test_whisper_longform_multi_batch_hard(self): def test_whisper_longform_multi_batch_hard_prev_cond(self): # fmt: off EXPECTED_TEXT = [ - " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile." - " Folks, I spend a lot of time right over there, night after night after night, actually. Carefully selecting for you the day's noosiest, most aerodynamic headlines, stress testing, and those topical anti-lock breaks and power steering, painstakingly stitching, leather seating so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school and slap myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen moon, render a gas tank out of an empty big gulp, fill with white claw and denatured alcohol, then light a match and let her rip and the demented one man soapbox derby of news that is my segment. Me, Guadalupe! No!" - " Ladies and gentlemen, you know, I spent a lot of time right over there Raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their jokes swollen teats Churning the daily stories into the decadent proven-style style triple cream breed that is my nightly monologue But sometimes sometimes folks I stagger home hungry after being released by the police and Root around in the neighbor's trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rod I won from a rat in a pre-donned street fight. Put it in a discarded paint can to leave it to ferment next to a trash fire then hunker down and hallucinate while eating the listeria laden demon custard of news that is my segment. You mean one of them." - " Folks, if you watch this show, you know I spend most of my time right over there carefully sorting through the day's biggest stories and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ichol Gregoire Ferrandi, who carefully dye them in a palette of bright zesty shades and adorn them in the finest and most topical inlay work using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddles stitching. In line it with bees, wax, coated linen, finely attached a mallet, hammered strap, pearled hardware, and close-shit to create for you the one-of-a-kind hoke couture, Erme's Birkin bag that is my monologue. But sometimes, sometimes folks, sometimes. Sometimes I wake up in the last car of an abandoned roller coaster at Coney Island where I'm I'm hiding from the triads. I have some engine lubricants out of a safe way bag and stagger down the shore to tear the sail off a beach schooner. Then I rip the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel lovely folks. And use it to stitch the sail into a loose pouch like a rock sack. And I stow away in the back of a garbage truck to the junkyard where I pick through to the debris for only the broken toys that make me the saddest until I have loaded for you. The Hobo Fugitives bug out, bindle of news that is my segment. Me one!" - " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's biggest stories right over there. Meticulously selecting the most topical chakra affirming scented candles, and using Feng Shui to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue. But sometimes just sometimes I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself, and used fry oil, wrap my hands with some double-duct tape by stole from the broken car window. Pound a six-pack of blueberry hard-seltzer and a sack of pills I stole from a parked ambulance. Then arm wrestle a raccoon in the back alley vision quest of news that is my segment. Meanwhile!" - " You know, folks, I spend most of my time right over there. Mining the day's biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels. Then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press-black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards in a faceplate and, finally, using fluted strips of white alloyed molding, I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating, Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes folks. Sometimes, just sometimes, I come into my sense as fully naked on the deck of a pirate besieged melee container ship that picked me up floating on the detached door of a portapotty in the Indian Ocean. Then after a sunstroke-induced realization of the crew of this ship plans to sell me an exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe at a pool chain that accepting my new role as Captain and declaring myself king of the windarc seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create the sopping wet pirate crown of news that is my segment. Meanwhile!" - " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's Newsiest most topical flower eggs milk and butter and Stranding into a fine batter to make delicate and informative comedy pancakes Then I glaze them in the juice and zest of the most relevant midnight Valencia oranges and douse it all and a fine Dela main de voyage cognac Before prom baying and basting them tables. I deserve for you the James Beard award worthy crepe suzzette That is my nightly monologue, but sometimes just sometimes folks. I wake up in the baggage hold of Greyhound bus. It's being hoisted by the scrap yard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps and busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strain, pair of sweatpants and as oven mitts to extract and serve the demented transience poundcake of news that is my segment. Me, Guadalupe!" - " Folks, if you watched the show and I hope you do, I spent a lot of time right over there. Tiredlessly studying the lineage of the days most important thoroughbred stories and whole-stiner headlines, working with the best trainers, money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen. That is my nightly monologue, but sometimes, sometimes, folks, I break into an unincorporated veterinary genetics lab and grab whatever test tubes I can find and then under a grow light I got from a discarded chia pet. I mixed the pilfered DNA of a horse and whatever was in a tube labeled Keith Colan extra. Slurrying the concoction with caffeine pills and a microwave red bull, I screamed, sang a prayer to Janice, initiator of human life and God of transformation as a half horse, half man, freak. Seizes to life before me and the hideous collection of loose animal parts and corrupted man tissue that is my segment. Meanwhile!" + " Folks, if you watch the show, you know I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories, developing the central headline pawns, definitely maneuvering an oh-so-topical night to F6, faming of classic Sicilian, named or variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a Fisher show's in lip-nitsky attack that culminates in the elegant lethal slow played all pass on checkmate that is my nightly monologue, but sometimes sometimes folks I sometimes I start a little wake-up side down in the monkey bars of a condemned playground on a super fun site, get all hepped up on goofballs, rummage that would discard a tag bag of defective toys, yank out a fistball of disembodied doll limbs, toss them on a stain kid's place mad from a defunked denies, set up a table inside a rusty cargo container down by the warf and challenge toothless drifters to the godless bughouse blitz of tournament that is my segment.", + " Folks, I spent a lot of time right over there night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing on those topical anti-lock breaks and power steering, painstakingly stitching, leather seating, so soft, it would make JD power and her associates blush. To create the luxury sedan that is my nightly monologue, but sometimes I just sometimes focus. I lurched to consciousness in the back of an abandoned school bus and slapped myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen-moon render a gas tank out of an empty big gulp, filled with white claw and de-natured alcohol, then light a match, letter-ripping the dis-mented one-man soapbox derby of news that is my segment.", + " Ladies and gentlemen, you know, I spent a lot of time right over there, raising the finest hosting news cattle firmly, yet tenderly milking the latest headlines from their jokes, swollen teats, churning the daily stories into the decadent Provincil style triple cream-breed. It is my nightly monologue, but sometimes sometimes I stagger home hungry after being released by the police and root around in the neighbors trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rind I won from a rat and a pre-drawn street fight. Put it into discarded paint can to leave it to ferment next to a trash fire than a hunker down in hallucinate while eating the lusteria latent demon custard of news that is my segment.", + " Folks, you watched this show, you know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle, and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ickel Greg Waferandi, who carefully died them in a pallet of bright, zesty shades, and adorn them in the finest most topical inlay work, using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddle stitching, and line it with bees, wax, coated linen, and finally attach a mallet hammered strap, perled hardware, and close-shet to create for you the one of a kind hope, kutur, earn-may is burkin bag that is my monologue, but sometimes, sometimes, sometimes. Sometimes, sometimes I wake up in the last car of an abandoned roller coaster at Kony Island, where I'm hiding from the triads, I have some engine lubricants out of a safe way bag and staggered down the shore to tear the sail off a beach sooner than I ripped the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel Lovelyfokes, and use it to stitch the sail into a loose pouch like rock sack, and I stole a bag of a garbage truck to the junkyard, where I picked through to the debris for only the broken toys that make me the saddest, until I have loaded for you. The hobo fugitives bug out Bindle of news that is my segment.", + " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's big stories right over there. meticulously selecting the most topical chakra affirming scented candles, using Feng Shui, to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue, but sometimes just sometimes, I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself and use fry oil, wrap my hands and some old duct tape I stole from a broken car window, pound a six pack of blueberry hard-seller and a second pill, as I stole from a park damsel, and it's then arm wrestle a raccoon in the back alley vision quest of news that is my segment.", + " You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press, black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards and a face plate, and finally using fluted strips of white alloyed molding I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes, folks. Sometimes, just sometimes, I come to my senses fully naked on the deck of a pirate, beceived, melee, container ship that picked me up floating on the detainees. Then after I sunstroke in juice, realization of the crew of this ship plans to sell me and exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe in a pool chain that accepting my new role as captain and declaring myself king of the wind arc seas. I grab a dirty muck bucket covered in barnacles and a dornet with the teeth of the vanquished to create the softening wet pirate crown of news that is my segment. I'm going to use the white paper to create the softened white paper to create the softened white paper to create the softened white pirate crown of news that is my segment. Meanwhile.", + " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most relevant midnight valencio oranges. And doubts at all, and I find delimane de voyage cognac, before from bang and basting them tables, I deserve you the James Beard Award worthy creeps to ZET. That is my nightly monologue, but sometimes sometimes folks I wake up in the baggage hole of Greyhound bus, it's being hoisted by the scrapyard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps, busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strained pair of sweatpants and as ovenmets to extract and serve the demented transients pound cake of news that is my segment. Me wild!", + " Folks, if you watch the show and I hope you do, I spend a lot of time right over there. Tirelessly studying the lineage of the day's most important thoroughbred stories and whole-stiner headlines, working with the best trainers money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen that is my nightly monologue. But sometimes sometimes folks I break into an unincorporated veterinary genetics lab. And grab whatever test tubes I can find and then under a grow light I got from it a discarded chia pet. I mixed the pill for DNA of a horse and whatever was in a tube labeled Keith Cole and extra. Sloering the concoction with caffeine pills and a microwave bread bowl, I screamed sing a prayer to Janice initiator of human life and God of transformation as a half horse, half man freak, seasons to life before me. And the hideous collection of loose animal parts and corrupted men tissue that is my segment. Meanwhile.", ] # fmt: on @@ -2325,14 +2318,11 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) gen_kwargs["num_beams"] = 5 + torch.manual_seed(0) result = model.generate(**inputs, **gen_kwargs) decoded_all = processor.batch_decode(result, skip_special_tokens=True) - torch.manual_seed(0) for i in range(num_samples): - import ipdb - - ipdb.set_trace() assert decoded_all[i] == EXPECTED_TEXT[i] From 875beab54bcf57872cf46587e09ec612c8b2375d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 19 Dec 2023 13:53:12 +0000 Subject: [PATCH 44/75] Fix more --- src/transformers/generation/logits_process.py | 11 ----------- .../models/whisper/generation_whisper.py | 13 +++---------- 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 7bbae30d766a..9ab68e89b3f3 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -50,10 +50,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." ) - @property - def pass_all_logits(self): - return getattr(self, "_pass_all_logits", False) - class LogitsWarper: """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" @@ -64,10 +60,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." ) - @property - def pass_all_logits(self): - return getattr(self, "_pass_all_logits", False) - class LogitsProcessorList(list): """ @@ -92,9 +84,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa The processed prediction scores. """ - if not any(processor.pass_all_logits for processor in self) and len(scores.shape) > 2: - scores = scores[:, -1, :] - for processor in self: function_args = inspect.signature(processor.__call__).parameters if len(function_args) > 2: diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 29b4893ba269..edac9d210584 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -226,6 +226,8 @@ def generate( # 1. copy generation config if generation_config is None: generation_config = copy.deepcopy(self.generation_config) + else: + generation_config = copy.deepcopy(generation_config) # 2. set global generate variables input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] @@ -913,20 +915,11 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_toke @staticmethod def _retrieve_compression_ratio(tokens, vocab_size): + """ Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes """ length = int(math.log2(vocab_size) / 8) + 1 token_bytes = b''.join([t.to_bytes(length, 'little') for t in tokens.tolist()]) - - # string = tok.decode(tokens, skip_special_tokens=True) - # string_bytes = string.encode("utf-8") - # string_compression_ratio = len(string_bytes) / len(zlib.compress(string_bytes)) - compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes)) - # print(f"HERE: string: {string}") - # print(f"HERE: string ratio: {string_compression_ratio}") - # print(f"HERE: token ratio: {compression_ratio}") - # print('HERE:' + 20 * '-') - return compression_ratio @staticmethod From 39034c7ca53f4884b3a0fb753b467e987705c995 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 19 Dec 2023 13:53:40 +0000 Subject: [PATCH 45/75] Fix more --- src/transformers/generation/utils.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d2b2f1ebe476..d23f7f9245d7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -781,8 +781,6 @@ def _prepare_decoder_input_ids_for_generation( # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): pass - elif self.config.model_type in ["whisper"]: - pass # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): @@ -1129,7 +1127,6 @@ def _get_logits_processor( # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: processors.append(LogitNormalization()) - return processors def _get_stopping_criteria( @@ -1390,7 +1387,6 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de "generation.", UserWarning, ) - if input_ids_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" warnings.warn( @@ -2603,11 +2599,10 @@ def greedy_search( if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need - # pre-process distribution - next_tokens_scores = logits_processor(input_ids, outputs.logits) + next_token_logits = outputs.logits[:, -1, :] - if len(next_tokens_scores.shape) > 2: - next_tokens_scores = next_tokens_scores[:, -1, :] + # pre-process distribution + next_tokens_scores = logits_processor(input_ids, next_token_logits) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -2886,15 +2881,12 @@ def sample( if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits + next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) - if len(next_token_scores.shape) > 2: - next_token_scores = next_token_scores[:, -1, :] - # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: @@ -3210,16 +3202,12 @@ def beam_search( cur_len = cur_len + 1 continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits + next_token_logits = outputs.logits[:, -1, :] next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) - - if len(next_token_scores_processed.shape) > 2: - next_token_scores_processed = next_token_scores_processed[:, -1, :] - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) From 66c08eecb1939230e4c3f616529079a62dbf33bf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 19 Dec 2023 16:18:05 +0000 Subject: [PATCH 46/75] finish --- src/transformers/generation/logits_process.py | 24 +++++++++++++++---- src/transformers/generation/utils.py | 2 ++ .../models/whisper/generation_whisper.py | 20 ++++++++++++---- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 9ab68e89b3f3..a278b574fa7b 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1908,6 +1908,17 @@ def __init__(self, no_speech_token: int, begin_index: int, begin_index_offset: i # make sure we pass all logits self._pass_all_logits = True + # overwritten dynamically + self.model = None + self.inputs = None + + def set_model(self, model): + self.model = model + + def set_inputs(self, inputs): + self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs} + self.inputs["input_features"] = self.inputs.pop("inputs") + @property def no_speech_prob(self): return self._no_speech_prob @@ -1918,8 +1929,15 @@ def set_begin_index(self, begin_index): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if input_ids.shape[1] == self.begin_index: - no_speech_index = self.begin_index - self.begin_index_offset - no_speech_scores = scores[:, no_speech_index] + if self.begin_index_offset > 1: + with torch.no_grad(): + logits = self.model(**self.inputs).logits + + no_speech_index = self.begin_index - self.begin_index_offset + no_speech_scores = logits[:, no_speech_index] + else: + no_speech_scores = scores + if self.is_scores_logprobs: probs = no_speech_scores.exp() else: @@ -1927,8 +1945,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to self._no_speech_prob = probs[:, self.no_speech_token] - scores = scores[:, -1, :] - return scores diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d23f7f9245d7..02f3ebb40564 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -781,6 +781,8 @@ def _prepare_decoder_input_ids_for_generation( # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): pass + elif self.config.model_type in ["whisper"]: + pass # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index edac9d210584..2debd5fbc8ab 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -247,7 +247,7 @@ def generate( # 4. Retrieve logits processors # => TODO(Patrick): The way `num_start_tokens` is retrieved here is too brittle. Need a better approach num_start_tokens = len(generation_config.forced_decoder_ids) if generation_config.forced_decoder_ids is not None else 1 - logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold, num_start_tokens=num_start_tokens, is_shortform=is_shortform) + logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold, num_start_tokens=num_start_tokens, is_shortform=is_shortform, num_beams=kwargs.get("num_beams", 1)) # 5. If we're in shortform mode, simple generate the whole input at once and return the output if is_shortform: @@ -363,6 +363,15 @@ def generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, batc should_skip = [False for _ in range(cur_bsz)] fallback_index_map = list(range(cur_bsz)) + if no_speech_threshold is not None: + set_inputs = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs") + extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)} + set_inputs({ + "inputs": segment_input, + "decoder_input_ids": decoder_input_ids, + **extra_kwargs + }) + for fallback_idx, temperature in enumerate(temperatures): generation_config.do_sample = temperature > 0.0 generation_config.temperature = temperature @@ -449,7 +458,8 @@ def split_by_batch_index(values, key, batch_idx): # TODO(PVP) only works for batch size = 1 currently # Need to do before all other logit processors no_speech_prob = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") - if no_speech_prob[i] > no_speech_threshold and logprobs < logprob_threshold: + + if logprobs < logprob_threshold and no_speech_prob[i] > no_speech_threshold: print("Skip because of VAD") needs_fallback[i] = False should_skip[i] = True @@ -733,8 +743,7 @@ def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): return init_tokens - @staticmethod - def _retrieve_logit_processors(generation_config, logits_processor, no_speech_threshold, num_start_tokens, is_shortform): + def _retrieve_logit_processors(self, generation_config, logits_processor, no_speech_threshold, num_start_tokens, is_shortform, num_beams): begin_index = 1 if generation_config.return_timestamps is True: forced_decoder_ids = generation_config.forced_decoder_ids @@ -768,10 +777,11 @@ def _retrieve_logit_processors(generation_config, logits_processor, no_speech_th generation_config.begin_suppress_tokens = None if no_speech_threshold is not None and not is_shortform: - no_speech_detector = WhisperNoSpeechDetection(no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, begin_index_offset=num_start_tokens) + no_speech_detector = WhisperNoSpeechDetection(no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, begin_index_offset=num_start_tokens, scores_is_logprobs=num_beams > 1) logits_processor = ( [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor ) + no_speech_detector.set_model(self) return logits_processor From 254026c323e00894f7bf6060d4988642cbc7c6a1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 2 Jan 2024 18:24:44 +0000 Subject: [PATCH 47/75] Fix edge cases --- src/transformers/models/whisper/generation_whisper.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index f1571d56f73e..1bc9235928da 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1013,6 +1013,9 @@ def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) + else: + # make sure `"decoder_attention_mask"` is not passed to forward + kwargs.pop("decoder_attention_mask", None) return decoder_input_ids, kwargs @@ -1071,7 +1074,7 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_toke raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}") for current_segment_list in current_segments: - if current_segment_list is not None: + if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0: sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1) if cut_off_length is not None: From 85c68f298b990b71d3650a4a7bb07f8a550e62d7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 3 Jan 2024 10:00:29 +0100 Subject: [PATCH 48/75] fix return_dict_in_generate --- src/transformers/models/whisper/generation_whisper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 1bc9235928da..fa8503ed68d7 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -427,7 +427,6 @@ def generate( stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, - return_dict_in_generate=return_dict_in_generate, **kwargs, ) From c2fa76a291ca7a3d7f445e814f64cfbd20d1b0ae Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 3 Jan 2024 11:59:18 +0100 Subject: [PATCH 49/75] fix all tests --- src/transformers/generation/logits_process.py | 4 - .../models/whisper/generation_whisper.py | 161 ++++++++++-------- tests/models/whisper/test_modeling_whisper.py | 2 + 3 files changed, 91 insertions(+), 76 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index a278b574fa7b..30ed2015bc6a 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1874,7 +1874,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to # apply the `max_initial_timestamp` option if input_ids.shape[1] == self.begin_index: - print("HF Sample begin", self.begin_index) scores[:, : self.timestamp_begin] = -float("inf") if self.max_initial_timestamp_index is not None: @@ -1889,9 +1888,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob: scores[k, : self.timestamp_begin] = -float("inf") - if torch.isinf(scores).all(): - print("RED FLAG") - return scores diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index fa8503ed68d7..35588a440c30 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -221,7 +221,7 @@ def generate( is_multilingual=None, condition_on_prev_tokens: Optional[bool] = None, no_speech_threshold: Optional[float] = None, - temperature: Union[float, Tuple[float, ...]] = 0.0, + temperature: Optional[Union[float, Tuple[float, ...]]] = None, compression_ratio_threshold: Optional[float] = None, logprob_threshold: Optional[float] = None, prompt_ids: Optional[torch.Tensor] = None, @@ -411,6 +411,7 @@ def generate( self._set_language_and_task(language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config) # pass self.config for backward compatibility self._set_forced_decoder_ids(task=task, language=language, prompt_ids=prompt_ids, generation_config=generation_config, config=self.config, kwargs=kwargs) + self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs) self._set_num_frames(return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs) # 4. Retrieve logits processors @@ -427,6 +428,7 @@ def generate( stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, + temperature=temperature, **kwargs, ) @@ -484,7 +486,6 @@ def generate( if hasattr(proc, "set_begin_index"): proc.set_begin_index(decoder_input_ids.shape[-1]) - print("hf in tokens", decoder_input_ids[0].tolist()) # 6.8 Run generate with fallback seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(segment_input=segment_input, decoder_input_ids=decoder_input_ids, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, seek=seek, num_segment_frames=num_segment_frames, max_frames=max_frames, temperatures=temperatures, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, return_token_timestamps=return_token_timestamps, compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold, do_condition_on_prev_tokens=do_condition_on_prev_tokens, condition_on_prev_tokens=condition_on_prev_tokens, kwargs=kwargs) @@ -494,11 +495,9 @@ def generate( if should_skip[i]: seek[prev_i] += seek_num_frames[prev_i] - print("Skipped!") continue - # TODO(Patrick: delete cut type) - segments, segment_offset, cut_type = self._retrieve_segment( + segments, segment_offset = self._retrieve_segment( seek_sequence=seek_sequence, seek_outputs=seek_outputs, time_offset=time_offset, @@ -532,16 +531,10 @@ def generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, batc fallback_index_map = list(range(cur_bsz)) if no_speech_threshold is not None: - set_inputs = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs") - extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)} - set_inputs({ - "inputs": segment_input, - "decoder_input_ids": decoder_input_ids, - **extra_kwargs - }) + self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs) for fallback_idx, temperature in enumerate(temperatures): - generation_config.do_sample = temperature > 0.0 + generation_config.do_sample = temperature is not None and temperature > 0.0 generation_config.temperature = temperature generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1 @@ -556,35 +549,19 @@ def generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, batc **kwargs, ) - if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - num_frames = getattr(generation_config, "num_frames", None) - seek_outputs["token_timestamps"] = self._extract_token_timestamps( - seek_outputs, generation_config.alignment_heads, num_frames=num_frames - ) - - if generation_config.return_dict_in_generate: - def split_by_batch_index(values, key, batch_idx): - if key == "scores": - return [v[batch_idx].cpu() for v in values] - if key == "past_key_values": - # we don't save `past_key_values` as this is too costly - return None - return values[batch_idx].cpu() - - sequence_tokens = seek_outputs["sequences"] - seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0])] - else: - sequence_tokens = seek_outputs + # post-process sequence tokens and outputs to be in list form + sequence_tokens, seek_outputs = self._postprocess_outputs(seek_outputs, return_token_timestamps, generation_config) # remove all previously passed decoder input ids seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] + # 6.7 Extract cut sequences from every sequence and check if fallback should be applied + # Loop over each decoded audio individually as each decoding can be of a different length new_fallback_index_map = [] new_segment_input = [] new_decoder_input_ids = [] new_decoder_attention_mask = [] - # 6.7 Loop over each decoded audio individually as each decoding can be of a different length for i, seek_sequence in enumerate(seek_sequences): # make sure we cut a predicted EOS token if we are not finished with the generation yet prev_i = batch_idx_map[fallback_index_map[i]] @@ -594,47 +571,17 @@ def split_by_batch_index(values, key, batch_idx): if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: seek_sequence = seek_sequence[:-1] - print("hf out tokens", seek_sequence.tolist()) - # remove all padding tokens if seek_sequence[-1] == generation_config.pad_token_id: num_paddings = (seek_sequence == generation_config.pad_token_id).sum() seek_sequence = seek_sequence[:-num_paddings] - if compression_ratio_threshold is not None: - compression_ratio = self._retrieve_compression_ratio(seek_sequence, self.config.vocab_size) - - if compression_ratio > compression_ratio_threshold: - print("fallback compression") - print("current temp", temperature) - needs_fallback[i] = True - - if logprob_threshold is not None: - if "sequences_scores" in seek_outputs[0]: - logprobs = [s["sequences_scores"] for s in seek_outputs][i] - else: - scores = seek_outputs[i]["scores"] - logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, self.config.eos_token_id, temperature) - - # TODO(PVP) only works for batch size = 1 currently - if logprobs < logprob_threshold: - print("fallback logprobs", logprobs) - print("current temp", temperature) - needs_fallback[i] = True - - if no_speech_threshold is not None: - # TODO(PVP) only works for batch size = 1 currently - # Need to do before all other logit processors - no_speech_prob = self._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") - - if logprobs < logprob_threshold and no_speech_prob[i] > no_speech_threshold: - print("Skip because of VAD") - needs_fallback[i] = False - should_skip[i] = True + # check which sequences in batch need fallback & which should be skipped + needs_fallback[i], should_skip[i] = self._need_fallback(seek_sequence, seek_outputs, i, logits_processor, compression_ratio_threshold, logprob_threshold, no_speech_threshold, self.config.vocab_size, generation_config.eos_token_id, temperature) seek_sequence_list[fallback_index_map[i]] = seek_sequence seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] - do_condition_on_prev_tokens[fallback_index_map[i]] = condition_on_prev_tokens and temperature < 0.5 + do_condition_on_prev_tokens[fallback_index_map[i]] = condition_on_prev_tokens and temperature is not None and temperature < 0.5 if needs_fallback[i]: new_fallback_index_map.append(fallback_index_map[i]) @@ -659,6 +606,69 @@ def split_by_batch_index(values, key, batch_idx): return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens + def _postprocess_outputs(self, seek_outputs, return_token_timestamps, generation_config): + if return_token_timestamps and hasattr(generation_config, "alignment_heads"): + num_frames = getattr(generation_config, "num_frames", None) + seek_outputs["token_timestamps"] = self._extract_token_timestamps( + seek_outputs, generation_config.alignment_heads, num_frames=num_frames + ) + + if generation_config.return_dict_in_generate: + def split_by_batch_index(values, key, batch_idx): + if key == "scores": + return [v[batch_idx].cpu() for v in values] + if key == "past_key_values": + # we don't save `past_key_values` as this is too costly + return None + return values[batch_idx].cpu() + + sequence_tokens = seek_outputs["sequences"] + seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0])] + else: + sequence_tokens = seek_outputs + + return sequence_tokens, seek_outputs + + @staticmethod + def _need_fallback(seek_sequence, seek_outputs, index, logits_processor, compression_ratio_threshold, logprob_threshold, no_speech_threshold, vocab_size, eos_token_id, temperature): + needs_fallback = False + should_skip = False + if compression_ratio_threshold is not None: + compression_ratio = WhisperGenerationMixin._retrieve_compression_ratio(seek_sequence, vocab_size) + + if compression_ratio > compression_ratio_threshold: + needs_fallback = True + + if logprob_threshold is not None: + if "sequences_scores" in seek_outputs[0]: + logprobs = [s["sequences_scores"] for s in seek_outputs][index] + else: + scores = seek_outputs[index]["scores"] + logprobs = WhisperGenerationMixin._retrieve_avg_logprobs(scores, seek_sequence, eos_token_id, temperature) + + if logprobs < logprob_threshold: + needs_fallback = True + + if no_speech_threshold is not None: + no_speech_prob = WhisperGenerationMixin._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") + + if logprobs < logprob_threshold and no_speech_prob[index] > no_speech_threshold: + needs_fallback = False + should_skip = True + + return needs_fallback, should_skip + + + @staticmethod + def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs): + set_inputs = WhisperGenerationMixin._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs") + extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)} + set_inputs({ + "inputs": segment_input, + "decoder_input_ids": decoder_input_ids, + **extra_kwargs + }) + @staticmethod def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): @@ -859,6 +869,17 @@ def _set_forced_decoder_ids(task, language, prompt_ids, generation_config, confi forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] generation_config.forced_decoder_ids = forced_decoder_ids + @staticmethod + def _set_token_ids(generation_config, config, kwargs): + eos_token_id = kwargs.pop("eos_token_id", None) + decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) + + eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id + decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id + + generation_config.eos_token_id = eos_token_id if eos_token_id is not None else config.eos_token_id + generation_config.decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else config.decoder_start_token_id + @staticmethod def _set_num_frames(return_token_timestamps, generation_config, kwargs): if return_token_timestamps: @@ -1109,7 +1130,6 @@ def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature): rescale_temperature = temperature if temperature > 0.0 else 1 scores = torch.stack(scores).to(tokens.device) - # TODO(Patrick) - only leave scores = scores[:tokens.shape[0]] part if scores.shape[0] > tokens.shape[0]: scores = scores[:tokens.shape[0]] else: @@ -1119,7 +1139,7 @@ def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature): # retrieve logprob of selected tokens and sum sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0])) - length = (tokens != eos_token_id).sum(-1) + length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0] avg_logprobs = sum_logprobs / (length + 1) return avg_logprobs @@ -1171,14 +1191,12 @@ def _retrieve_segment( if single_timestamp_ending: # single timestamp at the end means no speech after the last timestamp. segment_offset = seek_num_frames[prev_idx] - cut_type = "single ending" else: # otherwise, ignore the unfinished segment and seek to the last timestamp # here we throw away all predictions after the last predicted "end of segment" # since we are cutting right in the middle of an audio last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin segment_offset = last_timestamp_pos * input_stride - cut_type = "cut" else: # If whisper does not predict any "end of segment" token, then # the whole decoding is considered a segment and we add it to the list of segments @@ -1197,7 +1215,6 @@ def _retrieve_segment( } ] segment_offset = seek_num_frames[prev_idx] - cut_type = "all" - return segments, segment_offset, cut_type + return segments, segment_offset diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 7176aee50078..82428b632ea9 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1360,6 +1360,7 @@ def _check_longform_generate_single_batch(self, condition_on_prev_tokens): timestamp_begin = vocab_size - num_timestamp_tokens model.generation_config.no_timestamps_token_id = timestamp_begin - 1 model.generation_config.eos_token_id = None + model.config.eos_token_id = None model.generation_config._detect_timestamp_from_logprob = False # make sure that we only have the same begin token model.generation_config.max_initial_timestamp_index = 0 @@ -1419,6 +1420,7 @@ def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): timestamp_begin = vocab_size - num_timestamp_tokens model.generation_config.no_timestamps_token_id = timestamp_begin - 1 model.generation_config.eos_token_id = None + model.config.eos_token_id = None model.generation_config._detect_timestamp_from_logprob = False # make sure that we only have the same begin token model.generation_config.max_initial_timestamp_index = 0 From affdb6d1babd31f6f0b26162416f26c61d2be487 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 3 Jan 2024 15:25:01 +0100 Subject: [PATCH 50/75] make style --- src/transformers/generation/logits_process.py | 6 +- .../models/whisper/generation_whisper.py | 300 ++++++++++++++---- .../models/whisper/modeling_whisper.py | 2 - tests/models/whisper/test_modeling_whisper.py | 12 +- 4 files changed, 243 insertions(+), 77 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 30ed2015bc6a..545d3f653754 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1845,7 +1845,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to scores[:, self.no_timestamps_token_id] = -float("inf") # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly - no_timestamps = False for k in range(input_ids.shape[0]): sampled_tokens = input_ids[k, self.begin_index :] seq = list(sampled_tokens.tolist()) @@ -1856,7 +1855,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to if last_was_timestamp: if penultimate_was_timestamp: # has to be non-timestamp scores[k, self.timestamp_begin :] = -float("inf") - no_timestamps = True else: # cannot be normal text tokens scores[k, : self.eos_token_id] = -float("inf") @@ -1894,7 +1892,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class WhisperNoSpeechDetection(LogitsProcessor): r"""This processor can be used to detect silence when using Whisper.""" - def __init__(self, no_speech_token: int, begin_index: int, begin_index_offset: int, scores_is_logprobs: bool = False): + def __init__( + self, no_speech_token: int, begin_index: int, begin_index_offset: int, scores_is_logprobs: bool = False + ): self.no_speech_token = no_speech_token self.begin_index = begin_index self.begin_index_offset = begin_index_offset diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 35588a440c30..56b9074397f9 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -16,12 +16,13 @@ import copy import math import warnings -import numpy as np import zlib from typing import Optional, Tuple, Union +import numpy as np import torch import torch.nn.functional as F +from torch import nn from ...generation.logits_process import ( SuppressTokensAtBeginLogitsProcessor, @@ -113,7 +114,6 @@ def _dynamic_time_warping(matrix: np.ndarray): class WhisperGenerationMixin: - def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None): """ Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to @@ -221,7 +221,7 @@ def generate( is_multilingual=None, condition_on_prev_tokens: Optional[bool] = None, no_speech_threshold: Optional[float] = None, - temperature: Optional[Union[float, Tuple[float, ...]]] = None, + temperature: Optional[Union[float, Tuple[float, ...]]] = None, compression_ratio_threshold: Optional[float] = None, logprob_threshold: Optional[float] = None, prompt_ids: Optional[torch.Tensor] = None, @@ -401,23 +401,53 @@ def generate( # 2. set global generate variables input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] num_segment_frames = input_stride * self.config.max_source_positions - total_input_frames = self._retrieve_total_input_frames(input_features=input_features, input_stride=input_stride, kwargs=kwargs) + total_input_frames = self._retrieve_total_input_frames( + input_features=input_features, input_stride=input_stride, kwargs=kwargs + ) is_shortform = total_input_frames <= num_segment_frames # 3. Make sure generation config is correctly set # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not - self._set_return_outputs(return_dict_in_generate=return_dict_in_generate, return_token_timestamps=return_token_timestamps, is_shortform=is_shortform, logprob_threshold=logprob_threshold, generation_config=generation_config) - self._set_return_timestamps(return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config) - self._set_language_and_task(language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config) + self._set_return_outputs( + return_dict_in_generate=return_dict_in_generate, + return_token_timestamps=return_token_timestamps, + is_shortform=is_shortform, + logprob_threshold=logprob_threshold, + generation_config=generation_config, + ) + self._set_return_timestamps( + return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config + ) + self._set_language_and_task( + language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config + ) # pass self.config for backward compatibility - self._set_forced_decoder_ids(task=task, language=language, prompt_ids=prompt_ids, generation_config=generation_config, config=self.config, kwargs=kwargs) + self._set_forced_decoder_ids( + task=task, + language=language, + prompt_ids=prompt_ids, + generation_config=generation_config, + config=self.config, + kwargs=kwargs, + ) self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs) - self._set_num_frames(return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs) + self._set_num_frames( + return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs + ) # 4. Retrieve logits processors # => TODO(Patrick): The way `num_start_tokens` is retrieved here is too brittle. Need a better approach - num_start_tokens = len(generation_config.forced_decoder_ids) if generation_config.forced_decoder_ids is not None else 1 - logits_processor = self._retrieve_logit_processors(generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold, num_start_tokens=num_start_tokens, is_shortform=is_shortform, num_beams=kwargs.get("num_beams", 1)) + num_start_tokens = ( + len(generation_config.forced_decoder_ids) if generation_config.forced_decoder_ids is not None else 1 + ) + logits_processor = self._retrieve_logit_processors( + generation_config=generation_config, + logits_processor=logits_processor, + no_speech_threshold=no_speech_threshold, + num_start_tokens=num_start_tokens, + is_shortform=is_shortform, + num_beams=kwargs.get("num_beams", 1), + ) # 5. If we're in shortform mode, simple generate the whole input at once and return the output if is_shortform: @@ -443,14 +473,18 @@ def generate( # We need to chunk the audio input depending on when the model generates timestamp tokens # 6.1 Set and retrieve global longform generation variables - self._set_condition_on_prev_tokens(condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config) + self._set_condition_on_prev_tokens( + condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config + ) timestamp_begin = generation_config.no_timestamps_token_id + 1 temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature temperature = temperatures[0] batch_size = input_features.shape[0] - max_frames, seek = self._retrieve_max_frames_and_seek(batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames) + max_frames, seek = self._retrieve_max_frames_and_seek( + batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames + ) init_tokens = self._retrieve_init_tokens_from_forced_decoder_ids(generation_config=generation_config) # 6.2 Preppare running variables, list for generation @@ -465,21 +499,52 @@ def generate( # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order # to know which original audio is being decoded # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk - input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(input_features=input_features, seek=seek, max_frames=max_frames, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map) + input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch( + input_features=input_features, + seek=seek, + max_frames=max_frames, + cur_bsz=cur_bsz, + batch_idx_map=batch_idx_map, + ) time_offset = seek * time_precision / input_stride seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) # 6.4 cut out next 30s segment from input features - segment_input = self._get_input_segment(input_features=input_features, seek=seek, seek_num_frames=seek_num_frames, num_segment_frames=num_segment_frames, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map) + segment_input = self._get_input_segment( + input_features=input_features, + seek=seek, + seek_num_frames=seek_num_frames, + num_segment_frames=num_segment_frames, + cur_bsz=cur_bsz, + batch_idx_map=batch_idx_map, + ) # 6.5 prepare decoder input ids # TODO(Patrick) - clean up prev_start_of_text - suppress_tokens = self._get_attr_from_logit_processors(logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens") + suppress_tokens = self._get_attr_from_logit_processors( + logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens" + ) prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None - decoder_input_ids, kwargs = self._prepare_decoder_input_ids(cur_bsz=cur_bsz, init_tokens=init_tokens, current_segments=current_segments, batch_idx_map=batch_idx_map, do_condition_on_prev_tokens=do_condition_on_prev_tokens, generation_config=generation_config, config=self.config, device=segment_input.device, kwargs=kwargs, prev_start_of_text=prev_start_of_text) + decoder_input_ids, kwargs = self._prepare_decoder_input_ids( + cur_bsz=cur_bsz, + init_tokens=init_tokens, + current_segments=current_segments, + batch_idx_map=batch_idx_map, + do_condition_on_prev_tokens=do_condition_on_prev_tokens, + generation_config=generation_config, + config=self.config, + device=segment_input.device, + kwargs=kwargs, + prev_start_of_text=prev_start_of_text, + ) # 6.6 set max new tokens or max length - kwargs = self._set_max_new_tokens_and_length(config=self.config, decoder_input_ids=decoder_input_ids, generation_config=generation_config, kwargs=kwargs) + kwargs = self._set_max_new_tokens_and_length( + config=self.config, + decoder_input_ids=decoder_input_ids, + generation_config=generation_config, + kwargs=kwargs, + ) # 6.7 Set current `begin_index` for all logit processors for proc in logits_processor: @@ -487,7 +552,28 @@ def generate( proc.set_begin_index(decoder_input_ids.shape[-1]) # 6.8 Run generate with fallback - seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(segment_input=segment_input, decoder_input_ids=decoder_input_ids, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, seek=seek, num_segment_frames=num_segment_frames, max_frames=max_frames, temperatures=temperatures, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, return_token_timestamps=return_token_timestamps, compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold, do_condition_on_prev_tokens=do_condition_on_prev_tokens, condition_on_prev_tokens=condition_on_prev_tokens, kwargs=kwargs) + seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback( + segment_input=segment_input, + decoder_input_ids=decoder_input_ids, + cur_bsz=cur_bsz, + batch_idx_map=batch_idx_map, + seek=seek, + num_segment_frames=num_segment_frames, + max_frames=max_frames, + temperatures=temperatures, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + return_token_timestamps=return_token_timestamps, + compression_ratio_threshold=compression_ratio_threshold, + logprob_threshold=logprob_threshold, + no_speech_threshold=no_speech_threshold, + do_condition_on_prev_tokens=do_condition_on_prev_tokens, + condition_on_prev_tokens=condition_on_prev_tokens, + kwargs=kwargs, + ) # 6.9 In every generated sequence, split by timestamp tokens and extract segments for i, seek_sequence in enumerate(seek_sequences): @@ -522,7 +608,29 @@ def generate( return sequences - def generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, batch_idx_map, seek, num_segment_frames, max_frames, temperatures, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_token_timestamps, compression_ratio_threshold, logprob_threshold, no_speech_threshold, do_condition_on_prev_tokens, condition_on_prev_tokens, kwargs): + def generate_with_fallback( + self, + segment_input, + decoder_input_ids, + cur_bsz, + batch_idx_map, + seek, + num_segment_frames, + max_frames, + temperatures, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_token_timestamps, + compression_ratio_threshold, + logprob_threshold, + no_speech_threshold, + do_condition_on_prev_tokens, + condition_on_prev_tokens, + kwargs, + ): # 6.6 Batch generate current chunk seek_sequence_list = [None for _ in range(cur_bsz)] seek_outputs_list = [None for _ in range(cur_bsz)] @@ -550,10 +658,12 @@ def generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, batc ) # post-process sequence tokens and outputs to be in list form - sequence_tokens, seek_outputs = self._postprocess_outputs(seek_outputs, return_token_timestamps, generation_config) + sequence_tokens, seek_outputs = self._postprocess_outputs( + seek_outputs, return_token_timestamps, generation_config + ) # remove all previously passed decoder input ids - seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1]:] + seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1] :] # 6.7 Extract cut sequences from every sequence and check if fallback should be applied # Loop over each decoded audio individually as each decoding can be of a different length @@ -577,18 +687,31 @@ def generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, batc seek_sequence = seek_sequence[:-num_paddings] # check which sequences in batch need fallback & which should be skipped - needs_fallback[i], should_skip[i] = self._need_fallback(seek_sequence, seek_outputs, i, logits_processor, compression_ratio_threshold, logprob_threshold, no_speech_threshold, self.config.vocab_size, generation_config.eos_token_id, temperature) + needs_fallback[i], should_skip[i] = self._need_fallback( + seek_sequence, + seek_outputs, + i, + logits_processor, + compression_ratio_threshold, + logprob_threshold, + no_speech_threshold, + self.config.vocab_size, + generation_config.eos_token_id, + temperature, + ) seek_sequence_list[fallback_index_map[i]] = seek_sequence seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] - do_condition_on_prev_tokens[fallback_index_map[i]] = condition_on_prev_tokens and temperature is not None and temperature < 0.5 + do_condition_on_prev_tokens[fallback_index_map[i]] = ( + condition_on_prev_tokens and temperature is not None and temperature < 0.5 + ) if needs_fallback[i]: new_fallback_index_map.append(fallback_index_map[i]) new_segment_input.append(segment_input[i]) new_decoder_input_ids.append(decoder_input_ids[i]) if "decoder_attention_mask" in kwargs: - new_decoder_attention_mask.append(kwargs['decoder_attention_mask'][i]) + new_decoder_attention_mask.append(kwargs["decoder_attention_mask"][i]) fallback_index_map = new_fallback_index_map @@ -614,6 +737,7 @@ def _postprocess_outputs(self, seek_outputs, return_token_timestamps, generation ) if generation_config.return_dict_in_generate: + def split_by_batch_index(values, key, batch_idx): if key == "scores": return [v[batch_idx].cpu() for v in values] @@ -623,14 +747,28 @@ def split_by_batch_index(values, key, batch_idx): return values[batch_idx].cpu() sequence_tokens = seek_outputs["sequences"] - seek_outputs = [{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0])] + seek_outputs = [ + {k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} + for i in range(sequence_tokens.shape[0]) + ] else: sequence_tokens = seek_outputs return sequence_tokens, seek_outputs @staticmethod - def _need_fallback(seek_sequence, seek_outputs, index, logits_processor, compression_ratio_threshold, logprob_threshold, no_speech_threshold, vocab_size, eos_token_id, temperature): + def _need_fallback( + seek_sequence, + seek_outputs, + index, + logits_processor, + compression_ratio_threshold, + logprob_threshold, + no_speech_threshold, + vocab_size, + eos_token_id, + temperature, + ): needs_fallback = False should_skip = False if compression_ratio_threshold is not None: @@ -644,13 +782,17 @@ def _need_fallback(seek_sequence, seek_outputs, index, logits_processor, compres logprobs = [s["sequences_scores"] for s in seek_outputs][index] else: scores = seek_outputs[index]["scores"] - logprobs = WhisperGenerationMixin._retrieve_avg_logprobs(scores, seek_sequence, eos_token_id, temperature) + logprobs = WhisperGenerationMixin._retrieve_avg_logprobs( + scores, seek_sequence, eos_token_id, temperature + ) if logprobs < logprob_threshold: needs_fallback = True if no_speech_threshold is not None: - no_speech_prob = WhisperGenerationMixin._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "no_speech_prob") + no_speech_prob = WhisperGenerationMixin._get_attr_from_logit_processors( + logits_processor, WhisperNoSpeechDetection, "no_speech_prob" + ) if logprobs < logprob_threshold and no_speech_prob[index] > no_speech_threshold: needs_fallback = False @@ -658,17 +800,13 @@ def _need_fallback(seek_sequence, seek_outputs, index, logits_processor, compres return needs_fallback, should_skip - @staticmethod def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs): - set_inputs = WhisperGenerationMixin._get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs") + set_inputs = WhisperGenerationMixin._get_attr_from_logit_processors( + logits_processor, WhisperNoSpeechDetection, "set_inputs" + ) extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)} - set_inputs({ - "inputs": segment_input, - "decoder_input_ids": decoder_input_ids, - **extra_kwargs - }) - + set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs}) @staticmethod def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): @@ -693,7 +831,9 @@ def _retrieve_total_input_frames(input_features, input_stride, kwargs): raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.") @staticmethod - def _set_return_outputs(return_dict_in_generate, return_token_timestamps, is_shortform, logprob_threshold, generation_config): + def _set_return_outputs( + return_dict_in_generate, return_token_timestamps, is_shortform, logprob_threshold, generation_config + ): if return_dict_in_generate is None: return_dict_in_generate = generation_config.return_dict_in_generate @@ -783,10 +923,7 @@ def _set_forced_decoder_ids(task, language, prompt_ids, generation_config, confi # Legacy code for backward compatibility if hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None: forced_decoder_ids = config.forced_decoder_ids - elif ( - hasattr(generation_config, "forced_decoder_ids") - and generation_config.forced_decoder_ids is not None - ): + elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None: forced_decoder_ids = generation_config.forced_decoder_ids else: forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) @@ -875,10 +1012,14 @@ def _set_token_ids(generation_config, config, kwargs): decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id - decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id + ) generation_config.eos_token_id = eos_token_id if eos_token_id is not None else config.eos_token_id - generation_config.decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else config.decoder_start_token_id + generation_config.decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else config.decoder_start_token_id + ) @staticmethod def _set_num_frames(return_token_timestamps, generation_config, kwargs): @@ -933,7 +1074,9 @@ def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): return init_tokens - def _retrieve_logit_processors(self, generation_config, logits_processor, no_speech_threshold, num_start_tokens, is_shortform, num_beams): + def _retrieve_logit_processors( + self, generation_config, logits_processor, no_speech_threshold, num_start_tokens, is_shortform, num_beams + ): begin_index = 1 if generation_config.return_timestamps is True: forced_decoder_ids = generation_config.forced_decoder_ids @@ -955,19 +1098,30 @@ def _retrieve_logit_processors(self, generation_config, logits_processor, no_spe if generation_config.suppress_tokens is not None: suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens) logits_processor = ( - [suppress_tokens_processor] if logits_processor is None else [suppress_tokens_processor] + logits_processor + [suppress_tokens_processor] + if logits_processor is None + else [suppress_tokens_processor] + logits_processor ) generation_config.suppress_tokens = None if generation_config.begin_suppress_tokens is not None: - begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index=begin_index) + begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor( + generation_config.begin_suppress_tokens, begin_index=begin_index + ) logits_processor = ( - [begin_suppress_processor] if logits_processor is None else [begin_suppress_processor] + logits_processor + [begin_suppress_processor] + if logits_processor is None + else [begin_suppress_processor] + logits_processor ) generation_config.begin_suppress_tokens = None if no_speech_threshold is not None and not is_shortform: - no_speech_detector = WhisperNoSpeechDetection(no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, begin_index_offset=num_start_tokens, scores_is_logprobs=num_beams > 1) + no_speech_detector = WhisperNoSpeechDetection( + no_speech_token=generation_config.no_timestamps_token_id - 1, + begin_index=begin_index, + begin_index_offset=num_start_tokens, + scores_is_logprobs=num_beams > 1, + ) logits_processor = ( [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor ) @@ -996,9 +1150,7 @@ def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames segment_input = [] for i in range(cur_bsz): prev_i = batch_idx_map[i] - segment_input_slice = input_features[ - i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i] - ] + segment_input_slice = input_features[i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]] if segment_input_slice.shape[-1] < num_segment_frames: # pad to 3000 if necessary @@ -1014,7 +1166,18 @@ def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames # TODO(Patrick) - remove prev_start_of_text @staticmethod - def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx_map, do_condition_on_prev_tokens, generation_config, config, device, kwargs, prev_start_of_text): + def _prepare_decoder_input_ids( + cur_bsz, + init_tokens, + current_segments, + batch_idx_map, + do_condition_on_prev_tokens, + generation_config, + config, + device, + kwargs, + prev_start_of_text, + ): cut_off_length = config.max_target_positions // 2 - 1 one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long) @@ -1028,11 +1191,15 @@ def _prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx bos_token_tensor = prev_start_of_text * one_tensor[0] prev_tokens = WhisperGenerationMixin._pad_to_max_length( - active_segments, generation_config.pad_token_id, padding="left", bos_token_tensor=bos_token_tensor, cut_off_length=cut_off_length + active_segments, + generation_config.pad_token_id, + padding="left", + bos_token_tensor=bos_token_tensor, + cut_off_length=cut_off_length, ) decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) - kwargs["decoder_attention_mask"] = (decoder_input_ids != generation_config.pad_token_id) + kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id else: # make sure `"decoder_attention_mask"` is not passed to forward kwargs.pop("decoder_attention_mask", None) @@ -1053,16 +1220,12 @@ def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, # Make sure we don't get larger than `max_length` if passed_max_length is not None and passed_max_new_tokens is None: - max_length = min( - passed_max_length + num_initial_tokens, config.max_target_positions - ) + max_length = min(passed_max_length + num_initial_tokens, config.max_target_positions) logger.info( f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment." ) elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: - max_length = min( - generation_config.max_length + num_initial_tokens, config.max_target_positions - ) + max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions) logger.info( f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment." ) @@ -1087,7 +1250,9 @@ def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, return kwargs @staticmethod - def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None): + def _pad_to_max_length( + current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None + ): max_total_length = 0 sequences = [] if padding not in ["right", "left"]: @@ -1118,9 +1283,9 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_toke @staticmethod def _retrieve_compression_ratio(tokens, vocab_size): - """ Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes """ + """Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes""" length = int(math.log2(vocab_size) / 8) + 1 - token_bytes = b''.join([t.to_bytes(length, 'little') for t in tokens.tolist()]) + token_bytes = b"".join([t.to_bytes(length, "little") for t in tokens.tolist()]) compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes)) return compression_ratio @@ -1131,9 +1296,9 @@ def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature): scores = torch.stack(scores).to(tokens.device) if scores.shape[0] > tokens.shape[0]: - scores = scores[:tokens.shape[0]] + scores = scores[: tokens.shape[0]] else: - tokens = tokens[-scores.shape[0]:] + tokens = tokens[-scores.shape[0] :] logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype) @@ -1175,7 +1340,7 @@ def _retrieve_segment( last_slice = 0 # Add each segment to list of all segments for current_slice in slices: - sliced_tokens = seek_sequence[last_slice : current_slice] + sliced_tokens = seek_sequence[last_slice:current_slice] start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin segments.append( @@ -1217,4 +1382,3 @@ def _retrieve_segment( segment_offset = seek_num_frames[prev_idx] return segments, segment_offset - diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 7351988790f3..c204f72723c4 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Whisper model.""" -RUN_NEW_WAY = True - import math from typing import Optional, Tuple, Union diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 82428b632ea9..af1c601be4a0 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -97,8 +97,6 @@ def set_begin_index(self, begin_index: int): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # we don't want to randomely sample timestamp tokens - orig_scores = scores.clone() - if input_ids.shape[-1] != self.begin_index: scores[:, self.timestamp_begin :] = -float("inf") @@ -1437,7 +1435,13 @@ def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): seed=1, ) ] - outputs_2 = model.generate(input_features_2, max_new_tokens=max_new_tokens, logits_processor=logits_processor, condition_on_prev_tokens=condition_on_prev_tokens, return_segments=True) + outputs_2 = model.generate( + input_features_2, + max_new_tokens=max_new_tokens, + logits_processor=logits_processor, + condition_on_prev_tokens=condition_on_prev_tokens, + return_segments=True, + ) tokens_2 = outputs_2["sequences"][0] segments_2 = outputs_2["segments"][0] @@ -1465,7 +1469,7 @@ def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): segments = outputs["segments"][1] # make sure batched and non-batched is the same - assert tokens_2.tolist() == tokens[:tokens_2.shape[-1]].tolist() + assert tokens_2.tolist() == tokens[: tokens_2.shape[-1]].tolist() for seg1, seg2 in zip(segments_2, segments): assert seg1["start"] == seg2["start"] From bf7ee487e393b28099c58befa62dd8dffd9557b1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 3 Jan 2024 16:08:50 +0100 Subject: [PATCH 51/75] add docstrings --- .../models/whisper/generation_whisper.py | 75 ++++++++++++++----- 1 file changed, 56 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 56b9074397f9..3cf016cfe184 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -17,19 +17,22 @@ import math import warnings import zlib -from typing import Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn +from ...generation.configuration import GenerationConfig from ...generation.logits_process import ( + LogitsProcessorList, SuppressTokensAtBeginLogitsProcessor, SuppressTokensLogitsProcessor, WhisperNoSpeechDetection, WhisperTimeStampLogitsProcessor, ) +from ...generation.stopping_criteria import StoppingCriteriaList from ...modeling_outputs import BaseModelOutput from ...utils import logging from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE @@ -210,26 +213,26 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec def generate( self, input_features: Optional[torch.Tensor] = None, - generation_config=None, - logits_processor=None, - stopping_criteria=None, - prefix_allowed_tokens_fn=None, - synced_gpus=False, - return_timestamps=None, - task=None, - language=None, - is_multilingual=None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: bool = False, + return_timestamps: Optional[bool] = None, + task: Optional[str] = None, + language: Optional[str] = None, + is_multilingual: Optional[bool] = None, + prompt_ids: Optional[torch.Tensor] = None, condition_on_prev_tokens: Optional[bool] = None, no_speech_threshold: Optional[float] = None, temperature: Optional[Union[float, Tuple[float, ...]]] = None, compression_ratio_threshold: Optional[float] = None, logprob_threshold: Optional[float] = None, - prompt_ids: Optional[torch.Tensor] = None, num_segment_frames: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + time_precision: float = 0.02, return_token_timestamps: Optional[bool] = None, return_segments: bool = False, - attention_mask: Optional[torch.Tensor] = None, - time_precision: int = 0.02, return_dict_in_generate: Optional[bool] = None, **kwargs, ): @@ -248,7 +251,7 @@ def generate( Parameters: - inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + input_features (`torch.Tensor` of varying shape depending on the modality, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of @@ -292,6 +295,45 @@ def generate( provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. + condition_on_prev_tokens (`bool`, *optional*): + Only relevant for long-form transcription. Whether to condition each segment on the previous segment. + As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + performance. + no_speech_threshold (`float`, *optional*): + Only relevant for long-form transcription. If defined, the "no-speech" token combined with the `logprob_threshold` + is used to determine whether a segment contains only silence. In this case, the transcription for this segment + is skipped. + As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + performance. + temperature (`float` or list of `float`, *optional*): + The temperature to be used for generation. Passing a single `float` value and `do_sample=True` activates + generation using sampling. For long-form transcription, temperature fallback can be activated by passing + a list of float values such as (0.0, 0.2, 0.4, 0.6, 0.8). As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + performance. + compression_ratio_threshold (`float`, *optional*): + Only relevant for long-form transcription. If defined, the zlib compression rate of each segment will be computed. If the compression rate of + a segment is higher than `compression_ratio_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is + repeated using a higher temperature. The intuition behind this feature is that segments with very high compression rates + suffer from a lot of repetition. The unwanted repetition can be reduced by injecting more randomness by increasing the temperature. If `compression_ratio_threshold` is defined + make sure that `temperature` is a list of values. A common value for `compression_ratio_threshold` is 1.35. + As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + performance. + logprob_threshold (`float`, *optional*): + Only relevant for long-form transcription. If defined, the average log-probability of each segment will be computed. If the log-probability of + a given segment is lower than `logprob_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is + repeated using a higher temperature. The intuition behind this feature is that segments of low logprobability + can be improved by injecting more randomness by increasing the temperature. If `logprob_threshold` is defined + make sure that `temperature` is a list of values. A common value for `logprob_threshold` is -1.0. + As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + performance. + num_segment_frames (`int`, *optional*): + The number of frames a single segment is made of. If not defined, `num_segment_frames` defaults to the model's stride + times the maximum input length. + attention_mask (`torch.Tensor`, *optional*): + `attention_mask` needs to be passed when doing long-form transcription using a batch size > 1. + time_precision (`int`, *optional*, defaults to 0.02): + The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts + for 20 ms. return_token_timestamps (`bool`, *optional*): Whether to return token-level timestamps with the text. This can be used with or without the `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into @@ -299,11 +341,6 @@ def generate( return_segments (`bool`, *optional*, defaults to `False`): Whether to additionally return a list of all segments. Note that this option can only be enabled when doing long-form transcription. - attention_mask (`torch.Tensor`, *optional*): - `attention_mask` needs to be passed when doing long-form transcription using a batch size > 1. - time_precision (`int`, *optional*, defaults to 0.02): - The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts - for 20 ms. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens. Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when From 8096e4a700ab27098cf0f1d66cb3f4752f738b9b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 3 Jan 2024 23:28:55 +0100 Subject: [PATCH 52/75] add docstrings --- src/transformers/models/whisper/generation_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 3cf016cfe184..71f90045d033 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from torch import nn -from ...generation.configuration import GenerationConfig +from ...generation.configuration_utils import GenerationConfig from ...generation.logits_process import ( LogitsProcessorList, SuppressTokensAtBeginLogitsProcessor, From f667fc4ec5779b5ec5ce059210443680342d1752 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Jan 2024 14:55:49 +0000 Subject: [PATCH 53/75] Fix logit processor --- .../models/whisper/generation_whisper.py | 27 ++++++++++++++----- .../models/whisper/modeling_whisper.py | 3 --- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 71f90045d033..0cff7fe3d649 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -30,6 +30,7 @@ SuppressTokensAtBeginLogitsProcessor, SuppressTokensLogitsProcessor, WhisperNoSpeechDetection, + ForceTokensLogitsProcessor, WhisperTimeStampLogitsProcessor, ) from ...generation.stopping_criteria import StoppingCriteriaList @@ -488,14 +489,16 @@ def generate( # 5. If we're in shortform mode, simple generate the whole input at once and return the output if is_shortform: + if temperature is not None: + kwargs["temperature"] = temperature + outputs = super().generate( input_features, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - temperature=temperature, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, **kwargs, ) @@ -1164,6 +1167,18 @@ def _retrieve_logit_processors( ) no_speech_detector.set_model(self) + if is_shortform and generation_config.forced_decoder_ids is not None: + forced_tokens_proc = ForceTokensLogitsProcessor(generation_config.forced_decoder_ids) + # TODO(Patrick): It's important that the `forced_tokens_proc` processor is appended after + # the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf + # which would lead to unexpected behavior + # The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead + # initialize all of them as `decoder_input_ids`. + logits_processor = ( + [forced_tokens_proc] if logits_processor is None else logits_processor + [forced_tokens_proc] + ) + generation_config.forced_decoder_ids = None + return logits_processor @staticmethod diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8db26a8f2917..ee38a88bc5fd 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -46,9 +46,6 @@ from .generation_whisper import WhisperGenerationMixin -# tok = AutoTokenizer.from_pretrained("openai/whisper-tiny") - - if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa From 33b1903cb3d5dd966c7290195276c3ebea375bea Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Jan 2024 16:27:56 +0000 Subject: [PATCH 54/75] make style --- src/transformers/models/whisper/generation_whisper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 0cff7fe3d649..f0a7f4ca4313 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -26,11 +26,11 @@ from ...generation.configuration_utils import GenerationConfig from ...generation.logits_process import ( + ForceTokensLogitsProcessor, LogitsProcessorList, SuppressTokensAtBeginLogitsProcessor, SuppressTokensLogitsProcessor, WhisperNoSpeechDetection, - ForceTokensLogitsProcessor, WhisperTimeStampLogitsProcessor, ) from ...generation.stopping_criteria import StoppingCriteriaList @@ -1170,9 +1170,9 @@ def _retrieve_logit_processors( if is_shortform and generation_config.forced_decoder_ids is not None: forced_tokens_proc = ForceTokensLogitsProcessor(generation_config.forced_decoder_ids) # TODO(Patrick): It's important that the `forced_tokens_proc` processor is appended after - # the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf + # the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf # which would lead to unexpected behavior - # The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead + # The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead # initialize all of them as `decoder_input_ids`. logits_processor = ( [forced_tokens_proc] if logits_processor is None else logits_processor + [forced_tokens_proc] From 76134ec5d804e4aee25b56baa6c41609a191ae7b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Jan 2024 17:11:18 +0000 Subject: [PATCH 55/75] fix pipeline test --- tests/pipelines/test_pipelines_automatic_speech_recognition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 3da55ab9da10..5b8480408f02 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1135,7 +1135,7 @@ def test_with_local_lm_fast(self): @slow def test_whisper_longform(self): # fmt: off - EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out of fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct denny's, set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!""" + EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile.""" # fmt: on processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") From 8f00ba3f8e34ed86aa7b8b7a4ec3e9dd51270a74 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Jan 2024 20:44:53 +0000 Subject: [PATCH 56/75] fix more style --- .../research_projects/jax-projects/big_bird/bigbird_flax.py | 2 +- examples/research_projects/jax-projects/big_bird/train.py | 2 +- examples/research_projects/vqgan-clip/VQGAN_CLIP.py | 2 +- src/transformers/generation/logits_process.py | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py index af5e11c83a6a..c171b88800ed 100644 --- a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py +++ b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py @@ -9,13 +9,13 @@ import jax.numpy as jnp import joblib import optax -import wandb from flax import jax_utils, struct, traverse_util from flax.serialization import from_bytes, to_bytes from flax.training import train_state from flax.training.common_utils import shard from tqdm.auto import tqdm +import wandb from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule diff --git a/examples/research_projects/jax-projects/big_bird/train.py b/examples/research_projects/jax-projects/big_bird/train.py index ce37b7f975bb..3840918d16ae 100644 --- a/examples/research_projects/jax-projects/big_bird/train.py +++ b/examples/research_projects/jax-projects/big_bird/train.py @@ -2,11 +2,11 @@ from dataclasses import replace import jax -import wandb from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step from datasets import load_dataset from flax import jax_utils +import wandb from transformers import BigBirdTokenizerFast diff --git a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py index 1bfbc4cd5c36..2a39955e347f 100644 --- a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py +++ b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py @@ -4,12 +4,12 @@ import imageio import torch import torchvision -import wandb from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan from loaders import load_vqgan from PIL import Image from torch import nn +import wandb from transformers import CLIPModel, CLIPTokenizerFast from utils import get_device, get_timestamp, show_pil diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 0afc0608b4f1..1f766b51831d 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1782,6 +1782,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): max_initial_timestamp_index (`int`, *optional*, defaults to 1): Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting timestamps that are too far in the future. + begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model. _detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps. Examples: From 5634a64a4a30a31160cceb95c1ae6da21c9d4f72 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Jan 2024 21:50:48 +0100 Subject: [PATCH 57/75] Apply suggestions from code review --- .../research_projects/jax-projects/big_bird/bigbird_flax.py | 2 +- examples/research_projects/jax-projects/big_bird/train.py | 2 +- examples/research_projects/vqgan-clip/VQGAN_CLIP.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py index c171b88800ed..af5e11c83a6a 100644 --- a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py +++ b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py @@ -9,13 +9,13 @@ import jax.numpy as jnp import joblib import optax +import wandb from flax import jax_utils, struct, traverse_util from flax.serialization import from_bytes, to_bytes from flax.training import train_state from flax.training.common_utils import shard from tqdm.auto import tqdm -import wandb from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule diff --git a/examples/research_projects/jax-projects/big_bird/train.py b/examples/research_projects/jax-projects/big_bird/train.py index 3840918d16ae..ce37b7f975bb 100644 --- a/examples/research_projects/jax-projects/big_bird/train.py +++ b/examples/research_projects/jax-projects/big_bird/train.py @@ -2,11 +2,11 @@ from dataclasses import replace import jax +import wandb from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step from datasets import load_dataset from flax import jax_utils -import wandb from transformers import BigBirdTokenizerFast diff --git a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py index 2a39955e347f..1bfbc4cd5c36 100644 --- a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py +++ b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py @@ -4,12 +4,12 @@ import imageio import torch import torchvision +import wandb from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan from loaders import load_vqgan from PIL import Image from torch import nn -import wandb from transformers import CLIPModel, CLIPTokenizerFast from utils import get_device, get_timestamp, show_pil From 530c246ce839f12ac6bb466008441791a62f2ac5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 10:58:51 +0200 Subject: [PATCH 58/75] apply feedback Sanchit --- src/transformers/generation/logits_process.py | 12 ++++++------ .../models/whisper/generation_whisper.py | 10 ++-------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 9217002624f9..9b31346bac92 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1893,12 +1893,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class WhisperNoSpeechDetection(LogitsProcessor): r"""This processor can be used to detect silence when using Whisper.""" - def __init__( - self, no_speech_token: int, begin_index: int, begin_index_offset: int, scores_is_logprobs: bool = False - ): + def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False): self.no_speech_token = no_speech_token + self.orig_begin_index = begin_index + + # `self.begin_index` is a running value that is changed on the fly self.begin_index = begin_index - self.begin_index_offset = begin_index_offset self._no_speech_prob = [0.0] self.is_scores_logprobs = scores_is_logprobs @@ -1926,11 +1926,11 @@ def set_begin_index(self, begin_index): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if input_ids.shape[1] == self.begin_index: - if self.begin_index_offset > 1: + if self.orig_begin_index > 1: with torch.no_grad(): logits = self.model(**self.inputs).logits - no_speech_index = self.begin_index - self.begin_index_offset + no_speech_index = self.begin_index - self.orig_begin_index no_speech_scores = logits[:, no_speech_index] else: no_speech_scores = scores diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 76ae905145d5..b073addd89c1 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -472,15 +472,10 @@ def generate( ) # 4. Retrieve logits processors - # => TODO(Patrick): The way `num_start_tokens` is retrieved here is too brittle. Need a better approach - num_start_tokens = ( - len(generation_config.forced_decoder_ids) if generation_config.forced_decoder_ids is not None else 1 - ) logits_processor = self._retrieve_logit_processors( generation_config=generation_config, logits_processor=logits_processor, no_speech_threshold=no_speech_threshold, - num_start_tokens=num_start_tokens, is_shortform=is_shortform, num_beams=kwargs.get("num_beams", 1), ) @@ -1115,7 +1110,6 @@ def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): def _retrieve_logit_processors( self, generation_config, logits_processor, no_speech_threshold, num_start_tokens, is_shortform, num_beams ): - begin_index = 1 if generation_config.return_timestamps is True: forced_decoder_ids = generation_config.forced_decoder_ids last_forced_decoder_ids = forced_decoder_ids[-1][-1] if forced_decoder_ids is not None else None @@ -1126,8 +1120,9 @@ def _retrieve_logit_processors( # Make sure that if list is empty we set it to None generation_config.forced_decoder_ids = forced_decoder_ids - begin_index = begin_index + len(forced_decoder_ids) if forced_decoder_ids is not None else begin_index + begin_index = len(forced_decoder_ids) + 1 if forced_decoder_ids is not None else 1 + if generation_config.return_timestamps is True: timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index) logits_processor = ( [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor @@ -1157,7 +1152,6 @@ def _retrieve_logit_processors( no_speech_detector = WhisperNoSpeechDetection( no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, - begin_index_offset=num_start_tokens, scores_is_logprobs=num_beams > 1, ) logits_processor = ( From d1662afbabc0c6c10b80f212b0f21e4080e75cc1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 09:11:34 +0000 Subject: [PATCH 59/75] correct more --- src/transformers/generation/logits_process.py | 8 +++++--- src/transformers/models/whisper/generation_whisper.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 9b31346bac92..76fb7d2ee69c 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1895,7 +1895,9 @@ class WhisperNoSpeechDetection(LogitsProcessor): def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False): self.no_speech_token = no_speech_token - self.orig_begin_index = begin_index + # offset between token, , in paper and first generated token + # is equal to the position of the first generated token index + self.start_of_trans_offset = begin_index # `self.begin_index` is a running value that is changed on the fly self.begin_index = begin_index @@ -1926,11 +1928,11 @@ def set_begin_index(self, begin_index): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if input_ids.shape[1] == self.begin_index: - if self.orig_begin_index > 1: + if self.start_of_trans_offset > 1: with torch.no_grad(): logits = self.model(**self.inputs).logits - no_speech_index = self.begin_index - self.orig_begin_index + no_speech_index = self.begin_index - self.start_of_trans_offset no_speech_scores = logits[:, no_speech_index] else: no_speech_scores = scores diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index b073addd89c1..14ceb0b59ea5 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1108,7 +1108,7 @@ def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): return init_tokens def _retrieve_logit_processors( - self, generation_config, logits_processor, no_speech_threshold, num_start_tokens, is_shortform, num_beams + self, generation_config, logits_processor, no_speech_threshold, is_shortform, num_beams ): if generation_config.return_timestamps is True: forced_decoder_ids = generation_config.forced_decoder_ids From d22a9b3c072e808725217fcb24c0d22e68a2be3e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 11:13:44 +0200 Subject: [PATCH 60/75] Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- .../models/whisper/generation_whisper.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 14ceb0b59ea5..bc576df67b78 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -238,7 +238,7 @@ def generate( **kwargs, ): """ - Transcribes or translates passed mel input features to a sequence of token ids. + Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids. @@ -252,11 +252,12 @@ def generate( Parameters: - input_features (`torch.Tensor` of varying shape depending on the modality, *optional*): - The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the - method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` - should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of - `input_ids`, `input_values`, `input_features`, or `pixel_values`. + input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*): + Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details. generation_config (`~generation.GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If From c34574f3d459547279b7a02607e31923d69d5a63 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 11:17:19 +0200 Subject: [PATCH 61/75] Apply suggestions from code review Co-authored-by: Joao Gante Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- .../models/whisper/generation_whisper.py | 9 +-- tests/models/whisper/test_modeling_whisper.py | 58 ++++++++----------- 2 files changed, 28 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index bc576df67b78..e137fe67b013 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -310,7 +310,7 @@ def generate( temperature (`float` or list of `float`, *optional*): The temperature to be used for generation. Passing a single `float` value and `do_sample=True` activates generation using sampling. For long-form transcription, temperature fallback can be activated by passing - a list of float values such as (0.0, 0.2, 0.4, 0.6, 0.8). As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + a list of float values such as (0.0, 0.2, 0.4, 0.6, 0.8, 1.0). As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve performance. compression_ratio_threshold (`float`, *optional*): Only relevant for long-form transcription. If defined, the zlib compression rate of each segment will be computed. If the compression rate of @@ -323,7 +323,7 @@ def generate( logprob_threshold (`float`, *optional*): Only relevant for long-form transcription. If defined, the average log-probability of each segment will be computed. If the log-probability of a given segment is lower than `logprob_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is - repeated using a higher temperature. The intuition behind this feature is that segments of low logprobability + repeated using a higher temperature. The intuition behind this feature is that segments of low log-probability can be improved by injecting more randomness by increasing the temperature. If `logprob_threshold` is defined make sure that `temperature` is a list of values. A common value for `logprob_threshold` is -1.0. As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve @@ -555,9 +555,7 @@ def generate( # 6.5 prepare decoder input ids # TODO(Patrick) - clean up prev_start_of_text - suppress_tokens = self._get_attr_from_logit_processors( - logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens" - ) + suppress_tokens = generation_config.suppress_tokens prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None decoder_input_ids, kwargs = self._prepare_decoder_input_ids( cur_bsz=cur_bsz, @@ -1228,7 +1226,6 @@ def _prepare_decoder_input_ids( one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long) decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) - # if condition_on_prev_tokens and len(current_segments[0]) > 0 and temperature < 0.5: if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0: # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index dc2f310554e1..97a01ccefab0 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2148,12 +2148,14 @@ def test_whisper_longform_single_batch_prev_cond(self): ] input_features = input_features.to(device="cuda") - gen_kwargs = {"return_timestamps": True} - gen_kwargs["no_speech_threshold"] = 0.6 - gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) - gen_kwargs["compression_ratio_threshold"] = 1.35 - gen_kwargs["condition_on_prev_tokens"] = True - gen_kwargs["logprob_threshold"] = -1.0 + gen_kwargs = { + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, + "logprob_threshold": -1.0, + } torch.manual_seed(0) result = model.generate(input_features, **gen_kwargs) @@ -2231,12 +2233,14 @@ def test_whisper_longform_multi_batch_prev_cond(self): audios.append(one_audio[80000:]) audios.append(one_audio[:]) - gen_kwargs = {"return_timestamps": True} - gen_kwargs["no_speech_threshold"] = 0.6 - gen_kwargs["temperature"] = 0.0 - gen_kwargs["compression_ratio_threshold"] = 1.35 - gen_kwargs["condition_on_prev_tokens"] = True - gen_kwargs["logprob_threshold"] = -1.0 + gen_kwargs = { + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": 0.0, + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, + "logprob_threshold": -1.0, + } with open("/home/patrick/expected.txt", "w") as f: decoded_single = [] @@ -2246,21 +2250,7 @@ def test_whisper_longform_multi_batch_prev_cond(self): result = model.generate(**inputs, **gen_kwargs) decoded_single.append(processor.batch_decode(result, skip_special_tokens=True)) - f.write(decoded_single[-1][0] + "\n") - # inputs = processor( - # audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True - # ) - # inputs = inputs.to(device="cuda") - - # result = model.generate(**inputs, **gen_kwargs) - # decoded_all = processor.batch_decode(result, skip_special_tokens=True) - - # # make sure single & batch is exactly the same - # assert decoded_all[0:1] == decoded_single[0] - # assert decoded_all[1:2] == decoded_single[1] - # assert decoded_all[2:3] == decoded_single[2] - # assert decoded_all[3:4] == decoded_single[3] # exact match assert decoded_single[0] == EXPECTED_TEXT_1 @@ -2347,13 +2337,15 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): ) inputs = inputs.to(device="cuda") - gen_kwargs = {"return_timestamps": True} - gen_kwargs["no_speech_threshold"] = 0.6 - gen_kwargs["compression_ratio_threshold"] = 1.35 - gen_kwargs["condition_on_prev_tokens"] = True - gen_kwargs["logprob_threshold"] = -1.0 - gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) - gen_kwargs["num_beams"] = 5 + gen_kwargs = { + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, + "logprob_threshold": -1.0, + "num_beams": 5, + } torch.manual_seed(0) result = model.generate(**inputs, **gen_kwargs) From 0e7c86e302c395ba2d856e67faed1d5ba581d507 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 09:20:39 +0000 Subject: [PATCH 62/75] correct more --- src/transformers/models/whisper/generation_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index e137fe67b013..f1bf620a204d 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1109,8 +1109,8 @@ def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): def _retrieve_logit_processors( self, generation_config, logits_processor, no_speech_threshold, is_shortform, num_beams ): + forced_decoder_ids = generation_config.forced_decoder_ids if generation_config.return_timestamps is True: - forced_decoder_ids = generation_config.forced_decoder_ids last_forced_decoder_ids = forced_decoder_ids[-1][-1] if forced_decoder_ids is not None else None if last_forced_decoder_ids == generation_config.no_timestamps_token_id: # remove no_timestamp to be forcefully generated if we want to return timestamps From 67c2ea4645d78b7b1c1ef428e25e22ffb79558b7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 09:58:08 +0000 Subject: [PATCH 63/75] correct more --- .../models/whisper/generation_whisper.py | 153 +++++++++++------- 1 file changed, 95 insertions(+), 58 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index f1bf620a204d..0ce8622110c8 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -117,6 +117,43 @@ def _dynamic_time_warping(matrix: np.ndarray): return text_indices, time_indices +def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): + logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None) + if logit_processor: + return getattr(logit_processor, attribute_name, None) + return None + + +def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None): + max_total_length = 0 + sequences = [] + if padding not in ["right", "left"]: + raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}") + + for current_segment_list in current_segments: + if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0: + sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1) + + if cut_off_length is not None: + sequence = sequence[-cut_off_length:] + + if bos_token_tensor is not None: + sequence = torch.cat([bos_token_tensor, sequence]) + + sequences.append(sequence) + max_total_length = max(max_total_length, len(sequences[-1])) + else: + sequences.append(bos_token_tensor) + + for i in range(len(current_segments)): + pad_length = max_total_length - len(sequences[i]) + pad = (0, pad_length) if padding == "right" else (pad_length, 0) + sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id) + + sequences = torch.stack(sequences, dim=0) + return sequences + + class WhisperGenerationMixin: def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None): """ @@ -225,10 +262,10 @@ def generate( is_multilingual: Optional[bool] = None, prompt_ids: Optional[torch.Tensor] = None, condition_on_prev_tokens: Optional[bool] = None, - no_speech_threshold: Optional[float] = None, temperature: Optional[Union[float, Tuple[float, ...]]] = None, compression_ratio_threshold: Optional[float] = None, logprob_threshold: Optional[float] = None, + no_speech_threshold: Optional[float] = None, num_segment_frames: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, time_precision: float = 0.02, @@ -301,12 +338,6 @@ def generate( Only relevant for long-form transcription. Whether to condition each segment on the previous segment. As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve performance. - no_speech_threshold (`float`, *optional*): - Only relevant for long-form transcription. If defined, the "no-speech" token combined with the `logprob_threshold` - is used to determine whether a segment contains only silence. In this case, the transcription for this segment - is skipped. - As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve - performance. temperature (`float` or list of `float`, *optional*): The temperature to be used for generation. Passing a single `float` value and `do_sample=True` activates generation using sampling. For long-form transcription, temperature fallback can be activated by passing @@ -328,6 +359,12 @@ def generate( make sure that `temperature` is a list of values. A common value for `logprob_threshold` is -1.0. As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve performance. + no_speech_threshold (`float`, *optional*): + Only relevant for long-form transcription. If defined, the "no-speech" token combined with the `logprob_threshold` + is used to determine whether a segment contains only silence. In this case, the transcription for this segment + is skipped. + As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + performance. num_segment_frames (`int`, *optional*): The number of frames a single segment is made of. If not defined, `num_segment_frames` defaults to the model's stride times the maximum input length. @@ -443,6 +480,17 @@ def generate( ) is_shortform = total_input_frames <= num_segment_frames + if is_shortform: + # warn user of ignored inputs + self._maybe_warn_unused_inputs( + condition_on_prev_tokens=condition_on_prev_tokens, + temperature=temperature, + compression_ratio_threshold=compression_ratio_threshold, + logprob_threshold=logprob_threshold, + no_speech_threshold=no_speech_threshold, + total_input_frames=total_input_frames, + ) + # 3. Make sure generation config is correctly set # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not self._set_return_outputs( @@ -632,7 +680,7 @@ def generate( # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output - sequences = self._pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right") + sequences = _pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right") # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. if return_segments: @@ -788,8 +836,8 @@ def split_by_batch_index(values, key, batch_idx): return sequence_tokens, seek_outputs - @staticmethod def _need_fallback( + self, seek_sequence, seek_outputs, index, @@ -804,7 +852,7 @@ def _need_fallback( needs_fallback = False should_skip = False if compression_ratio_threshold is not None: - compression_ratio = WhisperGenerationMixin._retrieve_compression_ratio(seek_sequence, vocab_size) + compression_ratio = self._retrieve_compression_ratio(seek_sequence, vocab_size) if compression_ratio > compression_ratio_threshold: needs_fallback = True @@ -814,15 +862,13 @@ def _need_fallback( logprobs = [s["sequences_scores"] for s in seek_outputs][index] else: scores = seek_outputs[index]["scores"] - logprobs = WhisperGenerationMixin._retrieve_avg_logprobs( - scores, seek_sequence, eos_token_id, temperature - ) + logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, eos_token_id, temperature) if logprobs < logprob_threshold: needs_fallback = True if no_speech_threshold is not None: - no_speech_prob = WhisperGenerationMixin._get_attr_from_logit_processors( + no_speech_prob = _get_attr_from_logit_processors( logits_processor, WhisperNoSpeechDetection, "no_speech_prob" ) @@ -834,19 +880,10 @@ def _need_fallback( @staticmethod def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs): - set_inputs = WhisperGenerationMixin._get_attr_from_logit_processors( - logits_processor, WhisperNoSpeechDetection, "set_inputs" - ) + set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs") extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)} set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs}) - @staticmethod - def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): - logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None) - if logit_processor: - return getattr(logit_processor, attribute_name, None) - return None - @staticmethod def _retrieve_total_input_frames(input_features, input_stride, kwargs): if input_features is not None: @@ -862,6 +899,39 @@ def _retrieve_total_input_frames(input_features, input_stride, kwargs): raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.") + @staticmethod + def _maybe_warn_unused_inputs( + condition_on_prev_tokens, + temperature, + compression_ratio_threshold, + logprob_threshold, + no_speech_threshold, + total_input_frames, + ): + warning_prefix = ( + f"Audio input consists of only {total_input_frames}. " + "Short-form transcription is activated." + "{}, but will be ignored." + ) + if condition_on_prev_tokens is not None: + logger.warn(warning_prefix.format(f"condition_on_prev_tokens is set to {condition_on_prev_tokens}")) + + if compression_ratio_threshold is not None: + logger.warn(warning_prefix.format(f"compression_ratio_threshold is set to {compression_ratio_threshold}")) + + if logprob_threshold is not None: + logger.warn(warning_prefix.format(f"logprob_threshold is set to {logprob_threshold}")) + + if no_speech_threshold is not None: + logger.warn(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}")) + + # when passing temperature as a list it cannot just be ignored => throw error in this case + if isinstance(temperature, (list, tuple)): + raise ValueError( + f"Audio input consists of only {total_input_frames}. Short-form transcription is activated." + f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation." + ) + @staticmethod def _set_return_outputs( return_dict_in_generate, return_token_timestamps, is_shortform, logprob_threshold, generation_config @@ -1232,7 +1302,7 @@ def _prepare_decoder_input_ids( prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or prev_start_of_text bos_token_tensor = prev_start_of_text * one_tensor[0] - prev_tokens = WhisperGenerationMixin._pad_to_max_length( + prev_tokens = _pad_to_max_length( active_segments, generation_config.pad_token_id, padding="left", @@ -1291,39 +1361,6 @@ def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, return kwargs - @staticmethod - def _pad_to_max_length( - current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None - ): - max_total_length = 0 - sequences = [] - if padding not in ["right", "left"]: - raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}") - - for current_segment_list in current_segments: - if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0: - sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1) - - if cut_off_length is not None: - sequence = sequence[-cut_off_length:] - - if bos_token_tensor is not None: - sequence = torch.cat([bos_token_tensor, sequence]) - - sequences.append(sequence) - max_total_length = max(max_total_length, len(sequences[-1])) - else: - sequences.append(bos_token_tensor) - - for i in range(len(current_segments)): - pad_length = max_total_length - len(sequences[i]) - pad = (0, pad_length) if padding == "right" else (pad_length, 0) - sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id) - - sequences = torch.stack(sequences, dim=0) - return sequences - - @staticmethod def _retrieve_compression_ratio(tokens, vocab_size): """Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes""" length = int(math.log2(vocab_size) / 8) + 1 From c9da44d7607a36d5539349f3c3e552d2a96687b3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 09:58:17 +0000 Subject: [PATCH 64/75] correct more --- tests/models/whisper/test_modeling_whisper.py | 52 +++++++++---------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 97a01ccefab0..34c9da40e5de 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2149,12 +2149,12 @@ def test_whisper_longform_single_batch_prev_cond(self): input_features = input_features.to(device="cuda") gen_kwargs = { - "return_timestamps": True, - "no_speech_threshold": 0.6, - "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), - "compression_ratio_threshold": 1.35, - "condition_on_prev_tokens": True, - "logprob_threshold": -1.0, + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, + "logprob_threshold": -1.0, } torch.manual_seed(0) @@ -2234,23 +2234,21 @@ def test_whisper_longform_multi_batch_prev_cond(self): audios.append(one_audio[:]) gen_kwargs = { - "return_timestamps": True, - "no_speech_threshold": 0.6, - "temperature": 0.0, - "compression_ratio_threshold": 1.35, - "condition_on_prev_tokens": True, - "logprob_threshold": -1.0, + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": 0.0, + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, + "logprob_threshold": -1.0, } - with open("/home/patrick/expected.txt", "w") as f: - decoded_single = [] - for audio in audios: - inputs = processor(audio, return_tensors="pt", truncation=False) - inputs = inputs.to(device="cuda") - - result = model.generate(**inputs, **gen_kwargs) - decoded_single.append(processor.batch_decode(result, skip_special_tokens=True)) + decoded_single = [] + for audio in audios: + inputs = processor(audio, return_tensors="pt", truncation=False) + inputs = inputs.to(device="cuda") + result = model.generate(**inputs, **gen_kwargs) + decoded_single.append(processor.batch_decode(result, skip_special_tokens=True)) # exact match assert decoded_single[0] == EXPECTED_TEXT_1 @@ -2338,13 +2336,13 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): inputs = inputs.to(device="cuda") gen_kwargs = { - "return_timestamps": True, - "no_speech_threshold": 0.6, - "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), - "compression_ratio_threshold": 1.35, - "condition_on_prev_tokens": True, - "logprob_threshold": -1.0, - "num_beams": 5, + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, + "logprob_threshold": -1.0, + "num_beams": 5, } torch.manual_seed(0) From 6541bac020b64edce08dd00d22c7215e280baa83 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 10:01:47 +0000 Subject: [PATCH 65/75] Fix staticmethod --- src/transformers/models/whisper/generation_whisper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 0ce8622110c8..ea6103701e04 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1361,6 +1361,7 @@ def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, return kwargs + @staticmethod def _retrieve_compression_ratio(tokens, vocab_size): """Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes""" length = int(math.log2(vocab_size) / 8) + 1 From aae16f316eff7239737187c11a9002f57f155203 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 12:41:11 +0000 Subject: [PATCH 66/75] correct more --- .../jax-projects/big_bird/bigbird_flax.py | 2 +- .../jax-projects/big_bird/train.py | 2 +- .../research_projects/vqgan-clip/VQGAN_CLIP.py | 2 +- .../models/whisper/tokenization_whisper.py | 17 ++++++++++++++--- tests/models/whisper/test_modeling_whisper.py | 15 +++++++++++++++ 5 files changed, 32 insertions(+), 6 deletions(-) diff --git a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py index af5e11c83a6a..c171b88800ed 100644 --- a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py +++ b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py @@ -9,13 +9,13 @@ import jax.numpy as jnp import joblib import optax -import wandb from flax import jax_utils, struct, traverse_util from flax.serialization import from_bytes, to_bytes from flax.training import train_state from flax.training.common_utils import shard from tqdm.auto import tqdm +import wandb from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule diff --git a/examples/research_projects/jax-projects/big_bird/train.py b/examples/research_projects/jax-projects/big_bird/train.py index ce37b7f975bb..3840918d16ae 100644 --- a/examples/research_projects/jax-projects/big_bird/train.py +++ b/examples/research_projects/jax-projects/big_bird/train.py @@ -2,11 +2,11 @@ from dataclasses import replace import jax -import wandb from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step from datasets import load_dataset from flax import jax_utils +import wandb from transformers import BigBirdTokenizerFast diff --git a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py index 1bfbc4cd5c36..2a39955e347f 100644 --- a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py +++ b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py @@ -4,12 +4,12 @@ import imageio import torch import torchvision -import wandb from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan from loaders import load_vqgan from PIL import Image from torch import nn +import wandb from transformers import CLIPModel, CLIPTokenizerFast from utils import get_device, get_timestamp, show_pil diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index a54103ccef8f..51a6fb757f38 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -530,10 +530,21 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre """ timestamp_begin = self.all_special_ids[-1] + 1 outputs = [[]] + + cur_max_timestamp = 0.0 + prev_segments_len = 0.0 + for token in token_ids: if token >= timestamp_begin: - timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>" - outputs.append(timestamp) + timestamp = float(f"{(token - timestamp_begin) * time_precision:.2f}") + if timestamp < cur_max_timestamp: + # next segment has started + prev_segments_len += cur_max_timestamp + + cur_max_timestamp = timestamp + timestamp = round(timestamp + prev_segments_len, 2) + + outputs.append(f"<|{timestamp}|>") outputs.append([]) else: outputs[-1].append(token) @@ -628,7 +639,7 @@ def decode( skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None, output_offsets: bool = False, - time_precision=0.02, + time_precision: float = 0.02, decode_with_timestamps: bool = False, normalize: bool = False, basic_normalize: bool = False, diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 34c9da40e5de..defa0d2c015d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -18,6 +18,7 @@ import inspect import os import random +import re import tempfile import time import unittest @@ -2130,6 +2131,20 @@ def test_whisper_longform_single_batch(self): assert decoded == EXPECTED_TEXT + decoded_with_timestamps = processor.batch_decode(result, skip_special_tokens=True, decode_with_timestamps=True) + + no_timestamp_matches = re.split(r"<\|[\d\.]+\|>", decoded_with_timestamps[0]) + + assert ["".join(no_timestamp_matches)] == EXPECTED_TEXT + + timestamp_matches = re.findall(r"<\|[\d\.]+\|>", decoded_with_timestamps[0]) + + timestamp_floats = [float(t[2:-2]) for t in timestamp_matches] + + is_increasing = all(timestamp_floats[i] <= timestamp_floats[i + 1] for i in range(len(timestamp_floats) - 1)) + + assert is_increasing + @slow def test_whisper_longform_single_batch_prev_cond(self): # fmt: off From d24a4d856c398d1eff7da8d68efe3e63cfb7ef6a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 15:34:11 +0000 Subject: [PATCH 67/75] fix --- .../research_projects/jax-projects/big_bird/bigbird_flax.py | 2 +- examples/research_projects/jax-projects/big_bird/train.py | 2 +- examples/research_projects/vqgan-clip/VQGAN_CLIP.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py index c171b88800ed..af5e11c83a6a 100644 --- a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py +++ b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py @@ -9,13 +9,13 @@ import jax.numpy as jnp import joblib import optax +import wandb from flax import jax_utils, struct, traverse_util from flax.serialization import from_bytes, to_bytes from flax.training import train_state from flax.training.common_utils import shard from tqdm.auto import tqdm -import wandb from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule diff --git a/examples/research_projects/jax-projects/big_bird/train.py b/examples/research_projects/jax-projects/big_bird/train.py index 3840918d16ae..ce37b7f975bb 100644 --- a/examples/research_projects/jax-projects/big_bird/train.py +++ b/examples/research_projects/jax-projects/big_bird/train.py @@ -2,11 +2,11 @@ from dataclasses import replace import jax +import wandb from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step from datasets import load_dataset from flax import jax_utils -import wandb from transformers import BigBirdTokenizerFast diff --git a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py index 2a39955e347f..1bfbc4cd5c36 100644 --- a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py +++ b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py @@ -4,12 +4,12 @@ import imageio import torch import torchvision +import wandb from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan from loaders import load_vqgan from PIL import Image from torch import nn -import wandb from transformers import CLIPModel, CLIPTokenizerFast from utils import get_device, get_timestamp, show_pil From 85ec8feeaddc67e52ec29ffe286e01542e8780b2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 16:21:37 +0000 Subject: [PATCH 68/75] fix slow tests --- src/transformers/models/whisper/generation_whisper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index ea6103701e04..56eba28f55f4 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -602,8 +602,9 @@ def generate( ) # 6.5 prepare decoder input ids - # TODO(Patrick) - clean up prev_start_of_text - suppress_tokens = generation_config.suppress_tokens + suppress_tokens = _get_attr_from_logit_processors( + logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens" + ) prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None decoder_input_ids, kwargs = self._prepare_decoder_input_ids( cur_bsz=cur_bsz, From 32b745c100232aae3ced8220ee3eed611480fe4c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jan 2024 17:34:15 +0000 Subject: [PATCH 69/75] make style --- .../models/whisper/tokenization_whisper_fast.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index ee44bb5918d2..4b2cdb8f9ff0 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -224,10 +224,21 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre """ timestamp_begin = self.all_special_ids[-1] + 1 outputs = [[]] + + cur_max_timestamp = 0.0 + prev_segments_len = 0.0 + for token in token_ids: if token >= timestamp_begin: - timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>" - outputs.append(timestamp) + timestamp = float(f"{(token - timestamp_begin) * time_precision:.2f}") + if timestamp < cur_max_timestamp: + # next segment has started + prev_segments_len += cur_max_timestamp + + cur_max_timestamp = timestamp + timestamp = round(timestamp + prev_segments_len, 2) + + outputs.append(f"<|{timestamp}|>") outputs.append([]) else: outputs[-1].append(token) @@ -327,7 +338,7 @@ def decode( skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None, output_offsets: bool = False, - time_precision=0.02, + time_precision: float = 0.02, decode_with_timestamps: bool = False, normalize: bool = False, basic_normalize: bool = False, From 113d6781b2529d21658d50010f976170dac3ad26 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 17 Jan 2024 09:22:40 +0000 Subject: [PATCH 70/75] fix tokenizer test --- src/transformers/models/whisper/tokenization_whisper.py | 6 +++--- .../models/whisper/tokenization_whisper_fast.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 51a6fb757f38..dbb7cc6684ac 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -536,15 +536,15 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre for token in token_ids: if token >= timestamp_begin: - timestamp = float(f"{(token - timestamp_begin) * time_precision:.2f}") + timestamp = float((token - timestamp_begin) * time_precision) + if timestamp < cur_max_timestamp: # next segment has started prev_segments_len += cur_max_timestamp cur_max_timestamp = timestamp - timestamp = round(timestamp + prev_segments_len, 2) - outputs.append(f"<|{timestamp}|>") + outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>") outputs.append([]) else: outputs[-1].append(token) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 4b2cdb8f9ff0..089d6dc52bff 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -230,15 +230,15 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre for token in token_ids: if token >= timestamp_begin: - timestamp = float(f"{(token - timestamp_begin) * time_precision:.2f}") + timestamp = float((token - timestamp_begin) * time_precision) + if timestamp < cur_max_timestamp: # next segment has started prev_segments_len += cur_max_timestamp cur_max_timestamp = timestamp - timestamp = round(timestamp + prev_segments_len, 2) - outputs.append(f"<|{timestamp}|>") + outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>") outputs.append([]) else: outputs[-1].append(token) From de34f234446040f6de1692230de9334a26b95858 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 17 Jan 2024 09:22:53 +0000 Subject: [PATCH 71/75] fix tokenizer test --- src/transformers/models/whisper/tokenization_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index dbb7cc6684ac..b24ab62a8b50 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -537,7 +537,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre for token in token_ids: if token >= timestamp_begin: timestamp = float((token - timestamp_begin) * time_precision) - + if timestamp < cur_max_timestamp: # next segment has started prev_segments_len += cur_max_timestamp From db02e95ba217cbd32b82d5c48fc870840dbcbcec Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 19 Jan 2024 12:35:57 +0200 Subject: [PATCH 72/75] Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/logits_process.py | 4 +--- src/transformers/models/whisper/generation_whisper.py | 5 ++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 76fb7d2ee69c..acc69360ca0c 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1891,7 +1891,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class WhisperNoSpeechDetection(LogitsProcessor): - r"""This processor can be used to detect silence when using Whisper.""" + r"""This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits to follow the original implementation""" def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False): self.no_speech_token = no_speech_token @@ -1904,8 +1904,6 @@ def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: b self._no_speech_prob = [0.0] self.is_scores_logprobs = scores_is_logprobs - # make sure we pass all logits - self._pass_all_logits = True # overwritten dynamically self.model = None diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 56eba28f55f4..270322c7cb29 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1,6 +1,5 @@ # coding=utf-8 -# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -710,7 +709,7 @@ def generate_with_fallback( no_speech_threshold, do_condition_on_prev_tokens, condition_on_prev_tokens, - kwargs, + **kwargs, ): # 6.6 Batch generate current chunk seek_sequence_list = [None for _ in range(cur_bsz)] From 71b4893a8596eb7058995d45be9941bbbe1d21a9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 19 Jan 2024 11:21:42 +0000 Subject: [PATCH 73/75] finish --- src/transformers/generation/logits_process.py | 1 - .../models/whisper/generation_whisper.py | 87 ++++++++++++------- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index acc69360ca0c..04120e39fbd2 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1904,7 +1904,6 @@ def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: b self._no_speech_prob = [0.0] self.is_scores_logprobs = scores_is_logprobs - # overwritten dynamically self.model = None self.inputs = None diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 270322c7cb29..67b5369ec849 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -518,6 +518,13 @@ def generate( self._set_num_frames( return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs ) + self._set_thresholds_and_condition( + generation_config=generation_config, + logprob_threshold=logprob_threshold, + compression_ratio_threshold=compression_ratio_threshold, + no_speech_threshold=no_speech_threshold, + condition_on_prev_tokens=condition_on_prev_tokens, + ) # 4. Retrieve logits processors logits_processor = self._retrieve_logit_processors( @@ -604,7 +611,6 @@ def generate( suppress_tokens = _get_attr_from_logit_processors( logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens" ) - prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None decoder_input_ids, kwargs = self._prepare_decoder_input_ids( cur_bsz=cur_bsz, init_tokens=init_tokens, @@ -614,8 +620,8 @@ def generate( generation_config=generation_config, config=self.config, device=segment_input.device, + suppress_tokens=suppress_tokens, kwargs=kwargs, - prev_start_of_text=prev_start_of_text, ) # 6.6 set max new tokens or max length @@ -647,11 +653,7 @@ def generate( prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, return_token_timestamps=return_token_timestamps, - compression_ratio_threshold=compression_ratio_threshold, - logprob_threshold=logprob_threshold, - no_speech_threshold=no_speech_threshold, do_condition_on_prev_tokens=do_condition_on_prev_tokens, - condition_on_prev_tokens=condition_on_prev_tokens, kwargs=kwargs, ) @@ -704,11 +706,7 @@ def generate_with_fallback( prefix_allowed_tokens_fn, synced_gpus, return_token_timestamps, - compression_ratio_threshold, - logprob_threshold, - no_speech_threshold, do_condition_on_prev_tokens, - condition_on_prev_tokens, **kwargs, ): # 6.6 Batch generate current chunk @@ -718,7 +716,7 @@ def generate_with_fallback( should_skip = [False for _ in range(cur_bsz)] fallback_index_map = list(range(cur_bsz)) - if no_speech_threshold is not None: + if generation_config.no_speech_threshold is not None: self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs) for fallback_idx, temperature in enumerate(temperatures): @@ -772,18 +770,15 @@ def generate_with_fallback( seek_outputs, i, logits_processor, - compression_ratio_threshold, - logprob_threshold, - no_speech_threshold, + generation_config, self.config.vocab_size, - generation_config.eos_token_id, temperature, ) seek_sequence_list[fallback_index_map[i]] = seek_sequence seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] do_condition_on_prev_tokens[fallback_index_map[i]] = ( - condition_on_prev_tokens and temperature is not None and temperature < 0.5 + generation_config.condition_on_prev_tokens and temperature is not None and temperature < 0.5 ) if needs_fallback[i]: @@ -842,37 +837,39 @@ def _need_fallback( seek_outputs, index, logits_processor, - compression_ratio_threshold, - logprob_threshold, - no_speech_threshold, + generation_config, vocab_size, - eos_token_id, temperature, ): needs_fallback = False should_skip = False - if compression_ratio_threshold is not None: + if generation_config.compression_ratio_threshold is not None: compression_ratio = self._retrieve_compression_ratio(seek_sequence, vocab_size) - if compression_ratio > compression_ratio_threshold: + if compression_ratio > generation_config.compression_ratio_threshold: needs_fallback = True - if logprob_threshold is not None: + if generation_config.logprob_threshold is not None: if "sequences_scores" in seek_outputs[0]: logprobs = [s["sequences_scores"] for s in seek_outputs][index] else: scores = seek_outputs[index]["scores"] - logprobs = self._retrieve_avg_logprobs(scores, seek_sequence, eos_token_id, temperature) + logprobs = self._retrieve_avg_logprobs( + scores, seek_sequence, generation_config.eos_token_id, temperature + ) - if logprobs < logprob_threshold: + if logprobs < generation_config.logprob_threshold: needs_fallback = True - if no_speech_threshold is not None: + if generation_config.no_speech_threshold is not None: no_speech_prob = _get_attr_from_logit_processors( logits_processor, WhisperNoSpeechDetection, "no_speech_prob" ) - if logprobs < logprob_threshold and no_speech_prob[index] > no_speech_threshold: + if ( + logprobs < generation_config.logprob_threshold + and no_speech_prob[index] > generation_config.no_speech_threshold + ): needs_fallback = False should_skip = True @@ -1136,6 +1133,35 @@ def _set_num_frames(return_token_timestamps, generation_config, kwargs): generation_config.num_frames = kwargs.pop("num_frames", None) + @staticmethod + def _set_thresholds_and_condition( + generation_config, + logprob_threshold, + compression_ratio_threshold, + no_speech_threshold, + condition_on_prev_tokens, + ): + generation_config.logprob_threshold = ( + logprob_threshold + if logprob_threshold is not None + else getattr(generation_config, "logprob_threshold", None) + ) + generation_config.compression_ratio_threshold = ( + compression_ratio_threshold + if compression_ratio_threshold is not None + else getattr(generation_config, "compression_ratio_threshold", None) + ) + generation_config.logprob_threshold = ( + no_speech_threshold + if no_speech_threshold is not None + else getattr(generation_config, "no_speech_threshold", None) + ) + generation_config.condition_on_prev_tokens = ( + condition_on_prev_tokens + if condition_on_prev_tokens is not None + else getattr(generation_config, "condition_on_prev_tokens", None) + ) + @staticmethod def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config): condition_on_prev_tokens = ( @@ -1277,7 +1303,6 @@ def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames return segment_input - # TODO(Patrick) - remove prev_start_of_text @staticmethod def _prepare_decoder_input_ids( cur_bsz, @@ -1288,14 +1313,18 @@ def _prepare_decoder_input_ids( generation_config, config, device, + suppress_tokens, kwargs, - prev_start_of_text, ): cut_off_length = config.max_target_positions // 2 - 1 one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long) decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) + prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None) + if prev_start_of_text is None: + prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None + if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0: # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] From e3711a3bacf42ec01020df1519b937495e101e4b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 19 Jan 2024 11:26:00 +0000 Subject: [PATCH 74/75] finish --- src/transformers/models/whisper/generation_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 67b5369ec849..dc7240bcdb21 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1151,7 +1151,7 @@ def _set_thresholds_and_condition( if compression_ratio_threshold is not None else getattr(generation_config, "compression_ratio_threshold", None) ) - generation_config.logprob_threshold = ( + generation_config.no_speech_threshold = ( no_speech_threshold if no_speech_threshold is not None else getattr(generation_config, "no_speech_threshold", None) From e9673c5900cb1de24af87cd195fe8347082a312f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 19 Jan 2024 11:37:48 +0000 Subject: [PATCH 75/75] revert kwargs change --- src/transformers/models/whisper/generation_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index dc7240bcdb21..c45fffb984b1 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -707,7 +707,7 @@ def generate_with_fallback( synced_gpus, return_token_timestamps, do_condition_on_prev_tokens, - **kwargs, + kwargs, ): # 6.6 Batch generate current chunk seek_sequence_list = [None for _ in range(cur_bsz)]