diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index ba69db1c5e78..9e3e144c241f 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1075,6 +1075,8 @@
title: Parakeet
- local: model_doc/pe_audio
title: PE Audio
+ - local: model_doc/phoneticxeus
+ title: PhoneticXeus
- local: model_doc/pop2piano
title: Pop2Piano
- local: model_doc/seamless_m4t
diff --git a/docs/source/en/model_doc/phoneticxeus.md b/docs/source/en/model_doc/phoneticxeus.md
new file mode 100644
index 000000000000..41a517ef9019
--- /dev/null
+++ b/docs/source/en/model_doc/phoneticxeus.md
@@ -0,0 +1,67 @@
+
+
+# PhoneticXeus
+
+
+

+
+
+## Overview
+
+PhoneticXeus is a multilingual IPA phone recognition model that uses a CNN feature encoder followed by an
+E-Branchformer encoder with a CTC head. The E-Branchformer architecture runs self-attention and a Convolutional
+Gating MLP (cgMLP) as parallel branches merged via depthwise convolution, rather than the sequential
+attention-convolution pattern used in Conformer models. The model employs intermediate CTC (interCTC)
+self-conditioning at specified encoder layers.
+
+The model was released by the [Changeling Lab](https://huggingface.co/changelinglab) and can be found at
+[changelinglab/PhoneticXeus](https://huggingface.co/changelinglab/PhoneticXeus).
+
+## Usage tips
+
+- PhoneticXeus outputs IPA (International Phonetic Alphabet) phone sequences rather than text transcriptions.
+- The vocabulary consists of 428 IPA phoneme tokens.
+- It uses `Wav2Vec2FeatureExtractor` for audio preprocessing and `PhoneticXeusTokenizer` for CTC decoding.
+- The E-Branchformer encoder uses parallel self-attention and cgMLP branches instead of the sequential
+ attention and convolution used in Conformer models.
+- InterCTC self-conditioning feeds intermediate CTC predictions back into the encoder at layers 4, 8, and 12.
+
+## Resources
+
+- [Automatic speech recognition task guide](../tasks/asr)
+
+## PhoneticXeusConfig
+
+[[autodoc]] PhoneticXeusConfig
+
+## PhoneticXeusModel
+
+[[autodoc]] PhoneticXeusModel
+ - forward
+
+## PhoneticXeusForCTC
+
+[[autodoc]] PhoneticXeusForCTC
+ - forward
+
+## PhoneticXeusTokenizer
+
+[[autodoc]] PhoneticXeusTokenizer
+
+## PhoneticXeusProcessor
+
+[[autodoc]] PhoneticXeusProcessor
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 2c0fe88d0e74..41535c319cba 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -360,6 +360,7 @@
("phi3", "Phi3Config"),
("phi4_multimodal", "Phi4MultimodalConfig"),
("phimoe", "PhimoeConfig"),
+ ("phoneticxeus", "PhoneticXeusConfig"),
("pi0", "PI0Config"),
("pix2struct", "Pix2StructConfig"),
("pixio", "PixioConfig"),
@@ -893,6 +894,7 @@
("phi4_multimodal", "Phi4Multimodal"),
("phimoe", "Phimoe"),
("phobert", "PhoBERT"),
+ ("phoneticxeus", "PhoneticXeus"),
("pi0", "PI0"),
("pix2struct", "Pix2Struct"),
("pixio", "Pixio"),
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index 111c56efb436..acc120ae2023 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -64,6 +64,7 @@
("parakeet_encoder", "ParakeetFeatureExtractor"),
("pe_audio", "PeAudioFeatureExtractor"),
("pe_audio_video", "PeAudioFeatureExtractor"),
+ ("phoneticxeus", "Wav2Vec2FeatureExtractor"),
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
("pop2piano", "Pop2PianoFeatureExtractor"),
("qwen2_5_omni", "WhisperFeatureExtractor"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index d4cb17cddfa6..d6b6a336b925 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -352,6 +352,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("phi3", "Phi3Model"),
("phi4_multimodal", "Phi4MultimodalModel"),
("phimoe", "PhimoeModel"),
+ ("phoneticxeus", "PhoneticXeusModel"),
("pi0", "PI0Model"),
("pixio", "PixioModel"),
("pixtral", "PixtralVisionModel"),
@@ -1622,6 +1623,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("hubert", "HubertForCTC"),
("lasr_ctc", "LasrForCTC"),
("parakeet_ctc", "ParakeetForCTC"),
+ ("phoneticxeus", "PhoneticXeusForCTC"),
("sew", "SEWForCTC"),
("sew-d", "SEWDForCTC"),
("unispeech", "UniSpeechForCTC"),
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index 262480b71485..ec42a263ba17 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -135,6 +135,7 @@
("paligemma", "PaliGemmaProcessor"),
("perception_lm", "PerceptionLMProcessor"),
("phi4_multimodal", "Phi4MultimodalProcessor"),
+ ("phoneticxeus", "PhoneticXeusProcessor"),
("pi0", "PI0Processor"),
("pix2struct", "Pix2StructProcessor"),
("pixtral", "PixtralProcessor"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 1b38f2e7a3f1..0ed10b9dac18 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -250,6 +250,7 @@
("perceiver", "PerceiverTokenizer"),
("phi", "GPT2Tokenizer" if is_tokenizers_available() else None),
("phobert", "PhobertTokenizer"),
+ ("phoneticxeus", "PhoneticXeusTokenizer"),
("pix2struct", "T5Tokenizer" if is_tokenizers_available() else None),
(
"pixtral",
diff --git a/src/transformers/models/phoneticxeus/__init__.py b/src/transformers/models/phoneticxeus/__init__.py
new file mode 100644
index 000000000000..f8e947bed548
--- /dev/null
+++ b/src/transformers/models/phoneticxeus/__init__.py
@@ -0,0 +1,34 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+
+
+if TYPE_CHECKING:
+ from .configuration_phoneticxeus import *
+ from .modeling_phoneticxeus import *
+ from .processing_phoneticxeus import *
+ from .tokenization_phoneticxeus import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ modules = {
+ "configuration_phoneticxeus": ["PhoneticXeusConfig"],
+ "modeling_phoneticxeus": ["PhoneticXeusForCTC", "PhoneticXeusModel", "PhoneticXeusPreTrainedModel"],
+ "processing_phoneticxeus": ["PhoneticXeusProcessor"],
+ "tokenization_phoneticxeus": ["PhoneticXeusTokenizer"],
+ }
+ sys.modules[__name__] = _LazyModule(__name__, _file, modules, module_spec=__spec__)
diff --git a/src/transformers/models/phoneticxeus/configuration_phoneticxeus.py b/src/transformers/models/phoneticxeus/configuration_phoneticxeus.py
new file mode 100644
index 000000000000..3b641459af99
--- /dev/null
+++ b/src/transformers/models/phoneticxeus/configuration_phoneticxeus.py
@@ -0,0 +1,177 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PhoneticXeus model configuration"""
+
+import functools
+import operator
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import auto_docstring
+
+
+@auto_docstring(checkpoint="changelinglab/PhoneticXeus")
+class PhoneticXeusConfig(PretrainedConfig):
+ r"""
+ vocab_size (`int`, *optional*, defaults to 428):
+ Vocabulary size of the PhoneticXeus model (IPA phoneme inventory).
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the encoder layers.
+ num_hidden_layers (`int`, *optional*, defaults to 19):
+ Number of E-Branchformer encoder layers.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the encoder.
+ intermediate_size (`int`, *optional*, defaults to 4096):
+ Dimensionality of the feed-forward layers in the encoder.
+ hidden_act (`str`, *optional*, defaults to `"swish"`):
+ The non-linear activation function in the feed-forward layers.
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for fully connected layers in the encoder.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for the feature projection layer.
+ final_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for the final projection layer of [`PhoneticXeusForCTC`].
+ layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability during training.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ Standard deviation of the truncated normal initializer for weight initialization.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ Epsilon for layer normalization.
+ normalize_audio (`bool`, *optional*, defaults to `True`):
+ Whether to apply layer normalization to the raw audio waveform before the CNN feature encoder.
+ feat_extract_norm (`str`, *optional*, defaults to `"layer"`):
+ The norm to be applied to 1D convolutional layers in the feature encoder. One of `"group"` for group
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
+ convolutional layers.
+ feat_extract_activation (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the 1D convolutional layers of the feature extractor.
+ conv_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ Number of input and output channels of each 1D convolutional layer in the feature encoder.
+ conv_stride (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ Stride of each 1D convolutional layer in the feature encoder.
+ conv_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 2, 2)`):
+ Kernel size of each 1D convolutional layer in the feature encoder.
+ conv_bias (`bool`, *optional*, defaults to `True`):
+ Whether the 1D convolutional layers have a bias.
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
+ Kernel size of the convolutional positional embeddings layer.
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
+ Number of groups of the convolutional positional embeddings layer.
+ conv_pos_weight_norm (`bool`, *optional*, defaults to `True`):
+ Whether to apply weight normalization to the convolutional positional embedding layer.
+ cgmlp_linear_units (`int`, *optional*, defaults to 4096):
+ Hidden dimensionality of the ConvolutionalGatingMLP in each E-Branchformer layer.
+ cgmlp_conv_kernel (`int`, *optional*, defaults to 31):
+ Kernel size of the depthwise convolution in the Convolutional Spatial Gating Unit (CSGU).
+ use_linear_after_conv (`bool`, *optional*, defaults to `False`):
+ Whether to apply a linear layer after the depthwise convolution in CSGU.
+ gate_activation (`str`, *optional*, defaults to `"identity"`):
+ Activation function for gating in CSGU.
+ merge_conv_kernel (`int`, *optional*, defaults to 31):
+ Kernel size of the depthwise convolution used to merge the two branches in each E-Branchformer layer.
+ use_ffn (`bool`, *optional*, defaults to `True`):
+ Whether to use feed-forward layers in each E-Branchformer layer.
+ macaron_ffn (`bool`, *optional*, defaults to `True`):
+ Whether to use macaron-style pre-branch feed-forward layer (half-step residual).
+ interctc_layer_idx (`tuple[int]` or `list[int]`, *optional*, defaults to `(4, 8, 12)`):
+ Layer indices (1-based) at which intermediate CTC self-conditioning is applied. At each specified layer,
+ the encoder output is projected through the CTC head, softmaxed, and fed back via a conditioning layer.
+ interctc_use_conditioning (`bool`, *optional*, defaults to `True`):
+ Whether to enable intermediate CTC self-conditioning in the encoder.
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`.
+ ctc_zero_infinity (`bool`, *optional*, defaults to `True`):
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`.
+
+ Example:
+
+ ```python
+ >>> from transformers import PhoneticXeusConfig, PhoneticXeusModel
+
+ >>> configuration = PhoneticXeusConfig()
+ >>> model = PhoneticXeusModel(configuration)
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "phoneticxeus"
+
+ vocab_size: int = 428
+ hidden_size: int = 1024
+ num_hidden_layers: int = 19
+ num_attention_heads: int = 8
+ intermediate_size: int = 4096
+
+ hidden_act: str = "swish"
+ hidden_dropout: float | int = 0.1
+ attention_dropout: float | int = 0.1
+ feat_proj_dropout: float | int = 0.0
+ final_dropout: float | int = 0.1
+ layerdrop: float | int = 0.0
+ initializer_range: float = 0.02
+ layer_norm_eps: float = 1e-5
+
+ normalize_audio: bool = True
+ feat_extract_norm: str = "layer"
+ feat_extract_activation: str = "gelu"
+ conv_dim: list[int] | tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512)
+ conv_stride: list[int] | tuple[int, ...] = (5, 2, 2, 2, 2, 2, 2)
+ conv_kernel: list[int] | tuple[int, ...] = (10, 3, 3, 3, 3, 2, 2)
+ conv_bias: bool = True
+
+ num_conv_pos_embeddings: int = 128
+ num_conv_pos_embedding_groups: int = 16
+ conv_pos_weight_norm: bool = True
+
+ cgmlp_linear_units: int = 4096
+ cgmlp_conv_kernel: int = 31
+ use_linear_after_conv: bool = False
+ gate_activation: str = "identity"
+ merge_conv_kernel: int = 31
+ use_ffn: bool = True
+ macaron_ffn: bool = True
+
+ interctc_layer_idx: list[int] | tuple[int, ...] = (4, 8, 12)
+ interctc_use_conditioning: bool = True
+
+ ctc_loss_reduction: str = "sum"
+ ctc_zero_infinity: bool = True
+
+ pad_token_id: int | None = 0
+ bos_token_id: int | None = 1
+ eos_token_id: int | list[int] | None = 2
+
+ def __post_init__(self, **kwargs):
+ self.num_feat_extract_layers = len(self.conv_dim)
+ super().__post_init__(**kwargs)
+
+ def validate_architecture(self):
+ if (
+ (len(self.conv_stride) != self.num_feat_extract_layers)
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
+ ):
+ raise ValueError(
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ )
+
+ @property
+ def inputs_to_logits_ratio(self):
+ return functools.reduce(operator.mul, self.conv_stride, 1)
+
+
+__all__ = ["PhoneticXeusConfig"]
diff --git a/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py b/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py
new file mode 100644
index 000000000000..5ddb604ca0d9
--- /dev/null
+++ b/src/transformers/models/phoneticxeus/convert_phoneticxeus_checkpoint.py
@@ -0,0 +1,191 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert PhoneticXeus (ESPnet/Lightning) checkpoint to HuggingFace format.
+
+Usage:
+ python convert_phoneticxeus_checkpoint.py \
+ --checkpoint_path changelinglab/PhoneticXeus \
+ --output_dir ./phoneticxeus_hf
+"""
+
+import argparse
+import re
+
+import torch
+
+from transformers import PhoneticXeusConfig, PhoneticXeusForCTC, PhoneticXeusTokenizer, Wav2Vec2FeatureExtractor
+
+
+_PREFIX = PhoneticXeusForCTC.base_model_prefix
+
+
+def convert_key(key: str) -> str | None:
+ """Map a single ESPnet state_dict key to the HuggingFace equivalent."""
+
+ if key.startswith("frontend.layers."):
+ return key.replace("frontend.layers.", f"{_PREFIX}.feature_extractor.conv_layers.")
+
+ if key.startswith("preencoder.linear_out."):
+ return key.replace("preencoder.linear_out.", f"{_PREFIX}.feature_projection.projection.")
+
+ if key.startswith("encoder.embed.0.convs.0."):
+ return key.replace("encoder.embed.0.convs.0.", f"{_PREFIX}.encoder.pos_conv_embed.conv.")
+
+ if key.startswith("encoder.after_norm."):
+ return key.replace("encoder.after_norm.", f"{_PREFIX}.encoder.layer_norm.")
+
+ if key.startswith("encoder.conditioning_layer."):
+ return key.replace("encoder.conditioning_layer.", f"{_PREFIX}.encoder.conditioning_layer.")
+
+ m = re.match(r"encoder\.encoders\.(\d+)\.(.*)", key)
+ if m:
+ layer_idx, rest = m.group(1), m.group(2)
+ if rest.startswith("attn."):
+ rest = rest.replace("attn.", "self_attn.", 1)
+ return f"{_PREFIX}.encoder.layers.{layer_idx}.{rest}"
+
+ if key.startswith("ctc.ctc_lo."):
+ return key.replace("ctc.ctc_lo.", "lm_head.")
+
+ return None
+
+
+def convert_state_dict(sd: dict) -> dict:
+ """Convert full ESPnet state dict to HuggingFace format."""
+ new_sd = {}
+ skipped = []
+ for key, value in sd.items():
+ new_key = convert_key(key)
+ if new_key is not None:
+ new_sd[new_key] = value
+ else:
+ skipped.append(key)
+
+ if skipped:
+ print(f"Skipped {len(skipped)} keys: {skipped[:10]}{'...' if len(skipped) > 10 else ''}")
+
+ return new_sd
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Convert PhoneticXeus checkpoint to HuggingFace format")
+ parser.add_argument(
+ "--checkpoint_path",
+ type=str,
+ required=True,
+ help="Path to .ckpt file or HuggingFace repo (e.g., 'changelinglab/PhoneticXeus')",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ required=True,
+ help="Output directory for HuggingFace model files",
+ )
+ parser.add_argument(
+ "--vocab_file",
+ type=str,
+ default=None,
+ help="Path to ipa_vocab.json. If not provided, downloads from HF repo.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ type=str,
+ default=None,
+ help="If set, push model to this HuggingFace Hub repo",
+ )
+ args = parser.parse_args()
+
+ # Resolve checkpoint path
+ ckpt_path = args.checkpoint_path
+ if not ckpt_path.endswith((".pt", ".bin", ".ckpt", ".pth")):
+ # Assume HuggingFace repo
+ from huggingface_hub import hf_hub_download
+
+ ckpt_path = hf_hub_download(ckpt_path, "phoneticxeus_state_dict.pt")
+ print(f"Downloaded checkpoint to: {ckpt_path}")
+
+ # Load state dict (expects a plain dict of tensors, not a Lightning checkpoint)
+ print("Loading checkpoint...")
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)
+ print(f"Loaded {len(sd)} keys from checkpoint")
+
+ print("Converting state dict...")
+ hf_sd = convert_state_dict(sd)
+ print(f"Converted to {len(hf_sd)} HuggingFace keys")
+
+ # Create config
+ config = PhoneticXeusConfig()
+ print(
+ f"Config: hidden_size={config.hidden_size}, num_layers={config.num_hidden_layers}, vocab={config.vocab_size}"
+ )
+
+ # Create model and load weights
+ print("Creating HuggingFace model...")
+ model = PhoneticXeusForCTC(config)
+ load_info = model.load_state_dict(hf_sd, strict=False)
+ print(f"Missing keys: {load_info.missing_keys}")
+ print(f"Unexpected keys: {load_info.unexpected_keys}")
+
+ # Verify with dummy forward pass
+ print("Running verification forward pass...")
+ model.eval()
+ with torch.no_grad():
+ dummy_input = torch.randn(1, 16000) # 1 second of audio
+ output = model(dummy_input)
+ logits = output.logits
+ print(f"Output shape: {logits.shape}") # Should be (1, T, 428)
+ assert logits.shape[-1] == config.vocab_size, (
+ f"Expected vocab_size={config.vocab_size}, got {logits.shape[-1]}"
+ )
+ print("Verification passed!")
+
+ # Save
+ print(f"Saving model to {args.output_dir}...")
+ model.save_pretrained(args.output_dir)
+ config.save_pretrained(args.output_dir)
+
+ # Save feature extractor
+ feature_extractor = Wav2Vec2FeatureExtractor(
+ feature_size=1,
+ sampling_rate=16000,
+ padding_value=0.0,
+ do_normalize=True,
+ return_attention_mask=True,
+ )
+ feature_extractor.save_pretrained(args.output_dir)
+
+ # Save tokenizer + processor if vocab provided
+ if args.vocab_file:
+ import shutil
+
+ from transformers import PhoneticXeusProcessor
+
+ shutil.copy(args.vocab_file, f"{args.output_dir}/vocab.json")
+ tokenizer = PhoneticXeusTokenizer(f"{args.output_dir}/vocab.json")
+ tokenizer.save_pretrained(args.output_dir)
+ processor = PhoneticXeusProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+ processor.save_pretrained(args.output_dir)
+ print(f"Saved tokenizer + processor (vocab_size={tokenizer.vocab_size})")
+
+ if args.push_to_hub:
+ print(f"Pushing to hub: {args.push_to_hub}")
+ model.push_to_hub(args.push_to_hub)
+ config.push_to_hub(args.push_to_hub)
+ feature_extractor.push_to_hub(args.push_to_hub)
+
+ print("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/phoneticxeus/modeling_phoneticxeus.py b/src/transformers/models/phoneticxeus/modeling_phoneticxeus.py
new file mode 100644
index 000000000000..34940f074ab4
--- /dev/null
+++ b/src/transformers/models/phoneticxeus/modeling_phoneticxeus.py
@@ -0,0 +1,673 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/phoneticxeus/modular_phoneticxeus.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_phoneticxeus.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import torch
+from torch import nn
+
+from ... import initialization as init
+from ...activations import ACT2FN
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, CausalLMOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring
+from .configuration_phoneticxeus import PhoneticXeusConfig
+
+
+class PhoneticXeusNoLayerNormConvLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class PhoneticXeusLayerNormConvLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class PhoneticXeusGroupNormConvLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class PhoneticXeusFeatureEncoder(nn.Module):
+ """Construct the features from raw audio waveform"""
+
+ def __init__(self, config):
+ super().__init__()
+
+ if config.feat_extract_norm == "group":
+ conv_layers = [PhoneticXeusGroupNormConvLayer(config, layer_id=0)] + [
+ PhoneticXeusNoLayerNormConvLayer(config, layer_id=i + 1)
+ for i in range(config.num_feat_extract_layers - 1)
+ ]
+ elif config.feat_extract_norm == "layer":
+ conv_layers = [
+ PhoneticXeusLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
+ ]
+ else:
+ raise ValueError(
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
+ )
+ self.conv_layers = nn.ModuleList(conv_layers)
+ self.gradient_checkpointing = False
+ self._requires_grad = True
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def forward(self, input_values):
+ hidden_states = input_values[:, None]
+
+ # make sure hidden_states require grad for gradient_checkpointing
+ if self._requires_grad and self.training:
+ hidden_states.requires_grad = True
+
+ for conv_layer in self.conv_layers:
+ hidden_states = conv_layer(hidden_states)
+
+ return hidden_states
+
+
+class PhoneticXeusFeatureProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.projection(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class PhoneticXeusPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.num_conv_pos_embeddings,
+ padding=config.num_conv_pos_embeddings // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ if config.conv_pos_weight_norm:
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+ if hasattr(self.conv, "parametrizations"):
+ weight_g = self.conv.parametrizations.weight.original0
+ weight_v = self.conv.parametrizations.weight.original1
+ else:
+ weight_g = self.conv.weight_g
+ weight_v = self.conv.weight_v
+ deepspeed.zero.register_external_parameter(self, weight_v)
+ deepspeed.zero.register_external_parameter(self, weight_g)
+ else:
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+
+ self.num_pad_remove = 1 if config.num_conv_pos_embeddings % 2 == 0 else 0
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.conv(hidden_states)
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+ hidden_states = nn.functional.gelu(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class PhoneticXeusSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.head_size = config.hidden_size // config.num_attention_heads
+
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.attention_dropout)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ batch_size, seq_len, _ = hidden_states.size()
+
+ query = self.linear_q(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
+ key = self.linear_k(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
+ value = self.linear_v(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
+
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
+ if attention_mask is not None:
+ scores = scores + attention_mask
+
+ attn_weights = torch.softmax(scores, dim=-1)
+ attn_weights = self.dropout(attn_weights)
+
+ hidden_states = torch.matmul(attn_weights, value)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, seq_len, -1)
+ hidden_states = self.linear_out(hidden_states)
+
+ return hidden_states, attn_weights if output_attentions else None
+
+
+class PhoneticXeusFeedForward(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.w_1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.activation = ACT2FN[config.hidden_act]
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.w_2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states):
+ hidden_states = self.w_1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.w_2(hidden_states)
+ return hidden_states
+
+
+class PhoneticXeusConvolutionalSpatialGatingUnit(nn.Module):
+ """CSGU: splits input, applies depthwise conv on gate half, then element-wise multiply."""
+
+ def __init__(self, config):
+ super().__init__()
+ n_channels = config.cgmlp_linear_units // 2
+ self.norm = nn.LayerNorm(n_channels)
+ self.conv = nn.Conv1d(
+ n_channels,
+ n_channels,
+ config.cgmlp_conv_kernel,
+ stride=1,
+ padding=(config.cgmlp_conv_kernel - 1) // 2,
+ groups=n_channels,
+ )
+ self.linear = nn.Linear(n_channels, n_channels) if config.use_linear_after_conv else None
+
+ if config.gate_activation == "identity":
+ self.gate_activation = nn.Identity()
+ else:
+ self.gate_activation = ACT2FN[config.gate_activation]
+
+ self.dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(self, hidden_states):
+ x_r, x_g = hidden_states.chunk(2, dim=-1)
+ x_g = self.norm(x_g)
+ x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2)
+ if self.linear is not None:
+ x_g = self.linear(x_g)
+ x_g = self.gate_activation(x_g)
+ return self.dropout(x_r * x_g)
+
+
+class PhoneticXeusConvolutionalGatingMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.channel_proj1 = nn.Sequential(
+ nn.Linear(config.hidden_size, config.cgmlp_linear_units),
+ nn.GELU(),
+ )
+ self.csgu = PhoneticXeusConvolutionalSpatialGatingUnit(config)
+ self.channel_proj2 = nn.Linear(config.cgmlp_linear_units // 2, config.hidden_size)
+
+ def forward(self, hidden_states):
+ hidden_states = self.channel_proj1(hidden_states)
+ hidden_states = self.csgu(hidden_states)
+ hidden_states = self.channel_proj2(hidden_states)
+ return hidden_states
+
+
+class PhoneticXeusEBranchformerEncoderLayer(GradientCheckpointingLayer):
+ """E-Branchformer layer: parallel self-attention and cgMLP branches merged via depthwise conv."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.self_attn = PhoneticXeusSelfAttention(config)
+ self.cgmlp = PhoneticXeusConvolutionalGatingMLP(config)
+
+ self.norm_mha = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.norm_mlp = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.depthwise_conv_fusion = nn.Conv1d(
+ 2 * config.hidden_size,
+ 2 * config.hidden_size,
+ kernel_size=config.merge_conv_kernel,
+ stride=1,
+ padding=(config.merge_conv_kernel - 1) // 2,
+ groups=2 * config.hidden_size,
+ bias=True,
+ )
+ self.merge_proj = nn.Linear(2 * config.hidden_size, config.hidden_size)
+
+ self.ff_scale = 1.0
+ if config.use_ffn and config.macaron_ffn:
+ self.feed_forward_macaron = PhoneticXeusFeedForward(config)
+ self.norm_ff_macaron = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.ff_scale = 0.5
+ else:
+ self.feed_forward_macaron = None
+
+ if config.use_ffn:
+ self.feed_forward = PhoneticXeusFeedForward(config)
+ self.norm_ff = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ else:
+ self.feed_forward = None
+
+ self.norm_final = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ if self.feed_forward_macaron is not None:
+ residual = hidden_states
+ hidden_states = residual + self.ff_scale * self.dropout(
+ self.feed_forward_macaron(self.norm_ff_macaron(hidden_states))
+ )
+
+ x1 = self.norm_mha(hidden_states)
+ x1, attn_weights = self.self_attn(x1, attention_mask=attention_mask, output_attentions=output_attentions)
+ x1 = self.dropout(x1)
+
+ x2 = self.norm_mlp(hidden_states)
+ x2 = self.dropout(self.cgmlp(x2))
+
+ x_concat = torch.cat([x1, x2], dim=-1)
+ x_tmp = self.depthwise_conv_fusion(x_concat.transpose(1, 2)).transpose(1, 2)
+ hidden_states = hidden_states + self.dropout(self.merge_proj(x_concat + x_tmp))
+
+ if self.feed_forward is not None:
+ residual = hidden_states
+ hidden_states = residual + self.ff_scale * self.dropout(self.feed_forward(self.norm_ff(hidden_states)))
+
+ hidden_states = self.norm_final(hidden_states)
+ return hidden_states, attn_weights
+
+
+class PhoneticXeusEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pos_conv_embed = PhoneticXeusPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList(
+ [PhoneticXeusEBranchformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+
+ self.interctc_layer_idx = list(config.interctc_layer_idx) if config.interctc_layer_idx else []
+ self.interctc_use_conditioning = config.interctc_use_conditioning
+ if self.interctc_use_conditioning and self.interctc_layer_idx:
+ self.conditioning_layer = nn.Linear(config.vocab_size, config.hidden_size)
+ else:
+ self.conditioning_layer = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ctc_head: nn.Linear | None = None,
+ ) -> tuple | BaseModelOutput:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0.0
+
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
+ )
+
+ hidden_states = self.pos_conv_embed(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ dropout_probability = torch.rand([])
+ skip_the_layer = self.training and dropout_probability < self.config.layerdrop
+
+ if not skip_the_layer or synced_gpus:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if (
+ self.interctc_use_conditioning
+ and self.conditioning_layer is not None
+ and ctc_head is not None
+ and (i + 1) in self.interctc_layer_idx
+ ):
+ ctc_out = torch.softmax(ctc_head(hidden_states), dim=-1)
+ hidden_states = hidden_states + self.conditioning_layer(ctc_out)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+@auto_docstring
+class PhoneticXeusPreTrainedModel(PreTrainedModel):
+ config: PhoneticXeusConfig
+ base_model_prefix = "phonetic_xeus"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ if isinstance(module, PhoneticXeusPositionalConvEmbedding):
+ std = math.sqrt(4.0 / (module.conv.kernel_size[0] * module.conv.in_channels))
+ init.normal_(module.conv.weight, mean=0, std=std)
+ init.constant_(module.conv.bias, 0)
+ elif isinstance(module, PhoneticXeusFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ init.uniform_(module.projection.weight, a=-k, b=k)
+ init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ init.zeros_(module.bias)
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ init.zeros_(module.bias)
+ init.ones_(module.weight)
+ elif isinstance(module, nn.Conv1d):
+ init.kaiming_normal_(module.weight)
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
+ def _conv_out_length(input_length, kernel_size, stride):
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.cumsum(dim=-1)[:, -1]).to(torch.long)
+ batch_size = attention_mask.shape[0]
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ attention_mask[(torch.arange(batch_size, device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+class PhoneticXeusModel(PhoneticXeusPreTrainedModel):
+ def __init__(self, config: PhoneticXeusConfig):
+ super().__init__(config)
+ self.feature_extractor = PhoneticXeusFeatureEncoder(config)
+ self.feature_projection = PhoneticXeusFeatureProjection(config)
+ self.encoder = PhoneticXeusEncoder(config)
+ self.post_init()
+
+ def freeze_feature_encoder(self):
+ self.feature_extractor._freeze_parameters()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: torch.Tensor | None,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ **kwargs,
+ ) -> tuple | BaseModelOutput:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if self.config.normalize_audio:
+ input_values = torch.nn.functional.layer_norm(input_values, input_values.shape)
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
+
+ hidden_states = self.feature_projection(extract_features)
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ ctc_head=kwargs.get("ctc_head"),
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if not return_dict:
+ return (hidden_states,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+_HIDDEN_STATES_START_POSITION = 1
+
+
+class PhoneticXeusForCTC(PhoneticXeusPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.phonetic_xeus = PhoneticXeusModel(config)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `PhoneticXeusForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
+ self.post_init()
+
+ def freeze_feature_encoder(self):
+ self.phonetic_xeus.feature_extractor._freeze_parameters()
+
+ def freeze_base_model(self):
+ for param in self.phonetic_xeus.parameters():
+ param.requires_grad = False
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: torch.Tensor | None,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: torch.Tensor | None = None,
+ **kwargs,
+ ) -> tuple | CausalLMOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if labels is not None and labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ outputs = self.phonetic_xeus(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ ctc_head=self.lm_head,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ attention_mask = (
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+__all__ = ["PhoneticXeusForCTC", "PhoneticXeusModel", "PhoneticXeusPreTrainedModel"]
diff --git a/src/transformers/models/phoneticxeus/modular_phoneticxeus.py b/src/transformers/models/phoneticxeus/modular_phoneticxeus.py
new file mode 100644
index 000000000000..3535b3a20f0a
--- /dev/null
+++ b/src/transformers/models/phoneticxeus/modular_phoneticxeus.py
@@ -0,0 +1,564 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PhoneticXeus model: E-Branchformer encoder for multilingual phone recognition."""
+
+import math
+
+import torch
+from torch import nn
+
+from ... import initialization as init
+from ...activations import ACT2FN
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, CausalLMOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring
+from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2FeatureEncoder
+from .configuration_phoneticxeus import PhoneticXeusConfig
+
+
+_HIDDEN_STATES_START_POSITION = 1
+
+
+class PhoneticXeusFeatureEncoder(Wav2Vec2FeatureEncoder):
+ pass
+
+
+class PhoneticXeusFeatureProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.projection(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class PhoneticXeusPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.num_conv_pos_embeddings,
+ padding=config.num_conv_pos_embeddings // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ if config.conv_pos_weight_norm:
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+ if hasattr(self.conv, "parametrizations"):
+ weight_g = self.conv.parametrizations.weight.original0
+ weight_v = self.conv.parametrizations.weight.original1
+ else:
+ weight_g = self.conv.weight_g
+ weight_v = self.conv.weight_v
+ deepspeed.zero.register_external_parameter(self, weight_v)
+ deepspeed.zero.register_external_parameter(self, weight_g)
+ else:
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+
+ self.num_pad_remove = 1 if config.num_conv_pos_embeddings % 2 == 0 else 0
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.conv(hidden_states)
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+ hidden_states = nn.functional.gelu(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class PhoneticXeusSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.head_size = config.hidden_size // config.num_attention_heads
+
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.attention_dropout)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ batch_size, seq_len, _ = hidden_states.size()
+
+ query = self.linear_q(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
+ key = self.linear_k(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
+ value = self.linear_v(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
+
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
+ if attention_mask is not None:
+ scores = scores + attention_mask
+
+ attn_weights = torch.softmax(scores, dim=-1)
+ attn_weights = self.dropout(attn_weights)
+
+ hidden_states = torch.matmul(attn_weights, value)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, seq_len, -1)
+ hidden_states = self.linear_out(hidden_states)
+
+ return hidden_states, attn_weights if output_attentions else None
+
+
+class PhoneticXeusFeedForward(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.w_1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.activation = ACT2FN[config.hidden_act]
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.w_2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states):
+ hidden_states = self.w_1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.w_2(hidden_states)
+ return hidden_states
+
+
+class PhoneticXeusConvolutionalSpatialGatingUnit(nn.Module):
+ """CSGU: splits input, applies depthwise conv on gate half, then element-wise multiply."""
+
+ def __init__(self, config):
+ super().__init__()
+ n_channels = config.cgmlp_linear_units // 2
+ self.norm = nn.LayerNorm(n_channels)
+ self.conv = nn.Conv1d(
+ n_channels,
+ n_channels,
+ config.cgmlp_conv_kernel,
+ stride=1,
+ padding=(config.cgmlp_conv_kernel - 1) // 2,
+ groups=n_channels,
+ )
+ self.linear = nn.Linear(n_channels, n_channels) if config.use_linear_after_conv else None
+
+ if config.gate_activation == "identity":
+ self.gate_activation = nn.Identity()
+ else:
+ self.gate_activation = ACT2FN[config.gate_activation]
+
+ self.dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(self, hidden_states):
+ x_r, x_g = hidden_states.chunk(2, dim=-1)
+ x_g = self.norm(x_g)
+ x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2)
+ if self.linear is not None:
+ x_g = self.linear(x_g)
+ x_g = self.gate_activation(x_g)
+ return self.dropout(x_r * x_g)
+
+
+class PhoneticXeusConvolutionalGatingMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.channel_proj1 = nn.Sequential(
+ nn.Linear(config.hidden_size, config.cgmlp_linear_units),
+ nn.GELU(),
+ )
+ self.csgu = PhoneticXeusConvolutionalSpatialGatingUnit(config)
+ self.channel_proj2 = nn.Linear(config.cgmlp_linear_units // 2, config.hidden_size)
+
+ def forward(self, hidden_states):
+ hidden_states = self.channel_proj1(hidden_states)
+ hidden_states = self.csgu(hidden_states)
+ hidden_states = self.channel_proj2(hidden_states)
+ return hidden_states
+
+
+class PhoneticXeusEBranchformerEncoderLayer(GradientCheckpointingLayer):
+ """E-Branchformer layer: parallel self-attention and cgMLP branches merged via depthwise conv."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.self_attn = PhoneticXeusSelfAttention(config)
+ self.cgmlp = PhoneticXeusConvolutionalGatingMLP(config)
+
+ self.norm_mha = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.norm_mlp = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.depthwise_conv_fusion = nn.Conv1d(
+ 2 * config.hidden_size,
+ 2 * config.hidden_size,
+ kernel_size=config.merge_conv_kernel,
+ stride=1,
+ padding=(config.merge_conv_kernel - 1) // 2,
+ groups=2 * config.hidden_size,
+ bias=True,
+ )
+ self.merge_proj = nn.Linear(2 * config.hidden_size, config.hidden_size)
+
+ self.ff_scale = 1.0
+ if config.use_ffn and config.macaron_ffn:
+ self.feed_forward_macaron = PhoneticXeusFeedForward(config)
+ self.norm_ff_macaron = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.ff_scale = 0.5
+ else:
+ self.feed_forward_macaron = None
+
+ if config.use_ffn:
+ self.feed_forward = PhoneticXeusFeedForward(config)
+ self.norm_ff = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ else:
+ self.feed_forward = None
+
+ self.norm_final = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ if self.feed_forward_macaron is not None:
+ residual = hidden_states
+ hidden_states = residual + self.ff_scale * self.dropout(
+ self.feed_forward_macaron(self.norm_ff_macaron(hidden_states))
+ )
+
+ x1 = self.norm_mha(hidden_states)
+ x1, attn_weights = self.self_attn(x1, attention_mask=attention_mask, output_attentions=output_attentions)
+ x1 = self.dropout(x1)
+
+ x2 = self.norm_mlp(hidden_states)
+ x2 = self.dropout(self.cgmlp(x2))
+
+ x_concat = torch.cat([x1, x2], dim=-1)
+ x_tmp = self.depthwise_conv_fusion(x_concat.transpose(1, 2)).transpose(1, 2)
+ hidden_states = hidden_states + self.dropout(self.merge_proj(x_concat + x_tmp))
+
+ if self.feed_forward is not None:
+ residual = hidden_states
+ hidden_states = residual + self.ff_scale * self.dropout(self.feed_forward(self.norm_ff(hidden_states)))
+
+ hidden_states = self.norm_final(hidden_states)
+ return hidden_states, attn_weights
+
+
+class PhoneticXeusEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pos_conv_embed = PhoneticXeusPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList(
+ [PhoneticXeusEBranchformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+
+ self.interctc_layer_idx = list(config.interctc_layer_idx) if config.interctc_layer_idx else []
+ self.interctc_use_conditioning = config.interctc_use_conditioning
+ if self.interctc_use_conditioning and self.interctc_layer_idx:
+ self.conditioning_layer = nn.Linear(config.vocab_size, config.hidden_size)
+ else:
+ self.conditioning_layer = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ctc_head: nn.Linear | None = None,
+ ) -> tuple | BaseModelOutput:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0.0
+
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
+ )
+
+ hidden_states = self.pos_conv_embed(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ dropout_probability = torch.rand([])
+ skip_the_layer = self.training and dropout_probability < self.config.layerdrop
+
+ if not skip_the_layer or synced_gpus:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if (
+ self.interctc_use_conditioning
+ and self.conditioning_layer is not None
+ and ctc_head is not None
+ and (i + 1) in self.interctc_layer_idx
+ ):
+ ctc_out = torch.softmax(ctc_head(hidden_states), dim=-1)
+ hidden_states = hidden_states + self.conditioning_layer(ctc_out)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+@auto_docstring
+class PhoneticXeusPreTrainedModel(PreTrainedModel):
+ config: PhoneticXeusConfig
+ base_model_prefix = "phonetic_xeus"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ if isinstance(module, PhoneticXeusPositionalConvEmbedding):
+ std = math.sqrt(4.0 / (module.conv.kernel_size[0] * module.conv.in_channels))
+ init.normal_(module.conv.weight, mean=0, std=std)
+ init.constant_(module.conv.bias, 0)
+ elif isinstance(module, PhoneticXeusFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ init.uniform_(module.projection.weight, a=-k, b=k)
+ init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ init.zeros_(module.bias)
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ init.zeros_(module.bias)
+ init.ones_(module.weight)
+ elif isinstance(module, nn.Conv1d):
+ init.kaiming_normal_(module.weight)
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
+ def _conv_out_length(input_length, kernel_size, stride):
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.cumsum(dim=-1)[:, -1]).to(torch.long)
+ batch_size = attention_mask.shape[0]
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ attention_mask[(torch.arange(batch_size, device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+class PhoneticXeusModel(PhoneticXeusPreTrainedModel):
+ def __init__(self, config: PhoneticXeusConfig):
+ super().__init__(config)
+ self.feature_extractor = PhoneticXeusFeatureEncoder(config)
+ self.feature_projection = PhoneticXeusFeatureProjection(config)
+ self.encoder = PhoneticXeusEncoder(config)
+ self.post_init()
+
+ def freeze_feature_encoder(self):
+ self.feature_extractor._freeze_parameters()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: torch.Tensor | None,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ **kwargs,
+ ) -> tuple | BaseModelOutput:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if self.config.normalize_audio:
+ input_values = torch.nn.functional.layer_norm(input_values, input_values.shape)
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
+
+ hidden_states = self.feature_projection(extract_features)
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ ctc_head=kwargs.get("ctc_head"),
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if not return_dict:
+ return (hidden_states,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class PhoneticXeusForCTC(PhoneticXeusPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.phonetic_xeus = PhoneticXeusModel(config)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `PhoneticXeusForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
+ self.post_init()
+
+ def freeze_feature_encoder(self):
+ self.phonetic_xeus.feature_extractor._freeze_parameters()
+
+ def freeze_base_model(self):
+ for param in self.phonetic_xeus.parameters():
+ param.requires_grad = False
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: torch.Tensor | None,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: torch.Tensor | None = None,
+ **kwargs,
+ ) -> tuple | CausalLMOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if labels is not None and labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ outputs = self.phonetic_xeus(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ ctc_head=self.lm_head,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ attention_mask = (
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+__all__ = [
+ "PhoneticXeusForCTC",
+ "PhoneticXeusModel",
+ "PhoneticXeusPreTrainedModel",
+]
diff --git a/src/transformers/models/phoneticxeus/processing_phoneticxeus.py b/src/transformers/models/phoneticxeus/processing_phoneticxeus.py
new file mode 100644
index 000000000000..df965531bcd3
--- /dev/null
+++ b/src/transformers/models/phoneticxeus/processing_phoneticxeus.py
@@ -0,0 +1,92 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Processor for PhoneticXeus."""
+
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput
+from ...utils import auto_docstring
+
+
+class PhoneticXeusProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {}
+
+
+@auto_docstring
+class PhoneticXeusProcessor(ProcessorMixin):
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+
+ @auto_docstring
+ def __call__(
+ self,
+ audio: AudioInput | None = None,
+ text: str | list[str] | TextInput | PreTokenizedInput | None = None,
+ **kwargs: Unpack[PhoneticXeusProcessorKwargs],
+ ):
+ r"""
+ Returns:
+ This method returns the audio features and/or tokenized text as needed.
+ """
+ if audio is None and text is None:
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+ output_kwargs = self._merge_kwargs(
+ PhoneticXeusProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if audio is not None:
+ inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
+ if text is not None:
+ encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
+
+ if text is None:
+ return inputs
+ elif audio is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
+
+ def pad(self, *args, **kwargs):
+ """
+ Pads extracted audio features and/or tokenized text.
+ Forwards to [`Wav2Vec2FeatureExtractor.pad`] and/or [`PreTrainedTokenizer.pad`].
+ """
+ input_features = kwargs.pop("input_features", None)
+ labels = kwargs.pop("labels", None)
+ if len(args) > 0:
+ input_features = args[0]
+ args = args[1:]
+
+ if input_features is not None:
+ input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
+ if labels is not None:
+ labels = self.tokenizer.pad(labels, **kwargs)
+
+ if labels is None:
+ return input_features
+ elif input_features is None:
+ return labels
+ else:
+ input_features["labels"] = labels["input_ids"]
+ return input_features
+
+ @property
+ def model_input_names(self):
+ return self.feature_extractor.model_input_names + ["labels"]
+
+
+__all__ = ["PhoneticXeusProcessor"]
diff --git a/src/transformers/models/phoneticxeus/tokenization_phoneticxeus.py b/src/transformers/models/phoneticxeus/tokenization_phoneticxeus.py
new file mode 100644
index 000000000000..dd40f76db40e
--- /dev/null
+++ b/src/transformers/models/phoneticxeus/tokenization_phoneticxeus.py
@@ -0,0 +1,63 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenizer for PhoneticXeus (IPA CTC tokenizer)."""
+
+from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
+
+
+class PhoneticXeusTokenizer(Wav2Vec2CTCTokenizer):
+ """CTC tokenizer for IPA phone sequences.
+
+ Thin wrapper around [`Wav2Vec2CTCTokenizer`] with defaults matching the PhoneticXeus IPA vocabulary
+ (428 tokens). Handles CTC blank collapsing and ID-to-IPA conversion.
+
+ Args:
+ vocab_file (`str`):
+ Path to `vocab.json` mapping IPA phone strings to integer IDs.
+ bos_token (`str`, *optional*, defaults to `""`):
+ Beginning of sentence token.
+ eos_token (`str`, *optional*, defaults to `""`):
+ End of sentence token.
+ unk_token (`str`, *optional*, defaults to `""`):
+ Unknown token.
+ pad_token (`str`, *optional*, defaults to `""`):
+ Padding / CTC blank token.
+ word_delimiter_token (`str`, *optional*, defaults to `" "`):
+ Token used as word delimiter. Set to `" "` since IPA transcriptions use space between words.
+ **kwargs:
+ Additional keyword arguments passed to [`Wav2Vec2CTCTokenizer`].
+ """
+
+ def __init__(
+ self,
+ vocab_file,
+ bos_token="",
+ eos_token="",
+ unk_token="",
+ pad_token="",
+ word_delimiter_token=" ",
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ word_delimiter_token=word_delimiter_token,
+ **kwargs,
+ )
+
+
+__all__ = ["PhoneticXeusTokenizer"]
diff --git a/tests/models/phoneticxeus/__init__.py b/tests/models/phoneticxeus/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/phoneticxeus/test_modeling_phoneticxeus.py b/tests/models/phoneticxeus/test_modeling_phoneticxeus.py
new file mode 100644
index 000000000000..e8971b85d142
--- /dev/null
+++ b/tests/models/phoneticxeus/test_modeling_phoneticxeus.py
@@ -0,0 +1,427 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch PhoneticXeus model."""
+
+import math
+import tempfile
+import unittest
+
+from transformers import PhoneticXeusConfig, is_torch_available
+from transformers.testing_utils import (
+ require_torch,
+ require_torch_accelerator,
+ require_torch_fp16,
+ slow,
+ torch_device,
+)
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ floats_tensor,
+ ids_tensor,
+ random_attention_mask,
+)
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ PhoneticXeusForCTC,
+ PhoneticXeusModel,
+ Wav2Vec2FeatureExtractor,
+ )
+
+
+class PhoneticXeusModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=1024,
+ is_training=False,
+ hidden_size=16,
+ feat_extract_norm="layer",
+ feat_extract_dropout=0.0,
+ feat_extract_activation="gelu",
+ conv_dim=(32, 32, 32),
+ conv_stride=(4, 4, 4),
+ conv_kernel=(8, 8, 8),
+ conv_bias=True,
+ num_conv_pos_embeddings=16,
+ num_conv_pos_embedding_groups=2,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ hidden_dropout=0.1,
+ attention_dropout=0.1,
+ intermediate_size=20,
+ layer_norm_eps=1e-5,
+ hidden_act="swish",
+ initializer_range=0.02,
+ vocab_size=32,
+ cgmlp_linear_units=20,
+ cgmlp_conv_kernel=3,
+ merge_conv_kernel=3,
+ use_ffn=True,
+ macaron_ffn=True,
+ interctc_layer_idx=(1,),
+ interctc_use_conditioning=True,
+ normalize_audio=True,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.hidden_size = hidden_size
+ self.feat_extract_norm = feat_extract_norm
+ self.feat_extract_dropout = feat_extract_dropout
+ self.feat_extract_activation = feat_extract_activation
+ self.conv_dim = conv_dim
+ self.conv_stride = conv_stride
+ self.conv_kernel = conv_kernel
+ self.conv_bias = conv_bias
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.intermediate_size = intermediate_size
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.vocab_size = vocab_size
+ self.cgmlp_linear_units = cgmlp_linear_units
+ self.cgmlp_conv_kernel = cgmlp_conv_kernel
+ self.merge_conv_kernel = merge_conv_kernel
+ self.use_ffn = use_ffn
+ self.macaron_ffn = macaron_ffn
+ self.interctc_layer_idx = interctc_layer_idx
+ self.interctc_use_conditioning = interctc_use_conditioning
+ self.normalize_audio = normalize_audio
+ self.scope = scope
+
+ output_seq_length = self.seq_length
+ for kernel, stride in zip(self.conv_kernel, self.conv_stride):
+ output_seq_length = (output_seq_length - (kernel - 1)) / stride
+ self.output_seq_length = int(math.ceil(output_seq_length))
+ self.encoder_seq_length = self.output_seq_length
+
+ def prepare_config_and_inputs(self):
+ input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ attention_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ config = self.get_config()
+
+ return config, input_values, attention_mask
+
+ def get_config(self):
+ return PhoneticXeusConfig(
+ hidden_size=self.hidden_size,
+ feat_extract_norm=self.feat_extract_norm,
+ feat_extract_activation=self.feat_extract_activation,
+ conv_dim=self.conv_dim,
+ conv_stride=self.conv_stride,
+ conv_kernel=self.conv_kernel,
+ conv_bias=self.conv_bias,
+ num_conv_pos_embeddings=self.num_conv_pos_embeddings,
+ num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ hidden_dropout=self.hidden_dropout,
+ attention_dropout=self.attention_dropout,
+ intermediate_size=self.intermediate_size,
+ layer_norm_eps=self.layer_norm_eps,
+ hidden_act=self.hidden_act,
+ initializer_range=self.initializer_range,
+ vocab_size=self.vocab_size,
+ cgmlp_linear_units=self.cgmlp_linear_units,
+ cgmlp_conv_kernel=self.cgmlp_conv_kernel,
+ merge_conv_kernel=self.merge_conv_kernel,
+ use_ffn=self.use_ffn,
+ macaron_ffn=self.macaron_ffn,
+ interctc_layer_idx=self.interctc_layer_idx,
+ interctc_use_conditioning=self.interctc_use_conditioning,
+ normalize_audio=self.normalize_audio,
+ )
+
+ def create_and_check_model(self, config, input_values, attention_mask):
+ model = PhoneticXeusModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_values, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_model_float16(self, config, input_values, attention_mask):
+ model = PhoneticXeusModel(config=config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = PhoneticXeusModel.from_pretrained(tmpdirname, dtype=torch.float16)
+
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ result = model(input_values.type(dtype=torch.float16), attention_mask=attention_mask)
+
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
+ )
+
+ def check_ctc_loss(self, config, input_values, *args):
+ model = PhoneticXeusForCTC(config=config)
+ model.to(torch_device)
+
+ # make sure that dropout is disabled
+ model.eval()
+
+ input_values = input_values[:3]
+ attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0
+
+ model.config.ctc_loss_reduction = "sum"
+ sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
+
+ model.config.ctc_loss_reduction = "mean"
+ mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
+
+ self.parent.assertTrue(isinstance(sum_loss, float))
+ self.parent.assertTrue(isinstance(mean_loss, float))
+
+ def check_ctc_training(self, config, input_values, *args):
+ config.ctc_zero_infinity = True
+ model = PhoneticXeusForCTC(config=config)
+ model.to(torch_device)
+ model.train()
+
+ # freeze feature encoder
+ model.freeze_feature_encoder()
+
+ input_values = input_values[:3]
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+
+ if max_length_labels[i] < labels.shape[-1]:
+ # it's important that we make sure that target lengths are at least
+ # one shorter than logit lengths to prevent -inf
+ labels[i, max_length_labels[i] - 1 :] = -100
+
+ loss = model(input_values, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_labels_out_of_vocab(self, config, input_values, *args):
+ model = PhoneticXeusForCTC(config)
+ model.to(torch_device)
+ model.train()
+
+ input_values = input_values[:3]
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
+
+ with self.parent.assertRaises(ValueError):
+ model(input_values, labels=labels)
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_values, attention_mask = self.prepare_config_and_inputs()
+ inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class PhoneticXeusModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ all_model_classes = (
+ (
+ PhoneticXeusForCTC,
+ PhoneticXeusModel,
+ )
+ if is_torch_available()
+ else ()
+ )
+ pipeline_model_mapping = (
+ {
+ "automatic-speech-recognition": PhoneticXeusForCTC,
+ "feature-extraction": PhoneticXeusModel,
+ }
+ if is_torch_available()
+ else {}
+ )
+
+ test_resize_embeddings = False
+
+ def test_batching_equivalence(self, atol=1e-3, rtol=1e-3):
+ super().test_batching_equivalence(atol=atol, rtol=rtol)
+
+ def setUp(self):
+ self.model_tester = PhoneticXeusModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=PhoneticXeusConfig, hidden_size=32)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @require_torch_accelerator
+ @require_torch_fp16
+ def test_model_float16(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_float16(*config_and_inputs)
+
+ def test_ctc_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_loss(*config_and_inputs)
+
+ def test_ctc_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_training(*config_and_inputs)
+
+ def test_labels_out_of_vocab(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
+
+ @unittest.skip(reason="PhoneticXeus has no inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="PhoneticXeus has input_values instead of input_ids")
+ def test_forward_signature(self):
+ pass
+
+ @unittest.skip(reason="PhoneticXeus has no token embeddings")
+ def test_model_get_set_embeddings(self):
+ pass
+
+ @unittest.skip(reason="Feed forward chunking is not implemented")
+ def test_feed_forward_chunking(self):
+ pass
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+
+ # no need to test all models as different heads yield the same functionality
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ # set layer drop to 0
+ model.config.layerdrop = 0.0
+
+ input_values = inputs_dict["input_values"]
+
+ input_lengths = torch.tensor(
+ [input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
+ )
+ output_lengths = model._get_feat_extract_output_lengths(input_lengths)
+
+ labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
+ inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
+ inputs_dict["labels"] = labels
+
+ outputs = model(**inputs_dict)
+
+ output = outputs[0]
+
+ # Encoder-/Decoder-only models
+ hidden_states = outputs.hidden_states[0]
+ attentions = outputs.attentions[0]
+
+ hidden_states.retain_grad()
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+ self.assertIsNotNone(attentions.grad)
+
+ # overwrite from test_modeling_common
+ def _mock_init_weights(self, module):
+ if hasattr(module, "weight") and module.weight is not None:
+ module.weight.fill_(3)
+ if hasattr(module, "weight_g") and module.weight_g is not None:
+ module.weight_g.data.fill_(3)
+ if hasattr(module, "weight_v") and module.weight_v is not None:
+ module.weight_v.data.fill_(3)
+ if hasattr(module, "bias") and module.bias is not None:
+ module.bias.fill_(3)
+
+ @slow
+ def test_model_from_pretrained(self):
+ model = PhoneticXeusModel.from_pretrained("changelinglab/PhoneticXeus-hf")
+ self.assertIsNotNone(model)
+
+
+@require_torch
+@slow
+class PhoneticXeusModelIntegrationTest(unittest.TestCase):
+ def _load_datasamples(self, num_samples):
+ from datasets import load_dataset
+
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ speech_samples = ds.sort("id").filter(lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)])
+ speech_samples = speech_samples[:num_samples]["audio"]
+
+ return [x["array"] for x in speech_samples]
+
+ def test_inference_ctc_batched(self):
+ model = PhoneticXeusForCTC.from_pretrained("changelinglab/PhoneticXeus-hf")
+ model.to(torch_device)
+ model.eval()
+
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("changelinglab/PhoneticXeus-hf")
+
+ input_speech = self._load_datasamples(2)
+
+ inputs = feature_extractor(input_speech, return_tensors="pt", padding=True, sampling_rate=16000)
+
+ input_values = inputs.input_values.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(input_values).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+
+ # Verify output shape: batch=2, time steps, vocab=428
+ self.assertEqual(logits.shape[0], 2)
+ self.assertEqual(logits.shape[-1], 428)
+
+ # Verify predictions are valid token ids
+ self.assertTrue((predicted_ids >= 0).all())
+ self.assertTrue((predicted_ids < 428).all())