Skip to content

TDT for HF#41545

Closed
hainan-xv wants to merge 1 commit intohuggingface:mainfrom
hainan-xv:hf_transformer_pr
Closed

TDT for HF#41545
hainan-xv wants to merge 1 commit intohuggingface:mainfrom
hainan-xv:hf_transformer_pr

Conversation

@hainan-xv
Copy link
Copy Markdown

@hainan-xv hainan-xv commented Oct 13, 2025

What does this PR do?

Parakeet TDT model integration.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Rocketknight1
Copy link
Copy Markdown
Member

cc @eustlb @ebezzam for audio

@hainan-xv hainan-xv force-pushed the hf_transformer_pr branch 3 times, most recently from ade8e2c to 93977e7 Compare October 23, 2025 19:20
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, fastspeech2_conformer, parakeet

@hainan-xv hainan-xv marked this pull request as ready for review October 23, 2025 19:26
@eustlb eustlb self-assigned this Oct 23, 2025
@eustlb eustlb added the Audio label Oct 23, 2025
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hainan-xv thanks for the PR to add the TDT variant!

It may seem like a lot of comments at the first glance, but they are mainly about Transformers conventions which aren't obvious. So I've tried to be explicit to help you with the changes (but let me know if something is unclear). It's already a great start that you've implemented the changes through modular 👏

A couple other points that come to mind:

  • I suppose this is the Transformers-compatible checkpoint you've created with conversion script?
  • We should also update the documentation file to mention the TDT variant

Thanks 🤗

@auto_docstring
class ParakeetPreTrainedModel(PreTrainedModel):
config: ParakeetCTCConfig
config: PreTrainedConfig
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to change this. Just wondering if you've tried removing the line altogether? Or did you have to specify config to something?

"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_CTC_MAPPING",
"MODEL_FOR_TDT_MAPPING",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you place in alphabetic order?

Comment on lines +493 to +498
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias)
self.depthwise_conv = nn.Conv1d(
channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=config.attention_bias
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR seems to have already made this change and merged into main with the name config.convolution_bias. Can you sync with main?

"AutoModelForAudioXVector",
"AutoModelForCausalLM",
"AutoModelForCTC",
"AutoModelForTDT",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also place in alphabetic order?

dropout=0,
vocab_size=1024,
forget_gate_bias=1.0,
t_max=None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Transformers convention) is this parameter used? I see it is only used here. If the final model checkpoint does not use it, the Transformer convention is to remove such variables and unused code paths.

If it is used, could you use a more explicit name than t_max? Namely being more verbose on what t is.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly if t_max is always used, from what I understand forget_gate_bias would not be used here so that code path could be removed instead

Comment on lines +606 to +607
enc: torch.Tensor,
pred: torch.Tensor,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use more verbose variable names? e.g. encoder_output and decoder_output

Comment on lines +306 to +307
("parakeet_tdt_decoder", "ParakeetTDTDecoderConfig"),
("parakeet_tdt_joint", "ParakeetTDTJointConfig"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such mappings can be removed after removing ParakeetTDTDecoderConfig and ParakeetTDTJointConfig

Comment on lines +617 to +619
encoder_kwargs=None,
decoder_kwargs=None,
joint_kwargs=None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adapting to new configuration structure

pass


class ParakeetTDTDecoderModelTester:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such tests won't be needed anymore for the decoder and joint "models" after integrating into the TDT model!

return [x["array"] for x in speech_samples]

@slow
def test_1b_model_integration(self):
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the reproducers and integration tests still need to be done?

The ones for CTC are a good example

Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple more thoughts after going through the paper

return BaseModelOutput(last_hidden_state=output)


class ParakeetTDTPredictor(ParakeetPreTrainedModel):
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Transformers convention) Similar to ParakeetTDTJoint this can inherit from nn.Module instead

logits = self.joint.joint_net(self.joint.enc(encoder_outputs.last_hidden_state)) #[:,:,:self.joint.vocab_size]

return CausalLMOutput(
loss=torch.sum(encoder_outputs.last_hidden_state), # a fake loss here.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eq 4 of the paper for the loss.

Also I just noticed that the forward method isn't being called by generate? This should be the case (see CTC), so we'll have to rethink how the components are called within forward. We can see after a first iteration of changes.

Comment on lines +676 to +677
hidden_state = None,
**kwargs: Unpack[TransformersKwargs],
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be dropped as hidden_state unused and **kwargs should no longer be necessary when ParakeetTDTPredictor inherits from nn.Module

@ebezzam ebezzam self-assigned this Dec 18, 2025
@ebezzam ebezzam mentioned this pull request Jan 21, 2026
5 tasks
@lmaksym lmaksym mentioned this pull request Feb 20, 2026
4 tasks
@ebezzam
Copy link
Copy Markdown
Contributor

ebezzam commented Mar 9, 2026

#44171 is the current PR for adding TDT

@ebezzam ebezzam closed this Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants