Skip to content

Parakeet tdt#44171

Open
lmaksym wants to merge 73 commits intohuggingface:mainfrom
lmaksym:parakeet-tdt
Open

Parakeet tdt#44171
lmaksym wants to merge 73 commits intohuggingface:mainfrom
lmaksym:parakeet-tdt

Conversation

@lmaksym
Copy link
Copy Markdown

@lmaksym lmaksym commented Feb 20, 2026

What does this PR do?

This PR adds TDT decoder support for Parakeet ASR models, extending the existing CTC-only implementation.
It incorporates the initial TDT integration work from #41545 by @hainan-xv (was not merged) and and addresses all review feedback from both #41545 and #43357.

Changes

  • ParakeetForTDT model with greedy TDT decoding in generate()
  • ParakeetTDTDecoder (LSTM prediction network) and ParakeetTDTJointNetwork as nn.Module subclasses
  • Per-token timestamp generation via return_timestamps=True
  • AutoModelForTDT auto class with pipeline, processor, and tokenizer integration
  • Flat ParakeetTDTConfig matching the CTC pattern (no nested decoder/joint configs)
  • Shared ParakeetPreTrainedModel base between CTC and TDT (no separate TDT base class)
  • NeMo-to-HF weight conversion script for TDT models
  • Documentation and tests following existing CTC patterns

Validation

  • 278 unit tests pass, make check-repo passes
  • CTC model unaffected by changes
  • LibriSpeech test-clean: 2.09% WER (matches NVIDIA published ~2-3%)
  • Timestamps validated against commercial ASR (94.3% within 2 frames)
  • Model: MaksL/parakeet-tdt-0.6b-v3

Before submitting

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ebezzam and @hainan-xv please review

-->

Hainan Xu and others added 2 commits February 20, 2026 09:45
Implement Token-and-Duration Transducer (TDT) decoding for Parakeet models,
extending the existing CTC-only support. This adds ParakeetForTDT with greedy
TDT decoding in generate(), per-token timestamp generation, and full
integration with AutoModelForTDT, processors, and ASR pipeline.
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lmaksym thank you putting together the PRs cleanly! I pushed a few changes for adapting to Transformers convention and added integration tests to compare with the original model from NeMo.

@hainan-xv and @nithinraok, your input could be useful for the TDT decoding, and also the loss computation.

Comment thread src/transformers/models/parakeet/modular_parakeet.py Outdated
Comment thread src/transformers/models/parakeet/modular_parakeet.py Outdated
- Use -100 label padding for training (HF convention)
- Fix timestamp recording in inner blank-seeking loop
- Add max_symbols_per_step guard matching NeMo
- Clean up decoding loop
- Add TDT training example to docs
- Use setUpClass for TDT integration tests
Copy link
Copy Markdown

@hainan-xv hainan-xv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment on the loss computation part.

Comment thread src/transformers/models/parakeet/modeling_parakeet.py Outdated
Comment thread tests/models/parakeet/test_modeling_parakeet.py Outdated
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lmaksym thanks for porting the TDT loss! it's nice (1) to not have to depend on torchaudio and (2) to make the TDT loss available in Transformers!

It is functional with this example (single GPU): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-tdt_training_snippet-py
But quite slow...

I wonder if there is a custom gradient computation in NeMo? As I noticed in the paper (Section 3.1), they say "We derive an analytical solution for the gradient of the TDT loss, since automatic differentiation for transducer loss is highly inefficient."

FYI I can test/fix on my side for multi-GPU compatibility.

Comment thread src/transformers/models/parakeet/modular_parakeet.py Outdated
@lmaksym
Copy link
Copy Markdown
Author

lmaksym commented Mar 3, 2026

@lmaksym thanks for porting the TDT loss! it's nice (1) to not have to depend on torchaudio and (2) to make the TDT loss available in Transformers!

It is functional with this example (single GPU): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-tdt_training_snippet-py But quite slow...

I wonder if there is a custom gradient computation in NeMo? As I noticed in the paper (Section 3.1), they say "We derive an analytical solution for the gradient of the TDT loss, since automatic differentiation for transducer loss is highly inefficient."

FYI I can test/fix on my side for multi-GPU compatibility.

I'll look into that

Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminders to update/check with final checkpoint and nit

Comment thread docs/source/en/model_doc/parakeet.md Outdated
Comment thread src/transformers/models/parakeet/modular_parakeet.py
Comment thread docs/source/en/model_doc/parakeet.md
Comment thread docs/source/en/model_doc/parakeet.md Outdated
Comment thread src/transformers/models/parakeet/processing_parakeet.py Outdated
eustlb
eustlb approved these changes Apr 16, 2026
Copy link
Copy Markdown
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🚀 very nice work @ebezzam and @lmaksym

  • for the loss, I used kernels to allow us to have something as good as numba implem. Benchmarked with this script, it's looking good! Tested via loss and gradients comparison.
