From c76381e363d561c5bd7e5907fb624ec418bde104 Mon Sep 17 00:00:00 2001 From: Shikhar Date: Thu, 9 Apr 2026 22:26:24 -0500 Subject: [PATCH 1/6] add pxeus modeling --- .../models/phoneticxeus/__init__.py | 34 + .../configuration_phoneticxeus.py | 177 +++++ .../convert_phoneticxeus_checkpoint.py | 233 ++++++ .../phoneticxeus/modeling_phoneticxeus.py | 673 ++++++++++++++++++ .../phoneticxeus/modular_phoneticxeus.py | 553 ++++++++++++++ .../phoneticxeus/processing_phoneticxeus.py | 92 +++ .../phoneticxeus/tokenization_phoneticxeus.py | 63 ++ 7 files changed, 1825 insertions(+) create mode 100644 src/transformers/models/phoneticxeus/__init__.py create mode 100644 src/transformers/models/phoneticxeus/configuration_phoneticxeus.py create mode 100644 src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py create mode 100644 src/transformers/models/phoneticxeus/modeling_phoneticxeus.py create mode 100644 src/transformers/models/phoneticxeus/modular_phoneticxeus.py create mode 100644 src/transformers/models/phoneticxeus/processing_phoneticxeus.py create mode 100644 src/transformers/models/phoneticxeus/tokenization_phoneticxeus.py diff --git a/src/transformers/models/phoneticxeus/__init__.py b/src/transformers/models/phoneticxeus/__init__.py new file mode 100644 index 000000000000..f8e947bed548 --- /dev/null +++ b/src/transformers/models/phoneticxeus/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule + + +if TYPE_CHECKING: + from .configuration_phoneticxeus import * + from .modeling_phoneticxeus import * + from .processing_phoneticxeus import * + from .tokenization_phoneticxeus import * +else: + import sys + + _file = globals()["__file__"] + modules = { + "configuration_phoneticxeus": ["PhoneticXeusConfig"], + "modeling_phoneticxeus": ["PhoneticXeusForCTC", "PhoneticXeusModel", "PhoneticXeusPreTrainedModel"], + "processing_phoneticxeus": ["PhoneticXeusProcessor"], + "tokenization_phoneticxeus": ["PhoneticXeusTokenizer"], + } + sys.modules[__name__] = _LazyModule(__name__, _file, modules, module_spec=__spec__) diff --git a/src/transformers/models/phoneticxeus/configuration_phoneticxeus.py b/src/transformers/models/phoneticxeus/configuration_phoneticxeus.py new file mode 100644 index 000000000000..3b641459af99 --- /dev/null +++ b/src/transformers/models/phoneticxeus/configuration_phoneticxeus.py @@ -0,0 +1,177 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PhoneticXeus model configuration""" + +import functools +import operator + +from ...configuration_utils import PretrainedConfig +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="changelinglab/PhoneticXeus") +class PhoneticXeusConfig(PretrainedConfig): + r""" + vocab_size (`int`, *optional*, defaults to 428): + Vocabulary size of the PhoneticXeus model (IPA phoneme inventory). + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers. + num_hidden_layers (`int`, *optional*, defaults to 19): + Number of E-Branchformer encoder layers. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the encoder. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimensionality of the feed-forward layers in the encoder. + hidden_act (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the feed-forward layers. + hidden_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for fully connected layers in the encoder. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the feature projection layer. + final_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the final projection layer of [`PhoneticXeusForCTC`]. + layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability during training. + initializer_range (`float`, *optional*, defaults to 0.02): + Standard deviation of the truncated normal initializer for weight initialization. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + Epsilon for layer normalization. + normalize_audio (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the raw audio waveform before the CNN feature encoder. + feat_extract_norm (`str`, *optional*, defaults to `"layer"`): + The norm to be applied to 1D convolutional layers in the feature encoder. One of `"group"` for group + normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D + convolutional layers. + feat_extract_activation (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the 1D convolutional layers of the feature extractor. + conv_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`): + Number of input and output channels of each 1D convolutional layer in the feature encoder. + conv_stride (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`): + Stride of each 1D convolutional layer in the feature encoder. + conv_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 2, 2)`): + Kernel size of each 1D convolutional layer in the feature encoder. + conv_bias (`bool`, *optional*, defaults to `True`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Kernel size of the convolutional positional embeddings layer. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of the convolutional positional embeddings layer. + conv_pos_weight_norm (`bool`, *optional*, defaults to `True`): + Whether to apply weight normalization to the convolutional positional embedding layer. + cgmlp_linear_units (`int`, *optional*, defaults to 4096): + Hidden dimensionality of the ConvolutionalGatingMLP in each E-Branchformer layer. + cgmlp_conv_kernel (`int`, *optional*, defaults to 31): + Kernel size of the depthwise convolution in the Convolutional Spatial Gating Unit (CSGU). + use_linear_after_conv (`bool`, *optional*, defaults to `False`): + Whether to apply a linear layer after the depthwise convolution in CSGU. + gate_activation (`str`, *optional*, defaults to `"identity"`): + Activation function for gating in CSGU. + merge_conv_kernel (`int`, *optional*, defaults to 31): + Kernel size of the depthwise convolution used to merge the two branches in each E-Branchformer layer. + use_ffn (`bool`, *optional*, defaults to `True`): + Whether to use feed-forward layers in each E-Branchformer layer. + macaron_ffn (`bool`, *optional*, defaults to `True`): + Whether to use macaron-style pre-branch feed-forward layer (half-step residual). + interctc_layer_idx (`tuple[int]` or `list[int]`, *optional*, defaults to `(4, 8, 12)`): + Layer indices (1-based) at which intermediate CTC self-conditioning is applied. At each specified layer, + the encoder output is projected through the CTC head, softmaxed, and fed back via a conditioning layer. + interctc_use_conditioning (`bool`, *optional*, defaults to `True`): + Whether to enable intermediate CTC self-conditioning in the encoder. + ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`): + Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. + ctc_zero_infinity (`bool`, *optional*, defaults to `True`): + Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. + + Example: + + ```python + >>> from transformers import PhoneticXeusConfig, PhoneticXeusModel + + >>> configuration = PhoneticXeusConfig() + >>> model = PhoneticXeusModel(configuration) + >>> configuration = model.config + ```""" + + model_type = "phoneticxeus" + + vocab_size: int = 428 + hidden_size: int = 1024 + num_hidden_layers: int = 19 + num_attention_heads: int = 8 + intermediate_size: int = 4096 + + hidden_act: str = "swish" + hidden_dropout: float | int = 0.1 + attention_dropout: float | int = 0.1 + feat_proj_dropout: float | int = 0.0 + final_dropout: float | int = 0.1 + layerdrop: float | int = 0.0 + initializer_range: float = 0.02 + layer_norm_eps: float = 1e-5 + + normalize_audio: bool = True + feat_extract_norm: str = "layer" + feat_extract_activation: str = "gelu" + conv_dim: list[int] | tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512) + conv_stride: list[int] | tuple[int, ...] = (5, 2, 2, 2, 2, 2, 2) + conv_kernel: list[int] | tuple[int, ...] = (10, 3, 3, 3, 3, 2, 2) + conv_bias: bool = True + + num_conv_pos_embeddings: int = 128 + num_conv_pos_embedding_groups: int = 16 + conv_pos_weight_norm: bool = True + + cgmlp_linear_units: int = 4096 + cgmlp_conv_kernel: int = 31 + use_linear_after_conv: bool = False + gate_activation: str = "identity" + merge_conv_kernel: int = 31 + use_ffn: bool = True + macaron_ffn: bool = True + + interctc_layer_idx: list[int] | tuple[int, ...] = (4, 8, 12) + interctc_use_conditioning: bool = True + + ctc_loss_reduction: str = "sum" + ctc_zero_infinity: bool = True + + pad_token_id: int | None = 0 + bos_token_id: int | None = 1 + eos_token_id: int | list[int] | None = 2 + + def __post_init__(self, **kwargs): + self.num_feat_extract_layers = len(self.conv_dim) + super().__post_init__(**kwargs) + + def validate_architecture(self): + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` ==" + " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) =" + f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`," + f" `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + @property + def inputs_to_logits_ratio(self): + return functools.reduce(operator.mul, self.conv_stride, 1) + + +__all__ = ["PhoneticXeusConfig"] diff --git a/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py b/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py new file mode 100644 index 000000000000..9da6fb579e5c --- /dev/null +++ b/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py @@ -0,0 +1,233 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert PhoneticXeus (ESPnet/Lightning) checkpoint to HuggingFace format. + +Usage: + python convert_phoneticxeus_checkpoint.py \ + --checkpoint_path changelinglab/PhoneticXeus \ + --output_dir ./phoneticxeus_hf +""" + +import argparse +import json +import re +import sys + +import torch + +from transformers import PhoneticXeusConfig, PhoneticXeusForCTC, PhoneticXeusTokenizer, Wav2Vec2FeatureExtractor + + +def load_checkpoint(checkpoint_path: str) -> dict: + """Load ESPnet/Lightning checkpoint, handling pickled module references.""" + import types + + # Stub modules that may be pickled in the checkpoint + for mod_name in ["src", "lightning"]: + if mod_name not in sys.modules: + m = types.ModuleType(mod_name) + m.__path__ = [] + m.__file__ = "" + sys.modules[mod_name] = m + + # Recursively stub submodules + class _StubFinder: + def find_module(self, fullname, path=None): + if fullname.startswith(("src.", "lightning.")): + return self + return None + + def load_module(self, fullname): + if fullname in sys.modules: + return sys.modules[fullname] + m = types.ModuleType(fullname) + m.__path__ = [] + m.__file__ = "" + m.__loader__ = self + sys.modules[fullname] = m + return m + + sys.meta_path.insert(0, _StubFinder()) + + state = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + # Extract state_dict from Lightning checkpoint + if "state_dict" in state: + sd = state["state_dict"] + # Strip "net." prefix from Lightning module wrapping + sd = {k.replace("net.", "", 1): v for k, v in sd.items() if k.startswith("net.")} + else: + sd = state + + return sd + + +_PREFIX = PhoneticXeusForCTC.base_model_prefix + + +def convert_key(key: str) -> str | None: + """Map a single ESPnet state_dict key to the HuggingFace equivalent.""" + + if key.startswith("frontend.layers."): + return key.replace("frontend.layers.", f"{_PREFIX}.feature_extractor.conv_layers.") + + if key.startswith("preencoder.linear_out."): + return key.replace("preencoder.linear_out.", f"{_PREFIX}.feature_projection.projection.") + + if key.startswith("encoder.embed.0.convs.0."): + return key.replace("encoder.embed.0.convs.0.", f"{_PREFIX}.encoder.pos_conv_embed.conv.") + + if key.startswith("encoder.after_norm."): + return key.replace("encoder.after_norm.", f"{_PREFIX}.encoder.layer_norm.") + + if key.startswith("encoder.conditioning_layer."): + return key.replace("encoder.conditioning_layer.", f"{_PREFIX}.encoder.conditioning_layer.") + + m = re.match(r"encoder\.encoders\.(\d+)\.(.*)", key) + if m: + layer_idx, rest = m.group(1), m.group(2) + if rest.startswith("attn."): + rest = rest.replace("attn.", "self_attn.", 1) + return f"{_PREFIX}.encoder.layers.{layer_idx}.{rest}" + + if key.startswith("ctc.ctc_lo."): + return key.replace("ctc.ctc_lo.", "lm_head.") + + return None + + +def convert_state_dict(sd: dict) -> dict: + """Convert full ESPnet state dict to HuggingFace format.""" + new_sd = {} + skipped = [] + for key, value in sd.items(): + new_key = convert_key(key) + if new_key is not None: + new_sd[new_key] = value + else: + skipped.append(key) + + if skipped: + print(f"Skipped {len(skipped)} keys: {skipped[:10]}{'...' if len(skipped) > 10 else ''}") + + return new_sd + + +def main(): + parser = argparse.ArgumentParser(description="Convert PhoneticXeus checkpoint to HuggingFace format") + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, + help="Path to .ckpt file or HuggingFace repo (e.g., 'changelinglab/PhoneticXeus')", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Output directory for HuggingFace model files", + ) + parser.add_argument( + "--vocab_file", + type=str, + default=None, + help="Path to ipa_vocab.json. If not provided, downloads from HF repo.", + ) + parser.add_argument( + "--push_to_hub", + type=str, + default=None, + help="If set, push model to this HuggingFace Hub repo", + ) + args = parser.parse_args() + + # Resolve checkpoint path + ckpt_path = args.checkpoint_path + if not ckpt_path.endswith(".ckpt") and not ckpt_path.endswith(".pth"): + # Assume HuggingFace repo + from huggingface_hub import hf_hub_download + + ckpt_path = hf_hub_download(ckpt_path, "checkpoint-22000.ckpt") + print(f"Downloaded checkpoint to: {ckpt_path}") + + # Load and convert + print("Loading checkpoint...") + sd = load_checkpoint(ckpt_path) + print(f"Loaded {len(sd)} keys from checkpoint") + + print("Converting state dict...") + hf_sd = convert_state_dict(sd) + print(f"Converted to {len(hf_sd)} HuggingFace keys") + + # Create config + config = PhoneticXeusConfig() + print(f"Config: hidden_size={config.hidden_size}, num_layers={config.num_hidden_layers}, vocab={config.vocab_size}") + + # Create model and load weights + print("Creating HuggingFace model...") + model = PhoneticXeusForCTC(config) + load_info = model.load_state_dict(hf_sd, strict=False) + print(f"Missing keys: {load_info.missing_keys}") + print(f"Unexpected keys: {load_info.unexpected_keys}") + + # Verify with dummy forward pass + print("Running verification forward pass...") + model.eval() + with torch.no_grad(): + dummy_input = torch.randn(1, 16000) # 1 second of audio + output = model(dummy_input) + logits = output.logits + print(f"Output shape: {logits.shape}") # Should be (1, T, 428) + assert logits.shape[-1] == config.vocab_size, f"Expected vocab_size={config.vocab_size}, got {logits.shape[-1]}" + print("Verification passed!") + + # Save + print(f"Saving model to {args.output_dir}...") + model.save_pretrained(args.output_dir) + config.save_pretrained(args.output_dir) + + # Save feature extractor + feature_extractor = Wav2Vec2FeatureExtractor( + feature_size=1, + sampling_rate=16000, + padding_value=0.0, + do_normalize=True, + return_attention_mask=True, + ) + feature_extractor.save_pretrained(args.output_dir) + + # Save tokenizer + processor if vocab provided + if args.vocab_file: + import shutil + + from transformers import PhoneticXeusProcessor + + shutil.copy(args.vocab_file, f"{args.output_dir}/vocab.json") + tokenizer = PhoneticXeusTokenizer(f"{args.output_dir}/vocab.json") + tokenizer.save_pretrained(args.output_dir) + processor = PhoneticXeusProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer) + processor.save_pretrained(args.output_dir) + print(f"Saved tokenizer + processor (vocab_size={tokenizer.vocab_size})") + + if args.push_to_hub: + print(f"Pushing to hub: {args.push_to_hub}") + model.push_to_hub(args.push_to_hub) + config.push_to_hub(args.push_to_hub) + feature_extractor.push_to_hub(args.push_to_hub) + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/phoneticxeus/modeling_phoneticxeus.py b/src/transformers/models/phoneticxeus/modeling_phoneticxeus.py new file mode 100644 index 000000000000..e82b7d3289ff --- /dev/null +++ b/src/transformers/models/phoneticxeus/modeling_phoneticxeus.py @@ -0,0 +1,673 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phoneticxeus/modular_phoneticxeus.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phoneticxeus.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, CausalLMOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring +from .configuration_phoneticxeus import PhoneticXeusConfig + + +class PhoneticXeusNoLayerNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class PhoneticXeusLayerNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class PhoneticXeusGroupNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class PhoneticXeusFeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [PhoneticXeusGroupNormConvLayer(config, layer_id=0)] + [ + PhoneticXeusNoLayerNormConvLayer(config, layer_id=i + 1) + for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [ + PhoneticXeusLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) + ] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +class PhoneticXeusFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + hidden_states = self.projection(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class PhoneticXeusPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + if config.conv_pos_weight_norm: + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.num_pad_remove = 1 if config.num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.conv(hidden_states) + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + hidden_states = nn.functional.gelu(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class PhoneticXeusSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_heads = config.num_attention_heads + self.head_size = config.hidden_size // config.num_attention_heads + + self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.attention_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + batch_size, seq_len, _ = hidden_states.size() + + query = self.linear_q(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2) + key = self.linear_k(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2) + value = self.linear_v(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) + if attention_mask is not None: + scores = scores + attention_mask + + attn_weights = torch.softmax(scores, dim=-1) + attn_weights = self.dropout(attn_weights) + + hidden_states = torch.matmul(attn_weights, value) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, seq_len, -1) + hidden_states = self.linear_out(hidden_states) + + return hidden_states, attn_weights if output_attentions else None + + +class PhoneticXeusFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.w_1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.activation = ACT2FN[config.hidden_act] + self.dropout = nn.Dropout(config.hidden_dropout) + self.w_2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states): + hidden_states = self.w_1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.w_2(hidden_states) + return hidden_states + + +class PhoneticXeusConvolutionalSpatialGatingUnit(nn.Module): + """CSGU: splits input, applies depthwise conv on gate half, then element-wise multiply.""" + + def __init__(self, config): + super().__init__() + n_channels = config.cgmlp_linear_units // 2 + self.norm = nn.LayerNorm(n_channels) + self.conv = nn.Conv1d( + n_channels, + n_channels, + config.cgmlp_conv_kernel, + stride=1, + padding=(config.cgmlp_conv_kernel - 1) // 2, + groups=n_channels, + ) + self.linear = nn.Linear(n_channels, n_channels) if config.use_linear_after_conv else None + + if config.gate_activation == "identity": + self.gate_activation = nn.Identity() + else: + self.gate_activation = ACT2FN[config.gate_activation] + + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + x_r, x_g = hidden_states.chunk(2, dim=-1) + x_g = self.norm(x_g) + x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) + if self.linear is not None: + x_g = self.linear(x_g) + x_g = self.gate_activation(x_g) + return self.dropout(x_r * x_g) + + +class PhoneticXeusConvolutionalGatingMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.channel_proj1 = nn.Sequential( + nn.Linear(config.hidden_size, config.cgmlp_linear_units), + nn.GELU(), + ) + self.csgu = PhoneticXeusConvolutionalSpatialGatingUnit(config) + self.channel_proj2 = nn.Linear(config.cgmlp_linear_units // 2, config.hidden_size) + + def forward(self, hidden_states): + hidden_states = self.channel_proj1(hidden_states) + hidden_states = self.csgu(hidden_states) + hidden_states = self.channel_proj2(hidden_states) + return hidden_states + + +class PhoneticXeusEBranchformerEncoderLayer(GradientCheckpointingLayer): + """E-Branchformer layer: parallel self-attention and cgMLP branches merged via depthwise conv.""" + + def __init__(self, config): + super().__init__() + self.self_attn = PhoneticXeusSelfAttention(config) + self.cgmlp = PhoneticXeusConvolutionalGatingMLP(config) + + self.norm_mha = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm_mlp = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.depthwise_conv_fusion = nn.Conv1d( + 2 * config.hidden_size, + 2 * config.hidden_size, + kernel_size=config.merge_conv_kernel, + stride=1, + padding=(config.merge_conv_kernel - 1) // 2, + groups=2 * config.hidden_size, + bias=True, + ) + self.merge_proj = nn.Linear(2 * config.hidden_size, config.hidden_size) + + self.ff_scale = 1.0 + if config.use_ffn and config.macaron_ffn: + self.feed_forward_macaron = PhoneticXeusFeedForward(config) + self.norm_ff_macaron = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ff_scale = 0.5 + else: + self.feed_forward_macaron = None + + if config.use_ffn: + self.feed_forward = PhoneticXeusFeedForward(config) + self.norm_ff = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + else: + self.feed_forward = None + + self.norm_final = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if self.feed_forward_macaron is not None: + residual = hidden_states + hidden_states = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(self.norm_ff_macaron(hidden_states)) + ) + + x1 = self.norm_mha(hidden_states) + x1, attn_weights = self.self_attn(x1, attention_mask=attention_mask, output_attentions=output_attentions) + x1 = self.dropout(x1) + + x2 = self.norm_mlp(hidden_states) + x2 = self.dropout(self.cgmlp(x2)) + + x_concat = torch.cat([x1, x2], dim=-1) + x_tmp = self.depthwise_conv_fusion(x_concat.transpose(1, 2)).transpose(1, 2) + hidden_states = hidden_states + self.dropout(self.merge_proj(x_concat + x_tmp)) + + if self.feed_forward is not None: + residual = hidden_states + hidden_states = residual + self.ff_scale * self.dropout(self.feed_forward(self.norm_ff(hidden_states))) + + hidden_states = self.norm_final(hidden_states) + return hidden_states, attn_weights + + +class PhoneticXeusEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = PhoneticXeusPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList( + [PhoneticXeusEBranchformerEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + self.interctc_layer_idx = list(config.interctc_layer_idx) if config.interctc_layer_idx else [] + self.interctc_use_conditioning = config.interctc_use_conditioning + if self.interctc_use_conditioning and self.interctc_layer_idx: + self.conditioning_layer = nn.Linear(config.vocab_size, config.hidden_size) + else: + self.conditioning_layer = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ctc_head: nn.Linear | None = None, + ) -> tuple | BaseModelOutput: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0.0 + + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + hidden_states = self.pos_conv_embed(hidden_states) + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + dropout_probability = torch.rand([]) + skip_the_layer = self.training and dropout_probability < self.config.layerdrop + + if not skip_the_layer or synced_gpus: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if ( + self.interctc_use_conditioning + and self.conditioning_layer is not None + and ctc_head is not None + and (i + 1) in self.interctc_layer_idx + ): + ctc_out = torch.softmax(ctc_head(hidden_states), dim=-1) + hidden_states = hidden_states + self.conditioning_layer(ctc_out) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class PhoneticXeusPreTrainedModel(PreTrainedModel): + config: PhoneticXeusConfig + base_model_prefix = "phonetic_xeus" + main_input_name = "input_values" + supports_gradient_checkpointing = True + + @torch.no_grad() + def _init_weights(self, module): + if isinstance(module, PhoneticXeusPositionalConvEmbedding): + std = math.sqrt(4.0 / (module.conv.kernel_size[0] * module.conv.in_channels)) + init.normal_(module.conv.weight, mean=0, std=std) + init.constant_(module.conv.bias, 0) + elif isinstance(module, PhoneticXeusFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + init.zeros_(module.bias) + init.ones_(module.weight) + elif isinstance(module, nn.Conv1d): + init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int): + def _conv_out_length(input_length, kernel_size, stride): + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + output_lengths = self._get_feat_extract_output_lengths(attention_mask.cumsum(dim=-1)[:, -1]).to(torch.long) + batch_size = attention_mask.shape[0] + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + attention_mask[(torch.arange(batch_size, device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +class PhoneticXeusModel(PhoneticXeusPreTrainedModel): + def __init__(self, config: PhoneticXeusConfig): + super().__init__(config) + self.feature_extractor = PhoneticXeusFeatureEncoder(config) + self.feature_projection = PhoneticXeusFeatureProjection(config) + self.encoder = PhoneticXeusEncoder(config) + self.post_init() + + def freeze_feature_encoder(self): + self.feature_extractor._freeze_parameters() + + @auto_docstring + def forward( + self, + input_values: torch.Tensor | None, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | BaseModelOutput: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if self.config.normalize_audio: + input_values = torch.nn.functional.layer_norm(input_values, input_values.shape) + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + + hidden_states = self.feature_projection(extract_features) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ctc_head=kwargs.get("ctc_head"), + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +_HIDDEN_STATES_START_POSITION = 2 + + +class PhoneticXeusForCTC(PhoneticXeusPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.phonetic_xeus = PhoneticXeusModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `PhoneticXeusForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) + self.post_init() + + def freeze_feature_encoder(self): + self.phonetic_xeus.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + for param in self.phonetic_xeus.parameters(): + param.requires_grad = False + + @auto_docstring + def forward( + self, + input_values: torch.Tensor | None, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: torch.Tensor | None = None, + **kwargs, + ) -> tuple | CausalLMOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + outputs = self.phonetic_xeus( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ctc_head=self.lm_head, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +__all__ = ["PhoneticXeusForCTC", "PhoneticXeusModel", "PhoneticXeusPreTrainedModel"] diff --git a/src/transformers/models/phoneticxeus/modular_phoneticxeus.py b/src/transformers/models/phoneticxeus/modular_phoneticxeus.py new file mode 100644 index 000000000000..1f823c093ce7 --- /dev/null +++ b/src/transformers/models/phoneticxeus/modular_phoneticxeus.py @@ -0,0 +1,553 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PhoneticXeus model: E-Branchformer encoder for multilingual phone recognition.""" + +import math + +import torch +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, CausalLMOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring +from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2FeatureEncoder +from .configuration_phoneticxeus import PhoneticXeusConfig + + +_HIDDEN_STATES_START_POSITION = 2 + + +class PhoneticXeusFeatureEncoder(Wav2Vec2FeatureEncoder): + pass + + +class PhoneticXeusFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + hidden_states = self.projection(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class PhoneticXeusPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + if config.conv_pos_weight_norm: + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.num_pad_remove = 1 if config.num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.conv(hidden_states) + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + hidden_states = nn.functional.gelu(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class PhoneticXeusSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_heads = config.num_attention_heads + self.head_size = config.hidden_size // config.num_attention_heads + + self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.attention_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + batch_size, seq_len, _ = hidden_states.size() + + query = self.linear_q(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2) + key = self.linear_k(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2) + value = self.linear_v(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) + if attention_mask is not None: + scores = scores + attention_mask + + attn_weights = torch.softmax(scores, dim=-1) + attn_weights = self.dropout(attn_weights) + + hidden_states = torch.matmul(attn_weights, value) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, seq_len, -1) + hidden_states = self.linear_out(hidden_states) + + return hidden_states, attn_weights if output_attentions else None + + +class PhoneticXeusFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.w_1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.activation = ACT2FN[config.hidden_act] + self.dropout = nn.Dropout(config.hidden_dropout) + self.w_2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states): + hidden_states = self.w_1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.w_2(hidden_states) + return hidden_states + + +class PhoneticXeusConvolutionalSpatialGatingUnit(nn.Module): + """CSGU: splits input, applies depthwise conv on gate half, then element-wise multiply.""" + + def __init__(self, config): + super().__init__() + n_channels = config.cgmlp_linear_units // 2 + self.norm = nn.LayerNorm(n_channels) + self.conv = nn.Conv1d( + n_channels, n_channels, config.cgmlp_conv_kernel, + stride=1, padding=(config.cgmlp_conv_kernel - 1) // 2, groups=n_channels, + ) + self.linear = nn.Linear(n_channels, n_channels) if config.use_linear_after_conv else None + + if config.gate_activation == "identity": + self.gate_activation = nn.Identity() + else: + self.gate_activation = ACT2FN[config.gate_activation] + + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + x_r, x_g = hidden_states.chunk(2, dim=-1) + x_g = self.norm(x_g) + x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) + if self.linear is not None: + x_g = self.linear(x_g) + x_g = self.gate_activation(x_g) + return self.dropout(x_r * x_g) + + +class PhoneticXeusConvolutionalGatingMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.channel_proj1 = nn.Sequential( + nn.Linear(config.hidden_size, config.cgmlp_linear_units), + nn.GELU(), + ) + self.csgu = PhoneticXeusConvolutionalSpatialGatingUnit(config) + self.channel_proj2 = nn.Linear(config.cgmlp_linear_units // 2, config.hidden_size) + + def forward(self, hidden_states): + hidden_states = self.channel_proj1(hidden_states) + hidden_states = self.csgu(hidden_states) + hidden_states = self.channel_proj2(hidden_states) + return hidden_states + + +class PhoneticXeusEBranchformerEncoderLayer(GradientCheckpointingLayer): + """E-Branchformer layer: parallel self-attention and cgMLP branches merged via depthwise conv.""" + + def __init__(self, config): + super().__init__() + self.self_attn = PhoneticXeusSelfAttention(config) + self.cgmlp = PhoneticXeusConvolutionalGatingMLP(config) + + self.norm_mha = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm_mlp = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.depthwise_conv_fusion = nn.Conv1d( + 2 * config.hidden_size, 2 * config.hidden_size, + kernel_size=config.merge_conv_kernel, stride=1, + padding=(config.merge_conv_kernel - 1) // 2, + groups=2 * config.hidden_size, bias=True, + ) + self.merge_proj = nn.Linear(2 * config.hidden_size, config.hidden_size) + + self.ff_scale = 1.0 + if config.use_ffn and config.macaron_ffn: + self.feed_forward_macaron = PhoneticXeusFeedForward(config) + self.norm_ff_macaron = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ff_scale = 0.5 + else: + self.feed_forward_macaron = None + + if config.use_ffn: + self.feed_forward = PhoneticXeusFeedForward(config) + self.norm_ff = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + else: + self.feed_forward = None + + self.norm_final = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if self.feed_forward_macaron is not None: + residual = hidden_states + hidden_states = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(hidden_states))) + + x1 = self.norm_mha(hidden_states) + x1, attn_weights = self.self_attn(x1, attention_mask=attention_mask, output_attentions=output_attentions) + x1 = self.dropout(x1) + + x2 = self.norm_mlp(hidden_states) + x2 = self.dropout(self.cgmlp(x2)) + + x_concat = torch.cat([x1, x2], dim=-1) + x_tmp = self.depthwise_conv_fusion(x_concat.transpose(1, 2)).transpose(1, 2) + hidden_states = hidden_states + self.dropout(self.merge_proj(x_concat + x_tmp)) + + if self.feed_forward is not None: + residual = hidden_states + hidden_states = residual + self.ff_scale * self.dropout(self.feed_forward(self.norm_ff(hidden_states))) + + hidden_states = self.norm_final(hidden_states) + return hidden_states, attn_weights + + +class PhoneticXeusEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = PhoneticXeusPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList( + [PhoneticXeusEBranchformerEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + self.interctc_layer_idx = list(config.interctc_layer_idx) if config.interctc_layer_idx else [] + self.interctc_use_conditioning = config.interctc_use_conditioning + if self.interctc_use_conditioning and self.interctc_layer_idx: + self.conditioning_layer = nn.Linear(config.vocab_size, config.hidden_size) + else: + self.conditioning_layer = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ctc_head: nn.Linear | None = None, + ) -> tuple | BaseModelOutput: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0.0 + + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + hidden_states = self.pos_conv_embed(hidden_states) + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + dropout_probability = torch.rand([]) + skip_the_layer = self.training and dropout_probability < self.config.layerdrop + + if not skip_the_layer or synced_gpus: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if ( + self.interctc_use_conditioning + and self.conditioning_layer is not None + and ctc_head is not None + and (i + 1) in self.interctc_layer_idx + ): + ctc_out = torch.softmax(ctc_head(hidden_states), dim=-1) + hidden_states = hidden_states + self.conditioning_layer(ctc_out) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class PhoneticXeusPreTrainedModel(PreTrainedModel): + config: PhoneticXeusConfig + base_model_prefix = "phonetic_xeus" + main_input_name = "input_values" + supports_gradient_checkpointing = True + + @torch.no_grad() + def _init_weights(self, module): + if isinstance(module, PhoneticXeusPositionalConvEmbedding): + std = math.sqrt(4.0 / (module.conv.kernel_size[0] * module.conv.in_channels)) + init.normal_(module.conv.weight, mean=0, std=std) + init.constant_(module.conv.bias, 0) + elif isinstance(module, PhoneticXeusFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + init.zeros_(module.bias) + init.ones_(module.weight) + elif isinstance(module, nn.Conv1d): + init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int): + def _conv_out_length(input_length, kernel_size, stride): + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + output_lengths = self._get_feat_extract_output_lengths(attention_mask.cumsum(dim=-1)[:, -1]).to(torch.long) + batch_size = attention_mask.shape[0] + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + attention_mask[(torch.arange(batch_size, device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +class PhoneticXeusModel(PhoneticXeusPreTrainedModel): + def __init__(self, config: PhoneticXeusConfig): + super().__init__(config) + self.feature_extractor = PhoneticXeusFeatureEncoder(config) + self.feature_projection = PhoneticXeusFeatureProjection(config) + self.encoder = PhoneticXeusEncoder(config) + self.post_init() + + def freeze_feature_encoder(self): + self.feature_extractor._freeze_parameters() + + @auto_docstring + def forward( + self, + input_values: torch.Tensor | None, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | BaseModelOutput: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if self.config.normalize_audio: + input_values = torch.nn.functional.layer_norm(input_values, input_values.shape) + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + + hidden_states = self.feature_projection(extract_features) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ctc_head=kwargs.get("ctc_head"), + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class PhoneticXeusForCTC(PhoneticXeusPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.phonetic_xeus = PhoneticXeusModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `PhoneticXeusForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) + self.post_init() + + def freeze_feature_encoder(self): + self.phonetic_xeus.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + for param in self.phonetic_xeus.parameters(): + param.requires_grad = False + + @auto_docstring + def forward( + self, + input_values: torch.Tensor | None, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: torch.Tensor | None = None, + **kwargs, + ) -> tuple | CausalLMOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + outputs = self.phonetic_xeus( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ctc_head=self.lm_head, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +__all__ = [ + "PhoneticXeusForCTC", + "PhoneticXeusModel", + "PhoneticXeusPreTrainedModel", +] diff --git a/src/transformers/models/phoneticxeus/processing_phoneticxeus.py b/src/transformers/models/phoneticxeus/processing_phoneticxeus.py new file mode 100644 index 000000000000..df965531bcd3 --- /dev/null +++ b/src/transformers/models/phoneticxeus/processing_phoneticxeus.py @@ -0,0 +1,92 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor for PhoneticXeus.""" + +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput +from ...utils import auto_docstring + + +class PhoneticXeusProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + +@auto_docstring +class PhoneticXeusProcessor(ProcessorMixin): + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + @auto_docstring + def __call__( + self, + audio: AudioInput | None = None, + text: str | list[str] | TextInput | PreTokenizedInput | None = None, + **kwargs: Unpack[PhoneticXeusProcessorKwargs], + ): + r""" + Returns: + This method returns the audio features and/or tokenized text as needed. + """ + if audio is None and text is None: + raise ValueError("You need to specify either an `audio` or `text` input to process.") + + output_kwargs = self._merge_kwargs( + PhoneticXeusProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if audio is not None: + inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + if text is not None: + encodings = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + if text is None: + return inputs + elif audio is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + return inputs + + def pad(self, *args, **kwargs): + """ + Pads extracted audio features and/or tokenized text. + Forwards to [`Wav2Vec2FeatureExtractor.pad`] and/or [`PreTrainedTokenizer.pad`]. + """ + input_features = kwargs.pop("input_features", None) + labels = kwargs.pop("labels", None) + if len(args) > 0: + input_features = args[0] + args = args[1:] + + if input_features is not None: + input_features = self.feature_extractor.pad(input_features, *args, **kwargs) + if labels is not None: + labels = self.tokenizer.pad(labels, **kwargs) + + if labels is None: + return input_features + elif input_features is None: + return labels + else: + input_features["labels"] = labels["input_ids"] + return input_features + + @property + def model_input_names(self): + return self.feature_extractor.model_input_names + ["labels"] + + +__all__ = ["PhoneticXeusProcessor"] diff --git a/src/transformers/models/phoneticxeus/tokenization_phoneticxeus.py b/src/transformers/models/phoneticxeus/tokenization_phoneticxeus.py new file mode 100644 index 000000000000..dd40f76db40e --- /dev/null +++ b/src/transformers/models/phoneticxeus/tokenization_phoneticxeus.py @@ -0,0 +1,63 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenizer for PhoneticXeus (IPA CTC tokenizer).""" + +from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer + + +class PhoneticXeusTokenizer(Wav2Vec2CTCTokenizer): + """CTC tokenizer for IPA phone sequences. + + Thin wrapper around [`Wav2Vec2CTCTokenizer`] with defaults matching the PhoneticXeus IPA vocabulary + (428 tokens). Handles CTC blank collapsing and ID-to-IPA conversion. + + Args: + vocab_file (`str`): + Path to `vocab.json` mapping IPA phone strings to integer IDs. + bos_token (`str`, *optional*, defaults to `""`): + Beginning of sentence token. + eos_token (`str`, *optional*, defaults to `""`): + End of sentence token. + unk_token (`str`, *optional*, defaults to `""`): + Unknown token. + pad_token (`str`, *optional*, defaults to `""`): + Padding / CTC blank token. + word_delimiter_token (`str`, *optional*, defaults to `" "`): + Token used as word delimiter. Set to `" "` since IPA transcriptions use space between words. + **kwargs: + Additional keyword arguments passed to [`Wav2Vec2CTCTokenizer`]. + """ + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + unk_token="", + pad_token="", + word_delimiter_token=" ", + **kwargs, + ): + super().__init__( + vocab_file, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + word_delimiter_token=word_delimiter_token, + **kwargs, + ) + + +__all__ = ["PhoneticXeusTokenizer"] From d1232590c990b7e8e389a024ee4119e6417e03ea Mon Sep 17 00:00:00 2001 From: Shikhar Date: Thu, 9 Apr 2026 23:22:12 -0500 Subject: [PATCH 2/6] fix issue with return --- .../convert_phoneticxeus_checkpoint.py | 9 ++++--- .../phoneticxeus/modeling_phoneticxeus.py | 2 +- .../phoneticxeus/modular_phoneticxeus.py | 27 +++++++++++++------ 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py b/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py index 9da6fb579e5c..b025e499d344 100644 --- a/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py +++ b/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py @@ -20,7 +20,6 @@ """ import argparse -import json import re import sys @@ -172,7 +171,9 @@ def main(): # Create config config = PhoneticXeusConfig() - print(f"Config: hidden_size={config.hidden_size}, num_layers={config.num_hidden_layers}, vocab={config.vocab_size}") + print( + f"Config: hidden_size={config.hidden_size}, num_layers={config.num_hidden_layers}, vocab={config.vocab_size}" + ) # Create model and load weights print("Creating HuggingFace model...") @@ -189,7 +190,9 @@ def main(): output = model(dummy_input) logits = output.logits print(f"Output shape: {logits.shape}") # Should be (1, T, 428) - assert logits.shape[-1] == config.vocab_size, f"Expected vocab_size={config.vocab_size}, got {logits.shape[-1]}" + assert logits.shape[-1] == config.vocab_size, ( + f"Expected vocab_size={config.vocab_size}, got {logits.shape[-1]}" + ) print("Verification passed!") # Save diff --git a/src/transformers/models/phoneticxeus/modeling_phoneticxeus.py b/src/transformers/models/phoneticxeus/modeling_phoneticxeus.py index e82b7d3289ff..34940f074ab4 100644 --- a/src/transformers/models/phoneticxeus/modeling_phoneticxeus.py +++ b/src/transformers/models/phoneticxeus/modeling_phoneticxeus.py @@ -575,7 +575,7 @@ def forward( ) -_HIDDEN_STATES_START_POSITION = 2 +_HIDDEN_STATES_START_POSITION = 1 class PhoneticXeusForCTC(PhoneticXeusPreTrainedModel): diff --git a/src/transformers/models/phoneticxeus/modular_phoneticxeus.py b/src/transformers/models/phoneticxeus/modular_phoneticxeus.py index 1f823c093ce7..3535b3a20f0a 100644 --- a/src/transformers/models/phoneticxeus/modular_phoneticxeus.py +++ b/src/transformers/models/phoneticxeus/modular_phoneticxeus.py @@ -30,7 +30,7 @@ from .configuration_phoneticxeus import PhoneticXeusConfig -_HIDDEN_STATES_START_POSITION = 2 +_HIDDEN_STATES_START_POSITION = 1 class PhoneticXeusFeatureEncoder(Wav2Vec2FeatureEncoder): @@ -155,8 +155,12 @@ def __init__(self, config): n_channels = config.cgmlp_linear_units // 2 self.norm = nn.LayerNorm(n_channels) self.conv = nn.Conv1d( - n_channels, n_channels, config.cgmlp_conv_kernel, - stride=1, padding=(config.cgmlp_conv_kernel - 1) // 2, groups=n_channels, + n_channels, + n_channels, + config.cgmlp_conv_kernel, + stride=1, + padding=(config.cgmlp_conv_kernel - 1) // 2, + groups=n_channels, ) self.linear = nn.Linear(n_channels, n_channels) if config.use_linear_after_conv else None @@ -206,10 +210,13 @@ def __init__(self, config): self.norm_mlp = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.depthwise_conv_fusion = nn.Conv1d( - 2 * config.hidden_size, 2 * config.hidden_size, - kernel_size=config.merge_conv_kernel, stride=1, + 2 * config.hidden_size, + 2 * config.hidden_size, + kernel_size=config.merge_conv_kernel, + stride=1, padding=(config.merge_conv_kernel - 1) // 2, - groups=2 * config.hidden_size, bias=True, + groups=2 * config.hidden_size, + bias=True, ) self.merge_proj = nn.Linear(2 * config.hidden_size, config.hidden_size) @@ -238,7 +245,9 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: if self.feed_forward_macaron is not None: residual = hidden_states - hidden_states = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(hidden_states))) + hidden_states = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(self.norm_ff_macaron(hidden_states)) + ) x1 = self.norm_mha(hidden_states) x1, attn_weights = self.self_attn(x1, attention_mask=attention_mask, output_attentions=output_attentions) @@ -313,7 +322,9 @@ def forward( if not skip_the_layer or synced_gpus: layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] From 98ea62c99842f2c2567b184a30b712d57608c548 Mon Sep 17 00:00:00 2001 From: Shikhar Date: Thu, 9 Apr 2026 23:24:05 -0500 Subject: [PATCH 3/6] add documentation --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/phoneticxeus.md | 67 ++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 docs/source/en/model_doc/phoneticxeus.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b9f66011de80..761e9907e034 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1075,6 +1075,8 @@ title: Parakeet - local: model_doc/pe_audio title: PE Audio + - local: model_doc/phoneticxeus + title: PhoneticXeus - local: model_doc/pop2piano title: Pop2Piano - local: model_doc/seamless_m4t diff --git a/docs/source/en/model_doc/phoneticxeus.md b/docs/source/en/model_doc/phoneticxeus.md new file mode 100644 index 000000000000..41a517ef9019 --- /dev/null +++ b/docs/source/en/model_doc/phoneticxeus.md @@ -0,0 +1,67 @@ + + +# PhoneticXeus + +
+PyTorch +
+ +## Overview + +PhoneticXeus is a multilingual IPA phone recognition model that uses a CNN feature encoder followed by an +E-Branchformer encoder with a CTC head. The E-Branchformer architecture runs self-attention and a Convolutional +Gating MLP (cgMLP) as parallel branches merged via depthwise convolution, rather than the sequential +attention-convolution pattern used in Conformer models. The model employs intermediate CTC (interCTC) +self-conditioning at specified encoder layers. + +The model was released by the [Changeling Lab](https://huggingface.co/changelinglab) and can be found at +[changelinglab/PhoneticXeus](https://huggingface.co/changelinglab/PhoneticXeus). + +## Usage tips + +- PhoneticXeus outputs IPA (International Phonetic Alphabet) phone sequences rather than text transcriptions. +- The vocabulary consists of 428 IPA phoneme tokens. +- It uses `Wav2Vec2FeatureExtractor` for audio preprocessing and `PhoneticXeusTokenizer` for CTC decoding. +- The E-Branchformer encoder uses parallel self-attention and cgMLP branches instead of the sequential + attention and convolution used in Conformer models. +- InterCTC self-conditioning feeds intermediate CTC predictions back into the encoder at layers 4, 8, and 12. + +## Resources + +- [Automatic speech recognition task guide](../tasks/asr) + +## PhoneticXeusConfig + +[[autodoc]] PhoneticXeusConfig + +## PhoneticXeusModel + +[[autodoc]] PhoneticXeusModel + - forward + +## PhoneticXeusForCTC + +[[autodoc]] PhoneticXeusForCTC + - forward + +## PhoneticXeusTokenizer + +[[autodoc]] PhoneticXeusTokenizer + +## PhoneticXeusProcessor + +[[autodoc]] PhoneticXeusProcessor From 37651ede868d3b3f7d60cb59b061b9fb50b05ed9 Mon Sep 17 00:00:00 2001 From: Shikhar Date: Thu, 9 Apr 2026 23:24:40 -0500 Subject: [PATCH 4/6] add tests for intergratinon --- tests/models/phoneticxeus/__init__.py | 0 .../test_modeling_phoneticxeus.py | 427 ++++++++++++++++++ 2 files changed, 427 insertions(+) create mode 100644 tests/models/phoneticxeus/__init__.py create mode 100644 tests/models/phoneticxeus/test_modeling_phoneticxeus.py diff --git a/tests/models/phoneticxeus/__init__.py b/tests/models/phoneticxeus/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/phoneticxeus/test_modeling_phoneticxeus.py b/tests/models/phoneticxeus/test_modeling_phoneticxeus.py new file mode 100644 index 000000000000..e8971b85d142 --- /dev/null +++ b/tests/models/phoneticxeus/test_modeling_phoneticxeus.py @@ -0,0 +1,427 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch PhoneticXeus model.""" + +import math +import tempfile +import unittest + +from transformers import PhoneticXeusConfig, is_torch_available +from transformers.testing_utils import ( + require_torch, + require_torch_accelerator, + require_torch_fp16, + slow, + torch_device, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + floats_tensor, + ids_tensor, + random_attention_mask, +) +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + PhoneticXeusForCTC, + PhoneticXeusModel, + Wav2Vec2FeatureExtractor, + ) + + +class PhoneticXeusModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=1024, + is_training=False, + hidden_size=16, + feat_extract_norm="layer", + feat_extract_dropout=0.0, + feat_extract_activation="gelu", + conv_dim=(32, 32, 32), + conv_stride=(4, 4, 4), + conv_kernel=(8, 8, 8), + conv_bias=True, + num_conv_pos_embeddings=16, + num_conv_pos_embedding_groups=2, + num_hidden_layers=2, + num_attention_heads=2, + hidden_dropout=0.1, + attention_dropout=0.1, + intermediate_size=20, + layer_norm_eps=1e-5, + hidden_act="swish", + initializer_range=0.02, + vocab_size=32, + cgmlp_linear_units=20, + cgmlp_conv_kernel=3, + merge_conv_kernel=3, + use_ffn=True, + macaron_ffn=True, + interctc_layer_idx=(1,), + interctc_use_conditioning=True, + normalize_audio=True, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_dropout = feat_extract_dropout + self.feat_extract_activation = feat_extract_activation + self.conv_dim = conv_dim + self.conv_stride = conv_stride + self.conv_kernel = conv_kernel + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.intermediate_size = intermediate_size + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.vocab_size = vocab_size + self.cgmlp_linear_units = cgmlp_linear_units + self.cgmlp_conv_kernel = cgmlp_conv_kernel + self.merge_conv_kernel = merge_conv_kernel + self.use_ffn = use_ffn + self.macaron_ffn = macaron_ffn + self.interctc_layer_idx = interctc_layer_idx + self.interctc_use_conditioning = interctc_use_conditioning + self.normalize_audio = normalize_audio + self.scope = scope + + output_seq_length = self.seq_length + for kernel, stride in zip(self.conv_kernel, self.conv_stride): + output_seq_length = (output_seq_length - (kernel - 1)) / stride + self.output_seq_length = int(math.ceil(output_seq_length)) + self.encoder_seq_length = self.output_seq_length + + def prepare_config_and_inputs(self): + input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) + + config = self.get_config() + + return config, input_values, attention_mask + + def get_config(self): + return PhoneticXeusConfig( + hidden_size=self.hidden_size, + feat_extract_norm=self.feat_extract_norm, + feat_extract_activation=self.feat_extract_activation, + conv_dim=self.conv_dim, + conv_stride=self.conv_stride, + conv_kernel=self.conv_kernel, + conv_bias=self.conv_bias, + num_conv_pos_embeddings=self.num_conv_pos_embeddings, + num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + hidden_dropout=self.hidden_dropout, + attention_dropout=self.attention_dropout, + intermediate_size=self.intermediate_size, + layer_norm_eps=self.layer_norm_eps, + hidden_act=self.hidden_act, + initializer_range=self.initializer_range, + vocab_size=self.vocab_size, + cgmlp_linear_units=self.cgmlp_linear_units, + cgmlp_conv_kernel=self.cgmlp_conv_kernel, + merge_conv_kernel=self.merge_conv_kernel, + use_ffn=self.use_ffn, + macaron_ffn=self.macaron_ffn, + interctc_layer_idx=self.interctc_layer_idx, + interctc_use_conditioning=self.interctc_use_conditioning, + normalize_audio=self.normalize_audio, + ) + + def create_and_check_model(self, config, input_values, attention_mask): + model = PhoneticXeusModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_values, attention_mask=attention_mask) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size) + ) + + def create_and_check_model_float16(self, config, input_values, attention_mask): + model = PhoneticXeusModel(config=config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = PhoneticXeusModel.from_pretrained(tmpdirname, dtype=torch.float16) + + model.to(torch_device) + model.eval() + + with torch.no_grad(): + result = model(input_values.type(dtype=torch.float16), attention_mask=attention_mask) + + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size) + ) + + def check_ctc_loss(self, config, input_values, *args): + model = PhoneticXeusForCTC(config=config) + model.to(torch_device) + + # make sure that dropout is disabled + model.eval() + + input_values = input_values[:3] + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0 + + model.config.ctc_loss_reduction = "sum" + sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() + + model.config.ctc_loss_reduction = "mean" + mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() + + self.parent.assertTrue(isinstance(sum_loss, float)) + self.parent.assertTrue(isinstance(mean_loss, float)) + + def check_ctc_training(self, config, input_values, *args): + config.ctc_zero_infinity = True + model = PhoneticXeusForCTC(config=config) + model.to(torch_device) + model.train() + + # freeze feature encoder + model.freeze_feature_encoder() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + + if max_length_labels[i] < labels.shape[-1]: + # it's important that we make sure that target lengths are at least + # one shorter than logit lengths to prevent -inf + labels[i, max_length_labels[i] - 1 :] = -100 + + loss = model(input_values, labels=labels).loss + self.parent.assertFalse(torch.isinf(loss).item()) + + loss.backward() + + def check_labels_out_of_vocab(self, config, input_values, *args): + model = PhoneticXeusForCTC(config) + model.to(torch_device) + model.train() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100) + + with self.parent.assertRaises(ValueError): + model(input_values, labels=labels) + + def prepare_config_and_inputs_for_common(self): + config, input_values, attention_mask = self.prepare_config_and_inputs() + inputs_dict = {"input_values": input_values, "attention_mask": attention_mask} + return config, inputs_dict + + +@require_torch +class PhoneticXeusModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + PhoneticXeusForCTC, + PhoneticXeusModel, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "automatic-speech-recognition": PhoneticXeusForCTC, + "feature-extraction": PhoneticXeusModel, + } + if is_torch_available() + else {} + ) + + test_resize_embeddings = False + + def test_batching_equivalence(self, atol=1e-3, rtol=1e-3): + super().test_batching_equivalence(atol=atol, rtol=rtol) + + def setUp(self): + self.model_tester = PhoneticXeusModelTester(self) + self.config_tester = ConfigTester(self, config_class=PhoneticXeusConfig, hidden_size=32) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @require_torch_accelerator + @require_torch_fp16 + def test_model_float16(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_float16(*config_and_inputs) + + def test_ctc_loss_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_ctc_loss(*config_and_inputs) + + def test_ctc_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_ctc_training(*config_and_inputs) + + def test_labels_out_of_vocab(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_labels_out_of_vocab(*config_and_inputs) + + @unittest.skip(reason="PhoneticXeus has no inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="PhoneticXeus has input_values instead of input_ids") + def test_forward_signature(self): + pass + + @unittest.skip(reason="PhoneticXeus has no token embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Feed forward chunking is not implemented") + def test_feed_forward_chunking(self): + pass + + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + # set layer drop to 0 + model.config.layerdrop = 0.0 + + input_values = inputs_dict["input_values"] + + input_lengths = torch.tensor( + [input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device + ) + output_lengths = model._get_feat_extract_output_lengths(input_lengths) + + labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size) + inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"]) + inputs_dict["labels"] = labels + + outputs = model(**inputs_dict) + + output = outputs[0] + + # Encoder-/Decoder-only models + hidden_states = outputs.hidden_states[0] + attentions = outputs.attentions[0] + + hidden_states.retain_grad() + attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(hidden_states.grad) + self.assertIsNotNone(attentions.grad) + + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.fill_(3) + if hasattr(module, "weight_g") and module.weight_g is not None: + module.weight_g.data.fill_(3) + if hasattr(module, "weight_v") and module.weight_v is not None: + module.weight_v.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.fill_(3) + + @slow + def test_model_from_pretrained(self): + model = PhoneticXeusModel.from_pretrained("changelinglab/PhoneticXeus-hf") + self.assertIsNotNone(model) + + +@require_torch +@slow +class PhoneticXeusModelIntegrationTest(unittest.TestCase): + def _load_datasamples(self, num_samples): + from datasets import load_dataset + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + speech_samples = ds.sort("id").filter(lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)]) + speech_samples = speech_samples[:num_samples]["audio"] + + return [x["array"] for x in speech_samples] + + def test_inference_ctc_batched(self): + model = PhoneticXeusForCTC.from_pretrained("changelinglab/PhoneticXeus-hf") + model.to(torch_device) + model.eval() + + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("changelinglab/PhoneticXeus-hf") + + input_speech = self._load_datasamples(2) + + inputs = feature_extractor(input_speech, return_tensors="pt", padding=True, sampling_rate=16000) + + input_values = inputs.input_values.to(torch_device) + + with torch.no_grad(): + logits = model(input_values).logits + + predicted_ids = torch.argmax(logits, dim=-1) + + # Verify output shape: batch=2, time steps, vocab=428 + self.assertEqual(logits.shape[0], 2) + self.assertEqual(logits.shape[-1], 428) + + # Verify predictions are valid token ids + self.assertTrue((predicted_ids >= 0).all()) + self.assertTrue((predicted_ids < 428).all()) From ec4313f09f809c8ef9e90542bd874d518b8a40e3 Mon Sep 17 00:00:00 2001 From: Shikhar Bharadwaj Date: Fri, 10 Apr 2026 20:47:24 -0500 Subject: [PATCH 5/6] simplify ckpt loading --- .../convert_phoneticxeus_checkpoint.py | 53 ++----------------- 1 file changed, 4 insertions(+), 49 deletions(-) diff --git a/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py b/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py index b025e499d344..5ddb604ca0d9 100644 --- a/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py +++ b/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py @@ -21,57 +21,12 @@ import argparse import re -import sys import torch from transformers import PhoneticXeusConfig, PhoneticXeusForCTC, PhoneticXeusTokenizer, Wav2Vec2FeatureExtractor -def load_checkpoint(checkpoint_path: str) -> dict: - """Load ESPnet/Lightning checkpoint, handling pickled module references.""" - import types - - # Stub modules that may be pickled in the checkpoint - for mod_name in ["src", "lightning"]: - if mod_name not in sys.modules: - m = types.ModuleType(mod_name) - m.__path__ = [] - m.__file__ = "" - sys.modules[mod_name] = m - - # Recursively stub submodules - class _StubFinder: - def find_module(self, fullname, path=None): - if fullname.startswith(("src.", "lightning.")): - return self - return None - - def load_module(self, fullname): - if fullname in sys.modules: - return sys.modules[fullname] - m = types.ModuleType(fullname) - m.__path__ = [] - m.__file__ = "" - m.__loader__ = self - sys.modules[fullname] = m - return m - - sys.meta_path.insert(0, _StubFinder()) - - state = torch.load(checkpoint_path, map_location="cpu", weights_only=False) - - # Extract state_dict from Lightning checkpoint - if "state_dict" in state: - sd = state["state_dict"] - # Strip "net." prefix from Lightning module wrapping - sd = {k.replace("net.", "", 1): v for k, v in sd.items() if k.startswith("net.")} - else: - sd = state - - return sd - - _PREFIX = PhoneticXeusForCTC.base_model_prefix @@ -153,16 +108,16 @@ def main(): # Resolve checkpoint path ckpt_path = args.checkpoint_path - if not ckpt_path.endswith(".ckpt") and not ckpt_path.endswith(".pth"): + if not ckpt_path.endswith((".pt", ".bin", ".ckpt", ".pth")): # Assume HuggingFace repo from huggingface_hub import hf_hub_download - ckpt_path = hf_hub_download(ckpt_path, "checkpoint-22000.ckpt") + ckpt_path = hf_hub_download(ckpt_path, "phoneticxeus_state_dict.pt") print(f"Downloaded checkpoint to: {ckpt_path}") - # Load and convert + # Load state dict (expects a plain dict of tensors, not a Lightning checkpoint) print("Loading checkpoint...") - sd = load_checkpoint(ckpt_path) + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True) print(f"Loaded {len(sd)} keys from checkpoint") print("Converting state dict...") From 1788c0b8ff47ef5de270e150be4f586c4053dff0 Mon Sep 17 00:00:00 2001 From: Shikhar Bharadwaj Date: Fri, 10 Apr 2026 21:07:14 -0500 Subject: [PATCH 6/6] register phoneticxeus for auto loading --- src/transformers/models/auto/configuration_auto.py | 2 ++ src/transformers/models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 ++ src/transformers/models/auto/processing_auto.py | 1 + src/transformers/models/auto/tokenization_auto.py | 1 + 5 files changed, 7 insertions(+) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 2c0fe88d0e74..41535c319cba 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -360,6 +360,7 @@ ("phi3", "Phi3Config"), ("phi4_multimodal", "Phi4MultimodalConfig"), ("phimoe", "PhimoeConfig"), + ("phoneticxeus", "PhoneticXeusConfig"), ("pi0", "PI0Config"), ("pix2struct", "Pix2StructConfig"), ("pixio", "PixioConfig"), @@ -893,6 +894,7 @@ ("phi4_multimodal", "Phi4Multimodal"), ("phimoe", "Phimoe"), ("phobert", "PhoBERT"), + ("phoneticxeus", "PhoneticXeus"), ("pi0", "PI0"), ("pix2struct", "Pix2Struct"), ("pixio", "Pixio"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 111c56efb436..acc120ae2023 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -64,6 +64,7 @@ ("parakeet_encoder", "ParakeetFeatureExtractor"), ("pe_audio", "PeAudioFeatureExtractor"), ("pe_audio_video", "PeAudioFeatureExtractor"), + ("phoneticxeus", "Wav2Vec2FeatureExtractor"), ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"), ("pop2piano", "Pop2PianoFeatureExtractor"), ("qwen2_5_omni", "WhisperFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d4cb17cddfa6..d6b6a336b925 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -352,6 +352,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("phi3", "Phi3Model"), ("phi4_multimodal", "Phi4MultimodalModel"), ("phimoe", "PhimoeModel"), + ("phoneticxeus", "PhoneticXeusModel"), ("pi0", "PI0Model"), ("pixio", "PixioModel"), ("pixtral", "PixtralVisionModel"), @@ -1622,6 +1623,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("hubert", "HubertForCTC"), ("lasr_ctc", "LasrForCTC"), ("parakeet_ctc", "ParakeetForCTC"), + ("phoneticxeus", "PhoneticXeusForCTC"), ("sew", "SEWForCTC"), ("sew-d", "SEWDForCTC"), ("unispeech", "UniSpeechForCTC"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 262480b71485..ec42a263ba17 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -135,6 +135,7 @@ ("paligemma", "PaliGemmaProcessor"), ("perception_lm", "PerceptionLMProcessor"), ("phi4_multimodal", "Phi4MultimodalProcessor"), + ("phoneticxeus", "PhoneticXeusProcessor"), ("pi0", "PI0Processor"), ("pix2struct", "Pix2StructProcessor"), ("pixtral", "PixtralProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 1b38f2e7a3f1..0ed10b9dac18 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -250,6 +250,7 @@ ("perceiver", "PerceiverTokenizer"), ("phi", "GPT2Tokenizer" if is_tokenizers_available() else None), ("phobert", "PhobertTokenizer"), + ("phoneticxeus", "PhoneticXeusTokenizer"), ("pix2struct", "T5Tokenizer" if is_tokenizers_available() else None), ( "pixtral",