TDT for HF#41545
Conversation
ade8e2c to
93977e7
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, fastspeech2_conformer, parakeet |
93977e7 to
8a3e1cd
Compare
ebezzam
left a comment
There was a problem hiding this comment.
@hainan-xv thanks for the PR to add the TDT variant!
It may seem like a lot of comments at the first glance, but they are mainly about Transformers conventions which aren't obvious. So I've tried to be explicit to help you with the changes (but let me know if something is unclear). It's already a great start that you've implemented the changes through modular 👏
A couple other points that come to mind:
- I suppose this is the Transformers-compatible checkpoint you've created with conversion script?
- We should also update the documentation file to mention the TDT variant
Thanks 🤗
| @auto_docstring | ||
| class ParakeetPreTrainedModel(PreTrainedModel): | ||
| config: ParakeetCTCConfig | ||
| config: PreTrainedConfig |
There was a problem hiding this comment.
Makes sense to change this. Just wondering if you've tried removing the line altogether? Or did you have to specify config to something?
| "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", | ||
| "MODEL_FOR_CAUSAL_LM_MAPPING", | ||
| "MODEL_FOR_CTC_MAPPING", | ||
| "MODEL_FOR_TDT_MAPPING", |
There was a problem hiding this comment.
Can you place in alphabetic order?
| self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias) | ||
| self.depthwise_conv = nn.Conv1d( | ||
| channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True | ||
| channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=config.attention_bias | ||
| ) | ||
| self.norm = nn.BatchNorm1d(channels) | ||
| self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True) | ||
| self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias) |
There was a problem hiding this comment.
This PR seems to have already made this change and merged into main with the name config.convolution_bias. Can you sync with main?
| "AutoModelForAudioXVector", | ||
| "AutoModelForCausalLM", | ||
| "AutoModelForCTC", | ||
| "AutoModelForTDT", |
There was a problem hiding this comment.
Can you also place in alphabetic order?
| dropout=0, | ||
| vocab_size=1024, | ||
| forget_gate_bias=1.0, | ||
| t_max=None, |
There was a problem hiding this comment.
(Transformers convention) is this parameter used? I see it is only used here. If the final model checkpoint does not use it, the Transformer convention is to remove such variables and unused code paths.
If it is used, could you use a more explicit name than t_max? Namely being more verbose on what t is.
There was a problem hiding this comment.
Similarly if t_max is always used, from what I understand forget_gate_bias would not be used here so that code path could be removed instead
| enc: torch.Tensor, | ||
| pred: torch.Tensor, |
There was a problem hiding this comment.
can we use more verbose variable names? e.g. encoder_output and decoder_output
| ("parakeet_tdt_decoder", "ParakeetTDTDecoderConfig"), | ||
| ("parakeet_tdt_joint", "ParakeetTDTJointConfig"), |
There was a problem hiding this comment.
Such mappings can be removed after removing ParakeetTDTDecoderConfig and ParakeetTDTJointConfig
| encoder_kwargs=None, | ||
| decoder_kwargs=None, | ||
| joint_kwargs=None, |
There was a problem hiding this comment.
Adapting to new configuration structure
| pass | ||
|
|
||
|
|
||
| class ParakeetTDTDecoderModelTester: |
There was a problem hiding this comment.
Such tests won't be needed anymore for the decoder and joint "models" after integrating into the TDT model!
| return [x["array"] for x in speech_samples] | ||
|
|
||
| @slow | ||
| def test_1b_model_integration(self): |
There was a problem hiding this comment.
I suppose the reproducers and integration tests still need to be done?
The ones for CTC are a good example
ebezzam
left a comment
There was a problem hiding this comment.
A couple more thoughts after going through the paper
| return BaseModelOutput(last_hidden_state=output) | ||
|
|
||
|
|
||
| class ParakeetTDTPredictor(ParakeetPreTrainedModel): |
There was a problem hiding this comment.
(Transformers convention) Similar to ParakeetTDTJoint this can inherit from nn.Module instead
| logits = self.joint.joint_net(self.joint.enc(encoder_outputs.last_hidden_state)) #[:,:,:self.joint.vocab_size] | ||
|
|
||
| return CausalLMOutput( | ||
| loss=torch.sum(encoder_outputs.last_hidden_state), # a fake loss here. |
There was a problem hiding this comment.
| hidden_state = None, | ||
| **kwargs: Unpack[TransformersKwargs], |
There was a problem hiding this comment.
Can be dropped as hidden_state unused and **kwargs should no longer be necessary when ParakeetTDTPredictor inherits from nn.Module
|
#44171 is the current PR for adding TDT |
What does this PR do?
Parakeet TDT model integration.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.