Skip to content

Add Parakeet TDT model support#43357

Closed
lmaksym wants to merge 4 commits intohuggingface:mainfrom
lmaksym:add_support_parakeet_tdt
Closed

Add Parakeet TDT model support#43357
lmaksym wants to merge 4 commits intohuggingface:mainfrom
lmaksym:add_support_parakeet_tdt

Conversation

@lmaksym
Copy link
Copy Markdown

@lmaksym lmaksym commented Jan 19, 2026

What does this PR do?

Add TDT (Token Duration Transducer) decoder support for Parakeet ASR models.

TDT is a transducer-based architecture that jointly predicts tokens and their durations, enabling efficient decoding with accurate word-level timestamps. Unlike CTC, TDT can skip multiple frames at once based on predicted duration.

Changes

  • Add ParakeetForTDT model class with:
    • LSTM-based prediction network (decoder)
    • Joint network combining encoder and decoder outputs
    • Separate token and duration heads
    • Greedy TDT decoding with optional timestamp generation
  • Add ParakeetTDTConfig configuration class
  • Add ParakeetForTDTIntegrationTest with exact output matching
  • Add fixture generation script for reproducible tests
  • Update documentation with TDT model description and API reference

Notes

  • Integration tests currently use MaksL/parakeet-tdt-0.6b-v3 (converted HF format)
  • TODO: Update to nvidia/parakeet-tdt-0.6b-v3 once NVIDIA adds HF format to their repo

References

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@eustlb @ebezzam @vasqu - audio models

@ebezzam
Copy link
Copy Markdown
Contributor

ebezzam commented Jan 21, 2026

@lmaksym Thanks for the PR! There's already an ongoing one for TDT #41545

Have you taken a look at it?

@lmaksym
Copy link
Copy Markdown
Author

lmaksym commented Jan 21, 2026

@ebezzam Sorry, I haven’t found it earlier. It appears to be stuck. What do you suggest we do next? Should I check whether I’m following the principles outlined in the PR review, or should I wait until it’s accepted?

@lmaksym
Copy link
Copy Markdown
Author

lmaksym commented Jan 21, 2026

I mean until progress with ongoing pr

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, parakeet

@lmaksym lmaksym closed this Jan 21, 2026
@lmaksym lmaksym reopened this Jan 21, 2026
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43357&sha=6e9422

@lmaksym
Copy link
Copy Markdown
Author

lmaksym commented Jan 21, 2026

@lmaksym Thanks for the PR! There's already an ongoing one for TDT #41545

Have you taken a look at it?

Reviewed it and added a fix that went against the principles mentioned in the PR comments.

@ebezzam
Copy link
Copy Markdown
Contributor

ebezzam commented Jan 21, 2026

@lmaksym yes that branch is still waiting on updates. But since it was already started, could you branch off from the fork/branch of that PR? Namely from here: https://github.com/hainan-xv/transformers/tree/hf_transformer_pr

And add your changes / address my comments on #41545

You can then open a new PR so both yours and @hainan-xv's contributions are taken into account. Thanks and let me know if it's unclear!

@lmaksym
Copy link
Copy Markdown
Author

lmaksym commented Jan 22, 2026

@lmaksym yes that branch is still waiting on updates. But since it was already started, could you branch off from the fork/branch of that PR? Namely from here: https://github.com/hainan-xv/transformers/tree/hf_transformer_pr

And add your changes / address my comments on #41545

You can then open a new PR so both yours and @hainan-xv's contributions are taken into account. Thanks and let me know if it's unclear!

@ebezzam Sounds good. Would you mind taking a quick look before I move the code? It’ll help me move faster with that PR, if that works for you.

Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lmaksym thanks for the PR! I've left some initial comments. Let's have another iteration after you have:

  1. Forked the other fork/branch.
  2. Moved your code and addressed these comments.
  3. Opened a new PR.

I had a small chat with @hainan-xv who started the other PR, and he's happy to have your contributions and also provide his feedback (as he's from NVIDIA).

Thanks 🤗

("paligemma", "PaliGemmaModel"),
("parakeet_ctc", "ParakeetForCTC"),
("parakeet_encoder", "ParakeetEncoder"),
("parakeet_tdt", "ParakeetForTDT"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the other PR, you'll see that he added some code for loading AutoModelForTDT. We may want to keep that (still need to think about it) when you apply your changes to there

Comment on lines 68 to 84
from transformers import AutoModelForCTC, AutoProcessor
from datasets import load_dataset, Audio
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:5]]

inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(model.device, dtype=model.dtype)
outputs = model.generate(**inputs)
print(processor.batch_decode(outputs))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add example usage like this for the TDT model?

Also I'm noticing that processor.batch_decode should have skip_special_tokens=True in the example to not have all the <pad> tokens

Comment on lines +229 to +231
## ParakeetTDTConfig

[[autodoc]] ParakeetTDTConfig
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you shift this up with the other configs?

from .feature_extraction_parakeet import *
from .modeling_parakeet import *
from .processing_parakeet import *
from .tokenization_parakeet_fast import *
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this import should also be fixed? as there is no tokenization_parakeet_fast


Args:
vocab_size (`int`, *optional*, defaults to 8192):
Vocabulary size of the model (SentencePiece tokenizer). TDT uses a larger vocabulary than CTC.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can be more concise here "Vocabulary size of the model."


@auto_docstring
@can_return_tuple
def forward(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comments for the forward method on the other PR (here). Normally we call the forward method in generate and we'd like a loss to be computed.

def forward(
self,
input_ids: torch.LongTensor,
hidden_state: tuple[torch.Tensor, torch.Tensor] | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split as individual tensor inputs, namely hidden_state and cell_state.

And what we can do to avoid init_state (see here) is to initialize those tensors as you do in init_state if they are None


return output, hidden_state

def init_state(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Transformers convention) Let's remove this method. As much as possible we try to define modules that only have an __init__ and a forward method

Comment on lines +1005 to +1006
# Initialize decoder state with same dtype as encoder output
decoder_state = self.decoder.init_state(batch_size, device, dtype=encoder_hidden.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to remove this method, see here

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a bunch for the reproducer script! We normally don't add them to Transformers, could you make it into a Gist and put a link to it like this

@ebezzam
Copy link
Copy Markdown
Contributor

ebezzam commented Feb 11, 2026

hi @lmaksym, checking in if you've had time to work on this?

@lmaksym
Copy link
Copy Markdown
Author

lmaksym commented Feb 11, 2026

hi @lmaksym, checking in if you've had time to work on this?

Hey @ebezzam sorry for a delay. Yep, had some things to sort out. Going to raise new PR till the end of the week. Is it ok?

@ebezzam
Copy link
Copy Markdown
Contributor

ebezzam commented Feb 11, 2026

@lmaksym yes that works! just wanted to make sure you were still interested to work on it

@lmaksym lmaksym mentioned this pull request Feb 20, 2026
4 tasks
@lmaksym
Copy link
Copy Markdown
Author

lmaksym commented Feb 20, 2026

@ebezzam new pr created #44171 so closing this one

@lmaksym lmaksym closed this Feb 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants