From 8a3e1cdb762f2b0334fec00193f4d740c04a3ae7 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Sun, 12 Oct 2025 20:04:38 -0400 Subject: [PATCH] parakeet tdt intergration --- .../models/auto/configuration_auto.py | 9 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 19 + .../modeling_fastspeech2_conformer.py | 6 +- .../models/parakeet/configuration_parakeet.py | 140 +++++- .../models/parakeet/convert_nemo_to_hf.py | 96 +++- .../models/parakeet/modeling_parakeet.py | 411 ++++++++++++++++- .../models/parakeet/modular_parakeet.py | 387 +++++++++++++++- src/transformers/pipelines/__init__.py | 3 +- .../pipelines/automatic_speech_recognition.py | 13 +- .../models/parakeet/test_modeling_parakeet.py | 425 ++++++++++++++++++ 11 files changed, 1483 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7e2e84a445ef..d7a41f0812ae 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -301,7 +301,10 @@ ("owlvit", "OwlViTConfig"), ("paligemma", "PaliGemmaConfig"), ("parakeet_ctc", "ParakeetCTCConfig"), + ("parakeet_tdt", "ParakeetTDTConfig"), ("parakeet_encoder", "ParakeetEncoderConfig"), + ("parakeet_tdt_decoder", "ParakeetTDTDecoderConfig"), + ("parakeet_tdt_joint", "ParakeetTDTJointConfig"), ("patchtsmixer", "PatchTSMixerConfig"), ("patchtst", "PatchTSTConfig"), ("pegasus", "PegasusConfig"), @@ -759,7 +762,10 @@ ("paligemma", "PaliGemma"), ("parakeet", "Parakeet"), ("parakeet_ctc", "Parakeet"), + ("parakeet_tdt", "ParakeetTDT"), ("parakeet_encoder", "ParakeetEncoder"), + ("parakeet_tdt_decoder", "ParakeetTDTDecoder"), + ("parakeet_tdt_joint", "ParakeetTDTJoint"), ("patchtsmixer", "PatchTSMixer"), ("patchtst", "PatchTST"), ("pegasus", "Pegasus"), @@ -1002,7 +1008,10 @@ ("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"), ("video_llama_3_vision", "video_llama_3"), ("parakeet_encoder", "parakeet"), + ("parakeet_tdt_decoder", "parakeet"), + ("parakeet_tdt_joint", "parakeet"), ("parakeet_ctc", "parakeet"), + ("parakeet_tdt", "parakeet"), ] ) diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 746f14dd52fc..be4769dfcbdc 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -54,6 +54,7 @@ ("moonshine", "Wav2Vec2FeatureExtractor"), ("moshi", "EncodecFeatureExtractor"), ("parakeet_ctc", "ParakeetFeatureExtractor"), + ("parakeet_tdt", "ParakeetFeatureExtractor"), ("parakeet_encoder", "ParakeetFeatureExtractor"), ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"), ("pop2piano", "Pop2PianoFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 197029464efd..064974e6cb3e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -301,7 +301,10 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("owlvit", "OwlViTModel"), ("paligemma", "PaliGemmaModel"), ("parakeet_ctc", "ParakeetForCTC"), + ("parakeet_tdt", "ParakeetForTDT"), ("parakeet_encoder", "ParakeetEncoder"), + ("parakeet_tdt_decoder", "ParakeetTDTDecoder"), + ("parakeet_tdt_joint", "ParakeetTDTJoint"), ("patchtsmixer", "PatchTSMixerModel"), ("patchtst", "PatchTSTModel"), ("pegasus", "PegasusModel"), @@ -1624,6 +1627,14 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ] ) +MODEL_FOR_TDT_MAPPING_NAMES = OrderedDict( + [ + # Model for Token-and-Duration Transducer (TDT) mapping. + ("parakeet_tdt", "ParakeetForTDT"), + ] +) + + MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Audio Classification mapping @@ -1883,6 +1894,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES ) MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) +MODEL_FOR_TDT_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TDT_MAPPING_NAMES) MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES @@ -2200,6 +2212,11 @@ class AutoModelForCTC(_BaseAutoModelClass): AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") +class AutoModelForTDT(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TDT_MAPPING + + +AutoModelForTDT = auto_class_update(AutoModelForTDT, head_doc="token-and-duration transducer") class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING @@ -2305,6 +2322,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CTC_MAPPING", + "MODEL_FOR_TDT_MAPPING", "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", @@ -2352,6 +2370,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): "AutoModelForAudioXVector", "AutoModelForCausalLM", "AutoModelForCTC", + "AutoModelForTDT", "AutoModelForDepthEstimation", "AutoModelForImageClassification", "AutoModelForImageSegmentation", diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py index 5a2dc39385b3..34b90ee6af28 100644 --- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py @@ -490,12 +490,12 @@ def __init__(self, config: FastSpeech2ConformerConfig, module_config=None): kernel_size = module_config["kernel_size"] self.activation = ACT2FN[module_config.get("activation", "silu")] self.padding = (kernel_size - 1) // 2 - self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True) + self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias) self.depthwise_conv = nn.Conv1d( - channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True + channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=config.attention_bias ) self.norm = nn.BatchNorm1d(channels) - self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True) + self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias) def forward(self, hidden_states, attention_mask=None): """ diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 1e3d97b4182e..b5e294d784ed 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -150,6 +150,65 @@ def __init__( self.initializer_range = initializer_range + +class ParakeetTDTDecoderConfig(PreTrainedConfig): + model_type = "parakeet_tdt_decoder" + keys_to_ignore_at_inference = ["past_key_values"] + output_hidden_states = False + + def __init__( + self, + hidden_size=640, + num_hidden_layers=1, + dropout=0, + vocab_size=1024, + forget_gate_bias=1.0, + t_max=None, + weights_init_scale=1.0, + hidden_hidden_bias_scale=0, + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.dropout = dropout + self.vocab_size = vocab_size + self.forget_gate_bias=forget_gate_bias + self.t_max=t_max + self.weights_init_scale=weights_init_scale + self.hidden_hidden_bias_scale=hidden_hidden_bias_scale + + +class ParakeetTDTJointConfig(PreTrainedConfig): + model_type = "parakeet_tdt_joint" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + enc_hidden_size=1024, + pred_hidden_size=640, + hidden_size=640, + vocab_size=1024, + durations=[0,1,2,3,4], + norm=None, + dropout=0.0, + activation='relu', + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.enc_hidden_size = enc_hidden_size + self.pred_hidden_size = pred_hidden_size + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.durations = durations + self.dropout = dropout + self.activation = activation + + class ParakeetCTCConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`ParakeetForCTC`]. It is used to instantiate a @@ -232,4 +291,83 @@ def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): return cls(encoder_config=encoder_config.to_dict(), **kwargs) -__all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig"] +class ParakeetTDTConfig(PreTrainedConfig): + + model_type = "parakeet_tdt" + sub_configs = {"encoder_config": ParakeetEncoderConfig, "decoder_config": ParakeetTDTDecoderConfig, "joint_config": ParakeetTDTJointConfig} + + def __init__( + self, +# bos_token_id=1, +# eos_token_id=2, +# pad_token_id=1024, + tdt_loss_reduction="mean", + encoder_config: Union[dict, ParakeetEncoderConfig] = None, + decoder_config: Union[dict, ParakeetTDTDecoderConfig] = None, + joint_config: Union[dict, ParakeetTDTJointConfig] = None, + **kwargs, + ): + + if encoder_config is None: + self.encoder_config = ParakeetEncoderConfig() + elif isinstance(encoder_config, dict): + self.encoder_config = ParakeetEncoderConfig(**encoder_config) + elif isinstance(encoder_config, ParakeetEncoderConfig): + self.encoder_config = encoder_config + else: + raise ValueError( + f"`encoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" + ) + + if decoder_config is None: + self.decoder_config = ParakeetTDTDecoderConfig() + elif isinstance(decoder_config, dict): + self.decoder_config = ParakeetTDTDecoderConfig(**decoder_config) + elif isinstance(decoder_config, ParakeetTDTDecoderConfig): + self.decoder_config = decoder_config + else: + raise ValueError( + f"`decoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" + ) + + if joint_config is None: + self.joint_config = ParakeetTDTJointConfig() + elif isinstance(joint_config, dict): + self.joint_config = ParakeetTDTJointConfig(**joint_config) + elif isinstance(joint_config, ParakeetTDTJointConfig): + self.joint_config = joint_config + else: + raise ValueError( + f"`decoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" + ) + + vocab_size = self.joint_config.vocab_size + self.vocab_size = vocab_size + + self.blank_token_id = vocab_size + super().__init__( +# pad_token_id=self.blank_token_id, + **kwargs, + ) + + @classmethod + def from_configs( + cls, + encoder_config: ParakeetEncoderConfig, + decoder_config: ParakeetTDTDecoderConfig, + joint_config: ParakeetTDTJointConfig, + **kwargs): + r""" + Instantiate a [`ParakeetConfig`] (or a derived class) from parakeet encoder model configuration. + + Returns: + [`ParakeetConfig`]: An instance of a configuration object + """ + + return cls( + encoder_config=encoder_config.to_dict(), + decoder_config=decoder_config.to_dict(), + joint_config=joint_config.to_dict(), + **kwargs) + +__all__ = ["ParakeetCTCConfig", "ParakeetTDTConfig", "ParakeetEncoderConfig", "ParakeetTDTDecoderConfig", "ParakeetTDTJointConfig"] diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index f1998fbd81b8..39469ef59375 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -25,8 +25,10 @@ from transformers import ( ParakeetCTCConfig, + ParakeetTDTConfig, ParakeetFeatureExtractor, ParakeetForCTC, + ParakeetForTDT, ParakeetProcessor, ParakeetTokenizerFast, ) @@ -220,8 +222,11 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i "stochastic_depth_mode", "conv_context_size", "dropout_pre_encoder", + "reduction", + "reduction_position", + "reduction_factor", ] - enocder_config_keys_mapping = { + encoder_config_keys_mapping = { "d_model": "hidden_size", "n_heads": "num_attention_heads", "n_layers": "num_hidden_layers", @@ -234,16 +239,65 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i "dropout_emb": "dropout_positions", "dropout_att": "attention_dropout", "xscaling": "scale_input", + "use_bias": "attention_bias", } converted_encoder_config = {} + decoder_keys_to_ignore = [ + "_target_", + "normalization_mode", + "random_state_sampling", + "blank_as_pad", + "prednet", + ] + decoder_config_keys_mapping = { + "vocab_size": "vocab_size", + } + converted_decoder_config = {} + + joint_keys_to_ignore = [ + "_target_", + 'log_softmax', + 'preserve_memory', + 'fuse_loss_wer', + 'fused_batch_size', + 'jointnet', + 'vocabulary' + ] + joint_config_keys_mapping = { + "vocab_size": "vocab_size", + "num_classes": "num_classes", + "num_extra_outputs": "num_extra_outputs", + } + converted_joint_config = {} + + for key, value in nemo_config["encoder"].items(): if key in encoder_keys_to_ignore: continue - if key in enocder_config_keys_mapping: - converted_encoder_config[enocder_config_keys_mapping[key]] = value + if key in encoder_config_keys_mapping: + converted_encoder_config[encoder_config_keys_mapping[key]] = value else: - raise ValueError(f"Key {key} not found in enocder_config_keys_mapping") + raise ValueError(f"Key {key} not found in encoder_config_keys_mapping") + + if model_type == 'tdt': + for key, value in nemo_config["decoder"].items(): + if key in decoder_keys_to_ignore: + continue + if key in decoder_config_keys_mapping: + converted_decoder_config[decoder_config_keys_mapping[key]] = value + else: + raise ValueError(f"Key {key} not found in encoder_config_keys_mapping") + + for key, value in nemo_config["joint"].items(): + if key in joint_keys_to_ignore: + continue + if key in joint_config_keys_mapping: + converted_joint_config[joint_config_keys_mapping[key]] = value + else: + raise ValueError(f"Key {key} not found in encoder_config_keys_mapping") + + converted_joint_config["vocab_size"] = converted_joint_config["num_classes"] state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True) converted_state_dict = {} @@ -280,6 +334,38 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") print("Model reloaded successfully.") + elif model_type == "tdt": + num_classes = converted_joint_config["num_classes"] + model_config = ParakeetTDTConfig( + pad_token_id=num_classes, + vocab_size=num_classes+1, + blank_token_id=num_classes, + encoder_config=converted_encoder_config, + decoder_config=converted_decoder_config, + joint_config=converted_joint_config, + ) + print("Loading the checkpoint in a Parakeet TDT model.") + with torch.device("meta"): + model = ParakeetForTDT(model_config) + model.load_state_dict(converted_state_dict, strict=True, assign=True) + print("Checkpoint loaded successfully.") + del model.config._name_or_path + + print("Saving the model.") + model.save_pretrained(output_dir) + + if push_to_repo_id: + model.push_to_hub(push_to_repo_id) + + del converted_state_dict, model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + ParakeetForTDT.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") + print("Model reloaded successfully.") + + else: raise ValueError(f"Model type {model_type} not supported.") @@ -303,7 +389,7 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co") - parser.add_argument("--model_type", required=True, choices=["ctc"], help="Model type (`ctc`, `tdt`)") + parser.add_argument("--model_type", required=True, choices=["ctc","tdt"], help="Model type (`ctc`, `tdt`)") parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model") parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub") args = parser.parse_args() diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 7f74cf33418e..d890c6a7bcbc 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -29,12 +29,19 @@ from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, CausalLMOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithNoAttention, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import check_model_inputs -from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig +from .configuration_parakeet import ( + ParakeetCTCConfig, + ParakeetEncoderConfig, + ParakeetTDTConfig, + ParakeetTDTDecoderConfig, + ParakeetTDTJointConfig, + PreTrainedConfig, +) class ParakeetEncoderRelPositionalEncoding(nn.Module): @@ -120,12 +127,22 @@ def __init__(self, config: ParakeetEncoderConfig, module_config=None): kernel_size = module_config["kernel_size"] self.activation = ACT2FN[module_config.get("activation", "silu")] self.padding = (kernel_size - 1) // 2 - self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True) + self.pointwise_conv1 = nn.Conv1d( + channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias + ) self.depthwise_conv = nn.Conv1d( - channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True + channels, + channels, + kernel_size, + stride=1, + padding=self.padding, + groups=channels, + bias=config.attention_bias, ) self.norm = nn.BatchNorm1d(channels) - self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True) + self.pointwise_conv2 = nn.Conv1d( + channels, channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias + ) def forward(self, hidden_states, attention_mask=None): """ @@ -225,7 +242,9 @@ def __init__(self, config: ParakeetEncoderConfig, layer_idx: int): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) # W_{k,R} projection - self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.relative_k_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) # global content bias self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim)) # global positional bias @@ -415,7 +434,7 @@ def forward( @auto_docstring class ParakeetPreTrainedModel(PreTrainedModel): - config: ParakeetCTCConfig + config: PreTrainedConfig base_model_prefix = "model" main_input_name = "input_features" input_modalities = "audio" @@ -450,7 +469,11 @@ def _init_weights(self, module): module.bias_v.data.normal_(mean=0.0, std=std) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): - encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config + encoder_config = ( + self.config.encoder_config + if isinstance(self.config, (ParakeetCTCConfig, ParakeetTDTConfig)) + else self.config + ) kernel_size = encoder_config.subsampling_conv_kernel_size stride = encoder_config.subsampling_conv_stride @@ -597,6 +620,270 @@ class ParakeetGenerateOutput(ModelOutput): hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None +class ParakeetLSTM(torch.nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + dropout: Optional[float], + forget_gate_bias: Optional[float], + t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, + proj_size: int = 0, + ): + """Returns an LSTM with forget gate bias init to `forget_gate_bias`. + Args: + input_size: See `torch.nn.LSTM`. + hidden_size: See `torch.nn.LSTM`. + num_layers: See `torch.nn.LSTM`. + dropout: See `torch.nn.LSTM`. + + forget_gate_bias: float, set by default to 1.0, which constructs a forget gate + initialized to 1.0. + Reference: + [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf) + + t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization + of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course + of training. + Reference: + [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) + + weights_init_scale: Float scale of the weights after initialization. Setting to lower than one + sometimes helps reduce variance between runs. + + hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for + the default behaviour. + + Returns: + A `torch.nn.LSTM`. + """ + super(ParakeetLSTM, self).__init__() + + self.lstm = torch.nn.LSTM( + input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, proj_size=proj_size + ) + + if t_max is not None: + # apply chrono init + for name, v in self.lstm.named_parameters(): + if "bias" in name: + p = getattr(self.lstm, name) + n = p.nelement() + hidden_size = n // 4 + p.data.fill_(0) + p.data[hidden_size : 2 * hidden_size] = torch.log( + torch.nn.init.uniform_(p.data[0:hidden_size], 1, t_max - 1) + ) + # forget gate biases = log(uniform(1, Tmax-1)) + p.data[0:hidden_size] = -p.data[hidden_size : 2 * hidden_size] + # input gate biases = -(forget gate biases) + + elif forget_gate_bias is not None: + for name, v in self.lstm.named_parameters(): + if "bias_ih" in name: + bias = getattr(self.lstm, name) + bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias) + if "bias_hh" in name: + bias = getattr(self.lstm, name) + bias.data[hidden_size : 2 * hidden_size] *= float(hidden_hidden_bias_scale) + + self.dropout = torch.nn.Dropout(dropout) if dropout else None + + for name, v in self.named_parameters(): + if "weight" in name or "bias" in name: + v.data *= float(weights_init_scale) + + def forward( + self, x: torch.Tensor, h: Optional[tuple[torch.Tensor, torch.Tensor]] = None + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + x, h = self.lstm(x, h) + + if self.dropout: + x = self.dropout(x) + + return x, h + + +class ParakeetTDTJoint(ParakeetPreTrainedModel): + config: ParakeetTDTJointConfig + base_model_prefix = "" # joint" + main_input_name = "enc" + _supports_flat_attention_mask = False + _supports_sdpa = True + _supports_flex_attn = False + _supports_attention_backend = False + _can_record_outputs = {} + _no_split_modules = None + + def __init__(self, config: ParakeetTDTJointConfig): + super().__init__(config) + self.config = config + self.gradient_checkpointing = False + + self.enc = torch.nn.Linear(config.enc_hidden_size, config.hidden_size) + self.pred = torch.nn.Linear(config.pred_hidden_size, config.hidden_size) + + num_classes = config.vocab_size + 1 + len(config.durations) + + layers = ( + [torch.nn.ReLU(inplace=True)] + + ([torch.nn.Dropout(p=self.config.dropout)]) + + [torch.nn.Linear(config.hidden_size, num_classes)] + ) + self.joint_net = torch.nn.Sequential(*layers) + self.post_init() + + @auto_docstring + @check_model_inputs() + def forward( + self, + enc: torch.Tensor, + pred: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithNoAttention: + # Right now we only support joint for inference. + + pred = pred.view([-1, self.config.pred_hidden_size]) # making it B, D + enc = enc.view([-1, self.config.enc_hidden_size]) # making it B, D + enc = self.enc(enc) + pred = self.pred(pred) + + assert enc.shape[0] == pred.shape[0] + output = self.joint_net(enc + pred) + return BaseModelOutput(last_hidden_state=output) + + +class ParakeetTDTPredictor(ParakeetPreTrainedModel): + def __init__(self, config: ParakeetTDTDecoderConfig): + super().__init__(config) + self.gradient_checkpointing = False + self.config = config + + self.embed = torch.nn.Embedding(config.vocab_size + 1, config.hidden_size) # +1 for blank + self.dec_rnn = self.rnn( + config.hidden_size, + config.hidden_size, + config.num_hidden_layers + 1, + config.forget_gate_bias, + config.dropout, + config.t_max, + config.weights_init_scale, + config.hidden_hidden_bias_scale, + ) + self.post_init() + + def rnn( + self, + input_size: int, + hidden_size: int, + num_layers: int, + forget_gate_bias: Optional[float] = 1.0, + dropout: Optional[float] = 0.0, + t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, + proj_size: int = 0, + ) -> torch.nn.Module: + return ParakeetLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + forget_gate_bias=forget_gate_bias, + t_max=t_max, + weights_init_scale=weights_init_scale, + hidden_hidden_bias_scale=hidden_hidden_bias_scale, + proj_size=proj_size, + ) + + @auto_docstring + @check_model_inputs() + @can_return_tuple + def forward( + self, + input_token, + states, + hidden_state=None, + **kwargs: Unpack[TransformersKwargs], + ): + assert input_token is not None + + device = self.embed.weight.device + if input_token.device != device: + input_token = input_token.to(device) + return self.predict(input_token, state=states) + + def predict(self, y, state): + # Get device and dtype of current module + + # (B, U) -> (B, U, H) + y = self.embed(y).transpose(0, 1) # (U + 1, B, H) + + g, hid = self.dec_rnn(y, state) + g = g.transpose(0, 1).transpose(1, 2) # (B, H, U + 1) + + return g, hid + + +@auto_docstring( + custom_intro=""" + The Parakeet TDT Decoder. This class encapsulates both the predictor and joint network for TDT models. + """ +) +class ParakeetTDTDecoder(ParakeetPreTrainedModel): + config: ParakeetTDTDecoderConfig + base_model_prefix = "decoder" + main_input_name = "input_token" + _supports_flat_attention_mask = False + _supports_sdpa = True + _supports_flex_attn = False + _supports_attention_backend = False + _can_record_outputs = {} + _no_split_modules = None + + def __init__(self, config: ParakeetTDTDecoderConfig): + super().__init__(config) + self.config = config + self.gradient_checkpointing = False + self.prediction = ParakeetTDTPredictor(config) + self.post_init() + + def _init_weights(self, module): + if hasattr(self.config, "initializer_range"): + std = self.config.initializer_range + else: + # 0.02 is the standard default value accross the library + std = getattr(self.config.get_text_config(), "initializer_range", 0.02) + + module.prediction.embed.weight.data.normal_(mean=0.0, std=std) + for param in module.prediction.dec_rnn.lstm.parameters(): + param.data.normal_(mean=0.0, std=std) + + def get_input_embeddings(self): + return self.prediction.embed + + def set_input_embeddings(self, embed): + self.prediction.embed = embed + + @auto_docstring + @check_model_inputs() + @can_return_tuple + def forward( + self, + input_token, + hidden_state=None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithNoAttention: + if hidden_state is not None: + hidden_state = tuple(hidden_state.unbind(dim=0)) + + h_out, h_state = self.prediction(input_token, hidden_state, **kwargs) + return BaseModelOutputWithNoAttention(h_out, torch.stack(h_state, dim=0)) + + @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -741,4 +1028,110 @@ def generate( return sequences -__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"] +@auto_docstring( + custom_intro=""" + Parakeet TDT model. + """ +) +class ParakeetForTDT(ParakeetPreTrainedModel): + config: ParakeetTDTConfig + + def __init__(self, config: ParakeetTDTConfig): + super().__init__(config) + self.encoder = ParakeetEncoder(config.encoder_config) + self.decoder = ParakeetTDTDecoder(config.decoder_config) + self.joint = ParakeetTDTJoint(config.joint_config) + self.blank_token_id = config.blank_token_id + self.max_token_per_frame = 2 + self.post_init() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_features: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + encoder_outputs = self.encoder( + input_features=input_features, + **kwargs, + ) + + logits = self.joint.joint_net( + self.joint.enc(encoder_outputs.last_hidden_state) + ) # [:,:,:self.joint.vocab_size] + + return CausalLMOutput( + loss=torch.sum(encoder_outputs.last_hidden_state), # a fake loss here. + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_features: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + encoder_outputs = self.encoder( + input_features=input_features, + **kwargs, + ) + output = self.greedy_decode(encoder_outputs.last_hidden_state) + + return output + + def greedy_decode(self, encoder_output): + T = encoder_output.shape[1] + t = 0 + hyp = [] + last_label = torch.LongTensor([[self.blank_token_id]]) + dec_out = self.decoder(input_token=last_label) + g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states + + symbols_added = 0 + while t < T: + enc = encoder_output[0, t, :] + while symbols_added < self.max_token_per_frame: + logits = self.joint(enc, g).last_hidden_state + + logits = logits.view([-1]) + + token_logits = logits[: self.blank_token_id + 1].softmax(-1) + duration_logits = logits[self.blank_token_id + 1 :].softmax(-1) + + v, token = token_logits.max(-1) + v_duration, duration = duration_logits.max(-1) + token = token.item() + duration = duration.item() + + if token != self.blank_token_id: + hyp.append(token) + last_label = token + last_label = torch.LongTensor([[last_label]]) + dec_out = self.decoder(last_label, hidden_prime) + g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states + + if duration == 0: + symbols_added += 1 + else: + t += duration + symbols_added = 0 + break + + if symbols_added == self.max_token_per_frame: + t += 1 + symbols_added = 0 + + return hyp + + +__all__ = [ + "ParakeetForCTC", + "ParakeetForTDT", + "ParakeetEncoder", + "ParakeetTDTDecoder", + "ParakeetTDTJoint", + "ParakeetPreTrainedModel", +] diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index b8366c8cd086..080ad9dbd18d 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -17,21 +17,21 @@ import math from collections.abc import Callable from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Union, Tuple import torch from torch import nn from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, CausalLMOutput +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, BaseModelOutputWithNoAttention from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import check_model_inputs from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule from ..llama.modeling_llama import LlamaAttention, eager_attention_forward -from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig +from .configuration_parakeet import PreTrainedConfig, ParakeetCTCConfig, ParakeetTDTConfig, ParakeetEncoderConfig, ParakeetTDTDecoderConfig, ParakeetTDTJointConfig class ParakeetEncoderRelPositionalEncoding(nn.Module): @@ -111,7 +111,7 @@ def __init__(self, config: ParakeetEncoderConfig, layer_idx: int): super().__init__(config, layer_idx=layer_idx) self.is_causal = False # W_{k,R} projection - self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) # global content bias self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim)) # global positional bias @@ -301,7 +301,7 @@ def forward( @auto_docstring class ParakeetPreTrainedModel(PreTrainedModel): - config: ParakeetCTCConfig + config: PreTrainedConfig base_model_prefix = "model" main_input_name = "input_features" input_modalities = "audio" @@ -336,7 +336,7 @@ def _init_weights(self, module): module.bias_v.data.normal_(mean=0.0, std=std) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): - encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config + encoder_config = self.config.encoder_config if isinstance(self.config, (ParakeetCTCConfig, ParakeetTDTConfig)) else self.config kernel_size = encoder_config.subsampling_conv_kernel_size stride = encoder_config.subsampling_conv_stride @@ -483,6 +483,277 @@ class ParakeetGenerateOutput(ModelOutput): hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None +class ParakeetLSTM(torch.nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + dropout: Optional[float], + forget_gate_bias: Optional[float], + t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, + proj_size: int = 0, + ): + """Returns an LSTM with forget gate bias init to `forget_gate_bias`. + Args: + input_size: See `torch.nn.LSTM`. + hidden_size: See `torch.nn.LSTM`. + num_layers: See `torch.nn.LSTM`. + dropout: See `torch.nn.LSTM`. + + forget_gate_bias: float, set by default to 1.0, which constructs a forget gate + initialized to 1.0. + Reference: + [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf) + + t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization + of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course + of training. + Reference: + [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) + + weights_init_scale: Float scale of the weights after initialization. Setting to lower than one + sometimes helps reduce variance between runs. + + hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for + the default behaviour. + + Returns: + A `torch.nn.LSTM`. + """ + super(ParakeetLSTM, self).__init__() + + self.lstm = torch.nn.LSTM( + input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, proj_size=proj_size + ) + + if t_max is not None: + # apply chrono init + for name, v in self.lstm.named_parameters(): + if 'bias' in name: + p = getattr(self.lstm, name) + n = p.nelement() + hidden_size = n // 4 + p.data.fill_(0) + p.data[hidden_size : 2 * hidden_size] = torch.log( + torch.nn.init.uniform_(p.data[0:hidden_size], 1, t_max - 1) + ) + # forget gate biases = log(uniform(1, Tmax-1)) + p.data[0:hidden_size] = -p.data[hidden_size : 2 * hidden_size] + # input gate biases = -(forget gate biases) + + elif forget_gate_bias is not None: + for name, v in self.lstm.named_parameters(): + if "bias_ih" in name: + bias = getattr(self.lstm, name) + bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias) + if "bias_hh" in name: + bias = getattr(self.lstm, name) + bias.data[hidden_size : 2 * hidden_size] *= float(hidden_hidden_bias_scale) + + self.dropout = torch.nn.Dropout(dropout) if dropout else None + + for name, v in self.named_parameters(): + if 'weight' in name or 'bias' in name: + v.data *= float(weights_init_scale) + + def forward( + self, x: torch.Tensor, h: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + x, h = self.lstm(x, h) + + if self.dropout: + x = self.dropout(x) + + return x, h + +class ParakeetTDTJoint(ParakeetPreTrainedModel): + config: ParakeetTDTJointConfig + base_model_prefix = "" #joint" + main_input_name = "enc" + _supports_flat_attention_mask = False + _supports_sdpa = True + _supports_flex_attn = False + _supports_attention_backend = False + _can_record_outputs = {} + _no_split_modules = None + + def __init__(self, config: ParakeetTDTJointConfig): + super().__init__(config) + self.config = config + self.gradient_checkpointing = False + + self.enc = torch.nn.Linear(config.enc_hidden_size, config.hidden_size) + self.pred = torch.nn.Linear(config.pred_hidden_size, config.hidden_size) + + num_classes = config.vocab_size + 1 + len(config.durations) + + layers = ( + [torch.nn.ReLU(inplace=True)] + + ([torch.nn.Dropout(p=self.config.dropout)]) + + [torch.nn.Linear(config.hidden_size, num_classes)] + ) + self.joint_net = torch.nn.Sequential(*layers) + self.post_init() + + @auto_docstring + @check_model_inputs() + def forward( + self, + enc: torch.Tensor, + pred: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithNoAttention: + + # Right now we only support joint for inference. + + pred = pred.view([-1, self.config.pred_hidden_size]) # making it B, D + enc = enc.view([-1, self.config.enc_hidden_size]) # making it B, D + enc = self.enc(enc) + pred = self.pred(pred) + + assert enc.shape[0] == pred.shape[0] + output = self.joint_net(enc + pred) + return BaseModelOutput(last_hidden_state=output) + + +class ParakeetTDTPredictor(ParakeetPreTrainedModel): + + def __init__(self, config: ParakeetTDTDecoderConfig): + super().__init__(config) + self.gradient_checkpointing = False + self.config = config + + self.embed = torch.nn.Embedding(config.vocab_size + 1, config.hidden_size) # +1 for blank + self.dec_rnn = self.rnn( + config.hidden_size, + config.hidden_size, + config.num_hidden_layers + 1, + config.forget_gate_bias, + config.dropout, + config.t_max, + config.weights_init_scale, + config.hidden_hidden_bias_scale, + ) + self.post_init() + + + def rnn( + self, + input_size: int, + hidden_size: int, + num_layers: int, + forget_gate_bias: Optional[float] = 1.0, + dropout: Optional[float] = 0.0, + t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, + proj_size: int = 0, + ) -> torch.nn.Module: + return ParakeetLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + forget_gate_bias=forget_gate_bias, + t_max=t_max, + weights_init_scale=weights_init_scale, + hidden_hidden_bias_scale=hidden_hidden_bias_scale, + proj_size=proj_size, + ) + + + @auto_docstring + @check_model_inputs() + @can_return_tuple + def forward( + self, + input_token, + states, + hidden_state = None, + **kwargs: Unpack[TransformersKwargs], + ): + assert input_token is not None + + device = self.embed.weight.device + if input_token.device != device: + input_token = input_token.to(device) + return self.predict(input_token, state=states) + + def predict(self, y, state): + # Get device and dtype of current module + + # (B, U) -> (B, U, H) + y = self.embed(y).transpose(0, 1) # (U + 1, B, H) + + g, hid = self.dec_rnn(y, state) + g = g.transpose(0, 1).transpose(1, 2) # (B, H, U + 1) + + return g, hid + + + +@auto_docstring( + custom_intro=""" + The Parakeet TDT Decoder. This class encapsulates both the predictor and joint network for TDT models. + """ +) +class ParakeetTDTDecoder(ParakeetPreTrainedModel): + config: ParakeetTDTDecoderConfig + base_model_prefix = "decoder" + main_input_name = "input_token" + _supports_flat_attention_mask = False + _supports_sdpa = True + _supports_flex_attn = False + _supports_attention_backend = False + _can_record_outputs = {} + _no_split_modules = None + + def __init__(self, config: ParakeetTDTDecoderConfig): + super().__init__(config) + self.config = config + self.gradient_checkpointing = False + self.prediction = ParakeetTDTPredictor(config) + self.post_init() + + def _init_weights(self, module): + if hasattr(self.config, "initializer_range"): + std = self.config.initializer_range + else: + # 0.02 is the standard default value accross the library + std = getattr(self.config.get_text_config(), "initializer_range", 0.02) + + module.prediction.embed.weight.data.normal_(mean=0.0, std=std) + for param in module.prediction.dec_rnn.lstm.parameters(): + param.data.normal_(mean=0.0, std=std) + + def get_input_embeddings(self): + return self.prediction.embed + + def set_input_embeddings(self, embed): + self.prediction.embed = embed + + @auto_docstring + @check_model_inputs() + @can_return_tuple + def forward( + self, + input_token, + hidden_state = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithNoAttention: + + if hidden_state is not None: + hidden_state = tuple(hidden_state.unbind(dim=0)) + + h_out, h_state = self.prediction(input_token, hidden_state, **kwargs) + return BaseModelOutputWithNoAttention(h_out, torch.stack(h_state, dim=0)) + + + @auto_docstring( custom_intro=""" Parakeet Encoder with a Connectionist Temporal Classification (CTC) head. @@ -627,4 +898,106 @@ def generate( return sequences -__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"] +@auto_docstring( + custom_intro=""" + Parakeet TDT model. + """ +) +class ParakeetForTDT(ParakeetPreTrainedModel): + config: ParakeetTDTConfig + + def __init__(self, config: ParakeetTDTConfig): + super().__init__(config) + self.encoder = ParakeetEncoder(config.encoder_config) + self.decoder = ParakeetTDTDecoder(config.decoder_config) + self.joint = ParakeetTDTJoint(config.joint_config) + self.blank_token_id = config.blank_token_id + self.max_token_per_frame = 2 + self.post_init() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_features: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + encoder_outputs = self.encoder( + input_features=input_features, + **kwargs, + ) + + logits = self.joint.joint_net(self.joint.enc(encoder_outputs.last_hidden_state)) #[:,:,:self.joint.vocab_size] + + return CausalLMOutput( + loss=torch.sum(encoder_outputs.last_hidden_state), # a fake loss here. + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_features: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + + encoder_outputs = self.encoder( + input_features=input_features, + **kwargs, + ) + output = self.greedy_decode(encoder_outputs.last_hidden_state) + + return output + + def greedy_decode(self, encoder_output): + T = encoder_output.shape[1] + t = 0 + hyp = [] + last_label = torch.LongTensor([[self.blank_token_id]]) + dec_out = self.decoder(input_token=last_label) + g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states + + symbols_added = 0 + while t < T: + enc = encoder_output[0,t,:] + while symbols_added < self.max_token_per_frame: + logits = self.joint(enc, g).last_hidden_state + + logits = logits.view([-1]) + + token_logits = logits[:self.blank_token_id + 1].softmax(-1) + duration_logits = logits[self.blank_token_id + 1:].softmax(-1) + + v, token = token_logits.max(-1) + v_duration, duration = duration_logits.max(-1) + token = token.item() + duration = duration.item() + + if token != self.blank_token_id: + hyp.append(token) + last_label = token + last_label = torch.LongTensor([[last_label]]) + dec_out = self.decoder(last_label, hidden_prime) + g, hidden_prime = dec_out.last_hidden_state, dec_out.hidden_states + + if duration == 0: + symbols_added += 1 + else: + t += duration + symbols_added = 0 + break + + if symbols_added == self.max_token_per_frame: + t += 1 + symbols_added = 0 + + + return hyp + + + + + +__all__ = ["ParakeetForCTC", "ParakeetForTDT", "ParakeetEncoder", "ParakeetTDTDecoder", "ParakeetTDTJoint", "ParakeetPreTrainedModel"] diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 87db1981cbc0..2a79f42521fd 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -100,6 +100,7 @@ AutoModelForAudioClassification, AutoModelForCausalLM, AutoModelForCTC, + AutoModelForTDT, AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, @@ -148,7 +149,7 @@ }, "automatic-speech-recognition": { "impl": AutomaticSpeechRecognitionPipeline, - "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (), + "pt": (AutoModelForCTC, AutoModelForTDT, AutoModelForSpeechSeq2Seq) if is_torch_available() else (), "default": {"model": ("facebook/wav2vec2-base-960h", "22aad52")}, "type": "multimodal", }, diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 37034ad94f94..223c7bd20127 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -198,6 +198,8 @@ def __init__( self.type = "seq2seq_whisper" elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values(): self.type = "seq2seq" + elif model.config.model_type == "parakeet_tdt": + self.type = "tdt" elif ( feature_extractor._processor_class and feature_extractor._processor_class.endswith("WithLM") @@ -551,7 +553,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): if stride is not None: out["stride"] = stride - else: + elif self.type in {"ctc", "ctc_with_lm"}: inputs = { self.model.main_input_name: model_inputs.pop(self.model.main_input_name), "attention_mask": attention_mask, @@ -572,6 +574,15 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): out["stride"] = rescale_stride([stride], ratio)[0] else: out["stride"] = rescale_stride(stride, ratio) + elif self.type == 'tdt': + inputs = { + self.model.main_input_name: model_inputs.pop(self.model.main_input_name), + } + outputs = self.model.generate(**inputs) + out = {"tokens": torch.LongTensor(outputs).view([1, -1])} + else: + raise ValueError("Unsupported model type {self.type}.") + # Leftover extra = model_inputs return {"is_last": is_last, **out, **extra} diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py index 7bd35946574f..f14eb7f7196f 100644 --- a/tests/models/parakeet/test_modeling_parakeet.py +++ b/tests/models/parakeet/test_modeling_parakeet.py @@ -14,6 +14,7 @@ """Testing suite for the PyTorch Parakeet model.""" import json +import copy import tempfile import unittest from pathlib import Path @@ -34,9 +35,15 @@ from transformers import ( AutoProcessor, ParakeetCTCConfig, + ParakeetTDTConfig, ParakeetEncoder, ParakeetEncoderConfig, + ParakeetTDTDecoder, + ParakeetTDTDecoderConfig, + ParakeetTDTJoint, + ParakeetTDTJointConfig, ParakeetForCTC, + ParakeetForTDT, ) @@ -184,6 +191,232 @@ def test_model_get_set_embeddings(self): pass +class ParakeetTDTDecoderModelTester: + def __init__( + self, + parent, + batch_size=16, + vocab_size=128, + hidden_size=64, + num_hidden_layers=2, + seq_length=32, + is_training=True, + dropout=0, # so gradient checkpointing doesn't fail + ): + # testing suite parameters + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + + # config parameters + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.seq_length = seq_length + self.output_seq_length = seq_length + self.vocab_size = vocab_size + + def prepare_config_and_inputs(self): + input_token = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + config = self.get_config() + + return config, input_token + + def get_config(self): + return ParakeetTDTDecoderConfig( + num_hidden_layers=self.num_hidden_layers, + hidden_size=self.hidden_size, + vocab_size=self.vocab_size, + ) + + def create_and_check_model(self, config, input_token): + pass + model = ParakeetTDTDecoder(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_token) + + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, config.hidden_size, self.output_seq_length) + ) + + def prepare_config_and_inputs_for_common(self): + config, input_token = self.prepare_config_and_inputs() + inputs_dict = { + "input_token": input_token, + } + return config, inputs_dict + + + + +@require_torch +class ParakeetTDTDecoderModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (ParakeetTDTDecoder,) if is_torch_available() else () + + test_resize_embeddings = False + test_torch_exportable = True + has_attentions = False + is_encoder_decoder = False + + def setUp(self): + self.model_tester = ParakeetTDTDecoderModelTester(self) + self.config_tester = ConfigTester(self, config_class=ParakeetTDTDecoderConfig, has_text_modality=False, common_properties=['hidden_size','num_hidden_layers']) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(copy.deepcopy(config)) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(hidden_states.shape[1], expected_num_layers) + + if hasattr(self.model_tester, "encoder_seq_length"): + seq_length = self.model_tester.encoder_seq_length + if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1: + seq_length = seq_length * self.model_tester.chunk_length + else: + seq_length = self.model_tester.seq_length + + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + for k in config.sub_configs: + if getattr(config, k) is not None: + getattr(config, k).output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + @unittest.skip(reason="this class only returns the last hidden state not prior ones, and there is no gradient on last hidden state w.r.t output.") + def test_retain_grad_hidden_states_attentions(self): + pass + + +class ParakeetTDTJointModelTester: + def __init__( + self, + parent, + batch_size=16, + vocab_size=128, + hidden_size=64, + pred_hidden_size=64, + enc_hidden_size=64, + num_hidden_layers=2, + durations=[0,1,2,3,4], + is_training=True, + dropout=0.1, # so gradient checkpointing doesn't fail + ): + # testing suite parameters + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + + # config parameters + self.hidden_size = hidden_size + self.pred_hidden_size = pred_hidden_size + self.enc_hidden_size = enc_hidden_size + self.num_hidden_layers = num_hidden_layers + self.t_length = 1 # so far only support 1 + self.u_length = 1 # so far only support 1 + self.output_seq_length = -1 + self.vocab_size = vocab_size + self.durations = durations + + def prepare_config_and_inputs(self): + enc = floats_tensor([self.batch_size, self.t_length, self.enc_hidden_size]) + pred = floats_tensor([self.batch_size, self.u_length, self.pred_hidden_size]) + config = self.get_config() + + return config, enc, pred + + def get_config(self): + return ParakeetTDTJointConfig( + num_hidden_layers=self.num_hidden_layers, + hidden_size=self.hidden_size, + pred_hidden_size=self.enc_hidden_size, + enc_hidden_size=self.enc_hidden_size, + vocab_size=self.vocab_size, + durations=self.durations, + ) + + def create_and_check_model(self, config, enc, pred): + model = ParakeetTDTJoint(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(enc, pred) + + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, config.vocab_size + 1 + len(config.durations)) + ) + + def prepare_config_and_inputs_for_common(self): + config, enc, pred = self.prepare_config_and_inputs() + inputs_dict = { + "enc": enc, + "pred": pred, + } + return config, inputs_dict + + + + +@require_torch +class ParakeetTDTJointModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (ParakeetTDTJoint,) if is_torch_available() else () + + test_resize_embeddings = False + test_torch_exportable = True + has_attentions = False + is_encoder_decoder = False + + def setUp(self): + self.model_tester = ParakeetTDTJointModelTester(self) + self.config_tester = ConfigTester(self, config_class=ParakeetTDTJointConfig, has_text_modality=False, common_properties=['hidden_size']) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="this class doesn't have hidden states.") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="this class doesn't have hidden states.") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="ParakeetJoint does not use inputs_embeds") + def test_model_get_set_embeddings(self): + pass + + + class ParakeetForCTCModelTester: def __init__(self, parent, encoder_kwargs=None, is_training=True, vocab_size=128, pad_token_id=0): if encoder_kwargs is None: @@ -375,3 +608,195 @@ def test_1b_model_integration_batched(self): torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) + + + +class ParakeetForTDTModelTester: + def __init__(self, + parent, + encoder_kwargs=None, + decoder_kwargs=None, + joint_kwargs=None, + is_training=True, + vocab_size=128, + durations=[0,1,2,3,4], + pad_token_id=0 + ): + if encoder_kwargs is None: + encoder_kwargs = {} + if decoder_kwargs is None: + decoder_kwargs = {} + if joint_kwargs is None: + joint_kwargs = {} + + self.parent = parent + self.encoder_model_tester = ParakeetEncoderModelTester(parent, **encoder_kwargs) + self.decoder_model_tester = ParakeetTDTDecoderModelTester(parent, **decoder_kwargs) + self.joint_model_tester = ParakeetTDTJointModelTester(parent, **joint_kwargs) + self.is_training = is_training + + self.batch_size = self.encoder_model_tester.batch_size + self.output_seq_length = self.encoder_model_tester.output_seq_length + self.num_hidden_layers = self.encoder_model_tester.num_hidden_layers + self.seq_length = vocab_size + self.enc_hidden_size = self.encoder_model_tester.hidden_size + self.hidden_size = self.encoder_model_tester.hidden_size # this field is needed for test class + self.pred_hidden_size = self.decoder_model_tester.hidden_size + self.joint_hidden_size = self.joint_model_tester.hidden_size + + self.durations = durations + + self.vocab_size = vocab_size + len(self.durations) + 1 + self.pad_token_id = pad_token_id + + + def prepare_config_and_inputs(self): + _, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs() + config = self.get_config() + return config, input_features, attention_mask + + def get_config(self): + return ParakeetTDTConfig.from_configs( + encoder_config=self.encoder_model_tester.get_config(), + decoder_config=self.decoder_model_tester.get_config(), + joint_config=self.joint_model_tester.get_config(), + vocab_size=self.vocab_size, + durations=self.durations, + ) + + def create_and_check_model(self, config, input_features, attention_mask): + model = ParakeetForTDT(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_features, attention_mask=attention_mask) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.output_seq_length, self.vocab_size)) + + def prepare_config_and_inputs_for_common(self): + config, input_features, attention_mask = self.prepare_config_and_inputs() + inputs_dict = { + "input_features": input_features, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class ParakeetForTDTModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (ParakeetForTDT,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": ParakeetEncoder, + "automatic-speech-recognition": ParakeetForTDT, + } + if is_torch_available() + else {} + ) + + test_attention_outputs = False + + test_resize_embeddings = False + test_torch_exportable = True + + _is_composite = True + + def setUp(self): + self.model_tester = ParakeetForTDTModelTester(self) + self.config_tester = ConfigTester(self, config_class=ParakeetTDTConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="ParakeetEncoder does not use inputs_embeds") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="batching not supported") + def test_batching_equivalence(self): + pass + + # Original function assumes vision+text model, so overwrite since Parakeet is audio+text + # Below is modified from `tests/models/granite_speech/test_modeling_granite_speech.py` + def test_sdpa_can_dispatch_composite_models(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + +@require_torch +class ParakeetForTDTIntegrationTest(unittest.TestCase): + _dataset = None + + @classmethod + def setUp(cls): + cls.checkpoint_name = "hainanx/parakeet-tdt-0.6b-v3" + cls.dtype = torch.bfloat16 + cls.processor = AutoProcessor.from_pretrained("hainanx/parakeet-tdt-0.6b-v3") + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @classmethod + def _load_dataset(cls): + # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process. + if cls._dataset is None: + cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + cls._dataset = cls._dataset.cast_column( + "audio", Audio(sampling_rate=cls.processor.feature_extractor.sampling_rate) + ) + + def _load_datasamples(self, num_samples): + self._load_dataset() + ds = self._dataset + speech_samples = ds.sort("id")[:num_samples]["audio"] + return [x["array"] for x in speech_samples] + + @slow + def test_1b_model_integration(self): + """ + bezzam reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single-py + eustlb reproducer: https://gist.github.com/eustlb/6e9e3aa85de3f7c340ec3c36e65f2fe6 + """ + RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single.json" + with open(RESULTS_PATH, "r") as f: + raw_data = json.load(f) + EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) + EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] + + samples = self._load_datasamples(1) + model = ParakeetForTDT.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device) + model.eval() + model.to(torch_device) + + # -- apply + inputs = self.processor(samples) + inputs.to(torch_device, dtype=self.dtype) + predicted_ids = model.generate(**inputs) + torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) + predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) + self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)