Parakeet tdt#44171
Conversation
Implement Token-and-Duration Transducer (TDT) decoding for Parakeet models, extending the existing CTC-only support. This adds ParakeetForTDT with greedy TDT decoding in generate(), per-token timestamp generation, and full integration with AutoModelForTDT, processors, and ASR pipeline.
6c98cb8 to
f2b4938
Compare
There was a problem hiding this comment.
@lmaksym thank you putting together the PRs cleanly! I pushed a few changes for adapting to Transformers convention and added integration tests to compare with the original model from NeMo.
@hainan-xv and @nithinraok, your input could be useful for the TDT decoding, and also the loss computation.
- Use -100 label padding for training (HF convention) - Fix timestamp recording in inner blank-seeking loop - Add max_symbols_per_step guard matching NeMo - Clean up decoding loop - Add TDT training example to docs - Use setUpClass for TDT integration tests
7f70c24 to
760b4b6
Compare
hainan-xv
left a comment
There was a problem hiding this comment.
Left a comment on the loss computation part.
ebezzam
left a comment
There was a problem hiding this comment.
@lmaksym thanks for porting the TDT loss! it's nice (1) to not have to depend on torchaudio and (2) to make the TDT loss available in Transformers!
It is functional with this example (single GPU): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-tdt_training_snippet-py
But quite slow...
I wonder if there is a custom gradient computation in NeMo? As I noticed in the paper (Section 3.1), they say "We derive an analytical solution for the gradient of the TDT loss, since automatic differentiation for transducer loss is highly inefficient."
FYI I can test/fix on my side for multi-GPU compatibility.
I'll look into that |
ebezzam
left a comment
There was a problem hiding this comment.
Reminders to update/check with final checkpoint and nit
eustlb
left a comment
There was a problem hiding this comment.
LGTM 🚀 very nice work @ebezzam and @lmaksym
- for the loss, I used
kernelsto allow us to have something as good as numba implem. Benchmarked with this script, it's looking good! Tested via loss and gradients comparison.
| Config | Kernel vs PyTorch (speed) | Kernel vs PyTorch (memory) | Kernel vs NeMo (speed) |
|---|---|---|---|
| B=1 T=50 U=15 | 309x faster | 225x less | 7.8x faster |
| B=2 T=50 U=20 | 311x faster | 250x less | 7.4x faster |
| B=4 T=100 U=30 | 296x faster | 255x less | 5.1x faster |
| B=4 T=200 U=60 | 259x faster | 256x less | 3.7x faster |
| B=8 T=200 U=60 | 245x faster | 256x less | 3.8x faster |
| B=8 T=400 U=100 | 201x faster | 241x less | 3.6x faster |
- as you pointed out @ebezzam, look like lstm layers are not compatible with compile, making that we cannot get much more perfs with it compared to direct cuda graphing as in NeMo repo. I suggest we explore solution for this in a subsequent PR
| "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, | ||
| "tdt-loss": {"repo_id": "eustlb/tdt-loss", "version": 1}, | ||
| } |
There was a problem hiding this comment.
@ErikKaum pinging you here because your YouTube kernel tutorial helped a lot for this 😊 What are the next steps to move my tdt kernel from my repo to kernels-community and compile for other environments?
There was a problem hiding this comment.
@eustlb thanks for creating the kernel! btw I changed from "version": 1 to "revision": 1 as your kernel is rather in a v1 branch. Otherwise it wasn't loading as expected since the main branch is empty.
And maybe we need to also add the source to the main branch? I was a bit confused where the content was at first 😝
I guess @ErikKaum will have have best practice tips!
There was a problem hiding this comment.
here I just used the same convention as for other hub kernels: "version": 1 corresponding to a v1 branch so I am not so sure about changing "version": 1 to "revision": 1
| supported_modes = getattr(self, "_supported_generation_modes", None) | ||
| if supported_modes is not None and generation_mode not in supported_modes: | ||
| raise ValueError( | ||
| f"{self.__class__.__name__} only supports {supported_modes}, but got " | ||
| f"generation mode '{generation_mode}'." | ||
| ) | ||
|
|
There was a problem hiding this comment.
added this to be able to do
class ParakeetForTDT(ParakeetPreTrainedModel, ParakeetTDTGenerationMixin):
_supported_generation_modes = [GenerationMode.GREEDY_SEARCH]… nvidia checkpoint, style checks.
| "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, | ||
| "tdt-loss": {"repo_id": "eustlb/tdt-loss", "version": 1}, | ||
| } |
There was a problem hiding this comment.
@eustlb thanks for creating the kernel! btw I changed from "version": 1 to "revision": 1 as your kernel is rather in a v1 branch. Otherwise it wasn't loading as expected since the main branch is empty.
And maybe we need to also add the source to the main branch? I was a bit confused where the content was at first 😝
I guess @ErikKaum will have have best practice tips!
| # Since we only read from `_HUB_KERNEL_MAPPING`, we can allow all kernels | ||
| kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True) |
There was a problem hiding this comment.
Can we hardcode allow_all_kernels=True since we only read kernels from the library defined _HUB_KERNEL_MAPPING?
| except FileNotFoundError: | ||
| except FileNotFoundError as e: | ||
| mapping[kernel_name] = None | ||
| logger.warning_once(f"Failed to load kernel {kernel_name}: {e}") |
There was a problem hiding this comment.
Adding a helpful error message, otherwise kernel may not load without notifying the user! E.g. due to different Torch.
For example it will now print:
[transformers] Failed to load kernel tdt-loss: Cannot find a build variant for this system in eustlb/tdt-loss (revision: v1). Available variants: torch211-cxx11-cu128-x86_64-linux
| kernel = lazy_load_kernel("tdt-loss") | ||
| if kernel is None or not hasattr(kernel, "tdt_loss"): | ||
| logger.warning_once("Falling back to pure PyTorch implementation.") | ||
| return None | ||
| return kernel | ||
| except (ImportError, ModuleNotFoundError): | ||
| return None | ||
| except Exception as e: | ||
| logger.warning_once(f"Failed to load TDT CUDA kernel: {e}. Falling back to pure PyTorch implementation.") | ||
| return None |
There was a problem hiding this comment.
Since there is error handling in lazy_load_kernel, maybe we don't need error handling here as well? Or try to upstream to lazy_load_kernel
|
|
||
|
|
||
| @auto_docstring | ||
| class LasrProcessor(ProcessorMixin): |
There was a problem hiding this comment.
Now that Parakeet processor is handling TDT decoding, simpler to just create a new LasrProcessor than having to overwrite nearly everything from Parakeet's processor
|
run-slow: parakeet |
|
This comment contains models: ["models/parakeet"] |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, encodec, lasr, parakeet |
What does this PR do?
This PR adds TDT decoder support for Parakeet ASR models, extending the existing CTC-only implementation.
It incorporates the initial TDT integration work from #41545 by @hainan-xv (was not merged) and and addresses all review feedback from both #41545 and #43357.
Changes
ParakeetForTDTmodel with greedy TDT decoding ingenerate()ParakeetTDTDecoder(LSTM prediction network) andParakeetTDTJointNetworkasnn.Modulesubclassesreturn_timestamps=TrueAutoModelForTDTauto class with pipeline, processor, and tokenizer integrationParakeetTDTConfigmatching the CTC pattern (no nested decoder/joint configs)ParakeetPreTrainedModelbase between CTC and TDT (no separate TDT base class)Validation
make check-repopassesBefore submitting
Pull Request section?
documentation guidelines, and
here are tips on formatting docstrings.
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ebezzam and @hainan-xv please review
-->