From bf52a00f7c9b36d643fb22e881ff9fa971342a66 Mon Sep 17 00:00:00 2001 From: Maksym Lypivskyi Date: Mon, 19 Jan 2026 22:52:35 +0100 Subject: [PATCH 1/2] Add Parakeet TDT model support --- docs/source/en/model_doc/parakeet.md | 14 + src/transformers/convert_slow_tokenizer.py | 3 +- .../models/auto/configuration_auto.py | 3 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + src/transformers/models/parakeet/__init__.py | 1 + .../models/parakeet/configuration_parakeet.py | 111 +++- .../models/parakeet/convert_nemo_to_hf.py | 166 +++++- .../models/parakeet/modeling_parakeet.py | 475 +++++++++++++++++- .../models/parakeet/modular_parakeet.py | 475 +++++++++++++++++- .../parakeet/expected_results_tdt_batch.json | 1 + .../parakeet/expected_results_tdt_single.json | 1 + .../models/parakeet/generate_tdt_fixtures.py | 106 ++++ .../models/parakeet/test_modeling_parakeet.py | 394 +++++++++++++++ 16 files changed, 1739 insertions(+), 15 deletions(-) create mode 100644 tests/fixtures/parakeet/expected_results_tdt_batch.json create mode 100644 tests/fixtures/parakeet/expected_results_tdt_single.json create mode 100644 tests/models/parakeet/generate_tdt_fixtures.py diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md index b075e6d5ccf7..f37130d4fb48 100644 --- a/docs/source/en/model_doc/parakeet.md +++ b/docs/source/en/model_doc/parakeet.md @@ -34,6 +34,12 @@ Parakeet models, [introduced by NVIDIA NeMo](https://developer.nvidia.com/blog/p - 1D convolution projection from encoder hidden size to vocabulary size (for optimal NeMo compatibility). - CTC loss computation for training. - Greedy CTC decoding for inference. +- [**ParakeetForTDT**](#parakeetfortdt): a Fast Conformer Encoder + a Token-and-Duration Transducer (TDT) decoder + - **TDT Decoder**: A transducer-based decoder that predicts both tokens and their durations: + - Prediction network (LSTM-based) for autoregressive token prediction. + - Joint network combining encoder and decoder outputs. + - Separate token and duration heads for parallel prediction. + - Greedy TDT decoding with optional timestamp generation. The original implementation can be found in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo). Model checkpoints are to be found under [the NVIDIA organization](https://huggingface.co/nvidia/models?search=parakeet). @@ -219,3 +225,11 @@ outputs.loss.backward() ## ParakeetForCTC [[autodoc]] ParakeetForCTC + +## ParakeetTDTConfig + +[[autodoc]] ParakeetTDTConfig + +## ParakeetForTDT + +[[autodoc]] ParakeetForTDT diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 0e4201f6553b..286eace6ad79 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1770,9 +1770,8 @@ def __init__(self, vocab_file=None, *args): def tokenizer(self, proto): vocab_scores = self.vocab(proto) - - _, merges = self.SpmExtractor(self.vocab_file).extract(vocab_scores) bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)} + merges = generate_merges(bpe_vocab, vocab_scores) tokenizer = Tokenizer( BPE( bpe_vocab, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c2e33d97d73a..34aad49722b2 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -316,6 +316,7 @@ ("paligemma", "PaliGemmaConfig"), ("parakeet_ctc", "ParakeetCTCConfig"), ("parakeet_encoder", "ParakeetEncoderConfig"), + ("parakeet_tdt", "ParakeetTDTConfig"), ("patchtsmixer", "PatchTSMixerConfig"), ("patchtst", "PatchTSTConfig"), ("pe_audio", "PeAudioConfig"), @@ -792,6 +793,7 @@ ("parakeet", "Parakeet"), ("parakeet_ctc", "Parakeet"), ("parakeet_encoder", "ParakeetEncoder"), + ("parakeet_tdt", "ParakeetTDT"), ("patchtsmixer", "PatchTSMixer"), ("patchtst", "PatchTST"), ("pe_audio", "PeAudio"), @@ -1032,6 +1034,7 @@ ("parakeet_encoder", "parakeet"), ("lw_detr_vit", "lw_detr"), ("parakeet_ctc", "parakeet"), + ("parakeet_tdt", "parakeet"), ("lasr_encoder", "lasr"), ("lasr_ctc", "lasr"), ("wav2vec2-bert", "wav2vec2_bert"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 25b980443cf8..3d58d00e3596 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -59,6 +59,7 @@ ("musicgen_melody", "MusicgenMelodyFeatureExtractor"), ("parakeet_ctc", "ParakeetFeatureExtractor"), ("parakeet_encoder", "ParakeetFeatureExtractor"), + ("parakeet_tdt", "ParakeetFeatureExtractor"), ("pe_audio", "PeAudioFeatureExtractor"), ("pe_audio_video", "PeAudioFeatureExtractor"), ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a68baf338830..05324928a575 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -312,6 +312,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("paligemma", "PaliGemmaModel"), ("parakeet_ctc", "ParakeetForCTC"), ("parakeet_encoder", "ParakeetEncoder"), + ("parakeet_tdt", "ParakeetForTDT"), ("patchtsmixer", "PatchTSMixerModel"), ("patchtst", "PatchTSTModel"), ("pe_audio", "PeAudioModel"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index c27e5612e171..9ad548da6c06 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -121,6 +121,7 @@ ("owlvit", "OwlViTProcessor"), ("paddleocr_vl", "PaddleOCRVLProcessor"), ("paligemma", "PaliGemmaProcessor"), + ("parakeet_tdt", "ParakeetProcessor"), ("perception_lm", "PerceptionLMProcessor"), ("phi4_multimodal", "Phi4MultimodalProcessor"), ("pix2struct", "Pix2StructProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 701d4be849bf..370c0c334c8e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -237,6 +237,7 @@ ("ovis2", "Qwen2Tokenizer" if is_tokenizers_available() else None), ("owlv2", "CLIPTokenizer" if is_tokenizers_available() else None), ("owlvit", "CLIPTokenizer" if is_tokenizers_available() else None), + ("parakeet_tdt", "ParakeetTokenizer" if is_tokenizers_available() else None), ("pegasus", "PegasusTokenizer" if is_tokenizers_available() else None), ("pegasus_x", "PegasusTokenizer" if is_tokenizers_available() else None), ("perceiver", "PerceiverTokenizer"), diff --git a/src/transformers/models/parakeet/__init__.py b/src/transformers/models/parakeet/__init__.py index 5c54b2e2eadb..594eb9f7e099 100644 --- a/src/transformers/models/parakeet/__init__.py +++ b/src/transformers/models/parakeet/__init__.py @@ -21,6 +21,7 @@ from .configuration_parakeet import * from .feature_extraction_parakeet import * from .modeling_parakeet import * + from .processing_parakeet import * from .tokenization_parakeet_fast import * else: import sys diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 802da0b208fb..e1d5aec8c32d 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -231,4 +231,113 @@ def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): return cls(encoder_config=encoder_config.to_dict(), **kwargs) -__all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig"] +class ParakeetTDTConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ParakeetForTDT`]. It is used to instantiate a + Parakeet TDT (Token Duration Transducer) model according to the specified arguments, defining the model architecture. + + TDT models jointly predict tokens and their durations, enabling accurate speech recognition with word-level + timestamps. Unlike CTC which only provides character-level timing, TDT predicts how many frames each token spans. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 8192): + Vocabulary size of the model (SentencePiece tokenizer). TDT uses a larger vocabulary than CTC. + decoder_hidden_size (`int`, *optional*, defaults to 640): + Hidden size of the LSTM prediction network (decoder). + decoder_num_layers (`int`, *optional*, defaults to 1): + Number of LSTM layers in the prediction network. + joint_hidden_size (`int`, *optional*, defaults to 640): + Hidden size of the joint network that combines encoder and decoder outputs. + num_duration_bins (`int`, *optional*, defaults to 5): + Number of duration bins for predicting token durations. Each bin represents how many frames + to advance (e.g., [0, 1, 2, 3, 4] means 0-4 frames). + encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): + The config object or dictionary of the encoder. TDT reuses the same FastConformer encoder as CTC. + blank_token_id (`int`, *optional*, defaults to 8192): + Token ID for the blank symbol. In TDT, blank indicates "no token emission, advance frames". + pad_token_id (`int`, *optional*, defaults to 8192): + Padding token id. Defaults to blank_token_id. + + Example: + ```python + >>> from transformers import ParakeetForTDT, ParakeetTDTConfig + + >>> # Initializing a Parakeet TDT configuration + >>> configuration = ParakeetTDTConfig() + + >>> # Initializing a model from the configuration + >>> model = ParakeetForTDT(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + + This configuration class is based on the Parakeet TDT architecture from NVIDIA NeMo. TDT (Token Duration + Transducer) extends RNN-T by jointly predicting token durations, enabling efficient frame skipping during + decoding. You can find more details and pre-trained models at: + - [nvidia/parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2) (English) + - [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) (25 languages) + + References: + - TDT Paper: https://arxiv.org/abs/2304.06795 + - FastConformer Paper: https://arxiv.org/abs/2305.05084 + """ + + model_type = "parakeet_tdt" + sub_configs = {"encoder_config": ParakeetEncoderConfig} + + def __init__( + self, + vocab_size=8192, + decoder_hidden_size=640, + decoder_num_layers=1, + joint_hidden_size=640, + num_duration_bins=5, + encoder_config: dict | ParakeetEncoderConfig = None, + blank_token_id=8192, + pad_token_id=8192, + **kwargs, + ): + self.vocab_size = vocab_size + self.decoder_hidden_size = decoder_hidden_size + self.decoder_num_layers = decoder_num_layers + self.joint_hidden_size = joint_hidden_size + self.num_duration_bins = num_duration_bins + self.blank_token_id = blank_token_id + + if blank_token_id != vocab_size: + logger.warning( + f"blank_token_id ({blank_token_id}) should equal vocab_size ({vocab_size}) " + "for correct embedding table sizing. The embedding table size is vocab_size + 1 " + "to accommodate the blank token at index vocab_size." + ) + + if isinstance(encoder_config, dict): + self.encoder_config = ParakeetEncoderConfig(**encoder_config) + elif encoder_config is None: + self.encoder_config = ParakeetEncoderConfig() + else: + self.encoder_config = encoder_config + + self.initializer_range = self.encoder_config.initializer_range + + super().__init__( + pad_token_id=pad_token_id, + **kwargs, + ) + + @classmethod + def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): + r""" + Instantiate a [`ParakeetTDTConfig`] (or a derived class) from parakeet encoder model configuration. + + Returns: + [`ParakeetTDTConfig`]: An instance of a configuration object + """ + return cls(encoder_config=encoder_config.to_dict(), **kwargs) + + +__all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig", "ParakeetTDTConfig"] diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index 2d4085e6d340..69d542bdcc3b 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -28,7 +28,9 @@ ParakeetEncoderConfig, ParakeetFeatureExtractor, ParakeetForCTC, + ParakeetForTDT, ParakeetProcessor, + ParakeetTDTConfig, ParakeetTokenizer, ) from transformers.convert_slow_tokenizer import ParakeetConverter @@ -48,6 +50,19 @@ r"linear_pos": r"relative_k_proj", } +# Additional mappings for TDT decoder and joint network +NEMO_TDT_WEIGHT_MAPPING = { + # Decoder embedding + r"decoder\.prediction\.embed\.": r"decoder.embedding.", + # Decoder LSTM - remove extra nesting + r"decoder\.prediction\.dec_rnn\.lstm\.": r"decoder.lstm.", + # Joint network encoder projection + r"joint\.enc\.": r"joint.encoder_proj.", + # Decoder output projection (NeMo puts this in joint, HF puts in decoder) + r"joint\.pred\.": r"decoder.output_proj.", + # Note: joint.joint_net.2 (combined head) needs special handling - see write_tdt_model +} + def convert_key(key, mapping): for pattern, replacement in mapping.items(): @@ -329,20 +344,153 @@ def write_encoder_model(encoder_config, converted_state_dict, output_dir, push_t print("Model reloaded successfully.") +def convert_tdt_config(nemo_config, encoder_config): + """Convert NeMo TDT config to HF TDT config.""" + decoder_config = nemo_config.get("decoder", {}) + joint_config = nemo_config.get("joint", {}) + decoding_config = nemo_config.get("decoding", {}) + + # Extract vocab size from labels + labels = nemo_config.get("labels", []) + vocab_size = len(labels) if labels else decoder_config.get("vocab_size", 1024) + + # Prediction network config + prednet = decoder_config.get("prednet", {}) + decoder_hidden_size = prednet.get("pred_hidden", 640) + decoder_num_layers = prednet.get("pred_rnn_layers", 2) + + # Joint network config + jointnet = joint_config.get("jointnet", {}) + joint_hidden_size = jointnet.get("joint_hidden", 640) + + # Duration config from decoding + durations = decoding_config.get("durations", [0, 1, 2, 3, 4]) + num_duration_bins = len(durations) + + print( + f"TDT config: vocab_size={vocab_size}, decoder_hidden={decoder_hidden_size}, " + f"decoder_layers={decoder_num_layers}, joint_hidden={joint_hidden_size}, " + f"num_durations={num_duration_bins}" + ) + + return ParakeetTDTConfig( + vocab_size=vocab_size, + decoder_hidden_size=decoder_hidden_size, + decoder_num_layers=decoder_num_layers, + joint_hidden_size=joint_hidden_size, + num_duration_bins=num_duration_bins, + encoder_config=encoder_config.to_dict(), + blank_token_id=vocab_size, + ) + + +def load_and_convert_tdt_state_dict(model_files, vocab_size, num_duration_bins): + """Load NeMo TDT state dict and convert keys to HF format, splitting combined head.""" + state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True) + converted_state_dict = {} + + # Combine encoder and TDT mappings + all_mappings = {**NEMO_TO_HF_WEIGHT_MAPPING, **NEMO_TDT_WEIGHT_MAPPING} + + for key, value in state_dict.items(): + # Skip preprocessing weights + if key.endswith("featurizer.window") or key.endswith("featurizer.fb"): + print(f"Skipping preprocessing weight: {key}") + continue + + # Handle combined output head - needs to be split + if key == "joint.joint_net.2.weight": + # NeMo combines token and duration heads: [vocab_size+1+num_durations, joint_hidden] + # Split into separate heads + token_weight = value[: vocab_size + 1, :] # First vocab_size+1 rows for tokens + duration_weight = value[vocab_size + 1 :, :] # Last num_duration_bins rows for durations + converted_state_dict["joint.token_head.weight"] = token_weight + converted_state_dict["joint.duration_head.weight"] = duration_weight + print(f"Split combined weight: token_head {token_weight.shape}, duration_head {duration_weight.shape}") + continue + + if key == "joint.joint_net.2.bias": + # Same split for bias + token_bias = value[: vocab_size + 1] + duration_bias = value[vocab_size + 1 :] + converted_state_dict["joint.token_head.bias"] = token_bias + converted_state_dict["joint.duration_head.bias"] = duration_bias + print(f"Split combined bias: token_head {token_bias.shape}, duration_head {duration_bias.shape}") + continue + + # Standard key conversion + converted_key = convert_key(key, all_mappings) + converted_state_dict[converted_key] = value + + return converted_state_dict + + +def write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_to_repo_id=None): + """Write TDT model using encoder config, TDT config, and converted state dict.""" + # Step 1: Convert TDT config + model_config = convert_tdt_config(nemo_config, encoder_config) + print(f"Converted TDT config: {model_config}") + + # Step 2: Load and convert state dict with TDT-specific handling + converted_state_dict = load_and_convert_tdt_state_dict( + model_files, model_config.vocab_size, model_config.num_duration_bins + ) + + print("Loading the checkpoint in a Parakeet TDT model.") + with torch.device("meta"): + model = ParakeetForTDT(model_config) + + # Load weights + missing_keys, unexpected_keys = model.load_state_dict(converted_state_dict, strict=False, assign=True) + + if missing_keys: + print(f"Warning: Missing keys: {missing_keys}") + if unexpected_keys: + print(f"Warning: Unexpected keys: {unexpected_keys}") + + if not missing_keys and not unexpected_keys: + print("All weights loaded successfully!") + else: + # Re-try with strict to get detailed error if there are issues + try: + model.load_state_dict(converted_state_dict, strict=True, assign=True) + except Exception as e: + print(f"Strict loading failed: {e}") + print("Continuing with partial weights...") + + del model.config._name_or_path + + print("Saving the model.") + model.save_pretrained(output_dir) + + if push_to_repo_id: + model.push_to_hub(push_to_repo_id) + + del model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + ParakeetForTDT.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto") + print("Model reloaded successfully.") + + def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None): """Main model conversion function.""" # Step 1: Convert encoder config (shared across all model types) encoder_config = convert_encoder_config(nemo_config) print(f"Converted encoder config: {encoder_config}") - # Step 2: Load and convert state dict (shared across all model types) - converted_state_dict = load_and_convert_state_dict(model_files) - - # Step 3: Write model based on type + # Step 2: Write model based on type if model_type == "encoder": + converted_state_dict = load_and_convert_state_dict(model_files) write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id) elif model_type == "ctc": + converted_state_dict = load_and_convert_state_dict(model_files) write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id) + elif model_type == "tdt": + # TDT has its own state dict loading with combined head splitting + write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_to_repo_id) else: raise ValueError(f"Model type {model_type} not supported.") @@ -352,6 +500,7 @@ def main( output_dir, model_type, push_to_repo_id=None, + skip_processor=False, ): nemo_filename = f"{hf_repo_id.split('/')[-1]}.nemo" filepath = cached_file(hf_repo_id, nemo_filename) @@ -359,7 +508,10 @@ def main( model_files = extract_nemo_archive(filepath, os.path.dirname(filepath)) nemo_config = yaml.load(open(model_files["model_config"], "r"), Loader=yaml.FullLoader) - write_processor(nemo_config, model_files, output_dir, push_to_repo_id) + if not skip_processor: + write_processor(nemo_config, model_files, output_dir, push_to_repo_id) + else: + print("Skipping processor conversion (--skip_processor flag set)") write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id) @@ -367,14 +519,16 @@ def main( parser = argparse.ArgumentParser() parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co") parser.add_argument( - "--model_type", required=True, choices=["encoder", "ctc"], help="Model type (`encoder`, `ctc`)" + "--model_type", required=True, choices=["encoder", "ctc", "tdt"], help="Model type (`encoder`, `ctc`, `tdt`)" ) parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model") parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub") + parser.add_argument("--skip_processor", action="store_true", help="Skip processor conversion") args = parser.parse_args() main( args.hf_repo_id, args.output_dir, args.model_type, args.push_to_repo_id, + args.skip_processor, ) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 2e27be594216..9794c394829f 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -34,7 +34,7 @@ from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import check_model_inputs, maybe_autocast -from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig +from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig @dataclass @@ -612,6 +612,7 @@ def forward( position_embeddings, p=self.dropout_positions, training=self.training ) + output_mask = None if attention_mask is not None: output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1]) attention_mask = output_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1) @@ -635,7 +636,8 @@ def forward( ) return ParakeetEncoderModelOutput( - last_hidden_state=hidden_states, attention_mask=output_mask.int() if output_attention_mask else None + last_hidden_state=hidden_states, + attention_mask=output_mask.int() if output_attention_mask and output_mask is not None else None, ) @@ -810,4 +812,471 @@ def generate( return sequences -__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"] +@dataclass +class ParakeetTDTOutput(ModelOutput): + """ + Output type of [`ParakeetForTDT`]. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated token sequences. + timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + The timestamps (in seconds) for each generated token. + token_scores (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + The confidence scores for each generated token. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + The last hidden state of the encoder. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*): + Encoder attention weights. + """ + + sequences: torch.LongTensor + timestamps: torch.FloatTensor | None = None + token_scores: torch.FloatTensor | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_attentions: tuple[torch.FloatTensor] | None = None + + +class ParakeetTDTDecoder(nn.Module): + """ + LSTM-based prediction network for TDT (Token Duration Transducer). + + The decoder maintains language model context across token predictions. + It takes the previous token embedding and outputs a hidden representation + that is combined with the encoder output in the joint network. + + Key insight: The decoder only needs to be updated when emitting non-blank tokens, + as blank tokens (silence) don't contribute to language context. + """ + + def __init__(self, config: ParakeetTDTConfig): + super().__init__() + self.config = config + + # Token embedding layer + self.embedding = nn.Embedding(config.vocab_size + 1, config.decoder_hidden_size) # +1 for blank + + # LSTM prediction network + self.lstm = nn.LSTM( + input_size=config.decoder_hidden_size, + hidden_size=config.decoder_hidden_size, + num_layers=config.decoder_num_layers, + batch_first=True, + ) + + # Output projection + self.output_proj = nn.Linear(config.decoder_hidden_size, config.joint_hidden_size) + + def forward( + self, + input_ids: torch.LongTensor, + hidden_state: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + input_ids: Previous token IDs of shape (batch_size, 1) + hidden_state: Tuple of (h, c) LSTM states, each of shape + (num_layers, batch_size, hidden_size) + + Returns: + output: Decoder output of shape (batch_size, 1, joint_hidden_size) + hidden_state: Updated (h, c) tuple + """ + # Get token embeddings + embeddings = self.embedding(input_ids) # (batch, 1, decoder_hidden) + + # Run through LSTM + lstm_out, hidden_state = self.lstm(embeddings, hidden_state) # (batch, 1, decoder_hidden) + + # Project to joint network dimension + output = self.output_proj(lstm_out) # (batch, 1, joint_hidden) + + return output, hidden_state + + def init_state( + self, batch_size: int, device: torch.device, dtype: torch.dtype = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Initialize LSTM hidden and cell states to zeros.""" + h = torch.zeros( + self.config.decoder_num_layers, + batch_size, + self.config.decoder_hidden_size, + device=device, + dtype=dtype, + ) + c = torch.zeros( + self.config.decoder_num_layers, + batch_size, + self.config.decoder_hidden_size, + device=device, + dtype=dtype, + ) + return (h, c) + + +class ParakeetTDTJointNetwork(nn.Module): + """ + Joint network that combines encoder and decoder outputs to predict tokens and durations. + + The joint network takes: + - Encoder frame output at time t + - Decoder output based on previous token(s) + + And produces: + - Token logits (vocab_size + 1 for blank) + - Duration logits (num_duration_bins) + """ + + def __init__(self, config: ParakeetTDTConfig): + super().__init__() + self.config = config + + # Encoder projection to joint hidden size + self.encoder_proj = nn.Linear(config.encoder_config.hidden_size, config.joint_hidden_size) + + # Joint activation (NeMo uses relu in joint network) + self.activation = ACT2FN["relu"] + + # Output heads + self.token_head = nn.Linear(config.joint_hidden_size, config.vocab_size + 1) # +1 for blank + self.duration_head = nn.Linear(config.joint_hidden_size, config.num_duration_bins) + + def forward( + self, + encoder_output: torch.Tensor, + decoder_output: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + encoder_output: Encoder frame of shape (batch_size, 1, encoder_hidden_size) + decoder_output: Decoder output of shape (batch_size, 1, joint_hidden_size) + + Returns: + token_logits: Shape (batch_size, vocab_size + 1) + duration_logits: Shape (batch_size, num_duration_bins) + """ + # Project encoder output + encoder_proj = self.encoder_proj(encoder_output) # (batch, 1, joint_hidden) + + # Combine encoder and decoder outputs (additive joint) + joint_out = encoder_proj + decoder_output # (batch, 1, joint_hidden) + joint_out = self.activation(joint_out) + + # Predict token and duration + token_logits = self.token_head(joint_out).squeeze(1) # (batch, vocab+1) + duration_logits = self.duration_head(joint_out).squeeze(1) # (batch, num_bins) + + return token_logits, duration_logits + + +class ParakeetTDTPreTrainedModel(PreTrainedModel): + """Base class for TDT models with TDT-specific config.""" + + config: ParakeetTDTConfig + config_class = ParakeetTDTConfig + base_model_prefix = "model" + main_input_name = "input_features" + input_modalities = "audio" + supports_gradient_checkpointing = True + _no_split_modules = ["ParakeetEncoderBlock"] + _supports_flat_attention_mask = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_flash_attn = False + + @torch.no_grad() + def _init_weights(self, module): + std = self.config.initializer_range + + if isinstance(module, nn.Linear): + init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + init.normal_(module.weight, mean=0.0, std=std) + elif isinstance(module, nn.LSTM): + for name, param in module.named_parameters(): + if "weight" in name: + init.normal_(param, mean=0.0, std=std) + elif "bias" in name: + init.zeros_(param) + elif isinstance(module, ParakeetForTDT): + # Initialize duration_bins buffer + duration_bins = torch.arange(self.config.num_duration_bins, dtype=torch.long) + init.copy_(module.duration_bins, duration_bins) + + def _get_subsampling_output_length(self, input_lengths: torch.Tensor): + encoder_config = self.config.encoder_config + + kernel_size = encoder_config.subsampling_conv_kernel_size + stride = encoder_config.subsampling_conv_stride + num_layers = int(math.log2(encoder_config.subsampling_factor)) + + all_paddings = (kernel_size - 1) // 2 * 2 + add_pad = all_paddings - kernel_size + lengths = input_lengths + + for _ in range(num_layers): + lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + 1.0 + lengths = torch.floor(lengths) + + return lengths.to(dtype=torch.int) + + def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: int | None = None): + output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) + max_length = target_length if target_length is not None else output_lengths.max() + attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None] + return attention_mask + + +@auto_docstring( + custom_intro=""" + Parakeet model with TDT (Token Duration Transducer) head for speech recognition. + + TDT jointly predicts tokens and their durations, enabling efficient decoding with + accurate word-level timestamps. Unlike CTC, TDT can skip multiple frames at once + based on the predicted duration. + """ +) +class ParakeetForTDT(ParakeetTDTPreTrainedModel): + config: ParakeetTDTConfig + + # Frame duration in seconds (after 8x subsampling: 80ms per frame) + _frame_duration_seconds = 0.08 + + def __init__(self, config: ParakeetTDTConfig): + super().__init__(config) + self.encoder = ParakeetEncoder(config.encoder_config) + self.decoder = ParakeetTDTDecoder(config) + self.joint = ParakeetTDTJointNetwork(config) + + # Duration bins mapping: index -> number of frames to skip + self.register_buffer( + "duration_bins", + torch.arange(config.num_duration_bins, dtype=torch.long), + persistent=False, + ) + + self.post_init() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> ParakeetEncoderModelOutput: + r""" + Returns the encoder outputs. For TDT decoding, use the `generate` method. + + Example: + + ```python + >>> from transformers import AutoProcessor, ParakeetForTDT + >>> from datasets import load_dataset, Audio + + >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = ParakeetForTDT.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) + + >>> inputs = processor(ds[0]["audio"]["array"]) + >>> encoder_outputs = model(**inputs) + + >>> print(encoder_outputs.last_hidden_state.shape) + ``` + """ + encoder_outputs = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + output_attention_mask=True, + **kwargs, + ) + + return encoder_outputs + + @torch.no_grad() + def generate( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + return_timestamps: bool = False, + return_dict_in_generate: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> ParakeetTDTOutput | torch.LongTensor: + r""" + Perform TDT greedy decoding to generate token sequences. + + Args: + input_features (`torch.Tensor`): + Mel spectrogram features of shape `(batch_size, num_mel_bins, sequence_length)`. + attention_mask (`torch.Tensor`, *optional*): + Attention mask for the input features. + return_timestamps (`bool`, *optional*, defaults to `False`): + Whether to return timestamps for each token. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether to return a `ParakeetTDTOutput` instead of just sequences. + + Returns: + `ParakeetTDTOutput` or `torch.LongTensor`: Generated sequences and optional metadata. + + Example: + + ```python + >>> from transformers import AutoProcessor, ParakeetForTDT + >>> from datasets import load_dataset, Audio + + >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = ParakeetForTDT.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) + + >>> inputs = processor(ds[0]["audio"]["array"]) + >>> output = model.generate(**inputs, return_timestamps=True, return_dict_in_generate=True) + + >>> transcription = processor.batch_decode(output.sequences, skip_special_tokens=True) + >>> print(transcription) + ``` + """ + blank_id = self.config.blank_token_id + device = input_features.device + batch_size = input_features.shape[0] + + # Get encoder outputs + encoder_outputs = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + output_attention_mask=True, + **kwargs, + ) + encoder_hidden = encoder_outputs.last_hidden_state # (batch, time, hidden) + encoder_mask = encoder_outputs.attention_mask # (batch, time) or None + + num_frames = encoder_hidden.shape[1] + if encoder_mask is not None: + # Get valid lengths per batch item + valid_lengths = encoder_mask.sum(dim=1).int() # (batch,) + else: + valid_lengths = torch.full((batch_size,), num_frames, dtype=torch.int, device=device) + + # Initialize decoder state with same dtype as encoder output + decoder_state = self.decoder.init_state(batch_size, device, dtype=encoder_hidden.dtype) + + # Initialize with blank token + prev_tokens = torch.full((batch_size, 1), blank_id, dtype=torch.long, device=device) + + # Prime the decoder with initial blank + decoder_out, decoder_state = self.decoder(prev_tokens, decoder_state) + + # Output storage (will be padded/truncated at the end) + all_tokens = [[] for _ in range(batch_size)] + all_timestamps = [[] for _ in range(batch_size)] + all_scores = [[] for _ in range(batch_size)] + + # Frame indices for each batch item + time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) + + # Active mask - which batch items are still being processed + active_mask = time_indices < valid_lengths + + while active_mask.any(): + # Get current encoder frames for active items + # We process all items but only use results for active ones + safe_time_indices = torch.clamp(time_indices, max=num_frames - 1) + encoder_frames = encoder_hidden[torch.arange(batch_size, device=device), safe_time_indices].unsqueeze( + 1 + ) # (batch, 1, hidden) + + # Run joint network + token_logits, duration_logits = self.joint(encoder_frames, decoder_out) + + # Greedy selection + tokens = token_logits.argmax(dim=-1) # (batch,) + token_probs = torch.softmax(token_logits, dim=-1) + token_scores = token_probs.gather(1, tokens.unsqueeze(1)).squeeze(1) # (batch,) + + durations = duration_logits.argmax(dim=-1) # (batch,) + # Map to actual frame counts using duration bins + duration_frames = self.duration_bins[durations] # (batch,) + + # Ensure minimum duration of 1 to prevent infinite loops + # In trained models, duration 0 means "emit and stay" for multi-token words + # But with random weights, this causes infinite loops + is_blank = tokens == blank_id + duration_frames = torch.clamp(duration_frames, min=1) + + # Process non-blank tokens for active items + emit_mask = active_mask & ~is_blank + + for i in range(batch_size): + if emit_mask[i]: + all_tokens[i].append(tokens[i].item()) + all_timestamps[i].append(time_indices[i].item() * self._frame_duration_seconds) + all_scores[i].append(token_scores[i].item()) + + # Update decoder for items that emitted tokens + if emit_mask.any(): + # Only update decoder state for items that emitted non-blank + new_prev_tokens = tokens.unsqueeze(1) # (batch, 1) + new_decoder_out, new_decoder_state = self.decoder(new_prev_tokens, decoder_state) + + # Selectively update decoder output and state + # For items that emitted, use new state; otherwise keep old + emit_mask_expanded = emit_mask.view(batch_size, 1, 1) + decoder_out = torch.where(emit_mask_expanded, new_decoder_out, decoder_out) + + # Update LSTM states selectively + h_old, c_old = decoder_state + h_new, c_new = new_decoder_state + emit_mask_state = emit_mask.view(1, batch_size, 1) + decoder_state = ( + torch.where(emit_mask_state, h_new, h_old), + torch.where(emit_mask_state, c_new, c_old), + ) + + # Advance time indices by duration + time_indices = time_indices + duration_frames + + # Update active mask + active_mask = time_indices < valid_lengths + + # Pad sequences to same length + max_len = max(len(seq) for seq in all_tokens) if all_tokens[0] else 0 + if max_len == 0: + max_len = 1 # At least one token + + sequences = torch.full((batch_size, max_len), self.config.pad_token_id, dtype=torch.long, device=device) + timestamps = torch.zeros(batch_size, max_len, dtype=torch.float, device=device) + scores = torch.zeros(batch_size, max_len, dtype=torch.float, device=device) + + for i in range(batch_size): + seq_len = len(all_tokens[i]) + if seq_len > 0: + sequences[i, :seq_len] = torch.tensor(all_tokens[i], dtype=torch.long, device=device) + timestamps[i, :seq_len] = torch.tensor(all_timestamps[i], dtype=torch.float, device=device) + scores[i, :seq_len] = torch.tensor(all_scores[i], dtype=torch.float, device=device) + + if return_dict_in_generate: + return ParakeetTDTOutput( + sequences=sequences, + timestamps=timestamps if return_timestamps else None, + token_scores=scores if return_timestamps else None, + encoder_last_hidden_state=encoder_hidden, + encoder_attentions=encoder_outputs.attentions, + ) + + return sequences + + +__all__ = [ + "ParakeetForCTC", + "ParakeetForTDT", + "ParakeetEncoder", + "ParakeetPreTrainedModel", + "ParakeetTDTPreTrainedModel", +] diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index cf5ad1be8dc8..c8664ea91482 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -30,7 +30,7 @@ from ...utils.generic import check_model_inputs, maybe_autocast from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule from ..llama.modeling_llama import LlamaAttention, eager_attention_forward -from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig +from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig @dataclass @@ -450,6 +450,7 @@ def forward( position_embeddings, p=self.dropout_positions, training=self.training ) + output_mask = None if attention_mask is not None: output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1]) attention_mask = output_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1) @@ -473,7 +474,8 @@ def forward( ) return ParakeetEncoderModelOutput( - last_hidden_state=hidden_states, attention_mask=output_mask.int() if output_attention_mask else None + last_hidden_state=hidden_states, + attention_mask=output_mask.int() if output_attention_mask and output_mask is not None else None, ) @@ -648,4 +650,471 @@ def generate( return sequences -__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"] +@dataclass +class ParakeetTDTOutput(ModelOutput): + """ + Output type of [`ParakeetForTDT`]. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated token sequences. + timestamps (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + The timestamps (in seconds) for each generated token. + token_scores (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + The confidence scores for each generated token. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + The last hidden state of the encoder. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*): + Encoder attention weights. + """ + + sequences: torch.LongTensor + timestamps: torch.FloatTensor | None = None + token_scores: torch.FloatTensor | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_attentions: tuple[torch.FloatTensor] | None = None + + +class ParakeetTDTDecoder(nn.Module): + """ + LSTM-based prediction network for TDT (Token Duration Transducer). + + The decoder maintains language model context across token predictions. + It takes the previous token embedding and outputs a hidden representation + that is combined with the encoder output in the joint network. + + Key insight: The decoder only needs to be updated when emitting non-blank tokens, + as blank tokens (silence) don't contribute to language context. + """ + + def __init__(self, config: ParakeetTDTConfig): + super().__init__() + self.config = config + + # Token embedding layer + self.embedding = nn.Embedding(config.vocab_size + 1, config.decoder_hidden_size) # +1 for blank + + # LSTM prediction network + self.lstm = nn.LSTM( + input_size=config.decoder_hidden_size, + hidden_size=config.decoder_hidden_size, + num_layers=config.decoder_num_layers, + batch_first=True, + ) + + # Output projection + self.output_proj = nn.Linear(config.decoder_hidden_size, config.joint_hidden_size) + + def forward( + self, + input_ids: torch.LongTensor, + hidden_state: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + input_ids: Previous token IDs of shape (batch_size, 1) + hidden_state: Tuple of (h, c) LSTM states, each of shape + (num_layers, batch_size, hidden_size) + + Returns: + output: Decoder output of shape (batch_size, 1, joint_hidden_size) + hidden_state: Updated (h, c) tuple + """ + # Get token embeddings + embeddings = self.embedding(input_ids) # (batch, 1, decoder_hidden) + + # Run through LSTM + lstm_out, hidden_state = self.lstm(embeddings, hidden_state) # (batch, 1, decoder_hidden) + + # Project to joint network dimension + output = self.output_proj(lstm_out) # (batch, 1, joint_hidden) + + return output, hidden_state + + def init_state( + self, batch_size: int, device: torch.device, dtype: torch.dtype = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Initialize LSTM hidden and cell states to zeros.""" + h = torch.zeros( + self.config.decoder_num_layers, + batch_size, + self.config.decoder_hidden_size, + device=device, + dtype=dtype, + ) + c = torch.zeros( + self.config.decoder_num_layers, + batch_size, + self.config.decoder_hidden_size, + device=device, + dtype=dtype, + ) + return (h, c) + + +class ParakeetTDTJointNetwork(nn.Module): + """ + Joint network that combines encoder and decoder outputs to predict tokens and durations. + + The joint network takes: + - Encoder frame output at time t + - Decoder output based on previous token(s) + + And produces: + - Token logits (vocab_size + 1 for blank) + - Duration logits (num_duration_bins) + """ + + def __init__(self, config: ParakeetTDTConfig): + super().__init__() + self.config = config + + # Encoder projection to joint hidden size + self.encoder_proj = nn.Linear(config.encoder_config.hidden_size, config.joint_hidden_size) + + # Joint activation (NeMo uses relu in joint network) + self.activation = ACT2FN["relu"] + + # Output heads + self.token_head = nn.Linear(config.joint_hidden_size, config.vocab_size + 1) # +1 for blank + self.duration_head = nn.Linear(config.joint_hidden_size, config.num_duration_bins) + + def forward( + self, + encoder_output: torch.Tensor, + decoder_output: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + encoder_output: Encoder frame of shape (batch_size, 1, encoder_hidden_size) + decoder_output: Decoder output of shape (batch_size, 1, joint_hidden_size) + + Returns: + token_logits: Shape (batch_size, vocab_size + 1) + duration_logits: Shape (batch_size, num_duration_bins) + """ + # Project encoder output + encoder_proj = self.encoder_proj(encoder_output) # (batch, 1, joint_hidden) + + # Combine encoder and decoder outputs (additive joint) + joint_out = encoder_proj + decoder_output # (batch, 1, joint_hidden) + joint_out = self.activation(joint_out) + + # Predict token and duration + token_logits = self.token_head(joint_out).squeeze(1) # (batch, vocab+1) + duration_logits = self.duration_head(joint_out).squeeze(1) # (batch, num_bins) + + return token_logits, duration_logits + + +class ParakeetTDTPreTrainedModel(PreTrainedModel): + """Base class for TDT models with TDT-specific config.""" + + config: ParakeetTDTConfig + config_class = ParakeetTDTConfig + base_model_prefix = "model" + main_input_name = "input_features" + input_modalities = "audio" + supports_gradient_checkpointing = True + _no_split_modules = ["ParakeetEncoderBlock"] + _supports_flat_attention_mask = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_flash_attn = False + + @torch.no_grad() + def _init_weights(self, module): + std = self.config.initializer_range + + if isinstance(module, nn.Linear): + init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + init.normal_(module.weight, mean=0.0, std=std) + elif isinstance(module, nn.LSTM): + for name, param in module.named_parameters(): + if "weight" in name: + init.normal_(param, mean=0.0, std=std) + elif "bias" in name: + init.zeros_(param) + elif isinstance(module, ParakeetForTDT): + # Initialize duration_bins buffer + duration_bins = torch.arange(self.config.num_duration_bins, dtype=torch.long) + init.copy_(module.duration_bins, duration_bins) + + def _get_subsampling_output_length(self, input_lengths: torch.Tensor): + encoder_config = self.config.encoder_config + + kernel_size = encoder_config.subsampling_conv_kernel_size + stride = encoder_config.subsampling_conv_stride + num_layers = int(math.log2(encoder_config.subsampling_factor)) + + all_paddings = (kernel_size - 1) // 2 * 2 + add_pad = all_paddings - kernel_size + lengths = input_lengths + + for _ in range(num_layers): + lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + 1.0 + lengths = torch.floor(lengths) + + return lengths.to(dtype=torch.int) + + def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: int | None = None): + output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1)) + max_length = target_length if target_length is not None else output_lengths.max() + attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None] + return attention_mask + + +@auto_docstring( + custom_intro=""" + Parakeet model with TDT (Token Duration Transducer) head for speech recognition. + + TDT jointly predicts tokens and their durations, enabling efficient decoding with + accurate word-level timestamps. Unlike CTC, TDT can skip multiple frames at once + based on the predicted duration. + """ +) +class ParakeetForTDT(ParakeetTDTPreTrainedModel): + config: ParakeetTDTConfig + + # Frame duration in seconds (after 8x subsampling: 80ms per frame) + _frame_duration_seconds = 0.08 + + def __init__(self, config: ParakeetTDTConfig): + super().__init__(config) + self.encoder = ParakeetEncoder(config.encoder_config) + self.decoder = ParakeetTDTDecoder(config) + self.joint = ParakeetTDTJointNetwork(config) + + # Duration bins mapping: index -> number of frames to skip + self.register_buffer( + "duration_bins", + torch.arange(config.num_duration_bins, dtype=torch.long), + persistent=False, + ) + + self.post_init() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> ParakeetEncoderModelOutput: + r""" + Returns the encoder outputs. For TDT decoding, use the `generate` method. + + Example: + + ```python + >>> from transformers import AutoProcessor, ParakeetForTDT + >>> from datasets import load_dataset, Audio + + >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = ParakeetForTDT.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) + + >>> inputs = processor(ds[0]["audio"]["array"]) + >>> encoder_outputs = model(**inputs) + + >>> print(encoder_outputs.last_hidden_state.shape) + ``` + """ + encoder_outputs = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + output_attention_mask=True, + **kwargs, + ) + + return encoder_outputs + + @torch.no_grad() + def generate( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + return_timestamps: bool = False, + return_dict_in_generate: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> ParakeetTDTOutput | torch.LongTensor: + r""" + Perform TDT greedy decoding to generate token sequences. + + Args: + input_features (`torch.Tensor`): + Mel spectrogram features of shape `(batch_size, num_mel_bins, sequence_length)`. + attention_mask (`torch.Tensor`, *optional*): + Attention mask for the input features. + return_timestamps (`bool`, *optional*, defaults to `False`): + Whether to return timestamps for each token. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether to return a `ParakeetTDTOutput` instead of just sequences. + + Returns: + `ParakeetTDTOutput` or `torch.LongTensor`: Generated sequences and optional metadata. + + Example: + + ```python + >>> from transformers import AutoProcessor, ParakeetForTDT + >>> from datasets import load_dataset, Audio + + >>> model_id = "nvidia/parakeet-tdt-0.6b-v3" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = ParakeetForTDT.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) + + >>> inputs = processor(ds[0]["audio"]["array"]) + >>> output = model.generate(**inputs, return_timestamps=True, return_dict_in_generate=True) + + >>> transcription = processor.batch_decode(output.sequences, skip_special_tokens=True) + >>> print(transcription) + ``` + """ + blank_id = self.config.blank_token_id + device = input_features.device + batch_size = input_features.shape[0] + + # Get encoder outputs + encoder_outputs = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + output_attention_mask=True, + **kwargs, + ) + encoder_hidden = encoder_outputs.last_hidden_state # (batch, time, hidden) + encoder_mask = encoder_outputs.attention_mask # (batch, time) or None + + num_frames = encoder_hidden.shape[1] + if encoder_mask is not None: + # Get valid lengths per batch item + valid_lengths = encoder_mask.sum(dim=1).int() # (batch,) + else: + valid_lengths = torch.full((batch_size,), num_frames, dtype=torch.int, device=device) + + # Initialize decoder state with same dtype as encoder output + decoder_state = self.decoder.init_state(batch_size, device, dtype=encoder_hidden.dtype) + + # Initialize with blank token + prev_tokens = torch.full((batch_size, 1), blank_id, dtype=torch.long, device=device) + + # Prime the decoder with initial blank + decoder_out, decoder_state = self.decoder(prev_tokens, decoder_state) + + # Output storage (will be padded/truncated at the end) + all_tokens = [[] for _ in range(batch_size)] + all_timestamps = [[] for _ in range(batch_size)] + all_scores = [[] for _ in range(batch_size)] + + # Frame indices for each batch item + time_indices = torch.zeros(batch_size, dtype=torch.long, device=device) + + # Active mask - which batch items are still being processed + active_mask = time_indices < valid_lengths + + while active_mask.any(): + # Get current encoder frames for active items + # We process all items but only use results for active ones + safe_time_indices = torch.clamp(time_indices, max=num_frames - 1) + encoder_frames = encoder_hidden[torch.arange(batch_size, device=device), safe_time_indices].unsqueeze( + 1 + ) # (batch, 1, hidden) + + # Run joint network + token_logits, duration_logits = self.joint(encoder_frames, decoder_out) + + # Greedy selection + tokens = token_logits.argmax(dim=-1) # (batch,) + token_probs = torch.softmax(token_logits, dim=-1) + token_scores = token_probs.gather(1, tokens.unsqueeze(1)).squeeze(1) # (batch,) + + durations = duration_logits.argmax(dim=-1) # (batch,) + # Map to actual frame counts using duration bins + duration_frames = self.duration_bins[durations] # (batch,) + + # Ensure minimum duration of 1 to prevent infinite loops + # In trained models, duration 0 means "emit and stay" for multi-token words + # But with random weights, this causes infinite loops + is_blank = tokens == blank_id + duration_frames = torch.clamp(duration_frames, min=1) + + # Process non-blank tokens for active items + emit_mask = active_mask & ~is_blank + + for i in range(batch_size): + if emit_mask[i]: + all_tokens[i].append(tokens[i].item()) + all_timestamps[i].append(time_indices[i].item() * self._frame_duration_seconds) + all_scores[i].append(token_scores[i].item()) + + # Update decoder for items that emitted tokens + if emit_mask.any(): + # Only update decoder state for items that emitted non-blank + new_prev_tokens = tokens.unsqueeze(1) # (batch, 1) + new_decoder_out, new_decoder_state = self.decoder(new_prev_tokens, decoder_state) + + # Selectively update decoder output and state + # For items that emitted, use new state; otherwise keep old + emit_mask_expanded = emit_mask.view(batch_size, 1, 1) + decoder_out = torch.where(emit_mask_expanded, new_decoder_out, decoder_out) + + # Update LSTM states selectively + h_old, c_old = decoder_state + h_new, c_new = new_decoder_state + emit_mask_state = emit_mask.view(1, batch_size, 1) + decoder_state = ( + torch.where(emit_mask_state, h_new, h_old), + torch.where(emit_mask_state, c_new, c_old), + ) + + # Advance time indices by duration + time_indices = time_indices + duration_frames + + # Update active mask + active_mask = time_indices < valid_lengths + + # Pad sequences to same length + max_len = max(len(seq) for seq in all_tokens) if all_tokens[0] else 0 + if max_len == 0: + max_len = 1 # At least one token + + sequences = torch.full((batch_size, max_len), self.config.pad_token_id, dtype=torch.long, device=device) + timestamps = torch.zeros(batch_size, max_len, dtype=torch.float, device=device) + scores = torch.zeros(batch_size, max_len, dtype=torch.float, device=device) + + for i in range(batch_size): + seq_len = len(all_tokens[i]) + if seq_len > 0: + sequences[i, :seq_len] = torch.tensor(all_tokens[i], dtype=torch.long, device=device) + timestamps[i, :seq_len] = torch.tensor(all_timestamps[i], dtype=torch.float, device=device) + scores[i, :seq_len] = torch.tensor(all_scores[i], dtype=torch.float, device=device) + + if return_dict_in_generate: + return ParakeetTDTOutput( + sequences=sequences, + timestamps=timestamps if return_timestamps else None, + token_scores=scores if return_timestamps else None, + encoder_last_hidden_state=encoder_hidden, + encoder_attentions=encoder_outputs.attentions, + ) + + return sequences + + +__all__ = [ + "ParakeetForCTC", + "ParakeetForTDT", + "ParakeetEncoder", + "ParakeetPreTrainedModel", + "ParakeetTDTPreTrainedModel", +] diff --git a/tests/fixtures/parakeet/expected_results_tdt_batch.json b/tests/fixtures/parakeet/expected_results_tdt_batch.json new file mode 100644 index 000000000000..c3f46c17321d --- /dev/null +++ b/tests/fixtures/parakeet/expected_results_tdt_batch.json @@ -0,0 +1 @@ +{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "Nor is mister Quilter's manner less interesting than his matter.", "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind.", "He has grave doubts whether Sir Frederick Leighton's work is really Greek after all, and can discover in it but little of Rocky Ithaca.", "Linnell's pictures are a sort of up guards an atom paintings, and Mason's exquisite idols are as national as a jingo poem. mister Burkett Foster's landscapes smile at one much in the same way that mister Carker used to flash his teeth. And mister John Collier gives his sitter a cheerful slap on the back, before he says, like a shampooer in a Turkish bath Next man"], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [5685, 508, 282, 3459, 1382, 305, 441, 7931, 7870, 698, 1742, 293, 561, 1091, 365, 381, 7098, 2745, 1544, 441, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [1876, 280, 530, 7870, 1441, 1050, 407, 1974, 309, 940, 507, 347, 297, 289, 592, 506, 4070, 287, 7877, 1868, 4959, 398, 2037, 575, 603, 534, 555, 7124, 818, 313, 381, 555, 786, 7864, 1441, 7877, 1622, 305, 283, 2324, 1471, 3109, 325, 296, 381, 575, 5404, 1021, 355, 769, 2090, 7880, 344, 3110, 427, 319, 4838, 366, 506, 1737, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [1876, 2281, 1969, 507, 3362, 7886, 769, 328, 1299, 1239, 7319, 6447, 901, 1413, 1333, 3720, 289, 7931, 7870, 6182, 508, 5600, 4190, 377, 799, 441, 1111, 7877, 575, 2059, 5371, 3230, 334, 869, 2681, 7052, 592, 3341, 725, 7893, 2336, 7882, 566, 7865, 7883, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], [439, 1538, 530, 7931, 7870, 5970, 7868, 4147, 1714, 279, 275, 621, 592, 1840, 1980, 961, 7870, 411, 407, 313, 849, 942, 2399, 7877, 575, 2945, 289, 7931, 7870, 743, 341, 290, 582, 312, 7874, 324, 7870, 1714, 618, 285, 5858, 618, 279, 300, 381, 7869, 408, 311, 7883, 282, 3459, 426, 344, 7876, 861, 515, 308, 441, 7931, 7870, 3650, 7870, 7880, 474, 283, 1530, 787, 407, 2678, 4457, 334, 506, 766, 7864, 7195, 1050, 282, 3459, 3551, 1684, 1441, 326, 366, 309, 1028, 7882, 2745, 478, 291, 7882, 7883, 1976, 282, 3459, 3483, 4003, 332, 277, 317, 416, 283, 2745, 3488, 441, 279, 774, 277, 5346, 275, 4226, 431, 506, 6507, 7877, 555, 786, 7864, 813, 498, 676, 7877, 2656, 279, 275, 3930, 726, 7869, 277, 334, 279, 5183, 7876, 2739, 302, 7152, 1030, 3127, 698]]} \ No newline at end of file diff --git a/tests/fixtures/parakeet/expected_results_tdt_single.json b/tests/fixtures/parakeet/expected_results_tdt_single.json new file mode 100644 index 000000000000..d1d1a4857d7f --- /dev/null +++ b/tests/fixtures/parakeet/expected_results_tdt_single.json @@ -0,0 +1 @@ +{"transcriptions": ["mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."], "token_ids": [[282, 3459, 1382, 305, 441, 508, 506, 767, 487, 337, 592, 506, 3414, 7874, 337, 6046, 7870, 283, 7877, 575, 750, 1714, 1627, 319, 366, 4446, 7880, 1901, 2745, 3576, 5871, 7883]]} \ No newline at end of file diff --git a/tests/models/parakeet/generate_tdt_fixtures.py b/tests/models/parakeet/generate_tdt_fixtures.py new file mode 100644 index 000000000000..e09d6c56e03d --- /dev/null +++ b/tests/models/parakeet/generate_tdt_fixtures.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Script to generate expected test fixtures for ParakeetForTDT integration tests. + +This script runs the TDT model on LibriSpeech samples and saves the outputs +to JSON fixtures that are used by the integration tests. + +Usage: + python tests/models/parakeet/generate_tdt_fixtures.py + +Requirements: + - torch + - transformers + - datasets +""" + +import json +from pathlib import Path + +import torch +from datasets import Audio, load_dataset + +from transformers import AutoProcessor, ParakeetForTDT + + +def main(): + # TODO: Change to "nvidia/parakeet-tdt-0.6b-v3" once NVIDIA adds HF format to their repo + checkpoint_name = "MaksL/parakeet-tdt-0.6b-v3" + dtype = torch.bfloat16 + device = "cuda" if torch.cuda.is_available() else "cpu" + + print(f"Loading model {checkpoint_name}...") + processor = AutoProcessor.from_pretrained(checkpoint_name) + model = ParakeetForTDT.from_pretrained(checkpoint_name, torch_dtype=dtype, device_map=device) + model.eval() + + print("Loading dataset...") + dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + dataset = dataset.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) + + # Sort by ID to ensure reproducibility + speech_samples = dataset.sort("id") + + fixtures_dir = Path(__file__).parent.parent.parent / "fixtures" / "parakeet" + fixtures_dir.mkdir(parents=True, exist_ok=True) + + # Generate single sample fixture + print("Generating single sample fixture...") + single_sample = [speech_samples[0]["audio"]["array"]] + inputs = processor(single_sample) + inputs.to(device, dtype=dtype) + + with torch.no_grad(): + output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) + + single_fixture = { + "transcriptions": processor.batch_decode(output.sequences, skip_special_tokens=True), + "token_ids": output.sequences.cpu().tolist(), + } + + single_path = fixtures_dir / "expected_results_tdt_single.json" + with open(single_path, "w") as f: + json.dump(single_fixture, f) + print(f"Saved: {single_path}") + + # Generate batch fixture (5 samples) + print("Generating batch fixture...") + batch_samples = [speech_samples[i]["audio"]["array"] for i in range(5)] + inputs = processor(batch_samples) + inputs.to(device, dtype=dtype) + + with torch.no_grad(): + output = model.generate(**inputs, return_dict_in_generate=True, return_timestamps=True) + + batch_fixture = { + "transcriptions": processor.batch_decode(output.sequences, skip_special_tokens=True), + "token_ids": output.sequences.cpu().tolist(), + } + + batch_path = fixtures_dir / "expected_results_tdt_batch.json" + with open(batch_path, "w") as f: + json.dump(batch_fixture, f) + print(f"Saved: {batch_path}") + + print("\nFixtures generated successfully!") + print(f"\nSingle sample transcription:\n {single_fixture['transcriptions'][0]}") + print("\nBatch transcriptions:") + for i, t in enumerate(batch_fixture["transcriptions"]): + print(f" [{i}] {t}") + + +if __name__ == "__main__": + main() diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 7bd35946574f..9154d0ea78a2 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -37,6 +37,8 @@ ParakeetEncoder, ParakeetEncoderConfig, ParakeetForCTC, + ParakeetForTDT, + ParakeetTDTConfig, ) @@ -375,3 +377,395 @@ def test_1b_model_integration_batched(self): torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) + + +class ParakeetForTDTModelTester: + def __init__( + self, + parent, + encoder_kwargs=None, + is_training=True, + vocab_size=128, + decoder_hidden_size=64, + decoder_num_layers=1, + joint_hidden_size=64, + num_duration_bins=5, + blank_token_id=128, # Must equal vocab_size for embedding table + pad_token_id=128, + ): + if encoder_kwargs is None: + encoder_kwargs = {} + + self.parent = parent + self.encoder_model_tester = ParakeetEncoderModelTester(parent, **encoder_kwargs) + self.is_training = is_training + + self.batch_size = self.encoder_model_tester.batch_size + self.output_seq_length = self.encoder_model_tester.output_seq_length + self.num_hidden_layers = self.encoder_model_tester.num_hidden_layers + self.hidden_size = self.encoder_model_tester.hidden_size + + self.vocab_size = vocab_size + self.seq_length = vocab_size # Required by test_hidden_states_output + self.decoder_hidden_size = decoder_hidden_size + self.decoder_num_layers = decoder_num_layers + self.joint_hidden_size = joint_hidden_size + self.num_duration_bins = num_duration_bins + self.blank_token_id = blank_token_id + self.pad_token_id = pad_token_id + + def prepare_config_and_inputs(self): + _, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs() + config = self.get_config() + return config, input_features, attention_mask + + def get_config(self): + return ParakeetTDTConfig( + encoder_config=self.encoder_model_tester.get_config().to_dict(), + vocab_size=self.vocab_size, + decoder_hidden_size=self.decoder_hidden_size, + decoder_num_layers=self.decoder_num_layers, + joint_hidden_size=self.joint_hidden_size, + num_duration_bins=self.num_duration_bins, + blank_token_id=self.blank_token_id, + pad_token_id=self.pad_token_id, + ) + + def create_and_check_model(self, config, input_features, attention_mask): + model = ParakeetForTDT(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_features, attention_mask=attention_mask) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size) + ) + + def create_and_check_generate(self, config, input_features, attention_mask): + model = ParakeetForTDT(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model.generate( + input_features, attention_mask=attention_mask, return_dict_in_generate=True, return_timestamps=True + ) + # Check sequences shape - batch size should match + self.parent.assertEqual(result.sequences.shape[0], self.batch_size) + # Check timestamps are returned + self.parent.assertIsNotNone(result.timestamps) + self.parent.assertEqual(result.timestamps.shape[0], self.batch_size) + + def prepare_config_and_inputs_for_common(self): + config, input_features, attention_mask = self.prepare_config_and_inputs() + inputs_dict = { + "input_features": input_features, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class ParakeetForTDTModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (ParakeetForTDT,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": ParakeetEncoder, + "automatic-speech-recognition": ParakeetForTDT, + } + if is_torch_available() + else {} + ) + + test_attention_outputs = False + test_resize_embeddings = False + test_torch_exportable = True + + _is_composite = True + + def setUp(self): + self.model_tester = ParakeetForTDTModelTester(self) + self.config_tester = ConfigTester(self, config_class=ParakeetTDTConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_generate(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate(*config_and_inputs) + + def test_generate_returns_valid_output(self): + """Test that generate() returns valid sequences and timestamps.""" + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + model = ParakeetForTDT(config=config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model.generate( + input_features, + attention_mask=attention_mask, + return_dict_in_generate=True, + return_timestamps=True, + ) + + batch_size = input_features.shape[0] + + # Check output structure + self.assertIsNotNone(output.sequences) + self.assertIsNotNone(output.timestamps) + self.assertEqual(output.sequences.shape[0], batch_size) + self.assertEqual(output.timestamps.shape[0], batch_size) + + # Check timestamps are non-negative + self.assertTrue(torch.all(output.timestamps >= 0)) + + # Check tokens are within valid range (0 to vocab_size, excluding blank which is vocab_size) + non_pad_mask = output.sequences != config.pad_token_id + if non_pad_mask.any(): + valid_tokens = output.sequences[non_pad_mask] + self.assertTrue(torch.all(valid_tokens >= 0)) + self.assertTrue(torch.all(valid_tokens < config.vocab_size)) + + def test_generate_timestamps_are_monotonic(self): + """Test that timestamps are monotonically non-decreasing within each sequence.""" + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + model = ParakeetForTDT(config=config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model.generate( + input_features, + attention_mask=attention_mask, + return_dict_in_generate=True, + return_timestamps=True, + ) + + # For each sequence, check timestamps are monotonically non-decreasing + for i in range(output.timestamps.shape[0]): + seq_timestamps = output.timestamps[i] + # Get non-zero timestamps (zero padding at end) + non_zero_mask = seq_timestamps > 0 + if non_zero_mask.sum() > 1: + valid_timestamps = seq_timestamps[non_zero_mask] + # Check monotonicity: each timestamp >= previous + diffs = valid_timestamps[1:] - valid_timestamps[:-1] + self.assertTrue(torch.all(diffs >= 0), f"Timestamps not monotonic for batch {i}") + + def test_generate_without_timestamps(self): + """Test that timestamps are None when not requested.""" + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + model = ParakeetForTDT(config=config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model.generate( + input_features, + attention_mask=attention_mask, + return_dict_in_generate=True, + return_timestamps=False, + ) + + self.assertIsNone(output.timestamps) + self.assertIsNotNone(output.sequences) + + def test_generate_returns_tensor_without_dict(self): + """Test that generate() returns just sequences when return_dict_in_generate=False.""" + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + model = ParakeetForTDT(config=config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model.generate( + input_features, + attention_mask=attention_mask, + return_dict_in_generate=False, + ) + + # Should return just the tensor, not a dataclass + self.assertIsInstance(output, torch.Tensor) + self.assertEqual(output.shape[0], input_features.shape[0]) + + def test_generate_batch_independence(self): + """Test that batched inference produces same results as individual inference.""" + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + model = ParakeetForTDT(config=config) + model.to(torch_device) + model.eval() + + # Run batched inference + with torch.no_grad(): + batched_output = model.generate( + input_features, + attention_mask=attention_mask, + return_dict_in_generate=True, + return_timestamps=True, + ) + + # Run individual inference for first item + with torch.no_grad(): + single_output = model.generate( + input_features[0:1], + attention_mask=attention_mask[0:1] if attention_mask is not None else None, + return_dict_in_generate=True, + return_timestamps=True, + ) + + # Results should match for the first item + min_len = min(batched_output.sequences.shape[1], single_output.sequences.shape[1]) + torch.testing.assert_close( + batched_output.sequences[0, :min_len], + single_output.sequences[0, :min_len], + ) + + @unittest.skip(reason="ParakeetForTDT does not use inputs_embeds") + def test_model_get_set_embeddings(self): + pass + + # Override to handle composite model SDPA dispatch + def test_sdpa_can_dispatch_composite_models(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + +@require_torch +class ParakeetForTDTIntegrationTest(unittest.TestCase): + _dataset = None + + @classmethod + def setUp(cls): + # TODO: Change to "nvidia/parakeet-tdt-0.6b-v3" once NVIDIA adds HF format to their repo + cls.checkpoint_name = "MaksL/parakeet-tdt-0.6b-v3" + cls.dtype = torch.bfloat16 + cls.processor = AutoProcessor.from_pretrained("MaksL/parakeet-tdt-0.6b-v3") + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @classmethod + def _load_dataset(cls): + # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process. + if cls._dataset is None: + cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + cls._dataset = cls._dataset.cast_column( + "audio", Audio(sampling_rate=cls.processor.feature_extractor.sampling_rate) + ) + + def _load_datasamples(self, num_samples): + self._load_dataset() + ds = self._dataset + speech_samples = ds.sort("id")[:num_samples]["audio"] + return [x["array"] for x in speech_samples] + + @slow + def test_tdt_model_integration(self): + """ + Test TDT model inference on a single sample. + Tests that the model produces valid sequences and timestamps. + + Fixture generated by: python tests/models/parakeet/generate_tdt_fixtures.py + """ + RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_tdt_single.json" + with open(RESULTS_PATH, "r") as f: + raw_data = json.load(f) + EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) + EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] + + samples = self._load_datasamples(1) + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) + model.eval() + model.to(torch_device) + + # -- apply + inputs = self.processor(samples) + inputs.to(torch_device, dtype=self.dtype) + + with torch.no_grad(): + output = model.generate( + **inputs, + return_dict_in_generate=True, + return_timestamps=True, + ) + + # Check sequences match expected + torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) + predicted_transcripts = self.processor.batch_decode(output.sequences, skip_special_tokens=True) + self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) + + # Validate timestamps are monotonically non-decreasing + timestamps = output.timestamps[0].cpu() + non_zero_mask = timestamps > 0 + if non_zero_mask.sum() > 1: + valid_timestamps = timestamps[non_zero_mask] + diffs = valid_timestamps[1:] - valid_timestamps[:-1] + self.assertTrue(torch.all(diffs >= 0), "Timestamps should be monotonically non-decreasing") + + @slow + def test_tdt_model_integration_batched(self): + """ + Test TDT model inference on a batch of samples. + + Fixture generated by: python tests/models/parakeet/generate_tdt_fixtures.py + """ + RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_tdt_batch.json" + with open(RESULTS_PATH, "r") as f: + raw_data = json.load(f) + EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) + EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] + + samples = self._load_datasamples(5) + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) + model.eval() + model.to(torch_device) + + # -- apply + inputs = self.processor(samples) + inputs.to(torch_device, dtype=self.dtype) + + with torch.no_grad(): + output = model.generate( + **inputs, + return_dict_in_generate=True, + return_timestamps=True, + ) + + # Check sequences match expected + torch.testing.assert_close(output.sequences.cpu(), EXPECTED_TOKEN_IDS) + predicted_transcripts = self.processor.batch_decode(output.sequences, skip_special_tokens=True) + self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) + + # Validate timestamps for all samples in batch + for i in range(output.timestamps.shape[0]): + timestamps = output.timestamps[i].cpu() + non_zero_mask = timestamps > 0 + if non_zero_mask.sum() > 1: + valid_timestamps = timestamps[non_zero_mask] + diffs = valid_timestamps[1:] - valid_timestamps[:-1] + self.assertTrue(torch.all(diffs >= 0), f"Timestamps not monotonic for sample {i}") From 6e942252272cffe600c482ce83230b21b9c79703 Mon Sep 17 00:00:00 2001 From: Maksym Lypivskyi Date: Wed, 21 Jan 2026 21:01:37 +0100 Subject: [PATCH 2/2] fix: provide proper naming --- src/transformers/models/parakeet/modeling_parakeet.py | 6 +++--- src/transformers/models/parakeet/modular_parakeet.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 9794c394829f..d353ed5f0d9e 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -897,21 +897,21 @@ def init_state( self, batch_size: int, device: torch.device, dtype: torch.dtype = None ) -> tuple[torch.Tensor, torch.Tensor]: """Initialize LSTM hidden and cell states to zeros.""" - h = torch.zeros( + hidden_state = torch.zeros( self.config.decoder_num_layers, batch_size, self.config.decoder_hidden_size, device=device, dtype=dtype, ) - c = torch.zeros( + cell_state = torch.zeros( self.config.decoder_num_layers, batch_size, self.config.decoder_hidden_size, device=device, dtype=dtype, ) - return (h, c) + return (hidden_state, cell_state) class ParakeetTDTJointNetwork(nn.Module): diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index c8664ea91482..1162d1da0432 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -735,21 +735,21 @@ def init_state( self, batch_size: int, device: torch.device, dtype: torch.dtype = None ) -> tuple[torch.Tensor, torch.Tensor]: """Initialize LSTM hidden and cell states to zeros.""" - h = torch.zeros( + hidden_state = torch.zeros( self.config.decoder_num_layers, batch_size, self.config.decoder_hidden_size, device=device, dtype=dtype, ) - c = torch.zeros( + cell_state = torch.zeros( self.config.decoder_num_layers, batch_size, self.config.decoder_hidden_size, device=device, dtype=dtype, ) - return (h, c) + return (hidden_state, cell_state) class ParakeetTDTJointNetwork(nn.Module):