Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/source/en/model_doc/parakeet.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ Parakeet models, [introduced by NVIDIA NeMo](https://developer.nvidia.com/blog/p
- 1D convolution projection from encoder hidden size to vocabulary size (for optimal NeMo compatibility).
- CTC loss computation for training.
- Greedy CTC decoding for inference.
- [**ParakeetForTDT**](#parakeetfortdt): a Fast Conformer Encoder + a Token-and-Duration Transducer (TDT) decoder
- **TDT Decoder**: A transducer-based decoder that predicts both tokens and their durations:
- Prediction network (LSTM-based) for autoregressive token prediction.
- Joint network combining encoder and decoder outputs.
- Separate token and duration heads for parallel prediction.
- Greedy TDT decoding with optional timestamp generation.

The original implementation can be found in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo).
Model checkpoints are to be found under [the NVIDIA organization](https://huggingface.co/nvidia/models?search=parakeet).
Expand Down Expand Up @@ -219,3 +225,11 @@ outputs.loss.backward()
## ParakeetForCTC

[[autodoc]] ParakeetForCTC

## ParakeetTDTConfig

[[autodoc]] ParakeetTDTConfig
Comment on lines +229 to +231
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you shift this up with the other configs?


## ParakeetForTDT

[[autodoc]] ParakeetForTDT
3 changes: 1 addition & 2 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,9 +1770,8 @@ def __init__(self, vocab_file=None, *args):

def tokenizer(self, proto):
vocab_scores = self.vocab(proto)

_, merges = self.SpmExtractor(self.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
merges = generate_merges(bpe_vocab, vocab_scores)
tokenizer = Tokenizer(
BPE(
bpe_vocab,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@
("paligemma", "PaliGemmaConfig"),
("parakeet_ctc", "ParakeetCTCConfig"),
("parakeet_encoder", "ParakeetEncoderConfig"),
("parakeet_tdt", "ParakeetTDTConfig"),
("patchtsmixer", "PatchTSMixerConfig"),
("patchtst", "PatchTSTConfig"),
("pe_audio", "PeAudioConfig"),
Expand Down Expand Up @@ -792,6 +793,7 @@
("parakeet", "Parakeet"),
("parakeet_ctc", "Parakeet"),
("parakeet_encoder", "ParakeetEncoder"),
("parakeet_tdt", "ParakeetTDT"),
("patchtsmixer", "PatchTSMixer"),
("patchtst", "PatchTST"),
("pe_audio", "PeAudio"),
Expand Down Expand Up @@ -1032,6 +1034,7 @@
("parakeet_encoder", "parakeet"),
("lw_detr_vit", "lw_detr"),
("parakeet_ctc", "parakeet"),
("parakeet_tdt", "parakeet"),
("lasr_encoder", "lasr"),
("lasr_ctc", "lasr"),
("wav2vec2-bert", "wav2vec2_bert"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
("musicgen_melody", "MusicgenMelodyFeatureExtractor"),
("parakeet_ctc", "ParakeetFeatureExtractor"),
("parakeet_encoder", "ParakeetFeatureExtractor"),
("parakeet_tdt", "ParakeetFeatureExtractor"),
("pe_audio", "PeAudioFeatureExtractor"),
("pe_audio_video", "PeAudioFeatureExtractor"),
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("paligemma", "PaliGemmaModel"),
("parakeet_ctc", "ParakeetForCTC"),
("parakeet_encoder", "ParakeetEncoder"),
("parakeet_tdt", "ParakeetForTDT"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the other PR, you'll see that he added some code for loading AutoModelForTDT. We may want to keep that (still need to think about it) when you apply your changes to there

("patchtsmixer", "PatchTSMixerModel"),
("patchtst", "PatchTSTModel"),
("pe_audio", "PeAudioModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
("owlvit", "OwlViTProcessor"),
("paddleocr_vl", "PaddleOCRVLProcessor"),
("paligemma", "PaliGemmaProcessor"),
("parakeet_tdt", "ParakeetProcessor"),
("perception_lm", "PerceptionLMProcessor"),
("phi4_multimodal", "Phi4MultimodalProcessor"),
("pix2struct", "Pix2StructProcessor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@
("ovis2", "Qwen2Tokenizer" if is_tokenizers_available() else None),
("owlv2", "CLIPTokenizer" if is_tokenizers_available() else None),
("owlvit", "CLIPTokenizer" if is_tokenizers_available() else None),
("parakeet_tdt", "ParakeetTokenizer" if is_tokenizers_available() else None),
("pegasus", "PegasusTokenizer" if is_tokenizers_available() else None),
("pegasus_x", "PegasusTokenizer" if is_tokenizers_available() else None),
("perceiver", "PerceiverTokenizer"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/parakeet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .configuration_parakeet import *
from .feature_extraction_parakeet import *
from .modeling_parakeet import *
from .processing_parakeet import *
from .tokenization_parakeet_fast import *
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this import should also be fixed? as there is no tokenization_parakeet_fast

else:
import sys
Expand Down
111 changes: 110 additions & 1 deletion src/transformers/models/parakeet/configuration_parakeet.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,113 @@ def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs):
return cls(encoder_config=encoder_config.to_dict(), **kwargs)


__all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig"]
class ParakeetTDTConfig(PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ParakeetForTDT`]. It is used to instantiate a
Parakeet TDT (Token Duration Transducer) model according to the specified arguments, defining the model architecture.

TDT models jointly predict tokens and their durations, enabling accurate speech recognition with word-level
timestamps. Unlike CTC which only provides character-level timing, TDT predicts how many frames each token spans.

Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.

Args:
vocab_size (`int`, *optional*, defaults to 8192):
Vocabulary size of the model (SentencePiece tokenizer). TDT uses a larger vocabulary than CTC.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can be more concise here "Vocabulary size of the model."

decoder_hidden_size (`int`, *optional*, defaults to 640):
Hidden size of the LSTM prediction network (decoder).
decoder_num_layers (`int`, *optional*, defaults to 1):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename to num_decoder_layers

Number of LSTM layers in the prediction network.
joint_hidden_size (`int`, *optional*, defaults to 640):
Hidden size of the joint network that combines encoder and decoder outputs.
num_duration_bins (`int`, *optional*, defaults to 5):
Number of duration bins for predicting token durations. Each bin represents how many frames
to advance (e.g., [0, 1, 2, 3, 4] means 0-4 frames).
encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*):
The config object or dictionary of the encoder. TDT reuses the same FastConformer encoder as CTC.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can be more concise here "The config object or dictionary of the encoder."

blank_token_id (`int`, *optional*, defaults to 8192):
Token ID for the blank symbol. In TDT, blank indicates "no token emission, advance frames".
pad_token_id (`int`, *optional*, defaults to 8192):
Padding token id. Defaults to blank_token_id.
Comment on lines +257 to +260
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need both? Or can we just define pad_token_id like in the CTC model


Example:
```python
>>> from transformers import ParakeetForTDT, ParakeetTDTConfig

>>> # Initializing a Parakeet TDT configuration
>>> configuration = ParakeetTDTConfig()

>>> # Initializing a model from the configuration
>>> model = ParakeetForTDT(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```

This configuration class is based on the Parakeet TDT architecture from NVIDIA NeMo. TDT (Token Duration
Transducer) extends RNN-T by jointly predicting token durations, enabling efficient frame skipping during
decoding. You can find more details and pre-trained models at:
- [nvidia/parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2) (English)
- [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) (25 languages)

References:
- TDT Paper: https://arxiv.org/abs/2304.06795
- FastConformer Paper: https://arxiv.org/abs/2305.05084
Comment on lines +282 to +284
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this

"""

model_type = "parakeet_tdt"
sub_configs = {"encoder_config": ParakeetEncoderConfig}

def __init__(
self,
vocab_size=8192,
decoder_hidden_size=640,
decoder_num_layers=1,
joint_hidden_size=640,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a case where decoder_hidden_size and joint_hidden_size are different? if not, let's use one hidden_size parameter

num_duration_bins=5,
encoder_config: dict | ParakeetEncoderConfig = None,
blank_token_id=8192,
pad_token_id=8192,
**kwargs,
):
self.vocab_size = vocab_size
self.decoder_hidden_size = decoder_hidden_size
self.decoder_num_layers = decoder_num_layers
self.joint_hidden_size = joint_hidden_size
self.num_duration_bins = num_duration_bins
self.blank_token_id = blank_token_id

if blank_token_id != vocab_size:
logger.warning(
f"blank_token_id ({blank_token_id}) should equal vocab_size ({vocab_size}) "
"for correct embedding table sizing. The embedding table size is vocab_size + 1 "
"to accommodate the blank token at index vocab_size."
)

if isinstance(encoder_config, dict):
self.encoder_config = ParakeetEncoderConfig(**encoder_config)
elif encoder_config is None:
self.encoder_config = ParakeetEncoderConfig()
else:
self.encoder_config = encoder_config

self.initializer_range = self.encoder_config.initializer_range

super().__init__(
pad_token_id=pad_token_id,
**kwargs,
)

@classmethod
def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs):
r"""
Instantiate a [`ParakeetTDTConfig`] (or a derived class) from parakeet encoder model configuration.

Returns:
[`ParakeetTDTConfig`]: An instance of a configuration object
"""
return cls(encoder_config=encoder_config.to_dict(), **kwargs)


__all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig", "ParakeetTDTConfig"]
Loading