Add Parakeet TDT model support#43357
Conversation
|
@ebezzam Sorry, I haven’t found it earlier. It appears to be stuck. What do you suggest we do next? Should I check whether I’m following the principles outlined in the PR review, or should I wait until it’s accepted? |
|
I mean until progress with ongoing pr |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, parakeet |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43357&sha=6e9422 |
|
@lmaksym yes that branch is still waiting on updates. But since it was already started, could you branch off from the fork/branch of that PR? Namely from here: https://github.com/hainan-xv/transformers/tree/hf_transformer_pr And add your changes / address my comments on #41545 You can then open a new PR so both yours and @hainan-xv's contributions are taken into account. Thanks and let me know if it's unclear! |
@ebezzam Sounds good. Would you mind taking a quick look before I move the code? It’ll help me move faster with that PR, if that works for you. |
ebezzam
left a comment
There was a problem hiding this comment.
@lmaksym thanks for the PR! I've left some initial comments. Let's have another iteration after you have:
- Forked the other fork/branch.
- Moved your code and addressed these comments.
- Opened a new PR.
I had a small chat with @hainan-xv who started the other PR, and he's happy to have your contributions and also provide his feedback (as he's from NVIDIA).
Thanks 🤗
| ("paligemma", "PaliGemmaModel"), | ||
| ("parakeet_ctc", "ParakeetForCTC"), | ||
| ("parakeet_encoder", "ParakeetEncoder"), | ||
| ("parakeet_tdt", "ParakeetForTDT"), |
There was a problem hiding this comment.
In the other PR, you'll see that he added some code for loading AutoModelForTDT. We may want to keep that (still need to think about it) when you apply your changes to there
| from transformers import AutoModelForCTC, AutoProcessor | ||
| from datasets import load_dataset, Audio | ||
| import torch | ||
|
|
||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
|
||
| processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b") | ||
| model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device) | ||
|
|
||
| ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
| ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) | ||
| speech_samples = [el['array'] for el in ds["audio"][:5]] | ||
|
|
||
| inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate) | ||
| inputs.to(model.device, dtype=model.dtype) | ||
| outputs = model.generate(**inputs) | ||
| print(processor.batch_decode(outputs)) |
There was a problem hiding this comment.
Could you add example usage like this for the TDT model?
Also I'm noticing that processor.batch_decode should have skip_special_tokens=True in the example to not have all the <pad> tokens
| ## ParakeetTDTConfig | ||
|
|
||
| [[autodoc]] ParakeetTDTConfig |
There was a problem hiding this comment.
Could you shift this up with the other configs?
| from .feature_extraction_parakeet import * | ||
| from .modeling_parakeet import * | ||
| from .processing_parakeet import * | ||
| from .tokenization_parakeet_fast import * |
There was a problem hiding this comment.
this import should also be fixed? as there is no tokenization_parakeet_fast
|
|
||
| Args: | ||
| vocab_size (`int`, *optional*, defaults to 8192): | ||
| Vocabulary size of the model (SentencePiece tokenizer). TDT uses a larger vocabulary than CTC. |
There was a problem hiding this comment.
We can be more concise here "Vocabulary size of the model."
|
|
||
| @auto_docstring | ||
| @can_return_tuple | ||
| def forward( |
There was a problem hiding this comment.
See my comments for the forward method on the other PR (here). Normally we call the forward method in generate and we'd like a loss to be computed.
| def forward( | ||
| self, | ||
| input_ids: torch.LongTensor, | ||
| hidden_state: tuple[torch.Tensor, torch.Tensor] | None = None, |
There was a problem hiding this comment.
split as individual tensor inputs, namely hidden_state and cell_state.
And what we can do to avoid init_state (see here) is to initialize those tensors as you do in init_state if they are None
|
|
||
| return output, hidden_state | ||
|
|
||
| def init_state( |
There was a problem hiding this comment.
(Transformers convention) Let's remove this method. As much as possible we try to define modules that only have an __init__ and a forward method
| # Initialize decoder state with same dtype as encoder output | ||
| decoder_state = self.decoder.init_state(batch_size, device, dtype=encoder_hidden.dtype) |
|
hi @lmaksym, checking in if you've had time to work on this? |
|
@lmaksym yes that works! just wanted to make sure you were still interested to work on it |
What does this PR do?
Add TDT (Token Duration Transducer) decoder support for Parakeet ASR models.
TDT is a transducer-based architecture that jointly predicts tokens and their durations, enabling efficient decoding with accurate word-level timestamps. Unlike CTC, TDT can skip multiple frames at once based on predicted duration.
Changes
ParakeetForTDTmodel class with:ParakeetTDTConfigconfiguration classParakeetForTDTIntegrationTestwith exact output matchingNotes
MaksL/parakeet-tdt-0.6b-v3(converted HF format)nvidia/parakeet-tdt-0.6b-v3once NVIDIA adds HF format to their repoReferences
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@eustlb @ebezzam @vasqu - audio models