-
Notifications
You must be signed in to change notification settings - Fork 33.1k
TDT for HF #41545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TDT for HF #41545
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -301,7 +301,10 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): | |
| ("owlvit", "OwlViTModel"), | ||
| ("paligemma", "PaliGemmaModel"), | ||
| ("parakeet_ctc", "ParakeetForCTC"), | ||
| ("parakeet_tdt", "ParakeetForTDT"), | ||
| ("parakeet_encoder", "ParakeetEncoder"), | ||
| ("parakeet_tdt_decoder", "ParakeetTDTDecoder"), | ||
| ("parakeet_tdt_joint", "ParakeetTDTJoint"), | ||
| ("patchtsmixer", "PatchTSMixerModel"), | ||
| ("patchtst", "PatchTSTModel"), | ||
| ("pegasus", "PegasusModel"), | ||
|
|
@@ -1624,6 +1627,14 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): | |
| ] | ||
| ) | ||
|
|
||
| MODEL_FOR_TDT_MAPPING_NAMES = OrderedDict( | ||
| [ | ||
| # Model for Token-and-Duration Transducer (TDT) mapping. | ||
| ("parakeet_tdt", "ParakeetForTDT"), | ||
| ] | ||
| ) | ||
|
|
||
|
|
||
| MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | ||
| [ | ||
| # Model for Audio Classification mapping | ||
|
|
@@ -1883,6 +1894,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): | |
| CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES | ||
| ) | ||
| MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) | ||
| MODEL_FOR_TDT_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TDT_MAPPING_NAMES) | ||
| MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) | ||
| MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping( | ||
| CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES | ||
|
|
@@ -2200,6 +2212,11 @@ class AutoModelForCTC(_BaseAutoModelClass): | |
|
|
||
| AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") | ||
|
|
||
| class AutoModelForTDT(_BaseAutoModelClass): | ||
| _model_mapping = MODEL_FOR_TDT_MAPPING | ||
|
|
||
|
|
||
| AutoModelForTDT = auto_class_update(AutoModelForTDT, head_doc="token-and-duration transducer") | ||
|
|
||
| class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): | ||
| _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING | ||
|
|
@@ -2305,6 +2322,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", | ||
| "MODEL_FOR_CAUSAL_LM_MAPPING", | ||
| "MODEL_FOR_CTC_MAPPING", | ||
| "MODEL_FOR_TDT_MAPPING", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you place in alphabetic order? |
||
| "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", | ||
| "MODEL_FOR_DEPTH_ESTIMATION_MAPPING", | ||
| "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", | ||
|
|
@@ -2352,6 +2370,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| "AutoModelForAudioXVector", | ||
| "AutoModelForCausalLM", | ||
| "AutoModelForCTC", | ||
| "AutoModelForTDT", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also place in alphabetic order? |
||
| "AutoModelForDepthEstimation", | ||
| "AutoModelForImageClassification", | ||
| "AutoModelForImageSegmentation", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -490,12 +490,12 @@ def __init__(self, config: FastSpeech2ConformerConfig, module_config=None): | |
| kernel_size = module_config["kernel_size"] | ||
| self.activation = ACT2FN[module_config.get("activation", "silu")] | ||
| self.padding = (kernel_size - 1) // 2 | ||
| self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True) | ||
| self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias) | ||
| self.depthwise_conv = nn.Conv1d( | ||
| channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True | ||
| channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=config.attention_bias | ||
| ) | ||
| self.norm = nn.BatchNorm1d(channels) | ||
| self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True) | ||
| self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias) | ||
|
Comment on lines
+493
to
+498
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR seems to have already made this change and merged into main with the name |
||
|
|
||
| def forward(self, hidden_states, attention_mask=None): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -150,6 +150,65 @@ def __init__( | |
| self.initializer_range = initializer_range | ||
|
|
||
|
|
||
|
|
||
| class ParakeetTDTDecoderConfig(PreTrainedConfig): | ||
| model_type = "parakeet_tdt_decoder" | ||
| keys_to_ignore_at_inference = ["past_key_values"] | ||
| output_hidden_states = False | ||
|
|
||
| def __init__( | ||
| self, | ||
| hidden_size=640, | ||
| num_hidden_layers=1, | ||
| dropout=0, | ||
| vocab_size=1024, | ||
| forget_gate_bias=1.0, | ||
| t_max=None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Transformers convention) is this parameter used? I see it is only used here. If the final model checkpoint does not use it, the Transformer convention is to remove such variables and unused code paths. If it is used, could you use a more explicit name than
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly if |
||
| weights_init_scale=1.0, | ||
| hidden_hidden_bias_scale=0, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just |
||
| **kwargs, | ||
| ): | ||
| super().__init__( | ||
| **kwargs, | ||
| ) | ||
| self.hidden_size = hidden_size | ||
| self.num_hidden_layers = num_hidden_layers | ||
| self.dropout = dropout | ||
| self.vocab_size = vocab_size | ||
| self.forget_gate_bias=forget_gate_bias | ||
| self.t_max=t_max | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If used, also renaming |
||
| self.weights_init_scale=weights_init_scale | ||
| self.hidden_hidden_bias_scale=hidden_hidden_bias_scale | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above, renaming to |
||
|
|
||
|
|
||
| class ParakeetTDTJointConfig(PreTrainedConfig): | ||
| model_type = "parakeet_tdt_joint" | ||
| keys_to_ignore_at_inference = ["past_key_values"] | ||
|
|
||
| def __init__( | ||
| self, | ||
| enc_hidden_size=1024, | ||
| pred_hidden_size=640, | ||
| hidden_size=640, | ||
|
Comment on lines
+191
to
+192
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a case where If so, can we rename to |
||
| vocab_size=1024, | ||
| durations=[0,1,2,3,4], | ||
| norm=None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we remove |
||
| dropout=0.0, | ||
| activation='relu', | ||
| **kwargs, | ||
| ): | ||
| super().__init__( | ||
| **kwargs, | ||
| ) | ||
| self.enc_hidden_size = enc_hidden_size | ||
| self.pred_hidden_size = pred_hidden_size | ||
| self.hidden_size = hidden_size | ||
| self.vocab_size = vocab_size | ||
| self.durations = durations | ||
| self.dropout = dropout | ||
| self.activation = activation | ||
|
|
||
|
|
||
| class ParakeetCTCConfig(PreTrainedConfig): | ||
| r""" | ||
| This is the configuration class to store the configuration of a [`ParakeetForCTC`]. It is used to instantiate a | ||
|
|
@@ -232,4 +291,83 @@ def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs): | |
| return cls(encoder_config=encoder_config.to_dict(), **kwargs) | ||
|
|
||
|
|
||
| __all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig"] | ||
| class ParakeetTDTConfig(PreTrainedConfig): | ||
|
|
||
| model_type = "parakeet_tdt" | ||
| sub_configs = {"encoder_config": ParakeetEncoderConfig, "decoder_config": ParakeetTDTDecoderConfig, "joint_config": ParakeetTDTJointConfig} | ||
|
|
||
| def __init__( | ||
| self, | ||
| # bos_token_id=1, | ||
| # eos_token_id=2, | ||
| # pad_token_id=1024, | ||
| tdt_loss_reduction="mean", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems to be unused and can be removed? |
||
| encoder_config: Union[dict, ParakeetEncoderConfig] = None, | ||
| decoder_config: Union[dict, ParakeetTDTDecoderConfig] = None, | ||
| joint_config: Union[dict, ParakeetTDTJointConfig] = None, | ||
| **kwargs, | ||
| ): | ||
|
Comment on lines
+297
to
+309
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Transformers convention) Let's do something similar to I don't expect For example: class ParakeetTDTConfig(PreTrainedConfig):
""" TODO docstring"
model_type = "parakeet_tdt"
sub_configs = {"encoder_config": ParakeetEncoderConfig}
def __init__(
self,
encoder_config: Union[dict, ParakeetEncoderConfig] = None,
hidden_size=640,
decoder_hidden_size=640, # only if there is a case where different from `hidden_size`
hidden_act="relu",
num_hidden_layers=1,
dropout=0.0,
[forget_gate_bias and/or t_max as discussed above,
weights_init_scale=1.0,
durations=[0, 1, 2, 3, 4],
hidden_bias_scale=0,
vocab_size=1024,
blank_token_id=1024, # can pad_token_id naming be used instead and passed to `super().__init__`?
**kwargs,
):
# TODO set parameters like in ParakeetCTCConfig
...
super().__init__(**kwargs)
@property
def encoder_hidden_size(self) -> int:
return self.encoder_config.hidden_sizeNote that we can define properties for attributes like Moreover, in your conversion script, I notice that you are setting the I also notice that you don't use |
||
|
|
||
| if encoder_config is None: | ||
| self.encoder_config = ParakeetEncoderConfig() | ||
| elif isinstance(encoder_config, dict): | ||
| self.encoder_config = ParakeetEncoderConfig(**encoder_config) | ||
| elif isinstance(encoder_config, ParakeetEncoderConfig): | ||
| self.encoder_config = encoder_config | ||
| else: | ||
| raise ValueError( | ||
| f"`encoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" | ||
| ) | ||
|
Comment on lines
+317
to
+320
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to raise ValueError. You can do like in the CTC config here |
||
|
|
||
| if decoder_config is None: | ||
| self.decoder_config = ParakeetTDTDecoderConfig() | ||
| elif isinstance(decoder_config, dict): | ||
| self.decoder_config = ParakeetTDTDecoderConfig(**decoder_config) | ||
| elif isinstance(decoder_config, ParakeetTDTDecoderConfig): | ||
| self.decoder_config = decoder_config | ||
| else: | ||
| raise ValueError( | ||
| f"`decoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" | ||
| ) | ||
|
|
||
| if joint_config is None: | ||
| self.joint_config = ParakeetTDTJointConfig() | ||
| elif isinstance(joint_config, dict): | ||
| self.joint_config = ParakeetTDTJointConfig(**joint_config) | ||
| elif isinstance(joint_config, ParakeetTDTJointConfig): | ||
| self.joint_config = joint_config | ||
| else: | ||
| raise ValueError( | ||
| f"`decoder_config` must be a dictionary or an instance of `ParakeetEncoderConfig`, got {type(encoder_config)}" | ||
| ) | ||
|
|
||
| vocab_size = self.joint_config.vocab_size | ||
| self.vocab_size = vocab_size | ||
|
|
||
| self.blank_token_id = vocab_size | ||
| super().__init__( | ||
| # pad_token_id=self.blank_token_id, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def from_configs( | ||
| cls, | ||
| encoder_config: ParakeetEncoderConfig, | ||
| decoder_config: ParakeetTDTDecoderConfig, | ||
| joint_config: ParakeetTDTJointConfig, | ||
| **kwargs): | ||
| r""" | ||
| Instantiate a [`ParakeetConfig`] (or a derived class) from parakeet encoder model configuration. | ||
|
|
||
| Returns: | ||
| [`ParakeetConfig`]: An instance of a configuration object | ||
| """ | ||
|
|
||
| return cls( | ||
| encoder_config=encoder_config.to_dict(), | ||
| decoder_config=decoder_config.to_dict(), | ||
| joint_config=joint_config.to_dict(), | ||
| **kwargs) | ||
|
|
||
| __all__ = ["ParakeetCTCConfig", "ParakeetTDTConfig", "ParakeetEncoderConfig", "ParakeetTDTDecoderConfig", "ParakeetTDTJointConfig"] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such mappings can be removed after removing
ParakeetTDTDecoderConfigandParakeetTDTJointConfig