Config Kernel vs PyTorch (speed) Kernel vs PyTorch (memory) Kernel vs NeMo (speed)
B=1 T=50 U=15 309x faster 225x less 7.8x faster
B=2 T=50 U=20 311x faster 250x less 7.4x faster
B=4 T=100 U=30 296x faster 255x less 5.1x faster
B=4 T=200 U=60 259x faster 256x less 3.7x faster
B=8 T=200 U=60 245x faster 256x less 3.8x faster
B=8 T=400 U=100 201x faster 241x less 3.6x faster
  • as you pointed out @ebezzam, look like lstm layers are not compatible with compile, making that we cannot get much more perfs with it compared to direct cuda graphing as in NeMo repo. I suggest we explore solution for this in a subsequent PR

Comment on lines 288 to 290
"deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1},
"tdt-loss": {"repo_id": "eustlb/tdt-loss", "version": 1},
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ErikKaum pinging you here because your YouTube kernel tutorial helped a lot for this 😊 What are the next steps to move my tdt kernel from my repo to kernels-community and compile for other environments?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eustlb thanks for creating the kernel! btw I changed from "version": 1 to "revision": 1 as your kernel is rather in a v1 branch. Otherwise it wasn't loading as expected since the main branch is empty.

And maybe we need to also add the source to the main branch? I was a bit confused where the content was at first 😝

I guess @ErikKaum will have have best practice tips!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I just used the same convention as for other hub kernels: "version": 1 corresponding to a v1 branch so I am not so sure about changing "version": 1 to "revision": 1

Comment on lines +1462 to +1468
supported_modes = getattr(self, "_supported_generation_modes", None)
if supported_modes is not None and generation_mode not in supported_modes:
raise ValueError(
f"{self.__class__.__name__} only supports {supported_modes}, but got "
f"generation mode '{generation_mode}'."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this to be able to do

class ParakeetForTDT(ParakeetPreTrainedModel, ParakeetTDTGenerationMixin):
    _supported_generation_modes = [GenerationMode.GREEDY_SEARCH]

Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eustlb I did a run through of the tests and examples, and everything is passing!

Except for the kernel, but because of my torch setup. Moreover on that, I think we can improve the Pytorch fallback handling?

Comment on lines 288 to 290
"deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1},
"tdt-loss": {"repo_id": "eustlb/tdt-loss", "version": 1},
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eustlb thanks for creating the kernel! btw I changed from "version": 1 to "revision": 1 as your kernel is rather in a v1 branch. Otherwise it wasn't loading as expected since the main branch is empty.

And maybe we need to also add the source to the main branch? I was a bit confused where the content was at first 😝

I guess @ErikKaum will have have best practice tips!

Comment on lines +376 to +377
# Since we only read from `_HUB_KERNEL_MAPPING`, we can allow all kernels
kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we hardcode allow_all_kernels=True since we only read kernels from the library defined _HUB_KERNEL_MAPPING?

except FileNotFoundError:
except FileNotFoundError as e:
mapping[kernel_name] = None
logger.warning_once(f"Failed to load kernel {kernel_name}: {e}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a helpful error message, otherwise kernel may not load without notifying the user! E.g. due to different Torch.

For example it will now print:

[transformers] Failed to load kernel tdt-loss: Cannot find a build variant for this system in eustlb/tdt-loss (revision: v1). Available variants: torch211-cxx11-cu128-x86_64-linux

Comment on lines +28 to +37
kernel = lazy_load_kernel("tdt-loss")
if kernel is None or not hasattr(kernel, "tdt_loss"):
logger.warning_once("Falling back to pure PyTorch implementation.")
return None
return kernel
except (ImportError, ModuleNotFoundError):
return None
except Exception as e:
logger.warning_once(f"Failed to load TDT CUDA kernel: {e}. Falling back to pure PyTorch implementation.")
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is error handling in lazy_load_kernel, maybe we don't need error handling here as well? Or try to upstream to lazy_load_kernel

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep agree here



@auto_docstring
class LasrProcessor(ProcessorMixin):
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that Parakeet processor is handling TDT decoding, simpler to just create a new LasrProcessor than having to overwrite nearly everything from Parakeet's processor

@ebezzam
Copy link
Copy Markdown
Contributor

ebezzam commented Apr 20, 2026

run-slow: parakeet

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/parakeet"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 5e6d1f99 workflow commit (merge commit)
PR fd9f8b1b branch commit (from PR)
main ad0c0f9a base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, encodec, lasr, parakeet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants