From fa7d6e0ef88fb1bd68a902d0b83c86bd63d07fe6 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Sun, 12 Oct 2025 20:04:38 -0400 Subject: [PATCH 01/67] parakeet tdt intergration --- .../models/auto/configuration_auto.py | 11 +- .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 19 + .../models/parakeet/configuration_parakeet.py | 140 +++++- .../models/parakeet/convert_nemo_to_hf.py | 137 +++--- .../models/parakeet/modeling_parakeet.py | 404 ++++++++++++++++- .../models/parakeet/modular_parakeet.py | 386 +++++++++++++++- src/transformers/pipelines/__init__.py | 3 +- .../pipelines/automatic_speech_recognition.py | 13 +- .../models/parakeet/test_modeling_parakeet.py | 425 ++++++++++++++++++ 10 files changed, 1459 insertions(+), 80 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 18f8c632182a..af8ea68ebb26 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -325,7 +325,10 @@ ("paddleocr_vl", "PaddleOCRVLConfig"), ("paligemma", "PaliGemmaConfig"), ("parakeet_ctc", "ParakeetCTCConfig"), + ("parakeet_tdt", "ParakeetTDTConfig"), ("parakeet_encoder", "ParakeetEncoderConfig"), + ("parakeet_tdt_decoder", "ParakeetTDTDecoderConfig"), + ("parakeet_tdt_joint", "ParakeetTDTJointConfig"), ("patchtsmixer", "PatchTSMixerConfig"), ("patchtst", "PatchTSTConfig"), ("pe_audio", "PeAudioConfig"), @@ -823,7 +826,10 @@ ("paligemma", "PaliGemma"), ("parakeet", "Parakeet"), ("parakeet_ctc", "Parakeet"), + ("parakeet_tdt", "ParakeetTDT"), ("parakeet_encoder", "ParakeetEncoder"), + ("parakeet_tdt_decoder", "ParakeetTDTDecoder"), + ("parakeet_tdt_joint", "ParakeetTDTJoint"), ("patchtsmixer", "PatchTSMixer"), ("patchtst", "PatchTST"), ("pe_audio", "PeAudio"), @@ -1083,8 +1089,11 @@ ("pe_audio_video_encoder", "pe_audio_video"), ("video_llama_3_vision", "video_llama_3"), ("parakeet_encoder", "parakeet"), - ("lw_detr_vit", "lw_detr"), + ("parakeet_tdt_decoder", "parakeet"), + ("parakeet_tdt_joint", "parakeet"), ("parakeet_ctc", "parakeet"), + ("parakeet_tdt", "parakeet"), + ("lw_detr_vit", "lw_detr"), ("lasr_encoder", "lasr"), ("lasr_ctc", "lasr"), ("wav2vec2-bert", "wav2vec2_bert"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 58e253718af2..baf70fd306b1 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -59,6 +59,7 @@ ("musicgen", "EncodecFeatureExtractor"), ("musicgen_melody", "MusicgenMelodyFeatureExtractor"), ("parakeet_ctc", "ParakeetFeatureExtractor"), + ("parakeet_tdt", "ParakeetFeatureExtractor"), ("parakeet_encoder", "ParakeetFeatureExtractor"), ("pe_audio", "PeAudioFeatureExtractor"), ("pe_audio_video", "PeAudioFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 952ff1da2bfa..39250122e5a8 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -319,7 +319,10 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("owlvit", "OwlViTModel"), ("paligemma", "PaliGemmaModel"), ("parakeet_ctc", "ParakeetForCTC"), + ("parakeet_tdt", "ParakeetForTDT"), ("parakeet_encoder", "ParakeetEncoder"), + ("parakeet_tdt_decoder", "ParakeetTDTDecoder"), + ("parakeet_tdt_joint", "ParakeetTDTJoint"), ("patchtsmixer", "PatchTSMixerModel"), ("patchtst", "PatchTSTModel"), ("pe_audio", "PeAudioModel"), @@ -1551,6 +1554,14 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ] ) +MODEL_FOR_TDT_MAPPING_NAMES = OrderedDict( + [ + # Model for Token-and-Duration Transducer (TDT) mapping. + ("parakeet_tdt", "ParakeetForTDT"), + ] +) + + MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Audio Classification mapping @@ -1816,6 +1827,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES ) MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) +MODEL_FOR_TDT_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TDT_MAPPING_NAMES) MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES @@ -2124,6 +2136,11 @@ class AutoModelForCTC(_BaseAutoModelClass): AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") +class AutoModelForTDT(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TDT_MAPPING + + +AutoModelForTDT = auto_class_update(AutoModelForTDT, head_doc="token-and-duration transducer") class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING @@ -2187,6 +2204,7 @@ class AutoModelForAudioTokenization(_BaseAutoModelClass): "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CTC_MAPPING", + "MODEL_FOR_TDT_MAPPING", "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", @@ -2233,6 +2251,7 @@ class AutoModelForAudioTokenization(_BaseAutoModelClass): "AutoModelForAudioXVector", "AutoModelForCausalLM", "AutoModelForCTC", + "AutoModelForTDT", "AutoModelForDepthEstimation", "AutoModelForImageClassification", "AutoModelForImageSegmentation", diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 6b8ead0a1e85..96d11ca012bb 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -149,6 +149,65 @@ def __init__( ) + +class ParakeetTDTDecoderConfig(PreTrainedConfig): + model_type = "parakeet_tdt_decoder" + keys_to_ignore_at_inference = ["past_key_values"] + output_hidden_states = False + + def __init__( + self, + hidden_size=640, + num_hidden_layers=1, + dropout=0, + vocab_size=1024, + forget_gate_bias=1.0, + t_max=None, + weights_init_scale=1.0, + hidden_hidden_bias_scale=0, + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.dropout = dropout + self.vocab_size = vocab_size + self.forget_gate_bias=forget_gate_bias + self.t_max=t_max + self.weights_init_scale=weights_init_scale + self.hidden_hidden_bias_scale=hidden_hidden_bias_scale + + +class ParakeetTDTJointConfig(PreTrainedConfig): + model_type = "parakeet_tdt_joint" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + enc_hidden_size=1024, + pred_hidden_size=640, + hidden_size=640, + vocab_size=1024, + durations=[0,1,2,3,4], + norm=None, + dropout=0.0, + activation='relu', + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.enc_hidden_size = enc_hidden_size + self.pred_hidden_size = pred_hidden_size + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.durations = durations + self.dropout = dropout + self.activation = activation + + class ParakeetCTCConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`ParakeetForCTC`]. It is used to instantiate a @@ -229,4 +288,83 @@ def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): return cls(encoder_config=encoder_config.to_dict(), **kwargs) -__all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig"] +class ParakeetTDTConfig(PreTrainedConfig): + + model_type = "parakeet_tdt" + sub_configs = {"encoder_config": ParakeetEncoderConfig, "decoder_config": ParakeetTDTDecoderConfig, "joint_config": ParakeetTDTJointConfig} + + def __init__( + self, +# bos_token_id=1, +# eos_token_id=2, +# pad_token_id=1024, + tdt_loss_reduction="mean", + encoder_config: Union[dict, ParakeetEncoderConfig] = None, + decoder_config: Union[dict, ParakeetTDTDecoderConfig] = None, + joint_config: Union[dict, ParakeetTDTJointConfig] = None, + **kwargs, + ): + + if encoder_config is None: + self.encoder_config = ParakeetEncoderConfig() + elif isinstance(encoder_config, dict): + self.encoder_config = ParakeetEncoderConfig(**encoder_config) + elif isinstance(encoder_config, ParakeetEncoderConfig): + self.encoder_config = encoder_config + else: + raise ValueError( + f"`encoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" + ) + + if decoder_config is None: + self.decoder_config = ParakeetTDTDecoderConfig() + elif isinstance(decoder_config, dict): + self.decoder_config = ParakeetTDTDecoderConfig(**decoder_config) + elif isinstance(decoder_config, ParakeetTDTDecoderConfig): + self.decoder_config = decoder_config + else: + raise ValueError( + f"`decoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" + ) + + if joint_config is None: + self.joint_config = ParakeetTDTJointConfig() + elif isinstance(joint_config, dict): + self.joint_config = ParakeetTDTJointConfig(**joint_config) + elif isinstance(joint_config, ParakeetTDTJointConfig): + self.joint_config = joint_config + else: + raise ValueError( + f"`decoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" + ) + + vocab_size = self.joint_config.vocab_size + self.vocab_size = vocab_size + + self.blank_token_id = vocab_size + super().__init__( +# pad_token_id=self.blank_token_id, + **kwargs, + ) + + @classmethod + def from_configs( + cls, + encoder_config: ParakeetEncoderConfig, + decoder_config: ParakeetTDTDecoderConfig, + joint_config: ParakeetTDTJointConfig, + **kwargs): + r""" + Instantiate a [`ParakeetConfig`] (or a derived class) from parakeet encoder model configuration. + + Returns: + [`ParakeetConfig`]: An instance of a configuration object + """ + + return cls( + encoder_config=encoder_config.to_dict(), + decoder_config=decoder_config.to_dict(), + joint_config=joint_config.to_dict(), + **kwargs) + +__all__ = ["ParakeetCTCConfig", "ParakeetTDTConfig", "ParakeetEncoderConfig", "ParakeetTDTDecoderConfig", "ParakeetTDTJointConfig"] diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index 2d4085e6d340..b57a58a6eca1 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -24,10 +24,10 @@ from transformers import ( ParakeetCTCConfig, - ParakeetEncoder, - ParakeetEncoderConfig, + ParakeetTDTConfig, ParakeetFeatureExtractor, ParakeetForCTC, + ParakeetForTDT, ParakeetProcessor, ParakeetTokenizer, ) @@ -223,8 +223,8 @@ def convert_encoder_config(nemo_config): "conv_context_size", "dropout_pre_encoder", "reduction", - "reduction_factor", "reduction_position", + "reduction_factor", ] encoder_config_keys_mapping = { "d_model": "hidden_size", @@ -243,17 +243,62 @@ def convert_encoder_config(nemo_config): } converted_encoder_config = {} + decoder_keys_to_ignore = [ + "_target_", + "normalization_mode", + "random_state_sampling", + "blank_as_pad", + "prednet", + ] + decoder_config_keys_mapping = { + "vocab_size": "vocab_size", + } + converted_decoder_config = {} + + joint_keys_to_ignore = [ + "_target_", + 'log_softmax', + 'preserve_memory', + 'fuse_loss_wer', + 'fused_batch_size', + 'jointnet', + 'vocabulary' + ] + joint_config_keys_mapping = { + "vocab_size": "vocab_size", + "num_classes": "num_classes", + "num_extra_outputs": "num_extra_outputs", + } + converted_joint_config = {} + + for key, value in nemo_config["encoder"].items(): if key in encoder_keys_to_ignore: continue if key in encoder_config_keys_mapping: converted_encoder_config[encoder_config_keys_mapping[key]] = value - # NeMo uses 'use_bias' for both attention and convolution bias, but HF separates them - if key == "use_bias": - converted_encoder_config["convolution_bias"] = value else: raise ValueError(f"Key {key} not found in encoder_config_keys_mapping") + if model_type == 'tdt': + for key, value in nemo_config["decoder"].items(): + if key in decoder_keys_to_ignore: + continue + if key in decoder_config_keys_mapping: + converted_decoder_config[decoder_config_keys_mapping[key]] = value + else: + raise ValueError(f"Key {key} not found in encoder_config_keys_mapping") + + for key, value in nemo_config["joint"].items(): + if key in joint_keys_to_ignore: + continue + if key in joint_config_keys_mapping: + converted_joint_config[joint_config_keys_mapping[key]] = value + else: + raise ValueError(f"Key {key} not found in encoder_config_keys_mapping") + + converted_joint_config["vocab_size"] = converted_joint_config["num_classes"] + return ParakeetEncoderConfig(**converted_encoder_config) @@ -286,63 +331,39 @@ def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_re print("Saving the model.") model.save_pretrained(output_dir) - if push_to_repo_id: - model.push_to_hub(push_to_repo_id) + elif model_type == "tdt": + num_classes = converted_joint_config["num_classes"] + model_config = ParakeetTDTConfig( + pad_token_id=num_classes, + vocab_size=num_classes+1, + blank_token_id=num_classes, + encoder_config=converted_encoder_config, + decoder_config=converted_decoder_config, + joint_config=converted_joint_config, + ) + print("Loading the checkpoint in a Parakeet TDT model.") + with torch.device("meta"): + model = ParakeetForTDT(model_config) + model.load_state_dict(converted_state_dict, strict=True, assign=True) + print("Checkpoint loaded successfully.") + del model.config._name_or_path - del model + print("Saving the model.") + model.save_pretrained(output_dir) - # Safety check: reload the converted model - gc.collect() - print("Reloading the model to check if it's saved correctly.") - ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") - print("Model reloaded successfully.") + if push_to_repo_id: + model.push_to_hub(push_to_repo_id) + del converted_state_dict, model -def write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None): - """Write encoder model using encoder config and converted state dict.""" - # Filter to only encoder weights (exclude CTC head if present) - encoder_state_dict = { - k.replace("encoder.", "", 1) if k.startswith("encoder.") else k: v - for k, v in converted_state_dict.items() - if k.startswith("encoder.") - } + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + ParakeetForTDT.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") + print("Model reloaded successfully.") - print("Loading the checkpoint in a Parakeet Encoder model (for TDT).") - with torch.device("meta"): - model = ParakeetEncoder(encoder_config) - model.load_state_dict(encoder_state_dict, strict=True, assign=True) - print("Checkpoint loaded successfully.") - del model.config._name_or_path - print("Saving the model.") - model.save_pretrained(output_dir) - - if push_to_repo_id: - model.push_to_hub(push_to_repo_id) - del model - - # Safety check: reload the converted model - gc.collect() - print("Reloading the model to check if it's saved correctly.") - ParakeetEncoder.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") - print("Model reloaded successfully.") - - -def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None): - """Main model conversion function.""" - # Step 1: Convert encoder config (shared across all model types) - encoder_config = convert_encoder_config(nemo_config) - print(f"Converted encoder config: {encoder_config}") - - # Step 2: Load and convert state dict (shared across all model types) - converted_state_dict = load_and_convert_state_dict(model_files) - - # Step 3: Write model based on type - if model_type == "encoder": - write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id) - elif model_type == "ctc": - write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id) else: raise ValueError(f"Model type {model_type} not supported.") @@ -366,9 +387,7 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co") - parser.add_argument( - "--model_type", required=True, choices=["encoder", "ctc"], help="Model type (`encoder`, `ctc`)" - ) + parser.add_argument("--model_type", required=True, choices=["ctc","tdt"], help="Model type (`ctc`, `tdt`)") parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model") parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub") args = parser.parse_args() diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 23be85a2f827..76251433ee92 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -29,13 +29,19 @@ from ...activations import ACT2FN from ...integrations import use_kernel_func_from_hub, use_kernelized_func from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, CausalLMOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithNoAttention, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import maybe_autocast, merge_with_config_defaults -from ...utils.output_capturing import capture_outputs -from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig +from ...utils.generic import check_model_inputs +from .configuration_parakeet import ( + ParakeetCTCConfig, + ParakeetEncoderConfig, + ParakeetTDTConfig, + ParakeetTDTDecoderConfig, + ParakeetTDTJointConfig, + PreTrainedConfig, +) @dataclass @@ -132,7 +138,7 @@ def __init__(self, config: ParakeetEncoderConfig, module_config=None): self.activation = ACT2FN[module_config.get("activation", "silu")] self.padding = (kernel_size - 1) // 2 self.pointwise_conv1 = nn.Conv1d( - channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias + channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias ) self.depthwise_conv = nn.Conv1d( channels, @@ -141,11 +147,11 @@ def __init__(self, config: ParakeetEncoderConfig, module_config=None): stride=1, padding=self.padding, groups=channels, - bias=config.convolution_bias, + bias=config.attention_bias, ) self.norm = nn.BatchNorm1d(channels) self.pointwise_conv2 = nn.Conv1d( - channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias + channels, channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias ) def forward(self, hidden_states, attention_mask=None): @@ -282,7 +288,9 @@ def __init__(self, config: ParakeetEncoderConfig, layer_idx: int): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) # W_{k,R} projection - self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.relative_k_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) # global content bias self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim)) # global positional bias @@ -472,7 +480,7 @@ def forward( @auto_docstring class ParakeetPreTrainedModel(PreTrainedModel): - config: ParakeetCTCConfig + config: PreTrainedConfig base_model_prefix = "model" main_input_name = "input_features" input_modalities = "audio" @@ -513,7 +521,11 @@ def _init_weights(self, module): init.copy_(module.inv_freq, inv_freq) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): - encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config + encoder_config = ( + self.config.encoder_config + if isinstance(self.config, (ParakeetCTCConfig, ParakeetTDTConfig)) + else self.config + ) kernel_size = encoder_config.subsampling_conv_kernel_size stride = encoder_config.subsampling_conv_stride @@ -667,6 +679,270 @@ class ParakeetGenerateOutput(ModelOutput): hidden_states: tuple[tuple[torch.FloatTensor]] | None = None +class ParakeetLSTM(torch.nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + dropout: Optional[float], + forget_gate_bias: Optional[float], + t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, + proj_size: int = 0, + ): + """Returns an LSTM with forget gate bias init to `forget_gate_bias`. + Args: + input_size: See `torch.nn.LSTM`. + hidden_size: See `torch.nn.LSTM`. + num_layers: See `torch.nn.LSTM`. + dropout: See `torch.nn.LSTM`. + + forget_gate_bias: float, set by default to 1.0, which constructs a forget gate + initialized to 1.0. + Reference: + [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf) + + t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization + of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course + of training. + Reference: + [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) + + weights_init_scale: Float scale of the weights after initialization. Setting to lower than one + sometimes helps reduce variance between runs. + + hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for + the default behaviour. + + Returns: + A `torch.nn.LSTM`. + """ + super(ParakeetLSTM, self).__init__() + + self.lstm = torch.nn.LSTM( + input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, proj_size=proj_size + ) + + if t_max is not None: + # apply chrono init + for name, v in self.lstm.named_parameters(): + if "bias" in name: + p = getattr(self.lstm, name) + n = p.nelement() + hidden_size = n // 4 + p.data.fill_(0) + p.data[hidden_size : 2 * hidden_size] = torch.log( + torch.nn.init.uniform_(p.data[0:hidden_size], 1, t_max - 1) + ) + # forget gate biases = log(uniform(1, Tmax-1)) + p.data[0:hidden_size] = -p.data[hidden_size : 2 * hidden_size] + # input gate biases = -(forget gate biases) + + elif forget_gate_bias is not None: + for name, v in self.lstm.named_parameters(): + if "bias_ih" in name: + bias = getattr(self.lstm, name) + bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias) + if "bias_hh" in name: + bias = getattr(self.lstm, name) + bias.data[hidden_size : 2 * hidden_size] *= float(hidden_hidden_bias_scale) + + self.dropout = torch.nn.Dropout(dropout) if dropout else None + + for name, v in self.named_parameters(): + if "weight" in name or "bias" in name: + v.data *= float(weights_init_scale) + + def forward( + self, x: torch.Tensor, h: Optional[tuple[torch.Tensor, torch.Tensor]] = None + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + x, h = self.lstm(x, h) + + if self.dropout: + x = self.dropout(x) + + return x, h + + +class ParakeetTDTJoint(ParakeetPreTrainedModel): + config: ParakeetTDTJointConfig + base_model_prefix = "" # joint" + main_input_name = "enc" + _supports_flat_attention_mask = False + _supports_sdpa = True + _supports_flex_attn = False + _supports_attention_backend = False + _can_record_outputs = {} + _no_split_modules = None + + def __init__(self, config: ParakeetTDTJointConfig): + super().__init__(config) + self.config = config + self.gradient_checkpointing = False + + self.enc = torch.nn.Linear(config.enc_hidden_size, config.hidden_size) + self.pred = torch.nn.Linear(config.pred_hidden_size, config.hidden_size) + + num_classes = config.vocab_size + 1 + len(config.durations) + + layers = ( + [torch.nn.ReLU(inplace=True)] + + ([torch.nn.Dropout(p=self.config.dropout)]) + + [torch.nn.Linear(config.hidden_size, num_classes)] + ) + self.joint_net = torch.nn.Sequential(*layers) + self.post_init() + + @auto_docstring + @check_model_inputs() + def forward( + self, + enc: torch.Tensor, + pred: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithNoAttention: + # Right now we only support joint for inference. + + pred = pred.view([-1, self.config.pred_hidden_size]) # making it B, D + enc = enc.view([-1, self.config.enc_hidden_size]) # making it B, D + enc = self.enc(enc) + pred = self.pred(pred) + + assert enc.shape[0] == pred.shape[0] + output = self.joint_net(enc + pred) + return BaseModelOutput(last_hidden_state=output) + + +class ParakeetTDTPredictor(ParakeetPreTrainedModel): + def __init__(self, config: ParakeetTDTDecoderConfig): + super().__init__(config) + self.gradient_checkpointing = False + self.config = config + + self.embed = torch.nn.Embedding(config.vocab_size + 1, config.hidden_size) # +1 for blank + self.dec_rnn = self.rnn( + config.hidden_size, + config.hidden_size, + config.num_hidden_layers + 1, + config.forget_gate_bias, + config.dropout, + config.t_max, + config.weights_init_scale, + config.hidden_hidden_bias_scale, + ) + self.post_init() + + def rnn( + self, + input_size: int, + hidden_size: int, + num_layers: int, + forget_gate_bias: Optional[float] = 1.0, + dropout: Optional[float] = 0.0, + t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, + proj_size: int = 0, + ) -> torch.nn.Module: + return ParakeetLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + forget_gate_bias=forget_gate_bias, + t_max=t_max, + weights_init_scale=weights_init_scale, + hidden_hidden_bias_scale=hidden_hidden_bias_scale, + proj_size=proj_size, + ) + + @auto_docstring + @check_model_inputs() + @can_return_tuple + def forward( + self, + input_token, + states, + hidden_state=None, + **kwargs: Unpack[TransformersKwargs], + ): + assert input_token is not None + + device = self.embed.weight.device + if input_token.device != device: + input_token = input_token.to(device) + return self.predict(input_token, state=states) + + def predict(self, y, state): + # Get device and dtype of current module + + # (B, U) -> (B, U, H) + y = self.embed(y).transpose(0, 1) # (U + 1, B, H) + + g, hid = self.dec_rnn(y, state) + g = g.transpose(0, 1).transpose(1, 2) # (B, H, U + 1) + + return g, hid + + +@auto_docstring( + custom_intro=""" + The Parakeet TDT Decoder. This class encapsulates both the predictor and joint network for TDT models. + """ +) +class ParakeetTDTDecoder(ParakeetPreTrainedModel): + config: ParakeetTDTDecoderConfig + base_model_prefix = "decoder" + main_input_name = "input_token" + _supports_flat_attention_mask = False + _supports_sdpa = True + _supports_flex_attn = False + _supports_attention_backend = False + _can_record_outputs = {} + _no_split_modules = None + + def __init__(self, config: ParakeetTDTDecoderConfig): + super().__init__(config) + self.config = config + self.gradient_checkpointing = False + self.prediction = ParakeetTDTPredictor(config) + self.post_init() + + def _init_weights(self, module): + if hasattr(self.config, "initializer_range"): + std = self.config.initializer_range + else: + # 0.02 is the standard default value accross the library + std = getattr(self.config.get_text_config(), "initializer_range", 0.02) + + module.prediction.embed.weight.data.normal_(mean=0.0, std=std) + for param in module.prediction.dec_rnn.lstm.parameters(): + param.data.normal_(mean=0.0, std=std) + + def get_input_embeddings(self): + return self.prediction.embed + + def set_input_embeddings(self, embed): + self.prediction.embed = embed + + @auto_docstring + @check_model_inputs() + @can_return_tuple + def forward( + self, + input_token, + hidden_state=None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithNoAttention: + if hidden_state is not None: + hidden_state = tuple(hidden_state.unbind(dim=0)) + + h_out, h_state = self.prediction(input_token, hidden_state, **kwargs) + return BaseModelOutputWithNoAttention(h_out, torch.stack(h_state, dim=0)) + + @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -811,4 +1087,110 @@ def generate( return sequences -__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"] +@auto_docstring( + custom_intro=""" + Parakeet TDT model. + """ +) +class ParakeetForTDT(ParakeetPreTrainedModel): + config: ParakeetTDTConfig + + def __init__(self, config: ParakeetTDTConfig): + super().__init__(config) + self.encoder = ParakeetEncoder(config.encoder_config) + self.decoder = ParakeetTDTDecoder(config.decoder_config) + self.joint = ParakeetTDTJoint(config.joint_config) + self.blank_token_id = config.blank_token_id + self.max_token_per_frame = 2 + self.post_init() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_features: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + encoder_outputs = self.encoder( + input_features=input_features, + **kwargs, + ) + + logits = self.joint.joint_net( + self.joint.enc(encoder_outputs.last_hidden_state) + ) # [:,:,:self.joint.vocab_size] + + return CausalLMOutput( + loss=torch.sum(encoder_outputs.last_hidden_state), # a fake loss here. + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_features: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + encoder_outputs = self.encoder( + input_features=input_features, + **kwargs, + ) + output = self.greedy_decode(encoder_outputs.last_hidden_state) + + return output + + def greedy_decode(self, encoder_output): + T = encoder_output.shape[1] + t = 0 + hyp = [] + last_label = torch.LongTensor([[self.blank_token_id]]) + dec_out = self.decoder(input_token=last_label) + g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states + + symbols_added = 0 + while t < T: + enc = encoder_output[0, t, :] + while symbols_added < self.max_token_per_frame: + logits = self.joint(enc, g).last_hidden_state + + logits = logits.view([-1]) + + token_logits = logits[: self.blank_token_id + 1].softmax(-1) + duration_logits = logits[self.blank_token_id + 1 :].softmax(-1) + + v, token = token_logits.max(-1) + v_duration, duration = duration_logits.max(-1) + token = token.item() + duration = duration.item() + + if token != self.blank_token_id: + hyp.append(token) + last_label = token + last_label = torch.LongTensor([[last_label]]) + dec_out = self.decoder(last_label, hidden_prime) + g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states + + if duration == 0: + symbols_added += 1 + else: + t += duration + symbols_added = 0 + break + + if symbols_added == self.max_token_per_frame: + t += 1 + symbols_added = 0 + + return hyp + + +__all__ = [ + "ParakeetForCTC", + "ParakeetForTDT", + "ParakeetEncoder", + "ParakeetTDTDecoder", + "ParakeetTDTJoint", + "ParakeetPreTrainedModel", +] diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 93ef18b1fa49..594ea73e3f8a 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -16,6 +16,7 @@ import math from collections.abc import Callable from dataclasses import dataclass +from typing import Optional, Union, Tuple import torch from torch import nn @@ -23,7 +24,7 @@ from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, CausalLMOutput +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, BaseModelOutputWithNoAttention from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple @@ -31,7 +32,7 @@ from ...utils.output_capturing import capture_outputs from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule from ..llama.modeling_llama import LlamaAttention, eager_attention_forward -from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig +from .configuration_parakeet import PreTrainedConfig, ParakeetCTCConfig, ParakeetTDTConfig, ParakeetEncoderConfig, ParakeetTDTDecoderConfig, ParakeetTDTJointConfig @dataclass @@ -121,7 +122,7 @@ def __init__(self, config: ParakeetEncoderConfig, layer_idx: int): super().__init__(config, layer_idx=layer_idx) self.is_causal = False # W_{k,R} projection - self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) # global content bias self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim)) # global positional bias @@ -311,7 +312,7 @@ def forward( @auto_docstring class ParakeetPreTrainedModel(PreTrainedModel): - config: ParakeetCTCConfig + config: PreTrainedConfig base_model_prefix = "model" main_input_name = "input_features" input_modalities = "audio" @@ -352,7 +353,7 @@ def _init_weights(self, module): init.copy_(module.inv_freq, inv_freq) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): - encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config + encoder_config = self.config.encoder_config if isinstance(self.config, (ParakeetCTCConfig, ParakeetTDTConfig)) else self.config kernel_size = encoder_config.subsampling_conv_kernel_size stride = encoder_config.subsampling_conv_stride @@ -506,6 +507,277 @@ class ParakeetGenerateOutput(ModelOutput): hidden_states: tuple[tuple[torch.FloatTensor]] | None = None +class ParakeetLSTM(torch.nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + dropout: Optional[float], + forget_gate_bias: Optional[float], + t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, + proj_size: int = 0, + ): + """Returns an LSTM with forget gate bias init to `forget_gate_bias`. + Args: + input_size: See `torch.nn.LSTM`. + hidden_size: See `torch.nn.LSTM`. + num_layers: See `torch.nn.LSTM`. + dropout: See `torch.nn.LSTM`. + + forget_gate_bias: float, set by default to 1.0, which constructs a forget gate + initialized to 1.0. + Reference: + [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf) + + t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization + of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course + of training. + Reference: + [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) + + weights_init_scale: Float scale of the weights after initialization. Setting to lower than one + sometimes helps reduce variance between runs. + + hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for + the default behaviour. + + Returns: + A `torch.nn.LSTM`. + """ + super(ParakeetLSTM, self).__init__() + + self.lstm = torch.nn.LSTM( + input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, proj_size=proj_size + ) + + if t_max is not None: + # apply chrono init + for name, v in self.lstm.named_parameters(): + if 'bias' in name: + p = getattr(self.lstm, name) + n = p.nelement() + hidden_size = n // 4 + p.data.fill_(0) + p.data[hidden_size : 2 * hidden_size] = torch.log( + torch.nn.init.uniform_(p.data[0:hidden_size], 1, t_max - 1) + ) + # forget gate biases = log(uniform(1, Tmax-1)) + p.data[0:hidden_size] = -p.data[hidden_size : 2 * hidden_size] + # input gate biases = -(forget gate biases) + + elif forget_gate_bias is not None: + for name, v in self.lstm.named_parameters(): + if "bias_ih" in name: + bias = getattr(self.lstm, name) + bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias) + if "bias_hh" in name: + bias = getattr(self.lstm, name) + bias.data[hidden_size : 2 * hidden_size] *= float(hidden_hidden_bias_scale) + + self.dropout = torch.nn.Dropout(dropout) if dropout else None + + for name, v in self.named_parameters(): + if 'weight' in name or 'bias' in name: + v.data *= float(weights_init_scale) + + def forward( + self, x: torch.Tensor, h: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + x, h = self.lstm(x, h) + + if self.dropout: + x = self.dropout(x) + + return x, h + +class ParakeetTDTJoint(ParakeetPreTrainedModel): + config: ParakeetTDTJointConfig + base_model_prefix = "" #joint" + main_input_name = "enc" + _supports_flat_attention_mask = False + _supports_sdpa = True + _supports_flex_attn = False + _supports_attention_backend = False + _can_record_outputs = {} + _no_split_modules = None + + def __init__(self, config: ParakeetTDTJointConfig): + super().__init__(config) + self.config = config + self.gradient_checkpointing = False + + self.enc = torch.nn.Linear(config.enc_hidden_size, config.hidden_size) + self.pred = torch.nn.Linear(config.pred_hidden_size, config.hidden_size) + + num_classes = config.vocab_size + 1 + len(config.durations) + + layers = ( + [torch.nn.ReLU(inplace=True)] + + ([torch.nn.Dropout(p=self.config.dropout)]) + + [torch.nn.Linear(config.hidden_size, num_classes)] + ) + self.joint_net = torch.nn.Sequential(*layers) + self.post_init() + + @auto_docstring + @check_model_inputs() + def forward( + self, + enc: torch.Tensor, + pred: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithNoAttention: + + # Right now we only support joint for inference. + + pred = pred.view([-1, self.config.pred_hidden_size]) # making it B, D + enc = enc.view([-1, self.config.enc_hidden_size]) # making it B, D + enc = self.enc(enc) + pred = self.pred(pred) + + assert enc.shape[0] == pred.shape[0] + output = self.joint_net(enc + pred) + return BaseModelOutput(last_hidden_state=output) + + +class ParakeetTDTPredictor(ParakeetPreTrainedModel): + + def __init__(self, config: ParakeetTDTDecoderConfig): + super().__init__(config) + self.gradient_checkpointing = False + self.config = config + + self.embed = torch.nn.Embedding(config.vocab_size + 1, config.hidden_size) # +1 for blank + self.dec_rnn = self.rnn( + config.hidden_size, + config.hidden_size, + config.num_hidden_layers + 1, + config.forget_gate_bias, + config.dropout, + config.t_max, + config.weights_init_scale, + config.hidden_hidden_bias_scale, + ) + self.post_init() + + + def rnn( + self, + input_size: int, + hidden_size: int, + num_layers: int, + forget_gate_bias: Optional[float] = 1.0, + dropout: Optional[float] = 0.0, + t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, + proj_size: int = 0, + ) -> torch.nn.Module: + return ParakeetLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + forget_gate_bias=forget_gate_bias, + t_max=t_max, + weights_init_scale=weights_init_scale, + hidden_hidden_bias_scale=hidden_hidden_bias_scale, + proj_size=proj_size, + ) + + + @auto_docstring + @check_model_inputs() + @can_return_tuple + def forward( + self, + input_token, + states, + hidden_state = None, + **kwargs: Unpack[TransformersKwargs], + ): + assert input_token is not None + + device = self.embed.weight.device + if input_token.device != device: + input_token = input_token.to(device) + return self.predict(input_token, state=states) + + def predict(self, y, state): + # Get device and dtype of current module + + # (B, U) -> (B, U, H) + y = self.embed(y).transpose(0, 1) # (U + 1, B, H) + + g, hid = self.dec_rnn(y, state) + g = g.transpose(0, 1).transpose(1, 2) # (B, H, U + 1) + + return g, hid + + + +@auto_docstring( + custom_intro=""" + The Parakeet TDT Decoder. This class encapsulates both the predictor and joint network for TDT models. + """ +) +class ParakeetTDTDecoder(ParakeetPreTrainedModel): + config: ParakeetTDTDecoderConfig + base_model_prefix = "decoder" + main_input_name = "input_token" + _supports_flat_attention_mask = False + _supports_sdpa = True + _supports_flex_attn = False + _supports_attention_backend = False + _can_record_outputs = {} + _no_split_modules = None + + def __init__(self, config: ParakeetTDTDecoderConfig): + super().__init__(config) + self.config = config + self.gradient_checkpointing = False + self.prediction = ParakeetTDTPredictor(config) + self.post_init() + + def _init_weights(self, module): + if hasattr(self.config, "initializer_range"): + std = self.config.initializer_range + else: + # 0.02 is the standard default value accross the library + std = getattr(self.config.get_text_config(), "initializer_range", 0.02) + + module.prediction.embed.weight.data.normal_(mean=0.0, std=std) + for param in module.prediction.dec_rnn.lstm.parameters(): + param.data.normal_(mean=0.0, std=std) + + def get_input_embeddings(self): + return self.prediction.embed + + def set_input_embeddings(self, embed): + self.prediction.embed = embed + + @auto_docstring + @check_model_inputs() + @can_return_tuple + def forward( + self, + input_token, + hidden_state = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithNoAttention: + + if hidden_state is not None: + hidden_state = tuple(hidden_state.unbind(dim=0)) + + h_out, h_state = self.prediction(input_token, hidden_state, **kwargs) + return BaseModelOutputWithNoAttention(h_out, torch.stack(h_state, dim=0)) + + + @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -650,4 +922,106 @@ def generate( return sequences -__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"] +@auto_docstring( + custom_intro=""" + Parakeet TDT model. + """ +) +class ParakeetForTDT(ParakeetPreTrainedModel): + config: ParakeetTDTConfig + + def __init__(self, config: ParakeetTDTConfig): + super().__init__(config) + self.encoder = ParakeetEncoder(config.encoder_config) + self.decoder = ParakeetTDTDecoder(config.decoder_config) + self.joint = ParakeetTDTJoint(config.joint_config) + self.blank_token_id = config.blank_token_id + self.max_token_per_frame = 2 + self.post_init() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_features: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + encoder_outputs = self.encoder( + input_features=input_features, + **kwargs, + ) + + logits = self.joint.joint_net(self.joint.enc(encoder_outputs.last_hidden_state)) #[:,:,:self.joint.vocab_size] + + return CausalLMOutput( + loss=torch.sum(encoder_outputs.last_hidden_state), # a fake loss here. + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_features: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + + encoder_outputs = self.encoder( + input_features=input_features, + **kwargs, + ) + output = self.greedy_decode(encoder_outputs.last_hidden_state) + + return output + + def greedy_decode(self, encoder_output): + T = encoder_output.shape[1] + t = 0 + hyp = [] + last_label = torch.LongTensor([[self.blank_token_id]]) + dec_out = self.decoder(input_token=last_label) + g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states + + symbols_added = 0 + while t < T: + enc = encoder_output[0,t,:] + while symbols_added < self.max_token_per_frame: + logits = self.joint(enc, g).last_hidden_state + + logits = logits.view([-1]) + + token_logits = logits[:self.blank_token_id + 1].softmax(-1) + duration_logits = logits[self.blank_token_id + 1:].softmax(-1) + + v, token = token_logits.max(-1) + v_duration, duration = duration_logits.max(-1) + token = token.item() + duration = duration.item() + + if token != self.blank_token_id: + hyp.append(token) + last_label = token + last_label = torch.LongTensor([[last_label]]) + dec_out = self.decoder(last_label, hidden_prime) + g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states + + if duration == 0: + symbols_added += 1 + else: + t += duration + symbols_added = 0 + break + + if symbols_added == self.max_token_per_frame: + t += 1 + symbols_added = 0 + + + return hyp + + + + + +__all__ = ["ParakeetForCTC", "ParakeetForTDT", "ParakeetEncoder", "ParakeetTDTDecoder", "ParakeetTDTJoint", "ParakeetPreTrainedModel"] diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 57c7a806fdf2..da825ff39223 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -98,6 +98,7 @@ AutoModelForAudioClassification, AutoModelForCausalLM, AutoModelForCTC, + AutoModelForTDT, AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, @@ -147,7 +148,7 @@ }, "automatic-speech-recognition": { "impl": AutomaticSpeechRecognitionPipeline, - "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (), + "pt": (AutoModelForCTC, AutoModelForTDT, AutoModelForSpeechSeq2Seq) if is_torch_available() else (), "default": {"model": ("facebook/wav2vec2-base-960h", "22aad52")}, "type": "multimodal", }, diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 3019f74328c7..30eb9c987697 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -198,6 +198,8 @@ def __init__( self.type = "seq2seq_whisper" elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values(): self.type = "seq2seq" + elif model.config.model_type == "parakeet_tdt": + self.type = "tdt" elif decoder is not None: self.decoder = decoder self.type = "ctc_with_lm" @@ -556,7 +558,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): if stride is not None: out["stride"] = stride - else: + elif self.type in {"ctc", "ctc_with_lm"}: inputs = { self.model.main_input_name: model_inputs.pop(self.model.main_input_name), "attention_mask": attention_mask, @@ -577,6 +579,15 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): out["stride"] = rescale_stride([stride], ratio)[0] else: out["stride"] = rescale_stride(stride, ratio) + elif self.type == 'tdt': + inputs = { + self.model.main_input_name: model_inputs.pop(self.model.main_input_name), + } + outputs = self.model.generate(**inputs) + out = {"tokens": torch.LongTensor(outputs).view([1, -1])} + else: + raise ValueError("Unsupported model type {self.type}.") + # Leftover extra = model_inputs return {"is_last": is_last, **out, **extra} diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 0d23383a130a..46e53421dbd5 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -14,6 +14,7 @@ """Testing suite for the PyTorch Parakeet model.""" import json +import copy import tempfile import unittest from pathlib import Path @@ -34,9 +35,15 @@ from transformers import ( AutoProcessor, ParakeetCTCConfig, + ParakeetTDTConfig, ParakeetEncoder, ParakeetEncoderConfig, + ParakeetTDTDecoder, + ParakeetTDTDecoderConfig, + ParakeetTDTJoint, + ParakeetTDTJointConfig, ParakeetForCTC, + ParakeetForTDT, ) @@ -183,6 +190,232 @@ def test_model_get_set_embeddings(self): pass +class ParakeetTDTDecoderModelTester: + def __init__( + self, + parent, + batch_size=16, + vocab_size=128, + hidden_size=64, + num_hidden_layers=2, + seq_length=32, + is_training=True, + dropout=0, # so gradient checkpointing doesn't fail + ): + # testing suite parameters + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + + # config parameters + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.seq_length = seq_length + self.output_seq_length = seq_length + self.vocab_size = vocab_size + + def prepare_config_and_inputs(self): + input_token = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + config = self.get_config() + + return config, input_token + + def get_config(self): + return ParakeetTDTDecoderConfig( + num_hidden_layers=self.num_hidden_layers, + hidden_size=self.hidden_size, + vocab_size=self.vocab_size, + ) + + def create_and_check_model(self, config, input_token): + pass + model = ParakeetTDTDecoder(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_token) + + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, config.hidden_size, self.output_seq_length) + ) + + def prepare_config_and_inputs_for_common(self): + config, input_token = self.prepare_config_and_inputs() + inputs_dict = { + "input_token": input_token, + } + return config, inputs_dict + + + + +@require_torch +class ParakeetTDTDecoderModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (ParakeetTDTDecoder,) if is_torch_available() else () + + test_resize_embeddings = False + test_torch_exportable = True + has_attentions = False + is_encoder_decoder = False + + def setUp(self): + self.model_tester = ParakeetTDTDecoderModelTester(self) + self.config_tester = ConfigTester(self, config_class=ParakeetTDTDecoderConfig, has_text_modality=False, common_properties=['hidden_size','num_hidden_layers']) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(copy.deepcopy(config)) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(hidden_states.shape[1], expected_num_layers) + + if hasattr(self.model_tester, "encoder_seq_length"): + seq_length = self.model_tester.encoder_seq_length + if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1: + seq_length = seq_length * self.model_tester.chunk_length + else: + seq_length = self.model_tester.seq_length + + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + for k in config.sub_configs: + if getattr(config, k) is not None: + getattr(config, k).output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + @unittest.skip(reason="this class only returns the last hidden state not prior ones, and there is no gradient on last hidden state w.r.t output.") + def test_retain_grad_hidden_states_attentions(self): + pass + + +class ParakeetTDTJointModelTester: + def __init__( + self, + parent, + batch_size=16, + vocab_size=128, + hidden_size=64, + pred_hidden_size=64, + enc_hidden_size=64, + num_hidden_layers=2, + durations=[0,1,2,3,4], + is_training=True, + dropout=0.1, # so gradient checkpointing doesn't fail + ): + # testing suite parameters + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + + # config parameters + self.hidden_size = hidden_size + self.pred_hidden_size = pred_hidden_size + self.enc_hidden_size = enc_hidden_size + self.num_hidden_layers = num_hidden_layers + self.t_length = 1 # so far only support 1 + self.u_length = 1 # so far only support 1 + self.output_seq_length = -1 + self.vocab_size = vocab_size + self.durations = durations + + def prepare_config_and_inputs(self): + enc = floats_tensor([self.batch_size, self.t_length, self.enc_hidden_size]) + pred = floats_tensor([self.batch_size, self.u_length, self.pred_hidden_size]) + config = self.get_config() + + return config, enc, pred + + def get_config(self): + return ParakeetTDTJointConfig( + num_hidden_layers=self.num_hidden_layers, + hidden_size=self.hidden_size, + pred_hidden_size=self.enc_hidden_size, + enc_hidden_size=self.enc_hidden_size, + vocab_size=self.vocab_size, + durations=self.durations, + ) + + def create_and_check_model(self, config, enc, pred): + model = ParakeetTDTJoint(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(enc, pred) + + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, config.vocab_size + 1 + len(config.durations)) + ) + + def prepare_config_and_inputs_for_common(self): + config, enc, pred = self.prepare_config_and_inputs() + inputs_dict = { + "enc": enc, + "pred": pred, + } + return config, inputs_dict + + + + +@require_torch +class ParakeetTDTJointModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (ParakeetTDTJoint,) if is_torch_available() else () + + test_resize_embeddings = False + test_torch_exportable = True + has_attentions = False + is_encoder_decoder = False + + def setUp(self): + self.model_tester = ParakeetTDTJointModelTester(self) + self.config_tester = ConfigTester(self, config_class=ParakeetTDTJointConfig, has_text_modality=False, common_properties=['hidden_size']) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="this class doesn't have hidden states.") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="this class doesn't have hidden states.") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="ParakeetJoint does not use inputs_embeds") + def test_model_get_set_embeddings(self): + pass + + + class ParakeetForCTCModelTester: def __init__(self, parent, encoder_kwargs=None, is_training=True, vocab_size=128, pad_token_id=0): if encoder_kwargs is None: @@ -373,3 +606,195 @@ def test_1b_model_integration_batched(self): torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) + + + +class ParakeetForTDTModelTester: + def __init__(self, + parent, + encoder_kwargs=None, + decoder_kwargs=None, + joint_kwargs=None, + is_training=True, + vocab_size=128, + durations=[0,1,2,3,4], + pad_token_id=0 + ): + if encoder_kwargs is None: + encoder_kwargs = {} + if decoder_kwargs is None: + decoder_kwargs = {} + if joint_kwargs is None: + joint_kwargs = {} + + self.parent = parent + self.encoder_model_tester = ParakeetEncoderModelTester(parent, **encoder_kwargs) + self.decoder_model_tester = ParakeetTDTDecoderModelTester(parent, **decoder_kwargs) + self.joint_model_tester = ParakeetTDTJointModelTester(parent, **joint_kwargs) + self.is_training = is_training + + self.batch_size = self.encoder_model_tester.batch_size + self.output_seq_length = self.encoder_model_tester.output_seq_length + self.num_hidden_layers = self.encoder_model_tester.num_hidden_layers + self.seq_length = vocab_size + self.enc_hidden_size = self.encoder_model_tester.hidden_size + self.hidden_size = self.encoder_model_tester.hidden_size # this field is needed for test class + self.pred_hidden_size = self.decoder_model_tester.hidden_size + self.joint_hidden_size = self.joint_model_tester.hidden_size + + self.durations = durations + + self.vocab_size = vocab_size + len(self.durations) + 1 + self.pad_token_id = pad_token_id + + + def prepare_config_and_inputs(self): + _, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs() + config = self.get_config() + return config, input_features, attention_mask + + def get_config(self): + return ParakeetTDTConfig.from_configs( + encoder_config=self.encoder_model_tester.get_config(), + decoder_config=self.decoder_model_tester.get_config(), + joint_config=self.joint_model_tester.get_config(), + vocab_size=self.vocab_size, + durations=self.durations, + ) + + def create_and_check_model(self, config, input_features, attention_mask): + model = ParakeetForTDT(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_features, attention_mask=attention_mask) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.output_seq_length, self.vocab_size)) + + def prepare_config_and_inputs_for_common(self): + config, input_features, attention_mask = self.prepare_config_and_inputs() + inputs_dict = { + "input_features": input_features, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class ParakeetForTDTModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (ParakeetForTDT,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": ParakeetEncoder, + "automatic-speech-recognition": ParakeetForTDT, + } + if is_torch_available() + else {} + ) + + test_attention_outputs = False + + test_resize_embeddings = False + test_torch_exportable = True + + _is_composite = True + + def setUp(self): + self.model_tester = ParakeetForTDTModelTester(self) + self.config_tester = ConfigTester(self, config_class=ParakeetTDTConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="ParakeetEncoder does not use inputs_embeds") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="batching not supported") + def test_batching_equivalence(self): + pass + + # Original function assumes vision+text model, so overwrite since Parakeet is audio+text + # Below is modified from `tests/models/granite_speech/test_modeling_granite_speech.py` + def test_sdpa_can_dispatch_composite_models(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + +@require_torch +class ParakeetForTDTIntegrationTest(unittest.TestCase): + _dataset = None + + @classmethod + def setUp(cls): + cls.checkpoint_name = "hainanx/parakeet-tdt-0.6b-v3" + cls.dtype = torch.bfloat16 + cls.processor = AutoProcessor.from_pretrained("hainanx/parakeet-tdt-0.6b-v3") + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @classmethod + def _load_dataset(cls): + # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process. + if cls._dataset is None: + cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + cls._dataset = cls._dataset.cast_column( + "audio", Audio(sampling_rate=cls.processor.feature_extractor.sampling_rate) + ) + + def _load_datasamples(self, num_samples): + self._load_dataset() + ds = self._dataset + speech_samples = ds.sort("id")[:num_samples]["audio"] + return [x["array"] for x in speech_samples] + + @slow + def test_1b_model_integration(self): + """ + bezzam reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single-py + eustlb reproducer: https://gist.github.com/eustlb/6e9e3aa85de3f7c340ec3c36e65f2fe6 + """ + RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single.json" + with open(RESULTS_PATH, "r") as f: + raw_data = json.load(f) + EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) + EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] + + samples = self._load_datasamples(1) + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) + model.eval() + model.to(torch_device) + + # -- apply + inputs = self.processor(samples) + inputs.to(torch_device, dtype=self.dtype) + predicted_ids = model.generate(**inputs) + torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) + predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) + self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) From f2b493805d62a8f8e181ee8409aa51ccc7dc592a Mon Sep 17 00:00:00 2001 From: Maksym Lypivskyi Date: Fri, 20 Feb 2026 09:20:54 +0100 Subject: [PATCH 02/67] Add TDT decoder support for Parakeet ASR models Implement Token-and-Duration Transducer (TDT) decoding for Parakeet models, extending the existing CTC-only support. This adds ParakeetForTDT with greedy TDT decoding in generate(), per-token timestamp generation, and full integration with AutoModelForTDT, processors, and ASR pipeline. --- docs/source/en/model_doc/auto.md | 4 + docs/source/en/model_doc/parakeet.md | 52 ++ .../models/auto/configuration_auto.py | 10 +- .../models/auto/feature_extraction_auto.py | 2 +- src/transformers/models/auto/modeling_auto.py | 6 +- .../models/auto/processing_auto.py | 2 + .../models/auto/tokenization_auto.py | 2 + src/transformers/models/lasr/modeling_lasr.py | 4 + src/transformers/models/parakeet/__init__.py | 3 +- .../models/parakeet/configuration_parakeet.py | 187 ++--- .../models/parakeet/convert_nemo_to_hf.py | 243 ++++--- .../models/parakeet/modeling_parakeet.py | 664 ++++++++---------- .../models/parakeet/modular_parakeet.py | 636 ++++++++--------- src/transformers/pipelines/__init__.py | 2 +- .../pipelines/automatic_speech_recognition.py | 6 +- .../models/parakeet/test_modeling_parakeet.py | 399 ++--------- 16 files changed, 964 insertions(+), 1258 deletions(-) diff --git a/docs/source/en/model_doc/auto.md b/docs/source/en/model_doc/auto.md index b45b3bfdb187..aaf4a240153b 100644 --- a/docs/source/en/model_doc/auto.md +++ b/docs/source/en/model_doc/auto.md @@ -217,6 +217,10 @@ The following auto classes are available for the following audio tasks. [[autodoc]] AutoModelForCTC +### AutoModelForTDT + +[[autodoc]] AutoModelForTDT + ### AutoModelForSpeechSeq2Seq [[autodoc]] AutoModelForSpeechSeq2Seq diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index b075e6d5ccf7..a758608482e3 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -34,6 +34,11 @@ Parakeet models, [introduced by NVIDIA NeMo](https://developer.nvidia.com/blog/p - 1D convolution projection from encoder hidden size to vocabulary size (for optimal NeMo compatibility). - CTC loss computation for training. - Greedy CTC decoding for inference. +- [**ParakeetForTDT**](#parakeetfortdt): a Fast Conformer Encoder + a TDT (Token Duration Transducer) decoder + - **TDT Decoder**: Jointly predicts tokens and their durations, enabling efficient decoding: + - LSTM prediction network maintains language context across token predictions. + - Joint network combines encoder and decoder outputs. + - Duration head predicts how many frames to skip, enabling fast inference. The original implementation can be found in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo). Model checkpoints are to be found under [the NVIDIA organization](https://huggingface.co/nvidia/models?search=parakeet). @@ -81,6 +86,45 @@ print(processor.batch_decode(outputs)) +### TDT usage + + + + +```py +from transformers import pipeline + +pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-tdt-0.6b-v3") +out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3") +print(out) +``` + + + + +```py +from transformers import AutoModelForTDT, AutoProcessor +from datasets import load_dataset, Audio +import torch + +device = "cuda" if torch.cuda.is_available() else "cpu" + +processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") +model = AutoModelForTDT.from_pretrained("nvidia/parakeet-tdt-0.6b-v3", dtype="auto", device_map=device) + +ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") +ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) +speech_samples = [el['array'] for el in ds["audio"][:5]] + +inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate) +inputs.to(model.device, dtype=model.dtype) +output = model.generate(**inputs, return_dict_in_generate=True) +print(processor.batch_decode(output.sequences, skip_special_tokens=True)) +``` + + + + ### Making The Model Go Brrr Parakeet supports full-graph compilation with CUDA graphs! This optimization is most effective when you know the maximum audio length you want to transcribe. The key idea is using static input shapes to avoid recompilation. For example, if you know your audio will be under 30 seconds, you can use the processor to pad all inputs to 30 seconds, preparing consistent input features and attention masks. See the example below! @@ -212,6 +256,10 @@ outputs.loss.backward() [[autodoc]] ParakeetCTCConfig +## ParakeetTDTConfig + +[[autodoc]] ParakeetTDTConfig + ## ParakeetEncoder [[autodoc]] ParakeetEncoder @@ -219,3 +267,7 @@ outputs.loss.backward() ## ParakeetForCTC [[autodoc]] ParakeetForCTC + +## ParakeetForTDT + +[[autodoc]] ParakeetForTDT diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index af8ea68ebb26..321486d96866 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -325,10 +325,8 @@ ("paddleocr_vl", "PaddleOCRVLConfig"), ("paligemma", "PaliGemmaConfig"), ("parakeet_ctc", "ParakeetCTCConfig"), - ("parakeet_tdt", "ParakeetTDTConfig"), ("parakeet_encoder", "ParakeetEncoderConfig"), - ("parakeet_tdt_decoder", "ParakeetTDTDecoderConfig"), - ("parakeet_tdt_joint", "ParakeetTDTJointConfig"), + ("parakeet_tdt", "ParakeetTDTConfig"), ("patchtsmixer", "PatchTSMixerConfig"), ("patchtst", "PatchTSTConfig"), ("pe_audio", "PeAudioConfig"), @@ -826,10 +824,8 @@ ("paligemma", "PaliGemma"), ("parakeet", "Parakeet"), ("parakeet_ctc", "Parakeet"), - ("parakeet_tdt", "ParakeetTDT"), ("parakeet_encoder", "ParakeetEncoder"), - ("parakeet_tdt_decoder", "ParakeetTDTDecoder"), - ("parakeet_tdt_joint", "ParakeetTDTJoint"), + ("parakeet_tdt", "ParakeetTDT"), ("patchtsmixer", "PatchTSMixer"), ("patchtst", "PatchTST"), ("pe_audio", "PeAudio"), @@ -1089,8 +1085,6 @@ ("pe_audio_video_encoder", "pe_audio_video"), ("video_llama_3_vision", "video_llama_3"), ("parakeet_encoder", "parakeet"), - ("parakeet_tdt_decoder", "parakeet"), - ("parakeet_tdt_joint", "parakeet"), ("parakeet_ctc", "parakeet"), ("parakeet_tdt", "parakeet"), ("lw_detr_vit", "lw_detr"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index baf70fd306b1..a4cb2deae8ea 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -59,8 +59,8 @@ ("musicgen", "EncodecFeatureExtractor"), ("musicgen_melody", "MusicgenMelodyFeatureExtractor"), ("parakeet_ctc", "ParakeetFeatureExtractor"), - ("parakeet_tdt", "ParakeetFeatureExtractor"), ("parakeet_encoder", "ParakeetFeatureExtractor"), + ("parakeet_tdt", "ParakeetFeatureExtractor"), ("pe_audio", "PeAudioFeatureExtractor"), ("pe_audio_video", "PeAudioFeatureExtractor"), ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 39250122e5a8..e38ee21ca865 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -319,10 +319,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("owlvit", "OwlViTModel"), ("paligemma", "PaliGemmaModel"), ("parakeet_ctc", "ParakeetForCTC"), - ("parakeet_tdt", "ParakeetForTDT"), ("parakeet_encoder", "ParakeetEncoder"), - ("parakeet_tdt_decoder", "ParakeetTDTDecoder"), - ("parakeet_tdt_joint", "ParakeetTDTJoint"), + ("parakeet_tdt", "ParakeetForTDT"), ("patchtsmixer", "PatchTSMixerModel"), ("patchtst", "PatchTSTModel"), ("pe_audio", "PeAudioModel"), @@ -2136,12 +2134,14 @@ class AutoModelForCTC(_BaseAutoModelClass): AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") + class AutoModelForTDT(_BaseAutoModelClass): _model_mapping = MODEL_FOR_TDT_MAPPING AutoModelForTDT = auto_class_update(AutoModelForTDT, head_doc="token-and-duration transducer") + class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index ec8e460ac32a..200d9e89bef3 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -124,6 +124,8 @@ ("owlvit", "OwlViTProcessor"), ("paddleocr_vl", "PaddleOCRVLProcessor"), ("paligemma", "PaliGemmaProcessor"), + ("parakeet_ctc", "ParakeetProcessor"), + ("parakeet_tdt", "ParakeetProcessor"), ("perception_lm", "PerceptionLMProcessor"), ("phi4_multimodal", "Phi4MultimodalProcessor"), ("pix2struct", "Pix2StructProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 6b6ff939a50c..4a8797afd6d4 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -236,6 +236,8 @@ ("ovis2", "Qwen2Tokenizer" if is_tokenizers_available() else None), ("owlv2", "CLIPTokenizer" if is_tokenizers_available() else None), ("owlvit", "CLIPTokenizer" if is_tokenizers_available() else None), + ("parakeet_ctc", "ParakeetTokenizer" if is_tokenizers_available() else None), + ("parakeet_tdt", "ParakeetTokenizer" if is_tokenizers_available() else None), ("pegasus", "PegasusTokenizer" if is_tokenizers_available() else None), ("pegasus_x", "PegasusTokenizer" if is_tokenizers_available() else None), ("perceiver", "PerceiverTokenizer"), diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py index 7ecea9099410..83623dcaf067 100644 --- a/src/transformers/models/lasr/modeling_lasr.py +++ b/src/transformers/models/lasr/modeling_lasr.py @@ -563,6 +563,9 @@ class LasrGenerateOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. + token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Token-level timestamps in seconds indicating when each token was emitted. Only returned by TDT models + when `return_timestamps=True` is passed to `generate()`. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for @@ -576,6 +579,7 @@ class LasrGenerateOutput(ModelOutput): """ sequences: torch.LongTensor + token_timestamps: torch.FloatTensor | None = None logits: tuple[torch.FloatTensor] | None = None attentions: tuple[tuple[torch.FloatTensor]] | None = None hidden_states: tuple[tuple[torch.FloatTensor]] | None = None diff --git a/src/transformers/models/parakeet/__init__.py b/src/transformers/models/parakeet/__init__.py index 5c54b2e2eadb..e8bbfe7faf45 100644 --- a/src/transformers/models/parakeet/__init__.py +++ b/src/transformers/models/parakeet/__init__.py @@ -21,7 +21,8 @@ from .configuration_parakeet import * from .feature_extraction_parakeet import * from .modeling_parakeet import * - from .tokenization_parakeet_fast import * + from .processing_parakeet import * + from .tokenization_parakeet import * else: import sys diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 96d11ca012bb..256c4c30cc35 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -149,65 +149,6 @@ def __init__( ) - -class ParakeetTDTDecoderConfig(PreTrainedConfig): - model_type = "parakeet_tdt_decoder" - keys_to_ignore_at_inference = ["past_key_values"] - output_hidden_states = False - - def __init__( - self, - hidden_size=640, - num_hidden_layers=1, - dropout=0, - vocab_size=1024, - forget_gate_bias=1.0, - t_max=None, - weights_init_scale=1.0, - hidden_hidden_bias_scale=0, - **kwargs, - ): - super().__init__( - **kwargs, - ) - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.dropout = dropout - self.vocab_size = vocab_size - self.forget_gate_bias=forget_gate_bias - self.t_max=t_max - self.weights_init_scale=weights_init_scale - self.hidden_hidden_bias_scale=hidden_hidden_bias_scale - - -class ParakeetTDTJointConfig(PreTrainedConfig): - model_type = "parakeet_tdt_joint" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - enc_hidden_size=1024, - pred_hidden_size=640, - hidden_size=640, - vocab_size=1024, - durations=[0,1,2,3,4], - norm=None, - dropout=0.0, - activation='relu', - **kwargs, - ): - super().__init__( - **kwargs, - ) - self.enc_hidden_size = enc_hidden_size - self.pred_hidden_size = pred_hidden_size - self.hidden_size = hidden_size - self.vocab_size = vocab_size - self.durations = durations - self.dropout = dropout - self.activation = activation - - class ParakeetCTCConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`ParakeetForCTC`]. It is used to instantiate a @@ -289,82 +230,98 @@ def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): class ParakeetTDTConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ParakeetForTDT`]. It is used to instantiate a + Parakeet TDT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Parakeet TDT + [nvidia/parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2) architecture. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 8192): + Vocabulary size of the model. + decoder_hidden_size (`int`, *optional*, defaults to 640): + Hidden size of the LSTM prediction network and joint network. + num_decoder_layers (`int`, *optional*, defaults to 1): + Number of LSTM layers in the prediction network. + num_duration_bins (`int`, *optional*, defaults to 5): + Number of duration bins for predicting token durations. + hidden_act (`str`, *optional*, defaults to `"relu"`): + The activation function in the joint network. + max_symbols_per_step (`int`, *optional*, defaults to 10): + Maximum number of symbols to emit per encoder time step during greedy decoding. + seconds_per_frame (`float`, *optional*, defaults to 0.08): + Duration in seconds of each encoder output frame. Used for computing token timestamps. + Computed as `hop_length * subsampling_factor / sampling_rate` (e.g. 160 * 8 / 16000 = 0.08). + encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): + The config object or dictionary of the encoder. + pad_token_id (`int`, *optional*, defaults to 8192): + Padding token id. Also used as blank token id for TDT decoding. + + Example: + ```python + >>> from transformers import ParakeetForTDT, ParakeetTDTConfig + + >>> # Initializing a Parakeet TDT configuration + >>> configuration = ParakeetTDTConfig() + + >>> # Initializing a model from the configuration + >>> model = ParakeetForTDT(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ model_type = "parakeet_tdt" - sub_configs = {"encoder_config": ParakeetEncoderConfig, "decoder_config": ParakeetTDTDecoderConfig, "joint_config": ParakeetTDTJointConfig} + sub_configs = {"encoder_config": ParakeetEncoderConfig} def __init__( self, -# bos_token_id=1, -# eos_token_id=2, -# pad_token_id=1024, - tdt_loss_reduction="mean", - encoder_config: Union[dict, ParakeetEncoderConfig] = None, - decoder_config: Union[dict, ParakeetTDTDecoderConfig] = None, - joint_config: Union[dict, ParakeetTDTJointConfig] = None, + vocab_size=8192, + decoder_hidden_size=640, + num_decoder_layers=1, + num_duration_bins=5, + hidden_act="relu", + max_symbols_per_step=10, + seconds_per_frame=0.08, + encoder_config: dict | ParakeetEncoderConfig = None, + pad_token_id=8192, **kwargs, ): + self.vocab_size = vocab_size + self.decoder_hidden_size = decoder_hidden_size + self.num_decoder_layers = num_decoder_layers + self.num_duration_bins = num_duration_bins + self.hidden_act = hidden_act + self.max_symbols_per_step = max_symbols_per_step + self.seconds_per_frame = seconds_per_frame - if encoder_config is None: - self.encoder_config = ParakeetEncoderConfig() - elif isinstance(encoder_config, dict): + if isinstance(encoder_config, dict): self.encoder_config = ParakeetEncoderConfig(**encoder_config) - elif isinstance(encoder_config, ParakeetEncoderConfig): - self.encoder_config = encoder_config - else: - raise ValueError( - f"`encoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" - ) - - if decoder_config is None: - self.decoder_config = ParakeetTDTDecoderConfig() - elif isinstance(decoder_config, dict): - self.decoder_config = ParakeetTDTDecoderConfig(**decoder_config) - elif isinstance(decoder_config, ParakeetTDTDecoderConfig): - self.decoder_config = decoder_config - else: - raise ValueError( - f"`decoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" - ) - - if joint_config is None: - self.joint_config = ParakeetTDTJointConfig() - elif isinstance(joint_config, dict): - self.joint_config = ParakeetTDTJointConfig(**joint_config) - elif isinstance(joint_config, ParakeetTDTJointConfig): - self.joint_config = joint_config + elif encoder_config is None: + self.encoder_config = ParakeetEncoderConfig() else: - raise ValueError( - f"`decoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" - ) + self.encoder_config = encoder_config - vocab_size = self.joint_config.vocab_size - self.vocab_size = vocab_size + self.initializer_range = self.encoder_config.initializer_range - self.blank_token_id = vocab_size super().__init__( -# pad_token_id=self.blank_token_id, + pad_token_id=pad_token_id, **kwargs, ) @classmethod - def from_configs( - cls, - encoder_config: ParakeetEncoderConfig, - decoder_config: ParakeetTDTDecoderConfig, - joint_config: ParakeetTDTJointConfig, - **kwargs): + def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): r""" - Instantiate a [`ParakeetConfig`] (or a derived class) from parakeet encoder model configuration. + Instantiate a [`ParakeetTDTConfig`] (or a derived class) from parakeet encoder model configuration. Returns: - [`ParakeetConfig`]: An instance of a configuration object + [`ParakeetTDTConfig`]: An instance of a configuration object """ + return cls(encoder_config=encoder_config.to_dict(), **kwargs) - return cls( - encoder_config=encoder_config.to_dict(), - decoder_config=decoder_config.to_dict(), - joint_config=joint_config.to_dict(), - **kwargs) -__all__ = ["ParakeetCTCConfig", "ParakeetTDTConfig", "ParakeetEncoderConfig", "ParakeetTDTDecoderConfig", "ParakeetTDTJointConfig"] +__all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig", "ParakeetTDTConfig"] diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index b57a58a6eca1..51ea38214527 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -24,11 +24,12 @@ from transformers import ( ParakeetCTCConfig, - ParakeetTDTConfig, + ParakeetEncoderConfig, ParakeetFeatureExtractor, ParakeetForCTC, ParakeetForTDT, ParakeetProcessor, + ParakeetTDTConfig, ParakeetTokenizer, ) from transformers.convert_slow_tokenizer import ParakeetConverter @@ -48,6 +49,14 @@ r"linear_pos": r"relative_k_proj", } +# Additional mappings for TDT decoder and joint network +NEMO_TDT_WEIGHT_MAPPING = { + r"decoder\.prediction\.embed\.": r"decoder.embedding.", + r"decoder\.prediction\.dec_rnn\.lstm\.": r"decoder.lstm.", + r"joint\.enc\.": r"joint.encoder_projector.", + r"joint\.pred\.": r"decoder.decoder_projector.", +} + def convert_key(key, mapping): for pattern, replacement in mapping.items(): @@ -56,22 +65,12 @@ def convert_key(key, mapping): def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str]: - """ - Extract .nemo file (tar archive) and return paths to important files. - - Args: - nemo_file_path: Path to .nemo file - extract_dir: Directory to extract to - - Returns: - Dictionary with paths to model.pt, model_config.yaml, etc. - """ + """Extract .nemo file (tar archive) and return paths to important files.""" print(f"Extracting NeMo archive: {nemo_file_path}") with tarfile.open(nemo_file_path, "r", encoding="utf-8") as tar: tar.extractall(extract_dir) - # Log all extracted files for debugging all_files = [] for root, dirs, files in os.walk(extract_dir): for file in files: @@ -80,14 +79,12 @@ def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str print(f"All extracted files: {[os.path.basename(f) for f in all_files]}") - # Find important files with more robust detection model_files = {} for root, dirs, files in os.walk(extract_dir): for file in files: file_path = os.path.join(root, file) file_lower = file.lower() - # Look for model weights with various common names if ( file.endswith(".pt") or file.endswith(".pth") @@ -102,26 +99,23 @@ def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str model_files["model_weights"] = file_path print(f"Found model weights: {file}") - # Look for config files elif ( file == "model_config.yaml" or file == "config.yaml" or (file.endswith(".yaml") and "config" in file_lower) ): - if "model_config" not in model_files: # Prefer model_config.yaml + if "model_config" not in model_files: model_files["model_config"] = file_path print(f"Found config file: {file}") if file == "model_config.yaml": - model_files["model_config"] = file_path # Override with preferred name + model_files["model_config"] = file_path - # Look for vocabulary files elif ( file.endswith(".vocab") or file.endswith(".model") or file.endswith(".txt") or ("tokenizer" in file_lower and (file.endswith(".vocab") or file.endswith(".model"))) ): - # Prefer .vocab files over others if "tokenizer_model_file" not in model_files or file.endswith(".model"): model_files["tokenizer_model_file"] = file_path print(f"Found tokenizer model file: {file}") @@ -130,7 +124,6 @@ def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str print(f"Found model files: {list(model_files.keys())}") - # Validate that we found the required files if "model_weights" not in model_files: raise FileNotFoundError( f"Could not find model weights file in {nemo_file_path}. " @@ -223,8 +216,8 @@ def convert_encoder_config(nemo_config): "conv_context_size", "dropout_pre_encoder", "reduction", - "reduction_position", "reduction_factor", + "reduction_position", ] encoder_config_keys_mapping = { "d_model": "hidden_size", @@ -243,62 +236,16 @@ def convert_encoder_config(nemo_config): } converted_encoder_config = {} - decoder_keys_to_ignore = [ - "_target_", - "normalization_mode", - "random_state_sampling", - "blank_as_pad", - "prednet", - ] - decoder_config_keys_mapping = { - "vocab_size": "vocab_size", - } - converted_decoder_config = {} - - joint_keys_to_ignore = [ - "_target_", - 'log_softmax', - 'preserve_memory', - 'fuse_loss_wer', - 'fused_batch_size', - 'jointnet', - 'vocabulary' - ] - joint_config_keys_mapping = { - "vocab_size": "vocab_size", - "num_classes": "num_classes", - "num_extra_outputs": "num_extra_outputs", - } - converted_joint_config = {} - - for key, value in nemo_config["encoder"].items(): if key in encoder_keys_to_ignore: continue if key in encoder_config_keys_mapping: converted_encoder_config[encoder_config_keys_mapping[key]] = value + if key == "use_bias": + converted_encoder_config["convolution_bias"] = value else: raise ValueError(f"Key {key} not found in encoder_config_keys_mapping") - if model_type == 'tdt': - for key, value in nemo_config["decoder"].items(): - if key in decoder_keys_to_ignore: - continue - if key in decoder_config_keys_mapping: - converted_decoder_config[decoder_config_keys_mapping[key]] = value - else: - raise ValueError(f"Key {key} not found in encoder_config_keys_mapping") - - for key, value in nemo_config["joint"].items(): - if key in joint_keys_to_ignore: - continue - if key in joint_config_keys_mapping: - converted_joint_config[joint_config_keys_mapping[key]] = value - else: - raise ValueError(f"Key {key} not found in encoder_config_keys_mapping") - - converted_joint_config["vocab_size"] = converted_joint_config["num_classes"] - return ParakeetEncoderConfig(**converted_encoder_config) @@ -307,7 +254,6 @@ def load_and_convert_state_dict(model_files): state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True) converted_state_dict = {} for key, value in state_dict.items(): - # Skip preprocessing weights (featurizer components) if key.endswith("featurizer.window") or key.endswith("featurizer.fb"): print(f"Skipping preprocessing weight: {key}") continue @@ -331,39 +277,142 @@ def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_re print("Saving the model.") model.save_pretrained(output_dir) - elif model_type == "tdt": - num_classes = converted_joint_config["num_classes"] - model_config = ParakeetTDTConfig( - pad_token_id=num_classes, - vocab_size=num_classes+1, - blank_token_id=num_classes, - encoder_config=converted_encoder_config, - decoder_config=converted_decoder_config, - joint_config=converted_joint_config, - ) - print("Loading the checkpoint in a Parakeet TDT model.") - with torch.device("meta"): - model = ParakeetForTDT(model_config) - model.load_state_dict(converted_state_dict, strict=True, assign=True) - print("Checkpoint loaded successfully.") - del model.config._name_or_path + if push_to_repo_id: + model.push_to_hub(push_to_repo_id) - print("Saving the model.") - model.save_pretrained(output_dir) + del model - if push_to_repo_id: - model.push_to_hub(push_to_repo_id) + gc.collect() + print("Reloading the model to check if it's saved correctly.") + ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") + print("Model reloaded successfully.") - del converted_state_dict, model - # Safety check: reload the converted model - gc.collect() - print("Reloading the model to check if it's saved correctly.") - ParakeetForTDT.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") - print("Model reloaded successfully.") +def convert_tdt_config(nemo_config, encoder_config): + """Convert NeMo TDT config to HF TDT config.""" + decoder_config = nemo_config.get("decoder", {}) + decoding_config = nemo_config.get("decoding", {}) + labels = nemo_config.get("labels", []) + vocab_size = len(labels) if labels else decoder_config.get("vocab_size", 1024) + prednet = decoder_config.get("prednet", {}) + decoder_hidden_size = prednet.get("pred_hidden", 640) + num_decoder_layers = prednet.get("pred_rnn_layers", 2) + + durations = decoding_config.get("durations", [0, 1, 2, 3, 4]) + num_duration_bins = len(durations) + + preprocessor = nemo_config.get("preprocessor", {}) + sample_rate = preprocessor.get("sample_rate", 16000) + window_stride = preprocessor.get("window_stride", 0.01) + hop_length = int(window_stride * sample_rate) + subsampling_factor = encoder_config.subsampling_factor + seconds_per_frame = (hop_length * subsampling_factor) / sample_rate + + print( + f"TDT config: vocab_size={vocab_size}, decoder_hidden={decoder_hidden_size}, " + f"decoder_layers={num_decoder_layers}, num_durations={num_duration_bins}, " + f"seconds_per_frame={seconds_per_frame}" + ) + return ParakeetTDTConfig( + vocab_size=vocab_size, + decoder_hidden_size=decoder_hidden_size, + num_decoder_layers=num_decoder_layers, + num_duration_bins=num_duration_bins, + hidden_act="relu", + max_symbols_per_step=10, + seconds_per_frame=seconds_per_frame, + encoder_config=encoder_config.to_dict(), + pad_token_id=vocab_size, + ) + + +def load_and_convert_tdt_state_dict(model_files, vocab_size, num_duration_bins): + """Load NeMo TDT state dict and convert keys to HF format, splitting combined head.""" + state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True) + converted_state_dict = {} + + all_mappings = {**NEMO_TO_HF_WEIGHT_MAPPING, **NEMO_TDT_WEIGHT_MAPPING} + + for key, value in state_dict.items(): + if key.endswith("featurizer.window") or key.endswith("featurizer.fb"): + print(f"Skipping preprocessing weight: {key}") + continue + + # Handle combined output head split + if key == "joint.joint_net.2.weight": + token_weight = value[: vocab_size + 1, :] + duration_weight = value[vocab_size + 1 :, :] + converted_state_dict["joint.token_head.weight"] = token_weight + converted_state_dict["joint.duration_head.weight"] = duration_weight + print(f"Split combined weight: token_head {token_weight.shape}, duration_head {duration_weight.shape}") + continue + + if key == "joint.joint_net.2.bias": + token_bias = value[: vocab_size + 1] + duration_bias = value[vocab_size + 1 :] + converted_state_dict["joint.token_head.bias"] = token_bias + converted_state_dict["joint.duration_head.bias"] = duration_bias + print(f"Split combined bias: token_head {token_bias.shape}, duration_head {duration_bias.shape}") + continue + + converted_key = convert_key(key, all_mappings) + converted_state_dict[converted_key] = value + + return converted_state_dict + + +def write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_to_repo_id=None): + """Write TDT model using encoder config, TDT config, and converted state dict.""" + model_config = convert_tdt_config(nemo_config, encoder_config) + print(f"Converted TDT config: {model_config}") + + converted_state_dict = load_and_convert_tdt_state_dict( + model_files, model_config.vocab_size, model_config.num_duration_bins + ) + + print("Loading the checkpoint in a Parakeet TDT model.") + with torch.device("meta"): + model = ParakeetForTDT(model_config) + + missing_keys, unexpected_keys = model.load_state_dict(converted_state_dict, strict=False, assign=True) + + if missing_keys: + print(f"Warning: Missing keys: {missing_keys}") + if unexpected_keys: + print(f"Warning: Unexpected keys: {unexpected_keys}") + + if not missing_keys and not unexpected_keys: + print("All weights loaded successfully!") + + del model.config._name_or_path + + print("Saving the model.") + model.save_pretrained(output_dir) + + if push_to_repo_id: + model.push_to_hub(push_to_repo_id) + + del model + + gc.collect() + print("Reloading the model to check if it's saved correctly.") + ParakeetForTDT.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto") + print("Model reloaded successfully.") + + +def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None): + """Main model conversion function.""" + encoder_config = convert_encoder_config(nemo_config) + print(f"Converted encoder config: {encoder_config}") + + if model_type == "ctc": + converted_state_dict = load_and_convert_state_dict(model_files) + write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id) + elif model_type == "tdt": + write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_to_repo_id) else: raise ValueError(f"Model type {model_type} not supported.") @@ -387,7 +436,7 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co") - parser.add_argument("--model_type", required=True, choices=["ctc","tdt"], help="Model type (`ctc`, `tdt`)") + parser.add_argument("--model_type", required=True, choices=["ctc", "tdt"], help="Model type (`ctc`, `tdt`)") parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model") parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub") args = parser.parse_args() diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 76251433ee92..91c9aea5003c 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -29,19 +29,13 @@ from ...activations import ACT2FN from ...integrations import use_kernel_func_from_hub, use_kernelized_func from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithNoAttention, CausalLMOutput +from ...modeling_outputs import BaseModelOutput, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs -from .configuration_parakeet import ( - ParakeetCTCConfig, - ParakeetEncoderConfig, - ParakeetTDTConfig, - ParakeetTDTDecoderConfig, - ParakeetTDTJointConfig, - PreTrainedConfig, -) +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_torchaudio_available +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig @dataclass @@ -138,7 +132,7 @@ def __init__(self, config: ParakeetEncoderConfig, module_config=None): self.activation = ACT2FN[module_config.get("activation", "silu")] self.padding = (kernel_size - 1) // 2 self.pointwise_conv1 = nn.Conv1d( - channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias + channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias ) self.depthwise_conv = nn.Conv1d( channels, @@ -147,11 +141,11 @@ def __init__(self, config: ParakeetEncoderConfig, module_config=None): stride=1, padding=self.padding, groups=channels, - bias=config.attention_bias, + bias=config.convolution_bias, ) self.norm = nn.BatchNorm1d(channels) self.pointwise_conv2 = nn.Conv1d( - channels, channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias + channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias ) def forward(self, hidden_states, attention_mask=None): @@ -480,7 +474,7 @@ def forward( @auto_docstring class ParakeetPreTrainedModel(PreTrainedModel): - config: PreTrainedConfig + config: ParakeetCTCConfig base_model_prefix = "model" main_input_name = "input_features" input_modalities = "audio" @@ -515,17 +509,21 @@ def _init_weights(self, module): init.normal_(module.bias_u, mean=0.0, std=std) init.normal_(module.bias_v, mean=0.0, std=std) elif isinstance(module, ParakeetEncoderRelPositionalEncoding): + encoder_config = getattr(self.config, "encoder_config", self.config) inv_freq = 1.0 / ( - 10000.0 ** (torch.arange(0, self.config.hidden_size, 2, dtype=torch.int64) / self.config.hidden_size) + 10000.0 + ** (torch.arange(0, encoder_config.hidden_size, 2, dtype=torch.int64) / encoder_config.hidden_size) ) init.copy_(module.inv_freq, inv_freq) + elif isinstance(module, nn.LSTM): + for name, param in module.named_parameters(): + if "weight" in name: + init.normal_(param, mean=0.0, std=std) + elif "bias" in name: + init.zeros_(param) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): - encoder_config = ( - self.config.encoder_config - if isinstance(self.config, (ParakeetCTCConfig, ParakeetTDTConfig)) - else self.config - ) + encoder_config = getattr(self.config, "encoder_config", self.config) kernel_size = encoder_config.subsampling_conv_kernel_size stride = encoder_config.subsampling_conv_stride @@ -625,6 +623,7 @@ def forward( position_embeddings, p=self.dropout_positions, training=self.training ) + output_mask = None if attention_mask is not None: output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1]) attention_mask = output_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1) @@ -648,7 +647,8 @@ def forward( ) return ParakeetEncoderModelOutput( - last_hidden_state=hidden_states, attention_mask=output_mask.int() if output_attention_mask else None + last_hidden_state=hidden_states, + attention_mask=output_mask.int() if output_attention_mask and output_mask is not None else None, ) @@ -661,6 +661,9 @@ class ParakeetGenerateOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. + token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Token-level timestamps in seconds indicating when each token was emitted. Only returned by TDT models + when `return_timestamps=True` is passed to `generate()`. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for @@ -674,275 +677,12 @@ class ParakeetGenerateOutput(ModelOutput): """ sequences: torch.LongTensor + token_timestamps: torch.FloatTensor | None = None logits: tuple[torch.FloatTensor] | None = None attentions: tuple[tuple[torch.FloatTensor]] | None = None hidden_states: tuple[tuple[torch.FloatTensor]] | None = None -class ParakeetLSTM(torch.nn.Module): - def __init__( - self, - input_size: int, - hidden_size: int, - num_layers: int, - dropout: Optional[float], - forget_gate_bias: Optional[float], - t_max: Optional[int] = None, - weights_init_scale: float = 1.0, - hidden_hidden_bias_scale: float = 0.0, - proj_size: int = 0, - ): - """Returns an LSTM with forget gate bias init to `forget_gate_bias`. - Args: - input_size: See `torch.nn.LSTM`. - hidden_size: See `torch.nn.LSTM`. - num_layers: See `torch.nn.LSTM`. - dropout: See `torch.nn.LSTM`. - - forget_gate_bias: float, set by default to 1.0, which constructs a forget gate - initialized to 1.0. - Reference: - [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf) - - t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization - of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course - of training. - Reference: - [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) - - weights_init_scale: Float scale of the weights after initialization. Setting to lower than one - sometimes helps reduce variance between runs. - - hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for - the default behaviour. - - Returns: - A `torch.nn.LSTM`. - """ - super(ParakeetLSTM, self).__init__() - - self.lstm = torch.nn.LSTM( - input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, proj_size=proj_size - ) - - if t_max is not None: - # apply chrono init - for name, v in self.lstm.named_parameters(): - if "bias" in name: - p = getattr(self.lstm, name) - n = p.nelement() - hidden_size = n // 4 - p.data.fill_(0) - p.data[hidden_size : 2 * hidden_size] = torch.log( - torch.nn.init.uniform_(p.data[0:hidden_size], 1, t_max - 1) - ) - # forget gate biases = log(uniform(1, Tmax-1)) - p.data[0:hidden_size] = -p.data[hidden_size : 2 * hidden_size] - # input gate biases = -(forget gate biases) - - elif forget_gate_bias is not None: - for name, v in self.lstm.named_parameters(): - if "bias_ih" in name: - bias = getattr(self.lstm, name) - bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias) - if "bias_hh" in name: - bias = getattr(self.lstm, name) - bias.data[hidden_size : 2 * hidden_size] *= float(hidden_hidden_bias_scale) - - self.dropout = torch.nn.Dropout(dropout) if dropout else None - - for name, v in self.named_parameters(): - if "weight" in name or "bias" in name: - v.data *= float(weights_init_scale) - - def forward( - self, x: torch.Tensor, h: Optional[tuple[torch.Tensor, torch.Tensor]] = None - ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - x, h = self.lstm(x, h) - - if self.dropout: - x = self.dropout(x) - - return x, h - - -class ParakeetTDTJoint(ParakeetPreTrainedModel): - config: ParakeetTDTJointConfig - base_model_prefix = "" # joint" - main_input_name = "enc" - _supports_flat_attention_mask = False - _supports_sdpa = True - _supports_flex_attn = False - _supports_attention_backend = False - _can_record_outputs = {} - _no_split_modules = None - - def __init__(self, config: ParakeetTDTJointConfig): - super().__init__(config) - self.config = config - self.gradient_checkpointing = False - - self.enc = torch.nn.Linear(config.enc_hidden_size, config.hidden_size) - self.pred = torch.nn.Linear(config.pred_hidden_size, config.hidden_size) - - num_classes = config.vocab_size + 1 + len(config.durations) - - layers = ( - [torch.nn.ReLU(inplace=True)] - + ([torch.nn.Dropout(p=self.config.dropout)]) - + [torch.nn.Linear(config.hidden_size, num_classes)] - ) - self.joint_net = torch.nn.Sequential(*layers) - self.post_init() - - @auto_docstring - @check_model_inputs() - def forward( - self, - enc: torch.Tensor, - pred: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutputWithNoAttention: - # Right now we only support joint for inference. - - pred = pred.view([-1, self.config.pred_hidden_size]) # making it B, D - enc = enc.view([-1, self.config.enc_hidden_size]) # making it B, D - enc = self.enc(enc) - pred = self.pred(pred) - - assert enc.shape[0] == pred.shape[0] - output = self.joint_net(enc + pred) - return BaseModelOutput(last_hidden_state=output) - - -class ParakeetTDTPredictor(ParakeetPreTrainedModel): - def __init__(self, config: ParakeetTDTDecoderConfig): - super().__init__(config) - self.gradient_checkpointing = False - self.config = config - - self.embed = torch.nn.Embedding(config.vocab_size + 1, config.hidden_size) # +1 for blank - self.dec_rnn = self.rnn( - config.hidden_size, - config.hidden_size, - config.num_hidden_layers + 1, - config.forget_gate_bias, - config.dropout, - config.t_max, - config.weights_init_scale, - config.hidden_hidden_bias_scale, - ) - self.post_init() - - def rnn( - self, - input_size: int, - hidden_size: int, - num_layers: int, - forget_gate_bias: Optional[float] = 1.0, - dropout: Optional[float] = 0.0, - t_max: Optional[int] = None, - weights_init_scale: float = 1.0, - hidden_hidden_bias_scale: float = 0.0, - proj_size: int = 0, - ) -> torch.nn.Module: - return ParakeetLSTM( - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - dropout=dropout, - forget_gate_bias=forget_gate_bias, - t_max=t_max, - weights_init_scale=weights_init_scale, - hidden_hidden_bias_scale=hidden_hidden_bias_scale, - proj_size=proj_size, - ) - - @auto_docstring - @check_model_inputs() - @can_return_tuple - def forward( - self, - input_token, - states, - hidden_state=None, - **kwargs: Unpack[TransformersKwargs], - ): - assert input_token is not None - - device = self.embed.weight.device - if input_token.device != device: - input_token = input_token.to(device) - return self.predict(input_token, state=states) - - def predict(self, y, state): - # Get device and dtype of current module - - # (B, U) -> (B, U, H) - y = self.embed(y).transpose(0, 1) # (U + 1, B, H) - - g, hid = self.dec_rnn(y, state) - g = g.transpose(0, 1).transpose(1, 2) # (B, H, U + 1) - - return g, hid - - -@auto_docstring( - custom_intro=""" - The Parakeet TDT Decoder. This class encapsulates both the predictor and joint network for TDT models. - """ -) -class ParakeetTDTDecoder(ParakeetPreTrainedModel): - config: ParakeetTDTDecoderConfig - base_model_prefix = "decoder" - main_input_name = "input_token" - _supports_flat_attention_mask = False - _supports_sdpa = True - _supports_flex_attn = False - _supports_attention_backend = False - _can_record_outputs = {} - _no_split_modules = None - - def __init__(self, config: ParakeetTDTDecoderConfig): - super().__init__(config) - self.config = config - self.gradient_checkpointing = False - self.prediction = ParakeetTDTPredictor(config) - self.post_init() - - def _init_weights(self, module): - if hasattr(self.config, "initializer_range"): - std = self.config.initializer_range - else: - # 0.02 is the standard default value accross the library - std = getattr(self.config.get_text_config(), "initializer_range", 0.02) - - module.prediction.embed.weight.data.normal_(mean=0.0, std=std) - for param in module.prediction.dec_rnn.lstm.parameters(): - param.data.normal_(mean=0.0, std=std) - - def get_input_embeddings(self): - return self.prediction.embed - - def set_input_embeddings(self, embed): - self.prediction.embed = embed - - @auto_docstring - @check_model_inputs() - @can_return_tuple - def forward( - self, - input_token, - hidden_state=None, - **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutputWithNoAttention: - if hidden_state is not None: - hidden_state = tuple(hidden_state.unbind(dim=0)) - - h_out, h_state = self.prediction(input_token, hidden_state, **kwargs) - return BaseModelOutputWithNoAttention(h_out, torch.stack(h_state, dim=0)) - - @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -1087,9 +827,58 @@ def generate( return sequences +class ParakeetTDTDecoder(nn.Module): + """LSTM-based prediction network for TDT.""" + + def __init__(self, config: ParakeetTDTConfig): + super().__init__() + self.embedding = nn.Embedding(config.vocab_size + 1, config.decoder_hidden_size) + self.lstm = nn.LSTM( + input_size=config.decoder_hidden_size, + hidden_size=config.decoder_hidden_size, + num_layers=config.num_decoder_layers, + batch_first=True, + ) + self.decoder_projector = nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size) + + def forward( + self, + input_ids: torch.LongTensor, + hidden_state: torch.Tensor | None = None, + cell_state: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + embeddings = self.embedding(input_ids) + lstm_state = (hidden_state, cell_state) if hidden_state is not None else None + lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, lstm_state) + decoder_output = self.decoder_projector(lstm_output) + return decoder_output, hidden_state, cell_state + + +class ParakeetTDTJointNetwork(nn.Module): + """Joint network that combines encoder and decoder outputs to predict tokens and durations.""" + + def __init__(self, config: ParakeetTDTConfig): + super().__init__() + self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) + self.activation = ACT2FN[config.hidden_act] + self.token_head = nn.Linear(config.decoder_hidden_size, config.vocab_size + 1) + self.duration_head = nn.Linear(config.decoder_hidden_size, config.num_duration_bins) + + def forward( + self, + encoder_output: torch.Tensor, + decoder_output: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + encoder_projected = self.encoder_projector(encoder_output) + joint_output = self.activation(encoder_projected + decoder_output) + token_logits = self.token_head(joint_output) + duration_logits = self.duration_head(joint_output) + return token_logits, duration_logits + + @auto_docstring( custom_intro=""" - Parakeet TDT model. + Parakeet model with TDT (Token Duration Transducer) head for speech recognition. """ ) class ParakeetForTDT(ParakeetPreTrainedModel): @@ -1098,10 +887,9 @@ class ParakeetForTDT(ParakeetPreTrainedModel): def __init__(self, config: ParakeetTDTConfig): super().__init__(config) self.encoder = ParakeetEncoder(config.encoder_config) - self.decoder = ParakeetTDTDecoder(config.decoder_config) - self.joint = ParakeetTDTJoint(config.joint_config) - self.blank_token_id = config.blank_token_id - self.max_token_per_frame = 2 + self.decoder = ParakeetTDTDecoder(config) + self.joint = ParakeetTDTJointNetwork(config) + self.post_init() @auto_docstring @@ -1109,20 +897,86 @@ def __init__(self, config: ParakeetTDTConfig): def forward( self, input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], - ): + ) -> CausalLMOutput: + r""" + Example: + + ```python + >>> from transformers import AutoProcessor, ParakeetForTDT + >>> from datasets import load_dataset, Audio + + >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = ParakeetForTDT.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) + + >>> inputs = processor(ds[0]["audio"]["array"]) + >>> outputs = model(**inputs) + ``` + """ encoder_outputs = self.encoder( input_features=input_features, + attention_mask=attention_mask, **kwargs, ) - logits = self.joint.joint_net( - self.joint.enc(encoder_outputs.last_hidden_state) - ) # [:,:,:self.joint.vocab_size] + encoder_hidden_states = encoder_outputs.last_hidden_state + + loss = None + if labels is not None: + if not is_torchaudio_available(): + raise ImportError( + "torchaudio is required for TDT loss computation. Install it with: pip install torchaudio" + ) + from torchaudio.functional import rnnt_loss + + # Compute encoder output lengths + attention_mask = ( + attention_mask + if attention_mask is not None + else torch.ones(input_features.shape[:-1], dtype=torch.long, device=input_features.device) + ) + encoder_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) + + # Compute target lengths (non-pad tokens) + labels_mask = labels != self.config.pad_token_id + target_lengths = labels_mask.sum(-1) + + # Prepare decoder input: prepend blank token to labels + blank_tokens = torch.full( + (labels.shape[0], 1), self.config.pad_token_id, dtype=labels.dtype, device=labels.device + ) + decoder_input = torch.cat([blank_tokens, labels], dim=1) + + # Run decoder on full label sequence: (batch, U+1, decoder_hidden_size) + decoder_output, _, _ = self.decoder(decoder_input) + + # Compute joint output for all (T, U+1) pairs via broadcasting + # encoder: (batch, T, 1, encoder_hidden) -> projected to (batch, T, 1, decoder_hidden_size) + # decoder: (batch, 1, U+1, decoder_hidden_size) + token_logits, _ = self.joint( + encoder_hidden_states.unsqueeze(2), + decoder_output.unsqueeze(1), + ) + # token_logits: (batch, T, U+1, vocab_size+1) + + loss = rnnt_loss( + logits=token_logits.float(), + targets=labels.int(), + logit_lengths=encoder_lengths.int(), + target_lengths=target_lengths.int(), + blank=self.config.pad_token_id, + reduction="mean", + ) return CausalLMOutput( - loss=torch.sum(encoder_outputs.last_hidden_state), # a fake loss here. - logits=logits, + loss=loss, + logits=encoder_hidden_states, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @@ -1131,66 +985,166 @@ def forward( def generate( self, input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + return_timestamps: bool = False, + return_dict_in_generate: bool = False, **kwargs: Unpack[TransformersKwargs], - ): - encoder_outputs = self.encoder( + ) -> ParakeetGenerateOutput | torch.LongTensor: + r""" + Perform TDT greedy decoding to generate token sequences. + + Args: + return_timestamps (`bool`, *optional*, defaults to `False`): + Whether to return per-token timestamps in seconds. When `True`, forces + `return_dict_in_generate=True` and includes `token_timestamps` in the output. + + Example: + + ```python + >>> from transformers import AutoProcessor, ParakeetForTDT + >>> from datasets import load_dataset, Audio + + >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = ParakeetForTDT.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) + + >>> inputs = processor(ds[0]["audio"]["array"]) + >>> output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) + + >>> transcription = processor.batch_decode(output.sequences, skip_special_tokens=True) + >>> print(transcription) + >>> print(output.token_timestamps) + ``` + """ + if return_timestamps: + return_dict_in_generate = True + + blank_id = self.config.pad_token_id + max_symbols_per_step = self.config.max_symbols_per_step + device = input_features.device + batch_size = input_features.shape[0] + + kwargs["return_dict"] = True + outputs: CausalLMOutput = self( input_features=input_features, + attention_mask=attention_mask, **kwargs, ) - output = self.greedy_decode(encoder_outputs.last_hidden_state) - - return output - - def greedy_decode(self, encoder_output): - T = encoder_output.shape[1] - t = 0 - hyp = [] - last_label = torch.LongTensor([[self.blank_token_id]]) - dec_out = self.decoder(input_token=last_label) - g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states - - symbols_added = 0 - while t < T: - enc = encoder_output[0, t, :] - while symbols_added < self.max_token_per_frame: - logits = self.joint(enc, g).last_hidden_state - - logits = logits.view([-1]) - - token_logits = logits[: self.blank_token_id + 1].softmax(-1) - duration_logits = logits[self.blank_token_id + 1 :].softmax(-1) - - v, token = token_logits.max(-1) - v_duration, duration = duration_logits.max(-1) - token = token.item() - duration = duration.item() - - if token != self.blank_token_id: - hyp.append(token) - last_label = token - last_label = torch.LongTensor([[last_label]]) - dec_out = self.decoder(last_label, hidden_prime) - g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states - - if duration == 0: + encoder_hidden_states = outputs.logits + + sequence_length = encoder_hidden_states.shape[1] + if attention_mask is not None: + encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) + valid_lengths = encoder_attention_mask.sum(dim=1).int() + else: + valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) + + # Initialize decoder LSTM state + hidden_state = torch.zeros( + self.config.num_decoder_layers, + batch_size, + self.config.decoder_hidden_size, + device=device, + dtype=encoder_hidden_states.dtype, + ) + cell_state = torch.zeros_like(hidden_state) + + # Initialize with blank token + prev_tokens = torch.full((batch_size, 1), blank_id, dtype=torch.long, device=device) + decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) + + all_tokens = [[] for _ in range(batch_size)] + token_frame_indices = [[] for _ in range(batch_size)] if return_timestamps else None + time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) + active_mask = time_indices < valid_lengths + + while active_mask.any(): + safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) + encoder_frames = encoder_hidden_states[ + torch.arange(batch_size, device=device), safe_time_indices + ].unsqueeze(1) + + symbols_added = 0 + while symbols_added < max_symbols_per_step: + token_logits, duration_logits = self.joint(encoder_frames, decoder_output) + token_logits = token_logits.squeeze(1) + duration_logits = duration_logits.squeeze(1) + + tokens = token_logits.argmax(dim=-1) + durations = duration_logits.argmax(dim=-1) + + is_blank = tokens == blank_id + emit_mask = active_mask & ~is_blank + + for i in range(batch_size): + if emit_mask[i]: + all_tokens[i].append(tokens[i].item()) + if token_frame_indices is not None: + token_frame_indices[i].append(time_indices[i].item()) + + if emit_mask.any(): + new_prev_tokens = tokens.unsqueeze(1) + new_decoder_output, new_hidden_state, new_cell_state = self.decoder( + new_prev_tokens, hidden_state, cell_state + ) + + emit_mask_expanded = emit_mask.view(batch_size, 1, 1) + decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) + + emit_mask_state = emit_mask.view(1, batch_size, 1) + hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) + cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) + + # If duration is 0, stay on same frame (emit more tokens) + stay_mask = active_mask & (durations == 0) + if stay_mask.any(): symbols_added += 1 - else: - t += duration - symbols_added = 0 - break - - if symbols_added == self.max_token_per_frame: - t += 1 - symbols_added = 0 - - return hyp - - -__all__ = [ - "ParakeetForCTC", - "ParakeetForTDT", - "ParakeetEncoder", - "ParakeetTDTDecoder", - "ParakeetTDTJoint", - "ParakeetPreTrainedModel", -] + if symbols_added >= max_symbols_per_step: + time_indices = time_indices + 1 + break + continue + + # Duration > 0: advance time + time_indices = time_indices + torch.where(active_mask, durations, torch.zeros_like(durations)) + break + + active_mask = time_indices < valid_lengths + + # Pad sequences to same length + max_len = max((len(seq) for seq in all_tokens), default=0) + if max_len == 0: + max_len = 1 + + sequences = torch.full((batch_size, max_len), self.config.pad_token_id, dtype=torch.long, device=device) + for i in range(batch_size): + seq_len = len(all_tokens[i]) + if seq_len > 0: + sequences[i, :seq_len] = torch.tensor(all_tokens[i], dtype=torch.long, device=device) + + token_timestamps = None + if return_timestamps: + seconds_per_frame = self.config.seconds_per_frame + token_timestamps = torch.full((batch_size, max_len), 0.0, dtype=torch.float, device=device) + for i in range(batch_size): + num_tokens = len(token_frame_indices[i]) + if num_tokens > 0: + token_timestamps[i, :num_tokens] = ( + torch.tensor(token_frame_indices[i], dtype=torch.float, device=device) * seconds_per_frame + ) + + if return_dict_in_generate: + return ParakeetGenerateOutput( + sequences=sequences, + token_timestamps=token_timestamps, + logits=None, + attentions=outputs.attentions, + hidden_states=outputs.hidden_states, + ) + + return sequences + + +__all__ = ["ParakeetForCTC", "ParakeetForTDT", "ParakeetEncoder", "ParakeetPreTrainedModel"] diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 594ea73e3f8a..0329443e1902 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -16,7 +16,6 @@ import math from collections.abc import Callable from dataclasses import dataclass -from typing import Optional, Union, Tuple import torch from torch import nn @@ -24,15 +23,15 @@ from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, CausalLMOutput, BaseModelOutputWithNoAttention +from ...modeling_outputs import BaseModelOutput, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_torchaudio_available from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule from ..llama.modeling_llama import LlamaAttention, eager_attention_forward -from .configuration_parakeet import PreTrainedConfig, ParakeetCTCConfig, ParakeetTDTConfig, ParakeetEncoderConfig, ParakeetTDTDecoderConfig, ParakeetTDTJointConfig +from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig @dataclass @@ -122,7 +121,9 @@ def __init__(self, config: ParakeetEncoderConfig, layer_idx: int): super().__init__(config, layer_idx=layer_idx) self.is_causal = False # W_{k,R} projection - self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.relative_k_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) # global content bias self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim)) # global positional bias @@ -312,7 +313,7 @@ def forward( @auto_docstring class ParakeetPreTrainedModel(PreTrainedModel): - config: PreTrainedConfig + config: ParakeetCTCConfig base_model_prefix = "model" main_input_name = "input_features" input_modalities = "audio" @@ -347,13 +348,21 @@ def _init_weights(self, module): init.normal_(module.bias_u, mean=0.0, std=std) init.normal_(module.bias_v, mean=0.0, std=std) elif isinstance(module, ParakeetEncoderRelPositionalEncoding): + encoder_config = getattr(self.config, "encoder_config", self.config) inv_freq = 1.0 / ( - 10000.0 ** (torch.arange(0, self.config.hidden_size, 2, dtype=torch.int64) / self.config.hidden_size) + 10000.0 + ** (torch.arange(0, encoder_config.hidden_size, 2, dtype=torch.int64) / encoder_config.hidden_size) ) init.copy_(module.inv_freq, inv_freq) + elif isinstance(module, nn.LSTM): + for name, param in module.named_parameters(): + if "weight" in name: + init.normal_(param, mean=0.0, std=std) + elif "bias" in name: + init.zeros_(param) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): - encoder_config = self.config.encoder_config if isinstance(self.config, (ParakeetCTCConfig, ParakeetTDTConfig)) else self.config + encoder_config = getattr(self.config, "encoder_config", self.config) kernel_size = encoder_config.subsampling_conv_kernel_size stride = encoder_config.subsampling_conv_stride @@ -453,6 +462,7 @@ def forward( position_embeddings, p=self.dropout_positions, training=self.training ) + output_mask = None if attention_mask is not None: output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1]) attention_mask = output_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1) @@ -476,7 +486,8 @@ def forward( ) return ParakeetEncoderModelOutput( - last_hidden_state=hidden_states, attention_mask=output_mask.int() if output_attention_mask else None + last_hidden_state=hidden_states, + attention_mask=output_mask.int() if output_attention_mask and output_mask is not None else None, ) @@ -489,6 +500,9 @@ class ParakeetGenerateOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. + token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Token-level timestamps in seconds indicating when each token was emitted. Only returned by TDT models + when `return_timestamps=True` is passed to `generate()`. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for @@ -502,282 +516,12 @@ class ParakeetGenerateOutput(ModelOutput): """ sequences: torch.LongTensor + token_timestamps: torch.FloatTensor | None = None logits: tuple[torch.FloatTensor] | None = None attentions: tuple[tuple[torch.FloatTensor]] | None = None hidden_states: tuple[tuple[torch.FloatTensor]] | None = None -class ParakeetLSTM(torch.nn.Module): - def __init__( - self, - input_size: int, - hidden_size: int, - num_layers: int, - dropout: Optional[float], - forget_gate_bias: Optional[float], - t_max: Optional[int] = None, - weights_init_scale: float = 1.0, - hidden_hidden_bias_scale: float = 0.0, - proj_size: int = 0, - ): - """Returns an LSTM with forget gate bias init to `forget_gate_bias`. - Args: - input_size: See `torch.nn.LSTM`. - hidden_size: See `torch.nn.LSTM`. - num_layers: See `torch.nn.LSTM`. - dropout: See `torch.nn.LSTM`. - - forget_gate_bias: float, set by default to 1.0, which constructs a forget gate - initialized to 1.0. - Reference: - [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf) - - t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization - of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course - of training. - Reference: - [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) - - weights_init_scale: Float scale of the weights after initialization. Setting to lower than one - sometimes helps reduce variance between runs. - - hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for - the default behaviour. - - Returns: - A `torch.nn.LSTM`. - """ - super(ParakeetLSTM, self).__init__() - - self.lstm = torch.nn.LSTM( - input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, proj_size=proj_size - ) - - if t_max is not None: - # apply chrono init - for name, v in self.lstm.named_parameters(): - if 'bias' in name: - p = getattr(self.lstm, name) - n = p.nelement() - hidden_size = n // 4 - p.data.fill_(0) - p.data[hidden_size : 2 * hidden_size] = torch.log( - torch.nn.init.uniform_(p.data[0:hidden_size], 1, t_max - 1) - ) - # forget gate biases = log(uniform(1, Tmax-1)) - p.data[0:hidden_size] = -p.data[hidden_size : 2 * hidden_size] - # input gate biases = -(forget gate biases) - - elif forget_gate_bias is not None: - for name, v in self.lstm.named_parameters(): - if "bias_ih" in name: - bias = getattr(self.lstm, name) - bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias) - if "bias_hh" in name: - bias = getattr(self.lstm, name) - bias.data[hidden_size : 2 * hidden_size] *= float(hidden_hidden_bias_scale) - - self.dropout = torch.nn.Dropout(dropout) if dropout else None - - for name, v in self.named_parameters(): - if 'weight' in name or 'bias' in name: - v.data *= float(weights_init_scale) - - def forward( - self, x: torch.Tensor, h: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - - x, h = self.lstm(x, h) - - if self.dropout: - x = self.dropout(x) - - return x, h - -class ParakeetTDTJoint(ParakeetPreTrainedModel): - config: ParakeetTDTJointConfig - base_model_prefix = "" #joint" - main_input_name = "enc" - _supports_flat_attention_mask = False - _supports_sdpa = True - _supports_flex_attn = False - _supports_attention_backend = False - _can_record_outputs = {} - _no_split_modules = None - - def __init__(self, config: ParakeetTDTJointConfig): - super().__init__(config) - self.config = config - self.gradient_checkpointing = False - - self.enc = torch.nn.Linear(config.enc_hidden_size, config.hidden_size) - self.pred = torch.nn.Linear(config.pred_hidden_size, config.hidden_size) - - num_classes = config.vocab_size + 1 + len(config.durations) - - layers = ( - [torch.nn.ReLU(inplace=True)] - + ([torch.nn.Dropout(p=self.config.dropout)]) - + [torch.nn.Linear(config.hidden_size, num_classes)] - ) - self.joint_net = torch.nn.Sequential(*layers) - self.post_init() - - @auto_docstring - @check_model_inputs() - def forward( - self, - enc: torch.Tensor, - pred: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutputWithNoAttention: - - # Right now we only support joint for inference. - - pred = pred.view([-1, self.config.pred_hidden_size]) # making it B, D - enc = enc.view([-1, self.config.enc_hidden_size]) # making it B, D - enc = self.enc(enc) - pred = self.pred(pred) - - assert enc.shape[0] == pred.shape[0] - output = self.joint_net(enc + pred) - return BaseModelOutput(last_hidden_state=output) - - -class ParakeetTDTPredictor(ParakeetPreTrainedModel): - - def __init__(self, config: ParakeetTDTDecoderConfig): - super().__init__(config) - self.gradient_checkpointing = False - self.config = config - - self.embed = torch.nn.Embedding(config.vocab_size + 1, config.hidden_size) # +1 for blank - self.dec_rnn = self.rnn( - config.hidden_size, - config.hidden_size, - config.num_hidden_layers + 1, - config.forget_gate_bias, - config.dropout, - config.t_max, - config.weights_init_scale, - config.hidden_hidden_bias_scale, - ) - self.post_init() - - - def rnn( - self, - input_size: int, - hidden_size: int, - num_layers: int, - forget_gate_bias: Optional[float] = 1.0, - dropout: Optional[float] = 0.0, - t_max: Optional[int] = None, - weights_init_scale: float = 1.0, - hidden_hidden_bias_scale: float = 0.0, - proj_size: int = 0, - ) -> torch.nn.Module: - return ParakeetLSTM( - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - dropout=dropout, - forget_gate_bias=forget_gate_bias, - t_max=t_max, - weights_init_scale=weights_init_scale, - hidden_hidden_bias_scale=hidden_hidden_bias_scale, - proj_size=proj_size, - ) - - - @auto_docstring - @check_model_inputs() - @can_return_tuple - def forward( - self, - input_token, - states, - hidden_state = None, - **kwargs: Unpack[TransformersKwargs], - ): - assert input_token is not None - - device = self.embed.weight.device - if input_token.device != device: - input_token = input_token.to(device) - return self.predict(input_token, state=states) - - def predict(self, y, state): - # Get device and dtype of current module - - # (B, U) -> (B, U, H) - y = self.embed(y).transpose(0, 1) # (U + 1, B, H) - - g, hid = self.dec_rnn(y, state) - g = g.transpose(0, 1).transpose(1, 2) # (B, H, U + 1) - - return g, hid - - - -@auto_docstring( - custom_intro=""" - The Parakeet TDT Decoder. This class encapsulates both the predictor and joint network for TDT models. - """ -) -class ParakeetTDTDecoder(ParakeetPreTrainedModel): - config: ParakeetTDTDecoderConfig - base_model_prefix = "decoder" - main_input_name = "input_token" - _supports_flat_attention_mask = False - _supports_sdpa = True - _supports_flex_attn = False - _supports_attention_backend = False - _can_record_outputs = {} - _no_split_modules = None - - def __init__(self, config: ParakeetTDTDecoderConfig): - super().__init__(config) - self.config = config - self.gradient_checkpointing = False - self.prediction = ParakeetTDTPredictor(config) - self.post_init() - - def _init_weights(self, module): - if hasattr(self.config, "initializer_range"): - std = self.config.initializer_range - else: - # 0.02 is the standard default value accross the library - std = getattr(self.config.get_text_config(), "initializer_range", 0.02) - - module.prediction.embed.weight.data.normal_(mean=0.0, std=std) - for param in module.prediction.dec_rnn.lstm.parameters(): - param.data.normal_(mean=0.0, std=std) - - def get_input_embeddings(self): - return self.prediction.embed - - def set_input_embeddings(self, embed): - self.prediction.embed = embed - - @auto_docstring - @check_model_inputs() - @can_return_tuple - def forward( - self, - input_token, - hidden_state = None, - **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutputWithNoAttention: - - if hidden_state is not None: - hidden_state = tuple(hidden_state.unbind(dim=0)) - - h_out, h_state = self.prediction(input_token, hidden_state, **kwargs) - return BaseModelOutputWithNoAttention(h_out, torch.stack(h_state, dim=0)) - - - @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -922,9 +666,58 @@ def generate( return sequences +class ParakeetTDTDecoder(nn.Module): + """LSTM-based prediction network for TDT.""" + + def __init__(self, config: ParakeetTDTConfig): + super().__init__() + self.embedding = nn.Embedding(config.vocab_size + 1, config.decoder_hidden_size) + self.lstm = nn.LSTM( + input_size=config.decoder_hidden_size, + hidden_size=config.decoder_hidden_size, + num_layers=config.num_decoder_layers, + batch_first=True, + ) + self.decoder_projector = nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size) + + def forward( + self, + input_ids: torch.LongTensor, + hidden_state: torch.Tensor | None = None, + cell_state: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + embeddings = self.embedding(input_ids) + lstm_state = (hidden_state, cell_state) if hidden_state is not None else None + lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, lstm_state) + decoder_output = self.decoder_projector(lstm_output) + return decoder_output, hidden_state, cell_state + + +class ParakeetTDTJointNetwork(nn.Module): + """Joint network that combines encoder and decoder outputs to predict tokens and durations.""" + + def __init__(self, config: ParakeetTDTConfig): + super().__init__() + self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) + self.activation = ACT2FN[config.hidden_act] + self.token_head = nn.Linear(config.decoder_hidden_size, config.vocab_size + 1) + self.duration_head = nn.Linear(config.decoder_hidden_size, config.num_duration_bins) + + def forward( + self, + encoder_output: torch.Tensor, + decoder_output: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + encoder_projected = self.encoder_projector(encoder_output) + joint_output = self.activation(encoder_projected + decoder_output) + token_logits = self.token_head(joint_output) + duration_logits = self.duration_head(joint_output) + return token_logits, duration_logits + + @auto_docstring( custom_intro=""" - Parakeet TDT model. + Parakeet model with TDT (Token Duration Transducer) head for speech recognition. """ ) class ParakeetForTDT(ParakeetPreTrainedModel): @@ -933,10 +726,9 @@ class ParakeetForTDT(ParakeetPreTrainedModel): def __init__(self, config: ParakeetTDTConfig): super().__init__(config) self.encoder = ParakeetEncoder(config.encoder_config) - self.decoder = ParakeetTDTDecoder(config.decoder_config) - self.joint = ParakeetTDTJoint(config.joint_config) - self.blank_token_id = config.blank_token_id - self.max_token_per_frame = 2 + self.decoder = ParakeetTDTDecoder(config) + self.joint = ParakeetTDTJointNetwork(config) + self.post_init() @auto_docstring @@ -944,18 +736,86 @@ def __init__(self, config: ParakeetTDTConfig): def forward( self, input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], - ): + ) -> CausalLMOutput: + r""" + Example: + + ```python + >>> from transformers import AutoProcessor, ParakeetForTDT + >>> from datasets import load_dataset, Audio + + >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = ParakeetForTDT.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) + + >>> inputs = processor(ds[0]["audio"]["array"]) + >>> outputs = model(**inputs) + ``` + """ encoder_outputs = self.encoder( input_features=input_features, + attention_mask=attention_mask, **kwargs, ) - logits = self.joint.joint_net(self.joint.enc(encoder_outputs.last_hidden_state)) #[:,:,:self.joint.vocab_size] + encoder_hidden_states = encoder_outputs.last_hidden_state + + loss = None + if labels is not None: + if not is_torchaudio_available(): + raise ImportError( + "torchaudio is required for TDT loss computation. Install it with: pip install torchaudio" + ) + from torchaudio.functional import rnnt_loss + + # Compute encoder output lengths + attention_mask = ( + attention_mask + if attention_mask is not None + else torch.ones(input_features.shape[:-1], dtype=torch.long, device=input_features.device) + ) + encoder_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) + + # Compute target lengths (non-pad tokens) + labels_mask = labels != self.config.pad_token_id + target_lengths = labels_mask.sum(-1) + + # Prepare decoder input: prepend blank token to labels + blank_tokens = torch.full( + (labels.shape[0], 1), self.config.pad_token_id, dtype=labels.dtype, device=labels.device + ) + decoder_input = torch.cat([blank_tokens, labels], dim=1) + + # Run decoder on full label sequence: (batch, U+1, decoder_hidden_size) + decoder_output, _, _ = self.decoder(decoder_input) + + # Compute joint output for all (T, U+1) pairs via broadcasting + # encoder: (batch, T, 1, encoder_hidden) -> projected to (batch, T, 1, decoder_hidden_size) + # decoder: (batch, 1, U+1, decoder_hidden_size) + token_logits, _ = self.joint( + encoder_hidden_states.unsqueeze(2), + decoder_output.unsqueeze(1), + ) + # token_logits: (batch, T, U+1, vocab_size+1) + + loss = rnnt_loss( + logits=token_logits.float(), + targets=labels.int(), + logit_lengths=encoder_lengths.int(), + target_lengths=target_lengths.int(), + blank=self.config.pad_token_id, + reduction="mean", + ) return CausalLMOutput( - loss=torch.sum(encoder_outputs.last_hidden_state), # a fake loss here. - logits=logits, + loss=loss, + logits=encoder_hidden_states, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @@ -964,64 +824,166 @@ def forward( def generate( self, input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + return_timestamps: bool = False, + return_dict_in_generate: bool = False, **kwargs: Unpack[TransformersKwargs], - ): + ) -> ParakeetGenerateOutput | torch.LongTensor: + r""" + Perform TDT greedy decoding to generate token sequences. - encoder_outputs = self.encoder( + Args: + return_timestamps (`bool`, *optional*, defaults to `False`): + Whether to return per-token timestamps in seconds. When `True`, forces + `return_dict_in_generate=True` and includes `token_timestamps` in the output. + + Example: + + ```python + >>> from transformers import AutoProcessor, ParakeetForTDT + >>> from datasets import load_dataset, Audio + + >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = ParakeetForTDT.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) + + >>> inputs = processor(ds[0]["audio"]["array"]) + >>> output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) + + >>> transcription = processor.batch_decode(output.sequences, skip_special_tokens=True) + >>> print(transcription) + >>> print(output.token_timestamps) + ``` + """ + if return_timestamps: + return_dict_in_generate = True + + blank_id = self.config.pad_token_id + max_symbols_per_step = self.config.max_symbols_per_step + device = input_features.device + batch_size = input_features.shape[0] + + kwargs["return_dict"] = True + outputs: CausalLMOutput = self( input_features=input_features, + attention_mask=attention_mask, **kwargs, ) - output = self.greedy_decode(encoder_outputs.last_hidden_state) - - return output - - def greedy_decode(self, encoder_output): - T = encoder_output.shape[1] - t = 0 - hyp = [] - last_label = torch.LongTensor([[self.blank_token_id]]) - dec_out = self.decoder(input_token=last_label) - g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states - - symbols_added = 0 - while t < T: - enc = encoder_output[0,t,:] - while symbols_added < self.max_token_per_frame: - logits = self.joint(enc, g).last_hidden_state - - logits = logits.view([-1]) - - token_logits = logits[:self.blank_token_id + 1].softmax(-1) - duration_logits = logits[self.blank_token_id + 1:].softmax(-1) - - v, token = token_logits.max(-1) - v_duration, duration = duration_logits.max(-1) - token = token.item() - duration = duration.item() - - if token != self.blank_token_id: - hyp.append(token) - last_label = token - last_label = torch.LongTensor([[last_label]]) - dec_out = self.decoder(last_label, hidden_prime) - g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states - - if duration == 0: - symbols_added += 1 - else: - t += duration - symbols_added = 0 - break + encoder_hidden_states = outputs.logits - if symbols_added == self.max_token_per_frame: - t += 1 - symbols_added = 0 + sequence_length = encoder_hidden_states.shape[1] + if attention_mask is not None: + encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) + valid_lengths = encoder_attention_mask.sum(dim=1).int() + else: + valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) + + # Initialize decoder LSTM state + hidden_state = torch.zeros( + self.config.num_decoder_layers, + batch_size, + self.config.decoder_hidden_size, + device=device, + dtype=encoder_hidden_states.dtype, + ) + cell_state = torch.zeros_like(hidden_state) + + # Initialize with blank token + prev_tokens = torch.full((batch_size, 1), blank_id, dtype=torch.long, device=device) + decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) + + all_tokens = [[] for _ in range(batch_size)] + token_frame_indices = [[] for _ in range(batch_size)] if return_timestamps else None + time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) + active_mask = time_indices < valid_lengths + + while active_mask.any(): + safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) + encoder_frames = encoder_hidden_states[ + torch.arange(batch_size, device=device), safe_time_indices + ].unsqueeze(1) + + symbols_added = 0 + while symbols_added < max_symbols_per_step: + token_logits, duration_logits = self.joint(encoder_frames, decoder_output) + token_logits = token_logits.squeeze(1) + duration_logits = duration_logits.squeeze(1) + + tokens = token_logits.argmax(dim=-1) + durations = duration_logits.argmax(dim=-1) + + is_blank = tokens == blank_id + emit_mask = active_mask & ~is_blank + + for i in range(batch_size): + if emit_mask[i]: + all_tokens[i].append(tokens[i].item()) + if token_frame_indices is not None: + token_frame_indices[i].append(time_indices[i].item()) + + if emit_mask.any(): + new_prev_tokens = tokens.unsqueeze(1) + new_decoder_output, new_hidden_state, new_cell_state = self.decoder( + new_prev_tokens, hidden_state, cell_state + ) + emit_mask_expanded = emit_mask.view(batch_size, 1, 1) + decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) - return hyp + emit_mask_state = emit_mask.view(1, batch_size, 1) + hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) + cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) + # If duration is 0, stay on same frame (emit more tokens) + stay_mask = active_mask & (durations == 0) + if stay_mask.any(): + symbols_added += 1 + if symbols_added >= max_symbols_per_step: + time_indices = time_indices + 1 + break + continue + + # Duration > 0: advance time + time_indices = time_indices + torch.where(active_mask, durations, torch.zeros_like(durations)) + break + + active_mask = time_indices < valid_lengths + + # Pad sequences to same length + max_len = max((len(seq) for seq in all_tokens), default=0) + if max_len == 0: + max_len = 1 + + sequences = torch.full((batch_size, max_len), self.config.pad_token_id, dtype=torch.long, device=device) + for i in range(batch_size): + seq_len = len(all_tokens[i]) + if seq_len > 0: + sequences[i, :seq_len] = torch.tensor(all_tokens[i], dtype=torch.long, device=device) + + token_timestamps = None + if return_timestamps: + seconds_per_frame = self.config.seconds_per_frame + token_timestamps = torch.full((batch_size, max_len), 0.0, dtype=torch.float, device=device) + for i in range(batch_size): + num_tokens = len(token_frame_indices[i]) + if num_tokens > 0: + token_timestamps[i, :num_tokens] = ( + torch.tensor(token_frame_indices[i], dtype=torch.float, device=device) * seconds_per_frame + ) + if return_dict_in_generate: + return ParakeetGenerateOutput( + sequences=sequences, + token_timestamps=token_timestamps, + logits=None, + attentions=outputs.attentions, + hidden_states=outputs.hidden_states, + ) + return sequences -__all__ = ["ParakeetForCTC", "ParakeetForTDT", "ParakeetEncoder", "ParakeetTDTDecoder", "ParakeetTDTJoint", "ParakeetPreTrainedModel"] +__all__ = ["ParakeetForCTC", "ParakeetForTDT", "ParakeetEncoder", "ParakeetPreTrainedModel"] diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index da825ff39223..481ec4c79021 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -98,7 +98,6 @@ AutoModelForAudioClassification, AutoModelForCausalLM, AutoModelForCTC, - AutoModelForTDT, AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, @@ -114,6 +113,7 @@ AutoModelForSequenceClassification, AutoModelForSpeechSeq2Seq, AutoModelForTableQuestionAnswering, + AutoModelForTDT, AutoModelForTextToSpectrogram, AutoModelForTextToWaveform, AutoModelForTokenClassification, diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 30eb9c987697..f7af0df8fe69 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -579,12 +579,14 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): out["stride"] = rescale_stride([stride], ratio)[0] else: out["stride"] = rescale_stride(stride, ratio) - elif self.type == 'tdt': + elif self.type == "tdt": inputs = { self.model.main_input_name: model_inputs.pop(self.model.main_input_name), } + if "attention_mask" in model_inputs: + inputs["attention_mask"] = model_inputs.pop("attention_mask") outputs = self.model.generate(**inputs) - out = {"tokens": torch.LongTensor(outputs).view([1, -1])} + out = {"tokens": outputs} else: raise ValueError("Unsupported model type {self.type}.") diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 46e53421dbd5..b4279b1d9d24 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -14,7 +14,6 @@ """Testing suite for the PyTorch Parakeet model.""" import json -import copy import tempfile import unittest from pathlib import Path @@ -23,7 +22,7 @@ from transformers.testing_utils import cleanup, require_torch, slow, torch_device from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask +from ...test_modeling_common import ModelTesterMixin, floats_tensor, random_attention_mask if is_datasets_available(): @@ -35,15 +34,11 @@ from transformers import ( AutoProcessor, ParakeetCTCConfig, - ParakeetTDTConfig, ParakeetEncoder, ParakeetEncoderConfig, - ParakeetTDTDecoder, - ParakeetTDTDecoderConfig, - ParakeetTDTJoint, - ParakeetTDTJointConfig, ParakeetForCTC, ParakeetForTDT, + ParakeetTDTConfig, ) @@ -63,7 +58,7 @@ def __init__( conv_kernel_size=9, subsampling_factor=8, subsampling_conv_channels=32, - use_bias=True, + attention_bias=True, num_mel_bins=80, scale_input=True, ): @@ -84,7 +79,7 @@ def __init__( self.conv_kernel_size = conv_kernel_size self.subsampling_factor = subsampling_factor self.subsampling_conv_channels = subsampling_conv_channels - self.use_bias = use_bias + self.attention_bias = attention_bias self.num_mel_bins = num_mel_bins self.scale_input = scale_input @@ -115,7 +110,7 @@ def get_config(self): conv_kernel_size=self.conv_kernel_size, subsampling_factor=self.subsampling_factor, subsampling_conv_channels=self.subsampling_conv_channels, - use_bias=self.use_bias, + attention_bias=self.attention_bias, num_mel_bins=self.num_mel_bins, scale_input=self.scale_input, ) @@ -139,34 +134,6 @@ def prepare_config_and_inputs_for_common(self): } return config, inputs_dict - def check_ctc_loss(self, config, input_values, *args): - model = ParakeetForCTC(config=config) - model.to(torch_device) - - # make sure that dropout is disabled - model.eval() - - input_values = input_values[:3] - attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) - - input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] - max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) - labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size) - - # pad input - for i in range(len(input_lengths)): - input_values[i, input_lengths[i] :] = 0.0 - attention_mask[i, input_lengths[i] :] = 0 - - model.config.ctc_loss_reduction = "sum" - sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() - - model.config.ctc_loss_reduction = "mean" - mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() - - self.parent.assertTrue(isinstance(sum_loss, float)) - self.parent.assertTrue(isinstance(mean_loss, float)) - @require_torch class ParakeetEncoderModelTest(ModelTesterMixin, unittest.TestCase): @@ -190,232 +157,6 @@ def test_model_get_set_embeddings(self): pass -class ParakeetTDTDecoderModelTester: - def __init__( - self, - parent, - batch_size=16, - vocab_size=128, - hidden_size=64, - num_hidden_layers=2, - seq_length=32, - is_training=True, - dropout=0, # so gradient checkpointing doesn't fail - ): - # testing suite parameters - self.parent = parent - self.batch_size = batch_size - self.is_training = is_training - - # config parameters - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.seq_length = seq_length - self.output_seq_length = seq_length - self.vocab_size = vocab_size - - def prepare_config_and_inputs(self): - input_token = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - config = self.get_config() - - return config, input_token - - def get_config(self): - return ParakeetTDTDecoderConfig( - num_hidden_layers=self.num_hidden_layers, - hidden_size=self.hidden_size, - vocab_size=self.vocab_size, - ) - - def create_and_check_model(self, config, input_token): - pass - model = ParakeetTDTDecoder(config=config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - result = model(input_token) - - self.parent.assertEqual( - result.last_hidden_state.shape, (self.batch_size, config.hidden_size, self.output_seq_length) - ) - - def prepare_config_and_inputs_for_common(self): - config, input_token = self.prepare_config_and_inputs() - inputs_dict = { - "input_token": input_token, - } - return config, inputs_dict - - - - -@require_torch -class ParakeetTDTDecoderModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (ParakeetTDTDecoder,) if is_torch_available() else () - - test_resize_embeddings = False - test_torch_exportable = True - has_attentions = False - is_encoder_decoder = False - - def setUp(self): - self.model_tester = ParakeetTDTDecoderModelTester(self) - self.config_tester = ConfigTester(self, config_class=ParakeetTDTDecoderConfig, has_text_modality=False, common_properties=['hidden_size','num_hidden_layers']) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_hidden_states_output(self): - def check_hidden_states_output(inputs_dict, config, model_class): - model = model_class(copy.deepcopy(config)) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - - hidden_states = outputs.hidden_states - - expected_num_layers = getattr( - self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 - ) - self.assertEqual(hidden_states.shape[1], expected_num_layers) - - if hasattr(self.model_tester, "encoder_seq_length"): - seq_length = self.model_tester.encoder_seq_length - if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1: - seq_length = seq_length * self.model_tester.chunk_length - else: - seq_length = self.model_tester.seq_length - - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - inputs_dict["output_hidden_states"] = True - check_hidden_states_output(inputs_dict, config, model_class) - - # check that output_hidden_states also work using config - del inputs_dict["output_hidden_states"] - config.output_hidden_states = True - for k in config.sub_configs: - if getattr(config, k) is not None: - getattr(config, k).output_hidden_states = True - - check_hidden_states_output(inputs_dict, config, model_class) - - @unittest.skip(reason="this class only returns the last hidden state not prior ones, and there is no gradient on last hidden state w.r.t output.") - def test_retain_grad_hidden_states_attentions(self): - pass - - -class ParakeetTDTJointModelTester: - def __init__( - self, - parent, - batch_size=16, - vocab_size=128, - hidden_size=64, - pred_hidden_size=64, - enc_hidden_size=64, - num_hidden_layers=2, - durations=[0,1,2,3,4], - is_training=True, - dropout=0.1, # so gradient checkpointing doesn't fail - ): - # testing suite parameters - self.parent = parent - self.batch_size = batch_size - self.is_training = is_training - - # config parameters - self.hidden_size = hidden_size - self.pred_hidden_size = pred_hidden_size - self.enc_hidden_size = enc_hidden_size - self.num_hidden_layers = num_hidden_layers - self.t_length = 1 # so far only support 1 - self.u_length = 1 # so far only support 1 - self.output_seq_length = -1 - self.vocab_size = vocab_size - self.durations = durations - - def prepare_config_and_inputs(self): - enc = floats_tensor([self.batch_size, self.t_length, self.enc_hidden_size]) - pred = floats_tensor([self.batch_size, self.u_length, self.pred_hidden_size]) - config = self.get_config() - - return config, enc, pred - - def get_config(self): - return ParakeetTDTJointConfig( - num_hidden_layers=self.num_hidden_layers, - hidden_size=self.hidden_size, - pred_hidden_size=self.enc_hidden_size, - enc_hidden_size=self.enc_hidden_size, - vocab_size=self.vocab_size, - durations=self.durations, - ) - - def create_and_check_model(self, config, enc, pred): - model = ParakeetTDTJoint(config=config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - result = model(enc, pred) - - self.parent.assertEqual( - result.last_hidden_state.shape, (self.batch_size, config.vocab_size + 1 + len(config.durations)) - ) - - def prepare_config_and_inputs_for_common(self): - config, enc, pred = self.prepare_config_and_inputs() - inputs_dict = { - "enc": enc, - "pred": pred, - } - return config, inputs_dict - - - - -@require_torch -class ParakeetTDTJointModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (ParakeetTDTJoint,) if is_torch_available() else () - - test_resize_embeddings = False - test_torch_exportable = True - has_attentions = False - is_encoder_decoder = False - - def setUp(self): - self.model_tester = ParakeetTDTJointModelTester(self) - self.config_tester = ConfigTester(self, config_class=ParakeetTDTJointConfig, has_text_modality=False, common_properties=['hidden_size']) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - @unittest.skip(reason="this class doesn't have hidden states.") - def test_retain_grad_hidden_states_attentions(self): - pass - - @unittest.skip(reason="this class doesn't have hidden states.") - def test_hidden_states_output(self): - pass - - @unittest.skip(reason="ParakeetJoint does not use inputs_embeds") - def test_model_get_set_embeddings(self): - pass - - - class ParakeetForCTCModelTester: def __init__(self, parent, encoder_kwargs=None, is_training=True, vocab_size=128, pad_token_id=0): if encoder_kwargs is None: @@ -462,10 +203,6 @@ def prepare_config_and_inputs_for_common(self): } return config, inputs_dict - def test_ctc_loss_inference(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.encoder_model_tester.check_ctc_loss(*config_and_inputs) - @require_torch class ParakeetForCTCModelTest(ModelTesterMixin, unittest.TestCase): @@ -543,7 +280,6 @@ def tearDown(self): @classmethod def _load_dataset(cls): - # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process. if cls._dataset is None: cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") cls._dataset = cls._dataset.cast_column( @@ -558,10 +294,6 @@ def _load_datasamples(self, num_samples): @slow def test_1b_model_integration(self): - """ - bezzam reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single-py - eustlb reproducer: https://gist.github.com/eustlb/6e9e3aa85de3f7c340ec3c36e65f2fe6 - """ RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single.json" with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) @@ -573,7 +305,6 @@ def test_1b_model_integration(self): model.eval() model.to(torch_device) - # -- apply inputs = self.processor(samples) inputs.to(torch_device, dtype=self.dtype) predicted_ids = model.generate(**inputs) @@ -583,11 +314,6 @@ def test_1b_model_integration(self): @slow def test_1b_model_integration_batched(self): - """ - bezzam reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batched-py - eustlb reproducer: https://gist.github.com/eustlb/575b5da58de34a70116a1955b1183596 - """ - RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch.json" with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) @@ -599,7 +325,6 @@ def test_1b_model_integration_batched(self): model.eval() model.to(torch_device) - # -- apply inputs = self.processor(samples) inputs.to(torch_device, dtype=self.dtype) predicted_ids = model.generate(**inputs) @@ -608,58 +333,57 @@ def test_1b_model_integration_batched(self): self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) - class ParakeetForTDTModelTester: - def __init__(self, - parent, - encoder_kwargs=None, - decoder_kwargs=None, - joint_kwargs=None, - is_training=True, - vocab_size=128, - durations=[0,1,2,3,4], - pad_token_id=0 - ): + def __init__( + self, + parent, + encoder_kwargs=None, + is_training=True, + vocab_size=128, + decoder_hidden_size=64, + num_decoder_layers=1, + num_duration_bins=5, + hidden_act="relu", + max_symbols_per_step=10, + pad_token_id=128, + ): if encoder_kwargs is None: encoder_kwargs = {} - if decoder_kwargs is None: - decoder_kwargs = {} - if joint_kwargs is None: - joint_kwargs = {} self.parent = parent self.encoder_model_tester = ParakeetEncoderModelTester(parent, **encoder_kwargs) - self.decoder_model_tester = ParakeetTDTDecoderModelTester(parent, **decoder_kwargs) - self.joint_model_tester = ParakeetTDTJointModelTester(parent, **joint_kwargs) self.is_training = is_training self.batch_size = self.encoder_model_tester.batch_size self.output_seq_length = self.encoder_model_tester.output_seq_length self.num_hidden_layers = self.encoder_model_tester.num_hidden_layers - self.seq_length = vocab_size - self.enc_hidden_size = self.encoder_model_tester.hidden_size - self.hidden_size = self.encoder_model_tester.hidden_size # this field is needed for test class - self.pred_hidden_size = self.decoder_model_tester.hidden_size - self.joint_hidden_size = self.joint_model_tester.hidden_size - - self.durations = durations + self.hidden_size = self.encoder_model_tester.hidden_size + self.seq_length = self.encoder_model_tester.output_seq_length + self.encoder_seq_length = self.encoder_model_tester.output_seq_length - self.vocab_size = vocab_size + len(self.durations) + 1 + self.vocab_size = vocab_size + self.decoder_hidden_size = decoder_hidden_size + self.num_decoder_layers = num_decoder_layers + self.num_duration_bins = num_duration_bins + self.hidden_act = hidden_act + self.max_symbols_per_step = max_symbols_per_step self.pad_token_id = pad_token_id - def prepare_config_and_inputs(self): _, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs() config = self.get_config() return config, input_features, attention_mask def get_config(self): - return ParakeetTDTConfig.from_configs( - encoder_config=self.encoder_model_tester.get_config(), - decoder_config=self.decoder_model_tester.get_config(), - joint_config=self.joint_model_tester.get_config(), + return ParakeetTDTConfig( vocab_size=self.vocab_size, - durations=self.durations, + decoder_hidden_size=self.decoder_hidden_size, + num_decoder_layers=self.num_decoder_layers, + num_duration_bins=self.num_duration_bins, + hidden_act=self.hidden_act, + max_symbols_per_step=self.max_symbols_per_step, + encoder_config=self.encoder_model_tester.get_config().to_dict(), + pad_token_id=self.pad_token_id, ) def create_and_check_model(self, config, input_features, attention_mask): @@ -669,7 +393,10 @@ def create_and_check_model(self, config, input_features, attention_mask): with torch.no_grad(): result = model(input_features, attention_mask=attention_mask) - self.parent.assertEqual(result.logits.shape, (self.batch_size, self.output_seq_length, self.vocab_size)) + # forward() returns encoder hidden states as logits + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.output_seq_length, self.encoder_model_tester.hidden_size) + ) def prepare_config_and_inputs_for_common(self): config, input_features, attention_mask = self.prepare_config_and_inputs() @@ -695,7 +422,6 @@ class ParakeetForTDTModelTest(ModelTesterMixin, unittest.TestCase): test_attention_outputs = False test_resize_embeddings = False - test_torch_exportable = True _is_composite = True @@ -710,16 +436,11 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip(reason="ParakeetEncoder does not use inputs_embeds") + @unittest.skip(reason="ParakeetForTDT does not use inputs_embeds") def test_model_get_set_embeddings(self): pass - @unittest.skip(reason="batching not supported") - def test_batching_equivalence(self): - pass - # Original function assumes vision+text model, so overwrite since Parakeet is audio+text - # Below is modified from `tests/models/granite_speech/test_modeling_granite_speech.py` def test_sdpa_can_dispatch_composite_models(self): if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") @@ -745,6 +466,20 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") + def test_generate(self): + """Test that generate() produces valid output.""" + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + model = ParakeetForTDT(config=config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + sequences = model.generate(input_features, attention_mask=attention_mask) + + self.assertIsInstance(sequences, torch.Tensor) + self.assertEqual(sequences.dim(), 2) + self.assertEqual(sequences.shape[0], self.model_tester.batch_size) + @require_torch class ParakeetForTDTIntegrationTest(unittest.TestCase): @@ -752,16 +487,15 @@ class ParakeetForTDTIntegrationTest(unittest.TestCase): @classmethod def setUp(cls): - cls.checkpoint_name = "hainanx/parakeet-tdt-0.6b-v3" + cls.checkpoint_name = "nvidia/parakeet-tdt-0.6b-v3" cls.dtype = torch.bfloat16 - cls.processor = AutoProcessor.from_pretrained("hainanx/parakeet-tdt-0.6b-v3") + cls.processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") def tearDown(self): cleanup(torch_device, gc_collect=True) @classmethod def _load_dataset(cls): - # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process. if cls._dataset is None: cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") cls._dataset = cls._dataset.cast_column( @@ -775,26 +509,15 @@ def _load_datasamples(self, num_samples): return [x["array"] for x in speech_samples] @slow - def test_1b_model_integration(self): - """ - bezzam reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single-py - eustlb reproducer: https://gist.github.com/eustlb/6e9e3aa85de3f7c340ec3c36e65f2fe6 - """ - RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single.json" - with open(RESULTS_PATH, "r") as f: - raw_data = json.load(f) - EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) - EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] - + def test_tdt_model_integration(self): samples = self._load_datasamples(1) model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) model.eval() model.to(torch_device) - # -- apply inputs = self.processor(samples) inputs.to(torch_device, dtype=self.dtype) - predicted_ids = model.generate(**inputs) - torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) - predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) - self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) + output = model.generate(**inputs, return_dict_in_generate=True) + predicted_transcripts = self.processor.batch_decode(output.sequences, skip_special_tokens=True) + self.assertTrue(len(predicted_transcripts) > 0) + self.assertTrue(len(predicted_transcripts[0]) > 0) From fa36657f86e55df11ef94f683713bd210ed87c9f Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 25 Feb 2026 17:23:42 +0100 Subject: [PATCH 03/67] Add expected outputs for TDT, small fixes. --- docs/source/en/model_doc/parakeet.md | 6 +-- src/transformers/convert_slow_tokenizer.py | 6 ++- .../models/parakeet/configuration_parakeet.py | 8 ++- .../models/parakeet/convert_nemo_to_hf.py | 19 +++++++ .../models/parakeet/modeling_parakeet.py | 4 +- .../models/parakeet/modular_parakeet.py | 4 +- .../parakeet/expected_results_batch_tdt.json | 1 + .../parakeet/expected_results_single_tdt.json | 1 + .../models/parakeet/test_modeling_parakeet.py | 50 +++++++++++++++++-- 9 files changed, 79 insertions(+), 20 deletions(-) create mode 100644 tests/fixtures/parakeet/expected_results_batch_tdt.json create mode 100644 tests/fixtures/parakeet/expected_results_single_tdt.json diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index a758608482e3..e709f9f54ce0 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -43,11 +43,11 @@ Parakeet models, [introduced by NVIDIA NeMo](https://developer.nvidia.com/blog/p The original implementation can be found in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo). Model checkpoints are to be found under [the NVIDIA organization](https://huggingface.co/nvidia/models?search=parakeet). -This model was contributed by [Nithin Rao Koluguri](https://huggingface.co/nithinraok), [Eustache Le Bihan](https://huggingface.co/eustlb) and [Eric Bezzam](https://huggingface.co/bezzam). +This model was contributed by [Nithin Rao Koluguri](https://huggingface.co/nithinraok), [Eustache Le Bihan](https://huggingface.co/eustlb), [Eric Bezzam](https://huggingface.co/bezzam), [Maksym Lypivskyi](https://huggingface.co/MaksL), and [Hainan Xu](https://huggingface.co/hainanx). ## Usage -### Basic usage +### `ParakeetForCTC` usage @@ -86,7 +86,7 @@ print(processor.batch_decode(outputs)) -### TDT usage +### `ParakeetForTDT` usage diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index ce4385e478b2..94b54b64ae22 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -686,7 +686,8 @@ def tokenizer(self, proto): ) elif model_type == 2: - _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores) + result = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(None) + merges = result["merges"] bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)} tokenizer = Tokenizer( BPE( @@ -1771,7 +1772,8 @@ def __init__(self, vocab_file=None, *args): def tokenizer(self, proto): vocab_scores = self.vocab(proto) - _, merges = self.SpmExtractor(self.vocab_file).extract(vocab_scores) + result = self.SpmExtractor(self.vocab_file).extract(None) + merges = result["merges"] bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)} tokenizer = Tokenizer( BPE( diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 256c4c30cc35..3abd3b897fc8 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -234,7 +234,7 @@ class ParakeetTDTConfig(PreTrainedConfig): This is the configuration class to store the configuration of a [`ParakeetForTDT`]. It is used to instantiate a Parakeet TDT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Parakeet TDT - [nvidia/parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2) architecture. + [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) architecture. Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. @@ -307,11 +307,9 @@ def __init__( self.encoder_config = encoder_config self.initializer_range = self.encoder_config.initializer_range + self.pad_token_id = pad_token_id - super().__init__( - pad_token_id=pad_token_id, - **kwargs, - ) + super().__init__(**kwargs) @classmethod def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index 51ea38214527..f4ace95cf7ed 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -433,6 +433,25 @@ def main( write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id) +""" +CTC conversion example: +```bash +python src/transformers/models/parakeet/convert_nemo_to_hf.py \ + --hf_repo_id nvidia/parakeet-ctc-1.1b \ + --model_type ctc \ + --output_dir OUTPUT_DIR \ + --push_to_repo_id USERNAME/parakeet-ctc-1.1b +``` + +TDT conversion example: +```bash +python src/transformers/models/parakeet/convert_nemo_to_hf.py \ + --hf_repo_id nvidia/parakeet-tdt-0.6b-v3 \ + --model_type tdt \ + --output_dir OUTPUT_DIR \ + --push_to_repo_id USERNAME/parakeet-tdt-0.6b-v3-hf +``` +""" if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co") diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 91c9aea5003c..f14ebb7340cb 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -282,9 +282,7 @@ def __init__(self, config: ParakeetEncoderConfig, layer_idx: int): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) # W_{k,R} projection - self.relative_k_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) + self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) # global content bias self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim)) # global positional bias diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 0329443e1902..983330309838 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -121,9 +121,7 @@ def __init__(self, config: ParakeetEncoderConfig, layer_idx: int): super().__init__(config, layer_idx=layer_idx) self.is_causal = False # W_{k,R} projection - self.relative_k_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) + self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) # global content bias self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim)) # global positional bias diff --git a/tests/fixtures/parakeet/expected_results_batch_tdt.json b/tests/fixtures/parakeet/expected_results_batch_tdt.json new file mode 100644 index 000000000000..c3f46c17321d --- /dev/null +++ b/tests/fixtures/parakeet/expected_results_batch_tdt.json @@ -0,0 +1 @@ +{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "Nor is mister Quilter's manner less interesting than his matter.", "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind.", "He has grave doubts whether Sir Frederick Leighton's work is really Greek after all, and can discover in it but little of Rocky Ithaca.", "Linnell's pictures are a sort of up guards an atom paintings, and Mason's exquisite idols are as national as a jingo poem. mister Burkett Foster's landscapes smile at one much in the same way that mister Carker used to flash his teeth. And mister John Collier gives his sitter a cheerful slap on the back, before he says, like a shampooer in a Turkish bath Next man"], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [5685, 508, 282, 3459, 1382, 305, 441, 7931, 7870, 698, 1742, 293, 561, 1091, 365, 381, 7098, 2745, 1544, 441, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [1876, 280, 530, 7870, 1441, 1050, 407, 1974, 309, 940, 507, 347, 297, 289, 592, 506, 4070, 287, 7877, 1868, 4959, 398, 2037, 575, 603, 534, 555, 7124, 818, 313, 381, 555, 786, 7864, 1441, 7877, 1622, 305, 283, 2324, 1471, 3109, 325, 296, 381, 575, 5404, 1021, 355, 769, 2090, 7880, 344, 3110, 427, 319, 4838, 366, 506, 1737, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [1876, 2281, 1969, 507, 3362, 7886, 769, 328, 1299, 1239, 7319, 6447, 901, 1413, 1333, 3720, 289, 7931, 7870, 6182, 508, 5600, 4190, 377, 799, 441, 1111, 7877, 575, 2059, 5371, 3230, 334, 869, 2681, 7052, 592, 3341, 725, 7893, 2336, 7882, 566, 7865, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [439, 1538, 530, 7931, 7870, 5970, 7868, 4147, 1714, 279, 275, 621, 592, 1840, 1980, 961, 7870, 411, 407, 313, 849, 942, 2399, 7877, 575, 2945, 289, 7931, 7870, 743, 341, 290, 582, 312, 7874, 324, 7870, 1714, 618, 285, 5858, 618, 279, 300, 381, 7869, 408, 311, 7883, 282, 3459, 426, 344, 7876, 861, 515, 308, 441, 7931, 7870, 3650, 7870, 7880, 474, 283, 1530, 787, 407, 2678, 4457, 334, 506, 766, 7864, 7195, 1050, 282, 3459, 3551, 1684, 1441, 326, 366, 309, 1028, 7882, 2745, 478, 291, 7882, 7883, 1976, 282, 3459, 3483, 4003, 332, 277, 317, 416, 283, 2745, 3488, 441, 279, 774, 277, 5346, 275, 4226, 431, 506, 6507, 7877, 555, 786, 7864, 813, 498, 676, 7877, 2656, 279, 275, 3930, 726, 7869, 277, 334, 279, 5183, 7876, 2739, 302, 7152, 1030, 3127, 698]]} \ No newline at end of file diff --git a/tests/fixtures/parakeet/expected_results_single_tdt.json b/tests/fixtures/parakeet/expected_results_single_tdt.json new file mode 100644 index 000000000000..93a43c9fa9e8 --- /dev/null +++ b/tests/fixtures/parakeet/expected_results_single_tdt.json @@ -0,0 +1 @@ +{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."], "scores": [-90.4653091430664], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883]]} \ No newline at end of file diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index b4279b1d9d24..abd1cf10cc3c 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -294,6 +294,9 @@ def _load_datasamples(self, num_samples): @slow def test_1b_model_integration(self): + """ + reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single-py + """ RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single.json" with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) @@ -314,6 +317,9 @@ def test_1b_model_integration(self): @slow def test_1b_model_integration_batched(self): + """ + reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batched-py + """ RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch.json" with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) @@ -487,9 +493,13 @@ class ParakeetForTDTIntegrationTest(unittest.TestCase): @classmethod def setUp(cls): - cls.checkpoint_name = "nvidia/parakeet-tdt-0.6b-v3" + # cls.checkpoint_name = "nvidia/parakeet-tdt-0.6b-v3" + # cls.dtype = torch.bfloat16 + # cls.processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") + + cls.checkpoint_name = "bezzam/parakeet-tdt-0.6b-v3-hf" cls.dtype = torch.bfloat16 - cls.processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") + cls.processor = AutoProcessor.from_pretrained("bezzam/parakeet-tdt-0.6b-v3-hf") def tearDown(self): cleanup(torch_device, gc_collect=True) @@ -510,6 +520,15 @@ def _load_datasamples(self, num_samples): @slow def test_tdt_model_integration(self): + """ + reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single_tdt-py + """ + RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single_tdt.json" + with open(RESULTS_PATH, "r") as f: + raw_data = json.load(f) + EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) + EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] + samples = self._load_datasamples(1) model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) model.eval() @@ -518,6 +537,29 @@ def test_tdt_model_integration(self): inputs = self.processor(samples) inputs.to(torch_device, dtype=self.dtype) output = model.generate(**inputs, return_dict_in_generate=True) + torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) + predicted_transcripts = self.processor.batch_decode(output.sequences, skip_special_tokens=True) + self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) + + @slow + def test_tdt_model_integration_batched(self): + """ + reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batch_tdt-py + """ + RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch_tdt.json" + with open(RESULTS_PATH, "r") as f: + raw_data = json.load(f) + EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) + EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] + + samples = self._load_datasamples(5) + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) + model.eval() + model.to(torch_device) + + inputs = self.processor(samples) + inputs.to(torch_device, dtype=self.dtype) + output = model.generate(**inputs, return_dict_in_generate=True) + torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.batch_decode(output.sequences, skip_special_tokens=True) - self.assertTrue(len(predicted_transcripts) > 0) - self.assertTrue(len(predicted_transcripts[0]) > 0) + self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) From 05e2e346bd869016aab37882685cd8d561798224 Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 25 Feb 2026 17:38:54 +0100 Subject: [PATCH 04/67] Separate CTC and TDT generate outputs. --- src/transformers/models/lasr/modeling_lasr.py | 12 ++---- .../models/parakeet/modeling_parakeet.py | 43 ++++++++++++++----- .../models/parakeet/modular_parakeet.py | 43 ++++++++++++++----- 3 files changed, 68 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py index 83623dcaf067..24fa4872a2a8 100644 --- a/src/transformers/models/lasr/modeling_lasr.py +++ b/src/transformers/models/lasr/modeling_lasr.py @@ -555,17 +555,14 @@ def forward( @dataclass -class LasrGenerateOutput(ModelOutput): +class LasrCTCGenerateOutput(ModelOutput): """ - Outputs of Lasr models. + Outputs of Lasr CTC model generation. Args: sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Token-level timestamps in seconds indicating when each token was emitted. Only returned by TDT models - when `return_timestamps=True` is passed to `generate()`. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for @@ -579,7 +576,6 @@ class LasrGenerateOutput(ModelOutput): """ sequences: torch.LongTensor - token_timestamps: torch.FloatTensor | None = None logits: tuple[torch.FloatTensor] | None = None attentions: tuple[tuple[torch.FloatTensor]] | None = None hidden_states: tuple[tuple[torch.FloatTensor]] | None = None @@ -681,7 +677,7 @@ def generate( attention_mask: torch.Tensor | None = None, return_dict_in_generate: bool = False, **kwargs: Unpack[TransformersKwargs], - ) -> LasrGenerateOutput | torch.LongTensor: + ) -> LasrCTCGenerateOutput | torch.LongTensor: r""" Example: @@ -719,7 +715,7 @@ def generate( sequences[~attention_mask] = self.config.pad_token_id if return_dict_in_generate: - return LasrGenerateOutput( + return LasrCTCGenerateOutput( sequences=sequences, logits=outputs.logits, attentions=outputs.attentions, diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index f14ebb7340cb..312a67bc9bc9 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -651,17 +651,14 @@ def forward( @dataclass -class ParakeetGenerateOutput(ModelOutput): +class ParakeetCTCGenerateOutput(ModelOutput): """ - Outputs of Parakeet models. + Outputs of Parakeet CTC model generation. Args: sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Token-level timestamps in seconds indicating when each token was emitted. Only returned by TDT models - when `return_timestamps=True` is passed to `generate()`. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for @@ -675,12 +672,37 @@ class ParakeetGenerateOutput(ModelOutput): """ sequences: torch.LongTensor - token_timestamps: torch.FloatTensor | None = None logits: tuple[torch.FloatTensor] | None = None attentions: tuple[tuple[torch.FloatTensor]] | None = None hidden_states: tuple[tuple[torch.FloatTensor]] | None = None +@dataclass +class ParakeetTDTGenerateOutput(ModelOutput): + """ + Outputs of Parakeet TDT model generation. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Token-level timestamps in seconds indicating when each token was emitted. Only returned when + `return_timestamps=True` is passed to `generate()`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor + token_timestamps: torch.FloatTensor | None = None + attentions: tuple[tuple[torch.FloatTensor]] | None = None + hidden_states: tuple[tuple[torch.FloatTensor]] | None = None + + @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -777,7 +799,7 @@ def generate( attention_mask: torch.Tensor | None = None, return_dict_in_generate: bool = False, **kwargs: Unpack[TransformersKwargs], - ) -> ParakeetGenerateOutput | torch.LongTensor: + ) -> ParakeetCTCGenerateOutput | torch.LongTensor: r""" Example: @@ -815,7 +837,7 @@ def generate( sequences[~attention_mask] = self.config.pad_token_id if return_dict_in_generate: - return ParakeetGenerateOutput( + return ParakeetCTCGenerateOutput( sequences=sequences, logits=outputs.logits, attentions=outputs.attentions, @@ -987,7 +1009,7 @@ def generate( return_timestamps: bool = False, return_dict_in_generate: bool = False, **kwargs: Unpack[TransformersKwargs], - ) -> ParakeetGenerateOutput | torch.LongTensor: + ) -> ParakeetTDTGenerateOutput | torch.LongTensor: r""" Perform TDT greedy decoding to generate token sequences. @@ -1134,10 +1156,9 @@ def generate( ) if return_dict_in_generate: - return ParakeetGenerateOutput( + return ParakeetTDTGenerateOutput( sequences=sequences, token_timestamps=token_timestamps, - logits=None, attentions=outputs.attentions, hidden_states=outputs.hidden_states, ) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 983330309838..29553da39255 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -490,17 +490,14 @@ def forward( @dataclass -class ParakeetGenerateOutput(ModelOutput): +class ParakeetCTCGenerateOutput(ModelOutput): """ - Outputs of Parakeet models. + Outputs of Parakeet CTC model generation. Args: sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Token-level timestamps in seconds indicating when each token was emitted. Only returned by TDT models - when `return_timestamps=True` is passed to `generate()`. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for @@ -514,12 +511,37 @@ class ParakeetGenerateOutput(ModelOutput): """ sequences: torch.LongTensor - token_timestamps: torch.FloatTensor | None = None logits: tuple[torch.FloatTensor] | None = None attentions: tuple[tuple[torch.FloatTensor]] | None = None hidden_states: tuple[tuple[torch.FloatTensor]] | None = None +@dataclass +class ParakeetTDTGenerateOutput(ModelOutput): + """ + Outputs of Parakeet TDT model generation. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Token-level timestamps in seconds indicating when each token was emitted. Only returned when + `return_timestamps=True` is passed to `generate()`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor + token_timestamps: torch.FloatTensor | None = None + attentions: tuple[tuple[torch.FloatTensor]] | None = None + hidden_states: tuple[tuple[torch.FloatTensor]] | None = None + + @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -616,7 +638,7 @@ def generate( attention_mask: torch.Tensor | None = None, return_dict_in_generate: bool = False, **kwargs: Unpack[TransformersKwargs], - ) -> ParakeetGenerateOutput | torch.LongTensor: + ) -> ParakeetCTCGenerateOutput | torch.LongTensor: r""" Example: @@ -654,7 +676,7 @@ def generate( sequences[~attention_mask] = self.config.pad_token_id if return_dict_in_generate: - return ParakeetGenerateOutput( + return ParakeetCTCGenerateOutput( sequences=sequences, logits=outputs.logits, attentions=outputs.attentions, @@ -826,7 +848,7 @@ def generate( return_timestamps: bool = False, return_dict_in_generate: bool = False, **kwargs: Unpack[TransformersKwargs], - ) -> ParakeetGenerateOutput | torch.LongTensor: + ) -> ParakeetTDTGenerateOutput | torch.LongTensor: r""" Perform TDT greedy decoding to generate token sequences. @@ -973,10 +995,9 @@ def generate( ) if return_dict_in_generate: - return ParakeetGenerateOutput( + return ParakeetTDTGenerateOutput( sequences=sequences, token_timestamps=token_timestamps, - logits=None, attentions=outputs.attentions, hidden_states=outputs.hidden_states, ) From bb5ff331738f6325708430eefcffe7473a2951fd Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 25 Feb 2026 19:41:14 +0100 Subject: [PATCH 05/67] Work with auto device, better init, --- docs/source/en/model_doc/parakeet.md | 8 +-- src/transformers/modeling_utils.py | 6 ++ .../models/encodec/modeling_encodec.py | 15 +---- .../models/parakeet/modeling_parakeet.py | 53 +++++++----------- .../models/parakeet/modular_parakeet.py | 56 +++++++------------ 5 files changed, 51 insertions(+), 87 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index e709f9f54ce0..68f53aea372c 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -68,10 +68,8 @@ from transformers import AutoModelForCTC, AutoProcessor from datasets import load_dataset, Audio import torch -device = "cuda" if torch.cuda.is_available() else "cpu" - processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b") -model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device) +model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map="auto") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) @@ -107,10 +105,8 @@ from transformers import AutoModelForTDT, AutoProcessor from datasets import load_dataset, Audio import torch -device = "cuda" if torch.cuda.is_available() else "cpu" - processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") -model = AutoModelForTDT.from_pretrained("nvidia/parakeet-tdt-0.6b-v3", dtype="auto", device_map=device) +model = AutoModelForTDT.from_pretrained("nvidia/parakeet-tdt-0.6b-v3", dtype="auto", device_map="auto") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 729d7569f4d8..c7f4c3dc3ab1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2280,6 +2280,12 @@ def _init_weights(self, module): init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: init.zeros_(module.bias) + elif isinstance(module, nn.LSTM): + for name, param in module.named_parameters(): + if "weight" in name: + init.xavier_uniform_(param) + elif "bias" in name: + init.constant_(param, 0.0) elif isinstance(module, nn.Embedding): init.normal_(module.weight, mean=0.0, std=std) # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index 352a1e94006c..6af8e2d8c968 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -455,23 +455,12 @@ class EncodecPreTrainedModel(PreTrainedAudioTokenizerBase): @torch.no_grad() def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.GroupNorm): - init.zeros_(module.bias) - init.ones_(module.weight) - elif isinstance(module, nn.Conv1d): + super()._init_weights(module) + if isinstance(module, nn.Conv1d): init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) init.uniform_(module.bias, a=-k, b=k) - elif isinstance(module, nn.ConvTranspose1d): - module.reset_parameters() - elif isinstance(module, nn.LSTM): - for name, param in module.named_parameters(): - if "weight" in name: - init.xavier_uniform_(param) - elif "bias" in name: - init.constant_(param, 0.0) elif isinstance(module, EncodecConv1d): kernel_size = module.conv.kernel_size[0] stride = torch.tensor(module.conv.stride[0], dtype=torch.int64) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 312a67bc9bc9..df46a7227c2e 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -35,6 +35,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_torchaudio_available from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoModel from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig @@ -49,8 +50,6 @@ class ParakeetEncoderModelOutput(BaseModelOutput): class ParakeetEncoderRelPositionalEncoding(nn.Module): - """Relative positional encoding for Parakeet.""" - inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: ParakeetEncoderConfig, device=None): @@ -495,15 +494,9 @@ class ParakeetPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - - if hasattr(self.config, "initializer_range"): - std = self.config.initializer_range - else: - # 0.02 is the standard default value across the library - std = getattr(self.config.get_text_config(), "initializer_range", 0.02) + std = getattr(self.config, "initializer_range", 0.02) if isinstance(module, ParakeetEncoderAttention): - # Initialize positional bias parameters init.normal_(module.bias_u, mean=0.0, std=std) init.normal_(module.bias_v, mean=0.0, std=std) elif isinstance(module, ParakeetEncoderRelPositionalEncoding): @@ -513,12 +506,6 @@ def _init_weights(self, module): ** (torch.arange(0, encoder_config.hidden_size, 2, dtype=torch.int64) / encoder_config.hidden_size) ) init.copy_(module.inv_freq, inv_freq) - elif isinstance(module, nn.LSTM): - for name, param in module.named_parameters(): - if "weight" in name: - init.normal_(param, mean=0.0, std=std) - elif "bias" in name: - init.zeros_(param) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): encoder_config = getattr(self.config, "encoder_config", self.config) @@ -713,7 +700,7 @@ class ParakeetForCTC(ParakeetPreTrainedModel): def __init__(self, config: ParakeetCTCConfig): super().__init__(config) - self.encoder = ParakeetEncoder(config.encoder_config) + self.encoder = AutoModel.from_config(config.encoder_config) # Conv rather than linear to be consistent with NeMO decoding layer self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1) @@ -898,7 +885,7 @@ def forward( @auto_docstring( custom_intro=""" - Parakeet model with TDT (Token Duration Transducer) head for speech recognition. + Parakeet Encoder with a TDT (Token Duration Transducer) head. """ ) class ParakeetForTDT(ParakeetPreTrainedModel): @@ -906,7 +893,7 @@ class ParakeetForTDT(ParakeetPreTrainedModel): def __init__(self, config: ParakeetTDTConfig): super().__init__(config) - self.encoder = ParakeetEncoder(config.encoder_config) + self.encoder = AutoModel.from_config(config.encoder_config) self.decoder = ParakeetTDTDecoder(config) self.joint = ParakeetTDTJointNetwork(config) @@ -1039,22 +1026,19 @@ def generate( >>> print(output.token_timestamps) ``` """ + kwargs["return_dict"] = True if return_timestamps: return_dict_in_generate = True - blank_id = self.config.pad_token_id - max_symbols_per_step = self.config.max_symbols_per_step - device = input_features.device batch_size = input_features.shape[0] - - kwargs["return_dict"] = True - outputs: CausalLMOutput = self( + outputs: CausalLMOutput = self.forward( input_features=input_features, attention_mask=attention_mask, **kwargs, ) encoder_hidden_states = outputs.logits + device = encoder_hidden_states.device sequence_length = encoder_hidden_states.shape[1] if attention_mask is not None: encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) @@ -1067,14 +1051,16 @@ def generate( self.config.num_decoder_layers, batch_size, self.config.decoder_hidden_size, - device=device, dtype=encoder_hidden_states.dtype, ) cell_state = torch.zeros_like(hidden_state) # Initialize with blank token - prev_tokens = torch.full((batch_size, 1), blank_id, dtype=torch.long, device=device) + prev_tokens = torch.full((batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=device) decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) + decoder_output = decoder_output.to(device) + hidden_state = hidden_state.to(device) + cell_state = cell_state.to(device) all_tokens = [[] for _ in range(batch_size)] token_frame_indices = [[] for _ in range(batch_size)] if return_timestamps else None @@ -1088,16 +1074,14 @@ def generate( ].unsqueeze(1) symbols_added = 0 - while symbols_added < max_symbols_per_step: + while symbols_added < self.config.max_symbols_per_step: token_logits, duration_logits = self.joint(encoder_frames, decoder_output) - token_logits = token_logits.squeeze(1) - duration_logits = duration_logits.squeeze(1) + token_logits = token_logits.squeeze(1).to(device) + duration_logits = duration_logits.squeeze(1).to(device) tokens = token_logits.argmax(dim=-1) durations = duration_logits.argmax(dim=-1) - - is_blank = tokens == blank_id - emit_mask = active_mask & ~is_blank + emit_mask = active_mask & ~(tokens == self.config.pad_token_id) for i in range(batch_size): if emit_mask[i]: @@ -1110,6 +1094,9 @@ def generate( new_decoder_output, new_hidden_state, new_cell_state = self.decoder( new_prev_tokens, hidden_state, cell_state ) + new_decoder_output = new_decoder_output.to(device) + new_hidden_state = new_hidden_state.to(device) + new_cell_state = new_cell_state.to(device) emit_mask_expanded = emit_mask.view(batch_size, 1, 1) decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) @@ -1122,7 +1109,7 @@ def generate( stay_mask = active_mask & (durations == 0) if stay_mask.any(): symbols_added += 1 - if symbols_added >= max_symbols_per_step: + if symbols_added >= self.config.max_symbols_per_step: time_indices = time_indices + 1 break continue diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 29553da39255..49e12e09b4da 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -29,6 +29,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_torchaudio_available from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoModel from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule from ..llama.modeling_llama import LlamaAttention, eager_attention_forward from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig @@ -45,8 +46,6 @@ class ParakeetEncoderModelOutput(BaseModelOutput): class ParakeetEncoderRelPositionalEncoding(nn.Module): - """Relative positional encoding for Parakeet.""" - inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: ParakeetEncoderConfig, device=None): @@ -334,30 +333,17 @@ class ParakeetPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - - if hasattr(self.config, "initializer_range"): - std = self.config.initializer_range - else: - # 0.02 is the standard default value across the library - std = getattr(self.config.get_text_config(), "initializer_range", 0.02) + std = getattr(self.config, "initializer_range", 0.02) if isinstance(module, ParakeetEncoderAttention): - # Initialize positional bias parameters init.normal_(module.bias_u, mean=0.0, std=std) init.normal_(module.bias_v, mean=0.0, std=std) elif isinstance(module, ParakeetEncoderRelPositionalEncoding): encoder_config = getattr(self.config, "encoder_config", self.config) inv_freq = 1.0 / ( - 10000.0 - ** (torch.arange(0, encoder_config.hidden_size, 2, dtype=torch.int64) / encoder_config.hidden_size) + 10000.0 ** (torch.arange(0, encoder_config.hidden_size, 2, dtype=torch.int64) / encoder_config.hidden_size) ) init.copy_(module.inv_freq, inv_freq) - elif isinstance(module, nn.LSTM): - for name, param in module.named_parameters(): - if "weight" in name: - init.normal_(param, mean=0.0, std=std) - elif "bias" in name: - init.zeros_(param) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): encoder_config = getattr(self.config, "encoder_config", self.config) @@ -552,7 +538,7 @@ class ParakeetForCTC(ParakeetPreTrainedModel): def __init__(self, config: ParakeetCTCConfig): super().__init__(config) - self.encoder = ParakeetEncoder(config.encoder_config) + self.encoder = AutoModel.from_config(config.encoder_config) # Conv rather than linear to be consistent with NeMO decoding layer self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1) @@ -737,7 +723,7 @@ def forward( @auto_docstring( custom_intro=""" - Parakeet model with TDT (Token Duration Transducer) head for speech recognition. + Parakeet Encoder with a TDT (Token Duration Transducer) head. """ ) class ParakeetForTDT(ParakeetPreTrainedModel): @@ -745,7 +731,7 @@ class ParakeetForTDT(ParakeetPreTrainedModel): def __init__(self, config: ParakeetTDTConfig): super().__init__(config) - self.encoder = ParakeetEncoder(config.encoder_config) + self.encoder = AutoModel.from_config(config.encoder_config) self.decoder = ParakeetTDTDecoder(config) self.joint = ParakeetTDTJointNetwork(config) @@ -878,22 +864,19 @@ def generate( >>> print(output.token_timestamps) ``` """ + kwargs["return_dict"] = True if return_timestamps: return_dict_in_generate = True - blank_id = self.config.pad_token_id - max_symbols_per_step = self.config.max_symbols_per_step - device = input_features.device batch_size = input_features.shape[0] - - kwargs["return_dict"] = True - outputs: CausalLMOutput = self( + outputs: CausalLMOutput = self.forward( input_features=input_features, attention_mask=attention_mask, **kwargs, ) encoder_hidden_states = outputs.logits + device = encoder_hidden_states.device sequence_length = encoder_hidden_states.shape[1] if attention_mask is not None: encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) @@ -906,14 +889,16 @@ def generate( self.config.num_decoder_layers, batch_size, self.config.decoder_hidden_size, - device=device, dtype=encoder_hidden_states.dtype, ) cell_state = torch.zeros_like(hidden_state) # Initialize with blank token - prev_tokens = torch.full((batch_size, 1), blank_id, dtype=torch.long, device=device) + prev_tokens = torch.full((batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=device) decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) + decoder_output = decoder_output.to(device) + hidden_state = hidden_state.to(device) + cell_state = cell_state.to(device) all_tokens = [[] for _ in range(batch_size)] token_frame_indices = [[] for _ in range(batch_size)] if return_timestamps else None @@ -927,16 +912,14 @@ def generate( ].unsqueeze(1) symbols_added = 0 - while symbols_added < max_symbols_per_step: + while symbols_added < self.config.max_symbols_per_step: token_logits, duration_logits = self.joint(encoder_frames, decoder_output) - token_logits = token_logits.squeeze(1) - duration_logits = duration_logits.squeeze(1) + token_logits = token_logits.squeeze(1).to(device) + duration_logits = duration_logits.squeeze(1).to(device) tokens = token_logits.argmax(dim=-1) durations = duration_logits.argmax(dim=-1) - - is_blank = tokens == blank_id - emit_mask = active_mask & ~is_blank + emit_mask = active_mask & ~(tokens == self.config.pad_token_id) for i in range(batch_size): if emit_mask[i]: @@ -949,6 +932,9 @@ def generate( new_decoder_output, new_hidden_state, new_cell_state = self.decoder( new_prev_tokens, hidden_state, cell_state ) + new_decoder_output = new_decoder_output.to(device) + new_hidden_state = new_hidden_state.to(device) + new_cell_state = new_cell_state.to(device) emit_mask_expanded = emit_mask.view(batch_size, 1, 1) decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) @@ -961,7 +947,7 @@ def generate( stay_mask = active_mask & (durations == 0) if stay_mask.any(): symbols_added += 1 - if symbols_added >= max_symbols_per_step: + if symbols_added >= self.config.max_symbols_per_step: time_indices = time_indices + 1 break continue From 9ec79b02c23006cae828358785e2b12ca262b576 Mon Sep 17 00:00:00 2001 From: Eric B Date: Thu, 26 Feb 2026 15:28:43 +0100 Subject: [PATCH 06/67] Test timestamps and expose token duration. --- .../models/parakeet/configuration_parakeet.py | 21 +++++---- .../models/parakeet/modeling_parakeet.py | 47 +++++++++++++------ .../models/parakeet/modular_parakeet.py | 47 +++++++++++++------ .../expected_results_batch_tdt_timestamp.json | 1 + .../models/parakeet/test_modeling_parakeet.py | 41 ++++++++++++++-- 5 files changed, 115 insertions(+), 42 deletions(-) create mode 100644 tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 3abd3b897fc8..270c608cf597 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -51,6 +51,10 @@ class ParakeetEncoderConfig(PreTrainedConfig): The number of channels in the subsampling convolution layers. num_mel_bins (`int`, *optional*, defaults to 80): Number of mel features. + hop_length (`int`, *optional*, defaults to 160): + Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). subsampling_conv_kernel_size (`int`, *optional*, defaults to 3): The kernel size of the subsampling convolution layers. subsampling_conv_stride (`int`, *optional*, defaults to 2): @@ -106,6 +110,8 @@ def __init__( subsampling_factor=8, subsampling_conv_channels=256, num_mel_bins=80, + hop_length=160, + sampling_rate=16000, subsampling_conv_kernel_size=3, subsampling_conv_stride=2, dropout=0.1, @@ -134,6 +140,8 @@ def __init__( self.subsampling_factor = subsampling_factor self.subsampling_conv_channels = subsampling_conv_channels self.num_mel_bins = num_mel_bins + self.hop_length = hop_length + self.sampling_rate = sampling_rate self.dropout = dropout self.dropout_positions = dropout_positions @@ -144,9 +152,7 @@ def __init__( self.scale_input = scale_input self.initializer_range = initializer_range - super().__init__( - **kwargs, - ) + super().__init__(**kwargs) class ParakeetCTCConfig(PreTrainedConfig): @@ -252,9 +258,6 @@ class ParakeetTDTConfig(PreTrainedConfig): The activation function in the joint network. max_symbols_per_step (`int`, *optional*, defaults to 10): Maximum number of symbols to emit per encoder time step during greedy decoding. - seconds_per_frame (`float`, *optional*, defaults to 0.08): - Duration in seconds of each encoder output frame. Used for computing token timestamps. - Computed as `hop_length * subsampling_factor / sampling_rate` (e.g. 160 * 8 / 16000 = 0.08). encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): The config object or dictionary of the encoder. pad_token_id (`int`, *optional*, defaults to 8192): @@ -286,7 +289,6 @@ def __init__( num_duration_bins=5, hidden_act="relu", max_symbols_per_step=10, - seconds_per_frame=0.08, encoder_config: dict | ParakeetEncoderConfig = None, pad_token_id=8192, **kwargs, @@ -297,7 +299,6 @@ def __init__( self.num_duration_bins = num_duration_bins self.hidden_act = hidden_act self.max_symbols_per_step = max_symbols_per_step - self.seconds_per_frame = seconds_per_frame if isinstance(encoder_config, dict): self.encoder_config = ParakeetEncoderConfig(**encoder_config) @@ -311,6 +312,10 @@ def __init__( super().__init__(**kwargs) + @property + def frame_rate(self): + return self.encoder_config.sampling_rate / (self.encoder_config.hop_length * self.encoder_config.subsampling_factor) + @classmethod def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): r""" diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index df46a7227c2e..ae27435c9b78 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -676,6 +676,9 @@ class ParakeetTDTGenerateOutput(ModelOutput): token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Token-level timestamps in seconds indicating when each token was emitted. Only returned when `return_timestamps=True` is passed to `generate()`. + token_durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Token-level durations in frames indicating how many frames each token spans. Only returned when + `return_timestamps=True` is passed to `generate()`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. @@ -686,6 +689,7 @@ class ParakeetTDTGenerateOutput(ModelOutput): sequences: torch.LongTensor token_timestamps: torch.FloatTensor | None = None + token_durations: torch.LongTensor | None = None attentions: tuple[tuple[torch.FloatTensor]] | None = None hidden_states: tuple[tuple[torch.FloatTensor]] | None = None @@ -839,6 +843,7 @@ class ParakeetTDTDecoder(nn.Module): def __init__(self, config: ParakeetTDTConfig): super().__init__() + self.config = config self.embedding = nn.Embedding(config.vocab_size + 1, config.decoder_hidden_size) self.lstm = nn.LSTM( input_size=config.decoder_hidden_size, @@ -854,9 +859,21 @@ def forward( hidden_state: torch.Tensor | None = None, cell_state: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input_ids = input_ids.to(self.decoder_projector.weight.device) + if hidden_state is None or cell_state is None: + hidden_state = torch.zeros( + self.config.num_decoder_layers, + input_ids.shape[0], + self.config.decoder_hidden_size, + device=self.decoder_projector.weight.device, + dtype=self.decoder_projector.weight.dtype, + ) + cell_state = torch.zeros_like(hidden_state) + hidden_state = hidden_state.to(self.decoder_projector.weight.device) + cell_state = cell_state.to(self.decoder_projector.weight.device) + embeddings = self.embedding(input_ids) - lstm_state = (hidden_state, cell_state) if hidden_state is not None else None - lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, lstm_state) + lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, (hidden_state, cell_state)) decoder_output = self.decoder_projector(lstm_output) return decoder_output, hidden_state, cell_state @@ -1046,16 +1063,8 @@ def generate( else: valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) - # Initialize decoder LSTM state - hidden_state = torch.zeros( - self.config.num_decoder_layers, - batch_size, - self.config.decoder_hidden_size, - dtype=encoder_hidden_states.dtype, - ) - cell_state = torch.zeros_like(hidden_state) - - # Initialize with blank token + # Initialize decoder + hidden_state, cell_state = None, None prev_tokens = torch.full((batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=device) decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) decoder_output = decoder_output.to(device) @@ -1064,6 +1073,7 @@ def generate( all_tokens = [[] for _ in range(batch_size)] token_frame_indices = [[] for _ in range(batch_size)] if return_timestamps else None + token_durations_list = [[] for _ in range(batch_size)] if return_timestamps else None time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) active_mask = time_indices < valid_lengths @@ -1088,6 +1098,8 @@ def generate( all_tokens[i].append(tokens[i].item()) if token_frame_indices is not None: token_frame_indices[i].append(time_indices[i].item()) + if token_durations_list is not None: + token_durations_list[i].append(durations[i].item()) if emit_mask.any(): new_prev_tokens = tokens.unsqueeze(1) @@ -1110,7 +1122,7 @@ def generate( if stay_mask.any(): symbols_added += 1 if symbols_added >= self.config.max_symbols_per_step: - time_indices = time_indices + 1 + time_indices[active_mask & stay_mask] += 1 break continue @@ -1132,20 +1144,25 @@ def generate( sequences[i, :seq_len] = torch.tensor(all_tokens[i], dtype=torch.long, device=device) token_timestamps = None + token_durations = None if return_timestamps: - seconds_per_frame = self.config.seconds_per_frame token_timestamps = torch.full((batch_size, max_len), 0.0, dtype=torch.float, device=device) + token_durations = torch.full((batch_size, max_len), 0, dtype=torch.long, device=device) for i in range(batch_size): num_tokens = len(token_frame_indices[i]) if num_tokens > 0: token_timestamps[i, :num_tokens] = ( - torch.tensor(token_frame_indices[i], dtype=torch.float, device=device) * seconds_per_frame + torch.tensor(token_frame_indices[i], dtype=torch.float, device=device) / self.config.frame_rate + ) + token_durations[i, :num_tokens] = torch.tensor( + token_durations_list[i], dtype=torch.long, device=device ) if return_dict_in_generate: return ParakeetTDTGenerateOutput( sequences=sequences, token_timestamps=token_timestamps, + token_durations=token_durations, attentions=outputs.attentions, hidden_states=outputs.hidden_states, ) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 49e12e09b4da..70fbf31540ac 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -514,6 +514,9 @@ class ParakeetTDTGenerateOutput(ModelOutput): token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Token-level timestamps in seconds indicating when each token was emitted. Only returned when `return_timestamps=True` is passed to `generate()`. + token_durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Token-level durations in frames indicating how many frames each token spans. Only returned when + `return_timestamps=True` is passed to `generate()`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. @@ -524,6 +527,7 @@ class ParakeetTDTGenerateOutput(ModelOutput): sequences: torch.LongTensor token_timestamps: torch.FloatTensor | None = None + token_durations: torch.LongTensor | None = None attentions: tuple[tuple[torch.FloatTensor]] | None = None hidden_states: tuple[tuple[torch.FloatTensor]] | None = None @@ -677,6 +681,7 @@ class ParakeetTDTDecoder(nn.Module): def __init__(self, config: ParakeetTDTConfig): super().__init__() + self.config = config self.embedding = nn.Embedding(config.vocab_size + 1, config.decoder_hidden_size) self.lstm = nn.LSTM( input_size=config.decoder_hidden_size, @@ -692,9 +697,21 @@ def forward( hidden_state: torch.Tensor | None = None, cell_state: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input_ids = input_ids.to(self.decoder_projector.weight.device) + if hidden_state is None or cell_state is None: + hidden_state = torch.zeros( + self.config.num_decoder_layers, + input_ids.shape[0], + self.config.decoder_hidden_size, + device=self.decoder_projector.weight.device, + dtype=self.decoder_projector.weight.dtype, + ) + cell_state = torch.zeros_like(hidden_state) + hidden_state = hidden_state.to(self.decoder_projector.weight.device) + cell_state = cell_state.to(self.decoder_projector.weight.device) + embeddings = self.embedding(input_ids) - lstm_state = (hidden_state, cell_state) if hidden_state is not None else None - lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, lstm_state) + lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, (hidden_state, cell_state)) decoder_output = self.decoder_projector(lstm_output) return decoder_output, hidden_state, cell_state @@ -884,16 +901,8 @@ def generate( else: valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) - # Initialize decoder LSTM state - hidden_state = torch.zeros( - self.config.num_decoder_layers, - batch_size, - self.config.decoder_hidden_size, - dtype=encoder_hidden_states.dtype, - ) - cell_state = torch.zeros_like(hidden_state) - - # Initialize with blank token + # Initialize decoder + hidden_state, cell_state = None, None prev_tokens = torch.full((batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=device) decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) decoder_output = decoder_output.to(device) @@ -902,6 +911,7 @@ def generate( all_tokens = [[] for _ in range(batch_size)] token_frame_indices = [[] for _ in range(batch_size)] if return_timestamps else None + token_durations_list = [[] for _ in range(batch_size)] if return_timestamps else None time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) active_mask = time_indices < valid_lengths @@ -926,6 +936,8 @@ def generate( all_tokens[i].append(tokens[i].item()) if token_frame_indices is not None: token_frame_indices[i].append(time_indices[i].item()) + if token_durations_list is not None: + token_durations_list[i].append(durations[i].item()) if emit_mask.any(): new_prev_tokens = tokens.unsqueeze(1) @@ -948,7 +960,7 @@ def generate( if stay_mask.any(): symbols_added += 1 if symbols_added >= self.config.max_symbols_per_step: - time_indices = time_indices + 1 + time_indices[active_mask & stay_mask] += 1 break continue @@ -970,20 +982,25 @@ def generate( sequences[i, :seq_len] = torch.tensor(all_tokens[i], dtype=torch.long, device=device) token_timestamps = None + token_durations = None if return_timestamps: - seconds_per_frame = self.config.seconds_per_frame token_timestamps = torch.full((batch_size, max_len), 0.0, dtype=torch.float, device=device) + token_durations = torch.full((batch_size, max_len), 0, dtype=torch.long, device=device) for i in range(batch_size): num_tokens = len(token_frame_indices[i]) if num_tokens > 0: token_timestamps[i, :num_tokens] = ( - torch.tensor(token_frame_indices[i], dtype=torch.float, device=device) * seconds_per_frame + torch.tensor(token_frame_indices[i], dtype=torch.float, device=device) / self.config.frame_rate + ) + token_durations[i, :num_tokens] = torch.tensor( + token_durations_list[i], dtype=torch.long, device=device ) if return_dict_in_generate: return ParakeetTDTGenerateOutput( sequences=sequences, token_timestamps=token_timestamps, + token_durations=token_durations, attentions=outputs.attentions, hidden_states=outputs.hidden_states, ) diff --git a/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json b/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json new file mode 100644 index 000000000000..0acb4bae061b --- /dev/null +++ b/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json @@ -0,0 +1 @@ +{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "Nor is mister Quilter's manner less interesting than his matter.", "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind."], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [5685, 508, 282, 3459, 1382, 305, 441, 7931, 7870, 698, 1742, 293, 561, 1091, 365, 381, 7098, 2745, 1544, 441, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [1876, 280, 530, 7870, 1441, 1050, 407, 1974, 309, 940, 507, 347, 297, 289, 592, 506, 4070, 287, 7877, 1868, 4959, 398, 2037, 575, 603, 534, 555, 7124, 818, 313, 381, 555, 786, 7864, 1441, 7877, 1622, 305, 283, 2324, 1471, 3109, 325, 296, 381, 575, 5404, 1021, 355, 769, 2090, 7880, 344, 3110, 427, 319, 4838, 366, 506, 1737, 7883]], "token_timestamps": [[0.23999999463558197, 0.47999998927116394, 0.6399999856948853, 0.8799999952316284, 1.1200000047683716, 1.3600000143051147, 1.440000057220459, 1.600000023841858, 1.7599999904632568, 2.0, 2.1600000858306885, 2.240000009536743, 2.4000000953674316, 2.4800000190734863, 2.559999942779541, 2.7200000286102295, 2.880000114440918, 3.0399999618530273, 3.119999885559082, 3.2799999713897705, 3.440000057220459, 3.5999999046325684, 3.759999990463257, 3.9200000762939453, 4.079999923706055, 4.239999771118164, 4.400000095367432, 4.480000019073486, 4.71999979019165, 4.960000038146973, 5.360000133514404, 5.599999904632568, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3199999928474426, 0.6399999856948853, 0.8799999952316284, 1.0399999618530273, 1.2000000476837158, 1.440000057220459, 1.6799999475479126, 1.840000033378601, 1.9199999570846558, 2.0, 2.1600000858306885, 2.4000000953674316, 2.559999942779541, 2.7200000286102295, 2.9600000381469727, 3.119999885559082, 3.359999895095825, 3.5999999046325684, 3.9200000762939453, 4.159999847412109, 4.320000171661377, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3199999928474426, 0.6399999856948853, 0.7200000286102295, 0.9599999785423279, 1.1200000047683716, 1.3600000143051147, 1.600000023841858, 1.840000033378601, 2.0799999237060547, 2.240000009536743, 2.4800000190734863, 2.640000104904175, 2.799999952316284, 2.880000114440918, 3.0399999618530273, 3.200000047683716, 3.440000057220459, 3.680000066757202, 3.8399999141693115, 4.079999923706055, 4.400000095367432, 4.559999942779541, 4.71999979019165, 4.960000038146973, 5.119999885559082, 5.360000133514404, 5.519999980926514, 5.679999828338623, 5.920000076293945, 6.159999847412109, 6.239999771118164, 6.400000095367432, 6.559999942779541, 6.71999979019165, 6.960000038146973, 7.28000020980835, 7.599999904632568, 7.920000076293945, 8.15999984741211, 8.319999694824219, 8.479999542236328, 8.720000267028809, 8.880000114440918, 8.960000038146973, 9.119999885559082, 9.279999732971191, 9.4399995803833, 9.680000305175781, 9.760000228881836, 9.920000076293945, 10.15999984741211, 10.239999771118164, 10.399999618530273, 10.640000343322754, 10.880000114440918, 10.960000038146973, 11.199999809265137, 11.359999656677246, 11.520000457763672, 11.84000015258789, 12.15999984741211]], "token_durations": [[3, 2, 3, 3, 3, 1, 2, 2, 2, 2, 1, 2, 1, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 3, 3, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 3, 2, 2, 3, 3, 2, 1, 1, 2, 3, 2, 2, 3, 2, 3, 3, 4, 3, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 1, 3, 2, 3, 3, 3, 3, 2, 3, 2, 2, 1, 2, 2, 3, 3, 2, 3, 4, 2, 2, 3, 2, 3, 2, 2, 3, 3, 1, 2, 2, 2, 3, 4, 4, 4, 3, 1, 2, 3, 2, 1, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 3, 1, 3, 2, 2, 4, 4, 2]]} \ No newline at end of file diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index abd1cf10cc3c..c966c43a550a 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -529,12 +529,12 @@ def test_tdt_model_integration(self): EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] - samples = self._load_datasamples(1) + samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) model.eval() model.to(torch_device) - inputs = self.processor(samples) + inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) inputs.to(torch_device, dtype=self.dtype) output = model.generate(**inputs, return_dict_in_generate=True) torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) @@ -552,14 +552,47 @@ def test_tdt_model_integration_batched(self): EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] - samples = self._load_datasamples(5) + samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) model.eval() model.to(torch_device) - inputs = self.processor(samples) + inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) inputs.to(torch_device, dtype=self.dtype) output = model.generate(**inputs, return_dict_in_generate=True) torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.batch_decode(output.sequences, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) + + @slow + def test_tdt_model_integration_timestamps(self): + """ + reproducer: tests/models/parakeet/reproducer_batch_tdt_timestamps.py + """ + RESULTS_PATH = ( + Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch_tdt_timestamp.json" + ) + with open(RESULTS_PATH, "r") as f: + raw_data = json.load(f) + EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) + EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] + EXPECTED_TIMESTAMPS = torch.tensor(raw_data["token_timestamps"]) + EXPECTED_DURATIONS = torch.tensor(raw_data["token_durations"]) + + # Dynamically determine number of samples from expected results + samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) + model.eval() + model.to(torch_device) + + inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) + inputs.to(torch_device, dtype=self.dtype) + output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) + torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) + predicted_transcripts = self.processor.batch_decode(output.sequences, skip_special_tokens=True) + self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) + + # Check timestamps and durations + self.assertIsNotNone(output.token_timestamps, "token_timestamps should be returned when return_timestamps=True") + torch.testing.assert_close(output.token_timestamps.cpu(), EXPECTED_TIMESTAMPS) + torch.testing.assert_close(output.token_durations.cpu(), EXPECTED_DURATIONS) From 33f128ecda39ff5081ac79a99803c0a2a4024713 Mon Sep 17 00:00:00 2001 From: Eric B Date: Thu, 26 Feb 2026 16:21:52 +0100 Subject: [PATCH 07/67] Add reproducer link. --- tests/models/parakeet/test_modeling_parakeet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index c966c43a550a..4865cfd0e455 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -567,7 +567,7 @@ def test_tdt_model_integration_batched(self): @slow def test_tdt_model_integration_timestamps(self): """ - reproducer: tests/models/parakeet/reproducer_batch_tdt_timestamps.py + reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batch_tdt_timestamps-py """ RESULTS_PATH = ( Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch_tdt_timestamp.json" From 760b4b61e122372014da22aa1b1360cebb289534 Mon Sep 17 00:00:00 2001 From: Maksym Lypivskyi Date: Fri, 27 Feb 2026 16:14:18 +0100 Subject: [PATCH 08/67] fix: align TDT training and decoding with NeMo implementation - Use -100 label padding for training (HF convention) - Fix timestamp recording in inner blank-seeking loop - Add max_symbols_per_step guard matching NeMo - Clean up decoding loop - Add TDT training example to docs - Use setUpClass for TDT integration tests --- docs/source/en/model_doc/parakeet.md | 23 +++ .../models/lasr/configuration_lasr.py | 6 +- src/transformers/models/lasr/modeling_lasr.py | 3 +- .../models/lasr/processing_lasr.py | 4 +- .../models/parakeet/configuration_parakeet.py | 4 +- .../models/parakeet/modeling_parakeet.py | 145 +++++++++++------ .../models/parakeet/modular_parakeet.py | 148 +++++++++++------- .../models/parakeet/processing_parakeet.py | 4 +- .../models/parakeet/test_modeling_parakeet.py | 10 +- 9 files changed, 229 insertions(+), 118 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index 68f53aea372c..6722f932d631 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -228,6 +228,29 @@ outputs = model(**inputs) outputs.loss.backward() ``` +### TDT Training + +The TDT model uses RNNT loss (requires `torchaudio`). Pass `text` to the processor to prepare labels — padding is automatically handled with `-100`. + +```python +from transformers import AutoModelForTDT, AutoProcessor +from datasets import load_dataset, Audio + +processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") +model = AutoModelForTDT.from_pretrained("nvidia/parakeet-tdt-0.6b-v3", dtype="auto", device_map="auto") + +ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") +ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) +speech_samples = [el['array'] for el in ds["audio"][:5]] +text_samples = [el for el in ds["text"][:5]] + +inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate) +inputs.to(model.device, dtype=model.dtype) + +outputs = model(**inputs) +outputs.loss.backward() +``` + ## ParakeetTokenizer [[autodoc]] ParakeetTokenizer diff --git a/src/transformers/models/lasr/configuration_lasr.py b/src/transformers/models/lasr/configuration_lasr.py index 4d82b85044a2..60101030f38e 100644 --- a/src/transformers/models/lasr/configuration_lasr.py +++ b/src/transformers/models/lasr/configuration_lasr.py @@ -150,6 +150,8 @@ def __init__( self.subsampling_conv_stride = subsampling_conv_stride self.subsampling_conv_channels = subsampling_conv_channels self.num_mel_bins = num_mel_bins + self.hop_length = hop_length + self.sampling_rate = sampling_rate self.dropout = dropout self.dropout_positions = dropout_positions @@ -159,9 +161,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range - super().__init__( - **kwargs, - ) + super().__init__(**kwargs) class LasrCTCConfig(PreTrainedConfig): diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py index 24fa4872a2a8..18fa46657c78 100644 --- a/src/transformers/models/lasr/modeling_lasr.py +++ b/src/transformers/models/lasr/modeling_lasr.py @@ -36,6 +36,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoModel from .configuration_lasr import LasrCTCConfig, LasrEncoderConfig @@ -591,7 +592,7 @@ class LasrForCTC(LasrPreTrainedModel): def __init__(self, config: LasrCTCConfig): super().__init__(config) - self.encoder = LasrEncoder(config.encoder_config) + self.encoder = AutoModel.from_config(config.encoder_config) # Conv rather than linear to be consistent with NeMO decoding layer self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1) diff --git a/src/transformers/models/lasr/processing_lasr.py b/src/transformers/models/lasr/processing_lasr.py index c1acaebaae07..644cd835936d 100644 --- a/src/transformers/models/lasr/processing_lasr.py +++ b/src/transformers/models/lasr/processing_lasr.py @@ -88,7 +88,9 @@ def __call__( if text is None: return inputs else: - inputs["labels"] = encodings["input_ids"] + labels = encodings["input_ids"] + labels[labels == self.tokenizer.pad_token_id] = -100 + inputs["labels"] = labels return inputs @property diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 270c608cf597..3c233726e36c 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -314,7 +314,9 @@ def __init__( @property def frame_rate(self): - return self.encoder_config.sampling_rate / (self.encoder_config.hop_length * self.encoder_config.subsampling_factor) + return self.encoder_config.sampling_rate / ( + self.encoder_config.hop_length * self.encoder_config.subsampling_factor + ) @classmethod def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index ae27435c9b78..b1ff6da52c88 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -680,11 +680,11 @@ class ParakeetTDTGenerateOutput(ModelOutput): Token-level durations in frames indicating how many frames each token spans. Only returned when `return_timestamps=True` is passed to `generate()`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions from the encoder. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`. Hidden states from the encoder. """ sequences: torch.LongTensor @@ -967,10 +967,12 @@ def forward( ) encoder_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) - # Compute target lengths (non-pad tokens) - labels_mask = labels != self.config.pad_token_id + labels_mask = labels != -100 target_lengths = labels_mask.sum(-1) + labels = labels.clone() + labels[labels == -100] = self.config.pad_token_id + # Prepare decoder input: prepend blank token to labels blank_tokens = torch.full( (labels.shape[0], 1), self.config.pad_token_id, dtype=labels.dtype, device=labels.device @@ -980,11 +982,14 @@ def forward( # Run decoder on full label sequence: (batch, U+1, decoder_hidden_size) decoder_output, _, _ = self.decoder(decoder_input) + max_encoder_length = encoder_lengths.max().item() + encoder_hidden_states_trimmed = encoder_hidden_states[:, :max_encoder_length] + # Compute joint output for all (T, U+1) pairs via broadcasting # encoder: (batch, T, 1, encoder_hidden) -> projected to (batch, T, 1, decoder_hidden_size) # decoder: (batch, 1, U+1, decoder_hidden_size) token_logits, _ = self.joint( - encoder_hidden_states.unsqueeze(2), + encoder_hidden_states_trimmed.unsqueeze(2), decoder_output.unsqueeze(1), ) # token_logits: (batch, T, U+1, vocab_size+1) @@ -1074,61 +1079,97 @@ def generate( all_tokens = [[] for _ in range(batch_size)] token_frame_indices = [[] for _ in range(batch_size)] if return_timestamps else None token_durations_list = [[] for _ in range(batch_size)] if return_timestamps else None + batch_indices = torch.arange(batch_size, device=device) time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) + time_indices_current_labels = torch.zeros(batch_size, dtype=torch.long, device=device) active_mask = time_indices < valid_lengths + max_symbols = self.config.max_symbols_per_step + symbols_per_step = torch.zeros(batch_size, dtype=torch.long, device=device) + last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) + while active_mask.any(): safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) - encoder_frames = encoder_hidden_states[ - torch.arange(batch_size, device=device), safe_time_indices - ].unsqueeze(1) + encoder_frames = encoder_hidden_states[batch_indices, safe_time_indices].unsqueeze(1) + + token_logits, duration_logits = self.joint(encoder_frames, decoder_output) + token_logits = token_logits.squeeze(1).to(device) + duration_logits = duration_logits.squeeze(1).to(device) + + tokens = token_logits.argmax(dim=-1) + durations = duration_logits.argmax(dim=-1) + blank_mask = active_mask & (tokens == self.config.pad_token_id) + + # Force blank duration >= 1 to guarantee forward progress + durations = durations.masked_fill(blank_mask & (durations == 0), 1) + + # Save pre-advance position for timestamp recording + time_indices_current_labels.copy_(time_indices) + + # Advance time for all active elements + time_indices = time_indices + durations * active_mask + safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) + active_mask = time_indices < valid_lengths + advance_mask = active_mask & blank_mask + + # Inner loop: skip past consecutive blanks to find non-blank + while advance_mask.any(): + # Update timestamp tracking to current position + time_indices_current_labels = torch.where(advance_mask, time_indices, time_indices_current_labels) + encoder_frames = encoder_hidden_states[batch_indices, safe_time_indices].unsqueeze(1) - symbols_added = 0 - while symbols_added < self.config.max_symbols_per_step: token_logits, duration_logits = self.joint(encoder_frames, decoder_output) token_logits = token_logits.squeeze(1).to(device) duration_logits = duration_logits.squeeze(1).to(device) - tokens = token_logits.argmax(dim=-1) - durations = duration_logits.argmax(dim=-1) - emit_mask = active_mask & ~(tokens == self.config.pad_token_id) - - for i in range(batch_size): - if emit_mask[i]: - all_tokens[i].append(tokens[i].item()) - if token_frame_indices is not None: - token_frame_indices[i].append(time_indices[i].item()) - if token_durations_list is not None: - token_durations_list[i].append(durations[i].item()) - - if emit_mask.any(): - new_prev_tokens = tokens.unsqueeze(1) - new_decoder_output, new_hidden_state, new_cell_state = self.decoder( - new_prev_tokens, hidden_state, cell_state - ) - new_decoder_output = new_decoder_output.to(device) - new_hidden_state = new_hidden_state.to(device) - new_cell_state = new_cell_state.to(device) - - emit_mask_expanded = emit_mask.view(batch_size, 1, 1) - decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) - - emit_mask_state = emit_mask.view(1, batch_size, 1) - hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) - cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) - - # If duration is 0, stay on same frame (emit more tokens) - stay_mask = active_mask & (durations == 0) - if stay_mask.any(): - symbols_added += 1 - if symbols_added >= self.config.max_symbols_per_step: - time_indices[active_mask & stay_mask] += 1 - break - continue - - # Duration > 0: advance time - time_indices = time_indices + torch.where(active_mask, durations, torch.zeros_like(durations)) - break + more_tokens = token_logits.argmax(dim=-1) + more_durations = duration_logits.argmax(dim=-1) + + tokens = torch.where(advance_mask, more_tokens, tokens) + durations = torch.where(advance_mask, more_durations, durations) + + blank_mask = tokens == self.config.pad_token_id + durations = durations.masked_fill(blank_mask & (durations == 0), 1) + + time_indices = torch.where(advance_mask, time_indices + durations, time_indices) + safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) + active_mask = time_indices < valid_lengths + advance_mask = active_mask & blank_mask + + # Record results for non-blank tokens found + emit_mask = active_mask & (tokens != self.config.pad_token_id) + for i in range(batch_size): + if emit_mask[i]: + all_tokens[i].append(tokens[i].item()) + if token_frame_indices is not None: + token_frame_indices[i].append(time_indices_current_labels[i].item()) + if token_durations_list is not None: + token_durations_list[i].append(durations[i].item()) + + if emit_mask.any(): + new_prev_tokens = tokens.unsqueeze(1) + new_decoder_output, new_hidden_state, new_cell_state = self.decoder( + new_prev_tokens, hidden_state, cell_state + ) + new_decoder_output = new_decoder_output.to(device) + new_hidden_state = new_hidden_state.to(device) + new_cell_state = new_cell_state.to(device) + + emit_mask_expanded = emit_mask.view(batch_size, 1, 1) + decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) + + emit_mask_state = emit_mask.view(1, batch_size, 1) + hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) + cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) + + # Track symbols emitted per time step; force advance when max_symbols reached + time_changed = time_indices_current_labels != last_label_time + symbols_per_step = torch.where(time_changed, torch.zeros_like(symbols_per_step), symbols_per_step) + symbols_per_step = torch.where(emit_mask, symbols_per_step + 1, symbols_per_step) + last_label_time = torch.where(emit_mask, time_indices_current_labels, last_label_time) + force_advance = active_mask & (symbols_per_step >= max_symbols) + time_indices = time_indices + force_advance.long() + symbols_per_step = symbols_per_step.masked_fill(force_advance, 0) active_mask = time_indices < valid_lengths diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 70fbf31540ac..6f59b829b093 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -341,7 +341,8 @@ def _init_weights(self, module): elif isinstance(module, ParakeetEncoderRelPositionalEncoding): encoder_config = getattr(self.config, "encoder_config", self.config) inv_freq = 1.0 / ( - 10000.0 ** (torch.arange(0, encoder_config.hidden_size, 2, dtype=torch.int64) / encoder_config.hidden_size) + 10000.0 + ** (torch.arange(0, encoder_config.hidden_size, 2, dtype=torch.int64) / encoder_config.hidden_size) ) init.copy_(module.inv_freq, inv_freq) @@ -518,11 +519,11 @@ class ParakeetTDTGenerateOutput(ModelOutput): Token-level durations in frames indicating how many frames each token spans. Only returned when `return_timestamps=True` is passed to `generate()`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions from the encoder. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`. Hidden states from the encoder. """ sequences: torch.LongTensor @@ -805,10 +806,12 @@ def forward( ) encoder_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) - # Compute target lengths (non-pad tokens) - labels_mask = labels != self.config.pad_token_id + labels_mask = labels != -100 target_lengths = labels_mask.sum(-1) + labels = labels.clone() + labels[labels == -100] = self.config.pad_token_id + # Prepare decoder input: prepend blank token to labels blank_tokens = torch.full( (labels.shape[0], 1), self.config.pad_token_id, dtype=labels.dtype, device=labels.device @@ -818,11 +821,14 @@ def forward( # Run decoder on full label sequence: (batch, U+1, decoder_hidden_size) decoder_output, _, _ = self.decoder(decoder_input) + max_encoder_length = encoder_lengths.max().item() + encoder_hidden_states_trimmed = encoder_hidden_states[:, :max_encoder_length] + # Compute joint output for all (T, U+1) pairs via broadcasting # encoder: (batch, T, 1, encoder_hidden) -> projected to (batch, T, 1, decoder_hidden_size) # decoder: (batch, 1, U+1, decoder_hidden_size) token_logits, _ = self.joint( - encoder_hidden_states.unsqueeze(2), + encoder_hidden_states_trimmed.unsqueeze(2), decoder_output.unsqueeze(1), ) # token_logits: (batch, T, U+1, vocab_size+1) @@ -912,61 +918,97 @@ def generate( all_tokens = [[] for _ in range(batch_size)] token_frame_indices = [[] for _ in range(batch_size)] if return_timestamps else None token_durations_list = [[] for _ in range(batch_size)] if return_timestamps else None + batch_indices = torch.arange(batch_size, device=device) time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) + time_indices_current_labels = torch.zeros(batch_size, dtype=torch.long, device=device) active_mask = time_indices < valid_lengths + max_symbols = self.config.max_symbols_per_step + symbols_per_step = torch.zeros(batch_size, dtype=torch.long, device=device) + last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) + while active_mask.any(): safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) - encoder_frames = encoder_hidden_states[ - torch.arange(batch_size, device=device), safe_time_indices - ].unsqueeze(1) + encoder_frames = encoder_hidden_states[batch_indices, safe_time_indices].unsqueeze(1) + + token_logits, duration_logits = self.joint(encoder_frames, decoder_output) + token_logits = token_logits.squeeze(1).to(device) + duration_logits = duration_logits.squeeze(1).to(device) + + tokens = token_logits.argmax(dim=-1) + durations = duration_logits.argmax(dim=-1) + blank_mask = active_mask & (tokens == self.config.pad_token_id) + + # Force blank duration >= 1 to guarantee forward progress + durations = durations.masked_fill(blank_mask & (durations == 0), 1) + + # Save pre-advance position for timestamp recording + time_indices_current_labels.copy_(time_indices) + + # Advance time for all active elements + time_indices = time_indices + durations * active_mask + safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) + active_mask = time_indices < valid_lengths + advance_mask = active_mask & blank_mask + + # Inner loop: skip past consecutive blanks to find non-blank + while advance_mask.any(): + # Update timestamp tracking to current position + time_indices_current_labels = torch.where(advance_mask, time_indices, time_indices_current_labels) + encoder_frames = encoder_hidden_states[batch_indices, safe_time_indices].unsqueeze(1) - symbols_added = 0 - while symbols_added < self.config.max_symbols_per_step: token_logits, duration_logits = self.joint(encoder_frames, decoder_output) token_logits = token_logits.squeeze(1).to(device) duration_logits = duration_logits.squeeze(1).to(device) - tokens = token_logits.argmax(dim=-1) - durations = duration_logits.argmax(dim=-1) - emit_mask = active_mask & ~(tokens == self.config.pad_token_id) - - for i in range(batch_size): - if emit_mask[i]: - all_tokens[i].append(tokens[i].item()) - if token_frame_indices is not None: - token_frame_indices[i].append(time_indices[i].item()) - if token_durations_list is not None: - token_durations_list[i].append(durations[i].item()) - - if emit_mask.any(): - new_prev_tokens = tokens.unsqueeze(1) - new_decoder_output, new_hidden_state, new_cell_state = self.decoder( - new_prev_tokens, hidden_state, cell_state - ) - new_decoder_output = new_decoder_output.to(device) - new_hidden_state = new_hidden_state.to(device) - new_cell_state = new_cell_state.to(device) - - emit_mask_expanded = emit_mask.view(batch_size, 1, 1) - decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) - - emit_mask_state = emit_mask.view(1, batch_size, 1) - hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) - cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) - - # If duration is 0, stay on same frame (emit more tokens) - stay_mask = active_mask & (durations == 0) - if stay_mask.any(): - symbols_added += 1 - if symbols_added >= self.config.max_symbols_per_step: - time_indices[active_mask & stay_mask] += 1 - break - continue - - # Duration > 0: advance time - time_indices = time_indices + torch.where(active_mask, durations, torch.zeros_like(durations)) - break + more_tokens = token_logits.argmax(dim=-1) + more_durations = duration_logits.argmax(dim=-1) + + tokens = torch.where(advance_mask, more_tokens, tokens) + durations = torch.where(advance_mask, more_durations, durations) + + blank_mask = tokens == self.config.pad_token_id + durations = durations.masked_fill(blank_mask & (durations == 0), 1) + + time_indices = torch.where(advance_mask, time_indices + durations, time_indices) + safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) + active_mask = time_indices < valid_lengths + advance_mask = active_mask & blank_mask + + # Record results for non-blank tokens found + emit_mask = active_mask & (tokens != self.config.pad_token_id) + for i in range(batch_size): + if emit_mask[i]: + all_tokens[i].append(tokens[i].item()) + if token_frame_indices is not None: + token_frame_indices[i].append(time_indices_current_labels[i].item()) + if token_durations_list is not None: + token_durations_list[i].append(durations[i].item()) + + if emit_mask.any(): + new_prev_tokens = tokens.unsqueeze(1) + new_decoder_output, new_hidden_state, new_cell_state = self.decoder( + new_prev_tokens, hidden_state, cell_state + ) + new_decoder_output = new_decoder_output.to(device) + new_hidden_state = new_hidden_state.to(device) + new_cell_state = new_cell_state.to(device) + + emit_mask_expanded = emit_mask.view(batch_size, 1, 1) + decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) + + emit_mask_state = emit_mask.view(1, batch_size, 1) + hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) + cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) + + # Track symbols emitted per time step; force advance when max_symbols reached + time_changed = time_indices_current_labels != last_label_time + symbols_per_step = torch.where(time_changed, torch.zeros_like(symbols_per_step), symbols_per_step) + symbols_per_step = torch.where(emit_mask, symbols_per_step + 1, symbols_per_step) + last_label_time = torch.where(emit_mask, time_indices_current_labels, last_label_time) + force_advance = active_mask & (symbols_per_step >= max_symbols) + time_indices = time_indices + force_advance.long() + symbols_per_step = symbols_per_step.masked_fill(force_advance, 0) active_mask = time_indices < valid_lengths diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py index 69734fb055af..5670a9959c92 100644 --- a/src/transformers/models/parakeet/processing_parakeet.py +++ b/src/transformers/models/parakeet/processing_parakeet.py @@ -82,7 +82,9 @@ def __call__( if text is None: return inputs else: - inputs["labels"] = encodings["input_ids"] + labels = encodings["input_ids"] + labels[labels == self.tokenizer.pad_token_id] = -100 + inputs["labels"] = labels return inputs @property diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 4865cfd0e455..d284148744a1 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -492,11 +492,7 @@ class ParakeetForTDTIntegrationTest(unittest.TestCase): _dataset = None @classmethod - def setUp(cls): - # cls.checkpoint_name = "nvidia/parakeet-tdt-0.6b-v3" - # cls.dtype = torch.bfloat16 - # cls.processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") - + def setUpClass(cls): cls.checkpoint_name = "bezzam/parakeet-tdt-0.6b-v3-hf" cls.dtype = torch.bfloat16 cls.processor = AutoProcessor.from_pretrained("bezzam/parakeet-tdt-0.6b-v3-hf") @@ -593,6 +589,8 @@ def test_tdt_model_integration_timestamps(self): self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) # Check timestamps and durations - self.assertIsNotNone(output.token_timestamps, "token_timestamps should be returned when return_timestamps=True") + self.assertIsNotNone( + output.token_timestamps, "token_timestamps should be returned when return_timestamps=True" + ) torch.testing.assert_close(output.token_timestamps.cpu(), EXPECTED_TIMESTAMPS) torch.testing.assert_close(output.token_durations.cpu(), EXPECTED_DURATIONS) From b33002fca7988ec2a98e9413af7bacea6d8772bc Mon Sep 17 00:00:00 2001 From: Maksym Lypivskyi Date: Fri, 27 Feb 2026 16:36:03 +0100 Subject: [PATCH 09/67] revert: restore lasr generated files to original state --- src/transformers/models/lasr/configuration_lasr.py | 6 +++--- src/transformers/models/lasr/modeling_lasr.py | 3 +-- src/transformers/models/lasr/processing_lasr.py | 4 +--- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/lasr/configuration_lasr.py b/src/transformers/models/lasr/configuration_lasr.py index 60101030f38e..4d82b85044a2 100644 --- a/src/transformers/models/lasr/configuration_lasr.py +++ b/src/transformers/models/lasr/configuration_lasr.py @@ -150,8 +150,6 @@ def __init__( self.subsampling_conv_stride = subsampling_conv_stride self.subsampling_conv_channels = subsampling_conv_channels self.num_mel_bins = num_mel_bins - self.hop_length = hop_length - self.sampling_rate = sampling_rate self.dropout = dropout self.dropout_positions = dropout_positions @@ -161,7 +159,9 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range - super().__init__(**kwargs) + super().__init__( + **kwargs, + ) class LasrCTCConfig(PreTrainedConfig): diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py index 18fa46657c78..24fa4872a2a8 100644 --- a/src/transformers/models/lasr/modeling_lasr.py +++ b/src/transformers/models/lasr/modeling_lasr.py @@ -36,7 +36,6 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel from .configuration_lasr import LasrCTCConfig, LasrEncoderConfig @@ -592,7 +591,7 @@ class LasrForCTC(LasrPreTrainedModel): def __init__(self, config: LasrCTCConfig): super().__init__(config) - self.encoder = AutoModel.from_config(config.encoder_config) + self.encoder = LasrEncoder(config.encoder_config) # Conv rather than linear to be consistent with NeMO decoding layer self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1) diff --git a/src/transformers/models/lasr/processing_lasr.py b/src/transformers/models/lasr/processing_lasr.py index 644cd835936d..c1acaebaae07 100644 --- a/src/transformers/models/lasr/processing_lasr.py +++ b/src/transformers/models/lasr/processing_lasr.py @@ -88,9 +88,7 @@ def __call__( if text is None: return inputs else: - labels = encodings["input_ids"] - labels[labels == self.tokenizer.pad_token_id] = -100 - inputs["labels"] = labels + inputs["labels"] = encodings["input_ids"] return inputs @property From 48b39dd1a0f1b6123cf721cefd8afd19b0e0ca7f Mon Sep 17 00:00:00 2001 From: Maksym Lypivskyi Date: Fri, 27 Feb 2026 17:27:54 +0100 Subject: [PATCH 10/67] warn: torchaudio rnnt_loss does not train duration head --- .../models/parakeet/modeling_parakeet.py | 18 +++++++++++++++++- .../models/parakeet/modular_parakeet.py | 18 +++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index b1ff6da52c88..9909152e9970 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -32,13 +32,23 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_torchaudio_available +from ...utils import ( + ModelOutput, + TransformersKwargs, + auto_docstring, + can_return_tuple, + is_torchaudio_available, + logging, +) from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig +logger = logging.get_logger(__name__) + + @dataclass @auto_docstring( custom_intro=""" @@ -959,6 +969,12 @@ def forward( ) from torchaudio.functional import rnnt_loss + logger.warning_once( + "Training uses standard RNNT loss from torchaudio, which does not train the duration head. " + "The model will be trained as a regular RNNT. To train with TDT loss (including duration " + "prediction), use NeMo's TDT loss implementation." + ) + # Compute encoder output lengths attention_mask = ( attention_mask diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 6f59b829b093..6791875e69de 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -26,7 +26,14 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_torchaudio_available +from ...utils import ( + ModelOutput, + TransformersKwargs, + auto_docstring, + can_return_tuple, + is_torchaudio_available, + logging, +) from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -35,6 +42,9 @@ from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig +logger = logging.get_logger(__name__) + + @dataclass @auto_docstring( custom_intro=""" @@ -798,6 +808,12 @@ def forward( ) from torchaudio.functional import rnnt_loss + logger.warning_once( + "Training uses standard RNNT loss from torchaudio, which does not train the duration head. " + "The model will be trained as a regular RNNT. To train with TDT loss (including duration " + "prediction), use NeMo's TDT loss implementation." + ) + # Compute encoder output lengths attention_mask = ( attention_mask From e9f23ab617a13cee9f825171e28a28119c3f766c Mon Sep 17 00:00:00 2001 From: Eric B Date: Mon, 2 Mar 2026 21:50:06 +0100 Subject: [PATCH 11/67] Relax timestamp test, and test nits. --- .../models/parakeet/test_modeling_parakeet.py | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index d284148744a1..acf9718ec2a5 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -273,7 +273,7 @@ class ParakeetForCTCIntegrationTest(unittest.TestCase): def setUp(cls): cls.checkpoint_name = "nvidia/parakeet-ctc-1.1b" cls.dtype = torch.bfloat16 - cls.processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b") + cls.processor = AutoProcessor.from_pretrained(cls.checkpoint_name) def tearDown(self): cleanup(torch_device, gc_collect=True) @@ -304,8 +304,7 @@ def test_1b_model_integration(self): EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(1) - model = ParakeetForCTC.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) - model.eval() + model = ParakeetForCTC.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map=torch_device) model.to(torch_device) inputs = self.processor(samples) @@ -327,8 +326,7 @@ def test_1b_model_integration_batched(self): EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(5) - model = ParakeetForCTC.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) - model.eval() + model = ParakeetForCTC.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map=torch_device) model.to(torch_device) inputs = self.processor(samples) @@ -492,10 +490,10 @@ class ParakeetForTDTIntegrationTest(unittest.TestCase): _dataset = None @classmethod - def setUpClass(cls): + def setUp(cls): cls.checkpoint_name = "bezzam/parakeet-tdt-0.6b-v3-hf" cls.dtype = torch.bfloat16 - cls.processor = AutoProcessor.from_pretrained("bezzam/parakeet-tdt-0.6b-v3-hf") + cls.processor = AutoProcessor.from_pretrained(cls.checkpoint_name) def tearDown(self): cleanup(torch_device, gc_collect=True) @@ -526,8 +524,7 @@ def test_tdt_model_integration(self): EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) - model.eval() + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map=torch_device) model.to(torch_device) inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) @@ -549,8 +546,7 @@ def test_tdt_model_integration_batched(self): EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) - model.eval() + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map=torch_device) model.to(torch_device) inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) @@ -573,16 +569,15 @@ def test_tdt_model_integration_timestamps(self): EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] EXPECTED_TIMESTAMPS = torch.tensor(raw_data["token_timestamps"]) - EXPECTED_DURATIONS = torch.tensor(raw_data["token_durations"]) + EXPECTED_DURATIONS = raw_data["token_durations"] - # Dynamically determine number of samples from expected results + # Use larger precision for testing token durations and timestamps samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) - model.eval() + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map=torch_device) model.to(torch_device) inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) - inputs.to(torch_device, dtype=self.dtype) + inputs.to(torch_device, dtype=model.dtype) output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.batch_decode(output.sequences, skip_special_tokens=True) @@ -592,5 +587,6 @@ def test_tdt_model_integration_timestamps(self): self.assertIsNotNone( output.token_timestamps, "token_timestamps should be returned when return_timestamps=True" ) - torch.testing.assert_close(output.token_timestamps.cpu(), EXPECTED_TIMESTAMPS) - torch.testing.assert_close(output.token_durations.cpu(), EXPECTED_DURATIONS) + # Relax tolerance for timestamps due to potential internal precision differences + torch.testing.assert_close(output.token_timestamps.cpu(), EXPECTED_TIMESTAMPS, atol=0.4, rtol=1e-6) + self.assertListEqual(output.token_durations.cpu().tolist(), EXPECTED_DURATIONS) From e2b97aa1ca18e666c4b98b6ddb85bcc823bcbc53 Mon Sep 17 00:00:00 2001 From: Maksym Lypivskyi Date: Tue, 3 Mar 2026 15:18:31 +0100 Subject: [PATCH 12/67] feat: TDT training --- .../models/parakeet/configuration_parakeet.py | 5 + .../models/parakeet/convert_nemo_to_hf.py | 1 + .../models/parakeet/modeling_parakeet.py | 146 +++++++++++++--- .../models/parakeet/modular_parakeet.py | 135 +++++++++++++-- .../fixtures/parakeet/expected_tdt_loss.json | 39 +++++ .../parakeet/generate_tdt_loss_fixtures.py | 156 ++++++++++++++++++ .../models/parakeet/test_modeling_parakeet.py | 77 +++++++++ 7 files changed, 517 insertions(+), 42 deletions(-) create mode 100644 tests/fixtures/parakeet/expected_tdt_loss.json create mode 100644 tests/models/parakeet/generate_tdt_loss_fixtures.py diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 3c233726e36c..cc1c7bc31e8d 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -254,6 +254,9 @@ class ParakeetTDTConfig(PreTrainedConfig): Number of LSTM layers in the prediction network. num_duration_bins (`int`, *optional*, defaults to 5): Number of duration bins for predicting token durations. + durations (`list[int]`, *optional*, defaults to `[0, 1, 2, 3, 4]`): + Duration values for TDT loss computation. Each value represents how many frames a token or blank + emission spans. Must have length equal to `num_duration_bins`. hidden_act (`str`, *optional*, defaults to `"relu"`): The activation function in the joint network. max_symbols_per_step (`int`, *optional*, defaults to 10): @@ -287,6 +290,7 @@ def __init__( decoder_hidden_size=640, num_decoder_layers=1, num_duration_bins=5, + durations=None, hidden_act="relu", max_symbols_per_step=10, encoder_config: dict | ParakeetEncoderConfig = None, @@ -297,6 +301,7 @@ def __init__( self.decoder_hidden_size = decoder_hidden_size self.num_decoder_layers = num_decoder_layers self.num_duration_bins = num_duration_bins + self.durations = durations if durations is not None else list(range(num_duration_bins)) self.hidden_act = hidden_act self.max_symbols_per_step = max_symbols_per_step diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index f4ace95cf7ed..196fee5e21e6 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -321,6 +321,7 @@ def convert_tdt_config(nemo_config, encoder_config): decoder_hidden_size=decoder_hidden_size, num_decoder_layers=num_decoder_layers, num_duration_bins=num_duration_bins, + durations=durations, hidden_act="relu", max_symbols_per_step=10, seconds_per_frame=seconds_per_frame, diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 9909152e9970..521e2f5b0ed8 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -32,23 +32,13 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - ModelOutput, - TransformersKwargs, - auto_docstring, - can_return_tuple, - is_torchaudio_available, - logging, -) +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig -logger = logging.get_logger(__name__) - - @dataclass @auto_docstring( custom_intro=""" @@ -910,6 +900,119 @@ def forward( return token_logits, duration_logits +def tdt_loss( + token_logits: torch.Tensor, + duration_logits: torch.Tensor, + targets: torch.Tensor, + logit_lengths: torch.Tensor, + target_lengths: torch.Tensor, + blank: int, + durations: list[int], + sigma: float = 0.0, + reduction: str = "mean", +) -> torch.Tensor: + """ + Compute TDT (Token-and-Duration Transducer) loss. + + Ported from NeMo's `TDTLossPytorch`. Unlike standard RNNT loss, this loss trains both + the token prediction head and the duration prediction head. + + Args: + token_logits: Token logits of shape `(batch, T, U+1, vocab_size+1)`. + duration_logits: Duration logits of shape `(batch, T, U+1, num_durations)`. + targets: Target labels of shape `(batch, U)`. + logit_lengths: Encoder output lengths of shape `(batch,)`. + target_lengths: Target lengths of shape `(batch,)`. + blank: Blank token id. + durations: List of duration values (e.g., `[0, 1, 2, 3, 4]`). + sigma: Logit undernormalization constant (see TDT paper). Defaults to `0.0`. + reduction: Loss reduction method. One of `"mean"`, `"sum"`, or `"none"`. Defaults to `"mean"`. + + Returns: + Scalar loss tensor (or per-example losses if `reduction="none"`). + + Reference: + *Token-and-Duration Transducer (TDT)* — https://arxiv.org/abs/2304.06795 + """ + device = token_logits.device + batch_size, max_t, max_u, _ = token_logits.shape + + # Apply log-softmax to get log probabilities + token_log_probs = torch.log_softmax(token_logits, dim=-1) - sigma + duration_log_probs = torch.log_softmax(duration_logits, dim=-1) + + # Forward variable: log_alpha[b, t, u] = log P(y_{1:u} | x_{1:t}) + log_alpha = torch.full((batch_size, max_t, max_u), -1000.0, device=device) + log_alpha[:, 0, 0] = 0.0 + batch_idx = torch.arange(batch_size, device=device) + + for t in range(max_t): + for u in range(max_u): + if t == 0 and u == 0: + continue + + # Accumulate log-probabilities from all incoming arcs + candidates = [] + + for n, dur in enumerate(durations): + t_prev = t - dur + if t_prev < 0: + continue + + # Blank arc (duration > 0): same label position, skip `dur` frames + if dur > 0: + blank_contribution = ( + log_alpha[:, t_prev, u] + + token_log_probs[:, t_prev, u, blank] + + duration_log_probs[:, t_prev, u, n] + ) + candidates.append(blank_contribution) + + # Label arc (u > 0): emit label y_u from position (t_prev, u-1) + if u > 0: + label_contribution = ( + log_alpha[:, t_prev, u - 1] + + token_log_probs[batch_idx, t_prev, u - 1, targets[:, u - 1]] + + duration_log_probs[:, t_prev, u - 1, n] + ) + candidates.append(label_contribution) + + if candidates: + log_alpha[:, t, u] = torch.logsumexp(torch.stack(candidates, dim=0), dim=0) + + # Terminal probability: sum over blank arcs that reach (T, U) from (T-dur, U) + log_probs = torch.full((batch_size,), -1000.0, device=device) + for n, dur in enumerate(durations): + if dur == 0: + continue + # For each example, check if act_lens[b] - dur >= 0 + t_final = logit_lengths - dur + valid = t_final >= 0 + if not valid.any(): + continue + + t_clamped = t_final.clamp(min=0) + terminal = ( + log_alpha[batch_idx, t_clamped, target_lengths] + + token_log_probs[batch_idx, t_clamped, target_lengths, blank] + + duration_log_probs[batch_idx, t_clamped, target_lengths, n] + ) + # Only update valid entries + combined = torch.stack([log_probs, terminal], dim=0) + log_probs = torch.where(valid, torch.logsumexp(combined, dim=0), log_probs) + + losses = -log_probs + + if reduction == "mean": + return (losses / target_lengths.float()).mean() + elif reduction == "sum": + return losses.sum() + elif reduction == "none": + return losses + else: + return (losses / target_lengths.float()).mean() + + @auto_docstring( custom_intro=""" Parakeet Encoder with a TDT (Token Duration Transducer) head. @@ -963,18 +1066,6 @@ def forward( loss = None if labels is not None: - if not is_torchaudio_available(): - raise ImportError( - "torchaudio is required for TDT loss computation. Install it with: pip install torchaudio" - ) - from torchaudio.functional import rnnt_loss - - logger.warning_once( - "Training uses standard RNNT loss from torchaudio, which does not train the duration head. " - "The model will be trained as a regular RNNT. To train with TDT loss (including duration " - "prediction), use NeMo's TDT loss implementation." - ) - # Compute encoder output lengths attention_mask = ( attention_mask @@ -1004,18 +1095,21 @@ def forward( # Compute joint output for all (T, U+1) pairs via broadcasting # encoder: (batch, T, 1, encoder_hidden) -> projected to (batch, T, 1, decoder_hidden_size) # decoder: (batch, 1, U+1, decoder_hidden_size) - token_logits, _ = self.joint( + token_logits, duration_logits = self.joint( encoder_hidden_states_trimmed.unsqueeze(2), decoder_output.unsqueeze(1), ) # token_logits: (batch, T, U+1, vocab_size+1) + # duration_logits: (batch, T, U+1, num_duration_bins) - loss = rnnt_loss( - logits=token_logits.float(), + loss = tdt_loss( + token_logits=token_logits.float(), + duration_logits=duration_logits.float(), targets=labels.int(), logit_lengths=encoder_lengths.int(), target_lengths=target_lengths.int(), blank=self.config.pad_token_id, + durations=self.config.durations, reduction="mean", ) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 6791875e69de..0f9135852328 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -31,7 +31,6 @@ TransformersKwargs, auto_docstring, can_return_tuple, - is_torchaudio_available, logging, ) from ...utils.generic import maybe_autocast, merge_with_config_defaults @@ -727,6 +726,119 @@ def forward( return decoder_output, hidden_state, cell_state +def tdt_loss( + token_logits: torch.Tensor, + duration_logits: torch.Tensor, + targets: torch.Tensor, + logit_lengths: torch.Tensor, + target_lengths: torch.Tensor, + blank: int, + durations: list[int], + sigma: float = 0.0, + reduction: str = "mean", +) -> torch.Tensor: + """ + Compute TDT (Token-and-Duration Transducer) loss. + + Ported from NeMo's `TDTLossPytorch`. Unlike standard RNNT loss, this loss trains both + the token prediction head and the duration prediction head. + + Args: + token_logits: Token logits of shape `(batch, T, U+1, vocab_size+1)`. + duration_logits: Duration logits of shape `(batch, T, U+1, num_durations)`. + targets: Target labels of shape `(batch, U)`. + logit_lengths: Encoder output lengths of shape `(batch,)`. + target_lengths: Target lengths of shape `(batch,)`. + blank: Blank token id. + durations: List of duration values (e.g., `[0, 1, 2, 3, 4]`). + sigma: Logit undernormalization constant (see TDT paper). Defaults to `0.0`. + reduction: Loss reduction method. One of `"mean"`, `"sum"`, or `"none"`. Defaults to `"mean"`. + + Returns: + Scalar loss tensor (or per-example losses if `reduction="none"`). + + Reference: + *Token-and-Duration Transducer (TDT)* — https://arxiv.org/abs/2304.06795 + """ + device = token_logits.device + batch_size, max_t, max_u, _ = token_logits.shape + + # Apply log-softmax to get log probabilities + token_log_probs = torch.log_softmax(token_logits, dim=-1) - sigma + duration_log_probs = torch.log_softmax(duration_logits, dim=-1) + + # Forward variable: log_alpha[b, t, u] = log P(y_{1:u} | x_{1:t}) + log_alpha = torch.full((batch_size, max_t, max_u), -1000.0, device=device) + log_alpha[:, 0, 0] = 0.0 + batch_idx = torch.arange(batch_size, device=device) + + for t in range(max_t): + for u in range(max_u): + if t == 0 and u == 0: + continue + + # Accumulate log-probabilities from all incoming arcs + candidates = [] + + for n, dur in enumerate(durations): + t_prev = t - dur + if t_prev < 0: + continue + + # Blank arc (duration > 0): same label position, skip `dur` frames + if dur > 0: + blank_contribution = ( + log_alpha[:, t_prev, u] + + token_log_probs[:, t_prev, u, blank] + + duration_log_probs[:, t_prev, u, n] + ) + candidates.append(blank_contribution) + + # Label arc (u > 0): emit label y_u from position (t_prev, u-1) + if u > 0: + label_contribution = ( + log_alpha[:, t_prev, u - 1] + + token_log_probs[batch_idx, t_prev, u - 1, targets[:, u - 1]] + + duration_log_probs[:, t_prev, u - 1, n] + ) + candidates.append(label_contribution) + + if candidates: + log_alpha[:, t, u] = torch.logsumexp(torch.stack(candidates, dim=0), dim=0) + + # Terminal probability: sum over blank arcs that reach (T, U) from (T-dur, U) + log_probs = torch.full((batch_size,), -1000.0, device=device) + for n, dur in enumerate(durations): + if dur == 0: + continue + # For each example, check if act_lens[b] - dur >= 0 + t_final = logit_lengths - dur + valid = t_final >= 0 + if not valid.any(): + continue + + t_clamped = t_final.clamp(min=0) + terminal = ( + log_alpha[batch_idx, t_clamped, target_lengths] + + token_log_probs[batch_idx, t_clamped, target_lengths, blank] + + duration_log_probs[batch_idx, t_clamped, target_lengths, n] + ) + # Only update valid entries + combined = torch.stack([log_probs, terminal], dim=0) + log_probs = torch.where(valid, torch.logsumexp(combined, dim=0), log_probs) + + losses = -log_probs + + if reduction == "mean": + return (losses / target_lengths.float()).mean() + elif reduction == "sum": + return losses.sum() + elif reduction == "none": + return losses + else: + return (losses / target_lengths.float()).mean() + + class ParakeetTDTJointNetwork(nn.Module): """Joint network that combines encoder and decoder outputs to predict tokens and durations.""" @@ -802,18 +914,6 @@ def forward( loss = None if labels is not None: - if not is_torchaudio_available(): - raise ImportError( - "torchaudio is required for TDT loss computation. Install it with: pip install torchaudio" - ) - from torchaudio.functional import rnnt_loss - - logger.warning_once( - "Training uses standard RNNT loss from torchaudio, which does not train the duration head. " - "The model will be trained as a regular RNNT. To train with TDT loss (including duration " - "prediction), use NeMo's TDT loss implementation." - ) - # Compute encoder output lengths attention_mask = ( attention_mask @@ -843,18 +943,21 @@ def forward( # Compute joint output for all (T, U+1) pairs via broadcasting # encoder: (batch, T, 1, encoder_hidden) -> projected to (batch, T, 1, decoder_hidden_size) # decoder: (batch, 1, U+1, decoder_hidden_size) - token_logits, _ = self.joint( + token_logits, duration_logits = self.joint( encoder_hidden_states_trimmed.unsqueeze(2), decoder_output.unsqueeze(1), ) # token_logits: (batch, T, U+1, vocab_size+1) + # duration_logits: (batch, T, U+1, num_duration_bins) - loss = rnnt_loss( - logits=token_logits.float(), + loss = tdt_loss( + token_logits=token_logits.float(), + duration_logits=duration_logits.float(), targets=labels.int(), logit_lengths=encoder_lengths.int(), target_lengths=target_lengths.int(), blank=self.config.pad_token_id, + durations=self.config.durations, reduction="mean", ) diff --git a/tests/fixtures/parakeet/expected_tdt_loss.json b/tests/fixtures/parakeet/expected_tdt_loss.json new file mode 100644 index 000000000000..b8177341adcd --- /dev/null +++ b/tests/fixtures/parakeet/expected_tdt_loss.json @@ -0,0 +1,39 @@ +{ + "_comment": "Generated by generate_tdt_loss_fixtures.py using NeMo's TDTLossPytorch (CPU-patched). Inputs use torch.manual_seed(42), batch=2, T=8, U=4, vocab=5, durations=[0,1,2,3,4].", + "seed": 42, + "batch_size": 2, + "max_t": 8, + "max_u": 4, + "vocab_size": 5, + "durations": [ + 0, + 1, + 2, + 3, + 4 + ], + "targets": [ + [ + 4, + 2, + 2, + 1 + ], + [ + 0, + 4, + 2, + 4 + ] + ], + "logit_lengths": [ + 8, + 7 + ], + "target_lengths": [ + 4, + 3 + ], + "expected_loss_sum": 21.978168487548828, + "expected_loss_mean": 3.12455415725708 +} \ No newline at end of file diff --git a/tests/models/parakeet/generate_tdt_loss_fixtures.py b/tests/models/parakeet/generate_tdt_loss_fixtures.py new file mode 100644 index 000000000000..b7eae3639aee --- /dev/null +++ b/tests/models/parakeet/generate_tdt_loss_fixtures.py @@ -0,0 +1,156 @@ +""" +Generate TDT loss reference fixtures using NeMo's TDTLossPytorch. + +Usage (requires NeMo installed, no CUDA needed): + python tests/models/parakeet/generate_tdt_loss_fixtures.py + +Outputs: + tests/fixtures/parakeet/expected_tdt_loss.json + +The fixture contains deterministic inputs and expected loss values +computed by NeMo's TDTLossPytorch. Our tdt_loss implementation is +tested against these values in test_modeling_parakeet.py::TDTLossTest. +""" + +import json +import os + +import torch + + +def make_test_inputs(): + torch.manual_seed(42) + batch_size, max_t, max_u, vocab_size, num_durations = 2, 8, 4, 5, 5 + blank = vocab_size + + combined_logits = torch.randn(batch_size, max_t, max_u + 1, vocab_size + 1 + num_durations) + targets = torch.randint(0, vocab_size, (batch_size, max_u)) + logit_lengths = torch.tensor([max_t, max_t - 1]) + target_lengths = torch.tensor([max_u, max_u - 1]) + + return { + "combined_logits": combined_logits, + "token_logits": combined_logits[..., : vocab_size + 1], + "duration_logits": combined_logits[..., vocab_size + 1 :], + "targets": targets, + "logit_lengths": logit_lengths, + "target_lengths": target_lengths, + "blank": blank, + "durations": [0, 1, 2, 3, 4], + } + + +def compute_nemo_reference(inputs): + """Run NeMo's TDTLossPytorch (monkey-patched for CPU).""" + import nemo.collections.asr.losses.rnnt_pytorch as rnnt_mod + + # NeMo hardcodes .cuda() — patch compute_forward_prob for CPU + def patched_compute(self, acts, duration_acts, labels, act_lens, label_lens): + B, T, U, _ = acts.shape + log_alpha = torch.zeros(B, T, U, device=acts.device) + + for b in range(B): + for t in range(T): + for u in range(U): + if u == 0: + if t == 0: + log_alpha[b, t, u] = 0.0 + else: + log_alpha[b, t, u] = -1000.0 + for n, l in enumerate(self.durations): + if t - l >= 0 and l > 0: + tmp = ( + log_alpha[b, t - l, u] + + acts[b, t - l, u, self.blank] + + duration_acts[b, t - l, u, n] + ) + log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) + else: + log_alpha[b, t, u] = -1000.0 + for n, l in enumerate(self.durations): + if t - l >= 0: + if l > 0: + tmp = ( + log_alpha[b, t - l, u] + + acts[b, t - l, u, self.blank] + + duration_acts[b, t - l, u, n] + ) + log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) + tmp = ( + log_alpha[b, t - l, u - 1] + + acts[b, t - l, u - 1, labels[b, u - 1]] + + duration_acts[b, t - l, u - 1, n] + ) + log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) + + log_probs = [] + for b in range(B): + tt = torch.tensor(-1000.0, device=acts.device) + for n, l in enumerate(self.durations): + if act_lens[b] - l >= 0 and l > 0: + bb = ( + log_alpha[b, act_lens[b] - l, label_lens[b]] + + acts[b, act_lens[b] - l, label_lens[b], self.blank] + + duration_acts[b, act_lens[b] - l, label_lens[b], n] + ) + tt = self.logsumexp(bb, 1.0 * tt) + log_probs.append(tt) + + return torch.stack(log_probs), log_alpha + + orig = rnnt_mod.TDTLossPytorch.compute_forward_prob + rnnt_mod.TDTLossPytorch.compute_forward_prob = patched_compute + + results = {} + for reduction in ["sum", "mean"]: + loss_fn = rnnt_mod.TDTLossPytorch( + blank=inputs["blank"], + durations=inputs["durations"], + reduction=reduction, + sigma=0.0, + ) + loss = loss_fn( + acts=inputs["combined_logits"], + labels=inputs["targets"], + act_lens=inputs["logit_lengths"], + label_lens=inputs["target_lengths"], + ) + results[reduction] = loss.item() + print(f"NeMo TDT loss (reduction={reduction}): {loss.item():.10f}") + + rnnt_mod.TDTLossPytorch.compute_forward_prob = orig + return results + + +def main(): + inputs = make_test_inputs() + nemo_results = compute_nemo_reference(inputs) + + fixture = { + "_comment": "Generated by generate_tdt_loss_fixtures.py using NeMo's TDTLossPytorch (CPU-patched). " + "Inputs use torch.manual_seed(42), batch=2, T=8, U=4, vocab=5, durations=[0,1,2,3,4].", + "seed": 42, + "batch_size": 2, + "max_t": 8, + "max_u": 4, + "vocab_size": 5, + "durations": [0, 1, 2, 3, 4], + "targets": inputs["targets"].tolist(), + "logit_lengths": inputs["logit_lengths"].tolist(), + "target_lengths": inputs["target_lengths"].tolist(), + "expected_loss_sum": nemo_results["sum"], + "expected_loss_mean": nemo_results["mean"], + } + + output_path = os.path.join(os.path.dirname(__file__), "..", "..", "fixtures", "parakeet", "expected_tdt_loss.json") + output_path = os.path.normpath(output_path) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(output_path, "w") as f: + json.dump(fixture, f, indent=2) + + print(f"\nFixture written to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index acf9718ec2a5..e3d58d9ac2e4 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -40,6 +40,83 @@ ParakeetForTDT, ParakeetTDTConfig, ) + from transformers.models.parakeet.modeling_parakeet import tdt_loss + + +@require_torch +class TDTLossTest(unittest.TestCase): + """Test tdt_loss against reference values generated by NeMo's TDTLossPytorch. + + Fixture generated with: tests/models/parakeet/generate_tdt_loss_fixtures.py + """ + + FIXTURE_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_tdt_loss.json" + + @classmethod + def setUpClass(cls): + with open(cls.FIXTURE_PATH) as f: + cls.fixture = json.load(f) + + def _make_inputs(self): + torch.manual_seed(self.fixture["seed"]) + batch_size = self.fixture["batch_size"] + max_t = self.fixture["max_t"] + max_u = self.fixture["max_u"] + vocab_size = self.fixture["vocab_size"] + num_durations = len(self.fixture["durations"]) + blank = vocab_size + + combined_logits = torch.randn(batch_size, max_t, max_u + 1, vocab_size + 1 + num_durations) + targets = torch.randint(0, vocab_size, (batch_size, max_u)) + logit_lengths = torch.tensor(self.fixture["logit_lengths"]) + target_lengths = torch.tensor(self.fixture["target_lengths"]) + + return { + "token_logits": combined_logits[..., : vocab_size + 1], + "duration_logits": combined_logits[..., vocab_size + 1 :], + "targets": targets, + "logit_lengths": logit_lengths, + "target_lengths": target_lengths, + "blank": blank, + "durations": self.fixture["durations"], + } + + def test_tdt_loss_sum(self): + inputs = self._make_inputs() + loss = tdt_loss(**inputs, reduction="sum") + expected = torch.tensor(self.fixture["expected_loss_sum"]) + torch.testing.assert_close(loss, expected, rtol=1e-4, atol=1e-4) + + def test_tdt_loss_mean(self): + inputs = self._make_inputs() + loss = tdt_loss(**inputs, reduction="mean") + expected = torch.tensor(self.fixture["expected_loss_mean"]) + torch.testing.assert_close(loss, expected, rtol=1e-4, atol=1e-4) + + def test_tdt_loss_none(self): + inputs = self._make_inputs() + losses = tdt_loss(**inputs, reduction="none") + self.assertEqual(losses.shape, (self.fixture["batch_size"],)) + expected_sum = torch.tensor(self.fixture["expected_loss_sum"]) + torch.testing.assert_close(losses.sum(), expected_sum, rtol=1e-4, atol=1e-4) + + def test_tdt_loss_with_sigma(self): + inputs = self._make_inputs() + loss_no_sigma = tdt_loss(**inputs, sigma=0.0, reduction="sum") + loss_with_sigma = tdt_loss(**inputs, sigma=0.05, reduction="sum") + self.assertFalse(torch.allclose(loss_no_sigma, loss_with_sigma)) + self.assertGreater(loss_with_sigma.item(), loss_no_sigma.item()) + + def test_tdt_loss_gradient_flows(self): + inputs = self._make_inputs() + inputs["token_logits"] = inputs["token_logits"].requires_grad_(True) + inputs["duration_logits"] = inputs["duration_logits"].requires_grad_(True) + loss = tdt_loss(**inputs, reduction="mean") + loss.backward() + self.assertIsNotNone(inputs["token_logits"].grad) + self.assertIsNotNone(inputs["duration_logits"].grad) + self.assertFalse(torch.all(inputs["token_logits"].grad == 0)) + self.assertFalse(torch.all(inputs["duration_logits"].grad == 0)) class ParakeetEncoderModelTester: From 6b9fc731e9a9758904e3ce03a197d54f0c7381e6 Mon Sep 17 00:00:00 2001 From: Maksym Lypivskyi Date: Tue, 3 Mar 2026 15:28:27 +0100 Subject: [PATCH 13/67] chore: for cuda detection and run without patching --- .../parakeet/generate_tdt_loss_fixtures.py | 123 ++++++++++-------- 1 file changed, 70 insertions(+), 53 deletions(-) diff --git a/tests/models/parakeet/generate_tdt_loss_fixtures.py b/tests/models/parakeet/generate_tdt_loss_fixtures.py index b7eae3639aee..582ac7e51333 100644 --- a/tests/models/parakeet/generate_tdt_loss_fixtures.py +++ b/tests/models/parakeet/generate_tdt_loss_fixtures.py @@ -40,66 +40,81 @@ def make_test_inputs(): } -def compute_nemo_reference(inputs): - """Run NeMo's TDTLossPytorch (monkey-patched for CPU).""" - import nemo.collections.asr.losses.rnnt_pytorch as rnnt_mod - - # NeMo hardcodes .cuda() — patch compute_forward_prob for CPU - def patched_compute(self, acts, duration_acts, labels, act_lens, label_lens): - B, T, U, _ = acts.shape - log_alpha = torch.zeros(B, T, U, device=acts.device) - - for b in range(B): - for t in range(T): - for u in range(U): - if u == 0: - if t == 0: - log_alpha[b, t, u] = 0.0 - else: - log_alpha[b, t, u] = -1000.0 - for n, l in enumerate(self.durations): - if t - l >= 0 and l > 0: - tmp = ( - log_alpha[b, t - l, u] - + acts[b, t - l, u, self.blank] - + duration_acts[b, t - l, u, n] - ) - log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) +def _patched_compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens): + """NeMo's compute_forward_prob with .cuda() replaced by device-aware allocation. + + This is identical to NeMo's TDTLossPytorch.compute_forward_prob except + `log_alpha = log_alpha.cuda()` is replaced with `device=acts.device`, and + `torch.Tensor([-1000.0]).cuda()[0]` is replaced with `torch.tensor(-1000.0, device=acts.device)`. + The loss math is unchanged. + """ + B, T, U, _ = acts.shape + log_alpha = torch.zeros(B, T, U, device=acts.device) + + for b in range(B): + for t in range(T): + for u in range(U): + if u == 0: + if t == 0: + log_alpha[b, t, u] = 0.0 else: log_alpha[b, t, u] = -1000.0 for n, l in enumerate(self.durations): - if t - l >= 0: - if l > 0: - tmp = ( - log_alpha[b, t - l, u] - + acts[b, t - l, u, self.blank] - + duration_acts[b, t - l, u, n] - ) - log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) + if t - l >= 0 and l > 0: + tmp = ( + log_alpha[b, t - l, u] + + acts[b, t - l, u, self.blank] + + duration_acts[b, t - l, u, n] + ) + log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) + else: + log_alpha[b, t, u] = -1000.0 + for n, l in enumerate(self.durations): + if t - l >= 0: + if l > 0: tmp = ( - log_alpha[b, t - l, u - 1] - + acts[b, t - l, u - 1, labels[b, u - 1]] - + duration_acts[b, t - l, u - 1, n] + log_alpha[b, t - l, u] + + acts[b, t - l, u, self.blank] + + duration_acts[b, t - l, u, n] ) log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) + tmp = ( + log_alpha[b, t - l, u - 1] + + acts[b, t - l, u - 1, labels[b, u - 1]] + + duration_acts[b, t - l, u - 1, n] + ) + log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) + + log_probs = [] + for b in range(B): + tt = torch.tensor(-1000.0, device=acts.device) + for n, l in enumerate(self.durations): + if act_lens[b] - l >= 0 and l > 0: + bb = ( + log_alpha[b, act_lens[b] - l, label_lens[b]] + + acts[b, act_lens[b] - l, label_lens[b], self.blank] + + duration_acts[b, act_lens[b] - l, label_lens[b], n] + ) + tt = self.logsumexp(bb, 1.0 * tt) + log_probs.append(tt) + + return torch.stack(log_probs), log_alpha - log_probs = [] - for b in range(B): - tt = torch.tensor(-1000.0, device=acts.device) - for n, l in enumerate(self.durations): - if act_lens[b] - l >= 0 and l > 0: - bb = ( - log_alpha[b, act_lens[b] - l, label_lens[b]] - + acts[b, act_lens[b] - l, label_lens[b], self.blank] - + duration_acts[b, act_lens[b] - l, label_lens[b], n] - ) - tt = self.logsumexp(bb, 1.0 * tt) - log_probs.append(tt) - return torch.stack(log_probs), log_alpha +def compute_nemo_reference(inputs): + """Run NeMo's TDTLossPytorch. - orig = rnnt_mod.TDTLossPytorch.compute_forward_prob - rnnt_mod.TDTLossPytorch.compute_forward_prob = patched_compute + On CPU, monkey-patches compute_forward_prob to avoid NeMo's hardcoded .cuda(). + On CUDA, runs NeMo unmodified. + """ + import nemo.collections.asr.losses.rnnt_pytorch as rnnt_mod + + need_patch = not torch.cuda.is_available() + orig = None + if need_patch: + print("No CUDA available — patching NeMo's compute_forward_prob for CPU (math unchanged)") + orig = rnnt_mod.TDTLossPytorch.compute_forward_prob + rnnt_mod.TDTLossPytorch.compute_forward_prob = _patched_compute_forward_prob results = {} for reduction in ["sum", "mean"]: @@ -118,7 +133,9 @@ def patched_compute(self, acts, duration_acts, labels, act_lens, label_lens): results[reduction] = loss.item() print(f"NeMo TDT loss (reduction={reduction}): {loss.item():.10f}") - rnnt_mod.TDTLossPytorch.compute_forward_prob = orig + if orig is not None: + rnnt_mod.TDTLossPytorch.compute_forward_prob = orig + return results @@ -127,7 +144,7 @@ def main(): nemo_results = compute_nemo_reference(inputs) fixture = { - "_comment": "Generated by generate_tdt_loss_fixtures.py using NeMo's TDTLossPytorch (CPU-patched). " + "_comment": "Generated by generate_tdt_loss_fixtures.py using NeMo's TDTLossPytorch. " "Inputs use torch.manual_seed(42), batch=2, T=8, U=4, vocab=5, durations=[0,1,2,3,4].", "seed": 42, "batch_size": 2, From 6c879bc043f03cfdf6204068aec4c382cfbd4fd0 Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 3 Mar 2026 18:28:38 +0100 Subject: [PATCH 14/67] Equivalent timestamp processing as Nemo, and various nits/cleanup. --- .../models/parakeet/configuration_parakeet.py | 36 +++++-------- .../models/parakeet/convert_nemo_to_hf.py | 11 +--- .../models/parakeet/modeling_parakeet.py | 6 +-- .../models/parakeet/modular_parakeet.py | 4 +- .../models/parakeet/processing_parakeet.py | 51 +++++++++++++++++++ .../expected_results_batch_tdt_timestamp.json | 2 +- .../models/parakeet/test_modeling_parakeet.py | 24 ++++++--- 7 files changed, 87 insertions(+), 47 deletions(-) diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 3c233726e36c..e51c27451060 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -110,8 +110,6 @@ def __init__( subsampling_factor=8, subsampling_conv_channels=256, num_mel_bins=80, - hop_length=160, - sampling_rate=16000, subsampling_conv_kernel_size=3, subsampling_conv_stride=2, dropout=0.1, @@ -140,8 +138,6 @@ def __init__( self.subsampling_factor = subsampling_factor self.subsampling_conv_channels = subsampling_conv_channels self.num_mel_bins = num_mel_bins - self.hop_length = hop_length - self.sampling_rate = sampling_rate self.dropout = dropout self.dropout_positions = dropout_positions @@ -164,19 +160,19 @@ class ParakeetCTCConfig(PreTrainedConfig): documentation from [`PreTrainedConfig`] for more information. Args: - vocab_size (`int`, *optional*, defaults to 1025): - Vocabulary size of the model. - ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`): - Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an - instance of [`ParakeetForCTC`]. - ctc_zero_infinity (`bool`, *optional*, defaults to `True`): - Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly - occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance - of [`ParakeetForCTC`]. - encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): - The config object or dictionary of the encoder. - pad_token_id (`int`, *optional*, defaults to 1024): - Padding token id. Also used as blank token id. + vocab_size (`int`, *optional*, defaults to 1025): + Vocabulary size of the model. + ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`): + Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an + instance of [`ParakeetForCTC`]. + ctc_zero_infinity (`bool`, *optional*, defaults to `True`): + Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly + occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance + of [`ParakeetForCTC`]. + encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): + The config object or dictionary of the encoder. + pad_token_id (`int`, *optional*, defaults to 1024): + Padding token id. Also used as blank token id. Example: ```python @@ -312,12 +308,6 @@ def __init__( super().__init__(**kwargs) - @property - def frame_rate(self): - return self.encoder_config.sampling_rate / ( - self.encoder_config.hop_length * self.encoder_config.subsampling_factor - ) - @classmethod def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): r""" diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index f4ace95cf7ed..07a54013fdee 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -303,17 +303,9 @@ def convert_tdt_config(nemo_config, encoder_config): durations = decoding_config.get("durations", [0, 1, 2, 3, 4]) num_duration_bins = len(durations) - preprocessor = nemo_config.get("preprocessor", {}) - sample_rate = preprocessor.get("sample_rate", 16000) - window_stride = preprocessor.get("window_stride", 0.01) - hop_length = int(window_stride * sample_rate) - subsampling_factor = encoder_config.subsampling_factor - seconds_per_frame = (hop_length * subsampling_factor) / sample_rate - print( f"TDT config: vocab_size={vocab_size}, decoder_hidden={decoder_hidden_size}, " f"decoder_layers={num_decoder_layers}, num_durations={num_duration_bins}, " - f"seconds_per_frame={seconds_per_frame}" ) return ParakeetTDTConfig( @@ -323,7 +315,6 @@ def convert_tdt_config(nemo_config, encoder_config): num_duration_bins=num_duration_bins, hidden_act="relu", max_symbols_per_step=10, - seconds_per_frame=seconds_per_frame, encoder_config=encoder_config.to_dict(), pad_token_id=vocab_size, ) @@ -399,7 +390,7 @@ def write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_t gc.collect() print("Reloading the model to check if it's saved correctly.") - ParakeetForTDT.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto") + ParakeetForTDT.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") print("Model reloaded successfully.") diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 9909152e9970..b425ae16fc97 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -1203,13 +1203,13 @@ def generate( token_timestamps = None token_durations = None if return_timestamps: - token_timestamps = torch.full((batch_size, max_len), 0.0, dtype=torch.float, device=device) + token_timestamps = torch.full((batch_size, max_len), 0.0, dtype=torch.long, device=device) token_durations = torch.full((batch_size, max_len), 0, dtype=torch.long, device=device) for i in range(batch_size): num_tokens = len(token_frame_indices[i]) if num_tokens > 0: - token_timestamps[i, :num_tokens] = ( - torch.tensor(token_frame_indices[i], dtype=torch.float, device=device) / self.config.frame_rate + token_timestamps[i, :num_tokens] = torch.tensor( + token_frame_indices[i], dtype=torch.long, device=device ) token_durations[i, :num_tokens] = torch.tensor( token_durations_list[i], dtype=torch.long, device=device diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 6791875e69de..ded9c5522852 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -1042,13 +1042,13 @@ def generate( token_timestamps = None token_durations = None if return_timestamps: - token_timestamps = torch.full((batch_size, max_len), 0.0, dtype=torch.float, device=device) + token_timestamps = torch.full((batch_size, max_len), 0.0, dtype=torch.long, device=device) token_durations = torch.full((batch_size, max_len), 0, dtype=torch.long, device=device) for i in range(batch_size): num_tokens = len(token_frame_indices[i]) if num_tokens > 0: token_timestamps[i, :num_tokens] = ( - torch.tensor(token_frame_indices[i], dtype=torch.float, device=device) / self.config.frame_rate + torch.tensor(token_frame_indices[i], dtype=torch.long, device=device) ) token_durations[i, :num_tokens] = torch.tensor( token_durations_list[i], dtype=torch.long, device=device diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py index 5670a9959c92..459bf52d90c9 100644 --- a/src/transformers/models/parakeet/processing_parakeet.py +++ b/src/transformers/models/parakeet/processing_parakeet.py @@ -27,6 +27,7 @@ class ParakeetProcessorKwargs(ProcessingKwargs, total=False): "sampling_rate": 16000, "padding": "longest", "return_attention_mask": True, + "subsampling_factor": 8, }, "text_kwargs": { "padding": True, @@ -92,5 +93,55 @@ def model_input_names(self): feature_extractor_input_names = self.feature_extractor.model_input_names return feature_extractor_input_names + ["labels"] + def decode(self, *args, token_timestamps=None, token_durations=None, **kwargs): + """ + Forward arguments to [`~PreTrainedTokenizer.decode`] and post-process the timestamps (if provided for TDT) as + in the NeMo library. + """ + decoded = self.tokenizer.decode(*args, **kwargs) + + if token_timestamps is not None and token_durations is not None: + token_ids = args[0] + + output_kwargs = self._merge_kwargs( + ParakeetProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + frame_rate = self.feature_extractor.hop_length / self.feature_extractor.sampling_rate * output_kwargs["audio_kwargs"]["subsampling_factor"] + proc_timestamps = [] + for batch_ids, timestamps, durations in zip(token_ids, token_timestamps, token_durations): + # Original NeMo: https://github.com/NVIDIA-NeMo/NeMo/blob/1692a8fb97e1aadc883cfadd2a57c4e8a1b793aa/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L993 + non_blank_indices = [i for i, token_id in enumerate(batch_ids) if token_id != self.tokenizer.vocab_size] + non_blank_ids = [batch_ids[i] for i in non_blank_indices] + decoded_tokens = [self.tokenizer.decode([token_id]) for token_id in non_blank_ids] + timestamp_dict = [ + {"token": token_str, "start": int(timestamps[i]), "end": int(timestamps[i] + durations[i])} + for token_str, i in zip(decoded_tokens, non_blank_indices) + ] + timestamp_dict = self._refine_timestamps_tdt(timestamp_dict) + + # Convert to seconds + for offset in timestamp_dict: + offset["start"] = offset["start"] * frame_rate + offset["end"] = offset["end"] * frame_rate + proc_timestamps.append(timestamp_dict) + + return decoded, proc_timestamps + return decoded + + def _refine_timestamps_tdt( + self, + char_offsets, + supported_punctuation=['?', "'", 'Ā”', 'Āæ', '-', ':', ',', '%', '/', '.', '!'] + ): + for i, offset in enumerate(char_offsets): + # If token is a punctuation mark, set its start and end offset as start and end of previous token + if offset['token'] in supported_punctuation and i > 0: + offset['start'] = char_offsets[i - 1]['end'] + offset['end'] = offset['start'] + + return char_offsets + __all__ = ["ParakeetProcessor"] diff --git a/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json b/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json index 0acb4bae061b..e27e5f8304e5 100644 --- a/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json +++ b/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json @@ -1 +1 @@ -{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "Nor is mister Quilter's manner less interesting than his matter.", "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind."], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [5685, 508, 282, 3459, 1382, 305, 441, 7931, 7870, 698, 1742, 293, 561, 1091, 365, 381, 7098, 2745, 1544, 441, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [1876, 280, 530, 7870, 1441, 1050, 407, 1974, 309, 940, 507, 347, 297, 289, 592, 506, 4070, 287, 7877, 1868, 4959, 398, 2037, 575, 603, 534, 555, 7124, 818, 313, 381, 555, 786, 7864, 1441, 7877, 1622, 305, 283, 2324, 1471, 3109, 325, 296, 381, 575, 5404, 1021, 355, 769, 2090, 7880, 344, 3110, 427, 319, 4838, 366, 506, 1737, 7883]], "token_timestamps": [[0.23999999463558197, 0.47999998927116394, 0.6399999856948853, 0.8799999952316284, 1.1200000047683716, 1.3600000143051147, 1.440000057220459, 1.600000023841858, 1.7599999904632568, 2.0, 2.1600000858306885, 2.240000009536743, 2.4000000953674316, 2.4800000190734863, 2.559999942779541, 2.7200000286102295, 2.880000114440918, 3.0399999618530273, 3.119999885559082, 3.2799999713897705, 3.440000057220459, 3.5999999046325684, 3.759999990463257, 3.9200000762939453, 4.079999923706055, 4.239999771118164, 4.400000095367432, 4.480000019073486, 4.71999979019165, 4.960000038146973, 5.360000133514404, 5.599999904632568, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3199999928474426, 0.6399999856948853, 0.8799999952316284, 1.0399999618530273, 1.2000000476837158, 1.440000057220459, 1.6799999475479126, 1.840000033378601, 1.9199999570846558, 2.0, 2.1600000858306885, 2.4000000953674316, 2.559999942779541, 2.7200000286102295, 2.9600000381469727, 3.119999885559082, 3.359999895095825, 3.5999999046325684, 3.9200000762939453, 4.159999847412109, 4.320000171661377, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3199999928474426, 0.6399999856948853, 0.7200000286102295, 0.9599999785423279, 1.1200000047683716, 1.3600000143051147, 1.600000023841858, 1.840000033378601, 2.0799999237060547, 2.240000009536743, 2.4800000190734863, 2.640000104904175, 2.799999952316284, 2.880000114440918, 3.0399999618530273, 3.200000047683716, 3.440000057220459, 3.680000066757202, 3.8399999141693115, 4.079999923706055, 4.400000095367432, 4.559999942779541, 4.71999979019165, 4.960000038146973, 5.119999885559082, 5.360000133514404, 5.519999980926514, 5.679999828338623, 5.920000076293945, 6.159999847412109, 6.239999771118164, 6.400000095367432, 6.559999942779541, 6.71999979019165, 6.960000038146973, 7.28000020980835, 7.599999904632568, 7.920000076293945, 8.15999984741211, 8.319999694824219, 8.479999542236328, 8.720000267028809, 8.880000114440918, 8.960000038146973, 9.119999885559082, 9.279999732971191, 9.4399995803833, 9.680000305175781, 9.760000228881836, 9.920000076293945, 10.15999984741211, 10.239999771118164, 10.399999618530273, 10.640000343322754, 10.880000114440918, 10.960000038146973, 11.199999809265137, 11.359999656677246, 11.520000457763672, 11.84000015258789, 12.15999984741211]], "token_durations": [[3, 2, 3, 3, 3, 1, 2, 2, 2, 2, 1, 2, 1, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 3, 3, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 3, 2, 2, 3, 3, 2, 1, 1, 2, 3, 2, 2, 3, 2, 3, 3, 4, 3, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 1, 3, 2, 3, 3, 3, 3, 2, 3, 2, 2, 1, 2, 2, 3, 3, 2, 3, 4, 2, 2, 3, 2, 3, 2, 2, 3, 3, 1, 2, 2, 2, 3, 4, 4, 4, 3, 1, 2, 3, 2, 1, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 3, 1, 3, 2, 2, 4, 4, 2]]} \ No newline at end of file +{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "Nor is mister Quilter's manner less interesting than his matter.", "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind."], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [5685, 508, 282, 3459, 1382, 305, 441, 7931, 7870, 698, 1742, 293, 561, 1091, 365, 381, 7098, 2745, 1544, 441, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [1876, 280, 530, 7870, 1441, 1050, 407, 1974, 309, 940, 507, 347, 297, 289, 592, 506, 4070, 287, 7877, 1868, 4959, 398, 2037, 575, 603, 534, 555, 7124, 818, 313, 381, 555, 786, 7864, 1441, 7877, 1622, 305, 283, 2324, 1471, 3109, 325, 296, 381, 575, 5404, 1021, 355, 769, 2090, 7880, 344, 3110, 427, 319, 4838, 366, 506, 1737, 7883]], "start_timestamps": [[0.24, 0.48, 0.64, 0.88, 1.12, 1.36, 1.44, 1.6, 1.76, 2.0, 2.16, 2.24, 2.4, 2.48, 2.56, 2.72, 2.88, 3.04, 3.12, 3.2800000000000002, 3.44, 3.6, 3.7600000000000002, 3.92, 4.08, 4.24, 4.4, 4.48, 4.72, 4.96, 5.36, 5.6000000000000005], [0.32, 0.64, 0.88, 1.04, 1.2, 1.44, 1.68, 1.84, 1.92, 2.0, 2.16, 2.4, 2.56, 2.72, 2.96, 3.12, 3.36, 3.6, 3.92, 4.16, 4.32], [0.32, 0.64, 0.72, 0.96, 1.12, 1.36, 1.6, 1.84, 2.08, 2.24, 2.48, 2.64, 2.8000000000000003, 2.88, 3.04, 3.2, 3.44, 3.68, 3.84, 4.08, 4.4, 4.5600000000000005, 4.72, 4.96, 5.12, 5.36, 5.5200000000000005, 5.68, 5.92, 6.16, 6.24, 6.4, 6.5600000000000005, 6.72, 6.96, 7.28, 7.6000000000000005, 7.92, 8.16, 8.32, 8.48, 8.72, 8.88, 8.96, 9.120000000000001, 9.28, 9.44, 9.68, 9.76, 9.92, 10.16, 10.24, 10.4, 10.64, 10.88, 10.96, 11.200000000000001, 11.36, 11.52, 11.84, 12.16]], "end_timestamps": [[0.48, 0.64, 0.88, 1.12, 1.36, 1.44, 1.6, 1.76, 1.92, 2.16, 2.24, 2.4, 2.48, 2.56, 2.64, 2.88, 3.04, 3.12, 3.12, 3.44, 3.6, 3.7600000000000002, 3.92, 4.08, 4.24, 4.4, 4.48, 4.72, 4.96, 5.12, 5.6000000000000005, 5.6000000000000005], [0.64, 0.88, 1.04, 1.2, 1.44, 1.68, 1.84, 1.84, 2.0, 2.16, 2.4, 2.56, 2.72, 2.96, 3.12, 3.36, 3.6, 3.92, 4.16, 4.32, 4.32], [0.64, 0.72, 0.96, 1.12, 1.36, 1.6, 1.84, 2.08, 2.24, 2.48, 2.64, 2.8000000000000003, 2.88, 3.04, 3.2, 3.44, 3.68, 3.84, 3.84, 4.4, 4.5600000000000005, 4.72, 4.96, 5.12, 5.36, 5.5200000000000005, 5.68, 5.92, 6.16, 6.24, 6.4, 6.5600000000000005, 6.72, 6.96, 7.28, 7.28, 7.92, 8.16, 8.24, 8.48, 8.72, 8.88, 8.96, 9.120000000000001, 9.200000000000001, 9.44, 9.68, 9.76, 9.92, 10.16, 10.24, 10.4, 10.64, 10.88, 10.96, 11.200000000000001, 11.36, 11.52, 11.84, 12.16, 12.16]], "token_durations": [[3, 2, 3, 3, 3, 1, 2, 2, 2, 2, 1, 2, 1, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 3, 3, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 3, 2, 2, 3, 3, 2, 1, 1, 2, 3, 2, 2, 3, 2, 3, 3, 4, 3, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 1, 3, 2, 3, 3, 3, 3, 2, 3, 2, 2, 1, 2, 2, 3, 3, 2, 3, 4, 2, 2, 3, 2, 3, 2, 2, 3, 3, 1, 2, 2, 2, 3, 4, 4, 4, 3, 1, 2, 3, 2, 1, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 3, 1, 3, 2, 2, 4, 4, 2]]} \ No newline at end of file diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index acf9718ec2a5..6080227973b7 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -311,7 +311,7 @@ def test_1b_model_integration(self): inputs.to(torch_device, dtype=self.dtype) predicted_ids = model.generate(**inputs) torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) - predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) + predicted_transcripts = self.processor.decode(predicted_ids, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) @slow @@ -333,7 +333,7 @@ def test_1b_model_integration_batched(self): inputs.to(torch_device, dtype=self.dtype) predicted_ids = model.generate(**inputs) torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) - predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) + predicted_transcripts = self.processor.decode(predicted_ids, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) @@ -531,7 +531,7 @@ def test_tdt_model_integration(self): inputs.to(torch_device, dtype=self.dtype) output = model.generate(**inputs, return_dict_in_generate=True) torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) - predicted_transcripts = self.processor.batch_decode(output.sequences, skip_special_tokens=True) + predicted_transcripts = self.processor.decode(output.sequences, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) @slow @@ -553,7 +553,7 @@ def test_tdt_model_integration_batched(self): inputs.to(torch_device, dtype=self.dtype) output = model.generate(**inputs, return_dict_in_generate=True) torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) - predicted_transcripts = self.processor.batch_decode(output.sequences, skip_special_tokens=True) + predicted_transcripts = self.processor.decode(output.sequences, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) @slow @@ -568,7 +568,8 @@ def test_tdt_model_integration_timestamps(self): raw_data = json.load(f) EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] - EXPECTED_TIMESTAMPS = torch.tensor(raw_data["token_timestamps"]) + EXPECTED_START_TIMESTAMPS = raw_data["start_timestamps"] + EXPECTED_END_TIMESTAMPS = raw_data["end_timestamps"] EXPECTED_DURATIONS = raw_data["token_durations"] # Use larger precision for testing token durations and timestamps @@ -580,13 +581,20 @@ def test_tdt_model_integration_timestamps(self): inputs.to(torch_device, dtype=model.dtype) output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) - predicted_transcripts = self.processor.batch_decode(output.sequences, skip_special_tokens=True) + predicted_transcripts, predicted_timestamps = self.processor.decode( + output.sequences, + token_timestamps=output.token_timestamps, + token_durations=output.token_durations, + skip_special_tokens=True + ) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) # Check timestamps and durations self.assertIsNotNone( output.token_timestamps, "token_timestamps should be returned when return_timestamps=True" ) - # Relax tolerance for timestamps due to potential internal precision differences - torch.testing.assert_close(output.token_timestamps.cpu(), EXPECTED_TIMESTAMPS, atol=0.4, rtol=1e-6) + predicted_start_times = [[entry['start'] for entry in el] for el in predicted_timestamps] + predicted_end_times = [[entry['end'] for entry in el] for el in predicted_timestamps] + torch.testing.assert_close(predicted_start_times, EXPECTED_START_TIMESTAMPS) + torch.testing.assert_close(predicted_end_times, EXPECTED_END_TIMESTAMPS) self.assertListEqual(output.token_durations.cpu().tolist(), EXPECTED_DURATIONS) From 36bfa6391a90d338ea6e045d7e27d79b976ffb67 Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 3 Mar 2026 18:54:52 +0100 Subject: [PATCH 15/67] Simplify durations config. --- .../models/parakeet/configuration_parakeet.py | 14 ++++---------- .../models/parakeet/convert_nemo_to_hf.py | 11 +++-------- .../models/parakeet/modeling_parakeet.py | 2 +- .../models/parakeet/modular_parakeet.py | 2 +- 4 files changed, 9 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index c0452d66499f..cbe9073ee963 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -14,10 +14,6 @@ """Parakeet model configuration.""" from ...configuration_utils import PreTrainedConfig -from ...utils import logging - - -logger = logging.get_logger(__name__) class ParakeetEncoderConfig(PreTrainedConfig): @@ -251,8 +247,8 @@ class ParakeetTDTConfig(PreTrainedConfig): num_duration_bins (`int`, *optional*, defaults to 5): Number of duration bins for predicting token durations. durations (`list[int]`, *optional*, defaults to `[0, 1, 2, 3, 4]`): - Duration values for TDT loss computation. Each value represents how many frames a token or blank - emission spans. Must have length equal to `num_duration_bins`. + Token duration values that can be predicted. Each value represents how many frames a token or blank + emission spans. hidden_act (`str`, *optional*, defaults to `"relu"`): The activation function in the joint network. max_symbols_per_step (`int`, *optional*, defaults to 10): @@ -285,8 +281,7 @@ def __init__( vocab_size=8192, decoder_hidden_size=640, num_decoder_layers=1, - num_duration_bins=5, - durations=None, + durations=[0, 1, 2, 3, 4], hidden_act="relu", max_symbols_per_step=10, encoder_config: dict | ParakeetEncoderConfig = None, @@ -296,8 +291,7 @@ def __init__( self.vocab_size = vocab_size self.decoder_hidden_size = decoder_hidden_size self.num_decoder_layers = num_decoder_layers - self.num_duration_bins = num_duration_bins - self.durations = durations if durations is not None else list(range(num_duration_bins)) + self.durations = durations self.hidden_act = hidden_act self.max_symbols_per_step = max_symbols_per_step diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index 05a598118a83..daed5c11c598 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -301,18 +301,15 @@ def convert_tdt_config(nemo_config, encoder_config): num_decoder_layers = prednet.get("pred_rnn_layers", 2) durations = decoding_config.get("durations", [0, 1, 2, 3, 4]) - num_duration_bins = len(durations) - print( f"TDT config: vocab_size={vocab_size}, decoder_hidden={decoder_hidden_size}, " - f"decoder_layers={num_decoder_layers}, num_durations={num_duration_bins}, " + f"decoder_layers={num_decoder_layers}, durations={durations}, " ) return ParakeetTDTConfig( vocab_size=vocab_size, decoder_hidden_size=decoder_hidden_size, num_decoder_layers=num_decoder_layers, - num_duration_bins=num_duration_bins, durations=durations, hidden_act="relu", max_symbols_per_step=10, @@ -321,7 +318,7 @@ def convert_tdt_config(nemo_config, encoder_config): ) -def load_and_convert_tdt_state_dict(model_files, vocab_size, num_duration_bins): +def load_and_convert_tdt_state_dict(model_files, vocab_size): """Load NeMo TDT state dict and convert keys to HF format, splitting combined head.""" state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True) converted_state_dict = {} @@ -361,9 +358,7 @@ def write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_t model_config = convert_tdt_config(nemo_config, encoder_config) print(f"Converted TDT config: {model_config}") - converted_state_dict = load_and_convert_tdt_state_dict( - model_files, model_config.vocab_size, model_config.num_duration_bins - ) + converted_state_dict = load_and_convert_tdt_state_dict(model_files, model_config.vocab_size) print("Loading the checkpoint in a Parakeet TDT model.") with torch.device("meta"): diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 5325d5e8ee75..54527c7423df 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -886,7 +886,7 @@ def __init__(self, config: ParakeetTDTConfig): self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.activation = ACT2FN[config.hidden_act] self.token_head = nn.Linear(config.decoder_hidden_size, config.vocab_size + 1) - self.duration_head = nn.Linear(config.decoder_hidden_size, config.num_duration_bins) + self.duration_head = nn.Linear(config.decoder_hidden_size, len(config.durations)) def forward( self, diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 4cbb667aa001..0f61018c1ee7 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -847,7 +847,7 @@ def __init__(self, config: ParakeetTDTConfig): self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.activation = ACT2FN[config.hidden_act] self.token_head = nn.Linear(config.decoder_hidden_size, config.vocab_size + 1) - self.duration_head = nn.Linear(config.decoder_hidden_size, config.num_duration_bins) + self.duration_head = nn.Linear(config.decoder_hidden_size, len(config.durations)) def forward( self, From 2df0cccae53e1419f3ef2f3b2ecf2a356fc633e5 Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 3 Mar 2026 19:15:17 +0100 Subject: [PATCH 16/67] Update training examples. --- docs/source/en/model_doc/parakeet.md | 42 +++++++++++++++++----------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index 6722f932d631..c7906f94a54b 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -203,51 +203,59 @@ with TimerContext("Fourth generation"): print(processor.batch_decode(outputs)) ``` -### Training +### CTC Training ```python -from transformers import AutoModelForCTC, AutoProcessor -from datasets import load_dataset, Audio import torch +from datasets import Audio, load_dataset +from transformers import AutoModelForCTC, AutoProcessor -device = "cuda" if torch.cuda.is_available() else "cpu" +model_id = "nvidia/parakeet-ctc-1.1b" +NUM_SAMPLES = 5 -processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b") -model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device) +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForCTC.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") +model.train() ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) -speech_samples = [el['array'] for el in ds["audio"][:5]] -text_samples = [el for el in ds["text"][:5]] +speech_samples = [el['array'] for el in ds["audio"][:NUM_SAMPLES]] +text_samples = [el for el in ds["text"][:NUM_SAMPLES]] # passing `text` to the processor will prepare inputs' `labels` key inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate) -inputs.to(device, dtype=model.dtype) +inputs.to(device=model.device, dtype=model.dtype) outputs = model(**inputs) +print("Loss:", outputs.loss.item()) outputs.loss.backward() ``` ### TDT Training -The TDT model uses RNNT loss (requires `torchaudio`). Pass `text` to the processor to prepare labels — padding is automatically handled with `-100`. - ```python +from datasets import Audio, load_dataset +import torch from transformers import AutoModelForTDT, AutoProcessor -from datasets import load_dataset, Audio -processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") -model = AutoModelForTDT.from_pretrained("nvidia/parakeet-tdt-0.6b-v3", dtype="auto", device_map="auto") +model_id = "bezzam/parakeet-tdt-0.6b-v3-hf" +NUM_SAMPLES = 3 + +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForTDT.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") +model.train() ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) -speech_samples = [el['array'] for el in ds["audio"][:5]] -text_samples = [el for el in ds["text"][:5]] +speech_samples = [el['array'] for el in ds["audio"][:NUM_SAMPLES]] +text_samples = [el for el in ds["text"][:NUM_SAMPLES]] +# passing `text` to the processor will prepare inputs' `labels` key inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate) -inputs.to(model.device, dtype=model.dtype) +inputs.to(device=model.device, dtype=model.dtype) outputs = model(**inputs) +print("Loss:", outputs.loss.item()) outputs.loss.backward() ``` From 388c6d36d9d9e43797d5ce4d3d66a03f8930a3b2 Mon Sep 17 00:00:00 2001 From: Maksym Lypivskyi Date: Tue, 3 Mar 2026 22:58:40 +0100 Subject: [PATCH 17/67] chore: enable parralelism --- src/transformers/models/parakeet/modeling_parakeet.py | 5 +++++ src/transformers/models/parakeet/modular_parakeet.py | 5 +++++ tests/models/parakeet/test_modeling_parakeet.py | 6 +++--- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 54527c7423df..cc8090e97b1c 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -1102,6 +1102,11 @@ def forward( # token_logits: (batch, T, U+1, vocab_size+1) # duration_logits: (batch, T, U+1, num_duration_bins) + # move labels to correct device to enable pipeline parallelism + labels = labels.to(token_logits.device) + encoder_lengths = encoder_lengths.to(token_logits.device) + target_lengths = target_lengths.to(token_logits.device) + loss = tdt_loss( token_logits=token_logits.float(), duration_logits=duration_logits.float(), diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 0f61018c1ee7..4a3131501b54 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -950,6 +950,11 @@ def forward( # token_logits: (batch, T, U+1, vocab_size+1) # duration_logits: (batch, T, U+1, num_duration_bins) + # move labels to correct device to enable pipeline parallelism + labels = labels.to(token_logits.device) + encoder_lengths = encoder_lengths.to(token_logits.device) + target_lengths = target_lengths.to(token_logits.device) + loss = tdt_loss( token_logits=token_logits.float(), duration_logits=duration_logits.float(), diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index eb884ed85421..6104998888c5 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -423,7 +423,7 @@ def __init__( vocab_size=128, decoder_hidden_size=64, num_decoder_layers=1, - num_duration_bins=5, + durations=None, hidden_act="relu", max_symbols_per_step=10, pad_token_id=128, @@ -445,7 +445,7 @@ def __init__( self.vocab_size = vocab_size self.decoder_hidden_size = decoder_hidden_size self.num_decoder_layers = num_decoder_layers - self.num_duration_bins = num_duration_bins + self.durations = durations if durations is not None else [0, 1, 2, 3, 4] self.hidden_act = hidden_act self.max_symbols_per_step = max_symbols_per_step self.pad_token_id = pad_token_id @@ -460,7 +460,7 @@ def get_config(self): vocab_size=self.vocab_size, decoder_hidden_size=self.decoder_hidden_size, num_decoder_layers=self.num_decoder_layers, - num_duration_bins=self.num_duration_bins, + durations=self.durations, hidden_act=self.hidden_act, max_symbols_per_step=self.max_symbols_per_step, encoder_config=self.encoder_model_tester.get_config().to_dict(), From 08b2b5588a4d9e3a110a6d91bd1b434d877674a3 Mon Sep 17 00:00:00 2001 From: Maksym Lypivskyi Date: Wed, 4 Mar 2026 01:54:54 +0100 Subject: [PATCH 18/67] chore: performance optimization --- .../models/parakeet/modeling_parakeet.py | 99 ++++++++++++------- .../models/parakeet/modular_parakeet.py | 99 ++++++++++++------- 2 files changed, 122 insertions(+), 76 deletions(-) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index cc8090e97b1c..2c5f24315659 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -915,7 +915,9 @@ def tdt_loss( Compute TDT (Token-and-Duration Transducer) loss. Ported from NeMo's `TDTLossPytorch`. Unlike standard RNNT loss, this loss trains both - the token prediction head and the duration prediction head. + the token prediction head and the duration prediction head. Uses vectorized anti-diagonal + processing for efficiency: all (t, u) pairs on each anti-diagonal t+u=n are computed in + parallel as batched tensor operations. Args: token_logits: Token logits of shape `(batch, T, U+1, vocab_size+1)`. @@ -941,51 +943,73 @@ def tdt_loss( token_log_probs = torch.log_softmax(token_logits, dim=-1) - sigma duration_log_probs = torch.log_softmax(duration_logits, dim=-1) - # Forward variable: log_alpha[b, t, u] = log P(y_{1:u} | x_{1:t}) - log_alpha = torch.full((batch_size, max_t, max_u), -1000.0, device=device) + log_alpha = torch.full((batch_size, max_t, max_u), float("-inf"), device=device) log_alpha[:, 0, 0] = 0.0 - batch_idx = torch.arange(batch_size, device=device) - for t in range(max_t): - for u in range(max_u): - if t == 0 and u == 0: - continue + # Precompute blank and label log-probs for vectorized access + blank_log_probs = token_log_probs[:, :, :, blank] - # Accumulate log-probabilities from all incoming arcs - candidates = [] + if max_u > 1: + targets_expanded = targets.unsqueeze(1).expand(-1, max_t, -1) # (batch, T, U_labels) + label_log_probs = torch.gather( + token_log_probs[:, :, : max_u - 1, :], # (batch, T, U-1, vocab) + dim=3, + index=targets_expanded.unsqueeze(-1), + ).squeeze(-1) # (batch, T, U-1) - for n, dur in enumerate(durations): - t_prev = t - dur - if t_prev < 0: - continue + # Process anti-diagonals: all (t, u) with t + u = n have no mutual dependencies + for n in range(1, max_t + max_u - 1): + u_start = max(0, n - max_t + 1) + u_end = min(n + 1, max_u) + u_indices = torch.arange(u_start, u_end, device=device) + t_indices = n - u_indices - # Blank arc (duration > 0): same label position, skip `dur` frames - if dur > 0: - blank_contribution = ( - log_alpha[:, t_prev, u] - + token_log_probs[:, t_prev, u, blank] - + duration_log_probs[:, t_prev, u, n] - ) - candidates.append(blank_contribution) - - # Label arc (u > 0): emit label y_u from position (t_prev, u-1) - if u > 0: - label_contribution = ( - log_alpha[:, t_prev, u - 1] - + token_log_probs[batch_idx, t_prev, u - 1, targets[:, u - 1]] - + duration_log_probs[:, t_prev, u - 1, n] - ) - candidates.append(label_contribution) + all_candidates = [] + + for i, dur in enumerate(durations): + t_prev = t_indices - dur + valid_t = t_prev >= 0 + + if not valid_t.any(): + continue + + t_src = t_prev.clamp(min=0) + + # Blank arcs (dur > 0): from (t-dur, u) to (t, u) + if dur > 0: + contrib = ( + log_alpha[:, t_src, u_indices] + + blank_log_probs[:, t_src, u_indices] + + duration_log_probs[:, t_src, u_indices, i] + ) + contrib = torch.where(valid_t.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) + all_candidates.append(contrib) + + # Label arcs: from (t-dur, u-1) to (t, u), only if u > 0 + valid_u = u_indices > 0 + valid_both = valid_t & valid_u + if valid_both.any(): + u_src = (u_indices - 1).clamp(min=0) + u_src_label = u_src.clamp(max=max_u - 2) if max_u > 1 else u_src + + contrib = ( + log_alpha[:, t_src, u_src] + + label_log_probs[:, t_src, u_src_label] + + duration_log_probs[:, t_src, u_src, i] + ) + contrib = torch.where(valid_both.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) + all_candidates.append(contrib) - if candidates: - log_alpha[:, t, u] = torch.logsumexp(torch.stack(candidates, dim=0), dim=0) + if all_candidates: + stacked = torch.stack(all_candidates, dim=0) + log_alpha[:, t_indices, u_indices] = torch.logsumexp(stacked, dim=0) # Terminal probability: sum over blank arcs that reach (T, U) from (T-dur, U) - log_probs = torch.full((batch_size,), -1000.0, device=device) - for n, dur in enumerate(durations): + batch_idx = torch.arange(batch_size, device=device) + log_probs = torch.full((batch_size,), float("-inf"), device=device) + for i, dur in enumerate(durations): if dur == 0: continue - # For each example, check if act_lens[b] - dur >= 0 t_final = logit_lengths - dur valid = t_final >= 0 if not valid.any(): @@ -995,9 +1019,8 @@ def tdt_loss( terminal = ( log_alpha[batch_idx, t_clamped, target_lengths] + token_log_probs[batch_idx, t_clamped, target_lengths, blank] - + duration_log_probs[batch_idx, t_clamped, target_lengths, n] + + duration_log_probs[batch_idx, t_clamped, target_lengths, i] ) - # Only update valid entries combined = torch.stack([log_probs, terminal], dim=0) log_probs = torch.where(valid, torch.logsumexp(combined, dim=0), log_probs) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 4a3131501b54..16affe808803 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -741,7 +741,9 @@ def tdt_loss( Compute TDT (Token-and-Duration Transducer) loss. Ported from NeMo's `TDTLossPytorch`. Unlike standard RNNT loss, this loss trains both - the token prediction head and the duration prediction head. + the token prediction head and the duration prediction head. Uses vectorized anti-diagonal + processing for efficiency: all (t, u) pairs on each anti-diagonal t+u=n are computed in + parallel as batched tensor operations. Args: token_logits: Token logits of shape `(batch, T, U+1, vocab_size+1)`. @@ -767,51 +769,73 @@ def tdt_loss( token_log_probs = torch.log_softmax(token_logits, dim=-1) - sigma duration_log_probs = torch.log_softmax(duration_logits, dim=-1) - # Forward variable: log_alpha[b, t, u] = log P(y_{1:u} | x_{1:t}) - log_alpha = torch.full((batch_size, max_t, max_u), -1000.0, device=device) + log_alpha = torch.full((batch_size, max_t, max_u), float("-inf"), device=device) log_alpha[:, 0, 0] = 0.0 - batch_idx = torch.arange(batch_size, device=device) - for t in range(max_t): - for u in range(max_u): - if t == 0 and u == 0: - continue + # Precompute blank and label log-probs for vectorized access + blank_log_probs = token_log_probs[:, :, :, blank] - # Accumulate log-probabilities from all incoming arcs - candidates = [] + if max_u > 1: + targets_expanded = targets.unsqueeze(1).expand(-1, max_t, -1) # (batch, T, U_labels) + label_log_probs = torch.gather( + token_log_probs[:, :, : max_u - 1, :], # (batch, T, U-1, vocab) + dim=3, + index=targets_expanded.unsqueeze(-1), + ).squeeze(-1) # (batch, T, U-1) - for n, dur in enumerate(durations): - t_prev = t - dur - if t_prev < 0: - continue + # Process anti-diagonals: all (t, u) with t + u = n have no mutual dependencies + for n in range(1, max_t + max_u - 1): + u_start = max(0, n - max_t + 1) + u_end = min(n + 1, max_u) + u_indices = torch.arange(u_start, u_end, device=device) + t_indices = n - u_indices - # Blank arc (duration > 0): same label position, skip `dur` frames - if dur > 0: - blank_contribution = ( - log_alpha[:, t_prev, u] - + token_log_probs[:, t_prev, u, blank] - + duration_log_probs[:, t_prev, u, n] - ) - candidates.append(blank_contribution) - - # Label arc (u > 0): emit label y_u from position (t_prev, u-1) - if u > 0: - label_contribution = ( - log_alpha[:, t_prev, u - 1] - + token_log_probs[batch_idx, t_prev, u - 1, targets[:, u - 1]] - + duration_log_probs[:, t_prev, u - 1, n] - ) - candidates.append(label_contribution) + all_candidates = [] + + for i, dur in enumerate(durations): + t_prev = t_indices - dur + valid_t = t_prev >= 0 + + if not valid_t.any(): + continue + + t_src = t_prev.clamp(min=0) + + # Blank arcs (dur > 0): from (t-dur, u) to (t, u) + if dur > 0: + contrib = ( + log_alpha[:, t_src, u_indices] + + blank_log_probs[:, t_src, u_indices] + + duration_log_probs[:, t_src, u_indices, i] + ) + contrib = torch.where(valid_t.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) + all_candidates.append(contrib) + + # Label arcs: from (t-dur, u-1) to (t, u), only if u > 0 + valid_u = u_indices > 0 + valid_both = valid_t & valid_u + if valid_both.any(): + u_src = (u_indices - 1).clamp(min=0) + u_src_label = u_src.clamp(max=max_u - 2) if max_u > 1 else u_src + + contrib = ( + log_alpha[:, t_src, u_src] + + label_log_probs[:, t_src, u_src_label] + + duration_log_probs[:, t_src, u_src, i] + ) + contrib = torch.where(valid_both.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) + all_candidates.append(contrib) - if candidates: - log_alpha[:, t, u] = torch.logsumexp(torch.stack(candidates, dim=0), dim=0) + if all_candidates: + stacked = torch.stack(all_candidates, dim=0) + log_alpha[:, t_indices, u_indices] = torch.logsumexp(stacked, dim=0) # Terminal probability: sum over blank arcs that reach (T, U) from (T-dur, U) - log_probs = torch.full((batch_size,), -1000.0, device=device) - for n, dur in enumerate(durations): + batch_idx = torch.arange(batch_size, device=device) + log_probs = torch.full((batch_size,), float("-inf"), device=device) + for i, dur in enumerate(durations): if dur == 0: continue - # For each example, check if act_lens[b] - dur >= 0 t_final = logit_lengths - dur valid = t_final >= 0 if not valid.any(): @@ -821,9 +845,8 @@ def tdt_loss( terminal = ( log_alpha[batch_idx, t_clamped, target_lengths] + token_log_probs[batch_idx, t_clamped, target_lengths, blank] - + duration_log_probs[batch_idx, t_clamped, target_lengths, n] + + duration_log_probs[batch_idx, t_clamped, target_lengths, i] ) - # Only update valid entries combined = torch.stack([log_probs, terminal], dim=0) log_probs = torch.where(valid, torch.logsumexp(combined, dim=0), log_probs) From 0c4e05a82d3b06ae71cb5aa72acefeaeef871cb5 Mon Sep 17 00:00:00 2001 From: Maksym Lypivskyi Date: Wed, 4 Mar 2026 02:03:59 +0100 Subject: [PATCH 19/67] fix: formatting --- .../models/parakeet/modular_parakeet.py | 4 ++-- .../models/parakeet/processing_parakeet.py | 20 +++++++++++-------- .../models/parakeet/test_modeling_parakeet.py | 6 +++--- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 16affe808803..cf9f32d9aadf 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -1178,8 +1178,8 @@ def generate( for i in range(batch_size): num_tokens = len(token_frame_indices[i]) if num_tokens > 0: - token_timestamps[i, :num_tokens] = ( - torch.tensor(token_frame_indices[i], dtype=torch.long, device=device) + token_timestamps[i, :num_tokens] = torch.tensor( + token_frame_indices[i], dtype=torch.long, device=device ) token_durations[i, :num_tokens] = torch.tensor( token_durations_list[i], dtype=torch.long, device=device diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py index 459bf52d90c9..dca9e75b0769 100644 --- a/src/transformers/models/parakeet/processing_parakeet.py +++ b/src/transformers/models/parakeet/processing_parakeet.py @@ -108,11 +108,17 @@ def decode(self, *args, token_timestamps=None, token_durations=None, **kwargs): tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) - frame_rate = self.feature_extractor.hop_length / self.feature_extractor.sampling_rate * output_kwargs["audio_kwargs"]["subsampling_factor"] + frame_rate = ( + self.feature_extractor.hop_length + / self.feature_extractor.sampling_rate + * output_kwargs["audio_kwargs"]["subsampling_factor"] + ) proc_timestamps = [] for batch_ids, timestamps, durations in zip(token_ids, token_timestamps, token_durations): # Original NeMo: https://github.com/NVIDIA-NeMo/NeMo/blob/1692a8fb97e1aadc883cfadd2a57c4e8a1b793aa/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L993 - non_blank_indices = [i for i, token_id in enumerate(batch_ids) if token_id != self.tokenizer.vocab_size] + non_blank_indices = [ + i for i, token_id in enumerate(batch_ids) if token_id != self.tokenizer.vocab_size + ] non_blank_ids = [batch_ids[i] for i in non_blank_indices] decoded_tokens = [self.tokenizer.decode([token_id]) for token_id in non_blank_ids] timestamp_dict = [ @@ -131,15 +137,13 @@ def decode(self, *args, token_timestamps=None, token_durations=None, **kwargs): return decoded def _refine_timestamps_tdt( - self, - char_offsets, - supported_punctuation=['?', "'", 'Ā”', 'Āæ', '-', ':', ',', '%', '/', '.', '!'] + self, char_offsets, supported_punctuation=["?", "'", "Ā”", "Āæ", "-", ":", ",", "%", "/", ".", "!"] ): for i, offset in enumerate(char_offsets): # If token is a punctuation mark, set its start and end offset as start and end of previous token - if offset['token'] in supported_punctuation and i > 0: - offset['start'] = char_offsets[i - 1]['end'] - offset['end'] = offset['start'] + if offset["token"] in supported_punctuation and i > 0: + offset["start"] = char_offsets[i - 1]["end"] + offset["end"] = offset["start"] return char_offsets diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 6104998888c5..1b948363536a 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -662,7 +662,7 @@ def test_tdt_model_integration_timestamps(self): output.sequences, token_timestamps=output.token_timestamps, token_durations=output.token_durations, - skip_special_tokens=True + skip_special_tokens=True, ) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) @@ -670,8 +670,8 @@ def test_tdt_model_integration_timestamps(self): self.assertIsNotNone( output.token_timestamps, "token_timestamps should be returned when return_timestamps=True" ) - predicted_start_times = [[entry['start'] for entry in el] for el in predicted_timestamps] - predicted_end_times = [[entry['end'] for entry in el] for el in predicted_timestamps] + predicted_start_times = [[entry["start"] for entry in el] for el in predicted_timestamps] + predicted_end_times = [[entry["end"] for entry in el] for el in predicted_timestamps] torch.testing.assert_close(predicted_start_times, EXPECTED_START_TIMESTAMPS) torch.testing.assert_close(predicted_end_times, EXPECTED_END_TIMESTAMPS) self.assertListEqual(output.token_durations.cpu().tolist(), EXPECTED_DURATIONS) From 1ddd804979b120978907ba0ab4e207f7a5bcf602 Mon Sep 17 00:00:00 2001 From: Eric B Date: Thu, 5 Mar 2026 19:30:54 +0100 Subject: [PATCH 20/67] Doc and testing nits --- docs/source/en/model_doc/parakeet.md | 66 +++++++++++++++---- .../models/parakeet/test_modeling_parakeet.py | 31 ++++----- 2 files changed, 63 insertions(+), 34 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index c7906f94a54b..9dd03ad00bfc 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -66,10 +66,10 @@ print(out) ```py from transformers import AutoModelForCTC, AutoProcessor from datasets import load_dataset, Audio -import torch -processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b") -model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map="auto") +model_id = "nvidia/parakeet-ctc-1.1b" +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForCTC.from_pretrained(model_id, dtype="auto", device_map="auto") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) @@ -78,7 +78,7 @@ speech_samples = [el['array'] for el in ds["audio"][:5]] inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=model.dtype) outputs = model.generate(**inputs) -print(processor.batch_decode(outputs)) +print(processor.decode(outputs)) ``` @@ -89,6 +89,8 @@ print(processor.batch_decode(outputs)) +Parakeet TDT transcripts include casing, and the model can also performk token timestamping. + ```py from transformers import pipeline @@ -103,10 +105,10 @@ print(out) ```py from transformers import AutoModelForTDT, AutoProcessor from datasets import load_dataset, Audio -import torch -processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") -model = AutoModelForTDT.from_pretrained("nvidia/parakeet-tdt-0.6b-v3", dtype="auto", device_map="auto") +model_id = "nvidia/parakeet-tdt-0.6b-v3" +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) @@ -115,7 +117,44 @@ speech_samples = [el['array'] for el in ds["audio"][:5]] inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=model.dtype) output = model.generate(**inputs, return_dict_in_generate=True) -print(processor.batch_decode(output.sequences, skip_special_tokens=True)) +print(processor.decode(output.sequences, skip_special_tokens=True)) +``` + + + + + + + +```py +from datasets import Audio, load_dataset +from transformers import AutoModelForTDT, AutoProcessor + +model_id = "nvidia/parakeet-tdt-0.6b-v3" +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto") + +ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") +ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) +speech_samples = [el['array'] for el in ds["audio"][:1]] + +inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate) +inputs.to(model.device, dtype=model.dtype) +output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) +decoded_output, decoded_timestamps = processor.decode( + output.sequences, + token_timestamps=output.token_timestamps, + token_durations=output.token_durations, + skip_special_tokens=True +) +print("Transcription:", decoded_output) +print("\nTimestamped tokens:", decoded_timestamps) + +""" +Transcription: ['mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'] + +Timestamped tokens: [[{'token': 'm', 'start': 0.24, 'end': 0.48}, {'token': 'ister', 'start': 0.48, 'end': 0.64}, {'token': 'Qu', 'start': 0.64, 'end': 0.88}, {'token': 'il', 'start': 0.88, 'end': 1.12}, {'token': 'ter', 'start': 1.12, 'end': 1.36}, {'token': 'is', 'start': 1.36, 'end': 1.44}, {'token': 'the', 'start': 1.44, 'end': 1.6}, {'token': 'ap', 'start': 1.6, 'end': 1.76}, {'token': 'ost', 'start': 1.76, 'end': 1.92}, {'token': 'le', 'start': 2.0, 'end': 2.16}, {'token': 'of', 'start': 2.16, 'end': 2.24}, {'token': 'the', 'start': 2.24, 'end': 2.4}, {'token': 'mid', 'start': 2.4, 'end': 2.48}, {'token': 'd', 'start': 2.48, 'end': 2.56}, {'token': 'le', 'start': 2.56, 'end': 2.64}, {'token': 'clas', 'start': 2.72, 'end': 2.88}, {'token': 's', 'start': 2.88, 'end': 3.04}, {'token': 'es', 'start': 3.04, 'end': 3.12}, {'token': ',', 'start': 3.12, 'end': 3.12}, {'token': 'and', 'start': 3.2800000000000002, 'end': 3.44}, {'token': 'we', 'start': 3.44, 'end': 3.6}, {'token': 'are', 'start': 3.6, 'end': 3.7600000000000002}, {'token': 'gl', 'start': 3.7600000000000002, 'end': 3.92}, {'token': 'ad', 'start': 3.92, 'end': 4.08}, {'token': 'to', 'start': 4.08, 'end': 4.24}, {'token': 'wel', 'start': 4.24, 'end': 4.4}, {'token': 'c', 'start': 4.4, 'end': 4.48}, {'token': 'ome', 'start': 4.48, 'end': 4.72}, {'token': 'his', 'start': 4.72, 'end': 4.96}, {'token': 'gos', 'start': 4.96, 'end': 5.12}, {'token': 'pel', 'start': 5.36, 'end': 5.6000000000000005}, {'token': '.', 'start': 5.6000000000000005, 'end': 5.6000000000000005}]] +""" ``` @@ -176,7 +215,7 @@ print("First generation - compiling...") # Generate with the compiled model with TimerContext("First generation"): outputs = model.generate(**inputs) -print(processor.batch_decode(outputs)) +print(processor.decode(outputs)) inputs = processor(speech_samples[1], **processor_kwargs) inputs.to(device, dtype=model.dtype) @@ -184,7 +223,7 @@ print("\n" + "="*50) print("Second generation - recording CUDA graphs...") with TimerContext("Second generation"): outputs = model.generate(**inputs) -print(processor.batch_decode(outputs)) +print(processor.decode(outputs)) inputs = processor(speech_samples[2], **processor_kwargs) inputs.to(device, dtype=model.dtype) @@ -192,7 +231,7 @@ print("\n" + "="*50) print("Third generation - fast !!!") with TimerContext("Third generation"): outputs = model.generate(**inputs) -print(processor.batch_decode(outputs)) +print(processor.decode(outputs)) inputs = processor(speech_samples[3], **processor_kwargs) inputs.to(device, dtype=model.dtype) @@ -200,7 +239,7 @@ print("\n" + "="*50) print("Fourth generation - still fast !!!") with TimerContext("Fourth generation"): outputs = model.generate(**inputs) -print(processor.batch_decode(outputs)) +print(processor.decode(outputs)) ``` ### CTC Training @@ -238,7 +277,7 @@ from datasets import Audio, load_dataset import torch from transformers import AutoModelForTDT, AutoProcessor -model_id = "bezzam/parakeet-tdt-0.6b-v3-hf" +model_id = "nvidia/parakeet-tdt-0.6b-v3-hf" NUM_SAMPLES = 3 processor = AutoProcessor.from_pretrained(model_id) @@ -272,7 +311,6 @@ outputs.loss.backward() [[autodoc]] ParakeetProcessor - __call__ - - batch_decode - decode ## ParakeetEncoderConfig diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 1b948363536a..3591edd8b0d4 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -294,9 +294,7 @@ class ParakeetForCTCModelTest(ModelTesterMixin, unittest.TestCase): ) test_attention_outputs = False - test_resize_embeddings = False - _is_composite = True def setUp(self): @@ -381,11 +379,10 @@ def test_1b_model_integration(self): EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(1) - model = ParakeetForCTC.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map=torch_device) - model.to(torch_device) + model = ParakeetForCTC.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") inputs = self.processor(samples) - inputs.to(torch_device, dtype=self.dtype) + inputs.to(model.device, dtype=self.dtype) predicted_ids = model.generate(**inputs) torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.decode(predicted_ids, skip_special_tokens=True) @@ -403,11 +400,10 @@ def test_1b_model_integration_batched(self): EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(5) - model = ParakeetForCTC.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map=torch_device) - model.to(torch_device) + model = ParakeetForCTC.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") inputs = self.processor(samples) - inputs.to(torch_device, dtype=self.dtype) + inputs.to(model.device, dtype=self.dtype) predicted_ids = model.generate(**inputs) torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.decode(predicted_ids, skip_special_tokens=True) @@ -501,9 +497,7 @@ class ParakeetForTDTModelTest(ModelTesterMixin, unittest.TestCase): ) test_attention_outputs = False - test_resize_embeddings = False - _is_composite = True def setUp(self): @@ -530,7 +524,7 @@ def test_sdpa_can_dispatch_composite_models(self): self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: @@ -601,11 +595,10 @@ def test_tdt_model_integration(self): EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map=torch_device) - model.to(torch_device) + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) - inputs.to(torch_device, dtype=self.dtype) + inputs.to(model.device, dtype=self.dtype) output = model.generate(**inputs, return_dict_in_generate=True) torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.decode(output.sequences, skip_special_tokens=True) @@ -623,11 +616,10 @@ def test_tdt_model_integration_batched(self): EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map=torch_device) - model.to(torch_device) + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) - inputs.to(torch_device, dtype=self.dtype) + inputs.to(model.device, dtype=self.dtype) output = model.generate(**inputs, return_dict_in_generate=True) torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.decode(output.sequences, skip_special_tokens=True) @@ -651,11 +643,10 @@ def test_tdt_model_integration_timestamps(self): # Use larger precision for testing token durations and timestamps samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map=torch_device) - model.to(torch_device) + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) - inputs.to(torch_device, dtype=model.dtype) + inputs.to(model.device, dtype=model.dtype) output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts, predicted_timestamps = self.processor.decode( From f51267034cf8daf4aa05727cfd991051cc248c5b Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 6 Mar 2026 15:11:07 +0100 Subject: [PATCH 21/67] Use active mask from current step, and nits. --- .../models/parakeet/modeling_parakeet.py | 48 +++++++++---------- .../models/parakeet/modular_parakeet.py | 48 +++++++++---------- 2 files changed, 46 insertions(+), 50 deletions(-) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 2c5f24315659..c19f79bf6a68 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -900,6 +900,7 @@ def forward( return token_logits, duration_logits +# TODO (ebezzam) eventually move to audio_utils or loss_utils for common usage? def tdt_loss( token_logits: torch.Tensor, duration_logits: torch.Tensor, @@ -912,7 +913,7 @@ def tdt_loss( reduction: str = "mean", ) -> torch.Tensor: """ - Compute TDT (Token-and-Duration Transducer) loss. + Compute TDT (Token-and-Duration Transducer) loss (https://arxiv.org/abs/2304.06795). Ported from NeMo's `TDTLossPytorch`. Unlike standard RNNT loss, this loss trains both the token prediction head and the duration prediction head. Uses vectorized anti-diagonal @@ -933,8 +934,6 @@ def tdt_loss( Returns: Scalar loss tensor (or per-example losses if `reduction="none"`). - Reference: - *Token-and-Duration Transducer (TDT)* — https://arxiv.org/abs/2304.06795 """ device = token_logits.device batch_size, max_t, max_u, _ = token_logits.shape @@ -1122,20 +1121,13 @@ def forward( encoder_hidden_states_trimmed.unsqueeze(2), decoder_output.unsqueeze(1), ) - # token_logits: (batch, T, U+1, vocab_size+1) - # duration_logits: (batch, T, U+1, num_duration_bins) - - # move labels to correct device to enable pipeline parallelism - labels = labels.to(token_logits.device) - encoder_lengths = encoder_lengths.to(token_logits.device) - target_lengths = target_lengths.to(token_logits.device) loss = tdt_loss( token_logits=token_logits.float(), duration_logits=duration_logits.float(), - targets=labels.int(), - logit_lengths=encoder_lengths.int(), - target_lengths=target_lengths.int(), + targets=labels.to(token_logits.device).int(), + logit_lengths=encoder_lengths.to(token_logits.device).int(), + target_lengths=target_lengths.to(token_logits.device).int(), blank=self.config.pad_token_id, durations=self.config.durations, reduction="mean", @@ -1162,8 +1154,8 @@ def generate( Args: return_timestamps (`bool`, *optional*, defaults to `False`): - Whether to return per-token timestamps in seconds. When `True`, forces - `return_dict_in_generate=True` and includes `token_timestamps` in the output. + Whether to return per-token timestamps and durations. When `True`, forces + `return_dict_in_generate=True` and includes `token_timestamps` and `token_durations` in the output. Example: @@ -1178,28 +1170,33 @@ def generate( >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) - >>> inputs = processor(ds[0]["audio"]["array"]) + >>> inputs = processor(ds[0]["audio"]["array"], sampling_rate=processor.feature_extractor.sampling_rate) + >>> inputs = inputs.to(model.device, dtype=model.dtype) >>> output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) - >>> transcription = processor.batch_decode(output.sequences, skip_special_tokens=True) - >>> print(transcription) - >>> print(output.token_timestamps) + >>> decoded_output, decoded_timestamps = processor.decode( + ... output.sequences, + ... token_timestamps=output.token_timestamps, + ... token_durations=output.token_durations, + ... skip_special_tokens=True + ... ) + >>> print("Transcription:", decoded_output) + >>> print("Timestamped tokens:", decoded_timestamps) ``` """ kwargs["return_dict"] = True if return_timestamps: return_dict_in_generate = True - - batch_size = input_features.shape[0] outputs: CausalLMOutput = self.forward( input_features=input_features, attention_mask=attention_mask, **kwargs, ) - encoder_hidden_states = outputs.logits + # greedy TDT decoding, `GreedyBatchedTDTLabelLoopingComputer.torch_impl` in NeMo + encoder_hidden_states = outputs.logits + batch_size, sequence_length = encoder_hidden_states.shape[:2] device = encoder_hidden_states.device - sequence_length = encoder_hidden_states.shape[1] if attention_mask is not None: encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) valid_lengths = encoder_attention_mask.sum(dim=1).int() @@ -1227,6 +1224,7 @@ def generate( last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) while active_mask.any(): + active_mask_prev = active_mask.clone() safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) encoder_frames = encoder_hidden_states[batch_indices, safe_time_indices].unsqueeze(1) @@ -1236,7 +1234,7 @@ def generate( tokens = token_logits.argmax(dim=-1) durations = duration_logits.argmax(dim=-1) - blank_mask = active_mask & (tokens == self.config.pad_token_id) + blank_mask = active_mask_prev & (tokens == self.config.pad_token_id) # Force blank duration >= 1 to guarantee forward progress durations = durations.masked_fill(blank_mask & (durations == 0), 1) @@ -1275,7 +1273,7 @@ def generate( advance_mask = active_mask & blank_mask # Record results for non-blank tokens found - emit_mask = active_mask & (tokens != self.config.pad_token_id) + emit_mask = active_mask_prev & (tokens != self.config.pad_token_id) for i in range(batch_size): if emit_mask[i]: all_tokens[i].append(tokens[i].item()) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index cf9f32d9aadf..c10afe6b7b64 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -726,6 +726,7 @@ def forward( return decoder_output, hidden_state, cell_state +# TODO (ebezzam) eventually move to audio_utils or loss_utils for common usage? def tdt_loss( token_logits: torch.Tensor, duration_logits: torch.Tensor, @@ -738,7 +739,7 @@ def tdt_loss( reduction: str = "mean", ) -> torch.Tensor: """ - Compute TDT (Token-and-Duration Transducer) loss. + Compute TDT (Token-and-Duration Transducer) loss (https://arxiv.org/abs/2304.06795). Ported from NeMo's `TDTLossPytorch`. Unlike standard RNNT loss, this loss trains both the token prediction head and the duration prediction head. Uses vectorized anti-diagonal @@ -759,8 +760,6 @@ def tdt_loss( Returns: Scalar loss tensor (or per-example losses if `reduction="none"`). - Reference: - *Token-and-Duration Transducer (TDT)* — https://arxiv.org/abs/2304.06795 """ device = token_logits.device batch_size, max_t, max_u, _ = token_logits.shape @@ -970,20 +969,13 @@ def forward( encoder_hidden_states_trimmed.unsqueeze(2), decoder_output.unsqueeze(1), ) - # token_logits: (batch, T, U+1, vocab_size+1) - # duration_logits: (batch, T, U+1, num_duration_bins) - - # move labels to correct device to enable pipeline parallelism - labels = labels.to(token_logits.device) - encoder_lengths = encoder_lengths.to(token_logits.device) - target_lengths = target_lengths.to(token_logits.device) loss = tdt_loss( token_logits=token_logits.float(), duration_logits=duration_logits.float(), - targets=labels.int(), - logit_lengths=encoder_lengths.int(), - target_lengths=target_lengths.int(), + targets=labels.to(token_logits.device).int(), + logit_lengths=encoder_lengths.to(token_logits.device).int(), + target_lengths=target_lengths.to(token_logits.device).int(), blank=self.config.pad_token_id, durations=self.config.durations, reduction="mean", @@ -1010,8 +1002,8 @@ def generate( Args: return_timestamps (`bool`, *optional*, defaults to `False`): - Whether to return per-token timestamps in seconds. When `True`, forces - `return_dict_in_generate=True` and includes `token_timestamps` in the output. + Whether to return per-token timestamps and durations. When `True`, forces + `return_dict_in_generate=True` and includes `token_timestamps` and `token_durations` in the output. Example: @@ -1026,28 +1018,33 @@ def generate( >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) - >>> inputs = processor(ds[0]["audio"]["array"]) + >>> inputs = processor(ds[0]["audio"]["array"], sampling_rate=processor.feature_extractor.sampling_rate) + >>> inputs = inputs.to(model.device, dtype=model.dtype) >>> output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) - >>> transcription = processor.batch_decode(output.sequences, skip_special_tokens=True) - >>> print(transcription) - >>> print(output.token_timestamps) + >>> decoded_output, decoded_timestamps = processor.decode( + ... output.sequences, + ... token_timestamps=output.token_timestamps, + ... token_durations=output.token_durations, + ... skip_special_tokens=True + ... ) + >>> print("Transcription:", decoded_output) + >>> print("Timestamped tokens:", decoded_timestamps) ``` """ kwargs["return_dict"] = True if return_timestamps: return_dict_in_generate = True - - batch_size = input_features.shape[0] outputs: CausalLMOutput = self.forward( input_features=input_features, attention_mask=attention_mask, **kwargs, ) - encoder_hidden_states = outputs.logits + # greedy TDT decoding, `GreedyBatchedTDTLabelLoopingComputer.torch_impl` in NeMo + encoder_hidden_states = outputs.logits + batch_size, sequence_length = encoder_hidden_states.shape[:2] device = encoder_hidden_states.device - sequence_length = encoder_hidden_states.shape[1] if attention_mask is not None: encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) valid_lengths = encoder_attention_mask.sum(dim=1).int() @@ -1075,6 +1072,7 @@ def generate( last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) while active_mask.any(): + active_mask_prev = active_mask.clone() safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) encoder_frames = encoder_hidden_states[batch_indices, safe_time_indices].unsqueeze(1) @@ -1084,7 +1082,7 @@ def generate( tokens = token_logits.argmax(dim=-1) durations = duration_logits.argmax(dim=-1) - blank_mask = active_mask & (tokens == self.config.pad_token_id) + blank_mask = active_mask_prev & (tokens == self.config.pad_token_id) # Force blank duration >= 1 to guarantee forward progress durations = durations.masked_fill(blank_mask & (durations == 0), 1) @@ -1123,7 +1121,7 @@ def generate( advance_mask = active_mask & blank_mask # Record results for non-blank tokens found - emit_mask = active_mask & (tokens != self.config.pad_token_id) + emit_mask = active_mask_prev & (tokens != self.config.pad_token_id) for i in range(batch_size): if emit_mask[i]: all_tokens[i].append(tokens[i].item()) From 07d8e35e5c79a9290cfb7f3d8fdf44c1dd3973a9 Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 6 Mar 2026 16:58:59 +0100 Subject: [PATCH 22/67] Better pre-allocate. --- .../models/parakeet/modeling_parakeet.py | 134 ++++++++---------- .../models/parakeet/modular_parakeet.py | 133 ++++++++--------- 2 files changed, 123 insertions(+), 144 deletions(-) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index c19f79bf6a68..ebb3417f51b0 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -890,14 +890,16 @@ def __init__(self, config: ParakeetTDTConfig): def forward( self, - encoder_output: torch.Tensor, decoder_output: torch.Tensor, + encoder_output: torch.Tensor | None = None, + projected_encoder_output: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - encoder_projected = self.encoder_projector(encoder_output) - joint_output = self.activation(encoder_projected + decoder_output) - token_logits = self.token_head(joint_output) - duration_logits = self.duration_head(joint_output) - return token_logits, duration_logits + if projected_encoder_output is None: + if encoder_output is None: + raise ValueError("Either encoder_output or projected_encoder_output must be provided.") + projected_encoder_output = self.encoder_projector(encoder_output) + joint_output = self.activation(projected_encoder_output + decoder_output) + return self.token_head(joint_output), self.duration_head(joint_output) # TODO (ebezzam) eventually move to audio_utils or loss_utils for common usage? @@ -1118,8 +1120,8 @@ def forward( # encoder: (batch, T, 1, encoder_hidden) -> projected to (batch, T, 1, decoder_hidden_size) # decoder: (batch, 1, U+1, decoder_hidden_size) token_logits, duration_logits = self.joint( - encoder_hidden_states_trimmed.unsqueeze(2), - decoder_output.unsqueeze(1), + decoder_output=decoder_output.unsqueeze(1), + encoder_output=encoder_hidden_states_trimmed.unsqueeze(2), ) loss = tdt_loss( @@ -1203,7 +1205,7 @@ def generate( else: valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) - # Initialize decoder + # Initialization hidden_state, cell_state = None, None prev_tokens = torch.full((batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=device) decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) @@ -1211,56 +1213,68 @@ def generate( hidden_state = hidden_state.to(device) cell_state = cell_state.to(device) - all_tokens = [[] for _ in range(batch_size)] - token_frame_indices = [[] for _ in range(batch_size)] if return_timestamps else None - token_durations_list = [[] for _ in range(batch_size)] if return_timestamps else None batch_indices = torch.arange(batch_size, device=device) time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) time_indices_current_labels = torch.zeros(batch_size, dtype=torch.long, device=device) active_mask = time_indices < valid_lengths + active_mask_prev = torch.zeros_like(active_mask) - max_symbols = self.config.max_symbols_per_step + zeros_symbols = torch.zeros(batch_size, dtype=torch.long, device=device) symbols_per_step = torch.zeros(batch_size, dtype=torch.long, device=device) last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) + max_output_len = sequence_length * self.config.max_symbols_per_step + all_tokens_tensor = torch.full( + (batch_size, max_output_len), self.config.pad_token_id, dtype=torch.long, device=device + ) + token_counts = torch.zeros(batch_size, dtype=torch.long, device=device) + if return_timestamps: + all_frame_indices = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) + all_durations_tensor = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) + + # separately call encoder projection to avoid redundant computation inside loop + projected_encoder_output = self.joint.encoder_projector(encoder_hidden_states).to(device) while active_mask.any(): - active_mask_prev = active_mask.clone() + active_mask_prev.copy_(active_mask) safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) - encoder_frames = encoder_hidden_states[batch_indices, safe_time_indices].unsqueeze(1) + projected_encoder_frames = projected_encoder_output[batch_indices, safe_time_indices].unsqueeze(1) - token_logits, duration_logits = self.joint(encoder_frames, decoder_output) + token_logits, duration_logits = self.joint( + decoder_output, + projected_encoder_output=projected_encoder_frames, + ) token_logits = token_logits.squeeze(1).to(device) duration_logits = duration_logits.squeeze(1).to(device) tokens = token_logits.argmax(dim=-1) durations = duration_logits.argmax(dim=-1) - blank_mask = active_mask_prev & (tokens == self.config.pad_token_id) # Force blank duration >= 1 to guarantee forward progress + blank_mask = active_mask_prev & (tokens == self.config.pad_token_id) durations = durations.masked_fill(blank_mask & (durations == 0), 1) # Save pre-advance position for timestamp recording time_indices_current_labels.copy_(time_indices) # Advance time for all active elements - time_indices = time_indices + durations * active_mask + time_indices = time_indices + durations.masked_fill(~active_mask, 0) safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) active_mask = time_indices < valid_lengths advance_mask = active_mask & blank_mask # Inner loop: skip past consecutive blanks to find non-blank while advance_mask.any(): - # Update timestamp tracking to current position time_indices_current_labels = torch.where(advance_mask, time_indices, time_indices_current_labels) - encoder_frames = encoder_hidden_states[batch_indices, safe_time_indices].unsqueeze(1) + projected_encoder_frames = projected_encoder_output[batch_indices, safe_time_indices].unsqueeze(1) - token_logits, duration_logits = self.joint(encoder_frames, decoder_output) + token_logits, duration_logits = self.joint( + decoder_output, projected_encoder_output=projected_encoder_frames + ) token_logits = token_logits.squeeze(1).to(device) duration_logits = duration_logits.squeeze(1).to(device) more_tokens = token_logits.argmax(dim=-1) more_durations = duration_logits.argmax(dim=-1) - tokens = torch.where(advance_mask, more_tokens, tokens) durations = torch.where(advance_mask, more_durations, durations) @@ -1274,66 +1288,43 @@ def generate( # Record results for non-blank tokens found emit_mask = active_mask_prev & (tokens != self.config.pad_token_id) - for i in range(batch_size): - if emit_mask[i]: - all_tokens[i].append(tokens[i].item()) - if token_frame_indices is not None: - token_frame_indices[i].append(time_indices_current_labels[i].item()) - if token_durations_list is not None: - token_durations_list[i].append(durations[i].item()) - - if emit_mask.any(): - new_prev_tokens = tokens.unsqueeze(1) - new_decoder_output, new_hidden_state, new_cell_state = self.decoder( - new_prev_tokens, hidden_state, cell_state - ) - new_decoder_output = new_decoder_output.to(device) - new_hidden_state = new_hidden_state.to(device) - new_cell_state = new_cell_state.to(device) - - emit_mask_expanded = emit_mask.view(batch_size, 1, 1) - decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) + emit_indices = token_counts[emit_mask] + all_tokens_tensor[emit_mask, emit_indices] = tokens[emit_mask] + if return_timestamps: + all_frame_indices[emit_mask, emit_indices] = time_indices_current_labels[emit_mask] + all_durations_tensor[emit_mask, emit_indices] = durations[emit_mask] + token_counts += emit_mask.long() + + new_decoder_output, new_hidden_state, new_cell_state = self.decoder( + tokens.unsqueeze(1), hidden_state, cell_state + ) + new_decoder_output = new_decoder_output.to(device) + new_hidden_state = new_hidden_state.to(device) + new_cell_state = new_cell_state.to(device) - emit_mask_state = emit_mask.view(1, batch_size, 1) - hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) - cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) + emit_mask_expanded = emit_mask.view(batch_size, 1, 1) + decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) + emit_mask_state = emit_mask.view(1, batch_size, 1) + hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) + cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) # Track symbols emitted per time step; force advance when max_symbols reached time_changed = time_indices_current_labels != last_label_time - symbols_per_step = torch.where(time_changed, torch.zeros_like(symbols_per_step), symbols_per_step) + symbols_per_step = torch.where(time_changed, zeros_symbols, symbols_per_step) symbols_per_step = torch.where(emit_mask, symbols_per_step + 1, symbols_per_step) last_label_time = torch.where(emit_mask, time_indices_current_labels, last_label_time) - force_advance = active_mask & (symbols_per_step >= max_symbols) + force_advance = active_mask & (symbols_per_step >= self.config.max_symbols_per_step) time_indices = time_indices + force_advance.long() symbols_per_step = symbols_per_step.masked_fill(force_advance, 0) - active_mask = time_indices < valid_lengths - # Pad sequences to same length - max_len = max((len(seq) for seq in all_tokens), default=0) - if max_len == 0: - max_len = 1 - - sequences = torch.full((batch_size, max_len), self.config.pad_token_id, dtype=torch.long, device=device) - for i in range(batch_size): - seq_len = len(all_tokens[i]) - if seq_len > 0: - sequences[i, :seq_len] = torch.tensor(all_tokens[i], dtype=torch.long, device=device) - - token_timestamps = None - token_durations = None + # Guard against edge case where no tokens were decoded (e.g. silent audio) + max_len = max(token_counts.max().item(), 1) + sequences = all_tokens_tensor[:, :max_len] + token_timestamps, token_durations = None, None if return_timestamps: - token_timestamps = torch.full((batch_size, max_len), 0.0, dtype=torch.long, device=device) - token_durations = torch.full((batch_size, max_len), 0, dtype=torch.long, device=device) - for i in range(batch_size): - num_tokens = len(token_frame_indices[i]) - if num_tokens > 0: - token_timestamps[i, :num_tokens] = torch.tensor( - token_frame_indices[i], dtype=torch.long, device=device - ) - token_durations[i, :num_tokens] = torch.tensor( - token_durations_list[i], dtype=torch.long, device=device - ) + token_timestamps = all_frame_indices[:, :max_len] + token_durations = all_durations_tensor[:, :max_len] if return_dict_in_generate: return ParakeetTDTGenerateOutput( @@ -1343,7 +1334,6 @@ def generate( attentions=outputs.attentions, hidden_states=outputs.hidden_states, ) - return sequences diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index c10afe6b7b64..294468ed640c 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -873,14 +873,16 @@ def __init__(self, config: ParakeetTDTConfig): def forward( self, - encoder_output: torch.Tensor, decoder_output: torch.Tensor, + encoder_output: torch.Tensor | None = None, + projected_encoder_output: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - encoder_projected = self.encoder_projector(encoder_output) - joint_output = self.activation(encoder_projected + decoder_output) - token_logits = self.token_head(joint_output) - duration_logits = self.duration_head(joint_output) - return token_logits, duration_logits + if projected_encoder_output is None: + if encoder_output is None: + raise ValueError("Either encoder_output or projected_encoder_output must be provided.") + projected_encoder_output = self.encoder_projector(encoder_output) + joint_output = self.activation(projected_encoder_output + decoder_output) + return self.token_head(joint_output), self.duration_head(joint_output) @auto_docstring( @@ -966,8 +968,8 @@ def forward( # encoder: (batch, T, 1, encoder_hidden) -> projected to (batch, T, 1, decoder_hidden_size) # decoder: (batch, 1, U+1, decoder_hidden_size) token_logits, duration_logits = self.joint( - encoder_hidden_states_trimmed.unsqueeze(2), - decoder_output.unsqueeze(1), + decoder_output=decoder_output.unsqueeze(1), + encoder_output=encoder_hidden_states_trimmed.unsqueeze(2), ) loss = tdt_loss( @@ -1051,7 +1053,7 @@ def generate( else: valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) - # Initialize decoder + # Initialization hidden_state, cell_state = None, None prev_tokens = torch.full((batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=device) decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) @@ -1059,56 +1061,67 @@ def generate( hidden_state = hidden_state.to(device) cell_state = cell_state.to(device) - all_tokens = [[] for _ in range(batch_size)] - token_frame_indices = [[] for _ in range(batch_size)] if return_timestamps else None - token_durations_list = [[] for _ in range(batch_size)] if return_timestamps else None batch_indices = torch.arange(batch_size, device=device) time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) time_indices_current_labels = torch.zeros(batch_size, dtype=torch.long, device=device) active_mask = time_indices < valid_lengths + active_mask_prev = torch.zeros_like(active_mask) - max_symbols = self.config.max_symbols_per_step + zeros_symbols = torch.zeros(batch_size, dtype=torch.long, device=device) symbols_per_step = torch.zeros(batch_size, dtype=torch.long, device=device) last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) + max_output_len = sequence_length * self.config.max_symbols_per_step + all_tokens_tensor = torch.full((batch_size, max_output_len), self.config.pad_token_id, dtype=torch.long, device=device) + token_counts = torch.zeros(batch_size, dtype=torch.long, device=device) + if return_timestamps: + all_frame_indices = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) + all_durations_tensor = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) + # separately call encoder projection to avoid redundant computation inside loop + projected_encoder_output = self.joint.encoder_projector(encoder_hidden_states).to(device) + while active_mask.any(): - active_mask_prev = active_mask.clone() + active_mask_prev.copy_(active_mask) safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) - encoder_frames = encoder_hidden_states[batch_indices, safe_time_indices].unsqueeze(1) + projected_encoder_frames = projected_encoder_output[batch_indices, safe_time_indices].unsqueeze(1) - token_logits, duration_logits = self.joint(encoder_frames, decoder_output) + token_logits, duration_logits = self.joint( + decoder_output, + projected_encoder_output=projected_encoder_frames, + ) token_logits = token_logits.squeeze(1).to(device) duration_logits = duration_logits.squeeze(1).to(device) tokens = token_logits.argmax(dim=-1) durations = duration_logits.argmax(dim=-1) - blank_mask = active_mask_prev & (tokens == self.config.pad_token_id) # Force blank duration >= 1 to guarantee forward progress + blank_mask = active_mask_prev & (tokens == self.config.pad_token_id) durations = durations.masked_fill(blank_mask & (durations == 0), 1) # Save pre-advance position for timestamp recording time_indices_current_labels.copy_(time_indices) # Advance time for all active elements - time_indices = time_indices + durations * active_mask + time_indices = time_indices + durations.masked_fill(~active_mask, 0) safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) active_mask = time_indices < valid_lengths advance_mask = active_mask & blank_mask # Inner loop: skip past consecutive blanks to find non-blank while advance_mask.any(): - # Update timestamp tracking to current position time_indices_current_labels = torch.where(advance_mask, time_indices, time_indices_current_labels) - encoder_frames = encoder_hidden_states[batch_indices, safe_time_indices].unsqueeze(1) + projected_encoder_frames = projected_encoder_output[batch_indices, safe_time_indices].unsqueeze(1) - token_logits, duration_logits = self.joint(encoder_frames, decoder_output) + token_logits, duration_logits = self.joint( + decoder_output, + projected_encoder_output=projected_encoder_frames + ) token_logits = token_logits.squeeze(1).to(device) duration_logits = duration_logits.squeeze(1).to(device) more_tokens = token_logits.argmax(dim=-1) more_durations = duration_logits.argmax(dim=-1) - tokens = torch.where(advance_mask, more_tokens, tokens) durations = torch.where(advance_mask, more_durations, durations) @@ -1122,66 +1135,43 @@ def generate( # Record results for non-blank tokens found emit_mask = active_mask_prev & (tokens != self.config.pad_token_id) - for i in range(batch_size): - if emit_mask[i]: - all_tokens[i].append(tokens[i].item()) - if token_frame_indices is not None: - token_frame_indices[i].append(time_indices_current_labels[i].item()) - if token_durations_list is not None: - token_durations_list[i].append(durations[i].item()) - - if emit_mask.any(): - new_prev_tokens = tokens.unsqueeze(1) - new_decoder_output, new_hidden_state, new_cell_state = self.decoder( - new_prev_tokens, hidden_state, cell_state - ) - new_decoder_output = new_decoder_output.to(device) - new_hidden_state = new_hidden_state.to(device) - new_cell_state = new_cell_state.to(device) - - emit_mask_expanded = emit_mask.view(batch_size, 1, 1) - decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) + emit_indices = token_counts[emit_mask] + all_tokens_tensor[emit_mask, emit_indices] = tokens[emit_mask] + if return_timestamps: + all_frame_indices[emit_mask, emit_indices] = time_indices_current_labels[emit_mask] + all_durations_tensor[emit_mask, emit_indices] = durations[emit_mask] + token_counts += emit_mask.long() + + new_decoder_output, new_hidden_state, new_cell_state = self.decoder( + tokens.unsqueeze(1), hidden_state, cell_state + ) + new_decoder_output = new_decoder_output.to(device) + new_hidden_state = new_hidden_state.to(device) + new_cell_state = new_cell_state.to(device) - emit_mask_state = emit_mask.view(1, batch_size, 1) - hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) - cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) + emit_mask_expanded = emit_mask.view(batch_size, 1, 1) + decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) + emit_mask_state = emit_mask.view(1, batch_size, 1) + hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) + cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) # Track symbols emitted per time step; force advance when max_symbols reached time_changed = time_indices_current_labels != last_label_time - symbols_per_step = torch.where(time_changed, torch.zeros_like(symbols_per_step), symbols_per_step) + symbols_per_step = torch.where(time_changed, zeros_symbols, symbols_per_step) symbols_per_step = torch.where(emit_mask, symbols_per_step + 1, symbols_per_step) last_label_time = torch.where(emit_mask, time_indices_current_labels, last_label_time) - force_advance = active_mask & (symbols_per_step >= max_symbols) + force_advance = active_mask & (symbols_per_step >= self.config.max_symbols_per_step) time_indices = time_indices + force_advance.long() symbols_per_step = symbols_per_step.masked_fill(force_advance, 0) - active_mask = time_indices < valid_lengths - # Pad sequences to same length - max_len = max((len(seq) for seq in all_tokens), default=0) - if max_len == 0: - max_len = 1 - - sequences = torch.full((batch_size, max_len), self.config.pad_token_id, dtype=torch.long, device=device) - for i in range(batch_size): - seq_len = len(all_tokens[i]) - if seq_len > 0: - sequences[i, :seq_len] = torch.tensor(all_tokens[i], dtype=torch.long, device=device) - - token_timestamps = None - token_durations = None + # Guard against edge case where no tokens were decoded (e.g. silent audio) + max_len = max(token_counts.max().item(), 1) + sequences = all_tokens_tensor[:, :max_len] + token_timestamps, token_durations = None, None if return_timestamps: - token_timestamps = torch.full((batch_size, max_len), 0.0, dtype=torch.long, device=device) - token_durations = torch.full((batch_size, max_len), 0, dtype=torch.long, device=device) - for i in range(batch_size): - num_tokens = len(token_frame_indices[i]) - if num_tokens > 0: - token_timestamps[i, :num_tokens] = torch.tensor( - token_frame_indices[i], dtype=torch.long, device=device - ) - token_durations[i, :num_tokens] = torch.tensor( - token_durations_list[i], dtype=torch.long, device=device - ) + token_timestamps = all_frame_indices[:, :max_len] + token_durations = all_durations_tensor[:, :max_len] if return_dict_in_generate: return ParakeetTDTGenerateOutput( @@ -1191,7 +1181,6 @@ def generate( attentions=outputs.attentions, hidden_states=outputs.hidden_states, ) - return sequences From fab050a3cfe7d4f0c7f90464db023f99f9baebe4 Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 6 Mar 2026 22:58:11 +0100 Subject: [PATCH 23/67] TDT has separate pad token and blank token. --- .../models/parakeet/configuration_parakeet.py | 14 ++++-- .../models/parakeet/convert_nemo_to_hf.py | 47 +++++++++++-------- .../models/parakeet/modeling_parakeet.py | 39 +++++---------- .../models/parakeet/modular_parakeet.py | 45 +++++++----------- .../models/parakeet/processing_parakeet.py | 9 ++-- .../parakeet/expected_results_batch_tdt.json | 2 +- .../expected_results_batch_tdt_timestamp.json | 2 +- .../models/parakeet/test_modeling_parakeet.py | 7 ++- 8 files changed, 77 insertions(+), 88 deletions(-) diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index cbe9073ee963..ea3cc1f9afe8 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -238,7 +238,7 @@ class ParakeetTDTConfig(PreTrainedConfig): documentation from [`PreTrainedConfig`] for more information. Args: - vocab_size (`int`, *optional*, defaults to 8192): + vocab_size (`int`, *optional*, defaults to 8193): Vocabulary size of the model. decoder_hidden_size (`int`, *optional*, defaults to 640): Hidden size of the LSTM prediction network and joint network. @@ -255,8 +255,10 @@ class ParakeetTDTConfig(PreTrainedConfig): Maximum number of symbols to emit per encoder time step during greedy decoding. encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): The config object or dictionary of the encoder. - pad_token_id (`int`, *optional*, defaults to 8192): - Padding token id. Also used as blank token id for TDT decoding. + pad_token_id (`int`, *optional*, defaults to 2): + Padding token id. + blank_token_id (`int`, *optional*, defaults to 8192): + Blank token id. Different from `pad_token_id` for TDT. Example: ```python @@ -278,14 +280,15 @@ class ParakeetTDTConfig(PreTrainedConfig): def __init__( self, - vocab_size=8192, + vocab_size=8193, decoder_hidden_size=640, num_decoder_layers=1, durations=[0, 1, 2, 3, 4], hidden_act="relu", max_symbols_per_step=10, encoder_config: dict | ParakeetEncoderConfig = None, - pad_token_id=8192, + pad_token_id=2, + blank_token_id=8192, **kwargs, ): self.vocab_size = vocab_size @@ -303,6 +306,7 @@ def __init__( self.encoder_config = encoder_config self.initializer_range = self.encoder_config.initializer_range + self.blank_token_id = blank_token_id self.pad_token_id = pad_token_id super().__init__(**kwargs) diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index daed5c11c598..4fb17653e59c 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -141,15 +141,25 @@ def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str return model_files -def write_processor(nemo_config: dict, model_files, output_dir, push_to_repo_id=None): +def write_processor(nemo_config: dict, model_files, output_dir, model_type, push_to_repo_id=None): tokenizer_converted = ParakeetConverter(model_files["tokenizer_model_file"]).converted() tokenizer_converted_fast = ParakeetTokenizer( tokenizer_object=tokenizer_converted, clean_up_tokenization_spaces=False, ) - tokenizer_converted_fast.add_tokens( - [AddedToken("", normalized=False, special=True), AddedToken("", normalized=False, special=True)] - ) + + if tokenizer_converted_fast.convert_tokens_to_ids("") is None: + # Normally CTC and TDT already have + tokenizer_converted_fast.add_tokens([AddedToken("", normalized=False, special=True)]) + print(f"Added token at ID: {tokenizer_converted_fast.convert_tokens_to_ids('')}") + if tokenizer_converted_fast.convert_tokens_to_ids("") is None: + # Normally CTC doesn't have while TDT has at token id = 2 + tokenizer_converted_fast.add_tokens([AddedToken("", normalized=False, special=True)]) + print(f"Added token at ID: {tokenizer_converted_fast.convert_tokens_to_ids('')}") + if model_type == "tdt": + # TDT needs a separate blank token + tokenizer_converted_fast.add_tokens([AddedToken("", normalized=False, special=True)]) + print(f"Added token at ID: {tokenizer_converted_fast.convert_tokens_to_ids('')}") tokenizer_converted_fast.add_special_tokens( { "pad_token": AddedToken("", normalized=False, special=True), @@ -186,7 +196,6 @@ def write_processor(nemo_config: dict, model_files, output_dir, push_to_repo_id= raise ValueError(f"Key {key} not found in feature_extractor_keys_mapping") feature_extractor = ParakeetFeatureExtractor(**converted_feature_extractor_config) - processor = ParakeetProcessor( feature_extractor=feature_extractor, tokenizer=tokenizer_converted_fast, @@ -290,19 +299,19 @@ def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_re def convert_tdt_config(nemo_config, encoder_config): """Convert NeMo TDT config to HF TDT config.""" - decoder_config = nemo_config.get("decoder", {}) - decoding_config = nemo_config.get("decoding", {}) - - labels = nemo_config.get("labels", []) - vocab_size = len(labels) if labels else decoder_config.get("vocab_size", 1024) + decoder_config = nemo_config["decoder"] + decoding_config = nemo_config["decoding"] + labels = nemo_config["labels"] + blank_token_id = len(labels) + vocab_size = len(labels) + 1 # +1 for blank token, which is added to tokenizer prednet = decoder_config.get("prednet", {}) decoder_hidden_size = prednet.get("pred_hidden", 640) num_decoder_layers = prednet.get("pred_rnn_layers", 2) - durations = decoding_config.get("durations", [0, 1, 2, 3, 4]) print( - f"TDT config: vocab_size={vocab_size}, decoder_hidden={decoder_hidden_size}, " + f"TDT config: vocab_size={vocab_size} (including blank token), " + f"decoder_hidden={decoder_hidden_size}, " f"decoder_layers={num_decoder_layers}, durations={durations}, " ) @@ -314,7 +323,8 @@ def convert_tdt_config(nemo_config, encoder_config): hidden_act="relu", max_symbols_per_step=10, encoder_config=encoder_config.to_dict(), - pad_token_id=vocab_size, + pad_token_id=labels.index(""), + blank_token_id=blank_token_id, # blank token is different from pad token for TDT ) @@ -330,18 +340,17 @@ def load_and_convert_tdt_state_dict(model_files, vocab_size): print(f"Skipping preprocessing weight: {key}") continue - # Handle combined output head split if key == "joint.joint_net.2.weight": - token_weight = value[: vocab_size + 1, :] - duration_weight = value[vocab_size + 1 :, :] + token_weight = value[:vocab_size, :] + duration_weight = value[vocab_size:, :] converted_state_dict["joint.token_head.weight"] = token_weight converted_state_dict["joint.duration_head.weight"] = duration_weight print(f"Split combined weight: token_head {token_weight.shape}, duration_head {duration_weight.shape}") continue if key == "joint.joint_net.2.bias": - token_bias = value[: vocab_size + 1] - duration_bias = value[vocab_size + 1 :] + token_bias = value[:vocab_size] + duration_bias = value[vocab_size:] converted_state_dict["joint.token_head.bias"] = token_bias converted_state_dict["joint.duration_head.bias"] = duration_bias print(f"Split combined bias: token_head {token_bias.shape}, duration_head {duration_bias.shape}") @@ -416,7 +425,7 @@ def main( model_files = extract_nemo_archive(filepath, os.path.dirname(filepath)) nemo_config = yaml.load(open(model_files["model_config"], "r"), Loader=yaml.FullLoader) - write_processor(nemo_config, model_files, output_dir, push_to_repo_id) + write_processor(nemo_config, model_files, output_dir, model_type, push_to_repo_id) write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index ebb3417f51b0..fc4926caf39a 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -756,8 +756,7 @@ def forward( ) input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) - # assuming that padded tokens are filled with -100 - # when not being attended to + # assuming that padded tokens are filled with pad_token_id when not being attended to labels_mask = labels != self.config.pad_token_id target_lengths = labels_mask.sum(-1) flattened_targets = labels.masked_select(labels_mask) @@ -844,7 +843,7 @@ class ParakeetTDTDecoder(nn.Module): def __init__(self, config: ParakeetTDTConfig): super().__init__() self.config = config - self.embedding = nn.Embedding(config.vocab_size + 1, config.decoder_hidden_size) + self.embedding = nn.Embedding(config.vocab_size, config.decoder_hidden_size) self.lstm = nn.LSTM( input_size=config.decoder_hidden_size, hidden_size=config.decoder_hidden_size, @@ -885,7 +884,7 @@ def __init__(self, config: ParakeetTDTConfig): super().__init__() self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.activation = ACT2FN[config.hidden_act] - self.token_head = nn.Linear(config.decoder_hidden_size, config.vocab_size + 1) + self.token_head = nn.Linear(config.decoder_hidden_size, config.vocab_size) self.duration_head = nn.Linear(config.decoder_hidden_size, len(config.durations)) def forward( @@ -1098,30 +1097,18 @@ def forward( ) encoder_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) - labels_mask = labels != -100 - target_lengths = labels_mask.sum(-1) - - labels = labels.clone() - labels[labels == -100] = self.config.pad_token_id + # Prepare labels for TDT loss + target_lengths = (labels != self.config.pad_token_id).sum(-1) - # Prepare decoder input: prepend blank token to labels + # Get joint decoder outputs blank_tokens = torch.full( - (labels.shape[0], 1), self.config.pad_token_id, dtype=labels.dtype, device=labels.device + (labels.shape[0], 1), self.config.blank_token_id, dtype=labels.dtype, device=labels.device ) decoder_input = torch.cat([blank_tokens, labels], dim=1) - - # Run decoder on full label sequence: (batch, U+1, decoder_hidden_size) decoder_output, _, _ = self.decoder(decoder_input) - - max_encoder_length = encoder_lengths.max().item() - encoder_hidden_states_trimmed = encoder_hidden_states[:, :max_encoder_length] - - # Compute joint output for all (T, U+1) pairs via broadcasting - # encoder: (batch, T, 1, encoder_hidden) -> projected to (batch, T, 1, decoder_hidden_size) - # decoder: (batch, 1, U+1, decoder_hidden_size) token_logits, duration_logits = self.joint( decoder_output=decoder_output.unsqueeze(1), - encoder_output=encoder_hidden_states_trimmed.unsqueeze(2), + encoder_output=encoder_hidden_states.unsqueeze(2), ) loss = tdt_loss( @@ -1130,7 +1117,7 @@ def forward( targets=labels.to(token_logits.device).int(), logit_lengths=encoder_lengths.to(token_logits.device).int(), target_lengths=target_lengths.to(token_logits.device).int(), - blank=self.config.pad_token_id, + blank=self.config.blank_token_id, durations=self.config.durations, reduction="mean", ) @@ -1207,7 +1194,7 @@ def generate( # Initialization hidden_state, cell_state = None, None - prev_tokens = torch.full((batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=device) + prev_tokens = torch.full((batch_size, 1), self.config.blank_token_id, dtype=torch.long, device=device) decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) decoder_output = decoder_output.to(device) hidden_state = hidden_state.to(device) @@ -1250,7 +1237,7 @@ def generate( durations = duration_logits.argmax(dim=-1) # Force blank duration >= 1 to guarantee forward progress - blank_mask = active_mask_prev & (tokens == self.config.pad_token_id) + blank_mask = active_mask_prev & (tokens == self.config.blank_token_id) durations = durations.masked_fill(blank_mask & (durations == 0), 1) # Save pre-advance position for timestamp recording @@ -1278,7 +1265,7 @@ def generate( tokens = torch.where(advance_mask, more_tokens, tokens) durations = torch.where(advance_mask, more_durations, durations) - blank_mask = tokens == self.config.pad_token_id + blank_mask = tokens == self.config.blank_token_id durations = durations.masked_fill(blank_mask & (durations == 0), 1) time_indices = torch.where(advance_mask, time_indices + durations, time_indices) @@ -1287,7 +1274,7 @@ def generate( advance_mask = active_mask & blank_mask # Record results for non-blank tokens found - emit_mask = active_mask_prev & (tokens != self.config.pad_token_id) + emit_mask = active_mask_prev & (tokens != self.config.blank_token_id) emit_indices = token_counts[emit_mask] all_tokens_tensor[emit_mask, emit_indices] = tokens[emit_mask] if return_timestamps: diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 294468ed640c..6e075dd4393d 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -604,8 +604,7 @@ def forward( ) input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) - # assuming that padded tokens are filled with -100 - # when not being attended to + # assuming that padded tokens are filled with pad_token_id when not being attended to labels_mask = labels != self.config.pad_token_id target_lengths = labels_mask.sum(-1) flattened_targets = labels.masked_select(labels_mask) @@ -692,7 +691,7 @@ class ParakeetTDTDecoder(nn.Module): def __init__(self, config: ParakeetTDTConfig): super().__init__() self.config = config - self.embedding = nn.Embedding(config.vocab_size + 1, config.decoder_hidden_size) + self.embedding = nn.Embedding(config.vocab_size, config.decoder_hidden_size) self.lstm = nn.LSTM( input_size=config.decoder_hidden_size, hidden_size=config.decoder_hidden_size, @@ -868,7 +867,7 @@ def __init__(self, config: ParakeetTDTConfig): super().__init__() self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.activation = ACT2FN[config.hidden_act] - self.token_head = nn.Linear(config.decoder_hidden_size, config.vocab_size + 1) + self.token_head = nn.Linear(config.decoder_hidden_size, config.vocab_size) self.duration_head = nn.Linear(config.decoder_hidden_size, len(config.durations)) def forward( @@ -946,30 +945,18 @@ def forward( ) encoder_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) - labels_mask = labels != -100 - target_lengths = labels_mask.sum(-1) - - labels = labels.clone() - labels[labels == -100] = self.config.pad_token_id + # Prepare labels for TDT loss + target_lengths = (labels != self.config.pad_token_id).sum(-1) - # Prepare decoder input: prepend blank token to labels + # Get joint decoder outputs blank_tokens = torch.full( - (labels.shape[0], 1), self.config.pad_token_id, dtype=labels.dtype, device=labels.device + (labels.shape[0], 1), self.config.blank_token_id, dtype=labels.dtype, device=labels.device ) decoder_input = torch.cat([blank_tokens, labels], dim=1) - - # Run decoder on full label sequence: (batch, U+1, decoder_hidden_size) decoder_output, _, _ = self.decoder(decoder_input) - - max_encoder_length = encoder_lengths.max().item() - encoder_hidden_states_trimmed = encoder_hidden_states[:, :max_encoder_length] - - # Compute joint output for all (T, U+1) pairs via broadcasting - # encoder: (batch, T, 1, encoder_hidden) -> projected to (batch, T, 1, decoder_hidden_size) - # decoder: (batch, 1, U+1, decoder_hidden_size) token_logits, duration_logits = self.joint( decoder_output=decoder_output.unsqueeze(1), - encoder_output=encoder_hidden_states_trimmed.unsqueeze(2), + encoder_output=encoder_hidden_states.unsqueeze(2), ) loss = tdt_loss( @@ -978,7 +965,7 @@ def forward( targets=labels.to(token_logits.device).int(), logit_lengths=encoder_lengths.to(token_logits.device).int(), target_lengths=target_lengths.to(token_logits.device).int(), - blank=self.config.pad_token_id, + blank=self.config.blank_token_id, durations=self.config.durations, reduction="mean", ) @@ -1055,7 +1042,7 @@ def generate( # Initialization hidden_state, cell_state = None, None - prev_tokens = torch.full((batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=device) + prev_tokens = torch.full((batch_size, 1), self.config.blank_token_id, dtype=torch.long, device=device) decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) decoder_output = decoder_output.to(device) hidden_state = hidden_state.to(device) @@ -1079,14 +1066,14 @@ def generate( # separately call encoder projection to avoid redundant computation inside loop projected_encoder_output = self.joint.encoder_projector(encoder_hidden_states).to(device) - + while active_mask.any(): active_mask_prev.copy_(active_mask) safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) - projected_encoder_frames = projected_encoder_output[batch_indices, safe_time_indices].unsqueeze(1) + projected_encoder_frames = projected_encoder_output[batch_indices, safe_time_indices].unsqueeze(1) token_logits, duration_logits = self.joint( - decoder_output, + decoder_output, projected_encoder_output=projected_encoder_frames, ) token_logits = token_logits.squeeze(1).to(device) @@ -1096,7 +1083,7 @@ def generate( durations = duration_logits.argmax(dim=-1) # Force blank duration >= 1 to guarantee forward progress - blank_mask = active_mask_prev & (tokens == self.config.pad_token_id) + blank_mask = active_mask_prev & (tokens == self.config.blank_token_id) durations = durations.masked_fill(blank_mask & (durations == 0), 1) # Save pre-advance position for timestamp recording @@ -1125,7 +1112,7 @@ def generate( tokens = torch.where(advance_mask, more_tokens, tokens) durations = torch.where(advance_mask, more_durations, durations) - blank_mask = tokens == self.config.pad_token_id + blank_mask = tokens == self.config.blank_token_id durations = durations.masked_fill(blank_mask & (durations == 0), 1) time_indices = torch.where(advance_mask, time_indices + durations, time_indices) @@ -1134,7 +1121,7 @@ def generate( advance_mask = active_mask & blank_mask # Record results for non-blank tokens found - emit_mask = active_mask_prev & (tokens != self.config.pad_token_id) + emit_mask = active_mask_prev & (tokens != self.config.blank_token_id) emit_indices = token_counts[emit_mask] all_tokens_tensor[emit_mask, emit_indices] = tokens[emit_mask] if return_timestamps: diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py index dca9e75b0769..0b662f56af34 100644 --- a/src/transformers/models/parakeet/processing_parakeet.py +++ b/src/transformers/models/parakeet/processing_parakeet.py @@ -83,9 +83,7 @@ def __call__( if text is None: return inputs else: - labels = encodings["input_ids"] - labels[labels == self.tokenizer.pad_token_id] = -100 - inputs["labels"] = labels + inputs["labels"] = encodings["input_ids"] return inputs @property @@ -115,9 +113,10 @@ def decode(self, *args, token_timestamps=None, token_durations=None, **kwargs): ) proc_timestamps = [] for batch_ids, timestamps, durations in zip(token_ids, token_timestamps, token_durations): - # Original NeMo: https://github.com/NVIDIA-NeMo/NeMo/blob/1692a8fb97e1aadc883cfadd2a57c4e8a1b793aa/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L993 + # See `compute_rnnt_timestamps` in NeMo: https://github.com/NVIDIA-NeMo/NeMo/blob/1692a8fb97e1aadc883cfadd2a57c4e8a1b793aa/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L993 + # Filter padding (unwritten positions in `all_tokens_tensor` in `generate`) non_blank_indices = [ - i for i, token_id in enumerate(batch_ids) if token_id != self.tokenizer.vocab_size + i for i, token_id in enumerate(batch_ids) if token_id != self.tokenizer.pad_token_id ] non_blank_ids = [batch_ids[i] for i in non_blank_indices] decoded_tokens = [self.tokenizer.decode([token_id]) for token_id in non_blank_ids] diff --git a/tests/fixtures/parakeet/expected_results_batch_tdt.json b/tests/fixtures/parakeet/expected_results_batch_tdt.json index c3f46c17321d..54f5198fd834 100644 --- a/tests/fixtures/parakeet/expected_results_batch_tdt.json +++ b/tests/fixtures/parakeet/expected_results_batch_tdt.json @@ -1 +1 @@ -{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "Nor is mister Quilter's manner less interesting than his matter.", "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind.", "He has grave doubts whether Sir Frederick Leighton's work is really Greek after all, and can discover in it but little of Rocky Ithaca.", "Linnell's pictures are a sort of up guards an atom paintings, and Mason's exquisite idols are as national as a jingo poem. mister Burkett Foster's landscapes smile at one much in the same way that mister Carker used to flash his teeth. And mister John Collier gives his sitter a cheerful slap on the back, before he says, like a shampooer in a Turkish bath Next man"], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [5685, 508, 282, 3459, 1382, 305, 441, 7931, 7870, 698, 1742, 293, 561, 1091, 365, 381, 7098, 2745, 1544, 441, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [1876, 280, 530, 7870, 1441, 1050, 407, 1974, 309, 940, 507, 347, 297, 289, 592, 506, 4070, 287, 7877, 1868, 4959, 398, 2037, 575, 603, 534, 555, 7124, 818, 313, 381, 555, 786, 7864, 1441, 7877, 1622, 305, 283, 2324, 1471, 3109, 325, 296, 381, 575, 5404, 1021, 355, 769, 2090, 7880, 344, 3110, 427, 319, 4838, 366, 506, 1737, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [1876, 2281, 1969, 507, 3362, 7886, 769, 328, 1299, 1239, 7319, 6447, 901, 1413, 1333, 3720, 289, 7931, 7870, 6182, 508, 5600, 4190, 377, 799, 441, 1111, 7877, 575, 2059, 5371, 3230, 334, 869, 2681, 7052, 592, 3341, 725, 7893, 2336, 7882, 566, 7865, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [439, 1538, 530, 7931, 7870, 5970, 7868, 4147, 1714, 279, 275, 621, 592, 1840, 1980, 961, 7870, 411, 407, 313, 849, 942, 2399, 7877, 575, 2945, 289, 7931, 7870, 743, 341, 290, 582, 312, 7874, 324, 7870, 1714, 618, 285, 5858, 618, 279, 300, 381, 7869, 408, 311, 7883, 282, 3459, 426, 344, 7876, 861, 515, 308, 441, 7931, 7870, 3650, 7870, 7880, 474, 283, 1530, 787, 407, 2678, 4457, 334, 506, 766, 7864, 7195, 1050, 282, 3459, 3551, 1684, 1441, 326, 366, 309, 1028, 7882, 2745, 478, 291, 7882, 7883, 1976, 282, 3459, 3483, 4003, 332, 277, 317, 416, 283, 2745, 3488, 441, 279, 774, 277, 5346, 275, 4226, 431, 506, 6507, 7877, 555, 786, 7864, 813, 498, 676, 7877, 2656, 279, 275, 3930, 726, 7869, 277, 334, 279, 5183, 7876, 2739, 302, 7152, 1030, 3127, 698]]} \ No newline at end of file +{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "Nor is mister Quilter's manner less interesting than his matter.", "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind.", "He has grave doubts whether Sir Frederick Leighton's work is really Greek after all, and can discover in it but little of Rocky Ithaca.", "Linnell's pictures are a sort of up guards an atom paintings, and Mason's exquisite idols are as national as a jingo poem. mister Burkett Foster's landscapes smile at one much in the same way that mister Carker used to flash his teeth. And mister John Collier gives his sitter a cheerful slap on the back, before he says, like a shampooer in a Turkish bath Next man"], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [5685, 508, 282, 3459, 1382, 305, 441, 7931, 7870, 698, 1742, 293, 561, 1091, 365, 381, 7098, 2745, 1544, 441, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [1876, 280, 530, 7870, 1441, 1050, 407, 1974, 309, 940, 507, 347, 297, 289, 592, 506, 4070, 287, 7877, 1868, 4959, 398, 2037, 575, 603, 534, 555, 7124, 818, 313, 381, 555, 786, 7864, 1441, 7877, 1622, 305, 283, 2324, 1471, 3109, 325, 296, 381, 575, 5404, 1021, 355, 769, 2090, 7880, 344, 3110, 427, 319, 4838, 366, 506, 1737, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [1876, 2281, 1969, 507, 3362, 7886, 769, 328, 1299, 1239, 7319, 6447, 901, 1413, 1333, 3720, 289, 7931, 7870, 6182, 508, 5600, 4190, 377, 799, 441, 1111, 7877, 575, 2059, 5371, 3230, 334, 869, 2681, 7052, 592, 3341, 725, 7893, 2336, 7882, 566, 7865, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [439, 1538, 530, 7931, 7870, 5970, 7868, 4147, 1714, 279, 275, 621, 592, 1840, 1980, 961, 7870, 411, 407, 313, 849, 942, 2399, 7877, 575, 2945, 289, 7931, 7870, 743, 341, 290, 582, 312, 7874, 324, 7870, 1714, 618, 285, 5858, 618, 279, 300, 381, 7869, 408, 311, 7883, 282, 3459, 426, 344, 7876, 861, 515, 308, 441, 7931, 7870, 3650, 7870, 7880, 474, 283, 1530, 787, 407, 2678, 4457, 334, 506, 766, 7864, 7195, 1050, 282, 3459, 3551, 1684, 1441, 326, 366, 309, 1028, 7882, 2745, 478, 291, 7882, 7883, 1976, 282, 3459, 3483, 4003, 332, 277, 317, 416, 283, 2745, 3488, 441, 279, 774, 277, 5346, 275, 4226, 431, 506, 6507, 7877, 555, 786, 7864, 813, 498, 676, 7877, 2656, 279, 275, 3930, 726, 7869, 277, 334, 279, 5183, 7876, 2739, 302, 7152, 1030, 3127, 698]]} \ No newline at end of file diff --git a/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json b/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json index e27e5f8304e5..0a9b2180b4cb 100644 --- a/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json +++ b/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json @@ -1 +1 @@ -{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "Nor is mister Quilter's manner less interesting than his matter.", "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind."], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [5685, 508, 282, 3459, 1382, 305, 441, 7931, 7870, 698, 1742, 293, 561, 1091, 365, 381, 7098, 2745, 1544, 441, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [1876, 280, 530, 7870, 1441, 1050, 407, 1974, 309, 940, 507, 347, 297, 289, 592, 506, 4070, 287, 7877, 1868, 4959, 398, 2037, 575, 603, 534, 555, 7124, 818, 313, 381, 555, 786, 7864, 1441, 7877, 1622, 305, 283, 2324, 1471, 3109, 325, 296, 381, 575, 5404, 1021, 355, 769, 2090, 7880, 344, 3110, 427, 319, 4838, 366, 506, 1737, 7883]], "start_timestamps": [[0.24, 0.48, 0.64, 0.88, 1.12, 1.36, 1.44, 1.6, 1.76, 2.0, 2.16, 2.24, 2.4, 2.48, 2.56, 2.72, 2.88, 3.04, 3.12, 3.2800000000000002, 3.44, 3.6, 3.7600000000000002, 3.92, 4.08, 4.24, 4.4, 4.48, 4.72, 4.96, 5.36, 5.6000000000000005], [0.32, 0.64, 0.88, 1.04, 1.2, 1.44, 1.68, 1.84, 1.92, 2.0, 2.16, 2.4, 2.56, 2.72, 2.96, 3.12, 3.36, 3.6, 3.92, 4.16, 4.32], [0.32, 0.64, 0.72, 0.96, 1.12, 1.36, 1.6, 1.84, 2.08, 2.24, 2.48, 2.64, 2.8000000000000003, 2.88, 3.04, 3.2, 3.44, 3.68, 3.84, 4.08, 4.4, 4.5600000000000005, 4.72, 4.96, 5.12, 5.36, 5.5200000000000005, 5.68, 5.92, 6.16, 6.24, 6.4, 6.5600000000000005, 6.72, 6.96, 7.28, 7.6000000000000005, 7.92, 8.16, 8.32, 8.48, 8.72, 8.88, 8.96, 9.120000000000001, 9.28, 9.44, 9.68, 9.76, 9.92, 10.16, 10.24, 10.4, 10.64, 10.88, 10.96, 11.200000000000001, 11.36, 11.52, 11.84, 12.16]], "end_timestamps": [[0.48, 0.64, 0.88, 1.12, 1.36, 1.44, 1.6, 1.76, 1.92, 2.16, 2.24, 2.4, 2.48, 2.56, 2.64, 2.88, 3.04, 3.12, 3.12, 3.44, 3.6, 3.7600000000000002, 3.92, 4.08, 4.24, 4.4, 4.48, 4.72, 4.96, 5.12, 5.6000000000000005, 5.6000000000000005], [0.64, 0.88, 1.04, 1.2, 1.44, 1.68, 1.84, 1.84, 2.0, 2.16, 2.4, 2.56, 2.72, 2.96, 3.12, 3.36, 3.6, 3.92, 4.16, 4.32, 4.32], [0.64, 0.72, 0.96, 1.12, 1.36, 1.6, 1.84, 2.08, 2.24, 2.48, 2.64, 2.8000000000000003, 2.88, 3.04, 3.2, 3.44, 3.68, 3.84, 3.84, 4.4, 4.5600000000000005, 4.72, 4.96, 5.12, 5.36, 5.5200000000000005, 5.68, 5.92, 6.16, 6.24, 6.4, 6.5600000000000005, 6.72, 6.96, 7.28, 7.28, 7.92, 8.16, 8.24, 8.48, 8.72, 8.88, 8.96, 9.120000000000001, 9.200000000000001, 9.44, 9.68, 9.76, 9.92, 10.16, 10.24, 10.4, 10.64, 10.88, 10.96, 11.200000000000001, 11.36, 11.52, 11.84, 12.16, 12.16]], "token_durations": [[3, 2, 3, 3, 3, 1, 2, 2, 2, 2, 1, 2, 1, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 3, 3, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 3, 2, 2, 3, 3, 2, 1, 1, 2, 3, 2, 2, 3, 2, 3, 3, 4, 3, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 1, 3, 2, 3, 3, 3, 3, 2, 3, 2, 2, 1, 2, 2, 3, 3, 2, 3, 4, 2, 2, 3, 2, 3, 2, 2, 3, 3, 1, 2, 2, 2, 3, 4, 4, 4, 3, 1, 2, 3, 2, 1, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 3, 1, 3, 2, 2, 4, 4, 2]]} \ No newline at end of file +{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "Nor is mister Quilter's manner less interesting than his matter.", "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind."], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [5685, 508, 282, 3459, 1382, 305, 441, 7931, 7870, 698, 1742, 293, 561, 1091, 365, 381, 7098, 2745, 1544, 441, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [1876, 280, 530, 7870, 1441, 1050, 407, 1974, 309, 940, 507, 347, 297, 289, 592, 506, 4070, 287, 7877, 1868, 4959, 398, 2037, 575, 603, 534, 555, 7124, 818, 313, 381, 555, 786, 7864, 1441, 7877, 1622, 305, 283, 2324, 1471, 3109, 325, 296, 381, 575, 5404, 1021, 355, 769, 2090, 7880, 344, 3110, 427, 319, 4838, 366, 506, 1737, 7883]], "start_timestamps": [[0.24, 0.48, 0.64, 0.88, 1.12, 1.36, 1.44, 1.6, 1.76, 2.0, 2.16, 2.24, 2.4, 2.48, 2.56, 2.72, 2.88, 3.04, 3.12, 3.2800000000000002, 3.44, 3.6, 3.7600000000000002, 3.92, 4.08, 4.24, 4.4, 4.48, 4.72, 4.96, 5.36, 5.6000000000000005], [0.32, 0.64, 0.88, 1.04, 1.2, 1.44, 1.68, 1.84, 1.92, 2.0, 2.16, 2.4, 2.56, 2.72, 2.96, 3.12, 3.36, 3.6, 3.92, 4.16, 4.32], [0.32, 0.64, 0.72, 0.96, 1.12, 1.36, 1.6, 1.84, 2.08, 2.24, 2.48, 2.64, 2.8000000000000003, 2.88, 3.04, 3.2, 3.44, 3.68, 3.84, 4.08, 4.4, 4.5600000000000005, 4.72, 4.96, 5.12, 5.36, 5.5200000000000005, 5.68, 5.92, 6.16, 6.24, 6.4, 6.5600000000000005, 6.72, 6.96, 7.28, 7.6000000000000005, 7.92, 8.16, 8.32, 8.48, 8.72, 8.88, 8.96, 9.120000000000001, 9.28, 9.44, 9.68, 9.76, 9.92, 10.16, 10.24, 10.4, 10.64, 10.88, 10.96, 11.200000000000001, 11.36, 11.52, 11.84, 12.16]], "end_timestamps": [[0.48, 0.64, 0.88, 1.12, 1.36, 1.44, 1.6, 1.76, 1.92, 2.16, 2.24, 2.4, 2.48, 2.56, 2.64, 2.88, 3.04, 3.12, 3.12, 3.44, 3.6, 3.7600000000000002, 3.92, 4.08, 4.24, 4.4, 4.48, 4.72, 4.96, 5.12, 5.6000000000000005, 5.6000000000000005], [0.64, 0.88, 1.04, 1.2, 1.44, 1.68, 1.84, 1.84, 2.0, 2.16, 2.4, 2.56, 2.72, 2.96, 3.12, 3.36, 3.6, 3.92, 4.16, 4.32, 4.32], [0.64, 0.72, 0.96, 1.12, 1.36, 1.6, 1.84, 2.08, 2.24, 2.48, 2.64, 2.8000000000000003, 2.88, 3.04, 3.2, 3.44, 3.68, 3.84, 3.84, 4.4, 4.5600000000000005, 4.72, 4.96, 5.12, 5.36, 5.5200000000000005, 5.68, 5.92, 6.16, 6.24, 6.4, 6.5600000000000005, 6.72, 6.96, 7.28, 7.28, 7.92, 8.16, 8.24, 8.48, 8.72, 8.88, 8.96, 9.120000000000001, 9.200000000000001, 9.44, 9.68, 9.76, 9.92, 10.16, 10.24, 10.4, 10.64, 10.88, 10.96, 11.200000000000001, 11.36, 11.52, 11.84, 12.16, 12.16]], "token_durations": [[3, 2, 3, 3, 3, 1, 2, 2, 2, 2, 1, 2, 1, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 3, 3, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 3, 2, 2, 3, 3, 2, 1, 1, 2, 3, 2, 2, 3, 2, 3, 3, 4, 3, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 1, 3, 2, 3, 3, 3, 3, 2, 3, 2, 2, 1, 2, 2, 3, 3, 2, 3, 4, 2, 2, 3, 2, 3, 2, 2, 3, 3, 1, 2, 2, 2, 3, 4, 4, 4, 3, 1, 2, 3, 2, 1, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 3, 1, 3, 2, 2, 4, 4, 2]]} \ No newline at end of file diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 3591edd8b0d4..b4f6e69190f3 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -416,13 +416,14 @@ def __init__( parent, encoder_kwargs=None, is_training=True, - vocab_size=128, + vocab_size=129, decoder_hidden_size=64, num_decoder_layers=1, durations=None, hidden_act="relu", max_symbols_per_step=10, - pad_token_id=128, + pad_token_id=2, + blank_token_id=128, ): if encoder_kwargs is None: encoder_kwargs = {} @@ -445,6 +446,7 @@ def __init__( self.hidden_act = hidden_act self.max_symbols_per_step = max_symbols_per_step self.pad_token_id = pad_token_id + self.blank_token_id = blank_token_id def prepare_config_and_inputs(self): _, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs() @@ -461,6 +463,7 @@ def get_config(self): max_symbols_per_step=self.max_symbols_per_step, encoder_config=self.encoder_model_tester.get_config().to_dict(), pad_token_id=self.pad_token_id, + blank_token_id=self.blank_token_id, ) def create_and_check_model(self, config, input_features, attention_mask): From 86d980c11b3a6fa42f2754991fdbedeb4f92ea0c Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 6 Mar 2026 23:15:53 +0100 Subject: [PATCH 24/67] Regenerate lasr. --- .../models/lasr/configuration_lasr.py | 4 +-- src/transformers/models/lasr/modeling_lasr.py | 6 ++--- src/transformers/models/lasr/modular_lasr.py | 25 +++++++++++++++++-- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/lasr/configuration_lasr.py b/src/transformers/models/lasr/configuration_lasr.py index 4d82b85044a2..3cb525e20df6 100644 --- a/src/transformers/models/lasr/configuration_lasr.py +++ b/src/transformers/models/lasr/configuration_lasr.py @@ -159,9 +159,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range - super().__init__( - **kwargs, - ) + super().__init__(**kwargs) class LasrCTCConfig(PreTrainedConfig): diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py index 24fa4872a2a8..199686ee3d7d 100644 --- a/src/transformers/models/lasr/modeling_lasr.py +++ b/src/transformers/models/lasr/modeling_lasr.py @@ -36,6 +36,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoModel from .configuration_lasr import LasrCTCConfig, LasrEncoderConfig @@ -591,7 +592,7 @@ class LasrForCTC(LasrPreTrainedModel): def __init__(self, config: LasrCTCConfig): super().__init__(config) - self.encoder = LasrEncoder(config.encoder_config) + self.encoder = AutoModel.from_config(config.encoder_config) # Conv rather than linear to be consistent with NeMO decoding layer self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1) @@ -643,8 +644,7 @@ def forward( ) input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) - # assuming that padded tokens are filled with -100 - # when not being attended to + # assuming that padded tokens are filled with pad_token_id when not being attended to labels_mask = labels != self.config.pad_token_id target_lengths = labels_mask.sum(-1) flattened_targets = labels.masked_select(labels_mask) diff --git a/src/transformers/models/lasr/modular_lasr.py b/src/transformers/models/lasr/modular_lasr.py index 7435ef3c43cd..9bfeb25c128b 100644 --- a/src/transformers/models/lasr/modular_lasr.py +++ b/src/transformers/models/lasr/modular_lasr.py @@ -23,7 +23,7 @@ from ...masking_utils import create_bidirectional_mask from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import Unpack +from ...processing_utils import ProcessingKwargs, Unpack from ...tokenization_utils_tokenizers import TokenizersBackend from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import merge_with_config_defaults @@ -95,8 +95,29 @@ def _decode( ) +class LasrProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "audio_kwargs": { + "sampling_rate": 16000, + "padding": "longest", + "return_attention_mask": True, + }, + "text_kwargs": { + "padding": True, + "padding_side": "right", + "add_special_tokens": False, + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + class LasrProcessor(ParakeetProcessor): - pass + + def decode(self, *args, **kwargs): + raise NotImplementedError("Not needed") + + def _refine_timestamps_tdt(self, *args, **kwargs): + raise NotImplementedError("Not needed") class LasrEncoderConfig(ParakeetEncoderConfig): From ab21380ba2ca17087ff090ea08bb5e000045bf40 Mon Sep 17 00:00:00 2001 From: Eric B Date: Sat, 7 Mar 2026 09:09:21 +0100 Subject: [PATCH 25/67] Style checks and nits --- src/transformers/models/lasr/modular_lasr.py | 4 +- .../models/lasr/processing_lasr.py | 4 + .../models/parakeet/configuration_parakeet.py | 46 ++--- .../models/parakeet/modular_parakeet.py | 7 +- .../pipelines/automatic_speech_recognition.py | 26 +-- .../fixtures/parakeet/expected_tdt_loss.json | 6 +- .../parakeet/generate_tdt_loss_fixtures.py | 173 ------------------ .../models/parakeet/test_modeling_parakeet.py | 21 ++- 8 files changed, 47 insertions(+), 240 deletions(-) delete mode 100644 tests/models/parakeet/generate_tdt_loss_fixtures.py diff --git a/src/transformers/models/lasr/modular_lasr.py b/src/transformers/models/lasr/modular_lasr.py index 57cb25d86617..6665d38cde14 100644 --- a/src/transformers/models/lasr/modular_lasr.py +++ b/src/transformers/models/lasr/modular_lasr.py @@ -160,9 +160,9 @@ class LasrProcessorKwargs(ProcessingKwargs, total=False): class LasrProcessor(ParakeetProcessor): - def decode(self, *args, **kwargs): - raise NotImplementedError("Not needed") + """Forward arguments to [`~PreTrainedTokenizer.decode`].""" + self.tokenizer.decode(*args, **kwargs) def _refine_timestamps_tdt(self, *args, **kwargs): raise NotImplementedError("Not needed") diff --git a/src/transformers/models/lasr/processing_lasr.py b/src/transformers/models/lasr/processing_lasr.py index c1acaebaae07..b7216ae08a65 100644 --- a/src/transformers/models/lasr/processing_lasr.py +++ b/src/transformers/models/lasr/processing_lasr.py @@ -96,5 +96,9 @@ def model_input_names(self): feature_extractor_input_names = self.feature_extractor.model_input_names return feature_extractor_input_names + ["labels"] + def decode(self, *args, **kwargs): + """Forward arguments to [`~PreTrainedTokenizer.decode`].""" + self.tokenizer.decode(*args, **kwargs) + __all__ = ["LasrProcessor"] diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 88c1fb6613eb..8a41ab817865 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -182,38 +182,24 @@ def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): return cls(encoder_config=encoder_config.to_dict(), **kwargs) +@auto_docstring(checkpoint="bezzam/parakeet-tdt-0.6b-v3-hf") class ParakeetTDTConfig(PreTrainedConfig): r""" - This is the configuration class to store the configuration of a [`ParakeetForTDT`]. It is used to instantiate a - Parakeet TDT model according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the Parakeet TDT - [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) architecture. - - Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PreTrainedConfig`] for more information. - - Args: - vocab_size (`int`, *optional*, defaults to 8193): - Vocabulary size of the model. - decoder_hidden_size (`int`, *optional*, defaults to 640): - Hidden size of the LSTM prediction network and joint network. - num_decoder_layers (`int`, *optional*, defaults to 1): - Number of LSTM layers in the prediction network. - num_duration_bins (`int`, *optional*, defaults to 5): - Number of duration bins for predicting token durations. - durations (`list[int]`, *optional*, defaults to `[0, 1, 2, 3, 4]`): - Token duration values that can be predicted. Each value represents how many frames a token or blank - emission spans. - hidden_act (`str`, *optional*, defaults to `"relu"`): - The activation function in the joint network. - max_symbols_per_step (`int`, *optional*, defaults to 10): - Maximum number of symbols to emit per encoder time step during greedy decoding. - encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): - The config object or dictionary of the encoder. - pad_token_id (`int`, *optional*, defaults to 2): - Padding token id. - blank_token_id (`int`, *optional*, defaults to 8192): - Blank token id. Different from `pad_token_id` for TDT. + decoder_hidden_size (`int`, *optional*, defaults to 640): + Hidden size of the LSTM prediction network and joint network. + num_decoder_layers (`int`, *optional*, defaults to 1): + Number of LSTM layers in the prediction network. + num_duration_bins (`int`, *optional*, defaults to 5): + Number of duration bins for predicting token durations. + durations (`list[int]`, *optional*, defaults to `[0, 1, 2, 3, 4]`): + Token duration values that can be predicted. Each value represents how many frames a token or blank + emission spans. + max_symbols_per_step (`int`, *optional*, defaults to 10): + Maximum number of symbols to emit per encoder time step during greedy decoding. + encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): + The config object or dictionary of the encoder. + blank_token_id (`int`, *optional*, defaults to 8192): + Blank token id. Different from `pad_token_id` for TDT. Example: ```python diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 6e075dd4393d..e39fd7829e86 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -1058,7 +1058,9 @@ def generate( symbols_per_step = torch.zeros(batch_size, dtype=torch.long, device=device) last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) max_output_len = sequence_length * self.config.max_symbols_per_step - all_tokens_tensor = torch.full((batch_size, max_output_len), self.config.pad_token_id, dtype=torch.long, device=device) + all_tokens_tensor = torch.full( + (batch_size, max_output_len), self.config.pad_token_id, dtype=torch.long, device=device + ) token_counts = torch.zeros(batch_size, dtype=torch.long, device=device) if return_timestamps: all_frame_indices = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) @@ -1101,8 +1103,7 @@ def generate( projected_encoder_frames = projected_encoder_output[batch_indices, safe_time_indices].unsqueeze(1) token_logits, duration_logits = self.joint( - decoder_output, - projected_encoder_output=projected_encoder_frames + decoder_output, projected_encoder_output=projected_encoder_frames ) token_logits = token_logits.squeeze(1).to(device) duration_logits = duration_logits.squeeze(1).to(device) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index f7af0df8fe69..9b5ab3c7ff0f 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -137,37 +137,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): model ([`PreTrainedModel`]): The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from [`PreTrainedModel`]. - feature_extractor ([`SequenceFeatureExtractor`]): + feature_extractor ([`SequenceFeatureExtractor`], *optional*): The feature extractor that will be used by the pipeline to encode waveform for the model. - tokenizer ([`PreTrainedTokenizer`]): + tokenizer ([`PreTrainedTokenizer`], *optional*): The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from [`PreTrainedTokenizer`]. decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*): [PyCTCDecode's BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180) can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information. - chunk_length_s (`float`, *optional*, defaults to 0): - The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). - - - - For more information on how to effectively use `chunk_length_s`, please have a look at the [ASR chunking - blog post](https://huggingface.co/blog/asr-chunking). - - - - stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`): - The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables - the model to *see* more context and infer letters better than without this context but the pipeline - discards the stride bits at the end to make the final reconstitution as perfect as possible. - - - - For more information on how to effectively use `stride_length_s`, please have a look at the [ASR chunking - blog post](https://huggingface.co/blog/asr-chunking). - - - device (Union[`int`, `torch.device`], *optional*): Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the model on the associated CUDA device id. diff --git a/tests/fixtures/parakeet/expected_tdt_loss.json b/tests/fixtures/parakeet/expected_tdt_loss.json index b8177341adcd..f129fd5f01ac 100644 --- a/tests/fixtures/parakeet/expected_tdt_loss.json +++ b/tests/fixtures/parakeet/expected_tdt_loss.json @@ -1,5 +1,5 @@ { - "_comment": "Generated by generate_tdt_loss_fixtures.py using NeMo's TDTLossPytorch (CPU-patched). Inputs use torch.manual_seed(42), batch=2, T=8, U=4, vocab=5, durations=[0,1,2,3,4].", + "_comment": "Generated by generate_tdt_loss_fixtures.py using NeMo's TDTLossPytorch. Inputs use torch.manual_seed(42), batch=2, T=8, U=4, vocab=5, durations=[0,1,2,3,4].", "seed": 42, "batch_size": 2, "max_t": 8, @@ -34,6 +34,6 @@ 4, 3 ], - "expected_loss_sum": 21.978168487548828, - "expected_loss_mean": 3.12455415725708 + "expected_loss_sum": 21.978166580200195, + "expected_loss_mean": 3.124553918838501 } \ No newline at end of file diff --git a/tests/models/parakeet/generate_tdt_loss_fixtures.py b/tests/models/parakeet/generate_tdt_loss_fixtures.py deleted file mode 100644 index 582ac7e51333..000000000000 --- a/tests/models/parakeet/generate_tdt_loss_fixtures.py +++ /dev/null @@ -1,173 +0,0 @@ -""" -Generate TDT loss reference fixtures using NeMo's TDTLossPytorch. - -Usage (requires NeMo installed, no CUDA needed): - python tests/models/parakeet/generate_tdt_loss_fixtures.py - -Outputs: - tests/fixtures/parakeet/expected_tdt_loss.json - -The fixture contains deterministic inputs and expected loss values -computed by NeMo's TDTLossPytorch. Our tdt_loss implementation is -tested against these values in test_modeling_parakeet.py::TDTLossTest. -""" - -import json -import os - -import torch - - -def make_test_inputs(): - torch.manual_seed(42) - batch_size, max_t, max_u, vocab_size, num_durations = 2, 8, 4, 5, 5 - blank = vocab_size - - combined_logits = torch.randn(batch_size, max_t, max_u + 1, vocab_size + 1 + num_durations) - targets = torch.randint(0, vocab_size, (batch_size, max_u)) - logit_lengths = torch.tensor([max_t, max_t - 1]) - target_lengths = torch.tensor([max_u, max_u - 1]) - - return { - "combined_logits": combined_logits, - "token_logits": combined_logits[..., : vocab_size + 1], - "duration_logits": combined_logits[..., vocab_size + 1 :], - "targets": targets, - "logit_lengths": logit_lengths, - "target_lengths": target_lengths, - "blank": blank, - "durations": [0, 1, 2, 3, 4], - } - - -def _patched_compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens): - """NeMo's compute_forward_prob with .cuda() replaced by device-aware allocation. - - This is identical to NeMo's TDTLossPytorch.compute_forward_prob except - `log_alpha = log_alpha.cuda()` is replaced with `device=acts.device`, and - `torch.Tensor([-1000.0]).cuda()[0]` is replaced with `torch.tensor(-1000.0, device=acts.device)`. - The loss math is unchanged. - """ - B, T, U, _ = acts.shape - log_alpha = torch.zeros(B, T, U, device=acts.device) - - for b in range(B): - for t in range(T): - for u in range(U): - if u == 0: - if t == 0: - log_alpha[b, t, u] = 0.0 - else: - log_alpha[b, t, u] = -1000.0 - for n, l in enumerate(self.durations): - if t - l >= 0 and l > 0: - tmp = ( - log_alpha[b, t - l, u] - + acts[b, t - l, u, self.blank] - + duration_acts[b, t - l, u, n] - ) - log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) - else: - log_alpha[b, t, u] = -1000.0 - for n, l in enumerate(self.durations): - if t - l >= 0: - if l > 0: - tmp = ( - log_alpha[b, t - l, u] - + acts[b, t - l, u, self.blank] - + duration_acts[b, t - l, u, n] - ) - log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) - tmp = ( - log_alpha[b, t - l, u - 1] - + acts[b, t - l, u - 1, labels[b, u - 1]] - + duration_acts[b, t - l, u - 1, n] - ) - log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) - - log_probs = [] - for b in range(B): - tt = torch.tensor(-1000.0, device=acts.device) - for n, l in enumerate(self.durations): - if act_lens[b] - l >= 0 and l > 0: - bb = ( - log_alpha[b, act_lens[b] - l, label_lens[b]] - + acts[b, act_lens[b] - l, label_lens[b], self.blank] - + duration_acts[b, act_lens[b] - l, label_lens[b], n] - ) - tt = self.logsumexp(bb, 1.0 * tt) - log_probs.append(tt) - - return torch.stack(log_probs), log_alpha - - -def compute_nemo_reference(inputs): - """Run NeMo's TDTLossPytorch. - - On CPU, monkey-patches compute_forward_prob to avoid NeMo's hardcoded .cuda(). - On CUDA, runs NeMo unmodified. - """ - import nemo.collections.asr.losses.rnnt_pytorch as rnnt_mod - - need_patch = not torch.cuda.is_available() - orig = None - if need_patch: - print("No CUDA available — patching NeMo's compute_forward_prob for CPU (math unchanged)") - orig = rnnt_mod.TDTLossPytorch.compute_forward_prob - rnnt_mod.TDTLossPytorch.compute_forward_prob = _patched_compute_forward_prob - - results = {} - for reduction in ["sum", "mean"]: - loss_fn = rnnt_mod.TDTLossPytorch( - blank=inputs["blank"], - durations=inputs["durations"], - reduction=reduction, - sigma=0.0, - ) - loss = loss_fn( - acts=inputs["combined_logits"], - labels=inputs["targets"], - act_lens=inputs["logit_lengths"], - label_lens=inputs["target_lengths"], - ) - results[reduction] = loss.item() - print(f"NeMo TDT loss (reduction={reduction}): {loss.item():.10f}") - - if orig is not None: - rnnt_mod.TDTLossPytorch.compute_forward_prob = orig - - return results - - -def main(): - inputs = make_test_inputs() - nemo_results = compute_nemo_reference(inputs) - - fixture = { - "_comment": "Generated by generate_tdt_loss_fixtures.py using NeMo's TDTLossPytorch. " - "Inputs use torch.manual_seed(42), batch=2, T=8, U=4, vocab=5, durations=[0,1,2,3,4].", - "seed": 42, - "batch_size": 2, - "max_t": 8, - "max_u": 4, - "vocab_size": 5, - "durations": [0, 1, 2, 3, 4], - "targets": inputs["targets"].tolist(), - "logit_lengths": inputs["logit_lengths"].tolist(), - "target_lengths": inputs["target_lengths"].tolist(), - "expected_loss_sum": nemo_results["sum"], - "expected_loss_mean": nemo_results["mean"], - } - - output_path = os.path.join(os.path.dirname(__file__), "..", "..", "fixtures", "parakeet", "expected_tdt_loss.json") - output_path = os.path.normpath(output_path) - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - with open(output_path, "w") as f: - json.dump(fixture, f, indent=2) - - print(f"\nFixture written to {output_path}") - - -if __name__ == "__main__": - main() diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index b4f6e69190f3..b1382bae7be5 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -46,8 +46,7 @@ @require_torch class TDTLossTest(unittest.TestCase): """Test tdt_loss against reference values generated by NeMo's TDTLossPytorch. - - Fixture generated with: tests/models/parakeet/generate_tdt_loss_fixtures.py + reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-generate_tdt_loss_fixtures-py """ FIXTURE_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_tdt_loss.json" @@ -85,20 +84,20 @@ def test_tdt_loss_sum(self): inputs = self._make_inputs() loss = tdt_loss(**inputs, reduction="sum") expected = torch.tensor(self.fixture["expected_loss_sum"]) - torch.testing.assert_close(loss, expected, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(loss, expected) def test_tdt_loss_mean(self): inputs = self._make_inputs() loss = tdt_loss(**inputs, reduction="mean") expected = torch.tensor(self.fixture["expected_loss_mean"]) - torch.testing.assert_close(loss, expected, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(loss, expected) def test_tdt_loss_none(self): inputs = self._make_inputs() losses = tdt_loss(**inputs, reduction="none") self.assertEqual(losses.shape, (self.fixture["batch_size"],)) expected_sum = torch.tensor(self.fixture["expected_loss_sum"]) - torch.testing.assert_close(losses.sum(), expected_sum, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(losses.sum(), expected_sum) def test_tdt_loss_with_sigma(self): inputs = self._make_inputs() @@ -218,6 +217,10 @@ class ParakeetEncoderModelTest(ModelTesterMixin, unittest.TestCase): test_resize_embeddings = False + @unittest.skip(reason="No available flash-SDPA kernels for Parakeet test shapes on this setup") + def test_sdpa_can_dispatch_on_flash(self): + pass + def setUp(self): self.model_tester = ParakeetEncoderModelTester(self) self.config_tester = ConfigTester(self, config_class=ParakeetEncoderConfig, has_text_modality=False) @@ -297,6 +300,10 @@ class ParakeetForCTCModelTest(ModelTesterMixin, unittest.TestCase): test_resize_embeddings = False _is_composite = True + @unittest.skip(reason="No available flash-SDPA kernels for Parakeet test shapes on this setup") + def test_sdpa_can_dispatch_on_flash(self): + pass + def setUp(self): self.model_tester = ParakeetForCTCModelTester(self) self.config_tester = ConfigTester(self, config_class=ParakeetCTCConfig) @@ -503,6 +510,10 @@ class ParakeetForTDTModelTest(ModelTesterMixin, unittest.TestCase): test_resize_embeddings = False _is_composite = True + @unittest.skip(reason="No available flash-SDPA kernels for Parakeet test shapes on this setup") + def test_sdpa_can_dispatch_on_flash(self): + pass + def setUp(self): self.model_tester = ParakeetForTDTModelTester(self) self.config_tester = ConfigTester(self, config_class=ParakeetTDTConfig) From d0141d5f0aff154f1ae5ca2051e265de7969d2e7 Mon Sep 17 00:00:00 2001 From: Eric B Date: Sat, 7 Mar 2026 09:40:24 +0100 Subject: [PATCH 26/67] Nits, put back ctc loss test --- docs/source/en/model_doc/parakeet.md | 4 +-- .../models/parakeet/test_modeling_parakeet.py | 34 ++++++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index 9dd03ad00bfc..d0cd1ffe9c34 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -259,7 +259,7 @@ model.train() ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) speech_samples = [el['array'] for el in ds["audio"][:NUM_SAMPLES]] -text_samples = [el for el in ds["text"][:NUM_SAMPLES]] +text_samples = ds["text"][:NUM_SAMPLES] # passing `text` to the processor will prepare inputs' `labels` key inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate) @@ -287,7 +287,7 @@ model.train() ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) speech_samples = [el['array'] for el in ds["audio"][:NUM_SAMPLES]] -text_samples = [el for el in ds["text"][:NUM_SAMPLES]] +text_samples = ds["text"][:NUM_SAMPLES] # passing `text` to the processor will prepare inputs' `labels` key inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate) diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index b1382bae7be5..b92244ade41e 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -22,7 +22,7 @@ from transformers.testing_utils import cleanup, require_torch, slow, torch_device from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, random_attention_mask +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask if is_datasets_available(): @@ -210,6 +210,34 @@ def prepare_config_and_inputs_for_common(self): } return config, inputs_dict + def check_ctc_loss(self, config, input_values, *args): + model = ParakeetForCTC(config=config) + model.to(torch_device) + + # make sure that dropout is disabled + model.eval() + + input_values = input_values[:3] + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0 + + model.config.ctc_loss_reduction = "sum" + sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() + + model.config.ctc_loss_reduction = "mean" + mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() + + self.parent.assertTrue(isinstance(sum_loss, float)) + self.parent.assertTrue(isinstance(mean_loss, float)) + @require_torch class ParakeetEncoderModelTest(ModelTesterMixin, unittest.TestCase): @@ -283,6 +311,10 @@ def prepare_config_and_inputs_for_common(self): } return config, inputs_dict + def test_ctc_loss_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.encoder_model_tester.check_ctc_loss(*config_and_inputs) + @require_torch class ParakeetForCTCModelTest(ModelTesterMixin, unittest.TestCase): From f7529d410fbaedd6580d56f39476939f5dae0b4d Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 10 Mar 2026 11:44:07 +0100 Subject: [PATCH 27/67] More standard model output. --- docs/source/en/model_doc/parakeet.md | 78 ++++++++- .../models/parakeet/modeling_parakeet.py | 162 +++++++++++------- .../models/parakeet/modular_parakeet.py | 162 +++++++++++------- .../fixtures/parakeet/expected_tdt_loss.json | 8 +- .../models/parakeet/test_modeling_parakeet.py | 22 +-- 5 files changed, 283 insertions(+), 149 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index d0cd1ffe9c34..7c8a7d099fab 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -120,9 +120,6 @@ output = model.generate(**inputs, return_dict_in_generate=True) print(processor.decode(output.sequences, skip_special_tokens=True)) ``` - - - @@ -272,13 +269,18 @@ outputs.loss.backward() ### TDT Training -```python +The TDT loss has been implemented within Transformers to enable training. For faster training (around 10-50x depending on batch size), consider using NeMo's `TDTLossNumba`. Note that this requires installing the NeMo toolkit with `pip install nemo_toolkit[asr]`. + + + + +```py from datasets import Audio, load_dataset import torch from transformers import AutoModelForTDT, AutoProcessor model_id = "nvidia/parakeet-tdt-0.6b-v3-hf" -NUM_SAMPLES = 3 +NUM_SAMPLES = 4 processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForTDT.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") @@ -298,6 +300,72 @@ print("Loss:", outputs.loss.item()) outputs.loss.backward() ``` + + + +```py +import torch +from datasets import Audio, load_dataset +from nemo.collections.asr.losses.rnnt import TDTLossNumba +from transformers import AutoModelForTDT, AutoProcessor + + +model_id = "nvidia/parakeet-tdt-0.6b-v3-hf" +NUM_SAMPLES = 4 + +# Load model and processor +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForTDT.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") +model.train() + +# Initialize NeMo TDT loss +# NOTE: NeMo's TDTLossNumba doesn't seem to do normalization with target lengths as suggested by its docstring so doing manually: +# - Docstring: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py#L373 +# - Normalization: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py#L247-L253 +loss_fn = TDTLossNumba( + blank=model.config.blank_token_id, + durations=model.config.durations, + reduction="none", +) + +# Load dataset +ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") +ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) +speech_samples = [el["array"] for el in ds["audio"][:NUM_SAMPLES]] +text_samples = ds["text"][:NUM_SAMPLES] + +# Prepare inputs +inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate) +inputs.to(device=model.device, dtype=model.dtype) + +# Forward pass without computing loss +outputs = model(**inputs, compute_loss=False) + +# Prepare inputs for NeMo TDT loss +# -- NOTE: convert to float32 for NeMo loss since Numba doesn't support float16/bfloat16, but keep labels as integers +encoder_lengths = torch.full((outputs.last_hidden_state.shape[0],), outputs.last_hidden_state.shape[1], dtype=torch.long, device=model.device) +labels = inputs["labels"] +target_lengths = (labels != model.config.pad_token_id).sum(-1) +losses = loss_fn( + acts=outputs.logits.float(), + labels=labels.long(), + act_lens=encoder_lengths.long(), + label_lens=target_lengths.long(), +) + +# Normalize by target lengths +loss = (losses / target_lengths.float()).mean() +print(f"Loss (NeMo TDTLossNumba): {loss.item():.6f}") + +# Backward pass +loss.backward() +print("\nāœ“ Successfully computed loss and gradients using NeMo's fast TDT loss!") +``` + + + + + ## ParakeetTokenizer [[autodoc]] ParakeetTokenizer diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index fc4926caf39a..eead9d080ff1 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -664,36 +664,6 @@ class ParakeetCTCGenerateOutput(ModelOutput): hidden_states: tuple[tuple[torch.FloatTensor]] | None = None -@dataclass -class ParakeetTDTGenerateOutput(ModelOutput): - """ - Outputs of Parakeet TDT model generation. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Token-level timestamps in seconds indicating when each token was emitted. Only returned when - `return_timestamps=True` is passed to `generate()`. - token_durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Token-level durations in frames indicating how many frames each token spans. Only returned when - `return_timestamps=True` is passed to `generate()`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape - `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions from the encoder. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape - `(batch_size, sequence_length, hidden_size)`. Hidden states from the encoder. - """ - - sequences: torch.LongTensor - token_timestamps: torch.FloatTensor | None = None - token_durations: torch.LongTensor | None = None - attentions: tuple[tuple[torch.FloatTensor]] | None = None - hidden_states: tuple[tuple[torch.FloatTensor]] | None = None - - @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -901,6 +871,61 @@ def forward( return self.token_head(joint_output), self.duration_head(joint_output) +@dataclass +class ParakeetTDTGenerateOutput(ModelOutput): + """ + Outputs of Parakeet TDT model generation. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Token-level timestamps in seconds indicating when each token was emitted. Only returned when + `return_timestamps=True` is passed to `generate()`. + token_durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Token-level durations in frames indicating how many frames each token spans. Only returned when + `return_timestamps=True` is passed to `generate()`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions from the encoder. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`. Hidden states from the encoder. + """ + + sequences: torch.LongTensor + token_timestamps: torch.FloatTensor | None = None + token_durations: torch.LongTensor | None = None + attentions: tuple[tuple[torch.FloatTensor]] | None = None + hidden_states: tuple[tuple[torch.FloatTensor]] | None = None + + +@dataclass +class ParakeetTDTOutput(ModelOutput): + """ + Output structure for Parakeet TDT forward pass. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Last hidden state from the encoder. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Hidden states from the encoder. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Attention mask for the encoder. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size + num_durations)`, *optional*): + Joint token and duration logits computed from the encoder and decoder outputs. Only returned when `labels` are provided to the forward pass. + loss (`torch.FloatTensor`, *optional*): + The loss computed from the TDT loss function. Only returned when `labels` are provided to the forward pass. + """ + + last_hidden_state: torch.Tensor + hidden_states: tuple[tuple[torch.FloatTensor]] | None = None + attentions: tuple[tuple[torch.FloatTensor]] | None = None + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + + # TODO (ebezzam) eventually move to audio_utils or loss_utils for common usage? def tdt_loss( token_logits: torch.Tensor, @@ -908,7 +933,7 @@ def tdt_loss( targets: torch.Tensor, logit_lengths: torch.Tensor, target_lengths: torch.Tensor, - blank: int, + blank_token_id: int, durations: list[int], sigma: float = 0.0, reduction: str = "mean", @@ -916,10 +941,9 @@ def tdt_loss( """ Compute TDT (Token-and-Duration Transducer) loss (https://arxiv.org/abs/2304.06795). - Ported from NeMo's `TDTLossPytorch`. Unlike standard RNNT loss, this loss trains both - the token prediction head and the duration prediction head. Uses vectorized anti-diagonal - processing for efficiency: all (t, u) pairs on each anti-diagonal t+u=n are computed in - parallel as batched tensor operations. + Ported from NeMo's `TDTLossPytorch` with anti-diagonal processing. Unlike standard RNNT loss, this loss trains both + the token prediction head and the duration prediction head. It uses vectorized anti-diagonal processing for + efficiency: all (t, u) pairs on each anti-diagonal t+u=n are computed in parallel as batched tensor operations. Args: token_logits: Token logits of shape `(batch, T, U+1, vocab_size+1)`. @@ -927,7 +951,7 @@ def tdt_loss( targets: Target labels of shape `(batch, U)`. logit_lengths: Encoder output lengths of shape `(batch,)`. target_lengths: Target lengths of shape `(batch,)`. - blank: Blank token id. + blank_token_id: Blank token id. durations: List of duration values (e.g., `[0, 1, 2, 3, 4]`). sigma: Logit undernormalization constant (see TDT paper). Defaults to `0.0`. reduction: Loss reduction method. One of `"mean"`, `"sum"`, or `"none"`. Defaults to `"mean"`. @@ -947,7 +971,7 @@ def tdt_loss( log_alpha[:, 0, 0] = 0.0 # Precompute blank and label log-probs for vectorized access - blank_log_probs = token_log_probs[:, :, :, blank] + blank_log_probs = token_log_probs[:, :, :, blank_token_id] if max_u > 1: targets_expanded = targets.unsqueeze(1).expand(-1, max_t, -1) # (batch, T, U_labels) @@ -962,17 +986,14 @@ def tdt_loss( u_start = max(0, n - max_t + 1) u_end = min(n + 1, max_u) u_indices = torch.arange(u_start, u_end, device=device) - t_indices = n - u_indices + t_indices = n - u_indices all_candidates = [] - for i, dur in enumerate(durations): t_prev = t_indices - dur valid_t = t_prev >= 0 - if not valid_t.any(): continue - t_src = t_prev.clamp(min=0) # Blank arcs (dur > 0): from (t-dur, u) to (t, u) @@ -1018,7 +1039,7 @@ def tdt_loss( t_clamped = t_final.clamp(min=0) terminal = ( log_alpha[batch_idx, t_clamped, target_lengths] - + token_log_probs[batch_idx, t_clamped, target_lengths, blank] + + token_log_probs[batch_idx, t_clamped, target_lengths, blank_token_id] + duration_log_probs[batch_idx, t_clamped, target_lengths, i] ) combined = torch.stack([log_probs, terminal], dim=0) @@ -1030,10 +1051,7 @@ def tdt_loss( return (losses / target_lengths.float()).mean() elif reduction == "sum": return losses.sum() - elif reduction == "none": - return losses - else: - return (losses / target_lengths.float()).mean() + return losses @auto_docstring( @@ -1059,9 +1077,17 @@ def forward( input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, + compute_loss: bool | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutput: + ) -> ParakeetTDTOutput: r""" + Args: + compute_loss (`bool`, *optional*, defaults to `False`): + Whether to compute the loss when the `labels` argument is provided. If `False`, the model will compute + the joint token and duration logits but will not compute the TDT loss, even if `labels` are provided. + This can be useful for cases where you want to compute the loss separately, e.g. with NeMo's TDT loss + implementation. + Example: ```python @@ -1085,10 +1111,11 @@ def forward( **kwargs, ) - encoder_hidden_states = encoder_outputs.last_hidden_state - - loss = None + loss, logits = None, None if labels is not None: + if compute_loss is None: + compute_loss = True + # Compute encoder output lengths attention_mask = ( attention_mask @@ -1108,23 +1135,26 @@ def forward( decoder_output, _, _ = self.decoder(decoder_input) token_logits, duration_logits = self.joint( decoder_output=decoder_output.unsqueeze(1), - encoder_output=encoder_hidden_states.unsqueeze(2), - ) - - loss = tdt_loss( - token_logits=token_logits.float(), - duration_logits=duration_logits.float(), - targets=labels.to(token_logits.device).int(), - logit_lengths=encoder_lengths.to(token_logits.device).int(), - target_lengths=target_lengths.to(token_logits.device).int(), - blank=self.config.blank_token_id, - durations=self.config.durations, - reduction="mean", + encoder_output=encoder_outputs.last_hidden_state.unsqueeze(2), ) + logits = torch.cat([token_logits, duration_logits], dim=-1) + + if compute_loss: + loss = tdt_loss( + token_logits=token_logits.float(), + duration_logits=duration_logits.float(), + targets=labels.to(token_logits.device).int(), + logit_lengths=encoder_lengths.to(token_logits.device).int(), + target_lengths=target_lengths.to(token_logits.device).int(), + blank_token_id=self.config.blank_token_id, + durations=self.config.durations, + reduction="mean", + ) - return CausalLMOutput( + return ParakeetTDTOutput( loss=loss, - logits=encoder_hidden_states, + logits=logits, + last_hidden_state=encoder_outputs.last_hidden_state, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @@ -1176,14 +1206,14 @@ def generate( kwargs["return_dict"] = True if return_timestamps: return_dict_in_generate = True - outputs: CausalLMOutput = self.forward( + outputs = self.forward( input_features=input_features, attention_mask=attention_mask, **kwargs, ) # greedy TDT decoding, `GreedyBatchedTDTLabelLoopingComputer.torch_impl` in NeMo - encoder_hidden_states = outputs.logits + encoder_hidden_states = outputs.last_hidden_state batch_size, sequence_length = encoder_hidden_states.shape[:2] device = encoder_hidden_states.device if attention_mask is not None: diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index e39fd7829e86..6a42e243e0f7 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -512,36 +512,6 @@ class ParakeetCTCGenerateOutput(ModelOutput): hidden_states: tuple[tuple[torch.FloatTensor]] | None = None -@dataclass -class ParakeetTDTGenerateOutput(ModelOutput): - """ - Outputs of Parakeet TDT model generation. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Token-level timestamps in seconds indicating when each token was emitted. Only returned when - `return_timestamps=True` is passed to `generate()`. - token_durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Token-level durations in frames indicating how many frames each token spans. Only returned when - `return_timestamps=True` is passed to `generate()`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape - `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions from the encoder. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape - `(batch_size, sequence_length, hidden_size)`. Hidden states from the encoder. - """ - - sequences: torch.LongTensor - token_timestamps: torch.FloatTensor | None = None - token_durations: torch.LongTensor | None = None - attentions: tuple[tuple[torch.FloatTensor]] | None = None - hidden_states: tuple[tuple[torch.FloatTensor]] | None = None - - @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -732,7 +702,7 @@ def tdt_loss( targets: torch.Tensor, logit_lengths: torch.Tensor, target_lengths: torch.Tensor, - blank: int, + blank_token_id: int, durations: list[int], sigma: float = 0.0, reduction: str = "mean", @@ -740,10 +710,9 @@ def tdt_loss( """ Compute TDT (Token-and-Duration Transducer) loss (https://arxiv.org/abs/2304.06795). - Ported from NeMo's `TDTLossPytorch`. Unlike standard RNNT loss, this loss trains both - the token prediction head and the duration prediction head. Uses vectorized anti-diagonal - processing for efficiency: all (t, u) pairs on each anti-diagonal t+u=n are computed in - parallel as batched tensor operations. + Ported from NeMo's `TDTLossPytorch` with anti-diagonal processing. Unlike standard RNNT loss, this loss trains both + the token prediction head and the duration prediction head. It uses vectorized anti-diagonal processing for + efficiency: all (t, u) pairs on each anti-diagonal t+u=n are computed in parallel as batched tensor operations. Args: token_logits: Token logits of shape `(batch, T, U+1, vocab_size+1)`. @@ -751,7 +720,7 @@ def tdt_loss( targets: Target labels of shape `(batch, U)`. logit_lengths: Encoder output lengths of shape `(batch,)`. target_lengths: Target lengths of shape `(batch,)`. - blank: Blank token id. + blank_token_id: Blank token id. durations: List of duration values (e.g., `[0, 1, 2, 3, 4]`). sigma: Logit undernormalization constant (see TDT paper). Defaults to `0.0`. reduction: Loss reduction method. One of `"mean"`, `"sum"`, or `"none"`. Defaults to `"mean"`. @@ -771,7 +740,7 @@ def tdt_loss( log_alpha[:, 0, 0] = 0.0 # Precompute blank and label log-probs for vectorized access - blank_log_probs = token_log_probs[:, :, :, blank] + blank_log_probs = token_log_probs[:, :, :, blank_token_id] if max_u > 1: targets_expanded = targets.unsqueeze(1).expand(-1, max_t, -1) # (batch, T, U_labels) @@ -786,17 +755,14 @@ def tdt_loss( u_start = max(0, n - max_t + 1) u_end = min(n + 1, max_u) u_indices = torch.arange(u_start, u_end, device=device) - t_indices = n - u_indices + t_indices = n - u_indices all_candidates = [] - for i, dur in enumerate(durations): t_prev = t_indices - dur valid_t = t_prev >= 0 - if not valid_t.any(): continue - t_src = t_prev.clamp(min=0) # Blank arcs (dur > 0): from (t-dur, u) to (t, u) @@ -842,7 +808,7 @@ def tdt_loss( t_clamped = t_final.clamp(min=0) terminal = ( log_alpha[batch_idx, t_clamped, target_lengths] - + token_log_probs[batch_idx, t_clamped, target_lengths, blank] + + token_log_probs[batch_idx, t_clamped, target_lengths, blank_token_id] + duration_log_probs[batch_idx, t_clamped, target_lengths, i] ) combined = torch.stack([log_probs, terminal], dim=0) @@ -854,10 +820,7 @@ def tdt_loss( return (losses / target_lengths.float()).mean() elif reduction == "sum": return losses.sum() - elif reduction == "none": - return losses - else: - return (losses / target_lengths.float()).mean() + return losses class ParakeetTDTJointNetwork(nn.Module): @@ -884,6 +847,61 @@ def forward( return self.token_head(joint_output), self.duration_head(joint_output) +@dataclass +class ParakeetTDTGenerateOutput(ModelOutput): + """ + Outputs of Parakeet TDT model generation. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Token-level timestamps in seconds indicating when each token was emitted. Only returned when + `return_timestamps=True` is passed to `generate()`. + token_durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Token-level durations in frames indicating how many frames each token spans. Only returned when + `return_timestamps=True` is passed to `generate()`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions from the encoder. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`. Hidden states from the encoder. + """ + + sequences: torch.LongTensor + token_timestamps: torch.FloatTensor | None = None + token_durations: torch.LongTensor | None = None + attentions: tuple[tuple[torch.FloatTensor]] | None = None + hidden_states: tuple[tuple[torch.FloatTensor]] | None = None + + +@dataclass +class ParakeetTDTOutput(ModelOutput): + """ + Output structure for Parakeet TDT forward pass. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Last hidden state from the encoder. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Hidden states from the encoder. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Attention mask for the encoder. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size + num_durations)`, *optional*): + Joint token and duration logits computed from the encoder and decoder outputs. Only returned when `labels` are provided to the forward pass. + loss (`torch.FloatTensor`, *optional*): + The loss computed from the TDT loss function. Only returned when `labels` are provided to the forward pass. + """ + + last_hidden_state: torch.Tensor + hidden_states: tuple[tuple[torch.FloatTensor]] | None = None + attentions: tuple[tuple[torch.FloatTensor]] | None = None + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" Parakeet Encoder with a TDT (Token Duration Transducer) head. @@ -907,9 +925,17 @@ def forward( input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, + compute_loss: bool | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutput: + ) -> ParakeetTDTOutput: r""" + Args: + compute_loss (`bool`, *optional*, defaults to `False`): + Whether to compute the loss when the `labels` argument is provided. If `False`, the model will compute + the joint token and duration logits but will not compute the TDT loss, even if `labels` are provided. + This can be useful for cases where you want to compute the loss separately, e.g. with NeMo's TDT loss + implementation. + Example: ```python @@ -933,10 +959,11 @@ def forward( **kwargs, ) - encoder_hidden_states = encoder_outputs.last_hidden_state - - loss = None + loss, logits = None, None if labels is not None: + if compute_loss is None: + compute_loss = True + # Compute encoder output lengths attention_mask = ( attention_mask @@ -956,23 +983,26 @@ def forward( decoder_output, _, _ = self.decoder(decoder_input) token_logits, duration_logits = self.joint( decoder_output=decoder_output.unsqueeze(1), - encoder_output=encoder_hidden_states.unsqueeze(2), - ) - - loss = tdt_loss( - token_logits=token_logits.float(), - duration_logits=duration_logits.float(), - targets=labels.to(token_logits.device).int(), - logit_lengths=encoder_lengths.to(token_logits.device).int(), - target_lengths=target_lengths.to(token_logits.device).int(), - blank=self.config.blank_token_id, - durations=self.config.durations, - reduction="mean", + encoder_output=encoder_outputs.last_hidden_state.unsqueeze(2), ) + logits = torch.cat([token_logits, duration_logits], dim=-1) + + if compute_loss: + loss = tdt_loss( + token_logits=token_logits.float(), + duration_logits=duration_logits.float(), + targets=labels.to(token_logits.device).int(), + logit_lengths=encoder_lengths.to(token_logits.device).int(), + target_lengths=target_lengths.to(token_logits.device).int(), + blank_token_id=self.config.blank_token_id, + durations=self.config.durations, + reduction="mean", + ) - return CausalLMOutput( + return ParakeetTDTOutput( loss=loss, - logits=encoder_hidden_states, + logits=logits, + last_hidden_state=encoder_outputs.last_hidden_state, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @@ -1024,14 +1054,14 @@ def generate( kwargs["return_dict"] = True if return_timestamps: return_dict_in_generate = True - outputs: CausalLMOutput = self.forward( + outputs = self.forward( input_features=input_features, attention_mask=attention_mask, **kwargs, ) # greedy TDT decoding, `GreedyBatchedTDTLabelLoopingComputer.torch_impl` in NeMo - encoder_hidden_states = outputs.logits + encoder_hidden_states = outputs.last_hidden_state batch_size, sequence_length = encoder_hidden_states.shape[:2] device = encoder_hidden_states.device if attention_mask is not None: diff --git a/tests/fixtures/parakeet/expected_tdt_loss.json b/tests/fixtures/parakeet/expected_tdt_loss.json index f129fd5f01ac..7c3ff498483f 100644 --- a/tests/fixtures/parakeet/expected_tdt_loss.json +++ b/tests/fixtures/parakeet/expected_tdt_loss.json @@ -1,5 +1,4 @@ { - "_comment": "Generated by generate_tdt_loss_fixtures.py using NeMo's TDTLossPytorch. Inputs use torch.manual_seed(42), batch=2, T=8, U=4, vocab=5, durations=[0,1,2,3,4].", "seed": 42, "batch_size": 2, "max_t": 8, @@ -35,5 +34,10 @@ 3 ], "expected_loss_sum": 21.978166580200195, - "expected_loss_mean": 3.124553918838501 + "expected_loss_mean": 3.124553918838501, + "expected_loss_none": [ + 12.923372268676758, + 9.054794311523438 + ], + "expected_loss_mean_sigma_0p05": 3.1921849250793457 } \ No newline at end of file diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index b92244ade41e..80d8c519fc46 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -63,7 +63,7 @@ def _make_inputs(self): max_u = self.fixture["max_u"] vocab_size = self.fixture["vocab_size"] num_durations = len(self.fixture["durations"]) - blank = vocab_size + blank_token_id = vocab_size combined_logits = torch.randn(batch_size, max_t, max_u + 1, vocab_size + 1 + num_durations) targets = torch.randint(0, vocab_size, (batch_size, max_u)) @@ -76,7 +76,7 @@ def _make_inputs(self): "targets": targets, "logit_lengths": logit_lengths, "target_lengths": target_lengths, - "blank": blank, + "blank_token_id": blank_token_id, "durations": self.fixture["durations"], } @@ -94,18 +94,20 @@ def test_tdt_loss_mean(self): def test_tdt_loss_none(self): inputs = self._make_inputs() - losses = tdt_loss(**inputs, reduction="none") - self.assertEqual(losses.shape, (self.fixture["batch_size"],)) - expected_sum = torch.tensor(self.fixture["expected_loss_sum"]) - torch.testing.assert_close(losses.sum(), expected_sum) + losses = tdt_loss(**inputs, reduction=None) + expected = torch.tensor(self.fixture["expected_loss_none"]) + torch.testing.assert_close(losses, expected) def test_tdt_loss_with_sigma(self): inputs = self._make_inputs() - loss_no_sigma = tdt_loss(**inputs, sigma=0.0, reduction="sum") - loss_with_sigma = tdt_loss(**inputs, sigma=0.05, reduction="sum") + loss_no_sigma = tdt_loss(**inputs, sigma=0.0, reduction="mean") + loss_with_sigma = tdt_loss(**inputs, sigma=0.05, reduction="mean") self.assertFalse(torch.allclose(loss_no_sigma, loss_with_sigma)) self.assertGreater(loss_with_sigma.item(), loss_no_sigma.item()) + expected = torch.tensor(self.fixture["expected_loss_mean_sigma_0p05"]) + torch.testing.assert_close(loss_with_sigma, expected) + def test_tdt_loss_gradient_flows(self): inputs = self._make_inputs() inputs["token_logits"] = inputs["token_logits"].requires_grad_(True) @@ -512,9 +514,9 @@ def create_and_check_model(self, config, input_features, attention_mask): with torch.no_grad(): result = model(input_features, attention_mask=attention_mask) - # forward() returns encoder hidden states as logits + # Check encoder last hidden state self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.output_seq_length, self.encoder_model_tester.hidden_size) + result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.encoder_model_tester.hidden_size) ) def prepare_config_and_inputs_for_common(self): From 77b95d7301fcc7fd03e7e87f2a68af834699db40 Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 10 Mar 2026 15:42:19 +0100 Subject: [PATCH 28/67] Style --- tests/models/parakeet/test_modeling_parakeet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 80d8c519fc46..3f6c416a62af 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -516,7 +516,8 @@ def create_and_check_model(self, config, input_features, attention_mask): # Check encoder last hidden state self.parent.assertEqual( - result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.encoder_model_tester.hidden_size) + result.last_hidden_state.shape, + (self.batch_size, self.output_seq_length, self.encoder_model_tester.hidden_size), ) def prepare_config_and_inputs_for_common(self): From 94eae66fd6ba4ef6bb92764e93f589c555e9916b Mon Sep 17 00:00:00 2001 From: Eric B Date: Mon, 23 Mar 2026 16:45:52 +0100 Subject: [PATCH 29/67] Remove compute_loss flag and allow monkey patching to tdt loss --- docs/source/en/model_doc/parakeet.md | 47 +++++++++++-------- .../models/parakeet/modeling_parakeet.py | 33 +++++-------- .../models/parakeet/modular_parakeet.py | 33 +++++-------- 3 files changed, 50 insertions(+), 63 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index 7c8a7d099fab..e588f2bbd1b4 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -328,6 +328,31 @@ loss_fn = TDTLossNumba( reduction="none", ) +# Create wrapper to adapt NeMo loss to Transformers signature +def nemo_loss_wrapper(token_logits, duration_logits, targets, logit_lengths, target_lengths, **kwargs): + """Adapter function that converts Transformers loss signature to NeMo signature.""" + # Concatenate token and duration logits (NeMo expects combined logits) + acts = torch.cat([token_logits, duration_logits], dim=-1) + + # Use actual tensor shape for act_lens (NeMo requires T dimension to match max(act_lens)) + # The logit_lengths may not exactly match due to padding/masking edge cases + batch_size, T, U = acts.shape[:3] + act_lens = torch.full((batch_size,), T, dtype=torch.long, device=acts.device) + + # NeMo requires float32 (Numba doesn't support float16/bfloat16) and int64 + per_sample_losses = nemo_loss_fn( + acts=acts.float(), + labels=targets.long(), + act_lens=act_lens, + label_lens=target_lengths.long(), + ) + + # Normalize by target lengths and take mean across batch + return (per_sample_losses / target_lengths.float()).mean() + +# Monkey-patch the model's loss function +model.loss_function = nemo_loss_wrapper + # Load dataset ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) @@ -338,26 +363,10 @@ text_samples = ds["text"][:NUM_SAMPLES] inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate) inputs.to(device=model.device, dtype=model.dtype) -# Forward pass without computing loss -outputs = model(**inputs, compute_loss=False) - -# Prepare inputs for NeMo TDT loss -# -- NOTE: convert to float32 for NeMo loss since Numba doesn't support float16/bfloat16, but keep labels as integers -encoder_lengths = torch.full((outputs.last_hidden_state.shape[0],), outputs.last_hidden_state.shape[1], dtype=torch.long, device=model.device) -labels = inputs["labels"] -target_lengths = (labels != model.config.pad_token_id).sum(-1) -losses = loss_fn( - acts=outputs.logits.float(), - labels=labels.long(), - act_lens=encoder_lengths.long(), - label_lens=target_lengths.long(), -) - -# Normalize by target lengths -loss = (losses / target_lengths.float()).mean() +# Forward and backward +outputs = model(**inputs) +loss = outputs.loss print(f"Loss (NeMo TDTLossNumba): {loss.item():.6f}") - -# Backward pass loss.backward() print("\nāœ“ Successfully computed loss and gradients using NeMo's fast TDT loss!") ``` diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index eead9d080ff1..78d3ab63a33f 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -1067,6 +1067,7 @@ def __init__(self, config: ParakeetTDTConfig): self.encoder = AutoModel.from_config(config.encoder_config) self.decoder = ParakeetTDTDecoder(config) self.joint = ParakeetTDTJointNetwork(config) + self.loss_function = tdt_loss self.post_init() @@ -1077,17 +1078,9 @@ def forward( input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, - compute_loss: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTOutput: r""" - Args: - compute_loss (`bool`, *optional*, defaults to `False`): - Whether to compute the loss when the `labels` argument is provided. If `False`, the model will compute - the joint token and duration logits but will not compute the TDT loss, even if `labels` are provided. - This can be useful for cases where you want to compute the loss separately, e.g. with NeMo's TDT loss - implementation. - Example: ```python @@ -1113,9 +1106,6 @@ def forward( loss, logits = None, None if labels is not None: - if compute_loss is None: - compute_loss = True - # Compute encoder output lengths attention_mask = ( attention_mask @@ -1139,17 +1129,16 @@ def forward( ) logits = torch.cat([token_logits, duration_logits], dim=-1) - if compute_loss: - loss = tdt_loss( - token_logits=token_logits.float(), - duration_logits=duration_logits.float(), - targets=labels.to(token_logits.device).int(), - logit_lengths=encoder_lengths.to(token_logits.device).int(), - target_lengths=target_lengths.to(token_logits.device).int(), - blank_token_id=self.config.blank_token_id, - durations=self.config.durations, - reduction="mean", - ) + loss = self.loss_function( + token_logits=token_logits.float(), + duration_logits=duration_logits.float(), + targets=labels.to(token_logits.device).int(), + logit_lengths=encoder_lengths.to(token_logits.device).int(), + target_lengths=target_lengths.to(token_logits.device).int(), + blank_token_id=self.config.blank_token_id, + durations=self.config.durations, + reduction="mean", + ) return ParakeetTDTOutput( loss=loss, diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 6a42e243e0f7..f102314869cf 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -915,6 +915,7 @@ def __init__(self, config: ParakeetTDTConfig): self.encoder = AutoModel.from_config(config.encoder_config) self.decoder = ParakeetTDTDecoder(config) self.joint = ParakeetTDTJointNetwork(config) + self.loss_function = tdt_loss self.post_init() @@ -925,17 +926,9 @@ def forward( input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, - compute_loss: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTOutput: r""" - Args: - compute_loss (`bool`, *optional*, defaults to `False`): - Whether to compute the loss when the `labels` argument is provided. If `False`, the model will compute - the joint token and duration logits but will not compute the TDT loss, even if `labels` are provided. - This can be useful for cases where you want to compute the loss separately, e.g. with NeMo's TDT loss - implementation. - Example: ```python @@ -961,9 +954,6 @@ def forward( loss, logits = None, None if labels is not None: - if compute_loss is None: - compute_loss = True - # Compute encoder output lengths attention_mask = ( attention_mask @@ -987,17 +977,16 @@ def forward( ) logits = torch.cat([token_logits, duration_logits], dim=-1) - if compute_loss: - loss = tdt_loss( - token_logits=token_logits.float(), - duration_logits=duration_logits.float(), - targets=labels.to(token_logits.device).int(), - logit_lengths=encoder_lengths.to(token_logits.device).int(), - target_lengths=target_lengths.to(token_logits.device).int(), - blank_token_id=self.config.blank_token_id, - durations=self.config.durations, - reduction="mean", - ) + loss = self.loss_function( + token_logits=token_logits.float(), + duration_logits=duration_logits.float(), + targets=labels.to(token_logits.device).int(), + logit_lengths=encoder_lengths.to(token_logits.device).int(), + target_lengths=target_lengths.to(token_logits.device).int(), + blank_token_id=self.config.blank_token_id, + durations=self.config.durations, + reduction="mean", + ) return ParakeetTDTOutput( loss=loss, From f7d40675d21997128a8b38e76f0cb85cfa1d91f6 Mon Sep 17 00:00:00 2001 From: Eric Bezzam <4757445+ebezzam@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:53:10 +0100 Subject: [PATCH 30/67] Update src/transformers/models/parakeet/modular_parakeet.py Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> --- src/transformers/models/parakeet/modular_parakeet.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index f102314869cf..3852e43b0a37 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -486,7 +486,12 @@ def forward( @dataclass -class ParakeetCTCGenerateOutput(ModelOutput): +class ParakeetGenerateOutput(ParakeetCTCGenerateOutput): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + logger.warning_once( + "`ParakeetGenerateOutput` is deprecated and removed starting from version 5.5.0; please use `ParakeetCTCGenerateOutput` instead.", + ) """ Outputs of Parakeet CTC model generation. From f75c17b66eac15ca53dba40a93a4018d404e351b Mon Sep 17 00:00:00 2001 From: Eric B Date: Mon, 23 Mar 2026 19:40:51 +0100 Subject: [PATCH 31/67] Address various comments. --- .../models/parakeet/convert_nemo_to_hf.py | 17 +---- .../models/parakeet/modeling_parakeet.py | 73 +++++++++--------- .../models/parakeet/modular_parakeet.py | 75 +++++++++---------- 3 files changed, 76 insertions(+), 89 deletions(-) diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index 4fb17653e59c..632bc4c88aac 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -55,6 +55,7 @@ r"decoder\.prediction\.dec_rnn\.lstm\.": r"decoder.lstm.", r"joint\.enc\.": r"joint.encoder_projector.", r"joint\.pred\.": r"decoder.decoder_projector.", + r"joint\.joint_net\.2\.": r"joint.head.", } @@ -340,22 +341,6 @@ def load_and_convert_tdt_state_dict(model_files, vocab_size): print(f"Skipping preprocessing weight: {key}") continue - if key == "joint.joint_net.2.weight": - token_weight = value[:vocab_size, :] - duration_weight = value[vocab_size:, :] - converted_state_dict["joint.token_head.weight"] = token_weight - converted_state_dict["joint.duration_head.weight"] = duration_weight - print(f"Split combined weight: token_head {token_weight.shape}, duration_head {duration_weight.shape}") - continue - - if key == "joint.joint_net.2.bias": - token_bias = value[:vocab_size] - duration_bias = value[vocab_size:] - converted_state_dict["joint.token_head.bias"] = token_bias - converted_state_dict["joint.duration_head.bias"] = duration_bias - print(f"Split combined bias: token_head {token_bias.shape}, duration_head {duration_bias.shape}") - continue - converted_key = convert_key(key, all_mappings) converted_state_dict[converted_key] = value diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 78d3ab63a33f..3f28b028b86a 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -32,13 +32,16 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig +logger = logging.get_logger(__name__) + + @dataclass @auto_docstring( custom_intro=""" @@ -664,6 +667,19 @@ class ParakeetCTCGenerateOutput(ModelOutput): hidden_states: tuple[tuple[torch.FloatTensor]] | None = None +@dataclass +class ParakeetGenerateOutput(ParakeetCTCGenerateOutput): + """ + Deprecated alias for ParakeetCTCGenerateOutput. Use ParakeetCTCGenerateOutput instead. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + logger.warning_once( + "`ParakeetGenerateOutput` is deprecated and removed starting from version 5.5.0; please use `ParakeetCTCGenerateOutput` instead.", + ) + + @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -709,6 +725,8 @@ def forward( >>> print(outputs.loss) ```""" + if labels is not None: + kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.encoder( input_features=input_features, attention_mask=attention_mask, @@ -720,11 +738,7 @@ def forward( loss = None if labels is not None: - # retrieve loss input_lengths from attention_mask - attention_mask = ( - attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long) - ) - input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) + encoder_lengths = encoder_outputs.attention_mask.sum(-1) # assuming that padded tokens are filled with pad_token_id when not being attended to labels_mask = labels != self.config.pad_token_id @@ -738,7 +752,7 @@ def forward( loss = nn.functional.ctc_loss( log_probs, flattened_targets, - input_lengths, + encoder_lengths, target_lengths, blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, @@ -829,20 +843,13 @@ def forward( cell_state: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input_ids = input_ids.to(self.decoder_projector.weight.device) - if hidden_state is None or cell_state is None: - hidden_state = torch.zeros( - self.config.num_decoder_layers, - input_ids.shape[0], - self.config.decoder_hidden_size, - device=self.decoder_projector.weight.device, - dtype=self.decoder_projector.weight.dtype, - ) - cell_state = torch.zeros_like(hidden_state) - hidden_state = hidden_state.to(self.decoder_projector.weight.device) - cell_state = cell_state.to(self.decoder_projector.weight.device) + if hidden_state is not None and cell_state is not None: + hidden_cell_states = (hidden_state, cell_state) + else: + hidden_cell_states = None embeddings = self.embedding(input_ids) - lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, (hidden_state, cell_state)) + lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, hidden_cell_states) decoder_output = self.decoder_projector(lstm_output) return decoder_output, hidden_state, cell_state @@ -854,8 +861,9 @@ def __init__(self, config: ParakeetTDTConfig): super().__init__() self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.activation = ACT2FN[config.hidden_act] - self.token_head = nn.Linear(config.decoder_hidden_size, config.vocab_size) - self.duration_head = nn.Linear(config.decoder_hidden_size, len(config.durations)) + # Combined head outputs both token logits and duration logits + self.head = nn.Linear(config.decoder_hidden_size, config.vocab_size + len(config.durations)) + self.vocab_size = config.vocab_size def forward( self, @@ -868,7 +876,10 @@ def forward( raise ValueError("Either encoder_output or projected_encoder_output must be provided.") projected_encoder_output = self.encoder_projector(encoder_output) joint_output = self.activation(projected_encoder_output + decoder_output) - return self.token_head(joint_output), self.duration_head(joint_output) + logits = self.head(joint_output) + token_logits = logits[..., : self.vocab_size] + duration_logits = logits[..., self.vocab_size :] + return token_logits, duration_logits @dataclass @@ -1061,6 +1072,7 @@ def tdt_loss( ) class ParakeetForTDT(ParakeetPreTrainedModel): config: ParakeetTDTConfig + _no_split_modules = ["ParakeetTDTDecoder"] def __init__(self, config: ParakeetTDTConfig): super().__init__(config) @@ -1098,6 +1110,8 @@ def forward( >>> outputs = model(**inputs) ``` """ + if labels is not None: + kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.encoder( input_features=input_features, attention_mask=attention_mask, @@ -1106,13 +1120,7 @@ def forward( loss, logits = None, None if labels is not None: - # Compute encoder output lengths - attention_mask = ( - attention_mask - if attention_mask is not None - else torch.ones(input_features.shape[:-1], dtype=torch.long, device=input_features.device) - ) - encoder_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) + encoder_lengths = encoder_outputs.attention_mask.sum(-1) # Prepare labels for TDT loss target_lengths = (labels != self.config.pad_token_id).sum(-1) @@ -1127,7 +1135,6 @@ def forward( decoder_output=decoder_output.unsqueeze(1), encoder_output=encoder_outputs.last_hidden_state.unsqueeze(2), ) - logits = torch.cat([token_logits, duration_logits], dim=-1) loss = self.loss_function( token_logits=token_logits.float(), @@ -1139,6 +1146,7 @@ def forward( durations=self.config.durations, reduction="mean", ) + logits = torch.cat([token_logits, duration_logits], dim=-1) return ParakeetTDTOutput( loss=loss, @@ -1212,9 +1220,8 @@ def generate( valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) # Initialization - hidden_state, cell_state = None, None prev_tokens = torch.full((batch_size, 1), self.config.blank_token_id, dtype=torch.long, device=device) - decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) + decoder_output, hidden_state, cell_state = self.decoder(prev_tokens) decoder_output = decoder_output.to(device) hidden_state = hidden_state.to(device) cell_state = cell_state.to(device) @@ -1251,7 +1258,6 @@ def generate( ) token_logits = token_logits.squeeze(1).to(device) duration_logits = duration_logits.squeeze(1).to(device) - tokens = token_logits.argmax(dim=-1) durations = duration_logits.argmax(dim=-1) @@ -1278,7 +1284,6 @@ def generate( ) token_logits = token_logits.squeeze(1).to(device) duration_logits = duration_logits.squeeze(1).to(device) - more_tokens = token_logits.argmax(dim=-1) more_durations = duration_logits.argmax(dim=-1) tokens = torch.where(advance_mask, more_tokens, tokens) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 3852e43b0a37..b0c3d00faafd 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -486,12 +486,7 @@ def forward( @dataclass -class ParakeetGenerateOutput(ParakeetCTCGenerateOutput): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - logger.warning_once( - "`ParakeetGenerateOutput` is deprecated and removed starting from version 5.5.0; please use `ParakeetCTCGenerateOutput` instead.", - ) +class ParakeetCTCGenerateOutput(ModelOutput): """ Outputs of Parakeet CTC model generation. @@ -517,6 +512,19 @@ def __init__(self, *args, **kwargs): hidden_states: tuple[tuple[torch.FloatTensor]] | None = None +@dataclass +class ParakeetGenerateOutput(ParakeetCTCGenerateOutput): + """ + Deprecated alias for ParakeetCTCGenerateOutput. Use ParakeetCTCGenerateOutput instead. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + logger.warning_once( + "`ParakeetGenerateOutput` is deprecated and removed starting from version 5.5.0; please use `ParakeetCTCGenerateOutput` instead.", + ) + + @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -562,6 +570,8 @@ def forward( >>> print(outputs.loss) ```""" + if labels is not None: + kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.encoder( input_features=input_features, attention_mask=attention_mask, @@ -573,11 +583,7 @@ def forward( loss = None if labels is not None: - # retrieve loss input_lengths from attention_mask - attention_mask = ( - attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long) - ) - input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) + encoder_lengths = encoder_outputs.attention_mask.sum(-1) # assuming that padded tokens are filled with pad_token_id when not being attended to labels_mask = labels != self.config.pad_token_id @@ -591,7 +597,7 @@ def forward( loss = nn.functional.ctc_loss( log_probs, flattened_targets, - input_lengths, + encoder_lengths, target_lengths, blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, @@ -682,20 +688,13 @@ def forward( cell_state: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input_ids = input_ids.to(self.decoder_projector.weight.device) - if hidden_state is None or cell_state is None: - hidden_state = torch.zeros( - self.config.num_decoder_layers, - input_ids.shape[0], - self.config.decoder_hidden_size, - device=self.decoder_projector.weight.device, - dtype=self.decoder_projector.weight.dtype, - ) - cell_state = torch.zeros_like(hidden_state) - hidden_state = hidden_state.to(self.decoder_projector.weight.device) - cell_state = cell_state.to(self.decoder_projector.weight.device) + if hidden_state is not None and cell_state is not None: + hidden_cell_states = (hidden_state, cell_state) + else: + hidden_cell_states = None embeddings = self.embedding(input_ids) - lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, (hidden_state, cell_state)) + lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, hidden_cell_states) decoder_output = self.decoder_projector(lstm_output) return decoder_output, hidden_state, cell_state @@ -835,8 +834,9 @@ def __init__(self, config: ParakeetTDTConfig): super().__init__() self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.activation = ACT2FN[config.hidden_act] - self.token_head = nn.Linear(config.decoder_hidden_size, config.vocab_size) - self.duration_head = nn.Linear(config.decoder_hidden_size, len(config.durations)) + # Combined head outputs both token logits and duration logits + self.head = nn.Linear(config.decoder_hidden_size, config.vocab_size + len(config.durations)) + self.vocab_size = config.vocab_size def forward( self, @@ -849,7 +849,10 @@ def forward( raise ValueError("Either encoder_output or projected_encoder_output must be provided.") projected_encoder_output = self.encoder_projector(encoder_output) joint_output = self.activation(projected_encoder_output + decoder_output) - return self.token_head(joint_output), self.duration_head(joint_output) + logits = self.head(joint_output) + token_logits = logits[..., : self.vocab_size] + duration_logits = logits[..., self.vocab_size :] + return token_logits, duration_logits @dataclass @@ -914,6 +917,7 @@ class ParakeetTDTOutput(ModelOutput): ) class ParakeetForTDT(ParakeetPreTrainedModel): config: ParakeetTDTConfig + _no_split_modules = ["ParakeetTDTDecoder"] def __init__(self, config: ParakeetTDTConfig): super().__init__(config) @@ -951,6 +955,8 @@ def forward( >>> outputs = model(**inputs) ``` """ + if labels is not None: + kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.encoder( input_features=input_features, attention_mask=attention_mask, @@ -959,13 +965,7 @@ def forward( loss, logits = None, None if labels is not None: - # Compute encoder output lengths - attention_mask = ( - attention_mask - if attention_mask is not None - else torch.ones(input_features.shape[:-1], dtype=torch.long, device=input_features.device) - ) - encoder_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) + encoder_lengths = encoder_outputs.attention_mask.sum(-1) # Prepare labels for TDT loss target_lengths = (labels != self.config.pad_token_id).sum(-1) @@ -980,7 +980,6 @@ def forward( decoder_output=decoder_output.unsqueeze(1), encoder_output=encoder_outputs.last_hidden_state.unsqueeze(2), ) - logits = torch.cat([token_logits, duration_logits], dim=-1) loss = self.loss_function( token_logits=token_logits.float(), @@ -992,6 +991,7 @@ def forward( durations=self.config.durations, reduction="mean", ) + logits = torch.cat([token_logits, duration_logits], dim=-1) return ParakeetTDTOutput( loss=loss, @@ -1065,9 +1065,8 @@ def generate( valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) # Initialization - hidden_state, cell_state = None, None prev_tokens = torch.full((batch_size, 1), self.config.blank_token_id, dtype=torch.long, device=device) - decoder_output, hidden_state, cell_state = self.decoder(prev_tokens, hidden_state, cell_state) + decoder_output, hidden_state, cell_state = self.decoder(prev_tokens) decoder_output = decoder_output.to(device) hidden_state = hidden_state.to(device) cell_state = cell_state.to(device) @@ -1104,7 +1103,6 @@ def generate( ) token_logits = token_logits.squeeze(1).to(device) duration_logits = duration_logits.squeeze(1).to(device) - tokens = token_logits.argmax(dim=-1) durations = duration_logits.argmax(dim=-1) @@ -1131,7 +1129,6 @@ def generate( ) token_logits = token_logits.squeeze(1).to(device) duration_logits = duration_logits.squeeze(1).to(device) - more_tokens = token_logits.argmax(dim=-1) more_durations = duration_logits.argmax(dim=-1) tokens = torch.where(advance_mask, more_tokens, tokens) From 5a49b651b475560690fa331a142745ef0e3b70af Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 24 Mar 2026 16:19:31 +0100 Subject: [PATCH 32/67] More compatible with Transformers forward/generate approach --- .../models/parakeet/configuration_parakeet.py | 1 + .../models/parakeet/convert_nemo_to_hf.py | 2 +- .../models/parakeet/modeling_parakeet.py | 290 ++++++++++-------- .../models/parakeet/modular_parakeet.py | 290 ++++++++++-------- 4 files changed, 334 insertions(+), 249 deletions(-) diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 8a41ab817865..4e92698ba35e 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -249,6 +249,7 @@ def __init__( self.initializer_range = self.encoder_config.initializer_range self.blank_token_id = blank_token_id self.pad_token_id = pad_token_id + self.is_encoder_decoder = True super().__init__(**kwargs) diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index 632bc4c88aac..ccbec5fcb245 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -53,7 +53,7 @@ NEMO_TDT_WEIGHT_MAPPING = { r"decoder\.prediction\.embed\.": r"decoder.embedding.", r"decoder\.prediction\.dec_rnn\.lstm\.": r"decoder.lstm.", - r"joint\.enc\.": r"joint.encoder_projector.", + r"joint\.enc\.": r"encoder_projector.", r"joint\.pred\.": r"decoder.decoder_projector.", r"joint\.joint_net\.2\.": r"joint.head.", } diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 3f28b028b86a..203e75ae11b0 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -29,7 +29,7 @@ from ...activations import ACT2FN from ...integrations import use_kernel_func_from_hub, use_kernelized_func from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, CausalLMOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging @@ -45,10 +45,11 @@ @dataclass @auto_docstring( custom_intro=""" - Extends [~modeling_outputs.BaseModelOutput] to include the output attention mask since sequence length is not preserved in the model's forward. + Extends [~modeling_outputs.BaseModelOutputWithPooling] to include the output attention mask since sequence length + is not preserved in the model's forward. """ ) -class ParakeetEncoderModelOutput(BaseModelOutput): +class ParakeetEncoderModelOutput(BaseModelOutputWithPooling): attention_mask: torch.Tensor | None = None @@ -843,11 +844,9 @@ def forward( cell_state: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input_ids = input_ids.to(self.decoder_projector.weight.device) - if hidden_state is not None and cell_state is not None: - hidden_cell_states = (hidden_state, cell_state) - else: - hidden_cell_states = None - + hidden_cell_states = ( + (hidden_state, cell_state) if hidden_state is not None and cell_state is not None else None + ) embeddings = self.embedding(input_ids) lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, hidden_cell_states) decoder_output = self.decoder_projector(lstm_output) @@ -859,23 +858,16 @@ class ParakeetTDTJointNetwork(nn.Module): def __init__(self, config: ParakeetTDTConfig): super().__init__() - self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.activation = ACT2FN[config.hidden_act] - # Combined head outputs both token logits and duration logits self.head = nn.Linear(config.decoder_hidden_size, config.vocab_size + len(config.durations)) self.vocab_size = config.vocab_size def forward( self, decoder_output: torch.Tensor, - encoder_output: torch.Tensor | None = None, - projected_encoder_output: torch.Tensor | None = None, + encoder_output: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - if projected_encoder_output is None: - if encoder_output is None: - raise ValueError("Either encoder_output or projected_encoder_output must be provided.") - projected_encoder_output = self.encoder_projector(encoder_output) - joint_output = self.activation(projected_encoder_output + decoder_output) + joint_output = self.activation(encoder_output + decoder_output) logits = self.head(joint_output) token_logits = logits[..., : self.vocab_size] duration_logits = logits[..., self.vocab_size :] @@ -885,24 +877,19 @@ def forward( @dataclass class ParakeetTDTGenerateOutput(ModelOutput): """ - Outputs of Parakeet TDT model generation. + Outputs of Parakeet TDT generation. Args: sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. + Generated token sequences. token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Token-level timestamps in seconds indicating when each token was emitted. Only returned when - `return_timestamps=True` is passed to `generate()`. + Per-token frame indices. Returned when `return_timestamps=True`. token_durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Token-level durations in frames indicating how many frames each token spans. Only returned when - `return_timestamps=True` is passed to `generate()`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape - `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions from the encoder. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape - `(batch_size, sequence_length, hidden_size)`. Hidden states from the encoder. + Per-token durations in frames. Returned when `return_timestamps=True`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*): + Encoder attention weights per layer. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*): + Encoder hidden states per layer. """ sequences: torch.LongTensor @@ -915,26 +902,30 @@ class ParakeetTDTGenerateOutput(ModelOutput): @dataclass class ParakeetTDTOutput(ModelOutput): """ - Output structure for Parakeet TDT forward pass. + Output of the Parakeet TDT forward pass. Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Last hidden state from the encoder. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Hidden states from the encoder. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Attention mask for the encoder. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size + num_durations)`, *optional*): - Joint token and duration logits computed from the encoder and decoder outputs. Only returned when `labels` are provided to the forward pass. loss (`torch.FloatTensor`, *optional*): - The loss computed from the TDT loss function. Only returned when `labels` are provided to the forward pass. + TDT loss, returned when `labels` are provided. + logits (`torch.FloatTensor`): + Joint token and duration logits. Shape is `(batch, T, U+1, vocab+durations)` for training + or `(batch, 1, 1, vocab+durations)` for single-step inference. + encoder_outputs (`ParakeetEncoderModelOutput`, *optional*): + Encoder outputs with `pooler_output` containing projected hidden states. + decoder_output (`torch.FloatTensor`, *optional*): + Decoder LSTM output, reused during blank-skipping in generation. + decoder_hidden_state (`torch.FloatTensor`, *optional*): + Decoder LSTM hidden state. + decoder_cell_state (`torch.FloatTensor`, *optional*): + Decoder LSTM cell state. """ - last_hidden_state: torch.Tensor - hidden_states: tuple[tuple[torch.FloatTensor]] | None = None - attentions: tuple[tuple[torch.FloatTensor]] | None = None loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None + encoder_outputs: "ParakeetEncoderModelOutput | None" = None + decoder_output: torch.FloatTensor | None = None + decoder_hidden_state: torch.FloatTensor | None = None + decoder_cell_state: torch.FloatTensor | None = None # TODO (ebezzam) eventually move to audio_utils or loss_utils for common usage? @@ -1077,22 +1068,56 @@ class ParakeetForTDT(ParakeetPreTrainedModel): def __init__(self, config: ParakeetTDTConfig): super().__init__(config) self.encoder = AutoModel.from_config(config.encoder_config) + self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.decoder = ParakeetTDTDecoder(config) self.joint = ParakeetTDTJointNetwork(config) self.loss_function = tdt_loss self.post_init() + def get_audio_features( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> ParakeetEncoderModelOutput: + encoder_outputs = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + **kwargs, + ) + encoder_outputs.pooler_output = self.encoder_projector(encoder_outputs.last_hidden_state) + return encoder_outputs + @auto_docstring @can_return_tuple def forward( self, - input_features: torch.Tensor, + input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, + input_ids: torch.LongTensor | None = None, + encoder_outputs: ParakeetEncoderModelOutput | None = None, + encoder_frame_ids: torch.LongTensor | None = None, + decoder_output: torch.Tensor | None = None, + decoder_hidden_state: torch.Tensor | None = None, + decoder_cell_state: torch.Tensor | None = None, labels: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTOutput: r""" + input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Decoder input token ids for single-step inference. + encoder_outputs (`ParakeetEncoderModelOutput`, *optional*): + Pre-computed encoder outputs with `pooler_output` containing projected hidden states. + encoder_frame_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Encoder frame indices for the joint network during generation. + decoder_output (`torch.Tensor`, *optional*): + Pre-computed decoder LSTM output, reused during blank-skipping. + decoder_hidden_state (`torch.Tensor`, *optional*): + Decoder LSTM hidden state from a previous step. + decoder_cell_state (`torch.Tensor`, *optional*): + Decoder LSTM cell state from a previous step. + Example: ```python @@ -1110,32 +1135,58 @@ def forward( >>> outputs = model(**inputs) ``` """ - if labels is not None: - kwargs.setdefault("output_attention_mask", True) - encoder_outputs = self.encoder( - input_features=input_features, - attention_mask=attention_mask, - **kwargs, - ) + # 1. Encode + project + if encoder_outputs is None: + if input_features is None: + raise ValueError("Either `input_features` or `encoder_outputs` must be provided.") + if labels is not None: + kwargs.setdefault("output_attention_mask", True) + encoder_outputs = self.get_audio_features( + input_features=input_features, + attention_mask=attention_mask, + **kwargs, + ) + projected_encoder_output = encoder_outputs.pooler_output - loss, logits = None, None if labels is not None: - encoder_lengths = encoder_outputs.attention_mask.sum(-1) - - # Prepare labels for TDT loss - target_lengths = (labels != self.config.pad_token_id).sum(-1) - - # Get joint decoder outputs + # for training: [blank, labels...] for training blank_tokens = torch.full( (labels.shape[0], 1), self.config.blank_token_id, dtype=labels.dtype, device=labels.device ) - decoder_input = torch.cat([blank_tokens, labels], dim=1) - decoder_output, _, _ = self.decoder(decoder_input) - token_logits, duration_logits = self.joint( - decoder_output=decoder_output.unsqueeze(1), - encoder_output=encoder_outputs.last_hidden_state.unsqueeze(2), + input_ids = torch.cat([blank_tokens, labels], dim=1) + elif input_ids is None and decoder_output is None: + # for inference: start with blank token if not provided + input_ids = torch.full( + (projected_encoder_output.shape[0], 1), + self.config.blank_token_id, + dtype=torch.long, + device=projected_encoder_output.device, ) + if decoder_output is None: + decoder_output, decoder_hidden_state, decoder_cell_state = self.decoder( + input_ids, decoder_hidden_state, decoder_cell_state + ) + + if encoder_frame_ids is not None: + batch_indices = torch.arange(projected_encoder_output.shape[0], device=projected_encoder_output.device) + safe_frame_ids = torch.clamp(encoder_frame_ids, max=projected_encoder_output.shape[1] - 1) + encoder_for_joint = projected_encoder_output[batch_indices, safe_frame_ids].unsqueeze(1) + decoder_for_joint = decoder_output + else: + encoder_for_joint = projected_encoder_output.unsqueeze(2) + decoder_for_joint = decoder_output.unsqueeze(1) + + token_logits, duration_logits = self.joint( + decoder_output=decoder_for_joint, + encoder_output=encoder_for_joint, + ) + logits = torch.cat([token_logits, duration_logits], dim=-1) + + loss = None + if labels is not None: + encoder_lengths = encoder_outputs.attention_mask.sum(-1) + target_lengths = (labels != self.config.pad_token_id).sum(-1) loss = self.loss_function( token_logits=token_logits.float(), duration_logits=duration_logits.float(), @@ -1146,14 +1197,14 @@ def forward( durations=self.config.durations, reduction="mean", ) - logits = torch.cat([token_logits, duration_logits], dim=-1) return ParakeetTDTOutput( loss=loss, logits=logits, - last_hidden_state=encoder_outputs.last_hidden_state, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, + encoder_outputs=encoder_outputs, + decoder_output=decoder_output, + decoder_hidden_state=decoder_hidden_state, + decoder_cell_state=decoder_cell_state, ) @torch.no_grad() @@ -1200,39 +1251,35 @@ def generate( >>> print("Timestamped tokens:", decoded_timestamps) ``` """ - kwargs["return_dict"] = True if return_timestamps: return_dict_in_generate = True - outputs = self.forward( + + # Initial forward: encode + blank prediction + outputs: ParakeetTDTOutput = self.forward( input_features=input_features, attention_mask=attention_mask, + return_dict=True, **kwargs, ) + encoder_outputs = outputs.encoder_outputs + batch_size, sequence_length = encoder_outputs.pooler_output.shape[:2] + device = encoder_outputs.pooler_output.device - # greedy TDT decoding, `GreedyBatchedTDTLabelLoopingComputer.torch_impl` in NeMo - encoder_hidden_states = outputs.last_hidden_state - batch_size, sequence_length = encoder_hidden_states.shape[:2] - device = encoder_hidden_states.device if attention_mask is not None: encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) valid_lengths = encoder_attention_mask.sum(dim=1).int() else: valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) + decoder_output = outputs.decoder_output + decoder_hidden_state = outputs.decoder_hidden_state + decoder_cell_state = outputs.decoder_cell_state - # Initialization - prev_tokens = torch.full((batch_size, 1), self.config.blank_token_id, dtype=torch.long, device=device) - decoder_output, hidden_state, cell_state = self.decoder(prev_tokens) - decoder_output = decoder_output.to(device) - hidden_state = hidden_state.to(device) - cell_state = cell_state.to(device) - - batch_indices = torch.arange(batch_size, device=device) + vocab_size = self.config.vocab_size time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) time_indices_current_labels = torch.zeros(batch_size, dtype=torch.long, device=device) active_mask = time_indices < valid_lengths active_mask_prev = torch.zeros_like(active_mask) - zeros_symbols = torch.zeros(batch_size, dtype=torch.long, device=device) symbols_per_step = torch.zeros(batch_size, dtype=torch.long, device=device) last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) max_output_len = sequence_length * self.config.max_symbols_per_step @@ -1244,48 +1291,44 @@ def generate( all_frame_indices = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) all_durations_tensor = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) - # separately call encoder projection to avoid redundant computation inside loop - projected_encoder_output = self.joint.encoder_projector(encoder_hidden_states).to(device) - while active_mask.any(): active_mask_prev.copy_(active_mask) - safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) - projected_encoder_frames = projected_encoder_output[batch_indices, safe_time_indices].unsqueeze(1) - token_logits, duration_logits = self.joint( - decoder_output, - projected_encoder_output=projected_encoder_frames, + outputs = self.forward( + encoder_outputs=encoder_outputs, + encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), + decoder_output=decoder_output, + decoder_hidden_state=decoder_hidden_state, + decoder_cell_state=decoder_cell_state, + return_dict=True, ) - token_logits = token_logits.squeeze(1).to(device) - duration_logits = duration_logits.squeeze(1).to(device) - tokens = token_logits.argmax(dim=-1) - durations = duration_logits.argmax(dim=-1) + logits = outputs.logits.squeeze(1) + tokens = logits[..., :vocab_size].argmax(dim=-1) + durations = logits[..., vocab_size:].argmax(dim=-1) - # Force blank duration >= 1 to guarantee forward progress blank_mask = active_mask_prev & (tokens == self.config.blank_token_id) - durations = durations.masked_fill(blank_mask & (durations == 0), 1) + durations = durations.masked_fill(blank_mask & (durations == 0), 1) # ensure forward progress - # Save pre-advance position for timestamp recording time_indices_current_labels.copy_(time_indices) - - # Advance time for all active elements time_indices = time_indices + durations.masked_fill(~active_mask, 0) - safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) active_mask = time_indices < valid_lengths advance_mask = active_mask & blank_mask - # Inner loop: skip past consecutive blanks to find non-blank + # Skip consecutive blanks while advance_mask.any(): time_indices_current_labels = torch.where(advance_mask, time_indices, time_indices_current_labels) - projected_encoder_frames = projected_encoder_output[batch_indices, safe_time_indices].unsqueeze(1) - token_logits, duration_logits = self.joint( - decoder_output, projected_encoder_output=projected_encoder_frames + outputs = self.forward( + encoder_outputs=encoder_outputs, + encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), + decoder_output=decoder_output, + decoder_hidden_state=decoder_hidden_state, + decoder_cell_state=decoder_cell_state, + return_dict=True, ) - token_logits = token_logits.squeeze(1).to(device) - duration_logits = duration_logits.squeeze(1).to(device) - more_tokens = token_logits.argmax(dim=-1) - more_durations = duration_logits.argmax(dim=-1) + logits = outputs.logits.squeeze(1) + more_tokens = logits[..., :vocab_size].argmax(dim=-1) + more_durations = logits[..., vocab_size:].argmax(dim=-1) tokens = torch.where(advance_mask, more_tokens, tokens) durations = torch.where(advance_mask, more_durations, durations) @@ -1293,11 +1336,9 @@ def generate( durations = durations.masked_fill(blank_mask & (durations == 0), 1) time_indices = torch.where(advance_mask, time_indices + durations, time_indices) - safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) active_mask = time_indices < valid_lengths advance_mask = active_mask & blank_mask - # Record results for non-blank tokens found emit_mask = active_mask_prev & (tokens != self.config.blank_token_id) emit_indices = token_counts[emit_mask] all_tokens_tensor[emit_mask, emit_indices] = tokens[emit_mask] @@ -1306,22 +1347,24 @@ def generate( all_durations_tensor[emit_mask, emit_indices] = durations[emit_mask] token_counts += emit_mask.long() - new_decoder_output, new_hidden_state, new_cell_state = self.decoder( - tokens.unsqueeze(1), hidden_state, cell_state + # Update decoder state for emitted tokens + outputs = self.forward( + input_ids=tokens.unsqueeze(1), + encoder_outputs=encoder_outputs, + encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), + decoder_hidden_state=decoder_hidden_state, + decoder_cell_state=decoder_cell_state, + return_dict=True, ) - new_decoder_output = new_decoder_output.to(device) - new_hidden_state = new_hidden_state.to(device) - new_cell_state = new_cell_state.to(device) - emit_mask_expanded = emit_mask.view(batch_size, 1, 1) - decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) emit_mask_state = emit_mask.view(1, batch_size, 1) - hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) - cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) + decoder_hidden_state = torch.where(emit_mask_state, outputs.decoder_hidden_state, decoder_hidden_state) + decoder_cell_state = torch.where(emit_mask_state, outputs.decoder_cell_state, decoder_cell_state) + emit_mask_expanded = emit_mask.view(batch_size, 1, 1) + decoder_output = torch.where(emit_mask_expanded, outputs.decoder_output, decoder_output) - # Track symbols emitted per time step; force advance when max_symbols reached time_changed = time_indices_current_labels != last_label_time - symbols_per_step = torch.where(time_changed, zeros_symbols, symbols_per_step) + symbols_per_step = torch.where(time_changed, 0, symbols_per_step) symbols_per_step = torch.where(emit_mask, symbols_per_step + 1, symbols_per_step) last_label_time = torch.where(emit_mask, time_indices_current_labels, last_label_time) force_advance = active_mask & (symbols_per_step >= self.config.max_symbols_per_step) @@ -1329,7 +1372,6 @@ def generate( symbols_per_step = symbols_per_step.masked_fill(force_advance, 0) active_mask = time_indices < valid_lengths - # Guard against edge case where no tokens were decoded (e.g. silent audio) max_len = max(token_counts.max().item(), 1) sequences = all_tokens_tensor[:, :max_len] token_timestamps, token_durations = None, None @@ -1342,8 +1384,8 @@ def generate( sequences=sequences, token_timestamps=token_timestamps, token_durations=token_durations, - attentions=outputs.attentions, - hidden_states=outputs.hidden_states, + attentions=encoder_outputs.attentions, + hidden_states=encoder_outputs.hidden_states, ) return sequences diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index b0c3d00faafd..466db46b1533 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -23,7 +23,7 @@ from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, CausalLMOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -47,10 +47,11 @@ @dataclass @auto_docstring( custom_intro=""" - Extends [~modeling_outputs.BaseModelOutput] to include the output attention mask since sequence length is not preserved in the model's forward. + Extends [~modeling_outputs.BaseModelOutputWithPooling] to include the output attention mask since sequence length + is not preserved in the model's forward. """ ) -class ParakeetEncoderModelOutput(BaseModelOutput): +class ParakeetEncoderModelOutput(BaseModelOutputWithPooling): attention_mask: torch.Tensor | None = None @@ -688,11 +689,9 @@ def forward( cell_state: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input_ids = input_ids.to(self.decoder_projector.weight.device) - if hidden_state is not None and cell_state is not None: - hidden_cell_states = (hidden_state, cell_state) - else: - hidden_cell_states = None - + hidden_cell_states = ( + (hidden_state, cell_state) if hidden_state is not None and cell_state is not None else None + ) embeddings = self.embedding(input_ids) lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, hidden_cell_states) decoder_output = self.decoder_projector(lstm_output) @@ -832,23 +831,16 @@ class ParakeetTDTJointNetwork(nn.Module): def __init__(self, config: ParakeetTDTConfig): super().__init__() - self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.activation = ACT2FN[config.hidden_act] - # Combined head outputs both token logits and duration logits self.head = nn.Linear(config.decoder_hidden_size, config.vocab_size + len(config.durations)) self.vocab_size = config.vocab_size def forward( self, decoder_output: torch.Tensor, - encoder_output: torch.Tensor | None = None, - projected_encoder_output: torch.Tensor | None = None, + encoder_output: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - if projected_encoder_output is None: - if encoder_output is None: - raise ValueError("Either encoder_output or projected_encoder_output must be provided.") - projected_encoder_output = self.encoder_projector(encoder_output) - joint_output = self.activation(projected_encoder_output + decoder_output) + joint_output = self.activation(encoder_output + decoder_output) logits = self.head(joint_output) token_logits = logits[..., : self.vocab_size] duration_logits = logits[..., self.vocab_size :] @@ -858,24 +850,19 @@ def forward( @dataclass class ParakeetTDTGenerateOutput(ModelOutput): """ - Outputs of Parakeet TDT model generation. + Outputs of Parakeet TDT generation. Args: sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. + Generated token sequences. token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Token-level timestamps in seconds indicating when each token was emitted. Only returned when - `return_timestamps=True` is passed to `generate()`. + Per-token frame indices. Returned when `return_timestamps=True`. token_durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Token-level durations in frames indicating how many frames each token spans. Only returned when - `return_timestamps=True` is passed to `generate()`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape - `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions from the encoder. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple of tuples (one element for each layer of the encoder) of `torch.FloatTensor` of shape - `(batch_size, sequence_length, hidden_size)`. Hidden states from the encoder. + Per-token durations in frames. Returned when `return_timestamps=True`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*): + Encoder attention weights per layer. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*): + Encoder hidden states per layer. """ sequences: torch.LongTensor @@ -888,26 +875,30 @@ class ParakeetTDTGenerateOutput(ModelOutput): @dataclass class ParakeetTDTOutput(ModelOutput): """ - Output structure for Parakeet TDT forward pass. + Output of the Parakeet TDT forward pass. Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Last hidden state from the encoder. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Hidden states from the encoder. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Attention mask for the encoder. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size + num_durations)`, *optional*): - Joint token and duration logits computed from the encoder and decoder outputs. Only returned when `labels` are provided to the forward pass. loss (`torch.FloatTensor`, *optional*): - The loss computed from the TDT loss function. Only returned when `labels` are provided to the forward pass. + TDT loss, returned when `labels` are provided. + logits (`torch.FloatTensor`): + Joint token and duration logits. Shape is `(batch, T, U+1, vocab+durations)` for training + or `(batch, 1, 1, vocab+durations)` for single-step inference. + encoder_outputs (`ParakeetEncoderModelOutput`, *optional*): + Encoder outputs with `pooler_output` containing projected hidden states. + decoder_output (`torch.FloatTensor`, *optional*): + Decoder LSTM output, reused during blank-skipping in generation. + decoder_hidden_state (`torch.FloatTensor`, *optional*): + Decoder LSTM hidden state. + decoder_cell_state (`torch.FloatTensor`, *optional*): + Decoder LSTM cell state. """ - last_hidden_state: torch.Tensor - hidden_states: tuple[tuple[torch.FloatTensor]] | None = None - attentions: tuple[tuple[torch.FloatTensor]] | None = None loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None + encoder_outputs: "ParakeetEncoderModelOutput | None" = None + decoder_output: torch.FloatTensor | None = None + decoder_hidden_state: torch.FloatTensor | None = None + decoder_cell_state: torch.FloatTensor | None = None @auto_docstring( @@ -922,22 +913,56 @@ class ParakeetForTDT(ParakeetPreTrainedModel): def __init__(self, config: ParakeetTDTConfig): super().__init__(config) self.encoder = AutoModel.from_config(config.encoder_config) + self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.decoder = ParakeetTDTDecoder(config) self.joint = ParakeetTDTJointNetwork(config) self.loss_function = tdt_loss self.post_init() + def get_audio_features( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> ParakeetEncoderModelOutput: + encoder_outputs = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + **kwargs, + ) + encoder_outputs.pooler_output = self.encoder_projector(encoder_outputs.last_hidden_state) + return encoder_outputs + @auto_docstring @can_return_tuple def forward( self, - input_features: torch.Tensor, + input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, + input_ids: torch.LongTensor | None = None, + encoder_outputs: ParakeetEncoderModelOutput | None = None, + encoder_frame_ids: torch.LongTensor | None = None, + decoder_output: torch.Tensor | None = None, + decoder_hidden_state: torch.Tensor | None = None, + decoder_cell_state: torch.Tensor | None = None, labels: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTOutput: r""" + input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Decoder input token ids for single-step inference. + encoder_outputs (`ParakeetEncoderModelOutput`, *optional*): + Pre-computed encoder outputs with `pooler_output` containing projected hidden states. + encoder_frame_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Encoder frame indices for the joint network during generation. + decoder_output (`torch.Tensor`, *optional*): + Pre-computed decoder LSTM output, reused during blank-skipping. + decoder_hidden_state (`torch.Tensor`, *optional*): + Decoder LSTM hidden state from a previous step. + decoder_cell_state (`torch.Tensor`, *optional*): + Decoder LSTM cell state from a previous step. + Example: ```python @@ -955,32 +980,58 @@ def forward( >>> outputs = model(**inputs) ``` """ - if labels is not None: - kwargs.setdefault("output_attention_mask", True) - encoder_outputs = self.encoder( - input_features=input_features, - attention_mask=attention_mask, - **kwargs, - ) + # 1. Encode + project + if encoder_outputs is None: + if input_features is None: + raise ValueError("Either `input_features` or `encoder_outputs` must be provided.") + if labels is not None: + kwargs.setdefault("output_attention_mask", True) + encoder_outputs = self.get_audio_features( + input_features=input_features, + attention_mask=attention_mask, + **kwargs, + ) + projected_encoder_output = encoder_outputs.pooler_output - loss, logits = None, None if labels is not None: - encoder_lengths = encoder_outputs.attention_mask.sum(-1) - - # Prepare labels for TDT loss - target_lengths = (labels != self.config.pad_token_id).sum(-1) - - # Get joint decoder outputs + # for training: [blank, labels...] for training blank_tokens = torch.full( (labels.shape[0], 1), self.config.blank_token_id, dtype=labels.dtype, device=labels.device ) - decoder_input = torch.cat([blank_tokens, labels], dim=1) - decoder_output, _, _ = self.decoder(decoder_input) - token_logits, duration_logits = self.joint( - decoder_output=decoder_output.unsqueeze(1), - encoder_output=encoder_outputs.last_hidden_state.unsqueeze(2), + input_ids = torch.cat([blank_tokens, labels], dim=1) + elif input_ids is None and decoder_output is None: + # for inference: start with blank token if not provided + input_ids = torch.full( + (projected_encoder_output.shape[0], 1), + self.config.blank_token_id, + dtype=torch.long, + device=projected_encoder_output.device, ) + if decoder_output is None: + decoder_output, decoder_hidden_state, decoder_cell_state = self.decoder( + input_ids, decoder_hidden_state, decoder_cell_state + ) + + if encoder_frame_ids is not None: + batch_indices = torch.arange(projected_encoder_output.shape[0], device=projected_encoder_output.device) + safe_frame_ids = torch.clamp(encoder_frame_ids, max=projected_encoder_output.shape[1] - 1) + encoder_for_joint = projected_encoder_output[batch_indices, safe_frame_ids].unsqueeze(1) + decoder_for_joint = decoder_output + else: + encoder_for_joint = projected_encoder_output.unsqueeze(2) + decoder_for_joint = decoder_output.unsqueeze(1) + + token_logits, duration_logits = self.joint( + decoder_output=decoder_for_joint, + encoder_output=encoder_for_joint, + ) + logits = torch.cat([token_logits, duration_logits], dim=-1) + + loss = None + if labels is not None: + encoder_lengths = encoder_outputs.attention_mask.sum(-1) + target_lengths = (labels != self.config.pad_token_id).sum(-1) loss = self.loss_function( token_logits=token_logits.float(), duration_logits=duration_logits.float(), @@ -991,14 +1042,14 @@ def forward( durations=self.config.durations, reduction="mean", ) - logits = torch.cat([token_logits, duration_logits], dim=-1) return ParakeetTDTOutput( loss=loss, logits=logits, - last_hidden_state=encoder_outputs.last_hidden_state, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, + encoder_outputs=encoder_outputs, + decoder_output=decoder_output, + decoder_hidden_state=decoder_hidden_state, + decoder_cell_state=decoder_cell_state, ) @torch.no_grad() @@ -1045,39 +1096,35 @@ def generate( >>> print("Timestamped tokens:", decoded_timestamps) ``` """ - kwargs["return_dict"] = True if return_timestamps: return_dict_in_generate = True - outputs = self.forward( + + # Initial forward: encode + blank prediction + outputs: ParakeetTDTOutput = self.forward( input_features=input_features, attention_mask=attention_mask, + return_dict=True, **kwargs, ) + encoder_outputs = outputs.encoder_outputs + batch_size, sequence_length = encoder_outputs.pooler_output.shape[:2] + device = encoder_outputs.pooler_output.device - # greedy TDT decoding, `GreedyBatchedTDTLabelLoopingComputer.torch_impl` in NeMo - encoder_hidden_states = outputs.last_hidden_state - batch_size, sequence_length = encoder_hidden_states.shape[:2] - device = encoder_hidden_states.device if attention_mask is not None: encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) valid_lengths = encoder_attention_mask.sum(dim=1).int() else: valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) + decoder_output = outputs.decoder_output + decoder_hidden_state = outputs.decoder_hidden_state + decoder_cell_state = outputs.decoder_cell_state - # Initialization - prev_tokens = torch.full((batch_size, 1), self.config.blank_token_id, dtype=torch.long, device=device) - decoder_output, hidden_state, cell_state = self.decoder(prev_tokens) - decoder_output = decoder_output.to(device) - hidden_state = hidden_state.to(device) - cell_state = cell_state.to(device) - - batch_indices = torch.arange(batch_size, device=device) + vocab_size = self.config.vocab_size time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) time_indices_current_labels = torch.zeros(batch_size, dtype=torch.long, device=device) active_mask = time_indices < valid_lengths active_mask_prev = torch.zeros_like(active_mask) - zeros_symbols = torch.zeros(batch_size, dtype=torch.long, device=device) symbols_per_step = torch.zeros(batch_size, dtype=torch.long, device=device) last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) max_output_len = sequence_length * self.config.max_symbols_per_step @@ -1089,48 +1136,44 @@ def generate( all_frame_indices = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) all_durations_tensor = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) - # separately call encoder projection to avoid redundant computation inside loop - projected_encoder_output = self.joint.encoder_projector(encoder_hidden_states).to(device) - while active_mask.any(): active_mask_prev.copy_(active_mask) - safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) - projected_encoder_frames = projected_encoder_output[batch_indices, safe_time_indices].unsqueeze(1) - token_logits, duration_logits = self.joint( - decoder_output, - projected_encoder_output=projected_encoder_frames, + outputs = self.forward( + encoder_outputs=encoder_outputs, + encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), + decoder_output=decoder_output, + decoder_hidden_state=decoder_hidden_state, + decoder_cell_state=decoder_cell_state, + return_dict=True, ) - token_logits = token_logits.squeeze(1).to(device) - duration_logits = duration_logits.squeeze(1).to(device) - tokens = token_logits.argmax(dim=-1) - durations = duration_logits.argmax(dim=-1) + logits = outputs.logits.squeeze(1) + tokens = logits[..., :vocab_size].argmax(dim=-1) + durations = logits[..., vocab_size:].argmax(dim=-1) - # Force blank duration >= 1 to guarantee forward progress blank_mask = active_mask_prev & (tokens == self.config.blank_token_id) - durations = durations.masked_fill(blank_mask & (durations == 0), 1) + durations = durations.masked_fill(blank_mask & (durations == 0), 1) # ensure forward progress - # Save pre-advance position for timestamp recording time_indices_current_labels.copy_(time_indices) - - # Advance time for all active elements time_indices = time_indices + durations.masked_fill(~active_mask, 0) - safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) active_mask = time_indices < valid_lengths advance_mask = active_mask & blank_mask - # Inner loop: skip past consecutive blanks to find non-blank + # Skip consecutive blanks while advance_mask.any(): time_indices_current_labels = torch.where(advance_mask, time_indices, time_indices_current_labels) - projected_encoder_frames = projected_encoder_output[batch_indices, safe_time_indices].unsqueeze(1) - token_logits, duration_logits = self.joint( - decoder_output, projected_encoder_output=projected_encoder_frames + outputs = self.forward( + encoder_outputs=encoder_outputs, + encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), + decoder_output=decoder_output, + decoder_hidden_state=decoder_hidden_state, + decoder_cell_state=decoder_cell_state, + return_dict=True, ) - token_logits = token_logits.squeeze(1).to(device) - duration_logits = duration_logits.squeeze(1).to(device) - more_tokens = token_logits.argmax(dim=-1) - more_durations = duration_logits.argmax(dim=-1) + logits = outputs.logits.squeeze(1) + more_tokens = logits[..., :vocab_size].argmax(dim=-1) + more_durations = logits[..., vocab_size:].argmax(dim=-1) tokens = torch.where(advance_mask, more_tokens, tokens) durations = torch.where(advance_mask, more_durations, durations) @@ -1138,11 +1181,9 @@ def generate( durations = durations.masked_fill(blank_mask & (durations == 0), 1) time_indices = torch.where(advance_mask, time_indices + durations, time_indices) - safe_time_indices = torch.clamp(time_indices, max=sequence_length - 1) active_mask = time_indices < valid_lengths advance_mask = active_mask & blank_mask - # Record results for non-blank tokens found emit_mask = active_mask_prev & (tokens != self.config.blank_token_id) emit_indices = token_counts[emit_mask] all_tokens_tensor[emit_mask, emit_indices] = tokens[emit_mask] @@ -1151,22 +1192,24 @@ def generate( all_durations_tensor[emit_mask, emit_indices] = durations[emit_mask] token_counts += emit_mask.long() - new_decoder_output, new_hidden_state, new_cell_state = self.decoder( - tokens.unsqueeze(1), hidden_state, cell_state + # Update decoder state for emitted tokens + outputs = self.forward( + input_ids=tokens.unsqueeze(1), + encoder_outputs=encoder_outputs, + encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), + decoder_hidden_state=decoder_hidden_state, + decoder_cell_state=decoder_cell_state, + return_dict=True, ) - new_decoder_output = new_decoder_output.to(device) - new_hidden_state = new_hidden_state.to(device) - new_cell_state = new_cell_state.to(device) - emit_mask_expanded = emit_mask.view(batch_size, 1, 1) - decoder_output = torch.where(emit_mask_expanded, new_decoder_output, decoder_output) emit_mask_state = emit_mask.view(1, batch_size, 1) - hidden_state = torch.where(emit_mask_state, new_hidden_state, hidden_state) - cell_state = torch.where(emit_mask_state, new_cell_state, cell_state) + decoder_hidden_state = torch.where(emit_mask_state, outputs.decoder_hidden_state, decoder_hidden_state) + decoder_cell_state = torch.where(emit_mask_state, outputs.decoder_cell_state, decoder_cell_state) + emit_mask_expanded = emit_mask.view(batch_size, 1, 1) + decoder_output = torch.where(emit_mask_expanded, outputs.decoder_output, decoder_output) - # Track symbols emitted per time step; force advance when max_symbols reached time_changed = time_indices_current_labels != last_label_time - symbols_per_step = torch.where(time_changed, zeros_symbols, symbols_per_step) + symbols_per_step = torch.where(time_changed, 0, symbols_per_step) symbols_per_step = torch.where(emit_mask, symbols_per_step + 1, symbols_per_step) last_label_time = torch.where(emit_mask, time_indices_current_labels, last_label_time) force_advance = active_mask & (symbols_per_step >= self.config.max_symbols_per_step) @@ -1174,7 +1217,6 @@ def generate( symbols_per_step = symbols_per_step.masked_fill(force_advance, 0) active_mask = time_indices < valid_lengths - # Guard against edge case where no tokens were decoded (e.g. silent audio) max_len = max(token_counts.max().item(), 1) sequences = all_tokens_tensor[:, :max_len] token_timestamps, token_durations = None, None @@ -1187,8 +1229,8 @@ def generate( sequences=sequences, token_timestamps=token_timestamps, token_durations=token_durations, - attentions=outputs.attentions, - hidden_states=outputs.hidden_states, + attentions=encoder_outputs.attentions, + hidden_states=encoder_outputs.hidden_states, ) return sequences From 881233fd746f1b53c97f79c3bfe39b76476f56f0 Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 24 Mar 2026 16:20:45 +0100 Subject: [PATCH 33/67] compile option for generation and decoder cache --- .../models/parakeet/configuration_parakeet.py | 4 +- .../models/parakeet/modeling_parakeet.py | 229 +++++++++++------- .../models/parakeet/modular_parakeet.py | 221 ++++++++++------- 3 files changed, 281 insertions(+), 173 deletions(-) diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 4e92698ba35e..2172ac924f07 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -187,7 +187,7 @@ class ParakeetTDTConfig(PreTrainedConfig): r""" decoder_hidden_size (`int`, *optional*, defaults to 640): Hidden size of the LSTM prediction network and joint network. - num_decoder_layers (`int`, *optional*, defaults to 1): + num_decoder_layers (`int`, *optional*, defaults to 2): Number of LSTM layers in the prediction network. num_duration_bins (`int`, *optional*, defaults to 5): Number of duration bins for predicting token durations. @@ -223,7 +223,7 @@ def __init__( self, vocab_size=8193, decoder_hidden_size=640, - num_decoder_layers=1, + num_decoder_layers=2, durations=[0, 1, 2, 3, 4], hidden_act="relu", max_symbols_per_step=10, diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 203e75ae11b0..84cf31b07782 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -27,12 +27,20 @@ from ... import initialization as init from ...activations import ACT2FN +from ...generation import CompileConfig from ...integrations import use_kernel_func_from_hub, use_kernelized_func from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils import ( + ModelOutput, + TransformersKwargs, + auto_docstring, + can_return_tuple, + is_torchdynamo_compiling, + logging, +) from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -822,6 +830,69 @@ def generate( return sequences +class ParakeetTDTDecoderCache: + def __init__(self): + self.cache: torch.Tensor | None = None + self.hidden_state: torch.Tensor | None = None + self.cell_state: torch.Tensor | None = None + self.is_initialized: bool = False + + def lazy_initialization(self, hidden_states, lstm_module): + self.cache = torch.zeros( + hidden_states.shape[0], 1, lstm_module.hidden_size, device=hidden_states.device, dtype=hidden_states.dtype + ) + self.hidden_state = torch.zeros( + lstm_module.num_layers, + hidden_states.shape[0], + lstm_module.hidden_size, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + self.cell_state = torch.zeros( + lstm_module.num_layers, + hidden_states.shape[0], + lstm_module.hidden_size, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(self.cache) + torch._dynamo.mark_static_address(self.hidden_state) + torch._dynamo.mark_static_address(self.cell_state) + + self.is_initialized = True + + def update( + self, + decoder_output, + hidden_state, + cell_state, + lstm_module=None, + mask=None, + ): + if not self.is_initialized and lstm_module is not None: + self.lazy_initialization(decoder_output, lstm_module) + elif not self.is_initialized: + raise ValueError( + "ParakeetTDTDecoderCache is not initialized. Make sure to provide lstm_module to the update method." + ) + + if mask is None: + self.hidden_state.copy_(hidden_state) + self.cell_state.copy_(cell_state) + self.cache.copy_(decoder_output) + else: + # Mask to update specific batch elements + mask = mask.to(decoder_output.device) + batch_size = decoder_output.shape[0] + mask_h = mask.view(1, batch_size, 1) + mask_d = mask.view(batch_size, 1, 1) + self.cache = torch.where(mask_d, decoder_output, self.cache) + self.hidden_state = torch.where(mask_h, hidden_state, self.hidden_state) + self.cell_state = torch.where(mask_h, cell_state, self.cell_state) + + class ParakeetTDTDecoder(nn.Module): """LSTM-based prediction network for TDT.""" @@ -840,17 +911,23 @@ def __init__(self, config: ParakeetTDTConfig): def forward( self, input_ids: torch.LongTensor, - hidden_state: torch.Tensor | None = None, - cell_state: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + decoder_cache: ParakeetTDTDecoderCache | None = None, + decoder_cache_update_mask: torch.BoolTensor | None = None, + ) -> torch.Tensor: input_ids = input_ids.to(self.decoder_projector.weight.device) hidden_cell_states = ( - (hidden_state, cell_state) if hidden_state is not None and cell_state is not None else None + (decoder_cache.hidden_state, decoder_cache.cell_state) + if decoder_cache is not None and decoder_cache.is_initialized + else None ) embeddings = self.embedding(input_ids) lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, hidden_cell_states) decoder_output = self.decoder_projector(lstm_output) - return decoder_output, hidden_state, cell_state + if decoder_cache is not None: + decoder_cache.update( + decoder_output, hidden_state, cell_state, lstm_module=self.lstm, mask=decoder_cache_update_mask + ) + return decoder_output class ParakeetTDTJointNetwork(nn.Module): @@ -912,20 +989,15 @@ class ParakeetTDTOutput(ModelOutput): or `(batch, 1, 1, vocab+durations)` for single-step inference. encoder_outputs (`ParakeetEncoderModelOutput`, *optional*): Encoder outputs with `pooler_output` containing projected hidden states. - decoder_output (`torch.FloatTensor`, *optional*): - Decoder LSTM output, reused during blank-skipping in generation. - decoder_hidden_state (`torch.FloatTensor`, *optional*): - Decoder LSTM hidden state. - decoder_cell_state (`torch.FloatTensor`, *optional*): - Decoder LSTM cell state. + decoder_cache (`ParakeetTDTDecoderCache`, *optional*): + Decoder LSTM cache containing hidden state, cell state, and decoder output. + Updated in-place during generation. """ loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None encoder_outputs: "ParakeetEncoderModelOutput | None" = None - decoder_output: torch.FloatTensor | None = None - decoder_hidden_state: torch.FloatTensor | None = None - decoder_cell_state: torch.FloatTensor | None = None + decoder_cache: ParakeetTDTDecoderCache | None = None # TODO (ebezzam) eventually move to audio_utils or loss_utils for common usage? @@ -1098,9 +1170,9 @@ def forward( input_ids: torch.LongTensor | None = None, encoder_outputs: ParakeetEncoderModelOutput | None = None, encoder_frame_ids: torch.LongTensor | None = None, - decoder_output: torch.Tensor | None = None, - decoder_hidden_state: torch.Tensor | None = None, - decoder_cell_state: torch.Tensor | None = None, + decoder_cache: ParakeetTDTDecoderCache | None = None, + decoder_cache_update_mask: torch.BoolTensor | None = None, + use_decoder_cache: bool | None = None, labels: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTOutput: @@ -1111,12 +1183,18 @@ def forward( Pre-computed encoder outputs with `pooler_output` containing projected hidden states. encoder_frame_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Encoder frame indices for the joint network during generation. - decoder_output (`torch.Tensor`, *optional*): - Pre-computed decoder LSTM output, reused during blank-skipping. - decoder_hidden_state (`torch.Tensor`, *optional*): - Decoder LSTM hidden state from a previous step. - decoder_cell_state (`torch.Tensor`, *optional*): - Decoder LSTM cell state from a previous step. + decoder_cache (`ParakeetTDTDecoderCache`, *optional*): + Decoder LSTM cache. When provided and initialized, the cached `decoder_output` is reused + (e.g. during blank-skipping) instead of running the decoder. When `input_ids` is provided, + the decoder runs and the cache is updated in-place. + decoder_cache_update_mask (`torch.BoolTensor` of shape `(batch_size,)`, *optional*): + Boolean mask controlling which batch elements have their decoder cache updated. + When provided, only elements where the mask is `True` are written to the cache; + other elements retain their previous cached state. Used during generation to + preserve cache for samples that predicted blank tokens. + use_decoder_cache (`bool`, *optional*): + Whether to use a decoder cache. When `True` and `decoder_cache` is `None`, a new cache + is created automatically during the forward pass. Example: @@ -1154,7 +1232,7 @@ def forward( (labels.shape[0], 1), self.config.blank_token_id, dtype=labels.dtype, device=labels.device ) input_ids = torch.cat([blank_tokens, labels], dim=1) - elif input_ids is None and decoder_output is None: + elif input_ids is None and decoder_cache is None: # for inference: start with blank token if not provided input_ids = torch.full( (projected_encoder_output.shape[0], 1), @@ -1163,10 +1241,15 @@ def forward( device=projected_encoder_output.device, ) - if decoder_output is None: - decoder_output, decoder_hidden_state, decoder_cell_state = self.decoder( - input_ids, decoder_hidden_state, decoder_cell_state - ) + if use_decoder_cache and decoder_cache is None: + decoder_cache = ParakeetTDTDecoderCache() + + # Run decoder if we have input_ids (initial step or after emitting a token) + if input_ids is not None: + decoder_output = self.decoder(input_ids, decoder_cache, decoder_cache_update_mask) + else: + # Reuse cached decoder_output (blank-skipping path) + decoder_output = decoder_cache.cache if encoder_frame_ids is not None: batch_indices = torch.arange(projected_encoder_output.shape[0], device=projected_encoder_output.device) @@ -1202,9 +1285,7 @@ def forward( loss=loss, logits=logits, encoder_outputs=encoder_outputs, - decoder_output=decoder_output, - decoder_hidden_state=decoder_hidden_state, - decoder_cell_state=decoder_cell_state, + decoder_cache=decoder_cache, ) @torch.no_grad() @@ -1214,6 +1295,7 @@ def generate( attention_mask: torch.Tensor | None = None, return_timestamps: bool = False, return_dict_in_generate: bool = False, + compile_config: CompileConfig | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTGenerateOutput | torch.LongTensor: r""" @@ -1223,6 +1305,8 @@ def generate( return_timestamps (`bool`, *optional*, defaults to `False`): Whether to return per-token timestamps and durations. When `True`, forces `return_dict_in_generate=True` and includes `token_timestamps` and `token_durations` in the output. + compile_config ([`~generation.CompileConfig`], *optional*): + If provided, `torch.compile` will be applied to the forward calls in the decoding loop. Example: @@ -1254,92 +1338,71 @@ def generate( if return_timestamps: return_dict_in_generate = True - # Initial forward: encode + blank prediction - outputs: ParakeetTDTOutput = self.forward( + model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__ + + # Initial forward: encode + decoder initialization + outputs = model_forward( input_features=input_features, attention_mask=attention_mask, + use_decoder_cache=True, return_dict=True, **kwargs, ) encoder_outputs = outputs.encoder_outputs + decoder_cache = outputs.decoder_cache batch_size, sequence_length = encoder_outputs.pooler_output.shape[:2] device = encoder_outputs.pooler_output.device + # TODO use encoder attention mask like in loss computation? if attention_mask is not None: encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) valid_lengths = encoder_attention_mask.sum(dim=1).int() else: valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) - decoder_output = outputs.decoder_output - decoder_hidden_state = outputs.decoder_hidden_state - decoder_cell_state = outputs.decoder_cell_state - vocab_size = self.config.vocab_size time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) time_indices_current_labels = torch.zeros(batch_size, dtype=torch.long, device=device) active_mask = time_indices < valid_lengths - active_mask_prev = torch.zeros_like(active_mask) - symbols_per_step = torch.zeros(batch_size, dtype=torch.long, device=device) last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) max_output_len = sequence_length * self.config.max_symbols_per_step all_tokens_tensor = torch.full( (batch_size, max_output_len), self.config.pad_token_id, dtype=torch.long, device=device ) + tokens = torch.zeros(batch_size, dtype=torch.long, device=device) + durations = torch.zeros(batch_size, dtype=torch.long, device=device) token_counts = torch.zeros(batch_size, dtype=torch.long, device=device) if return_timestamps: all_frame_indices = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) all_durations_tensor = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) while active_mask.any(): - active_mask_prev.copy_(active_mask) + active_at_start = active_mask.clone() - outputs = self.forward( + time_indices_current_labels = torch.where(active_at_start, time_indices, time_indices_current_labels) + outputs = model_forward( encoder_outputs=encoder_outputs, encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), - decoder_output=decoder_output, - decoder_hidden_state=decoder_hidden_state, - decoder_cell_state=decoder_cell_state, + decoder_cache=decoder_cache, return_dict=True, ) logits = outputs.logits.squeeze(1) - tokens = logits[..., :vocab_size].argmax(dim=-1) - durations = logits[..., vocab_size:].argmax(dim=-1) + tokens = torch.where(active_at_start, logits[..., : self.config.vocab_size].argmax(dim=-1), tokens) + durations = torch.where(active_at_start, logits[..., self.config.vocab_size :].argmax(dim=-1), durations) - blank_mask = active_mask_prev & (tokens == self.config.blank_token_id) + blank_mask = active_at_start & (tokens == self.config.blank_token_id) durations = durations.masked_fill(blank_mask & (durations == 0), 1) # ensure forward progress - time_indices_current_labels.copy_(time_indices) - time_indices = time_indices + durations.masked_fill(~active_mask, 0) + # Advance time for all active samples + time_indices = time_indices + durations.masked_fill(~active_at_start, 0) active_mask = time_indices < valid_lengths - advance_mask = active_mask & blank_mask - - # Skip consecutive blanks - while advance_mask.any(): - time_indices_current_labels = torch.where(advance_mask, time_indices, time_indices_current_labels) - - outputs = self.forward( - encoder_outputs=encoder_outputs, - encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), - decoder_output=decoder_output, - decoder_hidden_state=decoder_hidden_state, - decoder_cell_state=decoder_cell_state, - return_dict=True, - ) - logits = outputs.logits.squeeze(1) - more_tokens = logits[..., :vocab_size].argmax(dim=-1) - more_durations = logits[..., vocab_size:].argmax(dim=-1) - tokens = torch.where(advance_mask, more_tokens, tokens) - durations = torch.where(advance_mask, more_durations, durations) - blank_mask = tokens == self.config.blank_token_id - durations = durations.masked_fill(blank_mask & (durations == 0), 1) - - time_indices = torch.where(advance_mask, time_indices + durations, time_indices) - active_mask = time_indices < valid_lengths - advance_mask = active_mask & blank_mask + # If all remaining active samples predicted blank, skip emit + decoder update + emit_mask = active_at_start & ~blank_mask + if not emit_mask.any(): + continue - emit_mask = active_mask_prev & (tokens != self.config.blank_token_id) + # Emit non-blank tokens emit_indices = token_counts[emit_mask] all_tokens_tensor[emit_mask, emit_indices] = tokens[emit_mask] if return_timestamps: @@ -1347,22 +1410,16 @@ def generate( all_durations_tensor[emit_mask, emit_indices] = durations[emit_mask] token_counts += emit_mask.long() - # Update decoder state for emitted tokens - outputs = self.forward( + # Run decoder for emitted tokens — only update cache for samples that emitted + model_forward( input_ids=tokens.unsqueeze(1), encoder_outputs=encoder_outputs, encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), - decoder_hidden_state=decoder_hidden_state, - decoder_cell_state=decoder_cell_state, + decoder_cache=decoder_cache, + decoder_cache_update_mask=emit_mask, return_dict=True, ) - emit_mask_state = emit_mask.view(1, batch_size, 1) - decoder_hidden_state = torch.where(emit_mask_state, outputs.decoder_hidden_state, decoder_hidden_state) - decoder_cell_state = torch.where(emit_mask_state, outputs.decoder_cell_state, decoder_cell_state) - emit_mask_expanded = emit_mask.view(batch_size, 1, 1) - decoder_output = torch.where(emit_mask_expanded, outputs.decoder_output, decoder_output) - time_changed = time_indices_current_labels != last_label_time symbols_per_step = torch.where(time_changed, 0, symbols_per_step) symbols_per_step = torch.where(emit_mask, symbols_per_step + 1, symbols_per_step) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 466db46b1533..71ed104f1a44 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -22,6 +22,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...generation import CompileConfig from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -31,6 +32,7 @@ TransformersKwargs, auto_docstring, can_return_tuple, + is_torchdynamo_compiling, logging, ) from ...utils.generic import maybe_autocast, merge_with_config_defaults @@ -667,6 +669,69 @@ def generate( return sequences +class ParakeetTDTDecoderCache: + def __init__(self): + self.cache: torch.Tensor | None = None + self.hidden_state: torch.Tensor | None = None + self.cell_state: torch.Tensor | None = None + self.is_initialized: bool = False + + def lazy_initialization(self, hidden_states, lstm_module): + self.cache = torch.zeros( + hidden_states.shape[0], 1, lstm_module.hidden_size, device=hidden_states.device, dtype=hidden_states.dtype + ) + self.hidden_state = torch.zeros( + lstm_module.num_layers, + hidden_states.shape[0], + lstm_module.hidden_size, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + self.cell_state = torch.zeros( + lstm_module.num_layers, + hidden_states.shape[0], + lstm_module.hidden_size, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(self.cache) + torch._dynamo.mark_static_address(self.hidden_state) + torch._dynamo.mark_static_address(self.cell_state) + + self.is_initialized = True + + def update( + self, + decoder_output, + hidden_state, + cell_state, + lstm_module=None, + mask=None, + ): + if not self.is_initialized and lstm_module is not None: + self.lazy_initialization(decoder_output, lstm_module) + elif not self.is_initialized: + raise ValueError( + "ParakeetTDTDecoderCache is not initialized. Make sure to provide lstm_module to the update method." + ) + + if mask is None: + self.hidden_state.copy_(hidden_state) + self.cell_state.copy_(cell_state) + self.cache.copy_(decoder_output) + else: + # Mask to update specific batch elements + mask = mask.to(decoder_output.device) + batch_size = decoder_output.shape[0] + mask_h = mask.view(1, batch_size, 1) + mask_d = mask.view(batch_size, 1, 1) + self.cache = torch.where(mask_d, decoder_output, self.cache) + self.hidden_state = torch.where(mask_h, hidden_state, self.hidden_state) + self.cell_state = torch.where(mask_h, cell_state, self.cell_state) + + class ParakeetTDTDecoder(nn.Module): """LSTM-based prediction network for TDT.""" @@ -685,17 +750,23 @@ def __init__(self, config: ParakeetTDTConfig): def forward( self, input_ids: torch.LongTensor, - hidden_state: torch.Tensor | None = None, - cell_state: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + decoder_cache: ParakeetTDTDecoderCache | None = None, + decoder_cache_update_mask: torch.BoolTensor | None = None, + ) -> torch.Tensor: input_ids = input_ids.to(self.decoder_projector.weight.device) hidden_cell_states = ( - (hidden_state, cell_state) if hidden_state is not None and cell_state is not None else None + (decoder_cache.hidden_state, decoder_cache.cell_state) + if decoder_cache is not None and decoder_cache.is_initialized + else None ) embeddings = self.embedding(input_ids) lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, hidden_cell_states) decoder_output = self.decoder_projector(lstm_output) - return decoder_output, hidden_state, cell_state + if decoder_cache is not None: + decoder_cache.update( + decoder_output, hidden_state, cell_state, lstm_module=self.lstm, mask=decoder_cache_update_mask + ) + return decoder_output # TODO (ebezzam) eventually move to audio_utils or loss_utils for common usage? @@ -885,20 +956,15 @@ class ParakeetTDTOutput(ModelOutput): or `(batch, 1, 1, vocab+durations)` for single-step inference. encoder_outputs (`ParakeetEncoderModelOutput`, *optional*): Encoder outputs with `pooler_output` containing projected hidden states. - decoder_output (`torch.FloatTensor`, *optional*): - Decoder LSTM output, reused during blank-skipping in generation. - decoder_hidden_state (`torch.FloatTensor`, *optional*): - Decoder LSTM hidden state. - decoder_cell_state (`torch.FloatTensor`, *optional*): - Decoder LSTM cell state. + decoder_cache (`ParakeetTDTDecoderCache`, *optional*): + Decoder LSTM cache containing hidden state, cell state, and decoder output. + Updated in-place during generation. """ loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None encoder_outputs: "ParakeetEncoderModelOutput | None" = None - decoder_output: torch.FloatTensor | None = None - decoder_hidden_state: torch.FloatTensor | None = None - decoder_cell_state: torch.FloatTensor | None = None + decoder_cache: ParakeetTDTDecoderCache | None = None @auto_docstring( @@ -943,9 +1009,9 @@ def forward( input_ids: torch.LongTensor | None = None, encoder_outputs: ParakeetEncoderModelOutput | None = None, encoder_frame_ids: torch.LongTensor | None = None, - decoder_output: torch.Tensor | None = None, - decoder_hidden_state: torch.Tensor | None = None, - decoder_cell_state: torch.Tensor | None = None, + decoder_cache: ParakeetTDTDecoderCache | None = None, + decoder_cache_update_mask: torch.BoolTensor | None = None, + use_decoder_cache: bool | None = None, labels: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTOutput: @@ -956,12 +1022,18 @@ def forward( Pre-computed encoder outputs with `pooler_output` containing projected hidden states. encoder_frame_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Encoder frame indices for the joint network during generation. - decoder_output (`torch.Tensor`, *optional*): - Pre-computed decoder LSTM output, reused during blank-skipping. - decoder_hidden_state (`torch.Tensor`, *optional*): - Decoder LSTM hidden state from a previous step. - decoder_cell_state (`torch.Tensor`, *optional*): - Decoder LSTM cell state from a previous step. + decoder_cache (`ParakeetTDTDecoderCache`, *optional*): + Decoder LSTM cache. When provided and initialized, the cached `decoder_output` is reused + (e.g. during blank-skipping) instead of running the decoder. When `input_ids` is provided, + the decoder runs and the cache is updated in-place. + decoder_cache_update_mask (`torch.BoolTensor` of shape `(batch_size,)`, *optional*): + Boolean mask controlling which batch elements have their decoder cache updated. + When provided, only elements where the mask is `True` are written to the cache; + other elements retain their previous cached state. Used during generation to + preserve cache for samples that predicted blank tokens. + use_decoder_cache (`bool`, *optional*): + Whether to use a decoder cache. When `True` and `decoder_cache` is `None`, a new cache + is created automatically during the forward pass. Example: @@ -999,7 +1071,7 @@ def forward( (labels.shape[0], 1), self.config.blank_token_id, dtype=labels.dtype, device=labels.device ) input_ids = torch.cat([blank_tokens, labels], dim=1) - elif input_ids is None and decoder_output is None: + elif input_ids is None and decoder_cache is None: # for inference: start with blank token if not provided input_ids = torch.full( (projected_encoder_output.shape[0], 1), @@ -1008,10 +1080,15 @@ def forward( device=projected_encoder_output.device, ) - if decoder_output is None: - decoder_output, decoder_hidden_state, decoder_cell_state = self.decoder( - input_ids, decoder_hidden_state, decoder_cell_state - ) + if use_decoder_cache and decoder_cache is None: + decoder_cache = ParakeetTDTDecoderCache() + + # Run decoder if we have input_ids (initial step or after emitting a token) + if input_ids is not None: + decoder_output = self.decoder(input_ids, decoder_cache, decoder_cache_update_mask) + else: + # Reuse cached decoder_output (blank-skipping path) + decoder_output = decoder_cache.cache if encoder_frame_ids is not None: batch_indices = torch.arange(projected_encoder_output.shape[0], device=projected_encoder_output.device) @@ -1047,9 +1124,7 @@ def forward( loss=loss, logits=logits, encoder_outputs=encoder_outputs, - decoder_output=decoder_output, - decoder_hidden_state=decoder_hidden_state, - decoder_cell_state=decoder_cell_state, + decoder_cache=decoder_cache, ) @torch.no_grad() @@ -1059,6 +1134,7 @@ def generate( attention_mask: torch.Tensor | None = None, return_timestamps: bool = False, return_dict_in_generate: bool = False, + compile_config: CompileConfig | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTGenerateOutput | torch.LongTensor: r""" @@ -1068,6 +1144,8 @@ def generate( return_timestamps (`bool`, *optional*, defaults to `False`): Whether to return per-token timestamps and durations. When `True`, forces `return_dict_in_generate=True` and includes `token_timestamps` and `token_durations` in the output. + compile_config ([`~generation.CompileConfig`], *optional*): + If provided, `torch.compile` will be applied to the forward calls in the decoding loop. Example: @@ -1099,92 +1177,71 @@ def generate( if return_timestamps: return_dict_in_generate = True - # Initial forward: encode + blank prediction - outputs: ParakeetTDTOutput = self.forward( + model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__ + + # Initial forward: encode + decoder initialization + outputs = model_forward( input_features=input_features, attention_mask=attention_mask, + use_decoder_cache=True, return_dict=True, **kwargs, ) encoder_outputs = outputs.encoder_outputs + decoder_cache = outputs.decoder_cache batch_size, sequence_length = encoder_outputs.pooler_output.shape[:2] device = encoder_outputs.pooler_output.device + # TODO use encoder attention mask like in loss computation? if attention_mask is not None: encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) valid_lengths = encoder_attention_mask.sum(dim=1).int() else: valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) - decoder_output = outputs.decoder_output - decoder_hidden_state = outputs.decoder_hidden_state - decoder_cell_state = outputs.decoder_cell_state - vocab_size = self.config.vocab_size time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) time_indices_current_labels = torch.zeros(batch_size, dtype=torch.long, device=device) active_mask = time_indices < valid_lengths - active_mask_prev = torch.zeros_like(active_mask) - symbols_per_step = torch.zeros(batch_size, dtype=torch.long, device=device) last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) max_output_len = sequence_length * self.config.max_symbols_per_step all_tokens_tensor = torch.full( (batch_size, max_output_len), self.config.pad_token_id, dtype=torch.long, device=device ) + tokens = torch.zeros(batch_size, dtype=torch.long, device=device) + durations = torch.zeros(batch_size, dtype=torch.long, device=device) token_counts = torch.zeros(batch_size, dtype=torch.long, device=device) if return_timestamps: all_frame_indices = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) all_durations_tensor = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) while active_mask.any(): - active_mask_prev.copy_(active_mask) + active_at_start = active_mask.clone() - outputs = self.forward( + time_indices_current_labels = torch.where(active_at_start, time_indices, time_indices_current_labels) + outputs = model_forward( encoder_outputs=encoder_outputs, encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), - decoder_output=decoder_output, - decoder_hidden_state=decoder_hidden_state, - decoder_cell_state=decoder_cell_state, + decoder_cache=decoder_cache, return_dict=True, ) logits = outputs.logits.squeeze(1) - tokens = logits[..., :vocab_size].argmax(dim=-1) - durations = logits[..., vocab_size:].argmax(dim=-1) + tokens = torch.where(active_at_start, logits[..., : self.config.vocab_size].argmax(dim=-1), tokens) + durations = torch.where(active_at_start, logits[..., self.config.vocab_size :].argmax(dim=-1), durations) - blank_mask = active_mask_prev & (tokens == self.config.blank_token_id) + blank_mask = active_at_start & (tokens == self.config.blank_token_id) durations = durations.masked_fill(blank_mask & (durations == 0), 1) # ensure forward progress - time_indices_current_labels.copy_(time_indices) - time_indices = time_indices + durations.masked_fill(~active_mask, 0) + # Advance time for all active samples + time_indices = time_indices + durations.masked_fill(~active_at_start, 0) active_mask = time_indices < valid_lengths - advance_mask = active_mask & blank_mask - - # Skip consecutive blanks - while advance_mask.any(): - time_indices_current_labels = torch.where(advance_mask, time_indices, time_indices_current_labels) - - outputs = self.forward( - encoder_outputs=encoder_outputs, - encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), - decoder_output=decoder_output, - decoder_hidden_state=decoder_hidden_state, - decoder_cell_state=decoder_cell_state, - return_dict=True, - ) - logits = outputs.logits.squeeze(1) - more_tokens = logits[..., :vocab_size].argmax(dim=-1) - more_durations = logits[..., vocab_size:].argmax(dim=-1) - tokens = torch.where(advance_mask, more_tokens, tokens) - durations = torch.where(advance_mask, more_durations, durations) - blank_mask = tokens == self.config.blank_token_id - durations = durations.masked_fill(blank_mask & (durations == 0), 1) - - time_indices = torch.where(advance_mask, time_indices + durations, time_indices) - active_mask = time_indices < valid_lengths - advance_mask = active_mask & blank_mask + # If all remaining active samples predicted blank, skip emit + decoder update + emit_mask = active_at_start & ~blank_mask + if not emit_mask.any(): + continue - emit_mask = active_mask_prev & (tokens != self.config.blank_token_id) + # Emit non-blank tokens emit_indices = token_counts[emit_mask] all_tokens_tensor[emit_mask, emit_indices] = tokens[emit_mask] if return_timestamps: @@ -1192,22 +1249,16 @@ def generate( all_durations_tensor[emit_mask, emit_indices] = durations[emit_mask] token_counts += emit_mask.long() - # Update decoder state for emitted tokens - outputs = self.forward( + # Run decoder for emitted tokens — only update cache for samples that emitted + model_forward( input_ids=tokens.unsqueeze(1), encoder_outputs=encoder_outputs, encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), - decoder_hidden_state=decoder_hidden_state, - decoder_cell_state=decoder_cell_state, + decoder_cache=decoder_cache, + decoder_cache_update_mask=emit_mask, return_dict=True, ) - emit_mask_state = emit_mask.view(1, batch_size, 1) - decoder_hidden_state = torch.where(emit_mask_state, outputs.decoder_hidden_state, decoder_hidden_state) - decoder_cell_state = torch.where(emit_mask_state, outputs.decoder_cell_state, decoder_cell_state) - emit_mask_expanded = emit_mask.view(batch_size, 1, 1) - decoder_output = torch.where(emit_mask_expanded, outputs.decoder_output, decoder_output) - time_changed = time_indices_current_labels != last_label_time symbols_per_step = torch.where(time_changed, 0, symbols_per_step) symbols_per_step = torch.where(emit_mask, symbols_per_step + 1, symbols_per_step) From b41a8ee6ec3c29940e0b9b5bd09ecc29fa67e1e3 Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 24 Mar 2026 20:28:13 +0100 Subject: [PATCH 34/67] Cleaner, better conventions. --- .../models/lasr/configuration_lasr.py | 4 +-- src/transformers/models/lasr/modeling_lasr.py | 35 +++++++++++++------ src/transformers/models/lasr/modular_lasr.py | 22 ++++++++---- .../models/parakeet/modeling_parakeet.py | 26 +++++++------- .../models/parakeet/modular_parakeet.py | 26 +++++++------- 5 files changed, 67 insertions(+), 46 deletions(-) diff --git a/src/transformers/models/lasr/configuration_lasr.py b/src/transformers/models/lasr/configuration_lasr.py index b3f7e722c4f3..d7c040dc4cc5 100644 --- a/src/transformers/models/lasr/configuration_lasr.py +++ b/src/transformers/models/lasr/configuration_lasr.py @@ -22,7 +22,7 @@ from ...utils import auto_docstring -@auto_docstring(checkpoint="TODO") +@auto_docstring(checkpoint="google/medasr") class LasrEncoderConfig(PreTrainedConfig): r""" convolution_bias (`bool`, *optional*, defaults to `False`): @@ -124,7 +124,7 @@ def __init__( super().__init__(**kwargs) -@auto_docstring(checkpoint="TODO") +@auto_docstring(checkpoint="google/medasr") class LasrCTCConfig(PreTrainedConfig): r""" ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`): diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py index 199686ee3d7d..df6eff9be010 100644 --- a/src/transformers/models/lasr/modeling_lasr.py +++ b/src/transformers/models/lasr/modeling_lasr.py @@ -29,7 +29,7 @@ from ...integrations import use_kernel_func_from_hub, use_kernelized_func from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, CausalLMOutput +from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -459,6 +459,17 @@ def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length return attention_mask +@dataclass +@auto_docstring( + custom_intro=""" + Extends [~modeling_outputs.BaseModelOutputWithPooling] to include the output attention mask since sequence length + is not preserved in the model's forward. + """ +) +class LasrEncoderModelOutput(BaseModelOutputWithPooling): + attention_mask: torch.Tensor | None = None + + @auto_docstring( custom_intro=""" The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100). @@ -493,8 +504,9 @@ def forward( self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, + output_attention_mask: bool | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutput: + ) -> LasrEncoderModelOutput: r""" Example: @@ -525,8 +537,10 @@ def forward( cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training) sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training) + output_mask = None if attention_mask is not None: - attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1]) + output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1]) + attention_mask = output_mask attention_mask = create_bidirectional_mask( config=self.config, @@ -552,7 +566,10 @@ def forward( hidden_states = self.out_norm(hidden_states) - return BaseModelOutput(last_hidden_state=hidden_states) + return LasrEncoderModelOutput( + last_hidden_state=hidden_states, + attention_mask=output_mask.int() if output_attention_mask and output_mask is not None else None, + ) @dataclass @@ -627,6 +644,8 @@ def forward( >>> print(outputs.loss) ```""" + if labels is not None: + kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.encoder( input_features=input_features, attention_mask=attention_mask, @@ -638,11 +657,7 @@ def forward( loss = None if labels is not None: - # retrieve loss input_lengths from attention_mask - attention_mask = ( - attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long) - ) - input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) + encoder_lengths = encoder_outputs.attention_mask.sum(-1) # assuming that padded tokens are filled with pad_token_id when not being attended to labels_mask = labels != self.config.pad_token_id @@ -656,7 +671,7 @@ def forward( loss = nn.functional.ctc_loss( log_probs, flattened_targets, - input_lengths, + encoder_lengths, target_lengths, blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, diff --git a/src/transformers/models/lasr/modular_lasr.py b/src/transformers/models/lasr/modular_lasr.py index 6665d38cde14..68b1c5a9df65 100644 --- a/src/transformers/models/lasr/modular_lasr.py +++ b/src/transformers/models/lasr/modular_lasr.py @@ -21,7 +21,6 @@ from torch import nn from ...masking_utils import create_bidirectional_mask -from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import ProcessingKwargs, Unpack from ...tokenization_utils_tokenizers import TokenizersBackend @@ -33,6 +32,7 @@ from ..parakeet.modeling_parakeet import ( ParakeetEncoderBlock, ParakeetEncoderConvolutionModule, + ParakeetEncoderModelOutput, ParakeetForCTC, ParakeetPreTrainedModel, ) @@ -168,7 +168,7 @@ def _refine_timestamps_tdt(self, *args, **kwargs): raise NotImplementedError("Not needed") -@auto_docstring(checkpoint="TODO") +@auto_docstring(checkpoint="google/medasr") class LasrEncoderConfig(ParakeetEncoderConfig): r""" convolution_bias (`bool`, *optional*, defaults to `False`): @@ -269,7 +269,7 @@ def __init__( del self.scale_input -@auto_docstring(checkpoint="TODO") +@auto_docstring(checkpoint="google/medasr") class LasrCTCConfig(ParakeetCTCConfig): r""" ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`): @@ -465,6 +465,10 @@ def _get_subsampling_output_length(self, input_lengths: torch.Tensor): return input_lengths +class LasrEncoderModelOutput(ParakeetEncoderModelOutput): + pass + + @auto_docstring( custom_intro=""" The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100). @@ -499,8 +503,9 @@ def forward( self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, + output_attention_mask: bool | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutput: + ) -> LasrEncoderModelOutput: r""" Example: @@ -531,8 +536,10 @@ def forward( cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training) sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training) + output_mask = None if attention_mask is not None: - attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1]) + output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1]) + attention_mask = output_mask attention_mask = create_bidirectional_mask( config=self.config, @@ -558,7 +565,10 @@ def forward( hidden_states = self.out_norm(hidden_states) - return BaseModelOutput(last_hidden_state=hidden_states) + return LasrEncoderModelOutput( + last_hidden_state=hidden_states, + attention_mask=output_mask.int() if output_attention_mask and output_mask is not None else None, + ) class LasrForCTC(ParakeetForCTC): diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 84cf31b07782..1efc69d73405 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -67,6 +67,15 @@ class ParakeetEncoderRelPositionalEncoding(nn.Module): def __init__(self, config: ParakeetEncoderConfig, device=None): super().__init__() self.max_position_embeddings = config.max_position_embeddings + self.config = config + inv_freq = self.compute_default_relative_positional_parameters(config, device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @staticmethod + def compute_default_relative_positional_parameters( + config: ParakeetEncoderConfig | None = None, + device=None, + ) -> torch.Tensor: base = 10000.0 inv_freq = 1.0 / ( base @@ -75,18 +84,11 @@ def __init__(self, config: ParakeetEncoderConfig, device=None): / config.hidden_size ) ) - - self.register_buffer("inv_freq", inv_freq, persistent=False) + return inv_freq @torch.no_grad() def forward(self, hidden_states: torch.Tensor): seq_length = hidden_states.shape[1] - if seq_length > self.max_position_embeddings: - raise ValueError( - f"Sequence Length: {seq_length} has to be less or equal than " - f"config.max_position_embeddings {self.max_position_embeddings}." - ) - position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device) inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device) @@ -512,12 +514,8 @@ def _init_weights(self, module): init.normal_(module.bias_u, mean=0.0, std=std) init.normal_(module.bias_v, mean=0.0, std=std) elif isinstance(module, ParakeetEncoderRelPositionalEncoding): - encoder_config = getattr(self.config, "encoder_config", self.config) - inv_freq = 1.0 / ( - 10000.0 - ** (torch.arange(0, encoder_config.hidden_size, 2, dtype=torch.int64) / encoder_config.hidden_size) - ) - init.copy_(module.inv_freq, inv_freq) + buffer_value = module.compute_default_relative_positional_parameters(module.config) + init.copy_(module.inv_freq, buffer_value) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): encoder_config = getattr(self.config, "encoder_config", self.config) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 71ed104f1a44..87c894df2811 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -63,6 +63,15 @@ class ParakeetEncoderRelPositionalEncoding(nn.Module): def __init__(self, config: ParakeetEncoderConfig, device=None): super().__init__() self.max_position_embeddings = config.max_position_embeddings + self.config = config + inv_freq = self.compute_default_relative_positional_parameters(config, device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @staticmethod + def compute_default_relative_positional_parameters( + config: ParakeetEncoderConfig | None = None, + device=None, + ) -> torch.Tensor: base = 10000.0 inv_freq = 1.0 / ( base @@ -71,18 +80,11 @@ def __init__(self, config: ParakeetEncoderConfig, device=None): / config.hidden_size ) ) - - self.register_buffer("inv_freq", inv_freq, persistent=False) + return inv_freq @torch.no_grad() def forward(self, hidden_states: torch.Tensor): seq_length = hidden_states.shape[1] - if seq_length > self.max_position_embeddings: - raise ValueError( - f"Sequence Length: {seq_length} has to be less or equal than " - f"config.max_position_embeddings {self.max_position_embeddings}." - ) - position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device) inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device) @@ -351,12 +353,8 @@ def _init_weights(self, module): init.normal_(module.bias_u, mean=0.0, std=std) init.normal_(module.bias_v, mean=0.0, std=std) elif isinstance(module, ParakeetEncoderRelPositionalEncoding): - encoder_config = getattr(self.config, "encoder_config", self.config) - inv_freq = 1.0 / ( - 10000.0 - ** (torch.arange(0, encoder_config.hidden_size, 2, dtype=torch.int64) / encoder_config.hidden_size) - ) - init.copy_(module.inv_freq, inv_freq) + buffer_value = module.compute_default_relative_positional_parameters(module.config) + init.copy_(module.inv_freq, buffer_value) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): encoder_config = getattr(self.config, "encoder_config", self.config) From 6c914dbe665408df3836ff76113ebfdaa321092d Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 24 Mar 2026 21:19:03 +0100 Subject: [PATCH 35/67] Update with main. --- docs/source/en/model_doc/parakeet.md | 9 +- .../models/lasr/configuration_lasr.py | 41 +++---- src/transformers/models/lasr/modeling_lasr.py | 7 +- src/transformers/models/lasr/modular_lasr.py | 48 ++++---- .../models/parakeet/configuration_parakeet.py | 104 +++++++----------- 5 files changed, 84 insertions(+), 125 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index e588f2bbd1b4..3ec4bdfd4433 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -301,7 +301,7 @@ outputs.loss.backward() ``` - + ```py import torch @@ -331,14 +331,9 @@ loss_fn = TDTLossNumba( # Create wrapper to adapt NeMo loss to Transformers signature def nemo_loss_wrapper(token_logits, duration_logits, targets, logit_lengths, target_lengths, **kwargs): """Adapter function that converts Transformers loss signature to NeMo signature.""" - # Concatenate token and duration logits (NeMo expects combined logits) acts = torch.cat([token_logits, duration_logits], dim=-1) - - # Use actual tensor shape for act_lens (NeMo requires T dimension to match max(act_lens)) - # The logit_lengths may not exactly match due to padding/masking edge cases batch_size, T, U = acts.shape[:3] act_lens = torch.full((batch_size,), T, dtype=torch.long, device=acts.device) - # NeMo requires float32 (Numba doesn't support float16/bfloat16) and int64 per_sample_losses = nemo_loss_fn( acts=acts.float(), @@ -346,8 +341,6 @@ def nemo_loss_wrapper(token_logits, duration_logits, targets, logit_lengths, tar act_lens=act_lens, label_lens=target_lengths.long(), ) - - # Normalize by target lengths and take mean across batch return (per_sample_losses / target_lengths.float()).mean() # Monkey-patch the model's loss function diff --git a/src/transformers/models/lasr/configuration_lasr.py b/src/transformers/models/lasr/configuration_lasr.py index 3b35086830a5..d57ae10e424c 100644 --- a/src/transformers/models/lasr/configuration_lasr.py +++ b/src/transformers/models/lasr/configuration_lasr.py @@ -48,21 +48,18 @@ class LasrEncoderConfig(PreTrainedConfig): The momentum for the batch normalization layers Example: - ```python - >>> from transformers import LasrEncoderModel, LasrEncoderConfig + ```python + >>> from transformers import LasrEncoderModel, LasrEncoderConfig - >>> # Initializing a `LasrEncoder` configuration - >>> configuration = LasrEncoderConfig() + >>> # Initializing a `LasrEncoder` configuration + >>> configuration = LasrEncoderConfig() - >>> # Initializing a model from the configuration - >>> model = LasrEncoderModel(configuration) + >>> # Initializing a model from the configuration + >>> model = LasrEncoderModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ``` - - This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details - and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO). + >>> # Accessing the model configuration + >>> configuration = model.config + ``` """ model_type = "lasr_encoder" @@ -111,17 +108,15 @@ class LasrCTCConfig(PreTrainedConfig): of [`LasrForCTC`]. Example: - ```python - >>> from transformers import LasrForCTC, LasrCTCConfig - >>> # Initializing a Lasr configuration - >>> configuration = LasrCTCConfig() - >>> # Initializing a model from the configuration - >>> model = LasrForCTC(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ``` - This configuration class is based on the Lasr CTC architecture from Google Health AI. You can find more details - and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO). + ```python + >>> from transformers import LasrForCTC, LasrCTCConfig + >>> # Initializing a Lasr configuration + >>> configuration = LasrCTCConfig() + >>> # Initializing a model from the configuration + >>> model = LasrForCTC(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` """ model_type = "lasr_ctc" diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py index df6eff9be010..699f7911c89d 100644 --- a/src/transformers/models/lasr/modeling_lasr.py +++ b/src/transformers/models/lasr/modeling_lasr.py @@ -508,13 +508,16 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> LasrEncoderModelOutput: r""" + output_attention_mask (`bool`, *optional*): + Whether to return the output attention mask. + Example: ```python >>> from transformers import AutoProcessor, LasrEncoder >>> from datasets import load_dataset, Audio - >>> model_id = TODO + >>> model_id = "google/medasr" >>> processor = AutoProcessor.from_pretrained(model_id) >>> encoder = ParakeetEncoder.from_pretrained(model_id) @@ -700,7 +703,7 @@ def generate( >>> from transformers import AutoProcessor, LasrForCTC >>> from datasets import load_dataset, Audio - >>> model_id = TODO + >>> model_id = "google/medasr" >>> processor = AutoProcessor.from_pretrained(model_id) >>> model = LasrForCTC.from_pretrained(model_id) diff --git a/src/transformers/models/lasr/modular_lasr.py b/src/transformers/models/lasr/modular_lasr.py index 52ff92e0f51f..5636cefa4676 100644 --- a/src/transformers/models/lasr/modular_lasr.py +++ b/src/transformers/models/lasr/modular_lasr.py @@ -193,21 +193,18 @@ class LasrEncoderConfig(ParakeetEncoderConfig): The momentum for the batch normalization layers Example: - ```python - >>> from transformers import LasrEncoderModel, LasrEncoderConfig - - >>> # Initializing a `LasrEncoder` configuration - >>> configuration = LasrEncoderConfig() + ```python + >>> from transformers import LasrEncoderModel, LasrEncoderConfig - >>> # Initializing a model from the configuration - >>> model = LasrEncoderModel(configuration) + >>> # Initializing a `LasrEncoder` configuration + >>> configuration = LasrEncoderConfig() - >>> # Accessing the model configuration - >>> configuration = model.config - ``` + >>> # Initializing a model from the configuration + >>> model = LasrEncoderModel(configuration) - This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details - and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO). + >>> # Accessing the model configuration + >>> configuration = model.config + ``` """ hidden_size: int = 512 @@ -242,17 +239,15 @@ class LasrCTCConfig(ParakeetCTCConfig): of [`LasrForCTC`]. Example: - ```python - >>> from transformers import LasrForCTC, LasrCTCConfig - >>> # Initializing a Lasr configuration - >>> configuration = LasrCTCConfig() - >>> # Initializing a model from the configuration - >>> model = LasrForCTC(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ``` - This configuration class is based on the Lasr CTC architecture from Google Health AI. You can find more details - and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO). + ```python + >>> from transformers import LasrForCTC, LasrCTCConfig + >>> # Initializing a Lasr configuration + >>> configuration = LasrCTCConfig() + >>> # Initializing a model from the configuration + >>> model = LasrForCTC(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` """ vocab_size: int = 512 @@ -453,13 +448,16 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> LasrEncoderModelOutput: r""" + output_attention_mask (`bool`, *optional*): + Whether to return the output attention mask. + Example: ```python >>> from transformers import AutoProcessor, LasrEncoder >>> from datasets import load_dataset, Audio - >>> model_id = TODO + >>> model_id = "google/medasr" >>> processor = AutoProcessor.from_pretrained(model_id) >>> encoder = ParakeetEncoder.from_pretrained(model_id) @@ -526,7 +524,7 @@ def generate(**super_kwargs): >>> from transformers import AutoProcessor, LasrForCTC >>> from datasets import load_dataset, Audio - >>> model_id = TODO + >>> model_id = "google/medasr" >>> processor = AutoProcessor.from_pretrained(model_id) >>> model = LasrForCTC.from_pretrained(model_id) diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 9cd0be412296..fb6bc1c04d7d 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Parakeet model configuration.""" from huggingface_hub.dataclasses import strict @@ -43,21 +42,18 @@ class ParakeetEncoderConfig(PreTrainedConfig): Whether to scale the input embeddings. Example: - ```python - >>> from transformers import ParakeetEncoderModel, ParakeetEncoderConfig - - >>> # Initializing a `ParakeetEncoder` configuration - >>> configuration = ParakeetEncoderConfig() + ```python + >>> from transformers import ParakeetEncoderModel, ParakeetEncoderConfig - >>> # Initializing a model from the configuration - >>> model = ParakeetEncoderModel(configuration) + >>> # Initializing a `ParakeetEncoder` configuration + >>> configuration = ParakeetEncoderConfig() - >>> # Accessing the model configuration - >>> configuration = model.config - ``` + >>> # Initializing a model from the configuration + >>> model = ParakeetEncoderModel(configuration) - This configuration class is based on the ParakeetEncoder architecture from NVIDIA NeMo. You can find more details - and pre-trained models at [nvidia/parakeet-ctc-1.1b](https://huggingface.co/nvidia/parakeet-ctc-1.1b). + >>> # Accessing the model configuration + >>> configuration = model.config + ``` """ model_type = "parakeet_encoder" @@ -136,85 +132,59 @@ def __post_init__(self, **kwargs): @auto_docstring(checkpoint="bezzam/parakeet-tdt-0.6b-v3-hf") +@strict class ParakeetTDTConfig(PreTrainedConfig): r""" + encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): + The config object or dictionary of the encoder. decoder_hidden_size (`int`, *optional*, defaults to 640): Hidden size of the LSTM prediction network and joint network. num_decoder_layers (`int`, *optional*, defaults to 2): Number of LSTM layers in the prediction network. - num_duration_bins (`int`, *optional*, defaults to 5): - Number of duration bins for predicting token durations. durations (`list[int]`, *optional*, defaults to `[0, 1, 2, 3, 4]`): Token duration values that can be predicted. Each value represents how many frames a token or blank emission spans. max_symbols_per_step (`int`, *optional*, defaults to 10): Maximum number of symbols to emit per encoder time step during greedy decoding. - encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): - The config object or dictionary of the encoder. blank_token_id (`int`, *optional*, defaults to 8192): Blank token id. Different from `pad_token_id` for TDT. Example: - ```python - >>> from transformers import ParakeetForTDT, ParakeetTDTConfig + ```python + >>> from transformers import ParakeetForTDT, ParakeetTDTConfig - >>> # Initializing a Parakeet TDT configuration - >>> configuration = ParakeetTDTConfig() + >>> # Initializing a Parakeet TDT configuration + >>> configuration = ParakeetTDTConfig() - >>> # Initializing a model from the configuration - >>> model = ParakeetForTDT(configuration) + >>> # Initializing a model from the configuration + >>> model = ParakeetForTDT(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ``` + >>> # Accessing the model configuration + >>> configuration = model.config + ``` """ model_type = "parakeet_tdt" sub_configs = {"encoder_config": ParakeetEncoderConfig} - def __init__( - self, - vocab_size=8193, - decoder_hidden_size=640, - num_decoder_layers=2, - durations=[0, 1, 2, 3, 4], - hidden_act="relu", - max_symbols_per_step=10, - encoder_config: dict | ParakeetEncoderConfig = None, - pad_token_id=2, - blank_token_id=8192, - **kwargs, - ): - self.vocab_size = vocab_size - self.decoder_hidden_size = decoder_hidden_size - self.num_decoder_layers = num_decoder_layers - self.durations = durations - self.hidden_act = hidden_act - self.max_symbols_per_step = max_symbols_per_step - - if isinstance(encoder_config, dict): - self.encoder_config = ParakeetEncoderConfig(**encoder_config) - elif encoder_config is None: - self.encoder_config = ParakeetEncoderConfig() - else: - self.encoder_config = encoder_config + vocab_size: int = 8193 + decoder_hidden_size: int = 640 + num_decoder_layers: int = 2 + hidden_act: str = "relu" + max_symbols_per_step: int = 10 + durations: list[int] | tuple[int, ...] = (0, 1, 2, 3, 4) + encoder_config: dict | PreTrainedConfig | None = None + pad_token_id: int = 2 + blank_token_id: int = 8192 + is_encoder_decoder: bool = True + def __post_init__(self, **kwargs): + if isinstance(self.encoder_config, dict): + self.encoder_config = ParakeetEncoderConfig(**self.encoder_config) + elif self.encoder_config is None: + self.encoder_config = ParakeetEncoderConfig() self.initializer_range = self.encoder_config.initializer_range - self.blank_token_id = blank_token_id - self.pad_token_id = pad_token_id - self.is_encoder_decoder = True - - super().__init__(**kwargs) - - @classmethod - def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): - r""" - Instantiate a [`ParakeetTDTConfig`] (or a derived class) from parakeet encoder model configuration. - - Returns: - [`ParakeetTDTConfig`]: An instance of a configuration object - """ - return cls(encoder_config=encoder_config.to_dict(), **kwargs) + super().__post_init__(**kwargs) __all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig", "ParakeetTDTConfig"] From 756cee1eb11848289941aaf50a609505223fc309 Mon Sep 17 00:00:00 2001 From: Eric B Date: Thu, 26 Mar 2026 09:49:56 +0100 Subject: [PATCH 36/67] doc nits --- docs/source/en/model_doc/parakeet.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index 3ec4bdfd4433..f90d476cd3cc 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -269,7 +269,7 @@ outputs.loss.backward() ### TDT Training -The TDT loss has been implemented within Transformers to enable training. For faster training (around 10-50x depending on batch size), consider using NeMo's `TDTLossNumba`. Note that this requires installing the NeMo toolkit with `pip install nemo_toolkit[asr]`. +The TDT loss has been implemented within Transformers to enable training. For faster training (around 10x), consider using NeMo's `TDTLossNumba`. Note that this requires installing the NeMo toolkit with `pip install nemo_toolkit[asr]`. @@ -319,16 +319,12 @@ model = AutoModelForTDT.from_pretrained(model_id, dtype=torch.bfloat16, device_m model.train() # Initialize NeMo TDT loss -# NOTE: NeMo's TDTLossNumba doesn't seem to do normalization with target lengths as suggested by its docstring so doing manually: -# - Docstring: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py#L373 -# - Normalization: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py#L247-L253 loss_fn = TDTLossNumba( blank=model.config.blank_token_id, durations=model.config.durations, reduction="none", ) -# Create wrapper to adapt NeMo loss to Transformers signature def nemo_loss_wrapper(token_logits, duration_logits, targets, logit_lengths, target_lengths, **kwargs): """Adapter function that converts Transformers loss signature to NeMo signature.""" acts = torch.cat([token_logits, duration_logits], dim=-1) @@ -341,6 +337,9 @@ def nemo_loss_wrapper(token_logits, duration_logits, targets, logit_lengths, tar act_lens=act_lens, label_lens=target_lengths.long(), ) + # NOTE: NeMo's TDTLossNumba doesn't do normalization with target lengths as suggested by its docstring so we do manually: + # - Docstring: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py#L373 + # - Expected normalization: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py#L247-L253 return (per_sample_losses / target_lengths.float()).mean() # Monkey-patch the model's loss function From f30c53649c0d6852377e2df4da8d7ab471b0dcf0 Mon Sep 17 00:00:00 2001 From: Eric B Date: Thu, 26 Mar 2026 11:51:41 +0100 Subject: [PATCH 37/67] Imitate whisper for encoder outputs as input --- .../models/parakeet/configuration_parakeet.py | 1 - .../models/parakeet/modeling_parakeet.py | 56 +++++++++++++------ .../models/parakeet/modular_parakeet.py | 56 +++++++++++++------ tests/test_modeling_common.py | 5 +- 4 files changed, 82 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index fb6bc1c04d7d..babc9526f760 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -176,7 +176,6 @@ class ParakeetTDTConfig(PreTrainedConfig): encoder_config: dict | PreTrainedConfig | None = None pad_token_id: int = 2 blank_token_id: int = 8192 - is_encoder_decoder: bool = True def __post_init__(self, **kwargs): if isinstance(self.encoder_config, dict): diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 1efc69d73405..bdec534629b8 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -975,7 +975,7 @@ class ParakeetTDTGenerateOutput(ModelOutput): @dataclass -class ParakeetTDTOutput(ModelOutput): +class ParakeetTDTOutput(BaseModelOutputWithPooling): """ Output of the Parakeet TDT forward pass. @@ -985,8 +985,8 @@ class ParakeetTDTOutput(ModelOutput): logits (`torch.FloatTensor`): Joint token and duration logits. Shape is `(batch, T, U+1, vocab+durations)` for training or `(batch, 1, 1, vocab+durations)` for single-step inference. - encoder_outputs (`ParakeetEncoderModelOutput`, *optional*): - Encoder outputs with `pooler_output` containing projected hidden states. + attention_mask (`torch.Tensor`, *optional*): + Encoder output attention mask after subsampling. decoder_cache (`ParakeetTDTDecoderCache`, *optional*): Decoder LSTM cache containing hidden state, cell state, and decoder output. Updated in-place during generation. @@ -994,7 +994,7 @@ class ParakeetTDTOutput(ModelOutput): loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None - encoder_outputs: "ParakeetEncoderModelOutput | None" = None + attention_mask: torch.Tensor | None = None decoder_cache: ParakeetTDTDecoderCache | None = None @@ -1145,6 +1145,7 @@ def __init__(self, config: ParakeetTDTConfig): self.post_init() + @can_return_tuple def get_audio_features( self, input_features: torch.Tensor, @@ -1166,7 +1167,7 @@ def forward( input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, input_ids: torch.LongTensor | None = None, - encoder_outputs: ParakeetEncoderModelOutput | None = None, + encoder_outputs: tuple[torch.FloatTensor] | None = None, encoder_frame_ids: torch.LongTensor | None = None, decoder_cache: ParakeetTDTDecoderCache | None = None, decoder_cache_update_mask: torch.BoolTensor | None = None, @@ -1177,8 +1178,9 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): Decoder input token ids for single-step inference. - encoder_outputs (`ParakeetEncoderModelOutput`, *optional*): - Pre-computed encoder outputs with `pooler_output` containing projected hidden states. + encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): + Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask). + Can be a tuple or `ParakeetEncoderModelOutput`. encoder_frame_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Encoder frame indices for the joint network during generation. decoder_cache (`ParakeetTDTDecoderCache`, *optional*): @@ -1222,6 +1224,14 @@ def forward( attention_mask=attention_mask, **kwargs, ) + elif not isinstance(encoder_outputs, ParakeetEncoderModelOutput): + encoder_outputs = ParakeetEncoderModelOutput( + last_hidden_state=encoder_outputs[0], + pooler_output=encoder_outputs[1], + hidden_states=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + attentions=encoder_outputs[3] if len(encoder_outputs) > 3 else None, + attention_mask=encoder_outputs[4] if len(encoder_outputs) > 4 else None, + ) projected_encoder_output = encoder_outputs.pooler_output if labels is not None: @@ -1282,7 +1292,11 @@ def forward( return ParakeetTDTOutput( loss=loss, logits=logits, - encoder_outputs=encoder_outputs, + last_hidden_state=encoder_outputs.last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + pooler_output=encoder_outputs.pooler_output, + attention_mask=encoder_outputs.attention_mask, decoder_cache=decoder_cache, ) @@ -1339,6 +1353,7 @@ def generate( model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__ # Initial forward: encode + decoder initialization + kwargs.setdefault("output_attention_mask", True) outputs = model_forward( input_features=input_features, attention_mask=attention_mask, @@ -1346,15 +1361,22 @@ def generate( return_dict=True, **kwargs, ) - encoder_outputs = outputs.encoder_outputs + + # Reconstruct encoder_outputs for subsequent forward calls + encoder_outputs = ParakeetEncoderModelOutput( + last_hidden_state=outputs.last_hidden_state, + pooler_output=outputs.pooler_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + attention_mask=outputs.attention_mask, + ) decoder_cache = outputs.decoder_cache - batch_size, sequence_length = encoder_outputs.pooler_output.shape[:2] - device = encoder_outputs.pooler_output.device + batch_size, sequence_length = outputs.pooler_output.shape[:2] + device = outputs.pooler_output.device - # TODO use encoder attention mask like in loss computation? - if attention_mask is not None: - encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) - valid_lengths = encoder_attention_mask.sum(dim=1).int() + # Use encoder attention mask for valid lengths + if outputs.attention_mask is not None: + valid_lengths = outputs.attention_mask.sum(dim=1).int() else: valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) @@ -1439,8 +1461,8 @@ def generate( sequences=sequences, token_timestamps=token_timestamps, token_durations=token_durations, - attentions=encoder_outputs.attentions, - hidden_states=encoder_outputs.hidden_states, + attentions=outputs.attentions, + hidden_states=outputs.hidden_states, ) return sequences diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 87c894df2811..80e0e61ad70d 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -942,7 +942,7 @@ class ParakeetTDTGenerateOutput(ModelOutput): @dataclass -class ParakeetTDTOutput(ModelOutput): +class ParakeetTDTOutput(BaseModelOutputWithPooling): """ Output of the Parakeet TDT forward pass. @@ -952,8 +952,8 @@ class ParakeetTDTOutput(ModelOutput): logits (`torch.FloatTensor`): Joint token and duration logits. Shape is `(batch, T, U+1, vocab+durations)` for training or `(batch, 1, 1, vocab+durations)` for single-step inference. - encoder_outputs (`ParakeetEncoderModelOutput`, *optional*): - Encoder outputs with `pooler_output` containing projected hidden states. + attention_mask (`torch.Tensor`, *optional*): + Encoder output attention mask after subsampling. decoder_cache (`ParakeetTDTDecoderCache`, *optional*): Decoder LSTM cache containing hidden state, cell state, and decoder output. Updated in-place during generation. @@ -961,7 +961,7 @@ class ParakeetTDTOutput(ModelOutput): loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None - encoder_outputs: "ParakeetEncoderModelOutput | None" = None + attention_mask: torch.Tensor | None = None decoder_cache: ParakeetTDTDecoderCache | None = None @@ -984,6 +984,7 @@ def __init__(self, config: ParakeetTDTConfig): self.post_init() + @can_return_tuple def get_audio_features( self, input_features: torch.Tensor, @@ -1005,7 +1006,7 @@ def forward( input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, input_ids: torch.LongTensor | None = None, - encoder_outputs: ParakeetEncoderModelOutput | None = None, + encoder_outputs: tuple[torch.FloatTensor] | None = None, encoder_frame_ids: torch.LongTensor | None = None, decoder_cache: ParakeetTDTDecoderCache | None = None, decoder_cache_update_mask: torch.BoolTensor | None = None, @@ -1016,8 +1017,9 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): Decoder input token ids for single-step inference. - encoder_outputs (`ParakeetEncoderModelOutput`, *optional*): - Pre-computed encoder outputs with `pooler_output` containing projected hidden states. + encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): + Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask). + Can be a tuple or `ParakeetEncoderModelOutput`. encoder_frame_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Encoder frame indices for the joint network during generation. decoder_cache (`ParakeetTDTDecoderCache`, *optional*): @@ -1061,6 +1063,14 @@ def forward( attention_mask=attention_mask, **kwargs, ) + elif not isinstance(encoder_outputs, ParakeetEncoderModelOutput): + encoder_outputs = ParakeetEncoderModelOutput( + last_hidden_state=encoder_outputs[0], + pooler_output=encoder_outputs[1], + hidden_states=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + attentions=encoder_outputs[3] if len(encoder_outputs) > 3 else None, + attention_mask=encoder_outputs[4] if len(encoder_outputs) > 4 else None, + ) projected_encoder_output = encoder_outputs.pooler_output if labels is not None: @@ -1121,7 +1131,11 @@ def forward( return ParakeetTDTOutput( loss=loss, logits=logits, - encoder_outputs=encoder_outputs, + last_hidden_state=encoder_outputs.last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + pooler_output=encoder_outputs.pooler_output, + attention_mask=encoder_outputs.attention_mask, decoder_cache=decoder_cache, ) @@ -1178,6 +1192,7 @@ def generate( model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__ # Initial forward: encode + decoder initialization + kwargs.setdefault("output_attention_mask", True) outputs = model_forward( input_features=input_features, attention_mask=attention_mask, @@ -1185,15 +1200,22 @@ def generate( return_dict=True, **kwargs, ) - encoder_outputs = outputs.encoder_outputs + + # Reconstruct encoder_outputs for subsequent forward calls + encoder_outputs = ParakeetEncoderModelOutput( + last_hidden_state=outputs.last_hidden_state, + pooler_output=outputs.pooler_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + attention_mask=outputs.attention_mask, + ) decoder_cache = outputs.decoder_cache - batch_size, sequence_length = encoder_outputs.pooler_output.shape[:2] - device = encoder_outputs.pooler_output.device + batch_size, sequence_length = outputs.pooler_output.shape[:2] + device = outputs.pooler_output.device - # TODO use encoder attention mask like in loss computation? - if attention_mask is not None: - encoder_attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequence_length) - valid_lengths = encoder_attention_mask.sum(dim=1).int() + # Use encoder attention mask for valid lengths + if outputs.attention_mask is not None: + valid_lengths = outputs.attention_mask.sum(dim=1).int() else: valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) @@ -1278,8 +1300,8 @@ def generate( sequences=sequences, token_timestamps=token_timestamps, token_durations=token_durations, - attentions=encoder_outputs.attentions, - hidden_states=encoder_outputs.hidden_states, + attentions=outputs.attentions, + hidden_states=outputs.hidden_states, ) return sequences diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 13b81855aaa6..72d3cae0986a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -5309,7 +5309,10 @@ def test_get_audio_features_output(self, return_dict: bool | None): elif hasattr(audio_config, "hidden_size"): hidden_size = audio_config.hidden_size elif hasattr(audio_config, "encoder_config"): - hidden_size = audio_config.encoder_config.hidden_dim + if hasattr(audio_config.encoder_config, "hidden_size"): + hidden_size = audio_config.encoder_config.hidden_size + else: + hidden_size = audio_config.encoder_config.hidden_dim elif hasattr(audio_config, "encoder_ffn_dim"): hidden_size = audio_config.encoder_ffn_dim self.assertEqual( From fa95fc8ee04bb6549008f90211f876149f30e32d Mon Sep 17 00:00:00 2001 From: Eric B Date: Thu, 26 Mar 2026 14:04:59 +0100 Subject: [PATCH 38/67] Address tests and nits. --- .../models/parakeet/modeling_parakeet.py | 19 +++++++++---------- .../models/parakeet/modular_parakeet.py | 19 +++++++++---------- .../models/parakeet/test_modeling_parakeet.py | 12 ++++++------ 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index bdec534629b8..db08d90789e4 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -912,7 +912,6 @@ def forward( decoder_cache: ParakeetTDTDecoderCache | None = None, decoder_cache_update_mask: torch.BoolTensor | None = None, ) -> torch.Tensor: - input_ids = input_ids.to(self.decoder_projector.weight.device) hidden_cell_states = ( (decoder_cache.hidden_state, decoder_cache.cell_state) if decoder_cache is not None and decoder_cache.is_initialized @@ -1166,7 +1165,7 @@ def forward( self, input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - input_ids: torch.LongTensor | None = None, + decoder_input_ids: torch.LongTensor | None = None, encoder_outputs: tuple[torch.FloatTensor] | None = None, encoder_frame_ids: torch.LongTensor | None = None, decoder_cache: ParakeetTDTDecoderCache | None = None, @@ -1176,7 +1175,7 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTOutput: r""" - input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): Decoder input token ids for single-step inference. encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask). @@ -1239,10 +1238,10 @@ def forward( blank_tokens = torch.full( (labels.shape[0], 1), self.config.blank_token_id, dtype=labels.dtype, device=labels.device ) - input_ids = torch.cat([blank_tokens, labels], dim=1) - elif input_ids is None and decoder_cache is None: + decoder_input_ids = torch.cat([blank_tokens, labels], dim=1) + elif decoder_input_ids is None and decoder_cache is None: # for inference: start with blank token if not provided - input_ids = torch.full( + decoder_input_ids = torch.full( (projected_encoder_output.shape[0], 1), self.config.blank_token_id, dtype=torch.long, @@ -1252,9 +1251,9 @@ def forward( if use_decoder_cache and decoder_cache is None: decoder_cache = ParakeetTDTDecoderCache() - # Run decoder if we have input_ids (initial step or after emitting a token) - if input_ids is not None: - decoder_output = self.decoder(input_ids, decoder_cache, decoder_cache_update_mask) + # Run decoder if we have decoder_input_ids (initial step or after emitting a token) + if decoder_input_ids is not None: + decoder_output = self.decoder(decoder_input_ids, decoder_cache, decoder_cache_update_mask) else: # Reuse cached decoder_output (blank-skipping path) decoder_output = decoder_cache.cache @@ -1432,7 +1431,7 @@ def generate( # Run decoder for emitted tokens — only update cache for samples that emitted model_forward( - input_ids=tokens.unsqueeze(1), + decoder_input_ids=tokens.unsqueeze(1), encoder_outputs=encoder_outputs, encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), decoder_cache=decoder_cache, diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 80e0e61ad70d..0d9994a14107 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -751,7 +751,6 @@ def forward( decoder_cache: ParakeetTDTDecoderCache | None = None, decoder_cache_update_mask: torch.BoolTensor | None = None, ) -> torch.Tensor: - input_ids = input_ids.to(self.decoder_projector.weight.device) hidden_cell_states = ( (decoder_cache.hidden_state, decoder_cache.cell_state) if decoder_cache is not None and decoder_cache.is_initialized @@ -1005,7 +1004,7 @@ def forward( self, input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - input_ids: torch.LongTensor | None = None, + decoder_input_ids: torch.LongTensor | None = None, encoder_outputs: tuple[torch.FloatTensor] | None = None, encoder_frame_ids: torch.LongTensor | None = None, decoder_cache: ParakeetTDTDecoderCache | None = None, @@ -1015,7 +1014,7 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTOutput: r""" - input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): Decoder input token ids for single-step inference. encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask). @@ -1078,10 +1077,10 @@ def forward( blank_tokens = torch.full( (labels.shape[0], 1), self.config.blank_token_id, dtype=labels.dtype, device=labels.device ) - input_ids = torch.cat([blank_tokens, labels], dim=1) - elif input_ids is None and decoder_cache is None: + decoder_input_ids = torch.cat([blank_tokens, labels], dim=1) + elif decoder_input_ids is None and decoder_cache is None: # for inference: start with blank token if not provided - input_ids = torch.full( + decoder_input_ids = torch.full( (projected_encoder_output.shape[0], 1), self.config.blank_token_id, dtype=torch.long, @@ -1091,9 +1090,9 @@ def forward( if use_decoder_cache and decoder_cache is None: decoder_cache = ParakeetTDTDecoderCache() - # Run decoder if we have input_ids (initial step or after emitting a token) - if input_ids is not None: - decoder_output = self.decoder(input_ids, decoder_cache, decoder_cache_update_mask) + # Run decoder if we have decoder_input_ids (initial step or after emitting a token) + if decoder_input_ids is not None: + decoder_output = self.decoder(decoder_input_ids, decoder_cache, decoder_cache_update_mask) else: # Reuse cached decoder_output (blank-skipping path) decoder_output = decoder_cache.cache @@ -1271,7 +1270,7 @@ def generate( # Run decoder for emitted tokens — only update cache for samples that emitted model_forward( - input_ids=tokens.unsqueeze(1), + decoder_input_ids=tokens.unsqueeze(1), encoder_outputs=encoder_outputs, encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), decoder_cache=decoder_cache, diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index b29e26322270..6667bb2ce5a5 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -458,13 +458,12 @@ def __init__( encoder_kwargs=None, is_training=True, vocab_size=129, - decoder_hidden_size=64, + decoder_hidden_size=32, num_decoder_layers=1, - durations=None, + durations=[0, 1, 2, 3, 4], hidden_act="relu", - max_symbols_per_step=10, + max_symbols_per_step=5, pad_token_id=2, - blank_token_id=128, ): if encoder_kwargs is None: encoder_kwargs = {} @@ -483,11 +482,11 @@ def __init__( self.vocab_size = vocab_size self.decoder_hidden_size = decoder_hidden_size self.num_decoder_layers = num_decoder_layers - self.durations = durations if durations is not None else [0, 1, 2, 3, 4] + self.durations = durations self.hidden_act = hidden_act self.max_symbols_per_step = max_symbols_per_step self.pad_token_id = pad_token_id - self.blank_token_id = blank_token_id + self.blank_token_id = vocab_size - 1 def prepare_config_and_inputs(self): _, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs() @@ -543,6 +542,7 @@ class ParakeetForTDTModelTest(ModelTesterMixin, unittest.TestCase): test_attention_outputs = False test_resize_embeddings = False + test_torch_exportable = False _is_composite = True @unittest.skip(reason="No available flash-SDPA kernels for Parakeet test shapes on this setup") From 5df7f289677a12412effd6eb57a97e67db1a706b Mon Sep 17 00:00:00 2001 From: Eric B Date: Thu, 26 Mar 2026 18:12:02 +0100 Subject: [PATCH 39/67] Inherit from GenerateMixIn for get_compiled_call --- .../models/parakeet/modeling_parakeet.py | 28 ++++++++++--------- .../models/parakeet/modular_parakeet.py | 28 ++++++++++--------- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index db08d90789e4..14c5e4f0b44f 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -27,7 +27,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...generation import CompileConfig +from ...generation import CompileConfig, GenerationMixin from ...integrations import use_kernel_func_from_hub, use_kernelized_func from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutput @@ -692,7 +692,7 @@ def __init__(self, *args, **kwargs): Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. """ ) -class ParakeetForCTC(ParakeetPreTrainedModel): +class ParakeetForCTC(ParakeetPreTrainedModel, GenerationMixin): config: ParakeetCTCConfig def __init__(self, config: ParakeetCTCConfig): @@ -779,9 +779,13 @@ def generate( input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, return_dict_in_generate: bool = False, + compile_config: CompileConfig | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetCTCGenerateOutput | torch.LongTensor: r""" + compile_config ([`~generation.CompileConfig`], *optional*): + If provided, `torch.compile` will be applied to the forward calls in the decoding loop. + Example: ```python @@ -802,8 +806,10 @@ def generate( >>> print(transcription) ``` """ + model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__ + kwargs["return_dict"] = True - outputs: CausalLMOutput = self.forward( + outputs: CausalLMOutput = model_forward( input_features=input_features, attention_mask=attention_mask, **kwargs, @@ -1130,7 +1136,7 @@ def tdt_loss( Parakeet Encoder with a TDT (Token Duration Transducer) head. """ ) -class ParakeetForTDT(ParakeetPreTrainedModel): +class ParakeetForTDT(ParakeetPreTrainedModel, GenerationMixin): config: ParakeetTDTConfig _no_split_modules = ["ParakeetTDTDecoder"] @@ -1310,14 +1316,11 @@ def generate( **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTGenerateOutput | torch.LongTensor: r""" - Perform TDT greedy decoding to generate token sequences. - - Args: - return_timestamps (`bool`, *optional*, defaults to `False`): - Whether to return per-token timestamps and durations. When `True`, forces - `return_dict_in_generate=True` and includes `token_timestamps` and `token_durations` in the output. - compile_config ([`~generation.CompileConfig`], *optional*): - If provided, `torch.compile` will be applied to the forward calls in the decoding loop. + return_timestamps (`bool`, *optional*, defaults to `False`): + Whether to return per-token timestamps and durations. When `True`, forces + `return_dict_in_generate=True` and includes `token_timestamps` and `token_durations` in the output. + compile_config ([`~generation.CompileConfig`], *optional*): + If provided, `torch.compile` will be applied to the forward calls in the decoding loop. Example: @@ -1373,7 +1376,6 @@ def generate( batch_size, sequence_length = outputs.pooler_output.shape[:2] device = outputs.pooler_output.device - # Use encoder attention mask for valid lengths if outputs.attention_mask is not None: valid_lengths = outputs.attention_mask.sum(dim=1).int() else: diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 0d9994a14107..50e9d21c169b 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -22,7 +22,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...generation import CompileConfig +from ...generation import CompileConfig, GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -531,7 +531,7 @@ def __init__(self, *args, **kwargs): Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. """ ) -class ParakeetForCTC(ParakeetPreTrainedModel): +class ParakeetForCTC(ParakeetPreTrainedModel, GenerationMixin): config: ParakeetCTCConfig def __init__(self, config: ParakeetCTCConfig): @@ -618,9 +618,13 @@ def generate( input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, return_dict_in_generate: bool = False, + compile_config: CompileConfig | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetCTCGenerateOutput | torch.LongTensor: r""" + compile_config ([`~generation.CompileConfig`], *optional*): + If provided, `torch.compile` will be applied to the forward calls in the decoding loop. + Example: ```python @@ -641,8 +645,10 @@ def generate( >>> print(transcription) ``` """ + model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__ + kwargs["return_dict"] = True - outputs: CausalLMOutput = self.forward( + outputs: CausalLMOutput = model_forward( input_features=input_features, attention_mask=attention_mask, **kwargs, @@ -969,7 +975,7 @@ class ParakeetTDTOutput(BaseModelOutputWithPooling): Parakeet Encoder with a TDT (Token Duration Transducer) head. """ ) -class ParakeetForTDT(ParakeetPreTrainedModel): +class ParakeetForTDT(ParakeetPreTrainedModel, GenerationMixin): config: ParakeetTDTConfig _no_split_modules = ["ParakeetTDTDecoder"] @@ -1149,14 +1155,11 @@ def generate( **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTGenerateOutput | torch.LongTensor: r""" - Perform TDT greedy decoding to generate token sequences. - - Args: - return_timestamps (`bool`, *optional*, defaults to `False`): - Whether to return per-token timestamps and durations. When `True`, forces - `return_dict_in_generate=True` and includes `token_timestamps` and `token_durations` in the output. - compile_config ([`~generation.CompileConfig`], *optional*): - If provided, `torch.compile` will be applied to the forward calls in the decoding loop. + return_timestamps (`bool`, *optional*, defaults to `False`): + Whether to return per-token timestamps and durations. When `True`, forces + `return_dict_in_generate=True` and includes `token_timestamps` and `token_durations` in the output. + compile_config ([`~generation.CompileConfig`], *optional*): + If provided, `torch.compile` will be applied to the forward calls in the decoding loop. Example: @@ -1212,7 +1215,6 @@ def generate( batch_size, sequence_length = outputs.pooler_output.shape[:2] device = outputs.pooler_output.device - # Use encoder attention mask for valid lengths if outputs.attention_mask is not None: valid_lengths = outputs.attention_mask.sum(dim=1).int() else: From cd706d48301fd2ce8bafc1ce1998447dbb6f0195 Mon Sep 17 00:00:00 2001 From: Eric B Date: Thu, 26 Mar 2026 18:48:27 +0100 Subject: [PATCH 40/67] Comment nit --- src/transformers/models/parakeet/modeling_parakeet.py | 2 +- src/transformers/models/parakeet/modular_parakeet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 14c5e4f0b44f..5150d35daeef 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -1431,7 +1431,7 @@ def generate( all_durations_tensor[emit_mask, emit_indices] = durations[emit_mask] token_counts += emit_mask.long() - # Run decoder for emitted tokens — only update cache for samples that emitted + # Update decoder cache for emitted tokens (using potentially compiled forward) model_forward( decoder_input_ids=tokens.unsqueeze(1), encoder_outputs=encoder_outputs, diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 50e9d21c169b..44be78f64f8e 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -1270,7 +1270,7 @@ def generate( all_durations_tensor[emit_mask, emit_indices] = durations[emit_mask] token_counts += emit_mask.long() - # Run decoder for emitted tokens — only update cache for samples that emitted + # Update decoder cache for emitted tokens (using potentially compiled forward) model_forward( decoder_input_ids=tokens.unsqueeze(1), encoder_outputs=encoder_outputs, From a47ed8a5ae3f7d814b92049e3a0a6d505308a986 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:19:33 +0200 Subject: [PATCH 41/67] forward cleanup --- .../models/parakeet/modular_parakeet.py | 76 ++++--------------- 1 file changed, 13 insertions(+), 63 deletions(-) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 44be78f64f8e..ccb26cfd734f 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -1011,11 +1011,9 @@ def forward( input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, decoder_input_ids: torch.LongTensor | None = None, - encoder_outputs: tuple[torch.FloatTensor] | None = None, - encoder_frame_ids: torch.LongTensor | None = None, decoder_cache: ParakeetTDTDecoderCache | None = None, - decoder_cache_update_mask: torch.BoolTensor | None = None, use_decoder_cache: bool | None = None, + encoder_outputs: ParakeetEncoderModelOutput | tuple[torch.FloatTensor] | None = None, labels: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTOutput: @@ -1025,17 +1023,10 @@ def forward( encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask). Can be a tuple or `ParakeetEncoderModelOutput`. - encoder_frame_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Encoder frame indices for the joint network during generation. decoder_cache (`ParakeetTDTDecoderCache`, *optional*): Decoder LSTM cache. When provided and initialized, the cached `decoder_output` is reused (e.g. during blank-skipping) instead of running the decoder. When `input_ids` is provided, the decoder runs and the cache is updated in-place. - decoder_cache_update_mask (`torch.BoolTensor` of shape `(batch_size,)`, *optional*): - Boolean mask controlling which batch elements have their decoder cache updated. - When provided, only elements where the mask is `True` are written to the cache; - other elements retain their previous cached state. Used during generation to - preserve cache for samples that predicted blank tokens. use_decoder_cache (`bool`, *optional*): Whether to use a decoder cache. When `True` and `decoder_cache` is `None`, a new cache is created automatically during the forward pass. @@ -1057,12 +1048,7 @@ def forward( >>> outputs = model(**inputs) ``` """ - # 1. Encode + project if encoder_outputs is None: - if input_features is None: - raise ValueError("Either `input_features` or `encoder_outputs` must be provided.") - if labels is not None: - kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.get_audio_features( input_features=input_features, attention_mask=attention_mask, @@ -1070,77 +1056,41 @@ def forward( ) elif not isinstance(encoder_outputs, ParakeetEncoderModelOutput): encoder_outputs = ParakeetEncoderModelOutput( - last_hidden_state=encoder_outputs[0], - pooler_output=encoder_outputs[1], + last_hidden_state=encoder_outputs[0] if len(encoder_outputs) > 0 else None, + pooler_output=encoder_outputs[1] if len(encoder_outputs) > 1 else None, hidden_states=encoder_outputs[2] if len(encoder_outputs) > 2 else None, attentions=encoder_outputs[3] if len(encoder_outputs) > 3 else None, attention_mask=encoder_outputs[4] if len(encoder_outputs) > 4 else None, ) - projected_encoder_output = encoder_outputs.pooler_output - - if labels is not None: - # for training: [blank, labels...] for training - blank_tokens = torch.full( - (labels.shape[0], 1), self.config.blank_token_id, dtype=labels.dtype, device=labels.device - ) - decoder_input_ids = torch.cat([blank_tokens, labels], dim=1) - elif decoder_input_ids is None and decoder_cache is None: - # for inference: start with blank token if not provided - decoder_input_ids = torch.full( - (projected_encoder_output.shape[0], 1), - self.config.blank_token_id, - dtype=torch.long, - device=projected_encoder_output.device, - ) if use_decoder_cache and decoder_cache is None: decoder_cache = ParakeetTDTDecoderCache() - # Run decoder if we have decoder_input_ids (initial step or after emitting a token) - if decoder_input_ids is not None: - decoder_output = self.decoder(decoder_input_ids, decoder_cache, decoder_cache_update_mask) - else: - # Reuse cached decoder_output (blank-skipping path) - decoder_output = decoder_cache.cache - - if encoder_frame_ids is not None: - batch_indices = torch.arange(projected_encoder_output.shape[0], device=projected_encoder_output.device) - safe_frame_ids = torch.clamp(encoder_frame_ids, max=projected_encoder_output.shape[1] - 1) - encoder_for_joint = projected_encoder_output[batch_indices, safe_frame_ids].unsqueeze(1) - decoder_for_joint = decoder_output - else: - encoder_for_joint = projected_encoder_output.unsqueeze(2) - decoder_for_joint = decoder_output.unsqueeze(1) - - token_logits, duration_logits = self.joint( - decoder_output=decoder_for_joint, - encoder_output=encoder_for_joint, + decoder_hidden_states = self.decoder(decoder_input_ids, cache=decoder_cache) + logits = self.joint( + encoder_hidden_states=encoder_outputs.pooler_output, + decoder_hidden_states=decoder_hidden_states, ) - logits = torch.cat([token_logits, duration_logits], dim=-1) loss = None if labels is not None: - encoder_lengths = encoder_outputs.attention_mask.sum(-1) - target_lengths = (labels != self.config.pad_token_id).sum(-1) loss = self.loss_function( - token_logits=token_logits.float(), - duration_logits=duration_logits.float(), - targets=labels.to(token_logits.device).int(), - logit_lengths=encoder_lengths.to(token_logits.device).int(), - target_lengths=target_lengths.to(token_logits.device).int(), + token_logits=logits[..., : self.config.vocab_size], + duration_logits=logits[..., self.config.vocab_size :], + labels=labels, + logit_lengths=encoder_outputs.attention_mask.sum(-1), + label_lengths=(labels != self.config.pad_token_id).sum(-1), blank_token_id=self.config.blank_token_id, durations=self.config.durations, - reduction="mean", ) return ParakeetTDTOutput( loss=loss, logits=logits, last_hidden_state=encoder_outputs.last_hidden_state, + pooler_output=encoder_outputs.pooler_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, - pooler_output=encoder_outputs.pooler_output, - attention_mask=encoder_outputs.attention_mask, decoder_cache=decoder_cache, ) From 13b68cec1c9a4958cce6b286df7abca0660a9e3f Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:20:05 +0200 Subject: [PATCH 42/67] generate cleanup + separate generation file --- .../models/parakeet/generation_parakeet.py | 168 ++++++++++++++++++ .../models/parakeet/modular_parakeet.py | 164 +---------------- 2 files changed, 169 insertions(+), 163 deletions(-) create mode 100644 src/transformers/models/parakeet/generation_parakeet.py diff --git a/src/transformers/models/parakeet/generation_parakeet.py b/src/transformers/models/parakeet/generation_parakeet.py new file mode 100644 index 000000000000..bf7ac32051aa --- /dev/null +++ b/src/transformers/models/parakeet/generation_parakeet.py @@ -0,0 +1,168 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch + +from ...generation import GenerationMixin, GenerationMode, StoppingCriteria +from ...utils import ModelOutput + + +@dataclass +class ParakeetTDTGenerateOutput(ModelOutput): + """ + Outputs of Parakeet TDT generation. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Generated token sequences (including blank tokens). + durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Per-step durations in frames. Combined with `sequences`, this is sufficient + to reconstruct full timestamp information (frame indices are the cumulative sum + of durations). + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*): + Encoder attention weights per layer. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*): + Encoder hidden states per layer. + """ + + sequences: torch.LongTensor + durations: torch.LongTensor | None = None + attentions: tuple[tuple[torch.FloatTensor]] | None = None + hidden_states: tuple[tuple[torch.FloatTensor]] | None = None + + +class EncoderExhaustedCriteria(StoppingCriteria): + """Stops generation when all batch elements have walked past their encoder output length.""" + + def __init__(self, model): + self.model = model + + def __call__(self, input_ids, scores, **kwargs): + if self.model._encoder_finished is None: + return torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device) + return self.model._encoder_finished + + +class ParakeetTDTGenerationMixin(GenerationMixin): + """Generation mixin for Parakeet TDT models. + + Handles transducer-specific generation logic: encoder frame tracking, + duration accumulation, and encoder-exhaustion stopping. + """ + def _get_stopping_criteria(self, *args, **kwargs): + criteria = super()._get_stopping_criteria(*args, **kwargs) + criteria.append(EncoderExhaustedCriteria(self)) + return criteria + + def _update_model_kwargs_for_generation(self, outputs, *args, **kwargs): + model_kwargs = super()._update_model_kwargs_for_generation(outputs, *args, **kwargs) + + # Advance encoder frame pointer by the predicted duration + logits = outputs.logits[:, -1, :] + tokens = logits[:, : self.config.vocab_size].argmax(dim=-1) + durations = logits[:, self.config.vocab_size :].argmax(dim=-1) + + # Only force forward progress (duration >= 1) for blank predictions; + blank_mask = tokens == self.config.blank_token_id + durations = torch.where(blank_mask & (durations == 0), torch.ones_like(durations), durations) + model_kwargs["encoder_frame_idxs"] = model_kwargs["encoder_frame_idxs"] + durations + self._step_durations.append(durations) + + # Track which batch elements have exhausted their encoder frames. + self._encoder_finished = model_kwargs["encoder_frame_idxs"] >= model_kwargs["encoder_valid_lengths"] + + return model_kwargs + + def _prepare_generated_length( + self, generation_config, has_default_max_length, has_default_min_length, + model_input_name, input_ids_length, inputs_tensor, + ): + # When the user hasn't explicitly set max_length/max_new_tokens, derive an upper + # bound from the encoder capacity. The actual stopping is handled by the + # encoder-exhaustion stopping criteria; this just sizes the output buffer. + if has_default_max_length and generation_config.max_new_tokens is None: + encoder_seq_len = self.encoder._get_subsampling_output_length( + torch.tensor([inputs_tensor.shape[1]], device=inputs_tensor.device) + ).item() + generation_config.max_length = self.config.max_symbols_per_step * encoder_seq_len + has_default_max_length = False # prevent super() from overwriting + return super()._prepare_generated_length( + generation_config, has_default_max_length, has_default_min_length, + model_input_name, input_ids_length, inputs_tensor, + ) + + def _prepare_model_inputs(self, *args, **kwargs): + inputs, input_name, model_kwargs = super()._prepare_model_inputs(*args, **kwargs) + + encoder_outputs = self.get_audio_features( + input_features=inputs, + attention_mask=model_kwargs.get("attention_mask", None), + output_attention_mask=True, + ) + model_kwargs["encoder_outputs"] = encoder_outputs + + if encoder_outputs.attention_mask is not None: + encoder_valid_lengths = encoder_outputs.attention_mask.sum(-1) + else: + batch_size = encoder_outputs.shape[0] + encoder_valid_lengths = torch.full( + (batch_size,), encoder_outputs.last_hidden_state.shape[1], dtype=torch.long, device=encoder_outputs.device + ) + model_kwargs["encoder_valid_lengths"] = encoder_valid_lengths + + model_kwargs["encoder_frame_idxs"] = torch.zeros( + inputs.shape[0], + device=inputs.device, + dtype=torch.long, + ) + + return inputs, input_name, model_kwargs + + def _prepare_cache_for_generation(self, generation_config, model_kwargs, *args, **kwargs): + from .modeling_parakeet import ParakeetTDTDecoderCache + + model_kwargs["decoder_cache"] = ParakeetTDTDecoderCache() + + def prepare_inputs_for_generation(self, input_ids, *args, **kwargs): + from .modeling_parakeet import ParakeetEncoderModelOutput + + model_inputs = super().prepare_inputs_for_generation(input_ids, *args, **kwargs) + encoder_frame_idxs = model_inputs.pop("encoder_frame_idxs").to(model_inputs["encoder_outputs"].pooler_output.device) + + pooler_output = model_inputs["encoder_outputs"].pooler_output + batch_size, max_encoder_len = pooler_output.shape[0], pooler_output.shape[1] + encoder_frame_idxs = encoder_frame_idxs.clamp(max=max_encoder_len - 1) + model_inputs["encoder_outputs"] = ParakeetEncoderModelOutput( + pooler_output=pooler_output[torch.arange(batch_size), encoder_frame_idxs, None], + ) + + return model_inputs + + def generate(self, inputs=None, generation_config=None, **kwargs): + # TODO @eustlb: this is temporary — we're going to modularize generate to allow doing this cleanly. + self._step_durations = [] + self._encoder_finished = None + + outputs = super().generate(inputs=inputs, generation_config=generation_config, **kwargs) + durations = torch.stack(self._step_durations, dim=1) # (batch, steps) + # Prepend a zero duration for the decoder_start_token_id that super().generate() prepends to sequences + durations = torch.cat([torch.zeros(durations.shape[0], 1, dtype=durations.dtype, device=durations.device), durations], dim=1) + del self._step_durations, self._encoder_finished + + return ParakeetTDTGenerateOutput( + sequences=outputs.sequences if isinstance(outputs, ModelOutput) else outputs, + durations=durations, + ) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index ccb26cfd734f..404fd61719bc 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -975,7 +975,7 @@ class ParakeetTDTOutput(BaseModelOutputWithPooling): Parakeet Encoder with a TDT (Token Duration Transducer) head. """ ) -class ParakeetForTDT(ParakeetPreTrainedModel, GenerationMixin): +class ParakeetForTDT(ParakeetPreTrainedModel, ParakeetTDTGenerationMixin): config: ParakeetTDTConfig _no_split_modules = ["ParakeetTDTDecoder"] @@ -1094,167 +1094,5 @@ def forward( decoder_cache=decoder_cache, ) - @torch.no_grad() - def generate( - self, - input_features: torch.Tensor, - attention_mask: torch.Tensor | None = None, - return_timestamps: bool = False, - return_dict_in_generate: bool = False, - compile_config: CompileConfig | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> ParakeetTDTGenerateOutput | torch.LongTensor: - r""" - return_timestamps (`bool`, *optional*, defaults to `False`): - Whether to return per-token timestamps and durations. When `True`, forces - `return_dict_in_generate=True` and includes `token_timestamps` and `token_durations` in the output. - compile_config ([`~generation.CompileConfig`], *optional*): - If provided, `torch.compile` will be applied to the forward calls in the decoding loop. - - Example: - - ```python - >>> from transformers import AutoProcessor, ParakeetForTDT - >>> from datasets import load_dataset, Audio - - >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = ParakeetForTDT.from_pretrained(model_id) - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) - - >>> inputs = processor(ds[0]["audio"]["array"], sampling_rate=processor.feature_extractor.sampling_rate) - >>> inputs = inputs.to(model.device, dtype=model.dtype) - >>> output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) - - >>> decoded_output, decoded_timestamps = processor.decode( - ... output.sequences, - ... token_timestamps=output.token_timestamps, - ... token_durations=output.token_durations, - ... skip_special_tokens=True - ... ) - >>> print("Transcription:", decoded_output) - >>> print("Timestamped tokens:", decoded_timestamps) - ``` - """ - if return_timestamps: - return_dict_in_generate = True - - model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__ - - # Initial forward: encode + decoder initialization - kwargs.setdefault("output_attention_mask", True) - outputs = model_forward( - input_features=input_features, - attention_mask=attention_mask, - use_decoder_cache=True, - return_dict=True, - **kwargs, - ) - - # Reconstruct encoder_outputs for subsequent forward calls - encoder_outputs = ParakeetEncoderModelOutput( - last_hidden_state=outputs.last_hidden_state, - pooler_output=outputs.pooler_output, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - attention_mask=outputs.attention_mask, - ) - decoder_cache = outputs.decoder_cache - batch_size, sequence_length = outputs.pooler_output.shape[:2] - device = outputs.pooler_output.device - - if outputs.attention_mask is not None: - valid_lengths = outputs.attention_mask.sum(dim=1).int() - else: - valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) - - time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) - time_indices_current_labels = torch.zeros(batch_size, dtype=torch.long, device=device) - active_mask = time_indices < valid_lengths - symbols_per_step = torch.zeros(batch_size, dtype=torch.long, device=device) - last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) - max_output_len = sequence_length * self.config.max_symbols_per_step - all_tokens_tensor = torch.full( - (batch_size, max_output_len), self.config.pad_token_id, dtype=torch.long, device=device - ) - tokens = torch.zeros(batch_size, dtype=torch.long, device=device) - durations = torch.zeros(batch_size, dtype=torch.long, device=device) - token_counts = torch.zeros(batch_size, dtype=torch.long, device=device) - if return_timestamps: - all_frame_indices = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) - all_durations_tensor = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) - - while active_mask.any(): - active_at_start = active_mask.clone() - - time_indices_current_labels = torch.where(active_at_start, time_indices, time_indices_current_labels) - outputs = model_forward( - encoder_outputs=encoder_outputs, - encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), - decoder_cache=decoder_cache, - return_dict=True, - ) - logits = outputs.logits.squeeze(1) - tokens = torch.where(active_at_start, logits[..., : self.config.vocab_size].argmax(dim=-1), tokens) - durations = torch.where(active_at_start, logits[..., self.config.vocab_size :].argmax(dim=-1), durations) - - blank_mask = active_at_start & (tokens == self.config.blank_token_id) - durations = durations.masked_fill(blank_mask & (durations == 0), 1) # ensure forward progress - - # Advance time for all active samples - time_indices = time_indices + durations.masked_fill(~active_at_start, 0) - active_mask = time_indices < valid_lengths - - # If all remaining active samples predicted blank, skip emit + decoder update - emit_mask = active_at_start & ~blank_mask - if not emit_mask.any(): - continue - - # Emit non-blank tokens - emit_indices = token_counts[emit_mask] - all_tokens_tensor[emit_mask, emit_indices] = tokens[emit_mask] - if return_timestamps: - all_frame_indices[emit_mask, emit_indices] = time_indices_current_labels[emit_mask] - all_durations_tensor[emit_mask, emit_indices] = durations[emit_mask] - token_counts += emit_mask.long() - - # Update decoder cache for emitted tokens (using potentially compiled forward) - model_forward( - decoder_input_ids=tokens.unsqueeze(1), - encoder_outputs=encoder_outputs, - encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), - decoder_cache=decoder_cache, - decoder_cache_update_mask=emit_mask, - return_dict=True, - ) - - time_changed = time_indices_current_labels != last_label_time - symbols_per_step = torch.where(time_changed, 0, symbols_per_step) - symbols_per_step = torch.where(emit_mask, symbols_per_step + 1, symbols_per_step) - last_label_time = torch.where(emit_mask, time_indices_current_labels, last_label_time) - force_advance = active_mask & (symbols_per_step >= self.config.max_symbols_per_step) - time_indices = time_indices + force_advance.long() - symbols_per_step = symbols_per_step.masked_fill(force_advance, 0) - active_mask = time_indices < valid_lengths - - max_len = max(token_counts.max().item(), 1) - sequences = all_tokens_tensor[:, :max_len] - token_timestamps, token_durations = None, None - if return_timestamps: - token_timestamps = all_frame_indices[:, :max_len] - token_durations = all_durations_tensor[:, :max_len] - - if return_dict_in_generate: - return ParakeetTDTGenerateOutput( - sequences=sequences, - token_timestamps=token_timestamps, - token_durations=token_durations, - attentions=outputs.attentions, - hidden_states=outputs.hidden_states, - ) - return sequences - __all__ = ["ParakeetForCTC", "ParakeetForTDT", "ParakeetEncoder", "ParakeetPreTrainedModel"] From 72c1ad002fc98bac84b2d169bf9e492ee0c4daf3 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:21:52 +0200 Subject: [PATCH 43/67] generate: add _supported_generation_modes --- src/transformers/generation/utils.py | 7 +++++++ src/transformers/models/parakeet/modular_parakeet.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8a55c184b0f0..391a83704b24 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1444,6 +1444,13 @@ def compute_transition_scores( def _validate_generation_mode( self: "GenerativePreTrainedModel", generation_mode, generation_config, generation_mode_kwargs ): + supported_modes = getattr(self, "_supported_generation_modes", None) + if supported_modes is not None and generation_mode not in supported_modes: + raise ValueError( + f"{self.__class__.__name__} only supports {supported_modes}, but got " + f"generation mode '{generation_mode}'." + ) + if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs: raise ValueError( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 404fd61719bc..8cea8eb5cd21 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -22,7 +22,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...generation import CompileConfig, GenerationMixin +from ...generation import CompileConfig, GenerationMixin, GenerationMode from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -978,6 +978,7 @@ class ParakeetTDTOutput(BaseModelOutputWithPooling): class ParakeetForTDT(ParakeetPreTrainedModel, ParakeetTDTGenerationMixin): config: ParakeetTDTConfig _no_split_modules = ["ParakeetTDTDecoder"] + _supported_generation_modes = [GenerationMode.GREEDY_SEARCH] def __init__(self, config: ParakeetTDTConfig): super().__init__(config) From 8e23b3df76e57483d9c4a9d687235c3c5ef211e2 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:22:55 +0200 Subject: [PATCH 44/67] automatic init of the loss --- src/transformers/loss/loss_tdt.py | 167 ++++++++++++++++++ src/transformers/loss/loss_utils.py | 2 + .../models/parakeet/modular_parakeet.py | 128 -------------- 3 files changed, 169 insertions(+), 128 deletions(-) create mode 100644 src/transformers/loss/loss_tdt.py diff --git a/src/transformers/loss/loss_tdt.py b/src/transformers/loss/loss_tdt.py new file mode 100644 index 000000000000..27389e10b725 --- /dev/null +++ b/src/transformers/loss/loss_tdt.py @@ -0,0 +1,167 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def tdt_loss( + token_logits: torch.Tensor, + duration_logits: torch.Tensor, + targets: torch.Tensor, + logit_lengths: torch.Tensor, + target_lengths: torch.Tensor, + blank_token_id: int, + durations: list[int], + sigma: float = 0.0, + reduction: str = "mean", +) -> torch.Tensor: + """ + Compute TDT (Token-and-Duration Transducer) loss (https://arxiv.org/abs/2304.06795). + + Ported from NeMo's `TDTLossPytorch` with anti-diagonal processing. Unlike standard RNNT loss, this loss trains both + the token prediction head and the duration prediction head. It uses vectorized anti-diagonal processing for + efficiency: all (t, u) pairs on each anti-diagonal t+u=n are computed in parallel as batched tensor operations. + + Args: + token_logits: Token logits of shape `(batch, T, U+1, vocab_size+1)`. + duration_logits: Duration logits of shape `(batch, T, U+1, num_durations)`. + targets: Target labels of shape `(batch, U)`. + logit_lengths: Encoder output lengths of shape `(batch,)`. + target_lengths: Target lengths of shape `(batch,)`. + blank_token_id: Blank token id. + durations: List of duration values (e.g., `[0, 1, 2, 3, 4]`). + sigma: Logit undernormalization constant (see TDT paper). Defaults to `0.0`. + reduction: Loss reduction method. One of `"mean"`, `"sum"`, or `"none"`. Defaults to `"mean"`. + + Returns: + Scalar loss tensor (or per-example losses if `reduction="none"`). + + """ + device = token_logits.device + batch_size, max_t, max_u, _ = token_logits.shape + + token_logits = token_logits.float() + duration_logits = duration_logits.float() + + # Apply log-softmax to get log probabilities + token_log_probs = torch.log_softmax(token_logits, dim=-1) - sigma + duration_log_probs = torch.log_softmax(duration_logits, dim=-1) + + log_alpha = torch.full((batch_size, max_t, max_u), float("-inf"), device=device) + log_alpha[:, 0, 0] = 0.0 + + # Precompute blank and label log-probs for vectorized access + blank_log_probs = token_log_probs[:, :, :, blank_token_id] + + if max_u > 1: + targets_expanded = targets.unsqueeze(1).expand(-1, max_t, -1) # (batch, T, U_labels) + label_log_probs = torch.gather( + token_log_probs[:, :, : max_u - 1, :], # (batch, T, U-1, vocab) + dim=3, + index=targets_expanded.unsqueeze(-1), + ).squeeze(-1) # (batch, T, U-1) + + # Process anti-diagonals: all (t, u) with t + u = n have no mutual dependencies + for n in range(1, max_t + max_u - 1): + u_start = max(0, n - max_t + 1) + u_end = min(n + 1, max_u) + u_indices = torch.arange(u_start, u_end, device=device) + + t_indices = n - u_indices + all_candidates = [] + for i, dur in enumerate(durations): + t_prev = t_indices - dur + valid_t = t_prev >= 0 + if not valid_t.any(): + continue + t_src = t_prev.clamp(min=0) + + # Blank arcs (dur > 0): from (t-dur, u) to (t, u) + if dur > 0: + contrib = ( + log_alpha[:, t_src, u_indices] + + blank_log_probs[:, t_src, u_indices] + + duration_log_probs[:, t_src, u_indices, i] + ) + contrib = torch.where(valid_t.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) + all_candidates.append(contrib) + + # Label arcs: from (t-dur, u-1) to (t, u), only if u > 0 + valid_u = u_indices > 0 + valid_both = valid_t & valid_u + if valid_both.any(): + u_src = (u_indices - 1).clamp(min=0) + u_src_label = u_src.clamp(max=max_u - 2) if max_u > 1 else u_src + + contrib = ( + log_alpha[:, t_src, u_src] + + label_log_probs[:, t_src, u_src_label] + + duration_log_probs[:, t_src, u_src, i] + ) + contrib = torch.where(valid_both.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) + all_candidates.append(contrib) + + if all_candidates: + stacked = torch.stack(all_candidates, dim=0) + log_alpha[:, t_indices, u_indices] = torch.logsumexp(stacked, dim=0) + + # Terminal probability: sum over blank arcs that reach (T, U) from (T-dur, U) + batch_idx = torch.arange(batch_size, device=device) + log_probs = torch.full((batch_size,), float("-inf"), device=device) + for i, dur in enumerate(durations): + if dur == 0: + continue + t_final = logit_lengths - dur + valid = t_final >= 0 + if not valid.any(): + continue + + t_clamped = t_final.clamp(min=0) + terminal = ( + log_alpha[batch_idx, t_clamped, target_lengths] + + token_log_probs[batch_idx, t_clamped, target_lengths, blank_token_id] + + duration_log_probs[batch_idx, t_clamped, target_lengths, i] + ) + combined = torch.stack([log_probs, terminal], dim=0) + log_probs = torch.where(valid, torch.logsumexp(combined, dim=0), log_probs) + + losses = -log_probs + + if reduction == "mean": + return (losses / target_lengths.float()).mean() + elif reduction == "sum": + return losses.sum() + return losses + + +def ParakeetForTDTLoss( + token_logits, + duration_logits, + labels, + logit_lengths, + label_lengths, + blank_token_id, + durations, + **kwargs, +): + device = token_logits.device + return tdt_loss( + token_logits=token_logits.float(), + duration_logits=duration_logits.float(), + targets=labels.to(device).int(), + logit_lengths=logit_lengths.to(device).int(), + target_lengths=label_lengths.to(device).int(), + blank_token_id=blank_token_id, + durations=durations, + ) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index df269477e9ec..e0aa92b50808 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -23,6 +23,7 @@ from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss from .loss_lw_detr import LwDetrForObjectDetectionLoss from .loss_rt_detr import RTDetrForObjectDetectionLoss +from .loss_tdt import ParakeetForTDTLoss def fixed_cross_entropy( @@ -165,4 +166,5 @@ def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs): "DFineForObjectDetection": DFineForObjectDetectionLoss, "CsmForConditionalGeneration": ForCausalLMLoss, "LwDetrForObjectDetection": LwDetrForObjectDetectionLoss, + "ParakeetForTDT": ParakeetForTDTLoss, } diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 8cea8eb5cd21..8f8fb8d500f2 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -772,133 +772,6 @@ def forward( return decoder_output -# TODO (ebezzam) eventually move to audio_utils or loss_utils for common usage? -def tdt_loss( - token_logits: torch.Tensor, - duration_logits: torch.Tensor, - targets: torch.Tensor, - logit_lengths: torch.Tensor, - target_lengths: torch.Tensor, - blank_token_id: int, - durations: list[int], - sigma: float = 0.0, - reduction: str = "mean", -) -> torch.Tensor: - """ - Compute TDT (Token-and-Duration Transducer) loss (https://arxiv.org/abs/2304.06795). - - Ported from NeMo's `TDTLossPytorch` with anti-diagonal processing. Unlike standard RNNT loss, this loss trains both - the token prediction head and the duration prediction head. It uses vectorized anti-diagonal processing for - efficiency: all (t, u) pairs on each anti-diagonal t+u=n are computed in parallel as batched tensor operations. - - Args: - token_logits: Token logits of shape `(batch, T, U+1, vocab_size+1)`. - duration_logits: Duration logits of shape `(batch, T, U+1, num_durations)`. - targets: Target labels of shape `(batch, U)`. - logit_lengths: Encoder output lengths of shape `(batch,)`. - target_lengths: Target lengths of shape `(batch,)`. - blank_token_id: Blank token id. - durations: List of duration values (e.g., `[0, 1, 2, 3, 4]`). - sigma: Logit undernormalization constant (see TDT paper). Defaults to `0.0`. - reduction: Loss reduction method. One of `"mean"`, `"sum"`, or `"none"`. Defaults to `"mean"`. - - Returns: - Scalar loss tensor (or per-example losses if `reduction="none"`). - - """ - device = token_logits.device - batch_size, max_t, max_u, _ = token_logits.shape - - # Apply log-softmax to get log probabilities - token_log_probs = torch.log_softmax(token_logits, dim=-1) - sigma - duration_log_probs = torch.log_softmax(duration_logits, dim=-1) - - log_alpha = torch.full((batch_size, max_t, max_u), float("-inf"), device=device) - log_alpha[:, 0, 0] = 0.0 - - # Precompute blank and label log-probs for vectorized access - blank_log_probs = token_log_probs[:, :, :, blank_token_id] - - if max_u > 1: - targets_expanded = targets.unsqueeze(1).expand(-1, max_t, -1) # (batch, T, U_labels) - label_log_probs = torch.gather( - token_log_probs[:, :, : max_u - 1, :], # (batch, T, U-1, vocab) - dim=3, - index=targets_expanded.unsqueeze(-1), - ).squeeze(-1) # (batch, T, U-1) - - # Process anti-diagonals: all (t, u) with t + u = n have no mutual dependencies - for n in range(1, max_t + max_u - 1): - u_start = max(0, n - max_t + 1) - u_end = min(n + 1, max_u) - u_indices = torch.arange(u_start, u_end, device=device) - - t_indices = n - u_indices - all_candidates = [] - for i, dur in enumerate(durations): - t_prev = t_indices - dur - valid_t = t_prev >= 0 - if not valid_t.any(): - continue - t_src = t_prev.clamp(min=0) - - # Blank arcs (dur > 0): from (t-dur, u) to (t, u) - if dur > 0: - contrib = ( - log_alpha[:, t_src, u_indices] - + blank_log_probs[:, t_src, u_indices] - + duration_log_probs[:, t_src, u_indices, i] - ) - contrib = torch.where(valid_t.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) - all_candidates.append(contrib) - - # Label arcs: from (t-dur, u-1) to (t, u), only if u > 0 - valid_u = u_indices > 0 - valid_both = valid_t & valid_u - if valid_both.any(): - u_src = (u_indices - 1).clamp(min=0) - u_src_label = u_src.clamp(max=max_u - 2) if max_u > 1 else u_src - - contrib = ( - log_alpha[:, t_src, u_src] - + label_log_probs[:, t_src, u_src_label] - + duration_log_probs[:, t_src, u_src, i] - ) - contrib = torch.where(valid_both.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) - all_candidates.append(contrib) - - if all_candidates: - stacked = torch.stack(all_candidates, dim=0) - log_alpha[:, t_indices, u_indices] = torch.logsumexp(stacked, dim=0) - - # Terminal probability: sum over blank arcs that reach (T, U) from (T-dur, U) - batch_idx = torch.arange(batch_size, device=device) - log_probs = torch.full((batch_size,), float("-inf"), device=device) - for i, dur in enumerate(durations): - if dur == 0: - continue - t_final = logit_lengths - dur - valid = t_final >= 0 - if not valid.any(): - continue - - t_clamped = t_final.clamp(min=0) - terminal = ( - log_alpha[batch_idx, t_clamped, target_lengths] - + token_log_probs[batch_idx, t_clamped, target_lengths, blank_token_id] - + duration_log_probs[batch_idx, t_clamped, target_lengths, i] - ) - combined = torch.stack([log_probs, terminal], dim=0) - log_probs = torch.where(valid, torch.logsumexp(combined, dim=0), log_probs) - - losses = -log_probs - - if reduction == "mean": - return (losses / target_lengths.float()).mean() - elif reduction == "sum": - return losses.sum() - return losses - class ParakeetTDTJointNetwork(nn.Module): """Joint network that combines encoder and decoder outputs to predict tokens and durations.""" @@ -986,7 +859,6 @@ def __init__(self, config: ParakeetTDTConfig): self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.decoder = ParakeetTDTDecoder(config) self.joint = ParakeetTDTJointNetwork(config) - self.loss_function = tdt_loss self.post_init() From 1cc39fd85fa3396eae587fcdc938c99a683b096b Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:28:05 +0200 Subject: [PATCH 45/67] modular cleanups --- .../models/parakeet/modular_parakeet.py | 68 ++++++------------- 1 file changed, 21 insertions(+), 47 deletions(-) diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 8f8fb8d500f2..a395e5c896af 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -41,6 +41,7 @@ from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule from ..llama.modeling_llama import LlamaAttention, eager_attention_forward from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig +from .generation_parakeet import ParakeetTDTGenerationMixin logger = logging.get_logger(__name__) @@ -741,7 +742,7 @@ class ParakeetTDTDecoder(nn.Module): def __init__(self, config: ParakeetTDTConfig): super().__init__() - self.config = config + self.blank_token_id = config.blank_token_id self.embedding = nn.Embedding(config.vocab_size, config.decoder_hidden_size) self.lstm = nn.LSTM( input_size=config.decoder_hidden_size, @@ -754,21 +755,26 @@ def __init__(self, config: ParakeetTDTConfig): def forward( self, input_ids: torch.LongTensor, - decoder_cache: ParakeetTDTDecoderCache | None = None, - decoder_cache_update_mask: torch.BoolTensor | None = None, + cache: ParakeetTDTDecoderCache | None = None, ) -> torch.Tensor: + # All-blank fast path + if cache is not None and cache.is_initialized: + blank_mask = input_ids[:, -1] == self.blank_token_id + if blank_mask.all(): + return cache.cache + hidden_cell_states = ( - (decoder_cache.hidden_state, decoder_cache.cell_state) - if decoder_cache is not None and decoder_cache.is_initialized - else None + (cache.hidden_state, cache.cell_state) if cache is not None and cache.is_initialized else None ) embeddings = self.embedding(input_ids) lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, hidden_cell_states) decoder_output = self.decoder_projector(lstm_output) - if decoder_cache is not None: - decoder_cache.update( - decoder_output, hidden_state, cell_state, lstm_module=self.lstm, mask=decoder_cache_update_mask - ) + + if cache is not None: + mask = ~blank_mask if cache.is_initialized else None + cache.update(decoder_output, hidden_state, cell_state, lstm_module=self.lstm, mask=mask) + return cache.cache + return decoder_output @@ -784,39 +790,11 @@ def __init__(self, config: ParakeetTDTConfig): def forward( self, - decoder_output: torch.Tensor, - encoder_output: torch.Tensor, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - joint_output = self.activation(encoder_output + decoder_output) - logits = self.head(joint_output) - token_logits = logits[..., : self.vocab_size] - duration_logits = logits[..., self.vocab_size :] - return token_logits, duration_logits - - -@dataclass -class ParakeetTDTGenerateOutput(ModelOutput): - """ - Outputs of Parakeet TDT generation. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Generated token sequences. - token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Per-token frame indices. Returned when `return_timestamps=True`. - token_durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Per-token durations in frames. Returned when `return_timestamps=True`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*): - Encoder attention weights per layer. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*): - Encoder hidden states per layer. - """ - - sequences: torch.LongTensor - token_timestamps: torch.FloatTensor | None = None - token_durations: torch.LongTensor | None = None - attentions: tuple[tuple[torch.FloatTensor]] | None = None - hidden_states: tuple[tuple[torch.FloatTensor]] | None = None + joint_output = self.activation(encoder_hidden_states + decoder_hidden_states) + return self.head(joint_output) @dataclass @@ -830,16 +808,12 @@ class ParakeetTDTOutput(BaseModelOutputWithPooling): logits (`torch.FloatTensor`): Joint token and duration logits. Shape is `(batch, T, U+1, vocab+durations)` for training or `(batch, 1, 1, vocab+durations)` for single-step inference. - attention_mask (`torch.Tensor`, *optional*): - Encoder output attention mask after subsampling. decoder_cache (`ParakeetTDTDecoderCache`, *optional*): - Decoder LSTM cache containing hidden state, cell state, and decoder output. - Updated in-place during generation. + Decoder LSTM cache containing hidden state, cell state, and last output. """ loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None - attention_mask: torch.Tensor | None = None decoder_cache: ParakeetTDTDecoderCache | None = None From 531f297ec1ea1b6e2017e7d18e637176798951ce Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:28:54 +0200 Subject: [PATCH 46/67] use is_encoder_decoder --- src/transformers/models/parakeet/configuration_parakeet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index babc9526f760..fb6bc1c04d7d 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -176,6 +176,7 @@ class ParakeetTDTConfig(PreTrainedConfig): encoder_config: dict | PreTrainedConfig | None = None pad_token_id: int = 2 blank_token_id: int = 8192 + is_encoder_decoder: bool = True def __post_init__(self, **kwargs): if isinstance(self.encoder_config, dict): From 2c0f23afd9abe4bfd87dd9590f926e3e2bb69a71 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:39:05 +0200 Subject: [PATCH 47/67] timestamp processing fully from tokens + durations --- .../models/parakeet/processing_parakeet.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py index 0b662f56af34..91d502784828 100644 --- a/src/transformers/models/parakeet/processing_parakeet.py +++ b/src/transformers/models/parakeet/processing_parakeet.py @@ -91,15 +91,17 @@ def model_input_names(self): feature_extractor_input_names = self.feature_extractor.model_input_names return feature_extractor_input_names + ["labels"] - def decode(self, *args, token_timestamps=None, token_durations=None, **kwargs): + def decode(self, *args, durations=None, **kwargs): """ Forward arguments to [`~PreTrainedTokenizer.decode`] and post-process the timestamps (if provided for TDT) as in the NeMo library. """ decoded = self.tokenizer.decode(*args, **kwargs) - if token_timestamps is not None and token_durations is not None: + if durations is not None: token_ids = args[0] + # Derive per-step frame indices from cumulative sum of durations. + timestamps = durations.cumsum(dim=-1) - durations output_kwargs = self._merge_kwargs( ParakeetProcessorKwargs, @@ -112,16 +114,18 @@ def decode(self, *args, token_timestamps=None, token_durations=None, **kwargs): * output_kwargs["audio_kwargs"]["subsampling_factor"] ) proc_timestamps = [] - for batch_ids, timestamps, durations in zip(token_ids, token_timestamps, token_durations): + for batch_ids, batch_timestamps, batch_durations in zip(token_ids, timestamps, durations): # See `compute_rnnt_timestamps` in NeMo: https://github.com/NVIDIA-NeMo/NeMo/blob/1692a8fb97e1aadc883cfadd2a57c4e8a1b793aa/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L993 - # Filter padding (unwritten positions in `all_tokens_tensor` in `generate`) + # Filter padding and blank tokens + blank_token_id = self.tokenizer.convert_tokens_to_ids("") + skip_ids = {self.tokenizer.pad_token_id, blank_token_id} non_blank_indices = [ - i for i, token_id in enumerate(batch_ids) if token_id != self.tokenizer.pad_token_id + i for i, token_id in enumerate(batch_ids) if int(token_id) not in skip_ids ] non_blank_ids = [batch_ids[i] for i in non_blank_indices] decoded_tokens = [self.tokenizer.decode([token_id]) for token_id in non_blank_ids] timestamp_dict = [ - {"token": token_str, "start": int(timestamps[i]), "end": int(timestamps[i] + durations[i])} + {"token": token_str, "start": int(batch_timestamps[i]), "end": int(batch_timestamps[i] + batch_durations[i])} for token_str, i in zip(decoded_tokens, non_blank_indices) ] timestamp_dict = self._refine_timestamps_tdt(timestamp_dict) From cef6639e58ab564c512630b18e4f823213ab7a04 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:39:32 +0200 Subject: [PATCH 48/67] convertion script update --- src/transformers/models/parakeet/convert_nemo_to_hf.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index ccbec5fcb245..8cea24f4a0cc 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -370,6 +370,11 @@ def write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_t del model.config._name_or_path + model.generation_config.decoder_start_token_id = model.config.blank_token_id + model.generation_config.suppress_tokens = list( + range(model.config.vocab_size, model.config.vocab_size + len(model.config.durations)) + ) + print("Saving the model.") model.save_pretrained(output_dir) From fd3cf9b237e185c0e43d19100151e445e363a166 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:43:01 +0200 Subject: [PATCH 49/67] test update --- tests/models/parakeet/test_modeling_parakeet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 6667bb2ce5a5..76f1aaaa4ac9 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -40,7 +40,7 @@ ParakeetForTDT, ParakeetTDTConfig, ) - from transformers.models.parakeet.modeling_parakeet import tdt_loss + from transformers.loss.loss_tdt import tdt_loss @require_torch From e63a5bf1cce44b47b477db300f632143a6d9300a Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:44:09 +0200 Subject: [PATCH 50/67] make --- .../models/parakeet/generation_parakeet.py | 4 +- .../models/parakeet/modeling_parakeet.py | 480 ++---------------- 2 files changed, 47 insertions(+), 437 deletions(-) diff --git a/src/transformers/models/parakeet/generation_parakeet.py b/src/transformers/models/parakeet/generation_parakeet.py index bf7ac32051aa..b714f4dcc277 100644 --- a/src/transformers/models/parakeet/generation_parakeet.py +++ b/src/transformers/models/parakeet/generation_parakeet.py @@ -16,7 +16,7 @@ import torch -from ...generation import GenerationMixin, GenerationMode, StoppingCriteria +from ...generation import GenerationMixin, StoppingCriteria from ...utils import ModelOutput @@ -74,7 +74,7 @@ def _update_model_kwargs_for_generation(self, outputs, *args, **kwargs): logits = outputs.logits[:, -1, :] tokens = logits[:, : self.config.vocab_size].argmax(dim=-1) durations = logits[:, self.config.vocab_size :].argmax(dim=-1) - + # Only force forward progress (duration >= 1) for blank predictions; blank_mask = tokens == self.config.blank_token_id durations = torch.where(blank_mask & (durations == 0), torch.ones_like(durations), durations) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 5150d35daeef..e1b0006619c6 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -27,7 +27,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...generation import CompileConfig, GenerationMixin +from ...generation import CompileConfig, GenerationMixin, GenerationMode from ...integrations import use_kernel_func_from_hub, use_kernelized_func from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutput @@ -45,6 +45,7 @@ from ...utils.output_capturing import capture_outputs from ..auto import AutoModel from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig +from .generation_parakeet import ParakeetTDTGenerationMixin logger = logging.get_logger(__name__) @@ -898,11 +899,18 @@ def update( class ParakeetTDTDecoder(nn.Module): - """LSTM-based prediction network for TDT.""" + """LSTM-based prediction network for TDT. + + During generation the decoder is called once per step. When a blank token + is fed back (i.e. the model predicted blank at the previous step), the LSTM + state must *not* change — only the encoder frame advances. The blank- + skipping logic restores the previous cache state for those batch elements + using ``torch.where`` so that callers can treat the decoder as a black box. + """ def __init__(self, config: ParakeetTDTConfig): super().__init__() - self.config = config + self.blank_token_id = config.blank_token_id self.embedding = nn.Embedding(config.vocab_size, config.decoder_hidden_size) self.lstm = nn.LSTM( input_size=config.decoder_hidden_size, @@ -915,21 +923,26 @@ def __init__(self, config: ParakeetTDTConfig): def forward( self, input_ids: torch.LongTensor, - decoder_cache: ParakeetTDTDecoderCache | None = None, - decoder_cache_update_mask: torch.BoolTensor | None = None, + cache: ParakeetTDTDecoderCache | None = None, ) -> torch.Tensor: + # All-blank fast path + if cache is not None and cache.is_initialized: + blank_mask = input_ids[:, -1] == self.blank_token_id + if blank_mask.all(): + return cache.cache + hidden_cell_states = ( - (decoder_cache.hidden_state, decoder_cache.cell_state) - if decoder_cache is not None and decoder_cache.is_initialized - else None + (cache.hidden_state, cache.cell_state) if cache is not None and cache.is_initialized else None ) embeddings = self.embedding(input_ids) lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, hidden_cell_states) decoder_output = self.decoder_projector(lstm_output) - if decoder_cache is not None: - decoder_cache.update( - decoder_output, hidden_state, cell_state, lstm_module=self.lstm, mask=decoder_cache_update_mask - ) + + if cache is not None: + # Use ~blank_mask so only non-blank elements are updated; blank elements keep previous state. + mask = ~blank_mask if cache.is_initialized else None + cache.update(decoder_output, hidden_state, cell_state, lstm_module=self.lstm, mask=mask) + return cache.cache return decoder_output @@ -944,39 +957,11 @@ def __init__(self, config: ParakeetTDTConfig): def forward( self, - decoder_output: torch.Tensor, - encoder_output: torch.Tensor, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - joint_output = self.activation(encoder_output + decoder_output) - logits = self.head(joint_output) - token_logits = logits[..., : self.vocab_size] - duration_logits = logits[..., self.vocab_size :] - return token_logits, duration_logits - - -@dataclass -class ParakeetTDTGenerateOutput(ModelOutput): - """ - Outputs of Parakeet TDT generation. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Generated token sequences. - token_timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Per-token frame indices. Returned when `return_timestamps=True`. - token_durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Per-token durations in frames. Returned when `return_timestamps=True`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*): - Encoder attention weights per layer. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*): - Encoder hidden states per layer. - """ - - sequences: torch.LongTensor - token_timestamps: torch.FloatTensor | None = None - token_durations: torch.LongTensor | None = None - attentions: tuple[tuple[torch.FloatTensor]] | None = None - hidden_states: tuple[tuple[torch.FloatTensor]] | None = None + joint_output = self.activation(encoder_hidden_states + decoder_hidden_states) + return self.head(joint_output) @dataclass @@ -990,155 +975,24 @@ class ParakeetTDTOutput(BaseModelOutputWithPooling): logits (`torch.FloatTensor`): Joint token and duration logits. Shape is `(batch, T, U+1, vocab+durations)` for training or `(batch, 1, 1, vocab+durations)` for single-step inference. - attention_mask (`torch.Tensor`, *optional*): - Encoder output attention mask after subsampling. decoder_cache (`ParakeetTDTDecoderCache`, *optional*): - Decoder LSTM cache containing hidden state, cell state, and decoder output. - Updated in-place during generation. + Decoder LSTM cache containing hidden state, cell state, and last output. """ loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None - attention_mask: torch.Tensor | None = None decoder_cache: ParakeetTDTDecoderCache | None = None -# TODO (ebezzam) eventually move to audio_utils or loss_utils for common usage? -def tdt_loss( - token_logits: torch.Tensor, - duration_logits: torch.Tensor, - targets: torch.Tensor, - logit_lengths: torch.Tensor, - target_lengths: torch.Tensor, - blank_token_id: int, - durations: list[int], - sigma: float = 0.0, - reduction: str = "mean", -) -> torch.Tensor: - """ - Compute TDT (Token-and-Duration Transducer) loss (https://arxiv.org/abs/2304.06795). - - Ported from NeMo's `TDTLossPytorch` with anti-diagonal processing. Unlike standard RNNT loss, this loss trains both - the token prediction head and the duration prediction head. It uses vectorized anti-diagonal processing for - efficiency: all (t, u) pairs on each anti-diagonal t+u=n are computed in parallel as batched tensor operations. - - Args: - token_logits: Token logits of shape `(batch, T, U+1, vocab_size+1)`. - duration_logits: Duration logits of shape `(batch, T, U+1, num_durations)`. - targets: Target labels of shape `(batch, U)`. - logit_lengths: Encoder output lengths of shape `(batch,)`. - target_lengths: Target lengths of shape `(batch,)`. - blank_token_id: Blank token id. - durations: List of duration values (e.g., `[0, 1, 2, 3, 4]`). - sigma: Logit undernormalization constant (see TDT paper). Defaults to `0.0`. - reduction: Loss reduction method. One of `"mean"`, `"sum"`, or `"none"`. Defaults to `"mean"`. - - Returns: - Scalar loss tensor (or per-example losses if `reduction="none"`). - - """ - device = token_logits.device - batch_size, max_t, max_u, _ = token_logits.shape - - # Apply log-softmax to get log probabilities - token_log_probs = torch.log_softmax(token_logits, dim=-1) - sigma - duration_log_probs = torch.log_softmax(duration_logits, dim=-1) - - log_alpha = torch.full((batch_size, max_t, max_u), float("-inf"), device=device) - log_alpha[:, 0, 0] = 0.0 - - # Precompute blank and label log-probs for vectorized access - blank_log_probs = token_log_probs[:, :, :, blank_token_id] - - if max_u > 1: - targets_expanded = targets.unsqueeze(1).expand(-1, max_t, -1) # (batch, T, U_labels) - label_log_probs = torch.gather( - token_log_probs[:, :, : max_u - 1, :], # (batch, T, U-1, vocab) - dim=3, - index=targets_expanded.unsqueeze(-1), - ).squeeze(-1) # (batch, T, U-1) - - # Process anti-diagonals: all (t, u) with t + u = n have no mutual dependencies - for n in range(1, max_t + max_u - 1): - u_start = max(0, n - max_t + 1) - u_end = min(n + 1, max_u) - u_indices = torch.arange(u_start, u_end, device=device) - - t_indices = n - u_indices - all_candidates = [] - for i, dur in enumerate(durations): - t_prev = t_indices - dur - valid_t = t_prev >= 0 - if not valid_t.any(): - continue - t_src = t_prev.clamp(min=0) - - # Blank arcs (dur > 0): from (t-dur, u) to (t, u) - if dur > 0: - contrib = ( - log_alpha[:, t_src, u_indices] - + blank_log_probs[:, t_src, u_indices] - + duration_log_probs[:, t_src, u_indices, i] - ) - contrib = torch.where(valid_t.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) - all_candidates.append(contrib) - - # Label arcs: from (t-dur, u-1) to (t, u), only if u > 0 - valid_u = u_indices > 0 - valid_both = valid_t & valid_u - if valid_both.any(): - u_src = (u_indices - 1).clamp(min=0) - u_src_label = u_src.clamp(max=max_u - 2) if max_u > 1 else u_src - - contrib = ( - log_alpha[:, t_src, u_src] - + label_log_probs[:, t_src, u_src_label] - + duration_log_probs[:, t_src, u_src, i] - ) - contrib = torch.where(valid_both.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) - all_candidates.append(contrib) - - if all_candidates: - stacked = torch.stack(all_candidates, dim=0) - log_alpha[:, t_indices, u_indices] = torch.logsumexp(stacked, dim=0) - - # Terminal probability: sum over blank arcs that reach (T, U) from (T-dur, U) - batch_idx = torch.arange(batch_size, device=device) - log_probs = torch.full((batch_size,), float("-inf"), device=device) - for i, dur in enumerate(durations): - if dur == 0: - continue - t_final = logit_lengths - dur - valid = t_final >= 0 - if not valid.any(): - continue - - t_clamped = t_final.clamp(min=0) - terminal = ( - log_alpha[batch_idx, t_clamped, target_lengths] - + token_log_probs[batch_idx, t_clamped, target_lengths, blank_token_id] - + duration_log_probs[batch_idx, t_clamped, target_lengths, i] - ) - combined = torch.stack([log_probs, terminal], dim=0) - log_probs = torch.where(valid, torch.logsumexp(combined, dim=0), log_probs) - - losses = -log_probs - - if reduction == "mean": - return (losses / target_lengths.float()).mean() - elif reduction == "sum": - return losses.sum() - return losses - - @auto_docstring( custom_intro=""" Parakeet Encoder with a TDT (Token Duration Transducer) head. """ ) -class ParakeetForTDT(ParakeetPreTrainedModel, GenerationMixin): +class ParakeetForTDT(ParakeetPreTrainedModel, ParakeetTDTGenerationMixin): config: ParakeetTDTConfig _no_split_modules = ["ParakeetTDTDecoder"] + _supported_generation_modes = [GenerationMode.GREEDY_SEARCH] def __init__(self, config: ParakeetTDTConfig): super().__init__(config) @@ -1146,7 +1000,6 @@ def __init__(self, config: ParakeetTDTConfig): self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.decoder = ParakeetTDTDecoder(config) self.joint = ParakeetTDTJointNetwork(config) - self.loss_function = tdt_loss self.post_init() @@ -1172,58 +1025,13 @@ def forward( input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, decoder_input_ids: torch.LongTensor | None = None, - encoder_outputs: tuple[torch.FloatTensor] | None = None, - encoder_frame_ids: torch.LongTensor | None = None, decoder_cache: ParakeetTDTDecoderCache | None = None, - decoder_cache_update_mask: torch.BoolTensor | None = None, use_decoder_cache: bool | None = None, + encoder_outputs: ParakeetEncoderModelOutput | tuple[torch.FloatTensor] | None = None, labels: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTOutput: - r""" - decoder_input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): - Decoder input token ids for single-step inference. - encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): - Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask). - Can be a tuple or `ParakeetEncoderModelOutput`. - encoder_frame_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Encoder frame indices for the joint network during generation. - decoder_cache (`ParakeetTDTDecoderCache`, *optional*): - Decoder LSTM cache. When provided and initialized, the cached `decoder_output` is reused - (e.g. during blank-skipping) instead of running the decoder. When `input_ids` is provided, - the decoder runs and the cache is updated in-place. - decoder_cache_update_mask (`torch.BoolTensor` of shape `(batch_size,)`, *optional*): - Boolean mask controlling which batch elements have their decoder cache updated. - When provided, only elements where the mask is `True` are written to the cache; - other elements retain their previous cached state. Used during generation to - preserve cache for samples that predicted blank tokens. - use_decoder_cache (`bool`, *optional*): - Whether to use a decoder cache. When `True` and `decoder_cache` is `None`, a new cache - is created automatically during the forward pass. - - Example: - - ```python - >>> from transformers import AutoProcessor, ParakeetForTDT - >>> from datasets import load_dataset, Audio - - >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = ParakeetForTDT.from_pretrained(model_id) - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) - - >>> inputs = processor(ds[0]["audio"]["array"]) - >>> outputs = model(**inputs) - ``` - """ - # 1. Encode + project if encoder_outputs is None: - if input_features is None: - raise ValueError("Either `input_features` or `encoder_outputs` must be provided.") - if labels is not None: - kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.get_audio_features( input_features=input_features, attention_mask=attention_mask, @@ -1231,241 +1039,43 @@ def forward( ) elif not isinstance(encoder_outputs, ParakeetEncoderModelOutput): encoder_outputs = ParakeetEncoderModelOutput( - last_hidden_state=encoder_outputs[0], - pooler_output=encoder_outputs[1], + last_hidden_state=encoder_outputs[0] if len(encoder_outputs) > 0 else None, + pooler_output=encoder_outputs[1] if len(encoder_outputs) > 1 else None, hidden_states=encoder_outputs[2] if len(encoder_outputs) > 2 else None, attentions=encoder_outputs[3] if len(encoder_outputs) > 3 else None, attention_mask=encoder_outputs[4] if len(encoder_outputs) > 4 else None, ) - projected_encoder_output = encoder_outputs.pooler_output - - if labels is not None: - # for training: [blank, labels...] for training - blank_tokens = torch.full( - (labels.shape[0], 1), self.config.blank_token_id, dtype=labels.dtype, device=labels.device - ) - decoder_input_ids = torch.cat([blank_tokens, labels], dim=1) - elif decoder_input_ids is None and decoder_cache is None: - # for inference: start with blank token if not provided - decoder_input_ids = torch.full( - (projected_encoder_output.shape[0], 1), - self.config.blank_token_id, - dtype=torch.long, - device=projected_encoder_output.device, - ) if use_decoder_cache and decoder_cache is None: decoder_cache = ParakeetTDTDecoderCache() - # Run decoder if we have decoder_input_ids (initial step or after emitting a token) - if decoder_input_ids is not None: - decoder_output = self.decoder(decoder_input_ids, decoder_cache, decoder_cache_update_mask) - else: - # Reuse cached decoder_output (blank-skipping path) - decoder_output = decoder_cache.cache - - if encoder_frame_ids is not None: - batch_indices = torch.arange(projected_encoder_output.shape[0], device=projected_encoder_output.device) - safe_frame_ids = torch.clamp(encoder_frame_ids, max=projected_encoder_output.shape[1] - 1) - encoder_for_joint = projected_encoder_output[batch_indices, safe_frame_ids].unsqueeze(1) - decoder_for_joint = decoder_output - else: - encoder_for_joint = projected_encoder_output.unsqueeze(2) - decoder_for_joint = decoder_output.unsqueeze(1) - - token_logits, duration_logits = self.joint( - decoder_output=decoder_for_joint, - encoder_output=encoder_for_joint, + decoder_hidden_states = self.decoder(decoder_input_ids, cache=decoder_cache) + logits = self.joint( + encoder_hidden_states=encoder_outputs.pooler_output, + decoder_hidden_states=decoder_hidden_states, ) - logits = torch.cat([token_logits, duration_logits], dim=-1) loss = None if labels is not None: - encoder_lengths = encoder_outputs.attention_mask.sum(-1) - target_lengths = (labels != self.config.pad_token_id).sum(-1) loss = self.loss_function( - token_logits=token_logits.float(), - duration_logits=duration_logits.float(), - targets=labels.to(token_logits.device).int(), - logit_lengths=encoder_lengths.to(token_logits.device).int(), - target_lengths=target_lengths.to(token_logits.device).int(), + token_logits=logits[..., : self.config.vocab_size], + duration_logits=logits[..., self.config.vocab_size :], + labels=labels, + logit_lengths=encoder_outputs.attention_mask.sum(-1), + label_lengths=(labels != self.config.pad_token_id).sum(-1), blank_token_id=self.config.blank_token_id, durations=self.config.durations, - reduction="mean", ) return ParakeetTDTOutput( loss=loss, logits=logits, last_hidden_state=encoder_outputs.last_hidden_state, + pooler_output=encoder_outputs.pooler_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, - pooler_output=encoder_outputs.pooler_output, - attention_mask=encoder_outputs.attention_mask, decoder_cache=decoder_cache, ) - @torch.no_grad() - def generate( - self, - input_features: torch.Tensor, - attention_mask: torch.Tensor | None = None, - return_timestamps: bool = False, - return_dict_in_generate: bool = False, - compile_config: CompileConfig | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> ParakeetTDTGenerateOutput | torch.LongTensor: - r""" - return_timestamps (`bool`, *optional*, defaults to `False`): - Whether to return per-token timestamps and durations. When `True`, forces - `return_dict_in_generate=True` and includes `token_timestamps` and `token_durations` in the output. - compile_config ([`~generation.CompileConfig`], *optional*): - If provided, `torch.compile` will be applied to the forward calls in the decoding loop. - - Example: - - ```python - >>> from transformers import AutoProcessor, ParakeetForTDT - >>> from datasets import load_dataset, Audio - - >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = ParakeetForTDT.from_pretrained(model_id) - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) - - >>> inputs = processor(ds[0]["audio"]["array"], sampling_rate=processor.feature_extractor.sampling_rate) - >>> inputs = inputs.to(model.device, dtype=model.dtype) - >>> output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) - - >>> decoded_output, decoded_timestamps = processor.decode( - ... output.sequences, - ... token_timestamps=output.token_timestamps, - ... token_durations=output.token_durations, - ... skip_special_tokens=True - ... ) - >>> print("Transcription:", decoded_output) - >>> print("Timestamped tokens:", decoded_timestamps) - ``` - """ - if return_timestamps: - return_dict_in_generate = True - - model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__ - - # Initial forward: encode + decoder initialization - kwargs.setdefault("output_attention_mask", True) - outputs = model_forward( - input_features=input_features, - attention_mask=attention_mask, - use_decoder_cache=True, - return_dict=True, - **kwargs, - ) - - # Reconstruct encoder_outputs for subsequent forward calls - encoder_outputs = ParakeetEncoderModelOutput( - last_hidden_state=outputs.last_hidden_state, - pooler_output=outputs.pooler_output, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - attention_mask=outputs.attention_mask, - ) - decoder_cache = outputs.decoder_cache - batch_size, sequence_length = outputs.pooler_output.shape[:2] - device = outputs.pooler_output.device - - if outputs.attention_mask is not None: - valid_lengths = outputs.attention_mask.sum(dim=1).int() - else: - valid_lengths = torch.full((batch_size,), sequence_length, dtype=torch.int, device=device) - - time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) - time_indices_current_labels = torch.zeros(batch_size, dtype=torch.long, device=device) - active_mask = time_indices < valid_lengths - symbols_per_step = torch.zeros(batch_size, dtype=torch.long, device=device) - last_label_time = torch.full((batch_size,), -1, dtype=torch.long, device=device) - max_output_len = sequence_length * self.config.max_symbols_per_step - all_tokens_tensor = torch.full( - (batch_size, max_output_len), self.config.pad_token_id, dtype=torch.long, device=device - ) - tokens = torch.zeros(batch_size, dtype=torch.long, device=device) - durations = torch.zeros(batch_size, dtype=torch.long, device=device) - token_counts = torch.zeros(batch_size, dtype=torch.long, device=device) - if return_timestamps: - all_frame_indices = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) - all_durations_tensor = torch.zeros((batch_size, max_output_len), dtype=torch.long, device=device) - - while active_mask.any(): - active_at_start = active_mask.clone() - - time_indices_current_labels = torch.where(active_at_start, time_indices, time_indices_current_labels) - outputs = model_forward( - encoder_outputs=encoder_outputs, - encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), - decoder_cache=decoder_cache, - return_dict=True, - ) - logits = outputs.logits.squeeze(1) - tokens = torch.where(active_at_start, logits[..., : self.config.vocab_size].argmax(dim=-1), tokens) - durations = torch.where(active_at_start, logits[..., self.config.vocab_size :].argmax(dim=-1), durations) - - blank_mask = active_at_start & (tokens == self.config.blank_token_id) - durations = durations.masked_fill(blank_mask & (durations == 0), 1) # ensure forward progress - - # Advance time for all active samples - time_indices = time_indices + durations.masked_fill(~active_at_start, 0) - active_mask = time_indices < valid_lengths - - # If all remaining active samples predicted blank, skip emit + decoder update - emit_mask = active_at_start & ~blank_mask - if not emit_mask.any(): - continue - - # Emit non-blank tokens - emit_indices = token_counts[emit_mask] - all_tokens_tensor[emit_mask, emit_indices] = tokens[emit_mask] - if return_timestamps: - all_frame_indices[emit_mask, emit_indices] = time_indices_current_labels[emit_mask] - all_durations_tensor[emit_mask, emit_indices] = durations[emit_mask] - token_counts += emit_mask.long() - - # Update decoder cache for emitted tokens (using potentially compiled forward) - model_forward( - decoder_input_ids=tokens.unsqueeze(1), - encoder_outputs=encoder_outputs, - encoder_frame_ids=torch.clamp(time_indices, max=sequence_length - 1), - decoder_cache=decoder_cache, - decoder_cache_update_mask=emit_mask, - return_dict=True, - ) - - time_changed = time_indices_current_labels != last_label_time - symbols_per_step = torch.where(time_changed, 0, symbols_per_step) - symbols_per_step = torch.where(emit_mask, symbols_per_step + 1, symbols_per_step) - last_label_time = torch.where(emit_mask, time_indices_current_labels, last_label_time) - force_advance = active_mask & (symbols_per_step >= self.config.max_symbols_per_step) - time_indices = time_indices + force_advance.long() - symbols_per_step = symbols_per_step.masked_fill(force_advance, 0) - active_mask = time_indices < valid_lengths - - max_len = max(token_counts.max().item(), 1) - sequences = all_tokens_tensor[:, :max_len] - token_timestamps, token_durations = None, None - if return_timestamps: - token_timestamps = all_frame_indices[:, :max_len] - token_durations = all_durations_tensor[:, :max_len] - - if return_dict_in_generate: - return ParakeetTDTGenerateOutput( - sequences=sequences, - token_timestamps=token_timestamps, - token_durations=token_durations, - attentions=outputs.attentions, - hidden_states=outputs.hidden_states, - ) - return sequences - __all__ = ["ParakeetForCTC", "ParakeetForTDT", "ParakeetEncoder", "ParakeetPreTrainedModel"] From 43ee7cd7f5fb352cc875e007777378f79116e0a8 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 17:33:30 +0200 Subject: [PATCH 51/67] test update --- .../models/parakeet/test_modeling_parakeet.py | 58 ++++++++++++++----- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 76f1aaaa4ac9..d1407c3633f2 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -506,12 +506,12 @@ def get_config(self): blank_token_id=self.blank_token_id, ) - def create_and_check_model(self, config, input_features, attention_mask): + def create_and_check_model(self, config, inputs_dict): model = ParakeetForTDT(config=config) model.to(torch_device) model.eval() with torch.no_grad(): - result = model(input_features, attention_mask=attention_mask) + result = model(**inputs_dict) # Check encoder last hidden state self.parent.assertEqual( @@ -521,9 +521,11 @@ def create_and_check_model(self, config, input_features, attention_mask): def prepare_config_and_inputs_for_common(self): config, input_features, attention_mask = self.prepare_config_and_inputs() + decoder_input_ids = ids_tensor([self.batch_size, 1], self.vocab_size) inputs_dict = { "input_features": input_features, "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, } return config, inputs_dict @@ -564,6 +566,44 @@ def test_model(self): def test_model_get_set_embeddings(self): pass + @unittest.skip( + reason="ParakeetForTDT is a transducer, not a standard encoder-decoder: no separate text config to set" + ) + def test_attn_implementation_composite_models(self): + pass + + @unittest.skip( + reason="ParakeetForTDT is a transducer with an LSTM prediction network; " + "it does not expose encoder_hidden_states in the standard encoder-decoder sense" + ) + def test_hidden_states_output(self): + pass + + @unittest.skip( + reason="ParakeetForTDT is a transducer with an LSTM prediction network; " + "it does not expose encoder_hidden_states in the standard encoder-decoder sense" + ) + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip( + reason="ParakeetForTDT has a custom generate() that is not fully compatible with GenerationTesterMixin" + ) + def test_generation_tester_mixin_inheritance(self): + pass + + @unittest.skip( + reason="ParakeetForTDT is a flat composite model without a separate base_model sub-module" + ) + def test_model_base_model_prefix(self): + pass + + @unittest.skip( + reason="ParakeetForTDT decoder is an LSTM prediction network without attention" + ) + def test_flex_attention_with_grads(self): + pass + # Original function assumes vision+text model, so overwrite since Parakeet is audio+text def test_sdpa_can_dispatch_composite_models(self): if not self.has_attentions: @@ -590,20 +630,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - def test_generate(self): - """Test that generate() produces valid output.""" - config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() - model = ParakeetForTDT(config=config) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - sequences = model.generate(input_features, attention_mask=attention_mask) - - self.assertIsInstance(sequences, torch.Tensor) - self.assertEqual(sequences.dim(), 2) - self.assertEqual(sequences.shape[0], self.model_tester.batch_size) - @require_torch class ParakeetForTDTIntegrationTest(unittest.TestCase): From c2a0f781ca7bc23f1158ea53fd729b6b1356a1a5 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 15 Apr 2026 17:34:00 +0200 Subject: [PATCH 52/67] test update --- tests/models/parakeet/test_modeling_parakeet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index d1407c3633f2..41eac202e014 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -559,7 +559,7 @@ def test_config(self): self.config_tester.run_common_tests() def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.create_and_check_model(*config_and_inputs) @unittest.skip(reason="ParakeetForTDT does not use inputs_embeds") From 1fd7ed78151be80426c893e275e312ceedcac753 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 16 Apr 2026 12:00:37 +0200 Subject: [PATCH 53/67] ensure correct loss computation --- src/transformers/loss/loss_tdt.py | 2 +- .../models/parakeet/modeling_parakeet.py | 6 +++--- .../models/parakeet/modular_parakeet.py | 7 +++---- .../models/parakeet/processing_parakeet.py | 14 +++++++++++--- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/transformers/loss/loss_tdt.py b/src/transformers/loss/loss_tdt.py index 27389e10b725..ae7afa1e7edc 100644 --- a/src/transformers/loss/loss_tdt.py +++ b/src/transformers/loss/loss_tdt.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 66fd971aec39..367f75984303 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -1051,9 +1051,9 @@ def forward( decoder_hidden_states = self.decoder(decoder_input_ids, cache=decoder_cache) logits = self.joint( - encoder_hidden_states=encoder_outputs.pooler_output, - decoder_hidden_states=decoder_hidden_states, - ) + encoder_hidden_states=encoder_outputs.pooler_output[:, :, None, :], + decoder_hidden_states=decoder_hidden_states[:, None, :, :], + ).squeeze(2) loss = None if labels is not None: diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 37a6065fe49d..0f413aa088ff 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -778,7 +778,6 @@ def forward( return decoder_output - class ParakeetTDTJointNetwork(nn.Module): """Joint network that combines encoder and decoder outputs to predict tokens and durations.""" @@ -915,9 +914,9 @@ def forward( decoder_hidden_states = self.decoder(decoder_input_ids, cache=decoder_cache) logits = self.joint( - encoder_hidden_states=encoder_outputs.pooler_output, - decoder_hidden_states=decoder_hidden_states, - ) + encoder_hidden_states=encoder_outputs.pooler_output[:, :, None, :], + decoder_hidden_states=decoder_hidden_states[:, None, :, :], + ).squeeze(2) loss = None if labels is not None: diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py index 91d502784828..2a691deaea76 100644 --- a/src/transformers/models/parakeet/processing_parakeet.py +++ b/src/transformers/models/parakeet/processing_parakeet.py @@ -40,7 +40,9 @@ class ParakeetProcessorKwargs(ProcessingKwargs, total=False): @auto_docstring class ParakeetProcessor(ProcessorMixin): - def __init__(self, feature_extractor, tokenizer): + def __init__(self, feature_extractor, tokenizer, blank_token=""): + self.blank_token = blank_token + self.blank_token_id = tokenizer.convert_tokens_to_ids(blank_token) super().__init__(feature_extractor, tokenizer) @auto_docstring @@ -84,6 +86,13 @@ def __call__( return inputs else: inputs["labels"] = encodings["input_ids"] + # Prepend blank token to labels to form decoder_input_ids. + # The TDT decoder expects [blank, label_0, ..., label_{U-1}] as input, + if isinstance(text, str): + text = [text] + decoder_text = [self.blank_token + t for t in text] + decoder_encodings = self.tokenizer(decoder_text, **output_kwargs["text_kwargs"]) + inputs["decoder_input_ids"] = decoder_encodings["input_ids"] return inputs @property @@ -106,7 +115,6 @@ def decode(self, *args, durations=None, **kwargs): output_kwargs = self._merge_kwargs( ParakeetProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, - **kwargs, ) frame_rate = ( self.feature_extractor.hop_length @@ -117,7 +125,7 @@ def decode(self, *args, durations=None, **kwargs): for batch_ids, batch_timestamps, batch_durations in zip(token_ids, timestamps, durations): # See `compute_rnnt_timestamps` in NeMo: https://github.com/NVIDIA-NeMo/NeMo/blob/1692a8fb97e1aadc883cfadd2a57c4e8a1b793aa/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L993 # Filter padding and blank tokens - blank_token_id = self.tokenizer.convert_tokens_to_ids("") + blank_token_id = self.blank_token_id skip_ids = {self.tokenizer.pad_token_id, blank_token_id} non_blank_indices = [ i for i, token_id in enumerate(batch_ids) if int(token_id) not in skip_ids From 7cc9d2e7dc3c4866c519854034c8a27ef9d0f747 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 16 Apr 2026 16:23:45 +0200 Subject: [PATCH 54/67] kernel loss --- src/transformers/integrations/hub_kernels.py | 1 + src/transformers/loss/loss_tdt.py | 50 ++++++++++++++++++-- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 88aff578fdc6..2894209173d3 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -286,6 +286,7 @@ def register_kernel_mapping_transformers(*args, **kwargs): "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1}, "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, + "tdt-loss": {"repo_id": "eustlb/tdt-loss", "version": 1}, } _KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {} diff --git a/src/transformers/loss/loss_tdt.py b/src/transformers/loss/loss_tdt.py index ae7afa1e7edc..3172c0175291 100644 --- a/src/transformers/loss/loss_tdt.py +++ b/src/transformers/loss/loss_tdt.py @@ -1,4 +1,4 @@ -# Copyright 2026 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,24 @@ import torch +from ..utils import logging + + +logger = logging.get_logger(__name__) + + +def _load_tdt_kernel(): + """Try to load the TDT loss CUDA kernel from the Hub. Returns None on failure.""" + try: + from ..integrations.hub_kernels import lazy_load_kernel + + return lazy_load_kernel("tdt-loss") + except (ImportError, ModuleNotFoundError): + return None + except Exception as e: + logger.warning_once(f"Failed to load TDT CUDA kernel: {e}. Falling back to pure PyTorch implementation.") + return None + def tdt_loss( token_logits: torch.Tensor, @@ -33,6 +51,9 @@ def tdt_loss( the token prediction head and the duration prediction head. It uses vectorized anti-diagonal processing for efficiency: all (t, u) pairs on each anti-diagonal t+u=n are computed in parallel as batched tensor operations. + When the ``kernels-community/tdt-loss`` CUDA kernel is installed, it is used automatically for GPU tensors, + Falls back to the pure PyTorch implementation otherwise. + Args: token_logits: Token logits of shape `(batch, T, U+1, vocab_size+1)`. duration_logits: Duration logits of shape `(batch, T, U+1, num_durations)`. @@ -48,6 +69,18 @@ def tdt_loss( Scalar loss tensor (or per-example losses if `reduction="none"`). """ + kernel = _load_tdt_kernel() if token_logits.is_cuda else None + if kernel is not None and hasattr(kernel, "tdt_loss"): + durations_t = torch.tensor(durations, dtype=torch.int32, device=token_logits.device) + return kernel.tdt_loss( + token_logits, duration_logits, targets, + logit_lengths, target_lengths, durations_t, + blank_token_id, sigma, reduction, + ) + + if reduction not in ("mean", "sum", "none"): + raise ValueError(f'Invalid reduction mode "{reduction}". Expected one of "mean", "sum", or "none".') + device = token_logits.device batch_size, max_t, max_u, _ = token_logits.shape @@ -55,6 +88,7 @@ def tdt_loss( duration_logits = duration_logits.float() # Apply log-softmax to get log probabilities + # sigma only applies to token logits (undernormalization constant from the TDT paper) token_log_probs = torch.log_softmax(token_logits, dim=-1) - sigma duration_log_probs = torch.log_softmax(duration_logits, dim=-1) @@ -72,6 +106,8 @@ def tdt_loss( index=targets_expanded.unsqueeze(-1), ).squeeze(-1) # (batch, T, U-1) + neg_inf = torch.tensor(float("-inf"), device=device) + # Process anti-diagonals: all (t, u) with t + u = n have no mutual dependencies for n in range(1, max_t + max_u - 1): u_start = max(0, n - max_t + 1) @@ -94,7 +130,7 @@ def tdt_loss( + blank_log_probs[:, t_src, u_indices] + duration_log_probs[:, t_src, u_indices, i] ) - contrib = torch.where(valid_t.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) + contrib = torch.where(valid_t.unsqueeze(0), contrib, neg_inf) all_candidates.append(contrib) # Label arcs: from (t-dur, u-1) to (t, u), only if u > 0 @@ -109,7 +145,7 @@ def tdt_loss( + label_log_probs[:, t_src, u_src_label] + duration_log_probs[:, t_src, u_src, i] ) - contrib = torch.where(valid_both.unsqueeze(0), contrib, torch.tensor(float("-inf"), device=device)) + contrib = torch.where(valid_both.unsqueeze(0), contrib, neg_inf) all_candidates.append(contrib) if all_candidates: @@ -153,15 +189,19 @@ def ParakeetForTDTLoss( label_lengths, blank_token_id, durations, + sigma=0.0, + reduction="mean", **kwargs, ): device = token_logits.device return tdt_loss( - token_logits=token_logits.float(), - duration_logits=duration_logits.float(), + token_logits=token_logits, + duration_logits=duration_logits, targets=labels.to(device).int(), logit_lengths=logit_lengths.to(device).int(), target_lengths=label_lengths.to(device).int(), blank_token_id=blank_token_id, durations=durations, + sigma=sigma, + reduction=reduction, ) From e753eab145452f055f2f00958f0f98d46afc2211 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 16 Apr 2026 16:24:20 +0200 Subject: [PATCH 55/67] test loss integration --- .../fixtures/parakeet/expected_loss_tdt.json | 5 ++ .../models/parakeet/test_modeling_parakeet.py | 57 +++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 tests/fixtures/parakeet/expected_loss_tdt.json diff --git a/tests/fixtures/parakeet/expected_loss_tdt.json b/tests/fixtures/parakeet/expected_loss_tdt.json new file mode 100644 index 000000000000..aee3c3f16c2b --- /dev/null +++ b/tests/fixtures/parakeet/expected_loss_tdt.json @@ -0,0 +1,5 @@ +{ + "num_samples": 2, + "expected_mean_loss": 0.528089, + "comment": "NeMo reference with sigma=0, HF-style mean reduction (per-sample / target_length, then average). Generated with https://gist.github.com/883ea42bf7d8ce2af42f3055627476a7" +} diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 41eac202e014..9f1882bb6719 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -741,3 +741,60 @@ def test_tdt_model_integration_timestamps(self): torch.testing.assert_close(predicted_start_times, EXPECTED_START_TIMESTAMPS) torch.testing.assert_close(predicted_end_times, EXPECTED_END_TIMESTAMPS) self.assertListEqual(output.token_durations.cpu().tolist(), EXPECTED_DURATIONS) + + @slow + def test_tdt_model_integration_loss(self): + """ + Verify that ParakeetForTDT loss matches NeMo's TDT loss (sigma=0) for both + the CUDA kernel and the pure PyTorch implementation. + reproducer: https://gist.github.com/883ea42bf7d8ce2af42f3055627476a7 + """ + from transformers.loss.loss_tdt import _load_tdt_kernel + + RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_loss_tdt.json" + with open(RESULTS_PATH, "r") as f: + raw_data = json.load(f) + EXPECTED_MEAN_LOSS = torch.tensor(raw_data["expected_mean_loss"]) + num_samples = raw_data["num_samples"] + + samples = self._load_datasamples(num_samples) + transcripts = self._dataset.sort("id")[:num_samples]["text"] + transcripts = [t.lower() for t in transcripts] + + # Use float32 for loss precision + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map="auto") + + inputs = self.processor( + audio=samples, + text=transcripts, + sampling_rate=self.processor.feature_extractor.sampling_rate, + ) + inputs.to(model.device) + + # Test both backends: kernel (if available) and pure PyTorch + has_kernel = _load_tdt_kernel() is not None + backends = [("kernel", None), ("torch", patch("transformers.loss.loss_tdt._load_tdt_kernel", return_value=None))] + if not has_kernel: + backends = backends[1:] # skip kernel test when not installed + + for backend_name, ctx in backends: + with self.subTest(backend=backend_name): + ctx_manager = ctx if ctx is not None else nullcontext() + with ctx_manager: + # Forward in eval mode — check loss matches NeMo + model.eval() + with torch.no_grad(): + outputs = model(**inputs) + self.assertIsNotNone(outputs.loss, "Loss must be computed when labels are provided") + self.assertEqual(outputs.logits.dim(), 4, "Training logits must be 4D (B, T, U+1, V+D)") + torch.testing.assert_close(outputs.loss.cpu(), EXPECTED_MEAN_LOSS, rtol=1e-3, atol=1e-3) + + # Backward — verify gradients flow + del outputs + torch.cuda.empty_cache() + model.train() + model.zero_grad() + outputs = model(**inputs) + outputs.loss.backward() + n_with_grad = sum(1 for p in model.parameters() if p.grad is not None) + self.assertGreater(n_with_grad, 0, "No gradients after backward") From ed3fa4dca3128788f9eb8c688c32a9cd94f19d52 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 16 Apr 2026 16:44:05 +0200 Subject: [PATCH 56/67] push to hub pr --- .../models/parakeet/convert_nemo_to_hf.py | 48 ++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index 8cea24f4a0cc..a7874e4996a0 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -142,7 +142,7 @@ def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str return model_files -def write_processor(nemo_config: dict, model_files, output_dir, model_type, push_to_repo_id=None): +def write_processor(nemo_config: dict, model_files, output_dir, model_type, push_to_repo_id=None, create_pr=True, revision=None): tokenizer_converted = ParakeetConverter(model_files["tokenizer_model_file"]).converted() tokenizer_converted_fast = ParakeetTokenizer( tokenizer_object=tokenizer_converted, @@ -204,7 +204,12 @@ def write_processor(nemo_config: dict, model_files, output_dir, model_type, push processor.save_pretrained(output_dir) if push_to_repo_id: - processor.push_to_hub(push_to_repo_id) + commit_info = processor.push_to_hub(push_to_repo_id, create_pr=create_pr, revision=revision) + if create_pr and hasattr(commit_info, "pr_url") and commit_info.pr_url: + pr_num = commit_info.pr_url.rstrip("/").split("/")[-1] + return f"refs/pr/{pr_num}" + + return revision def convert_encoder_config(nemo_config): @@ -273,7 +278,7 @@ def load_and_convert_state_dict(model_files): return converted_state_dict -def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None): +def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None, revision=None): """Write CTC model using encoder config and converted state dict.""" model_config = ParakeetCTCConfig.from_encoder_config(encoder_config) @@ -288,7 +293,7 @@ def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_re model.save_pretrained(output_dir) if push_to_repo_id: - model.push_to_hub(push_to_repo_id) + model.push_to_hub(push_to_repo_id, revision=revision) del model @@ -347,7 +352,7 @@ def load_and_convert_tdt_state_dict(model_files, vocab_size): return converted_state_dict -def write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_to_repo_id=None): +def write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_to_repo_id=None, revision=None): """Write TDT model using encoder config, TDT config, and converted state dict.""" model_config = convert_tdt_config(nemo_config, encoder_config) print(f"Converted TDT config: {model_config}") @@ -379,7 +384,7 @@ def write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_t model.save_pretrained(output_dir) if push_to_repo_id: - model.push_to_hub(push_to_repo_id) + model.push_to_hub(push_to_repo_id, revision=revision) del model @@ -389,16 +394,16 @@ def write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_t print("Model reloaded successfully.") -def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None): +def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None, revision=None): """Main model conversion function.""" encoder_config = convert_encoder_config(nemo_config) print(f"Converted encoder config: {encoder_config}") if model_type == "ctc": converted_state_dict = load_and_convert_state_dict(model_files) - write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id) + write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id, revision) elif model_type == "tdt": - write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_to_repo_id) + write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_to_repo_id, revision) else: raise ValueError(f"Model type {model_type} not supported.") @@ -408,6 +413,8 @@ def main( output_dir, model_type, push_to_repo_id=None, + create_pr=True, + revision=None, ): nemo_filename = f"{hf_repo_id.split('/')[-1]}.nemo" filepath = cached_file(hf_repo_id, nemo_filename) @@ -415,8 +422,14 @@ def main( model_files = extract_nemo_archive(filepath, os.path.dirname(filepath)) nemo_config = yaml.load(open(model_files["model_config"], "r"), Loader=yaml.FullLoader) - write_processor(nemo_config, model_files, output_dir, model_type, push_to_repo_id) - write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id) + # When revision is given (e.g. "refs/pr/3"), both pushes target that existing PR branch. + # Otherwise, write_processor creates a new PR and returns its revision for write_model. + pr_revision = write_processor( + nemo_config, model_files, output_dir, model_type, push_to_repo_id, + create_pr=create_pr if revision is None else False, + revision=revision, + ) + write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id, pr_revision) """ @@ -444,10 +457,23 @@ def main( parser.add_argument("--model_type", required=True, choices=["ctc", "tdt"], help="Model type (`ctc`, `tdt`)") parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model") parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub") + parser.add_argument( + "--create_pr", + default=True, + action=argparse.BooleanOptionalAction, + help="Create a PR when pushing to the Hub (default: True). Use --no-create_pr to push directly.", + ) + parser.add_argument( + "--revision", + default=None, + help='Push to an existing Hub PR branch (e.g. "refs/pr/3"). Overrides --create_pr.', + ) args = parser.parse_args() main( args.hf_repo_id, args.output_dir, args.model_type, args.push_to_repo_id, + args.create_pr, + args.revision, ) From ab66b23978ec291eca45a5d5cc1e616c5e032cbc Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 16 Apr 2026 16:50:49 +0200 Subject: [PATCH 57/67] integration tests to rely fully on transcripts --- .../models/parakeet/test_modeling_parakeet.py | 34 +++++++------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 9f1882bb6719..de1bff8ff222 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -16,7 +16,9 @@ import json import tempfile import unittest +from contextlib import nullcontext from pathlib import Path +from unittest.mock import patch from transformers import is_datasets_available, is_torch_available from transformers.testing_utils import cleanup, require_torch, slow, torch_device @@ -94,7 +96,7 @@ def test_tdt_loss_mean(self): def test_tdt_loss_none(self): inputs = self._make_inputs() - losses = tdt_loss(**inputs, reduction=None) + losses = tdt_loss(**inputs, reduction="none") expected = torch.tensor(self.fixture["expected_loss_none"]) torch.testing.assert_close(losses, expected) @@ -637,9 +639,10 @@ class ParakeetForTDTIntegrationTest(unittest.TestCase): @classmethod def setUp(cls): - cls.checkpoint_name = "bezzam/parakeet-tdt-0.6b-v3-hf" + cls.checkpoint_name = "nvidia/parakeet-tdt-0.6b-v3" + cls.revision = "refs/pr/39" cls.dtype = torch.bfloat16 - cls.processor = AutoProcessor.from_pretrained(cls.checkpoint_name) + cls.processor = AutoProcessor.from_pretrained(cls.checkpoint_name, revision=cls.revision) def tearDown(self): cleanup(torch_device, gc_collect=True) @@ -666,16 +669,14 @@ def test_tdt_model_integration(self): RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single_tdt.json" with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) - EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, revision=self.revision, dtype=self.dtype, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=self.dtype) output = model.generate(**inputs, return_dict_in_generate=True) - torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.decode(output.sequences, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) @@ -687,16 +688,14 @@ def test_tdt_model_integration_batched(self): RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch_tdt.json" with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) - EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, revision=self.revision, dtype=self.dtype, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=self.dtype) output = model.generate(**inputs, return_dict_in_generate=True) - torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.decode(output.sequences, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) @@ -710,37 +709,30 @@ def test_tdt_model_integration_timestamps(self): ) with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) - EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] EXPECTED_START_TIMESTAMPS = raw_data["start_timestamps"] EXPECTED_END_TIMESTAMPS = raw_data["end_timestamps"] - EXPECTED_DURATIONS = raw_data["token_durations"] # Use larger precision for testing token durations and timestamps samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map="auto") + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, revision=self.revision, dtype=torch.float32, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=model.dtype) - output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) - torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) + output = model.generate(**inputs, return_dict_in_generate=True) predicted_transcripts, predicted_timestamps = self.processor.decode( output.sequences, - token_timestamps=output.token_timestamps, - token_durations=output.token_durations, + durations=output.durations, skip_special_tokens=True, ) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) # Check timestamps and durations - self.assertIsNotNone( - output.token_timestamps, "token_timestamps should be returned when return_timestamps=True" - ) + self.assertIsNotNone(output.durations, "durations should be returned") predicted_start_times = [[entry["start"] for entry in el] for el in predicted_timestamps] predicted_end_times = [[entry["end"] for entry in el] for el in predicted_timestamps] torch.testing.assert_close(predicted_start_times, EXPECTED_START_TIMESTAMPS) torch.testing.assert_close(predicted_end_times, EXPECTED_END_TIMESTAMPS) - self.assertListEqual(output.token_durations.cpu().tolist(), EXPECTED_DURATIONS) @slow def test_tdt_model_integration_loss(self): @@ -762,7 +754,7 @@ def test_tdt_model_integration_loss(self): transcripts = [t.lower() for t in transcripts] # Use float32 for loss precision - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map="auto") + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, revision=self.revision, dtype=torch.float32, device_map="auto") inputs = self.processor( audio=samples, From a5ba0c618bcd9a96c415e6c60f1b9acf3026603d Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:13:17 +0200 Subject: [PATCH 58/67] udpate fixtures --- .../parakeet/expected_results_batch_tdt.json | 10 +- .../expected_results_batch_tdt_timestamp.json | 252 +++++++++++++++++- .../parakeet/expected_results_single_tdt.json | 6 +- 3 files changed, 265 insertions(+), 3 deletions(-) diff --git a/tests/fixtures/parakeet/expected_results_batch_tdt.json b/tests/fixtures/parakeet/expected_results_batch_tdt.json index 54f5198fd834..c6a37bad56e8 100644 --- a/tests/fixtures/parakeet/expected_results_batch_tdt.json +++ b/tests/fixtures/parakeet/expected_results_batch_tdt.json @@ -1 +1,9 @@ -{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "Nor is mister Quilter's manner less interesting than his matter.", "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind.", "He has grave doubts whether Sir Frederick Leighton's work is really Greek after all, and can discover in it but little of Rocky Ithaca.", "Linnell's pictures are a sort of up guards an atom paintings, and Mason's exquisite idols are as national as a jingo poem. mister Burkett Foster's landscapes smile at one much in the same way that mister Carker used to flash his teeth. And mister John Collier gives his sitter a cheerful slap on the back, before he says, like a shampooer in a Turkish bath Next man"], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [5685, 508, 282, 3459, 1382, 305, 441, 7931, 7870, 698, 1742, 293, 561, 1091, 365, 381, 7098, 2745, 1544, 441, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [1876, 280, 530, 7870, 1441, 1050, 407, 1974, 309, 940, 507, 347, 297, 289, 592, 506, 4070, 287, 7877, 1868, 4959, 398, 2037, 575, 603, 534, 555, 7124, 818, 313, 381, 555, 786, 7864, 1441, 7877, 1622, 305, 283, 2324, 1471, 3109, 325, 296, 381, 575, 5404, 1021, 355, 769, 2090, 7880, 344, 3110, 427, 319, 4838, 366, 506, 1737, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [1876, 2281, 1969, 507, 3362, 7886, 769, 328, 1299, 1239, 7319, 6447, 901, 1413, 1333, 3720, 289, 7931, 7870, 6182, 508, 5600, 4190, 377, 799, 441, 1111, 7877, 575, 2059, 5371, 3230, 334, 869, 2681, 7052, 592, 3341, 725, 7893, 2336, 7882, 566, 7865, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [439, 1538, 530, 7931, 7870, 5970, 7868, 4147, 1714, 279, 275, 621, 592, 1840, 1980, 961, 7870, 411, 407, 313, 849, 942, 2399, 7877, 575, 2945, 289, 7931, 7870, 743, 341, 290, 582, 312, 7874, 324, 7870, 1714, 618, 285, 5858, 618, 279, 300, 381, 7869, 408, 311, 7883, 282, 3459, 426, 344, 7876, 861, 515, 308, 441, 7931, 7870, 3650, 7870, 7880, 474, 283, 1530, 787, 407, 2678, 4457, 334, 506, 766, 7864, 7195, 1050, 282, 3459, 3551, 1684, 1441, 326, 366, 309, 1028, 7882, 2745, 478, 291, 7882, 7883, 1976, 282, 3459, 3483, 4003, 332, 277, 317, 416, 283, 2745, 3488, 441, 279, 774, 277, 5346, 275, 4226, 431, 506, 6507, 7877, 555, 786, 7864, 813, 498, 676, 7877, 2656, 279, 275, 3930, 726, 7869, 277, 334, 279, 5183, 7876, 2739, 302, 7152, 1030, 3127, 698]]} \ No newline at end of file +{ + "transcriptions": [ + "mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", + "Nor is mister Quilter's manner less interesting than his matter.", + "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind.", + "He has grave doubts whether Sir Frederick Leighton's work is really Greek after all, and can discover in it but little of Rocky Ithaca.", + "Linnell's pictures are a sort of up guards an atom paintings, and Mason's exquisite idols are as national as a jingo poem. mister Burkett Foster's landscapes smile at one much in the same way that mister Carker used to flash his teeth. And mister John Collier gives his sitter a cheerful slap on the back, before he says, like a shampooer in a Turkish bath Next man" + ] +} diff --git a/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json b/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json index 0a9b2180b4cb..f13d5aee8b5f 100644 --- a/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json +++ b/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json @@ -1 +1,251 @@ -{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "Nor is mister Quilter's manner less interesting than his matter.", "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind."], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [5685, 508, 282, 3459, 1382, 305, 441, 7931, 7870, 698, 1742, 293, 561, 1091, 365, 381, 7098, 2745, 1544, 441, 7883, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [1876, 280, 530, 7870, 1441, 1050, 407, 1974, 309, 940, 507, 347, 297, 289, 592, 506, 4070, 287, 7877, 1868, 4959, 398, 2037, 575, 603, 534, 555, 7124, 818, 313, 381, 555, 786, 7864, 1441, 7877, 1622, 305, 283, 2324, 1471, 3109, 325, 296, 381, 575, 5404, 1021, 355, 769, 2090, 7880, 344, 3110, 427, 319, 4838, 366, 506, 1737, 7883]], "start_timestamps": [[0.24, 0.48, 0.64, 0.88, 1.12, 1.36, 1.44, 1.6, 1.76, 2.0, 2.16, 2.24, 2.4, 2.48, 2.56, 2.72, 2.88, 3.04, 3.12, 3.2800000000000002, 3.44, 3.6, 3.7600000000000002, 3.92, 4.08, 4.24, 4.4, 4.48, 4.72, 4.96, 5.36, 5.6000000000000005], [0.32, 0.64, 0.88, 1.04, 1.2, 1.44, 1.68, 1.84, 1.92, 2.0, 2.16, 2.4, 2.56, 2.72, 2.96, 3.12, 3.36, 3.6, 3.92, 4.16, 4.32], [0.32, 0.64, 0.72, 0.96, 1.12, 1.36, 1.6, 1.84, 2.08, 2.24, 2.48, 2.64, 2.8000000000000003, 2.88, 3.04, 3.2, 3.44, 3.68, 3.84, 4.08, 4.4, 4.5600000000000005, 4.72, 4.96, 5.12, 5.36, 5.5200000000000005, 5.68, 5.92, 6.16, 6.24, 6.4, 6.5600000000000005, 6.72, 6.96, 7.28, 7.6000000000000005, 7.92, 8.16, 8.32, 8.48, 8.72, 8.88, 8.96, 9.120000000000001, 9.28, 9.44, 9.68, 9.76, 9.92, 10.16, 10.24, 10.4, 10.64, 10.88, 10.96, 11.200000000000001, 11.36, 11.52, 11.84, 12.16]], "end_timestamps": [[0.48, 0.64, 0.88, 1.12, 1.36, 1.44, 1.6, 1.76, 1.92, 2.16, 2.24, 2.4, 2.48, 2.56, 2.64, 2.88, 3.04, 3.12, 3.12, 3.44, 3.6, 3.7600000000000002, 3.92, 4.08, 4.24, 4.4, 4.48, 4.72, 4.96, 5.12, 5.6000000000000005, 5.6000000000000005], [0.64, 0.88, 1.04, 1.2, 1.44, 1.68, 1.84, 1.84, 2.0, 2.16, 2.4, 2.56, 2.72, 2.96, 3.12, 3.36, 3.6, 3.92, 4.16, 4.32, 4.32], [0.64, 0.72, 0.96, 1.12, 1.36, 1.6, 1.84, 2.08, 2.24, 2.48, 2.64, 2.8000000000000003, 2.88, 3.04, 3.2, 3.44, 3.68, 3.84, 3.84, 4.4, 4.5600000000000005, 4.72, 4.96, 5.12, 5.36, 5.5200000000000005, 5.68, 5.92, 6.16, 6.24, 6.4, 6.5600000000000005, 6.72, 6.96, 7.28, 7.28, 7.92, 8.16, 8.24, 8.48, 8.72, 8.88, 8.96, 9.120000000000001, 9.200000000000001, 9.44, 9.68, 9.76, 9.92, 10.16, 10.24, 10.4, 10.64, 10.88, 10.96, 11.200000000000001, 11.36, 11.52, 11.84, 12.16, 12.16]], "token_durations": [[3, 2, 3, 3, 3, 1, 2, 2, 2, 2, 1, 2, 1, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 3, 3, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 3, 2, 2, 3, 3, 2, 1, 1, 2, 3, 2, 2, 3, 2, 3, 3, 4, 3, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [4, 1, 3, 2, 3, 3, 3, 3, 2, 3, 2, 2, 1, 2, 2, 3, 3, 2, 3, 4, 2, 2, 3, 2, 3, 2, 2, 3, 3, 1, 2, 2, 2, 3, 4, 4, 4, 3, 1, 2, 3, 2, 1, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 3, 1, 3, 2, 2, 4, 4, 2]]} \ No newline at end of file +{ + "transcriptions": [ + "mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", + "Nor is mister Quilter's manner less interesting than his matter.", + "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind." + ], + "start_timestamps": [ + [ + 0.24, + 0.48, + 0.64, + 0.88, + 1.12, + 1.36, + 1.44, + 1.6, + 1.76, + 2.0, + 2.16, + 2.24, + 2.4, + 2.48, + 2.56, + 2.72, + 2.88, + 3.04, + 3.12, + 3.2800000000000002, + 3.44, + 3.6, + 3.7600000000000002, + 3.92, + 4.08, + 4.24, + 4.4, + 4.48, + 4.72, + 4.96, + 5.36, + 5.6000000000000005 + ], + [ + 0.32, + 0.64, + 0.88, + 1.04, + 1.2, + 1.44, + 1.68, + 1.84, + 1.92, + 2.0, + 2.16, + 2.4, + 2.56, + 2.72, + 2.96, + 3.12, + 3.36, + 3.6, + 3.92, + 4.16, + 4.32 + ], + [ + 0.32, + 0.64, + 0.72, + 0.96, + 1.12, + 1.36, + 1.6, + 1.84, + 2.08, + 2.24, + 2.48, + 2.64, + 2.8000000000000003, + 2.88, + 3.04, + 3.2, + 3.44, + 3.68, + 3.84, + 4.08, + 4.4, + 4.5600000000000005, + 4.72, + 4.96, + 5.12, + 5.36, + 5.5200000000000005, + 5.68, + 5.92, + 6.16, + 6.24, + 6.4, + 6.5600000000000005, + 6.72, + 6.96, + 7.28, + 7.6000000000000005, + 7.92, + 8.16, + 8.32, + 8.48, + 8.72, + 8.88, + 8.96, + 9.120000000000001, + 9.28, + 9.44, + 9.68, + 9.76, + 9.92, + 10.16, + 10.24, + 10.4, + 10.64, + 10.88, + 10.96, + 11.200000000000001, + 11.36, + 11.52, + 11.84, + 12.16 + ] + ], + "end_timestamps": [ + [ + 0.48, + 0.64, + 0.88, + 1.12, + 1.36, + 1.44, + 1.6, + 1.76, + 1.92, + 2.16, + 2.24, + 2.4, + 2.48, + 2.56, + 2.64, + 2.88, + 3.04, + 3.12, + 3.12, + 3.44, + 3.6, + 3.7600000000000002, + 3.92, + 4.08, + 4.24, + 4.4, + 4.48, + 4.72, + 4.96, + 5.12, + 5.6000000000000005, + 5.6000000000000005 + ], + [ + 0.64, + 0.88, + 1.04, + 1.2, + 1.44, + 1.68, + 1.84, + 1.84, + 2.0, + 2.16, + 2.4, + 2.56, + 2.72, + 2.96, + 3.12, + 3.36, + 3.6, + 3.92, + 4.16, + 4.32, + 4.32 + ], + [ + 0.64, + 0.72, + 0.96, + 1.12, + 1.36, + 1.6, + 1.84, + 2.08, + 2.24, + 2.48, + 2.64, + 2.8000000000000003, + 2.88, + 3.04, + 3.2, + 3.44, + 3.68, + 3.84, + 3.84, + 4.4, + 4.5600000000000005, + 4.72, + 4.96, + 5.12, + 5.36, + 5.5200000000000005, + 5.68, + 5.92, + 6.16, + 6.24, + 6.4, + 6.5600000000000005, + 6.72, + 6.96, + 7.28, + 7.28, + 7.92, + 8.16, + 8.24, + 8.48, + 8.72, + 8.88, + 8.96, + 9.120000000000001, + 9.200000000000001, + 9.44, + 9.68, + 9.76, + 9.92, + 10.16, + 10.24, + 10.4, + 10.64, + 10.88, + 10.96, + 11.200000000000001, + 11.36, + 11.52, + 11.84, + 12.16, + 12.16 + ] + ] +} diff --git a/tests/fixtures/parakeet/expected_results_single_tdt.json b/tests/fixtures/parakeet/expected_results_single_tdt.json index 93a43c9fa9e8..a757d763b6a3 100644 --- a/tests/fixtures/parakeet/expected_results_single_tdt.json +++ b/tests/fixtures/parakeet/expected_results_single_tdt.json @@ -1 +1,5 @@ -{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."], "scores": [-90.4653091430664], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883]]} \ No newline at end of file +{ + "transcriptions": [ + "mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ] +} From 48279a67e25933dedb24de1e2431fddd8331249b Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:19:02 +0200 Subject: [PATCH 59/67] we don't need to monkey patch with numba anymore! --- docs/source/en/model_doc/parakeet.md | 90 ++++------------------------ 1 file changed, 11 insertions(+), 79 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index f90d476cd3cc..87ea6f0e2b5b 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -94,7 +94,7 @@ Parakeet TDT transcripts include casing, and the model can also performk token t ```py from transformers import pipeline -pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-tdt-0.6b-v3") +pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-tdt-0.6b-v3", revision="refs/pr/39") out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3") print(out) ``` @@ -107,8 +107,9 @@ from transformers import AutoModelForTDT, AutoProcessor from datasets import load_dataset, Audio model_id = "nvidia/parakeet-tdt-0.6b-v3" -processor = AutoProcessor.from_pretrained(model_id) -model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto") +revision = "refs/pr/39" +processor = AutoProcessor.from_pretrained(model_id, revision=revision) +model = AutoModelForTDT.from_pretrained(model_id, revision=revision, dtype="auto", device_map="auto") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) @@ -128,8 +129,9 @@ from datasets import Audio, load_dataset from transformers import AutoModelForTDT, AutoProcessor model_id = "nvidia/parakeet-tdt-0.6b-v3" -processor = AutoProcessor.from_pretrained(model_id) -model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto") +revision = "refs/pr/39" +processor = AutoProcessor.from_pretrained(model_id, revision=revision) +model = AutoModelForTDT.from_pretrained(model_id, revision=revision, dtype="auto", device_map="auto") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) @@ -269,21 +271,17 @@ outputs.loss.backward() ### TDT Training -The TDT loss has been implemented within Transformers to enable training. For faster training (around 10x), consider using NeMo's `TDTLossNumba`. Note that this requires installing the NeMo toolkit with `pip install nemo_toolkit[asr]`. - - - - ```py from datasets import Audio, load_dataset import torch from transformers import AutoModelForTDT, AutoProcessor -model_id = "nvidia/parakeet-tdt-0.6b-v3-hf" +model_id = "nvidia/parakeet-tdt-0.6b-v3" +revision = "refs/pr/39" NUM_SAMPLES = 4 -processor = AutoProcessor.from_pretrained(model_id) -model = AutoModelForTDT.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") +processor = AutoProcessor.from_pretrained(model_id, revision=revision) +model = AutoModelForTDT.from_pretrained(model_id, revision=revision, dtype=torch.bfloat16, device_map="auto") model.train() ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") @@ -300,72 +298,6 @@ print("Loss:", outputs.loss.item()) outputs.loss.backward() ``` - - - -```py -import torch -from datasets import Audio, load_dataset -from nemo.collections.asr.losses.rnnt import TDTLossNumba -from transformers import AutoModelForTDT, AutoProcessor - - -model_id = "nvidia/parakeet-tdt-0.6b-v3-hf" -NUM_SAMPLES = 4 - -# Load model and processor -processor = AutoProcessor.from_pretrained(model_id) -model = AutoModelForTDT.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") -model.train() - -# Initialize NeMo TDT loss -loss_fn = TDTLossNumba( - blank=model.config.blank_token_id, - durations=model.config.durations, - reduction="none", -) - -def nemo_loss_wrapper(token_logits, duration_logits, targets, logit_lengths, target_lengths, **kwargs): - """Adapter function that converts Transformers loss signature to NeMo signature.""" - acts = torch.cat([token_logits, duration_logits], dim=-1) - batch_size, T, U = acts.shape[:3] - act_lens = torch.full((batch_size,), T, dtype=torch.long, device=acts.device) - # NeMo requires float32 (Numba doesn't support float16/bfloat16) and int64 - per_sample_losses = nemo_loss_fn( - acts=acts.float(), - labels=targets.long(), - act_lens=act_lens, - label_lens=target_lengths.long(), - ) - # NOTE: NeMo's TDTLossNumba doesn't do normalization with target lengths as suggested by its docstring so we do manually: - # - Docstring: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py#L373 - # - Expected normalization: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py#L247-L253 - return (per_sample_losses / target_lengths.float()).mean() - -# Monkey-patch the model's loss function -model.loss_function = nemo_loss_wrapper - -# Load dataset -ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") -ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) -speech_samples = [el["array"] for el in ds["audio"][:NUM_SAMPLES]] -text_samples = ds["text"][:NUM_SAMPLES] - -# Prepare inputs -inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate) -inputs.to(device=model.device, dtype=model.dtype) - -# Forward and backward -outputs = model(**inputs) -loss = outputs.loss -print(f"Loss (NeMo TDTLossNumba): {loss.item():.6f}") -loss.backward() -print("\nāœ“ Successfully computed loss and gradients using NeMo's fast TDT loss!") -``` - - - - ## ParakeetTokenizer From 1d7680d41abe9ab29f6c614c7aaa8d77e2ef85d7 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:29:44 +0200 Subject: [PATCH 60/67] fix pipeline usage --- src/transformers/models/parakeet/generation_parakeet.py | 4 ++-- src/transformers/pipelines/automatic_speech_recognition.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/parakeet/generation_parakeet.py b/src/transformers/models/parakeet/generation_parakeet.py index b714f4dcc277..60d165d5acb5 100644 --- a/src/transformers/models/parakeet/generation_parakeet.py +++ b/src/transformers/models/parakeet/generation_parakeet.py @@ -117,9 +117,9 @@ def _prepare_model_inputs(self, *args, **kwargs): if encoder_outputs.attention_mask is not None: encoder_valid_lengths = encoder_outputs.attention_mask.sum(-1) else: - batch_size = encoder_outputs.shape[0] + batch_size = encoder_outputs.last_hidden_state.shape[0] encoder_valid_lengths = torch.full( - (batch_size,), encoder_outputs.last_hidden_state.shape[1], dtype=torch.long, device=encoder_outputs.device + (batch_size,), encoder_outputs.last_hidden_state.shape[1], dtype=torch.long, device=encoder_outputs.last_hidden_state.device ) model_kwargs["encoder_valid_lengths"] = encoder_valid_lengths diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 9b5ab3c7ff0f..f71f4c4bd62c 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -564,7 +564,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): if "attention_mask" in model_inputs: inputs["attention_mask"] = model_inputs.pop("attention_mask") outputs = self.model.generate(**inputs) - out = {"tokens": outputs} + out = {"tokens": outputs.sequences} else: raise ValueError("Unsupported model type {self.type}.") From 59ddcedb00f39610ee53fffe55f9d111a64113a2 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:43:50 +0200 Subject: [PATCH 61/67] nit --- .../models/parakeet/modeling_parakeet.py | 44 ++++++++++++++----- .../models/parakeet/modular_parakeet.py | 2 - 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 367f75984303..78ec234b66bc 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -733,8 +733,6 @@ def forward( >>> print(outputs.loss) ```""" - if labels is not None: - kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.encoder( input_features=input_features, attention_mask=attention_mask, @@ -899,14 +897,7 @@ def update( class ParakeetTDTDecoder(nn.Module): - """LSTM-based prediction network for TDT. - - During generation the decoder is called once per step. When a blank token - is fed back (i.e. the model predicted blank at the previous step), the LSTM - state must *not* change — only the encoder frame advances. The blank- - skipping logic restores the previous cache state for those batch elements - using ``torch.where`` so that callers can treat the decoder as a black box. - """ + """LSTM-based prediction network for TDT.""" def __init__(self, config: ParakeetTDTConfig): super().__init__() @@ -939,10 +930,10 @@ def forward( decoder_output = self.decoder_projector(lstm_output) if cache is not None: - # Use ~blank_mask so only non-blank elements are updated; blank elements keep previous state. mask = ~blank_mask if cache.is_initialized else None cache.update(decoder_output, hidden_state, cell_state, lstm_module=self.lstm, mask=mask) return cache.cache + return decoder_output @@ -1031,6 +1022,37 @@ def forward( labels: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ParakeetTDTOutput: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Decoder input token ids for single-step inference. + encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): + Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask). + Can be a tuple or `ParakeetEncoderModelOutput`. + decoder_cache (`ParakeetTDTDecoderCache`, *optional*): + Decoder LSTM cache. When provided and initialized, the cached `decoder_output` is reused + (e.g. during blank-skipping) instead of running the decoder. When `input_ids` is provided, + the decoder runs and the cache is updated in-place. + use_decoder_cache (`bool`, *optional*): + Whether to use a decoder cache. When `True` and `decoder_cache` is `None`, a new cache + is created automatically during the forward pass. + + Example: + + ```python + >>> from transformers import AutoProcessor, ParakeetForTDT + >>> from datasets import load_dataset, Audio + + >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = ParakeetForTDT.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) + + >>> inputs = processor(ds[0]["audio"]["array"]) + >>> outputs = model(**inputs) + ``` + """ if encoder_outputs is None: encoder_outputs = self.get_audio_features( input_features=input_features, diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 0f413aa088ff..d98f788770d9 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -572,8 +572,6 @@ def forward( >>> print(outputs.loss) ```""" - if labels is not None: - kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.encoder( input_features=input_features, attention_mask=attention_mask, From 31490d19de2801002a768c658205cb447e54cbe8 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 16 Apr 2026 18:13:22 +0200 Subject: [PATCH 62/67] fix usage --- docs/source/en/model_doc/parakeet.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index 87ea6f0e2b5b..d7bedba44562 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -89,7 +89,7 @@ print(processor.decode(outputs)) -Parakeet TDT transcripts include casing, and the model can also performk token timestamping. +Parakeet TDT transcripts include casing, and the model can also perform token timestamping. ```py from transformers import pipeline @@ -139,12 +139,11 @@ speech_samples = [el['array'] for el in ds["audio"][:1]] inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=model.dtype) -output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) +output = model.generate(**inputs, return_dict_in_generate=True) decoded_output, decoded_timestamps = processor.decode( output.sequences, - token_timestamps=output.token_timestamps, - token_durations=output.token_durations, - skip_special_tokens=True + durations=output.durations, + skip_special_tokens=True, ) print("Transcription:", decoded_output) print("\nTimestamped tokens:", decoded_timestamps) From d8eb1b6f669553f6ae649a5a9a5e913e4428225f Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 17 Apr 2026 17:04:19 +0200 Subject: [PATCH 63/67] Pass through tests and examples: improve kernel fallback, update with nvidia checkpoint, style checks. --- docs/source/en/model_doc/parakeet.md | 19 +++-- src/transformers/integrations/hub_kernels.py | 8 ++- src/transformers/loss/loss_tdt.py | 20 ++++-- src/transformers/models/lasr/modeling_lasr.py | 10 +-- src/transformers/models/lasr/modular_lasr.py | 71 ++++++++++++++++--- .../models/lasr/processing_lasr.py | 5 +- .../models/parakeet/configuration_parakeet.py | 8 +-- .../models/parakeet/convert_nemo_to_hf.py | 10 ++- .../models/parakeet/generation_parakeet.py | 35 ++++++--- .../models/parakeet/modeling_parakeet.py | 7 +- .../models/parakeet/modular_parakeet.py | 7 +- .../models/parakeet/processing_parakeet.py | 19 +++-- .../models/parakeet/test_modeling_parakeet.py | 25 ++++--- 13 files changed, 168 insertions(+), 76 deletions(-) diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index d7bedba44562..cca7d395f2d2 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -58,6 +58,7 @@ from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-ctc-1.1b") out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3") print(out) +# {'text': 'yesterday it was thirty five degrees in barcelona but today the temperature will go down to minus twenty degrees'} ``` @@ -94,9 +95,10 @@ Parakeet TDT transcripts include casing, and the model can also perform token ti ```py from transformers import pipeline -pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-tdt-0.6b-v3", revision="refs/pr/39") +pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-tdt-0.6b-v3") out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3") print(out) +# {'text': 'Yesterday it was 35 degrees in Barcelona, but today the temperature will go down to minus 20 degrees.'} ``` @@ -107,9 +109,8 @@ from transformers import AutoModelForTDT, AutoProcessor from datasets import load_dataset, Audio model_id = "nvidia/parakeet-tdt-0.6b-v3" -revision = "refs/pr/39" -processor = AutoProcessor.from_pretrained(model_id, revision=revision) -model = AutoModelForTDT.from_pretrained(model_id, revision=revision, dtype="auto", device_map="auto") +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) @@ -129,9 +130,8 @@ from datasets import Audio, load_dataset from transformers import AutoModelForTDT, AutoProcessor model_id = "nvidia/parakeet-tdt-0.6b-v3" -revision = "refs/pr/39" -processor = AutoProcessor.from_pretrained(model_id, revision=revision) -model = AutoModelForTDT.from_pretrained(model_id, revision=revision, dtype="auto", device_map="auto") +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) @@ -276,11 +276,10 @@ import torch from transformers import AutoModelForTDT, AutoProcessor model_id = "nvidia/parakeet-tdt-0.6b-v3" -revision = "refs/pr/39" NUM_SAMPLES = 4 -processor = AutoProcessor.from_pretrained(model_id, revision=revision) -model = AutoModelForTDT.from_pretrained(model_id, revision=revision, dtype=torch.bfloat16, device_map="auto") +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForTDT.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") model.train() ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 2894209173d3..c0db0822b962 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -286,7 +286,7 @@ def register_kernel_mapping_transformers(*args, **kwargs): "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1}, "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, - "tdt-loss": {"repo_id": "eustlb/tdt-loss", "version": 1}, + "tdt-loss": {"repo_id": "eustlb/tdt-loss", "revision": "v1"}, } _KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {} @@ -373,10 +373,12 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) - kernel = get_kernel(repo_id, revision=revision, version=version) + # Since we only read from `_HUB_KERNEL_MAPPING`, we can allow all kernels + kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True) mapping[kernel_name] = kernel - except FileNotFoundError: + except FileNotFoundError as e: mapping[kernel_name] = None + logger.warning_once(f"Failed to load kernel {kernel_name}: {e}") except AssertionError: # Happens when torch is built without an accelerator backend; fall back to slow path. mapping[kernel_name] = None diff --git a/src/transformers/loss/loss_tdt.py b/src/transformers/loss/loss_tdt.py index 3172c0175291..6a128f18583c 100644 --- a/src/transformers/loss/loss_tdt.py +++ b/src/transformers/loss/loss_tdt.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +25,11 @@ def _load_tdt_kernel(): try: from ..integrations.hub_kernels import lazy_load_kernel - return lazy_load_kernel("tdt-loss") + kernel = lazy_load_kernel("tdt-loss") + if kernel is None or not hasattr(kernel, "tdt_loss"): + logger.warning_once("Falling back to pure PyTorch implementation.") + return None + return kernel except (ImportError, ModuleNotFoundError): return None except Exception as e: @@ -73,9 +77,15 @@ def tdt_loss( if kernel is not None and hasattr(kernel, "tdt_loss"): durations_t = torch.tensor(durations, dtype=torch.int32, device=token_logits.device) return kernel.tdt_loss( - token_logits, duration_logits, targets, - logit_lengths, target_lengths, durations_t, - blank_token_id, sigma, reduction, + token_logits, + duration_logits, + targets, + logit_lengths, + target_lengths, + durations_t, + blank_token_id, + sigma, + reduction, ) if reduction not in ("mean", "sum", "none"): diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py index 699f7911c89d..4a2700ea79ed 100644 --- a/src/transformers/models/lasr/modeling_lasr.py +++ b/src/transformers/models/lasr/modeling_lasr.py @@ -26,6 +26,7 @@ from torch import nn from ...activations import ACT2FN +from ...generation import CompileConfig, GenerationMixin from ...integrations import use_kernel_func_from_hub, use_kernelized_func from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer @@ -607,7 +608,7 @@ class LasrCTCGenerateOutput(ModelOutput): Lasr Encoder with a Connectionist Temporal Classification (CTC) head. """ ) -class LasrForCTC(LasrPreTrainedModel): +class LasrForCTC(LasrPreTrainedModel, GenerationMixin): config: LasrCTCConfig def __init__(self, config: LasrCTCConfig): @@ -647,8 +648,6 @@ def forward( >>> print(outputs.loss) ```""" - if labels is not None: - kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.encoder( input_features=input_features, attention_mask=attention_mask, @@ -694,6 +693,7 @@ def generate( input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, return_dict_in_generate: bool = False, + compile_config: CompileConfig | None = None, **kwargs: Unpack[TransformersKwargs], ) -> LasrCTCGenerateOutput | torch.LongTensor: r""" @@ -717,8 +717,10 @@ def generate( >>> print(transcription) ``` """ + model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__ + kwargs["return_dict"] = True - outputs: CausalLMOutput = self.forward( + outputs: CausalLMOutput = model_forward( input_features=input_features, attention_mask=attention_mask, **kwargs, diff --git a/src/transformers/models/lasr/modular_lasr.py b/src/transformers/models/lasr/modular_lasr.py index fd279383f12e..1329c5c0a2af 100644 --- a/src/transformers/models/lasr/modular_lasr.py +++ b/src/transformers/models/lasr/modular_lasr.py @@ -21,11 +21,13 @@ from tokenizers.models import Unigram from torch import nn +from ...audio_utils import AudioInput, make_list_of_audio from ...masking_utils import create_bidirectional_mask from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import ProcessingKwargs, Unpack +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...tokenization_utils_tokenizers import TokenizersBackend -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward @@ -37,10 +39,12 @@ ParakeetForCTC, ParakeetPreTrainedModel, ) -from ..parakeet.processing_parakeet import ParakeetProcessor from ..t5.tokenization_t5 import T5Tokenizer +logger = logging.get_logger(__name__) + + class LasrTokenizer(T5Tokenizer, TokenizersBackend): def __init__( self, @@ -160,13 +164,58 @@ class LasrProcessorKwargs(ProcessingKwargs, total=False): } -class LasrProcessor(ParakeetProcessor): - def decode(self, *args, **kwargs): - """Forward arguments to [`~PreTrainedTokenizer.decode`].""" - self.tokenizer.decode(*args, **kwargs) +@auto_docstring +class LasrProcessor(ProcessorMixin): + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + @auto_docstring + def __call__( + self, + audio: AudioInput, + text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None, + sampling_rate: int | None = None, + **kwargs: Unpack[LasrProcessorKwargs], + ): + r""" + sampling_rate (`int`, *optional*): + The sampling rate of the input audio in Hz. This should match the sampling rate expected by the feature + extractor (defaults to 16000 Hz). If provided, it will be validated against the processor's expected + sampling rate, and an error will be raised if they don't match. If not provided, a warning will be + issued and the default sampling rate will be assumed. + """ + audio = make_list_of_audio(audio) + + output_kwargs = self._merge_kwargs( + LasrProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if sampling_rate is None: + logger.warning_once( + f"You've provided audio without specifying the sampling rate. It will be assumed to be {output_kwargs['audio_kwargs']['sampling_rate']}, which can result in silent errors." + ) + elif sampling_rate != output_kwargs["audio_kwargs"]["sampling_rate"]: + raise ValueError( + f"The sampling rate of the audio ({sampling_rate}) does not match the sampling rate of the processor ({output_kwargs['audio_kwargs']['sampling_rate']}). Please provide resampled the audio to the expected sampling rate." + ) + + if audio is not None: + inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + if text is not None: + encodings = self.tokenizer(text, **output_kwargs["text_kwargs"]) - def _refine_timestamps_tdt(self, *args, **kwargs): - raise NotImplementedError("Not needed") + if text is None: + return inputs + else: + inputs["labels"] = encodings["input_ids"] + return inputs + + @property + def model_input_names(self): + feature_extractor_input_names = self.feature_extractor.model_input_names + return feature_extractor_input_names + ["labels"] @auto_docstring(checkpoint="google/medasr") @@ -202,6 +251,10 @@ class LasrEncoderConfig(ParakeetEncoderConfig): >>> # Initializing a model from the configuration >>> model = LasrEncoderModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details and pre-trained models at [google/medasr](https://huggingface.co/google/medasr). """ diff --git a/src/transformers/models/lasr/processing_lasr.py b/src/transformers/models/lasr/processing_lasr.py index b7216ae08a65..9eb093a49c7a 100644 --- a/src/transformers/models/lasr/processing_lasr.py +++ b/src/transformers/models/lasr/processing_lasr.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from ...audio_utils import AudioInput, make_list_of_audio from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput @@ -96,9 +97,5 @@ def model_input_names(self): feature_extractor_input_names = self.feature_extractor.model_input_names return feature_extractor_input_names + ["labels"] - def decode(self, *args, **kwargs): - """Forward arguments to [`~PreTrainedTokenizer.decode`].""" - self.tokenizer.decode(*args, **kwargs) - __all__ = ["LasrProcessor"] diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 44b0dfd7f402..60d782ad0e4b 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -135,17 +135,17 @@ def __post_init__(self, **kwargs): @strict class ParakeetTDTConfig(PreTrainedConfig): r""" - encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): - The config object or dictionary of the encoder. decoder_hidden_size (`int`, *optional*, defaults to 640): Hidden size of the LSTM prediction network and joint network. num_decoder_layers (`int`, *optional*, defaults to 2): Number of LSTM layers in the prediction network. + max_symbols_per_step (`int`, *optional*, defaults to 10): + Maximum number of symbols to emit per encoder time step during greedy decoding. durations (`list[int]`, *optional*, defaults to `[0, 1, 2, 3, 4]`): Token duration values that can be predicted. Each value represents how many frames a token or blank emission spans. - max_symbols_per_step (`int`, *optional*, defaults to 10): - Maximum number of symbols to emit per encoder time step during greedy decoding. + encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): + The config object or dictionary of the encoder. blank_token_id (`int`, *optional*, defaults to 8192): Blank token id. Different from `pad_token_id` for TDT. diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index a7874e4996a0..b1be27fe5dcf 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -142,7 +142,9 @@ def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str return model_files -def write_processor(nemo_config: dict, model_files, output_dir, model_type, push_to_repo_id=None, create_pr=True, revision=None): +def write_processor( + nemo_config: dict, model_files, output_dir, model_type, push_to_repo_id=None, create_pr=True, revision=None +): tokenizer_converted = ParakeetConverter(model_files["tokenizer_model_file"]).converted() tokenizer_converted_fast = ParakeetTokenizer( tokenizer_object=tokenizer_converted, @@ -425,7 +427,11 @@ def main( # When revision is given (e.g. "refs/pr/3"), both pushes target that existing PR branch. # Otherwise, write_processor creates a new PR and returns its revision for write_model. pr_revision = write_processor( - nemo_config, model_files, output_dir, model_type, push_to_repo_id, + nemo_config, + model_files, + output_dir, + model_type, + push_to_repo_id, create_pr=create_pr if revision is None else False, revision=revision, ) diff --git a/src/transformers/models/parakeet/generation_parakeet.py b/src/transformers/models/parakeet/generation_parakeet.py index 60d165d5acb5..fe422f3dd3a8 100644 --- a/src/transformers/models/parakeet/generation_parakeet.py +++ b/src/transformers/models/parakeet/generation_parakeet.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -62,6 +62,7 @@ class ParakeetTDTGenerationMixin(GenerationMixin): Handles transducer-specific generation logic: encoder frame tracking, duration accumulation, and encoder-exhaustion stopping. """ + def _get_stopping_criteria(self, *args, **kwargs): criteria = super()._get_stopping_criteria(*args, **kwargs) criteria.append(EncoderExhaustedCriteria(self)) @@ -87,8 +88,13 @@ def _update_model_kwargs_for_generation(self, outputs, *args, **kwargs): return model_kwargs def _prepare_generated_length( - self, generation_config, has_default_max_length, has_default_min_length, - model_input_name, input_ids_length, inputs_tensor, + self, + generation_config, + has_default_max_length, + has_default_min_length, + model_input_name, + input_ids_length, + inputs_tensor, ): # When the user hasn't explicitly set max_length/max_new_tokens, derive an upper # bound from the encoder capacity. The actual stopping is handled by the @@ -97,11 +103,15 @@ def _prepare_generated_length( encoder_seq_len = self.encoder._get_subsampling_output_length( torch.tensor([inputs_tensor.shape[1]], device=inputs_tensor.device) ).item() - generation_config.max_length = self.config.max_symbols_per_step * encoder_seq_len + generation_config.max_length = self.max_symbols_per_step * encoder_seq_len has_default_max_length = False # prevent super() from overwriting return super()._prepare_generated_length( - generation_config, has_default_max_length, has_default_min_length, - model_input_name, input_ids_length, inputs_tensor, + generation_config, + has_default_max_length, + has_default_min_length, + model_input_name, + input_ids_length, + inputs_tensor, ) def _prepare_model_inputs(self, *args, **kwargs): @@ -119,7 +129,10 @@ def _prepare_model_inputs(self, *args, **kwargs): else: batch_size = encoder_outputs.last_hidden_state.shape[0] encoder_valid_lengths = torch.full( - (batch_size,), encoder_outputs.last_hidden_state.shape[1], dtype=torch.long, device=encoder_outputs.last_hidden_state.device + (batch_size,), + encoder_outputs.last_hidden_state.shape[1], + dtype=torch.long, + device=encoder_outputs.last_hidden_state.device, ) model_kwargs["encoder_valid_lengths"] = encoder_valid_lengths @@ -140,7 +153,9 @@ def prepare_inputs_for_generation(self, input_ids, *args, **kwargs): from .modeling_parakeet import ParakeetEncoderModelOutput model_inputs = super().prepare_inputs_for_generation(input_ids, *args, **kwargs) - encoder_frame_idxs = model_inputs.pop("encoder_frame_idxs").to(model_inputs["encoder_outputs"].pooler_output.device) + encoder_frame_idxs = model_inputs.pop("encoder_frame_idxs").to( + model_inputs["encoder_outputs"].pooler_output.device + ) pooler_output = model_inputs["encoder_outputs"].pooler_output batch_size, max_encoder_len = pooler_output.shape[0], pooler_output.shape[1] @@ -159,7 +174,9 @@ def generate(self, inputs=None, generation_config=None, **kwargs): outputs = super().generate(inputs=inputs, generation_config=generation_config, **kwargs) durations = torch.stack(self._step_durations, dim=1) # (batch, steps) # Prepend a zero duration for the decoder_start_token_id that super().generate() prepends to sequences - durations = torch.cat([torch.zeros(durations.shape[0], 1, dtype=durations.dtype, device=durations.device), durations], dim=1) + durations = torch.cat( + [torch.zeros(durations.shape[0], 1, dtype=durations.dtype, device=durations.device), durations], dim=1 + ) del self._step_durations, self._encoder_finished return ParakeetTDTGenerateOutput( diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 78ec234b66bc..0fb362edfd49 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -991,6 +991,7 @@ def __init__(self, config: ParakeetTDTConfig): self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.decoder = ParakeetTDTDecoder(config) self.joint = ParakeetTDTJointNetwork(config) + self.max_symbols_per_step = config.max_symbols_per_step # used in generation self.post_init() @@ -1025,9 +1026,6 @@ def forward( r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): Decoder input token ids for single-step inference. - encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): - Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask). - Can be a tuple or `ParakeetEncoderModelOutput`. decoder_cache (`ParakeetTDTDecoderCache`, *optional*): Decoder LSTM cache. When provided and initialized, the cached `decoder_output` is reused (e.g. during blank-skipping) instead of running the decoder. When `input_ids` is provided, @@ -1035,6 +1033,9 @@ def forward( use_decoder_cache (`bool`, *optional*): Whether to use a decoder cache. When `True` and `decoder_cache` is `None`, a new cache is created automatically during the forward pass. + encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): + Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask). + Can be a tuple or `ParakeetEncoderModelOutput`. Example: diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index d98f788770d9..31c3a23e046f 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -830,6 +830,7 @@ def __init__(self, config: ParakeetTDTConfig): self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size) self.decoder = ParakeetTDTDecoder(config) self.joint = ParakeetTDTJointNetwork(config) + self.max_symbols_per_step = config.max_symbols_per_step # used in generation self.post_init() @@ -864,9 +865,6 @@ def forward( r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): Decoder input token ids for single-step inference. - encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): - Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask). - Can be a tuple or `ParakeetEncoderModelOutput`. decoder_cache (`ParakeetTDTDecoderCache`, *optional*): Decoder LSTM cache. When provided and initialized, the cached `decoder_output` is reused (e.g. during blank-skipping) instead of running the decoder. When `input_ids` is provided, @@ -874,6 +872,9 @@ def forward( use_decoder_cache (`bool`, *optional*): Whether to use a decoder cache. When `True` and `decoder_cache` is `None`, a new cache is created automatically during the forward pass. + encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): + Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask). + Can be a tuple or `ParakeetEncoderModelOutput`. Example: diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py index 2a691deaea76..85b63f396765 100644 --- a/src/transformers/models/parakeet/processing_parakeet.py +++ b/src/transformers/models/parakeet/processing_parakeet.py @@ -41,6 +41,10 @@ class ParakeetProcessorKwargs(ProcessingKwargs, total=False): @auto_docstring class ParakeetProcessor(ProcessorMixin): def __init__(self, feature_extractor, tokenizer, blank_token=""): + """ + blank_token (`str`, *optional*, defaults to `""`): + Blank token for TDT decoding. + """ self.blank_token = blank_token self.blank_token_id = tokenizer.convert_tokens_to_ids(blank_token) super().__init__(feature_extractor, tokenizer) @@ -98,7 +102,7 @@ def __call__( @property def model_input_names(self): feature_extractor_input_names = self.feature_extractor.model_input_names - return feature_extractor_input_names + ["labels"] + return feature_extractor_input_names + ["labels", "decoder_input_ids"] def decode(self, *args, durations=None, **kwargs): """ @@ -125,15 +129,16 @@ def decode(self, *args, durations=None, **kwargs): for batch_ids, batch_timestamps, batch_durations in zip(token_ids, timestamps, durations): # See `compute_rnnt_timestamps` in NeMo: https://github.com/NVIDIA-NeMo/NeMo/blob/1692a8fb97e1aadc883cfadd2a57c4e8a1b793aa/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L993 # Filter padding and blank tokens - blank_token_id = self.blank_token_id - skip_ids = {self.tokenizer.pad_token_id, blank_token_id} - non_blank_indices = [ - i for i, token_id in enumerate(batch_ids) if int(token_id) not in skip_ids - ] + skip_ids = {self.tokenizer.pad_token_id, self.blank_token_id} + non_blank_indices = [i for i, token_id in enumerate(batch_ids) if int(token_id) not in skip_ids] non_blank_ids = [batch_ids[i] for i in non_blank_indices] decoded_tokens = [self.tokenizer.decode([token_id]) for token_id in non_blank_ids] timestamp_dict = [ - {"token": token_str, "start": int(batch_timestamps[i]), "end": int(batch_timestamps[i] + batch_durations[i])} + { + "token": token_str, + "start": int(batch_timestamps[i]), + "end": int(batch_timestamps[i] + batch_durations[i]), + } for token_str, i in zip(decoded_tokens, non_blank_indices) ] timestamp_dict = self._refine_timestamps_tdt(timestamp_dict) diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index de1bff8ff222..2c6d219797aa 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -323,6 +323,7 @@ def test_ctc_loss_inference(self): @require_torch class ParakeetForCTCModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ParakeetForCTC,) if is_torch_available() else () + all_generative_model_classes = () # ParakeetForCTC has a custom genereate method pipeline_model_mapping = ( { "feature-extraction": ParakeetEncoder, @@ -594,15 +595,11 @@ def test_retain_grad_hidden_states_attentions(self): def test_generation_tester_mixin_inheritance(self): pass - @unittest.skip( - reason="ParakeetForTDT is a flat composite model without a separate base_model sub-module" - ) + @unittest.skip(reason="ParakeetForTDT is a flat composite model without a separate base_model sub-module") def test_model_base_model_prefix(self): pass - @unittest.skip( - reason="ParakeetForTDT decoder is an LSTM prediction network without attention" - ) + @unittest.skip(reason="ParakeetForTDT decoder is an LSTM prediction network without attention") def test_flex_attention_with_grads(self): pass @@ -640,9 +637,8 @@ class ParakeetForTDTIntegrationTest(unittest.TestCase): @classmethod def setUp(cls): cls.checkpoint_name = "nvidia/parakeet-tdt-0.6b-v3" - cls.revision = "refs/pr/39" cls.dtype = torch.bfloat16 - cls.processor = AutoProcessor.from_pretrained(cls.checkpoint_name, revision=cls.revision) + cls.processor = AutoProcessor.from_pretrained(cls.checkpoint_name) def tearDown(self): cleanup(torch_device, gc_collect=True) @@ -672,7 +668,7 @@ def test_tdt_model_integration(self): EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, revision=self.revision, dtype=self.dtype, device_map="auto") + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=self.dtype) @@ -691,7 +687,7 @@ def test_tdt_model_integration_batched(self): EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, revision=self.revision, dtype=self.dtype, device_map="auto") + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=self.dtype) @@ -715,7 +711,7 @@ def test_tdt_model_integration_timestamps(self): # Use larger precision for testing token durations and timestamps samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, revision=self.revision, dtype=torch.float32, device_map="auto") + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=model.dtype) @@ -754,7 +750,7 @@ def test_tdt_model_integration_loss(self): transcripts = [t.lower() for t in transcripts] # Use float32 for loss precision - model = ParakeetForTDT.from_pretrained(self.checkpoint_name, revision=self.revision, dtype=torch.float32, device_map="auto") + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map="auto") inputs = self.processor( audio=samples, @@ -765,7 +761,10 @@ def test_tdt_model_integration_loss(self): # Test both backends: kernel (if available) and pure PyTorch has_kernel = _load_tdt_kernel() is not None - backends = [("kernel", None), ("torch", patch("transformers.loss.loss_tdt._load_tdt_kernel", return_value=None))] + backends = [ + ("kernel", None), + ("torch", patch("transformers.loss.loss_tdt._load_tdt_kernel", return_value=None)), + ] if not has_kernel: backends = backends[1:] # skip kernel test when not installed From 1f1b912d38fd301b776e5186c1993db1c96a8e9a Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 17 Apr 2026 17:19:39 +0200 Subject: [PATCH 64/67] Update checkpoint --- src/transformers/models/parakeet/configuration_parakeet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 60d782ad0e4b..4b7c5b0fb526 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -131,7 +131,7 @@ def __post_init__(self, **kwargs): super().__post_init__(**kwargs) -@auto_docstring(checkpoint="bezzam/parakeet-tdt-0.6b-v3-hf") +@auto_docstring(checkpoint="nvidia/parakeet-tdt-0.6b-v3") @strict class ParakeetTDTConfig(PreTrainedConfig): r""" From fd9f8b1baa7618eb2d8e9dc0fedb82d4e3b00ff4 Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 17 Apr 2026 18:04:17 +0200 Subject: [PATCH 65/67] Add TDT to mapping after merge. --- src/transformers/models/auto/auto_mappings.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index 10e376b65956..24db9a947411 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -393,6 +393,7 @@ ("paligemma", "PaliGemmaConfig"), ("parakeet_ctc", "ParakeetCTCConfig"), ("parakeet_encoder", "ParakeetEncoderConfig"), + ("parakeet_tdt", "ParakeetTDTConfig"), ("patchtsmixer", "PatchTSMixerConfig"), ("patchtst", "PatchTSTConfig"), ("pe_audio", "PeAudioConfig"), @@ -755,6 +756,7 @@ ("paddleocr_vl_vision", "paddleocr_vl"), ("parakeet_ctc", "parakeet"), ("parakeet_encoder", "parakeet"), + ("parakeet_tdt", "parakeet"), ("pe_audio_encoder", "pe_audio"), ("pe_audio_video_encoder", "pe_audio_video"), ("pe_video_encoder", "pe_video"), From 136f67688056566532c06f2201b367eebf2652bb Mon Sep 17 00:00:00 2001 From: Eric B Date: Mon, 20 Apr 2026 11:34:00 +0200 Subject: [PATCH 66/67] Fix lasr generate test. --- tests/models/lasr/test_modeling_lasr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/lasr/test_modeling_lasr.py b/tests/models/lasr/test_modeling_lasr.py index 36060eecac3b..d212730676f9 100644 --- a/tests/models/lasr/test_modeling_lasr.py +++ b/tests/models/lasr/test_modeling_lasr.py @@ -245,6 +245,7 @@ def test_ctc_loss_inference(self): @require_torch class LasrForCTCModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (LasrForCTC,) if is_torch_available() else () + all_generative_model_classes = () # LasrForCTC has a custom genereate method pipeline_model_mapping = ( { "feature-extraction": LasrEncoder, From 833d2890417d53e145075e43a8f571ed844ad49a Mon Sep 17 00:00:00 2001 From: Eric B Date: Mon, 20 Apr 2026 12:07:57 +0200 Subject: [PATCH 67/67] Output attention mask if labels provided for computing loss. --- src/transformers/models/lasr/modeling_lasr.py | 2 ++ src/transformers/models/parakeet/modeling_parakeet.py | 2 ++ src/transformers/models/parakeet/modular_parakeet.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py index 4a2700ea79ed..19054874b1e1 100644 --- a/src/transformers/models/lasr/modeling_lasr.py +++ b/src/transformers/models/lasr/modeling_lasr.py @@ -648,6 +648,8 @@ def forward( >>> print(outputs.loss) ```""" + if labels is not None: + kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.encoder( input_features=input_features, attention_mask=attention_mask, diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 0fb362edfd49..4672dcab0cb2 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -733,6 +733,8 @@ def forward( >>> print(outputs.loss) ```""" + if labels is not None: + kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.encoder( input_features=input_features, attention_mask=attention_mask, diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 31c3a23e046f..22fce9362648 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -572,6 +572,8 @@ def forward( >>> print(outputs.loss) ```""" + if labels is not None: + kwargs.setdefault("output_attention_mask", True) encoder_outputs = self.encoder( input_features=input_features, attention_mask=attention_mask,