-
Notifications
You must be signed in to change notification settings - Fork 33.1k
Add Parakeet TDT model support #43357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -312,6 +312,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): | |
| ("paligemma", "PaliGemmaModel"), | ||
| ("parakeet_ctc", "ParakeetForCTC"), | ||
| ("parakeet_encoder", "ParakeetEncoder"), | ||
| ("parakeet_tdt", "ParakeetForTDT"), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the other PR, you'll see that he added some code for loading |
||
| ("patchtsmixer", "PatchTSMixerModel"), | ||
| ("patchtst", "PatchTSTModel"), | ||
| ("pe_audio", "PeAudioModel"), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| from .configuration_parakeet import * | ||
| from .feature_extraction_parakeet import * | ||
| from .modeling_parakeet import * | ||
| from .processing_parakeet import * | ||
| from .tokenization_parakeet_fast import * | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this import should also be fixed? as there is no |
||
| else: | ||
| import sys | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -229,4 +229,113 @@ def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): | |
| return cls(encoder_config=encoder_config.to_dict(), **kwargs) | ||
|
|
||
|
|
||
| __all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig"] | ||
| class ParakeetTDTConfig(PreTrainedConfig): | ||
| r""" | ||
| This is the configuration class to store the configuration of a [`ParakeetForTDT`]. It is used to instantiate a | ||
| Parakeet TDT (Token Duration Transducer) model according to the specified arguments, defining the model architecture. | ||
|
|
||
| TDT models jointly predict tokens and their durations, enabling accurate speech recognition with word-level | ||
| timestamps. Unlike CTC which only provides character-level timing, TDT predicts how many frames each token spans. | ||
|
|
||
| Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the | ||
| documentation from [`PreTrainedConfig`] for more information. | ||
|
|
||
| Args: | ||
| vocab_size (`int`, *optional*, defaults to 8192): | ||
| Vocabulary size of the model (SentencePiece tokenizer). TDT uses a larger vocabulary than CTC. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can be more concise here "Vocabulary size of the model." |
||
| decoder_hidden_size (`int`, *optional*, defaults to 640): | ||
| Hidden size of the LSTM prediction network (decoder). | ||
| decoder_num_layers (`int`, *optional*, defaults to 1): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's rename to |
||
| Number of LSTM layers in the prediction network. | ||
| joint_hidden_size (`int`, *optional*, defaults to 640): | ||
| Hidden size of the joint network that combines encoder and decoder outputs. | ||
| num_duration_bins (`int`, *optional*, defaults to 5): | ||
| Number of duration bins for predicting token durations. Each bin represents how many frames | ||
| to advance (e.g., [0, 1, 2, 3, 4] means 0-4 frames). | ||
| encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*): | ||
| The config object or dictionary of the encoder. TDT reuses the same FastConformer encoder as CTC. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can be more concise here "The config object or dictionary of the encoder." |
||
| blank_token_id (`int`, *optional*, defaults to 8192): | ||
| Token ID for the blank symbol. In TDT, blank indicates "no token emission, advance frames". | ||
| pad_token_id (`int`, *optional*, defaults to 8192): | ||
| Padding token id. Defaults to blank_token_id. | ||
|
Comment on lines
+257
to
+260
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need both? Or can we just define |
||
|
|
||
| Example: | ||
| ```python | ||
| >>> from transformers import ParakeetForTDT, ParakeetTDTConfig | ||
|
|
||
| >>> # Initializing a Parakeet TDT configuration | ||
| >>> configuration = ParakeetTDTConfig() | ||
|
|
||
| >>> # Initializing a model from the configuration | ||
| >>> model = ParakeetForTDT(configuration) | ||
|
|
||
| >>> # Accessing the model configuration | ||
| >>> configuration = model.config | ||
| ``` | ||
|
|
||
| This configuration class is based on the Parakeet TDT architecture from NVIDIA NeMo. TDT (Token Duration | ||
| Transducer) extends RNN-T by jointly predicting token durations, enabling efficient frame skipping during | ||
| decoding. You can find more details and pre-trained models at: | ||
| - [nvidia/parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2) (English) | ||
| - [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) (25 languages) | ||
|
|
||
| References: | ||
| - TDT Paper: https://arxiv.org/abs/2304.06795 | ||
| - FastConformer Paper: https://arxiv.org/abs/2305.05084 | ||
|
Comment on lines
+282
to
+284
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can remove this |
||
| """ | ||
|
|
||
| model_type = "parakeet_tdt" | ||
| sub_configs = {"encoder_config": ParakeetEncoderConfig} | ||
|
|
||
| def __init__( | ||
| self, | ||
| vocab_size=8192, | ||
| decoder_hidden_size=640, | ||
| decoder_num_layers=1, | ||
| joint_hidden_size=640, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a case where |
||
| num_duration_bins=5, | ||
| encoder_config: dict | ParakeetEncoderConfig = None, | ||
| blank_token_id=8192, | ||
| pad_token_id=8192, | ||
| **kwargs, | ||
| ): | ||
| self.vocab_size = vocab_size | ||
| self.decoder_hidden_size = decoder_hidden_size | ||
| self.decoder_num_layers = decoder_num_layers | ||
| self.joint_hidden_size = joint_hidden_size | ||
| self.num_duration_bins = num_duration_bins | ||
| self.blank_token_id = blank_token_id | ||
|
|
||
| if blank_token_id != vocab_size: | ||
| logger.warning( | ||
| f"blank_token_id ({blank_token_id}) should equal vocab_size ({vocab_size}) " | ||
| "for correct embedding table sizing. The embedding table size is vocab_size + 1 " | ||
| "to accommodate the blank token at index vocab_size." | ||
| ) | ||
|
|
||
| if isinstance(encoder_config, dict): | ||
| self.encoder_config = ParakeetEncoderConfig(**encoder_config) | ||
| elif encoder_config is None: | ||
| self.encoder_config = ParakeetEncoderConfig() | ||
| else: | ||
| self.encoder_config = encoder_config | ||
|
|
||
| self.initializer_range = self.encoder_config.initializer_range | ||
|
|
||
| super().__init__( | ||
| pad_token_id=pad_token_id, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): | ||
| r""" | ||
| Instantiate a [`ParakeetTDTConfig`] (or a derived class) from parakeet encoder model configuration. | ||
|
|
||
| Returns: | ||
| [`ParakeetTDTConfig`]: An instance of a configuration object | ||
| """ | ||
| return cls(encoder_config=encoder_config.to_dict(), **kwargs) | ||
|
|
||
|
|
||
| __all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig", "ParakeetTDTConfig"] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you shift this up with the other configs?