diff --git a/docs/source/en/model_doc/auto.md b/docs/source/en/model_doc/auto.md
index 3003e5c49edd..848517280f4b 100644
--- a/docs/source/en/model_doc/auto.md
+++ b/docs/source/en/model_doc/auto.md
@@ -225,6 +225,10 @@ The following auto classes are available for the following audio tasks.
[[autodoc]] AutoModelForCTC
+### AutoModelForTDT
+
+[[autodoc]] AutoModelForTDT
+
### AutoModelForSpeechSeq2Seq
[[autodoc]] AutoModelForSpeechSeq2Seq
diff --git a/docs/source/en/model_doc/parakeet.md b/docs/source/en/model_doc/parakeet.md
index b075e6d5ccf7..cca7d395f2d2 100644
--- a/docs/source/en/model_doc/parakeet.md
+++ b/docs/source/en/model_doc/parakeet.md
@@ -34,15 +34,20 @@ Parakeet models, [introduced by NVIDIA NeMo](https://developer.nvidia.com/blog/p
- 1D convolution projection from encoder hidden size to vocabulary size (for optimal NeMo compatibility).
- CTC loss computation for training.
- Greedy CTC decoding for inference.
+- [**ParakeetForTDT**](#parakeetfortdt): a Fast Conformer Encoder + a TDT (Token Duration Transducer) decoder
+ - **TDT Decoder**: Jointly predicts tokens and their durations, enabling efficient decoding:
+ - LSTM prediction network maintains language context across token predictions.
+ - Joint network combines encoder and decoder outputs.
+ - Duration head predicts how many frames to skip, enabling fast inference.
The original implementation can be found in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo).
Model checkpoints are to be found under [the NVIDIA organization](https://huggingface.co/nvidia/models?search=parakeet).
-This model was contributed by [Nithin Rao Koluguri](https://huggingface.co/nithinraok), [Eustache Le Bihan](https://huggingface.co/eustlb) and [Eric Bezzam](https://huggingface.co/bezzam).
+This model was contributed by [Nithin Rao Koluguri](https://huggingface.co/nithinraok), [Eustache Le Bihan](https://huggingface.co/eustlb), [Eric Bezzam](https://huggingface.co/bezzam), [Maksym Lypivskyi](https://huggingface.co/MaksL), and [Hainan Xu](https://huggingface.co/hainanx).
## Usage
-### Basic usage
+### `ParakeetForCTC` usage
@@ -53,6 +58,7 @@ from transformers import pipeline
pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-ctc-1.1b")
out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
print(out)
+# {'text': 'yesterday it was thirty five degrees in barcelona but today the temperature will go down to minus twenty degrees'}
```
@@ -61,12 +67,10 @@ print(out)
```py
from transformers import AutoModelForCTC, AutoProcessor
from datasets import load_dataset, Audio
-import torch
-
-device = "cuda" if torch.cuda.is_available() else "cpu"
-processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
-model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device)
+model_id = "nvidia/parakeet-ctc-1.1b"
+processor = AutoProcessor.from_pretrained(model_id)
+model = AutoModelForCTC.from_pretrained(model_id, dtype="auto", device_map="auto")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
@@ -75,7 +79,80 @@ speech_samples = [el['array'] for el in ds["audio"][:5]]
inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(model.device, dtype=model.dtype)
outputs = model.generate(**inputs)
-print(processor.batch_decode(outputs))
+print(processor.decode(outputs))
+```
+
+
+
+
+### `ParakeetForTDT` usage
+
+
+
+
+Parakeet TDT transcripts include casing, and the model can also perform token timestamping.
+
+```py
+from transformers import pipeline
+
+pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-tdt-0.6b-v3")
+out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
+print(out)
+# {'text': 'Yesterday it was 35 degrees in Barcelona, but today the temperature will go down to minus 20 degrees.'}
+```
+
+
+
+
+```py
+from transformers import AutoModelForTDT, AutoProcessor
+from datasets import load_dataset, Audio
+
+model_id = "nvidia/parakeet-tdt-0.6b-v3"
+processor = AutoProcessor.from_pretrained(model_id)
+model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto")
+
+ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
+speech_samples = [el['array'] for el in ds["audio"][:5]]
+
+inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
+inputs.to(model.device, dtype=model.dtype)
+output = model.generate(**inputs, return_dict_in_generate=True)
+print(processor.decode(output.sequences, skip_special_tokens=True))
+```
+
+
+
+
+```py
+from datasets import Audio, load_dataset
+from transformers import AutoModelForTDT, AutoProcessor
+
+model_id = "nvidia/parakeet-tdt-0.6b-v3"
+processor = AutoProcessor.from_pretrained(model_id)
+model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto")
+
+ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
+speech_samples = [el['array'] for el in ds["audio"][:1]]
+
+inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
+inputs.to(model.device, dtype=model.dtype)
+output = model.generate(**inputs, return_dict_in_generate=True)
+decoded_output, decoded_timestamps = processor.decode(
+ output.sequences,
+ durations=output.durations,
+ skip_special_tokens=True,
+)
+print("Transcription:", decoded_output)
+print("\nTimestamped tokens:", decoded_timestamps)
+
+"""
+Transcription: ['mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.']
+
+Timestamped tokens: [[{'token': 'm', 'start': 0.24, 'end': 0.48}, {'token': 'ister', 'start': 0.48, 'end': 0.64}, {'token': 'Qu', 'start': 0.64, 'end': 0.88}, {'token': 'il', 'start': 0.88, 'end': 1.12}, {'token': 'ter', 'start': 1.12, 'end': 1.36}, {'token': 'is', 'start': 1.36, 'end': 1.44}, {'token': 'the', 'start': 1.44, 'end': 1.6}, {'token': 'ap', 'start': 1.6, 'end': 1.76}, {'token': 'ost', 'start': 1.76, 'end': 1.92}, {'token': 'le', 'start': 2.0, 'end': 2.16}, {'token': 'of', 'start': 2.16, 'end': 2.24}, {'token': 'the', 'start': 2.24, 'end': 2.4}, {'token': 'mid', 'start': 2.4, 'end': 2.48}, {'token': 'd', 'start': 2.48, 'end': 2.56}, {'token': 'le', 'start': 2.56, 'end': 2.64}, {'token': 'clas', 'start': 2.72, 'end': 2.88}, {'token': 's', 'start': 2.88, 'end': 3.04}, {'token': 'es', 'start': 3.04, 'end': 3.12}, {'token': ',', 'start': 3.12, 'end': 3.12}, {'token': 'and', 'start': 3.2800000000000002, 'end': 3.44}, {'token': 'we', 'start': 3.44, 'end': 3.6}, {'token': 'are', 'start': 3.6, 'end': 3.7600000000000002}, {'token': 'gl', 'start': 3.7600000000000002, 'end': 3.92}, {'token': 'ad', 'start': 3.92, 'end': 4.08}, {'token': 'to', 'start': 4.08, 'end': 4.24}, {'token': 'wel', 'start': 4.24, 'end': 4.4}, {'token': 'c', 'start': 4.4, 'end': 4.48}, {'token': 'ome', 'start': 4.48, 'end': 4.72}, {'token': 'his', 'start': 4.72, 'end': 4.96}, {'token': 'gos', 'start': 4.96, 'end': 5.12}, {'token': 'pel', 'start': 5.36, 'end': 5.6000000000000005}, {'token': '.', 'start': 5.6000000000000005, 'end': 5.6000000000000005}]]
+"""
```
@@ -136,7 +213,7 @@ print("First generation - compiling...")
# Generate with the compiled model
with TimerContext("First generation"):
outputs = model.generate(**inputs)
-print(processor.batch_decode(outputs))
+print(processor.decode(outputs))
inputs = processor(speech_samples[1], **processor_kwargs)
inputs.to(device, dtype=model.dtype)
@@ -144,7 +221,7 @@ print("\n" + "="*50)
print("Second generation - recording CUDA graphs...")
with TimerContext("Second generation"):
outputs = model.generate(**inputs)
-print(processor.batch_decode(outputs))
+print(processor.decode(outputs))
inputs = processor(speech_samples[2], **processor_kwargs)
inputs.to(device, dtype=model.dtype)
@@ -152,7 +229,7 @@ print("\n" + "="*50)
print("Third generation - fast !!!")
with TimerContext("Third generation"):
outputs = model.generate(**inputs)
-print(processor.batch_decode(outputs))
+print(processor.decode(outputs))
inputs = processor(speech_samples[3], **processor_kwargs)
inputs.to(device, dtype=model.dtype)
@@ -160,34 +237,66 @@ print("\n" + "="*50)
print("Fourth generation - still fast !!!")
with TimerContext("Fourth generation"):
outputs = model.generate(**inputs)
-print(processor.batch_decode(outputs))
+print(processor.decode(outputs))
```
-### Training
+### CTC Training
```python
+import torch
+from datasets import Audio, load_dataset
from transformers import AutoModelForCTC, AutoProcessor
-from datasets import load_dataset, Audio
+
+model_id = "nvidia/parakeet-ctc-1.1b"
+NUM_SAMPLES = 5
+
+processor = AutoProcessor.from_pretrained(model_id)
+model = AutoModelForCTC.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto")
+model.train()
+
+ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
+speech_samples = [el['array'] for el in ds["audio"][:NUM_SAMPLES]]
+text_samples = ds["text"][:NUM_SAMPLES]
+
+# passing `text` to the processor will prepare inputs' `labels` key
+inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate)
+inputs.to(device=model.device, dtype=model.dtype)
+
+outputs = model(**inputs)
+print("Loss:", outputs.loss.item())
+outputs.loss.backward()
+```
+
+### TDT Training
+
+```py
+from datasets import Audio, load_dataset
import torch
+from transformers import AutoModelForTDT, AutoProcessor
-device = "cuda" if torch.cuda.is_available() else "cpu"
+model_id = "nvidia/parakeet-tdt-0.6b-v3"
+NUM_SAMPLES = 4
-processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
-model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device)
+processor = AutoProcessor.from_pretrained(model_id)
+model = AutoModelForTDT.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto")
+model.train()
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
-speech_samples = [el['array'] for el in ds["audio"][:5]]
-text_samples = [el for el in ds["text"][:5]]
+speech_samples = [el['array'] for el in ds["audio"][:NUM_SAMPLES]]
+text_samples = ds["text"][:NUM_SAMPLES]
# passing `text` to the processor will prepare inputs' `labels` key
inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate)
-inputs.to(device, dtype=model.dtype)
+inputs.to(device=model.device, dtype=model.dtype)
outputs = model(**inputs)
+print("Loss:", outputs.loss.item())
outputs.loss.backward()
```
+
## ParakeetTokenizer
[[autodoc]] ParakeetTokenizer
@@ -201,7 +310,6 @@ outputs.loss.backward()
[[autodoc]] ParakeetProcessor
- __call__
- - batch_decode
- decode
## ParakeetEncoderConfig
@@ -212,6 +320,10 @@ outputs.loss.backward()
[[autodoc]] ParakeetCTCConfig
+## ParakeetTDTConfig
+
+[[autodoc]] ParakeetTDTConfig
+
## ParakeetEncoder
[[autodoc]] ParakeetEncoder
@@ -219,3 +331,7 @@ outputs.loss.backward()
## ParakeetForCTC
[[autodoc]] ParakeetForCTC
+
+## ParakeetForTDT
+
+[[autodoc]] ParakeetForTDT
diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py
index 62fac901ae84..3720efa281c3 100644
--- a/src/transformers/convert_slow_tokenizer.py
+++ b/src/transformers/convert_slow_tokenizer.py
@@ -735,7 +735,8 @@ def tokenizer(self, proto):
)
elif model_type == 2:
- _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
+ result = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(None)
+ merges = result["merges"]
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(
@@ -1842,7 +1843,8 @@ def __init__(self, vocab_file=None, *args):
def tokenizer(self, proto):
vocab_scores = self.vocab(proto)
- _, merges = self.SpmExtractor(self.vocab_file).extract(vocab_scores)
+ result = self.SpmExtractor(self.vocab_file).extract(None)
+ merges = result["merges"]
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index 7439722c60b9..af5e325992ec 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -1459,6 +1459,13 @@ def compute_transition_scores(
def _validate_generation_mode(
self: "GenerativePreTrainedModel", generation_mode, generation_config, generation_mode_kwargs
):
+ supported_modes = getattr(self, "_supported_generation_modes", None)
+ if supported_modes is not None and generation_mode not in supported_modes:
+ raise ValueError(
+ f"{self.__class__.__name__} only supports {supported_modes}, but got "
+ f"generation mode '{generation_mode}'."
+ )
+
if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs:
raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py
index 88aff578fdc6..c0db0822b962 100644
--- a/src/transformers/integrations/hub_kernels.py
+++ b/src/transformers/integrations/hub_kernels.py
@@ -286,6 +286,7 @@ def register_kernel_mapping_transformers(*args, **kwargs):
"falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1},
"finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1},
"deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1},
+ "tdt-loss": {"repo_id": "eustlb/tdt-loss", "revision": "v1"},
}
_KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {}
@@ -372,10 +373,12 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None)
version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None)
- kernel = get_kernel(repo_id, revision=revision, version=version)
+ # Since we only read from `_HUB_KERNEL_MAPPING`, we can allow all kernels
+ kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True)
mapping[kernel_name] = kernel
- except FileNotFoundError:
+ except FileNotFoundError as e:
mapping[kernel_name] = None
+ logger.warning_once(f"Failed to load kernel {kernel_name}: {e}")
except AssertionError:
# Happens when torch is built without an accelerator backend; fall back to slow path.
mapping[kernel_name] = None
diff --git a/src/transformers/loss/loss_tdt.py b/src/transformers/loss/loss_tdt.py
new file mode 100644
index 000000000000..6a128f18583c
--- /dev/null
+++ b/src/transformers/loss/loss_tdt.py
@@ -0,0 +1,217 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def _load_tdt_kernel():
+ """Try to load the TDT loss CUDA kernel from the Hub. Returns None on failure."""
+ try:
+ from ..integrations.hub_kernels import lazy_load_kernel
+
+ kernel = lazy_load_kernel("tdt-loss")
+ if kernel is None or not hasattr(kernel, "tdt_loss"):
+ logger.warning_once("Falling back to pure PyTorch implementation.")
+ return None
+ return kernel
+ except (ImportError, ModuleNotFoundError):
+ return None
+ except Exception as e:
+ logger.warning_once(f"Failed to load TDT CUDA kernel: {e}. Falling back to pure PyTorch implementation.")
+ return None
+
+
+def tdt_loss(
+ token_logits: torch.Tensor,
+ duration_logits: torch.Tensor,
+ targets: torch.Tensor,
+ logit_lengths: torch.Tensor,
+ target_lengths: torch.Tensor,
+ blank_token_id: int,
+ durations: list[int],
+ sigma: float = 0.0,
+ reduction: str = "mean",
+) -> torch.Tensor:
+ """
+ Compute TDT (Token-and-Duration Transducer) loss (https://arxiv.org/abs/2304.06795).
+
+ Ported from NeMo's `TDTLossPytorch` with anti-diagonal processing. Unlike standard RNNT loss, this loss trains both
+ the token prediction head and the duration prediction head. It uses vectorized anti-diagonal processing for
+ efficiency: all (t, u) pairs on each anti-diagonal t+u=n are computed in parallel as batched tensor operations.
+
+ When the ``kernels-community/tdt-loss`` CUDA kernel is installed, it is used automatically for GPU tensors,
+ Falls back to the pure PyTorch implementation otherwise.
+
+ Args:
+ token_logits: Token logits of shape `(batch, T, U+1, vocab_size+1)`.
+ duration_logits: Duration logits of shape `(batch, T, U+1, num_durations)`.
+ targets: Target labels of shape `(batch, U)`.
+ logit_lengths: Encoder output lengths of shape `(batch,)`.
+ target_lengths: Target lengths of shape `(batch,)`.
+ blank_token_id: Blank token id.
+ durations: List of duration values (e.g., `[0, 1, 2, 3, 4]`).
+ sigma: Logit undernormalization constant (see TDT paper). Defaults to `0.0`.
+ reduction: Loss reduction method. One of `"mean"`, `"sum"`, or `"none"`. Defaults to `"mean"`.
+
+ Returns:
+ Scalar loss tensor (or per-example losses if `reduction="none"`).
+
+ """
+ kernel = _load_tdt_kernel() if token_logits.is_cuda else None
+ if kernel is not None and hasattr(kernel, "tdt_loss"):
+ durations_t = torch.tensor(durations, dtype=torch.int32, device=token_logits.device)
+ return kernel.tdt_loss(
+ token_logits,
+ duration_logits,
+ targets,
+ logit_lengths,
+ target_lengths,
+ durations_t,
+ blank_token_id,
+ sigma,
+ reduction,
+ )
+
+ if reduction not in ("mean", "sum", "none"):
+ raise ValueError(f'Invalid reduction mode "{reduction}". Expected one of "mean", "sum", or "none".')
+
+ device = token_logits.device
+ batch_size, max_t, max_u, _ = token_logits.shape
+
+ token_logits = token_logits.float()
+ duration_logits = duration_logits.float()
+
+ # Apply log-softmax to get log probabilities
+ # sigma only applies to token logits (undernormalization constant from the TDT paper)
+ token_log_probs = torch.log_softmax(token_logits, dim=-1) - sigma
+ duration_log_probs = torch.log_softmax(duration_logits, dim=-1)
+
+ log_alpha = torch.full((batch_size, max_t, max_u), float("-inf"), device=device)
+ log_alpha[:, 0, 0] = 0.0
+
+ # Precompute blank and label log-probs for vectorized access
+ blank_log_probs = token_log_probs[:, :, :, blank_token_id]
+
+ if max_u > 1:
+ targets_expanded = targets.unsqueeze(1).expand(-1, max_t, -1) # (batch, T, U_labels)
+ label_log_probs = torch.gather(
+ token_log_probs[:, :, : max_u - 1, :], # (batch, T, U-1, vocab)
+ dim=3,
+ index=targets_expanded.unsqueeze(-1),
+ ).squeeze(-1) # (batch, T, U-1)
+
+ neg_inf = torch.tensor(float("-inf"), device=device)
+
+ # Process anti-diagonals: all (t, u) with t + u = n have no mutual dependencies
+ for n in range(1, max_t + max_u - 1):
+ u_start = max(0, n - max_t + 1)
+ u_end = min(n + 1, max_u)
+ u_indices = torch.arange(u_start, u_end, device=device)
+
+ t_indices = n - u_indices
+ all_candidates = []
+ for i, dur in enumerate(durations):
+ t_prev = t_indices - dur
+ valid_t = t_prev >= 0
+ if not valid_t.any():
+ continue
+ t_src = t_prev.clamp(min=0)
+
+ # Blank arcs (dur > 0): from (t-dur, u) to (t, u)
+ if dur > 0:
+ contrib = (
+ log_alpha[:, t_src, u_indices]
+ + blank_log_probs[:, t_src, u_indices]
+ + duration_log_probs[:, t_src, u_indices, i]
+ )
+ contrib = torch.where(valid_t.unsqueeze(0), contrib, neg_inf)
+ all_candidates.append(contrib)
+
+ # Label arcs: from (t-dur, u-1) to (t, u), only if u > 0
+ valid_u = u_indices > 0
+ valid_both = valid_t & valid_u
+ if valid_both.any():
+ u_src = (u_indices - 1).clamp(min=0)
+ u_src_label = u_src.clamp(max=max_u - 2) if max_u > 1 else u_src
+
+ contrib = (
+ log_alpha[:, t_src, u_src]
+ + label_log_probs[:, t_src, u_src_label]
+ + duration_log_probs[:, t_src, u_src, i]
+ )
+ contrib = torch.where(valid_both.unsqueeze(0), contrib, neg_inf)
+ all_candidates.append(contrib)
+
+ if all_candidates:
+ stacked = torch.stack(all_candidates, dim=0)
+ log_alpha[:, t_indices, u_indices] = torch.logsumexp(stacked, dim=0)
+
+ # Terminal probability: sum over blank arcs that reach (T, U) from (T-dur, U)
+ batch_idx = torch.arange(batch_size, device=device)
+ log_probs = torch.full((batch_size,), float("-inf"), device=device)
+ for i, dur in enumerate(durations):
+ if dur == 0:
+ continue
+ t_final = logit_lengths - dur
+ valid = t_final >= 0
+ if not valid.any():
+ continue
+
+ t_clamped = t_final.clamp(min=0)
+ terminal = (
+ log_alpha[batch_idx, t_clamped, target_lengths]
+ + token_log_probs[batch_idx, t_clamped, target_lengths, blank_token_id]
+ + duration_log_probs[batch_idx, t_clamped, target_lengths, i]
+ )
+ combined = torch.stack([log_probs, terminal], dim=0)
+ log_probs = torch.where(valid, torch.logsumexp(combined, dim=0), log_probs)
+
+ losses = -log_probs
+
+ if reduction == "mean":
+ return (losses / target_lengths.float()).mean()
+ elif reduction == "sum":
+ return losses.sum()
+ return losses
+
+
+def ParakeetForTDTLoss(
+ token_logits,
+ duration_logits,
+ labels,
+ logit_lengths,
+ label_lengths,
+ blank_token_id,
+ durations,
+ sigma=0.0,
+ reduction="mean",
+ **kwargs,
+):
+ device = token_logits.device
+ return tdt_loss(
+ token_logits=token_logits,
+ duration_logits=duration_logits,
+ targets=labels.to(device).int(),
+ logit_lengths=logit_lengths.to(device).int(),
+ target_lengths=label_lengths.to(device).int(),
+ blank_token_id=blank_token_id,
+ durations=durations,
+ sigma=sigma,
+ reduction=reduction,
+ )
diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py
index df269477e9ec..e0aa92b50808 100644
--- a/src/transformers/loss/loss_utils.py
+++ b/src/transformers/loss/loss_utils.py
@@ -23,6 +23,7 @@
from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss
from .loss_lw_detr import LwDetrForObjectDetectionLoss
from .loss_rt_detr import RTDetrForObjectDetectionLoss
+from .loss_tdt import ParakeetForTDTLoss
def fixed_cross_entropy(
@@ -165,4 +166,5 @@ def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs):
"DFineForObjectDetection": DFineForObjectDetectionLoss,
"CsmForConditionalGeneration": ForCausalLMLoss,
"LwDetrForObjectDetection": LwDetrForObjectDetectionLoss,
+ "ParakeetForTDT": ParakeetForTDTLoss,
}
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 2e98863a762d..fd1be3aea007 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -2303,6 +2303,12 @@ def _init_weights(self, module):
init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
init.zeros_(module.bias)
+ elif isinstance(module, nn.LSTM):
+ for name, param in module.named_parameters():
+ if "weight" in name:
+ init.xavier_uniform_(param)
+ elif "bias" in name:
+ init.constant_(param, 0.0)
elif isinstance(module, nn.Embedding):
init.normal_(module.weight, mean=0.0, std=std)
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py
index 10e376b65956..24db9a947411 100644
--- a/src/transformers/models/auto/auto_mappings.py
+++ b/src/transformers/models/auto/auto_mappings.py
@@ -393,6 +393,7 @@
("paligemma", "PaliGemmaConfig"),
("parakeet_ctc", "ParakeetCTCConfig"),
("parakeet_encoder", "ParakeetEncoderConfig"),
+ ("parakeet_tdt", "ParakeetTDTConfig"),
("patchtsmixer", "PatchTSMixerConfig"),
("patchtst", "PatchTSTConfig"),
("pe_audio", "PeAudioConfig"),
@@ -755,6 +756,7 @@
("paddleocr_vl_vision", "paddleocr_vl"),
("parakeet_ctc", "parakeet"),
("parakeet_encoder", "parakeet"),
+ ("parakeet_tdt", "parakeet"),
("pe_audio_encoder", "pe_audio"),
("pe_audio_video_encoder", "pe_audio_video"),
("pe_video_encoder", "pe_video"),
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index 111c56efb436..2150741c082b 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -62,6 +62,7 @@
("musicgen_melody", "MusicgenMelodyFeatureExtractor"),
("parakeet_ctc", "ParakeetFeatureExtractor"),
("parakeet_encoder", "ParakeetFeatureExtractor"),
+ ("parakeet_tdt", "ParakeetFeatureExtractor"),
("pe_audio", "PeAudioFeatureExtractor"),
("pe_audio_video", "PeAudioFeatureExtractor"),
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index deb1153d335e..4d135a385938 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -333,6 +333,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("paligemma", "PaliGemmaModel"),
("parakeet_ctc", "ParakeetForCTC"),
("parakeet_encoder", "ParakeetEncoder"),
+ ("parakeet_tdt", "ParakeetForTDT"),
("patchtsmixer", "PatchTSMixerModel"),
("patchtst", "PatchTSTModel"),
("pe_audio", "PeAudioModel"),
@@ -1636,6 +1637,14 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
]
)
+MODEL_FOR_TDT_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Token-and-Duration Transducer (TDT) mapping.
+ ("parakeet_tdt", "ParakeetForTDT"),
+ ]
+)
+
+
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Audio Classification mapping
@@ -1906,6 +1915,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
+MODEL_FOR_TDT_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TDT_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
@@ -2229,6 +2239,13 @@ class AutoModelForCTC(_BaseAutoModelClass):
AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
+class AutoModelForTDT(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_TDT_MAPPING
+
+
+AutoModelForTDT = auto_class_update(AutoModelForTDT, head_doc="token-and-duration transducer")
+
+
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
@@ -2291,6 +2308,7 @@ class AutoModelForAudioTokenization(_BaseAutoModelClass):
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_CTC_MAPPING",
+ "MODEL_FOR_TDT_MAPPING",
"MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_DEPTH_ESTIMATION_MAPPING",
"MODEL_FOR_TEXT_RECOGNITION_MAPPING",
@@ -2339,6 +2357,7 @@ class AutoModelForAudioTokenization(_BaseAutoModelClass):
"AutoModelForAudioXVector",
"AutoModelForCausalLM",
"AutoModelForCTC",
+ "AutoModelForTDT",
"AutoModelForDepthEstimation",
"AutoModelForTextRecognition",
"AutoModelForTableRecognition",
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index 8d7d59c1f6ab..541d7d8434af 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -133,6 +133,8 @@
("owlvit", "OwlViTProcessor"),
("paddleocr_vl", "PaddleOCRVLProcessor"),
("paligemma", "PaliGemmaProcessor"),
+ ("parakeet_ctc", "ParakeetProcessor"),
+ ("parakeet_tdt", "ParakeetProcessor"),
("perception_lm", "PerceptionLMProcessor"),
("phi4_multimodal", "Phi4MultimodalProcessor"),
("pi0", "PI0Processor"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index fd93e24edee1..536f3a7d2ace 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -245,6 +245,8 @@
("ovis2", "Qwen2Tokenizer" if is_tokenizers_available() else None),
("owlv2", "CLIPTokenizer" if is_tokenizers_available() else None),
("owlvit", "CLIPTokenizer" if is_tokenizers_available() else None),
+ ("parakeet_ctc", "ParakeetTokenizer" if is_tokenizers_available() else None),
+ ("parakeet_tdt", "ParakeetTokenizer" if is_tokenizers_available() else None),
("pegasus", "PegasusTokenizer" if is_tokenizers_available() else None),
("pegasus_x", "PegasusTokenizer" if is_tokenizers_available() else None),
("perceiver", "PerceiverTokenizer"),
diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py
index 352a1e94006c..6af8e2d8c968 100644
--- a/src/transformers/models/encodec/modeling_encodec.py
+++ b/src/transformers/models/encodec/modeling_encodec.py
@@ -455,23 +455,12 @@ class EncodecPreTrainedModel(PreTrainedAudioTokenizerBase):
@torch.no_grad()
def _init_weights(self, module):
- """Initialize the weights"""
- if isinstance(module, nn.GroupNorm):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- elif isinstance(module, nn.Conv1d):
+ super()._init_weights(module)
+ if isinstance(module, nn.Conv1d):
init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
init.uniform_(module.bias, a=-k, b=k)
- elif isinstance(module, nn.ConvTranspose1d):
- module.reset_parameters()
- elif isinstance(module, nn.LSTM):
- for name, param in module.named_parameters():
- if "weight" in name:
- init.xavier_uniform_(param)
- elif "bias" in name:
- init.constant_(param, 0.0)
elif isinstance(module, EncodecConv1d):
kernel_size = module.conv.kernel_size[0]
stride = torch.tensor(module.conv.stride[0], dtype=torch.int64)
diff --git a/src/transformers/models/lasr/configuration_lasr.py b/src/transformers/models/lasr/configuration_lasr.py
index d55ea6449d37..db9b187b2ea5 100644
--- a/src/transformers/models/lasr/configuration_lasr.py
+++ b/src/transformers/models/lasr/configuration_lasr.py
@@ -48,18 +48,18 @@ class LasrEncoderConfig(PreTrainedConfig):
The momentum for the batch normalization layers
Example:
- ```python
- >>> from transformers import LasrEncoderModel, LasrEncoderConfig
+ ```python
+ >>> from transformers import LasrEncoderModel, LasrEncoderConfig
- >>> # Initializing a `LasrEncoder` configuration
- >>> configuration = LasrEncoderConfig()
+ >>> # Initializing a `LasrEncoder` configuration
+ >>> configuration = LasrEncoderConfig()
- >>> # Initializing a model from the configuration
- >>> model = LasrEncoderModel(configuration)
+ >>> # Initializing a model from the configuration
+ >>> model = LasrEncoderModel(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details
and pre-trained models at [google/medasr](https://huggingface.co/google/medasr).
@@ -111,15 +111,15 @@ class LasrCTCConfig(PreTrainedConfig):
of [`LasrForCTC`].
Example:
- ```python
- >>> from transformers import LasrForCTC, LasrCTCConfig
- >>> # Initializing a Lasr configuration
- >>> configuration = LasrCTCConfig()
- >>> # Initializing a model from the configuration
- >>> model = LasrForCTC(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```
+ ```python
+ >>> from transformers import LasrForCTC, LasrCTCConfig
+ >>> # Initializing a Lasr configuration
+ >>> configuration = LasrCTCConfig()
+ >>> # Initializing a model from the configuration
+ >>> model = LasrForCTC(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
This configuration class is based on the Lasr CTC architecture from Google Health AI. You can find more details
and pre-trained models at [google/medasr](https://huggingface.co/google/medasr).
"""
diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py
index 7ecea9099410..19054874b1e1 100644
--- a/src/transformers/models/lasr/modeling_lasr.py
+++ b/src/transformers/models/lasr/modeling_lasr.py
@@ -26,16 +26,18 @@
from torch import nn
from ...activations import ACT2FN
+from ...generation import CompileConfig, GenerationMixin
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
from ...masking_utils import create_bidirectional_mask
from ...modeling_layers import GradientCheckpointingLayer
-from ...modeling_outputs import BaseModelOutput, CausalLMOutput
+from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
+from ..auto import AutoModel
from .configuration_lasr import LasrCTCConfig, LasrEncoderConfig
@@ -458,6 +460,17 @@ def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length
return attention_mask
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Extends [~modeling_outputs.BaseModelOutputWithPooling] to include the output attention mask since sequence length
+ is not preserved in the model's forward.
+ """
+)
+class LasrEncoderModelOutput(BaseModelOutputWithPooling):
+ attention_mask: torch.Tensor | None = None
+
+
@auto_docstring(
custom_intro="""
The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
@@ -492,16 +505,20 @@ def forward(
self,
input_features: torch.Tensor,
attention_mask: torch.Tensor | None = None,
+ output_attention_mask: bool | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
+ ) -> LasrEncoderModelOutput:
r"""
+ output_attention_mask (`bool`, *optional*):
+ Whether to return the output attention mask.
+
Example:
```python
>>> from transformers import AutoProcessor, LasrEncoder
>>> from datasets import load_dataset, Audio
- >>> model_id = TODO
+ >>> model_id = "google/medasr"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> encoder = ParakeetEncoder.from_pretrained(model_id)
@@ -524,8 +541,10 @@ def forward(
cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)
+ output_mask = None
if attention_mask is not None:
- attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
+ output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
+ attention_mask = output_mask
attention_mask = create_bidirectional_mask(
config=self.config,
@@ -551,13 +570,16 @@ def forward(
hidden_states = self.out_norm(hidden_states)
- return BaseModelOutput(last_hidden_state=hidden_states)
+ return LasrEncoderModelOutput(
+ last_hidden_state=hidden_states,
+ attention_mask=output_mask.int() if output_attention_mask and output_mask is not None else None,
+ )
@dataclass
-class LasrGenerateOutput(ModelOutput):
+class LasrCTCGenerateOutput(ModelOutput):
"""
- Outputs of Lasr models.
+ Outputs of Lasr CTC model generation.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -586,12 +608,12 @@ class LasrGenerateOutput(ModelOutput):
Lasr Encoder with a Connectionist Temporal Classification (CTC) head.
"""
)
-class LasrForCTC(LasrPreTrainedModel):
+class LasrForCTC(LasrPreTrainedModel, GenerationMixin):
config: LasrCTCConfig
def __init__(self, config: LasrCTCConfig):
super().__init__(config)
- self.encoder = LasrEncoder(config.encoder_config)
+ self.encoder = AutoModel.from_config(config.encoder_config)
# Conv rather than linear to be consistent with NeMO decoding layer
self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
@@ -626,6 +648,8 @@ def forward(
>>> print(outputs.loss)
```"""
+ if labels is not None:
+ kwargs.setdefault("output_attention_mask", True)
encoder_outputs = self.encoder(
input_features=input_features,
attention_mask=attention_mask,
@@ -637,14 +661,9 @@ def forward(
loss = None
if labels is not None:
- # retrieve loss input_lengths from attention_mask
- attention_mask = (
- attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
- )
- input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
+ encoder_lengths = encoder_outputs.attention_mask.sum(-1)
- # assuming that padded tokens are filled with -100
- # when not being attended to
+ # assuming that padded tokens are filled with pad_token_id when not being attended to
labels_mask = labels != self.config.pad_token_id
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
@@ -656,7 +675,7 @@ def forward(
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
- input_lengths,
+ encoder_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
@@ -676,8 +695,9 @@ def generate(
input_features: torch.Tensor,
attention_mask: torch.Tensor | None = None,
return_dict_in_generate: bool = False,
+ compile_config: CompileConfig | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> LasrGenerateOutput | torch.LongTensor:
+ ) -> LasrCTCGenerateOutput | torch.LongTensor:
r"""
Example:
@@ -685,7 +705,7 @@ def generate(
>>> from transformers import AutoProcessor, LasrForCTC
>>> from datasets import load_dataset, Audio
- >>> model_id = TODO
+ >>> model_id = "google/medasr"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> model = LasrForCTC.from_pretrained(model_id)
@@ -699,8 +719,10 @@ def generate(
>>> print(transcription)
```
"""
+ model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__
+
kwargs["return_dict"] = True
- outputs: CausalLMOutput = self.forward(
+ outputs: CausalLMOutput = model_forward(
input_features=input_features,
attention_mask=attention_mask,
**kwargs,
@@ -715,7 +737,7 @@ def generate(
sequences[~attention_mask] = self.config.pad_token_id
if return_dict_in_generate:
- return LasrGenerateOutput(
+ return LasrCTCGenerateOutput(
sequences=sequences,
logits=outputs.logits,
attentions=outputs.attentions,
diff --git a/src/transformers/models/lasr/modular_lasr.py b/src/transformers/models/lasr/modular_lasr.py
index f016f16cff45..1329c5c0a2af 100644
--- a/src/transformers/models/lasr/modular_lasr.py
+++ b/src/transformers/models/lasr/modular_lasr.py
@@ -21,12 +21,13 @@
from tokenizers.models import Unigram
from torch import nn
+from ...audio_utils import AudioInput, make_list_of_audio
from ...masking_utils import create_bidirectional_mask
-from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...tokenization_utils_tokenizers import TokenizersBackend
-from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from ..llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward
@@ -34,13 +35,16 @@
from ..parakeet.modeling_parakeet import (
ParakeetEncoderBlock,
ParakeetEncoderConvolutionModule,
+ ParakeetEncoderModelOutput,
ParakeetForCTC,
ParakeetPreTrainedModel,
)
-from ..parakeet.processing_parakeet import ParakeetProcessor
from ..t5.tokenization_t5 import T5Tokenizer
+logger = logging.get_logger(__name__)
+
+
class LasrTokenizer(T5Tokenizer, TokenizersBackend):
def __init__(
self,
@@ -144,8 +148,74 @@ def _decode(
)
-class LasrProcessor(ParakeetProcessor):
- pass
+class LasrProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "audio_kwargs": {
+ "sampling_rate": 16000,
+ "padding": "longest",
+ "return_attention_mask": True,
+ },
+ "text_kwargs": {
+ "padding": True,
+ "padding_side": "right",
+ "add_special_tokens": False,
+ },
+ "common_kwargs": {"return_tensors": "pt"},
+ }
+
+
+@auto_docstring
+class LasrProcessor(ProcessorMixin):
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+
+ @auto_docstring
+ def __call__(
+ self,
+ audio: AudioInput,
+ text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
+ sampling_rate: int | None = None,
+ **kwargs: Unpack[LasrProcessorKwargs],
+ ):
+ r"""
+ sampling_rate (`int`, *optional*):
+ The sampling rate of the input audio in Hz. This should match the sampling rate expected by the feature
+ extractor (defaults to 16000 Hz). If provided, it will be validated against the processor's expected
+ sampling rate, and an error will be raised if they don't match. If not provided, a warning will be
+ issued and the default sampling rate will be assumed.
+ """
+ audio = make_list_of_audio(audio)
+
+ output_kwargs = self._merge_kwargs(
+ LasrProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if sampling_rate is None:
+ logger.warning_once(
+ f"You've provided audio without specifying the sampling rate. It will be assumed to be {output_kwargs['audio_kwargs']['sampling_rate']}, which can result in silent errors."
+ )
+ elif sampling_rate != output_kwargs["audio_kwargs"]["sampling_rate"]:
+ raise ValueError(
+ f"The sampling rate of the audio ({sampling_rate}) does not match the sampling rate of the processor ({output_kwargs['audio_kwargs']['sampling_rate']}). Please provide resampled the audio to the expected sampling rate."
+ )
+
+ if audio is not None:
+ inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
+ if text is not None:
+ encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
+
+ if text is None:
+ return inputs
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
+
+ @property
+ def model_input_names(self):
+ feature_extractor_input_names = self.feature_extractor.model_input_names
+ return feature_extractor_input_names + ["labels"]
@auto_docstring(checkpoint="google/medasr")
@@ -172,18 +242,18 @@ class LasrEncoderConfig(ParakeetEncoderConfig):
The momentum for the batch normalization layers
Example:
- ```python
- >>> from transformers import LasrEncoderModel, LasrEncoderConfig
+ ```python
+ >>> from transformers import LasrEncoderModel, LasrEncoderConfig
- >>> # Initializing a `LasrEncoder` configuration
- >>> configuration = LasrEncoderConfig()
+ >>> # Initializing a `LasrEncoder` configuration
+ >>> configuration = LasrEncoderConfig()
- >>> # Initializing a model from the configuration
- >>> model = LasrEncoderModel(configuration)
+ >>> # Initializing a model from the configuration
+ >>> model = LasrEncoderModel(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details
and pre-trained models at [google/medasr](https://huggingface.co/google/medasr).
@@ -221,15 +291,15 @@ class LasrCTCConfig(ParakeetCTCConfig):
of [`LasrForCTC`].
Example:
- ```python
- >>> from transformers import LasrForCTC, LasrCTCConfig
- >>> # Initializing a Lasr configuration
- >>> configuration = LasrCTCConfig()
- >>> # Initializing a model from the configuration
- >>> model = LasrForCTC(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```
+ ```python
+ >>> from transformers import LasrForCTC, LasrCTCConfig
+ >>> # Initializing a Lasr configuration
+ >>> configuration = LasrCTCConfig()
+ >>> # Initializing a model from the configuration
+ >>> model = LasrForCTC(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
This configuration class is based on the Lasr CTC architecture from Google Health AI. You can find more details
and pre-trained models at [google/medasr](https://huggingface.co/google/medasr).
"""
@@ -390,6 +460,10 @@ def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
return input_lengths
+class LasrEncoderModelOutput(ParakeetEncoderModelOutput):
+ pass
+
+
@auto_docstring(
custom_intro="""
The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
@@ -424,16 +498,20 @@ def forward(
self,
input_features: torch.Tensor,
attention_mask: torch.Tensor | None = None,
+ output_attention_mask: bool | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
+ ) -> LasrEncoderModelOutput:
r"""
+ output_attention_mask (`bool`, *optional*):
+ Whether to return the output attention mask.
+
Example:
```python
>>> from transformers import AutoProcessor, LasrEncoder
>>> from datasets import load_dataset, Audio
- >>> model_id = TODO
+ >>> model_id = "google/medasr"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> encoder = ParakeetEncoder.from_pretrained(model_id)
@@ -456,8 +534,10 @@ def forward(
cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)
+ output_mask = None
if attention_mask is not None:
- attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
+ output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
+ attention_mask = output_mask
attention_mask = create_bidirectional_mask(
config=self.config,
@@ -483,7 +563,10 @@ def forward(
hidden_states = self.out_norm(hidden_states)
- return BaseModelOutput(last_hidden_state=hidden_states)
+ return LasrEncoderModelOutput(
+ last_hidden_state=hidden_states,
+ attention_mask=output_mask.int() if output_attention_mask and output_mask is not None else None,
+ )
class LasrForCTC(ParakeetForCTC):
@@ -495,7 +578,7 @@ def generate(**super_kwargs):
>>> from transformers import AutoProcessor, LasrForCTC
>>> from datasets import load_dataset, Audio
- >>> model_id = TODO
+ >>> model_id = "google/medasr"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> model = LasrForCTC.from_pretrained(model_id)
diff --git a/src/transformers/models/lasr/processing_lasr.py b/src/transformers/models/lasr/processing_lasr.py
index c1acaebaae07..9eb093a49c7a 100644
--- a/src/transformers/models/lasr/processing_lasr.py
+++ b/src/transformers/models/lasr/processing_lasr.py
@@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
from ...audio_utils import AudioInput, make_list_of_audio
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
diff --git a/src/transformers/models/parakeet/__init__.py b/src/transformers/models/parakeet/__init__.py
index 5c54b2e2eadb..e8bbfe7faf45 100644
--- a/src/transformers/models/parakeet/__init__.py
+++ b/src/transformers/models/parakeet/__init__.py
@@ -21,7 +21,8 @@
from .configuration_parakeet import *
from .feature_extraction_parakeet import *
from .modeling_parakeet import *
- from .tokenization_parakeet_fast import *
+ from .processing_parakeet import *
+ from .tokenization_parakeet import *
else:
import sys
diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py
index 6f4622ea3b2f..4b7c5b0fb526 100644
--- a/src/transformers/models/parakeet/configuration_parakeet.py
+++ b/src/transformers/models/parakeet/configuration_parakeet.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Parakeet model configuration."""
from huggingface_hub.dataclasses import strict
@@ -43,21 +42,18 @@ class ParakeetEncoderConfig(PreTrainedConfig):
Whether to scale the input embeddings.
Example:
- ```python
- >>> from transformers import ParakeetEncoderModel, ParakeetEncoderConfig
-
- >>> # Initializing a `ParakeetEncoder` configuration
- >>> configuration = ParakeetEncoderConfig()
+ ```python
+ >>> from transformers import ParakeetEncoderModel, ParakeetEncoderConfig
- >>> # Initializing a model from the configuration
- >>> model = ParakeetEncoderModel(configuration)
+ >>> # Initializing a `ParakeetEncoder` configuration
+ >>> configuration = ParakeetEncoderConfig()
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```
+ >>> # Initializing a model from the configuration
+ >>> model = ParakeetEncoderModel(configuration)
- This configuration class is based on the ParakeetEncoder architecture from NVIDIA NeMo. You can find more details
- and pre-trained models at [nvidia/parakeet-ctc-1.1b](https://huggingface.co/nvidia/parakeet-ctc-1.1b).
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
"""
model_type = "parakeet_encoder"
@@ -135,4 +131,60 @@ def __post_init__(self, **kwargs):
super().__post_init__(**kwargs)
-__all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig"]
+@auto_docstring(checkpoint="nvidia/parakeet-tdt-0.6b-v3")
+@strict
+class ParakeetTDTConfig(PreTrainedConfig):
+ r"""
+ decoder_hidden_size (`int`, *optional*, defaults to 640):
+ Hidden size of the LSTM prediction network and joint network.
+ num_decoder_layers (`int`, *optional*, defaults to 2):
+ Number of LSTM layers in the prediction network.
+ max_symbols_per_step (`int`, *optional*, defaults to 10):
+ Maximum number of symbols to emit per encoder time step during greedy decoding.
+ durations (`list[int]`, *optional*, defaults to `[0, 1, 2, 3, 4]`):
+ Token duration values that can be predicted. Each value represents how many frames a token or blank
+ emission spans.
+ encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*):
+ The config object or dictionary of the encoder.
+ blank_token_id (`int`, *optional*, defaults to 8192):
+ Blank token id. Different from `pad_token_id` for TDT.
+
+ Example:
+ ```python
+ >>> from transformers import ParakeetForTDT, ParakeetTDTConfig
+
+ >>> # Initializing a Parakeet TDT configuration
+ >>> configuration = ParakeetTDTConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = ParakeetForTDT(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "parakeet_tdt"
+ sub_configs = {"encoder_config": ParakeetEncoderConfig}
+
+ vocab_size: int = 8193
+ decoder_hidden_size: int = 640
+ num_decoder_layers: int = 2
+ hidden_act: str = "relu"
+ max_symbols_per_step: int = 10
+ durations: list[int] | tuple[int, ...] = (0, 1, 2, 3, 4)
+ encoder_config: dict | PreTrainedConfig | None = None
+ pad_token_id: int = 2
+ blank_token_id: int = 8192
+ is_encoder_decoder: bool = True
+
+ def __post_init__(self, **kwargs):
+ if isinstance(self.encoder_config, dict):
+ self.encoder_config = ParakeetEncoderConfig(**self.encoder_config)
+ elif self.encoder_config is None:
+ self.encoder_config = ParakeetEncoderConfig()
+ self.initializer_range = self.encoder_config.initializer_range
+ super().__post_init__(**kwargs)
+
+
+__all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig", "ParakeetTDTConfig"]
diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py
index 2d4085e6d340..b1be27fe5dcf 100644
--- a/src/transformers/models/parakeet/convert_nemo_to_hf.py
+++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py
@@ -24,11 +24,12 @@
from transformers import (
ParakeetCTCConfig,
- ParakeetEncoder,
ParakeetEncoderConfig,
ParakeetFeatureExtractor,
ParakeetForCTC,
+ ParakeetForTDT,
ParakeetProcessor,
+ ParakeetTDTConfig,
ParakeetTokenizer,
)
from transformers.convert_slow_tokenizer import ParakeetConverter
@@ -48,6 +49,15 @@
r"linear_pos": r"relative_k_proj",
}
+# Additional mappings for TDT decoder and joint network
+NEMO_TDT_WEIGHT_MAPPING = {
+ r"decoder\.prediction\.embed\.": r"decoder.embedding.",
+ r"decoder\.prediction\.dec_rnn\.lstm\.": r"decoder.lstm.",
+ r"joint\.enc\.": r"encoder_projector.",
+ r"joint\.pred\.": r"decoder.decoder_projector.",
+ r"joint\.joint_net\.2\.": r"joint.head.",
+}
+
def convert_key(key, mapping):
for pattern, replacement in mapping.items():
@@ -56,22 +66,12 @@ def convert_key(key, mapping):
def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str]:
- """
- Extract .nemo file (tar archive) and return paths to important files.
-
- Args:
- nemo_file_path: Path to .nemo file
- extract_dir: Directory to extract to
-
- Returns:
- Dictionary with paths to model.pt, model_config.yaml, etc.
- """
+ """Extract .nemo file (tar archive) and return paths to important files."""
print(f"Extracting NeMo archive: {nemo_file_path}")
with tarfile.open(nemo_file_path, "r", encoding="utf-8") as tar:
tar.extractall(extract_dir)
- # Log all extracted files for debugging
all_files = []
for root, dirs, files in os.walk(extract_dir):
for file in files:
@@ -80,14 +80,12 @@ def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str
print(f"All extracted files: {[os.path.basename(f) for f in all_files]}")
- # Find important files with more robust detection
model_files = {}
for root, dirs, files in os.walk(extract_dir):
for file in files:
file_path = os.path.join(root, file)
file_lower = file.lower()
- # Look for model weights with various common names
if (
file.endswith(".pt")
or file.endswith(".pth")
@@ -102,26 +100,23 @@ def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str
model_files["model_weights"] = file_path
print(f"Found model weights: {file}")
- # Look for config files
elif (
file == "model_config.yaml"
or file == "config.yaml"
or (file.endswith(".yaml") and "config" in file_lower)
):
- if "model_config" not in model_files: # Prefer model_config.yaml
+ if "model_config" not in model_files:
model_files["model_config"] = file_path
print(f"Found config file: {file}")
if file == "model_config.yaml":
- model_files["model_config"] = file_path # Override with preferred name
+ model_files["model_config"] = file_path
- # Look for vocabulary files
elif (
file.endswith(".vocab")
or file.endswith(".model")
or file.endswith(".txt")
or ("tokenizer" in file_lower and (file.endswith(".vocab") or file.endswith(".model")))
):
- # Prefer .vocab files over others
if "tokenizer_model_file" not in model_files or file.endswith(".model"):
model_files["tokenizer_model_file"] = file_path
print(f"Found tokenizer model file: {file}")
@@ -130,7 +125,6 @@ def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str
print(f"Found model files: {list(model_files.keys())}")
- # Validate that we found the required files
if "model_weights" not in model_files:
raise FileNotFoundError(
f"Could not find model weights file in {nemo_file_path}. "
@@ -148,15 +142,27 @@ def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str
return model_files
-def write_processor(nemo_config: dict, model_files, output_dir, push_to_repo_id=None):
+def write_processor(
+ nemo_config: dict, model_files, output_dir, model_type, push_to_repo_id=None, create_pr=True, revision=None
+):
tokenizer_converted = ParakeetConverter(model_files["tokenizer_model_file"]).converted()
tokenizer_converted_fast = ParakeetTokenizer(
tokenizer_object=tokenizer_converted,
clean_up_tokenization_spaces=False,
)
- tokenizer_converted_fast.add_tokens(
- [AddedToken("", normalized=False, special=True), AddedToken("", normalized=False, special=True)]
- )
+
+ if tokenizer_converted_fast.convert_tokens_to_ids("") is None:
+ # Normally CTC and TDT already have
+ tokenizer_converted_fast.add_tokens([AddedToken("", normalized=False, special=True)])
+ print(f"Added token at ID: {tokenizer_converted_fast.convert_tokens_to_ids('')}")
+ if tokenizer_converted_fast.convert_tokens_to_ids("") is None:
+ # Normally CTC doesn't have while TDT has at token id = 2
+ tokenizer_converted_fast.add_tokens([AddedToken("", normalized=False, special=True)])
+ print(f"Added token at ID: {tokenizer_converted_fast.convert_tokens_to_ids('')}")
+ if model_type == "tdt":
+ # TDT needs a separate blank token
+ tokenizer_converted_fast.add_tokens([AddedToken("", normalized=False, special=True)])
+ print(f"Added token at ID: {tokenizer_converted_fast.convert_tokens_to_ids('')}")
tokenizer_converted_fast.add_special_tokens(
{
"pad_token": AddedToken("", normalized=False, special=True),
@@ -193,7 +199,6 @@ def write_processor(nemo_config: dict, model_files, output_dir, push_to_repo_id=
raise ValueError(f"Key {key} not found in feature_extractor_keys_mapping")
feature_extractor = ParakeetFeatureExtractor(**converted_feature_extractor_config)
-
processor = ParakeetProcessor(
feature_extractor=feature_extractor,
tokenizer=tokenizer_converted_fast,
@@ -201,7 +206,12 @@ def write_processor(nemo_config: dict, model_files, output_dir, push_to_repo_id=
processor.save_pretrained(output_dir)
if push_to_repo_id:
- processor.push_to_hub(push_to_repo_id)
+ commit_info = processor.push_to_hub(push_to_repo_id, create_pr=create_pr, revision=revision)
+ if create_pr and hasattr(commit_info, "pr_url") and commit_info.pr_url:
+ pr_num = commit_info.pr_url.rstrip("/").split("/")[-1]
+ return f"refs/pr/{pr_num}"
+
+ return revision
def convert_encoder_config(nemo_config):
@@ -248,7 +258,6 @@ def convert_encoder_config(nemo_config):
continue
if key in encoder_config_keys_mapping:
converted_encoder_config[encoder_config_keys_mapping[key]] = value
- # NeMo uses 'use_bias' for both attention and convolution bias, but HF separates them
if key == "use_bias":
converted_encoder_config["convolution_bias"] = value
else:
@@ -262,7 +271,6 @@ def load_and_convert_state_dict(model_files):
state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True)
converted_state_dict = {}
for key, value in state_dict.items():
- # Skip preprocessing weights (featurizer components)
if key.endswith("featurizer.window") or key.endswith("featurizer.fb"):
print(f"Skipping preprocessing weight: {key}")
continue
@@ -272,7 +280,7 @@ def load_and_convert_state_dict(model_files):
return converted_state_dict
-def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None):
+def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None, revision=None):
"""Write CTC model using encoder config and converted state dict."""
model_config = ParakeetCTCConfig.from_encoder_config(encoder_config)
@@ -287,62 +295,117 @@ def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_re
model.save_pretrained(output_dir)
if push_to_repo_id:
- model.push_to_hub(push_to_repo_id)
+ model.push_to_hub(push_to_repo_id, revision=revision)
del model
- # Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto")
print("Model reloaded successfully.")
-def write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None):
- """Write encoder model using encoder config and converted state dict."""
- # Filter to only encoder weights (exclude CTC head if present)
- encoder_state_dict = {
- k.replace("encoder.", "", 1) if k.startswith("encoder.") else k: v
- for k, v in converted_state_dict.items()
- if k.startswith("encoder.")
- }
+def convert_tdt_config(nemo_config, encoder_config):
+ """Convert NeMo TDT config to HF TDT config."""
+ decoder_config = nemo_config["decoder"]
+ decoding_config = nemo_config["decoding"]
+ labels = nemo_config["labels"]
+ blank_token_id = len(labels)
+ vocab_size = len(labels) + 1 # +1 for blank token, which is added to tokenizer
+
+ prednet = decoder_config.get("prednet", {})
+ decoder_hidden_size = prednet.get("pred_hidden", 640)
+ num_decoder_layers = prednet.get("pred_rnn_layers", 2)
+ durations = decoding_config.get("durations", [0, 1, 2, 3, 4])
+ print(
+ f"TDT config: vocab_size={vocab_size} (including blank token), "
+ f"decoder_hidden={decoder_hidden_size}, "
+ f"decoder_layers={num_decoder_layers}, durations={durations}, "
+ )
+
+ return ParakeetTDTConfig(
+ vocab_size=vocab_size,
+ decoder_hidden_size=decoder_hidden_size,
+ num_decoder_layers=num_decoder_layers,
+ durations=durations,
+ hidden_act="relu",
+ max_symbols_per_step=10,
+ encoder_config=encoder_config.to_dict(),
+ pad_token_id=labels.index(""),
+ blank_token_id=blank_token_id, # blank token is different from pad token for TDT
+ )
+
+
+def load_and_convert_tdt_state_dict(model_files, vocab_size):
+ """Load NeMo TDT state dict and convert keys to HF format, splitting combined head."""
+ state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True)
+ converted_state_dict = {}
+
+ all_mappings = {**NEMO_TO_HF_WEIGHT_MAPPING, **NEMO_TDT_WEIGHT_MAPPING}
+
+ for key, value in state_dict.items():
+ if key.endswith("featurizer.window") or key.endswith("featurizer.fb"):
+ print(f"Skipping preprocessing weight: {key}")
+ continue
- print("Loading the checkpoint in a Parakeet Encoder model (for TDT).")
+ converted_key = convert_key(key, all_mappings)
+ converted_state_dict[converted_key] = value
+
+ return converted_state_dict
+
+
+def write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_to_repo_id=None, revision=None):
+ """Write TDT model using encoder config, TDT config, and converted state dict."""
+ model_config = convert_tdt_config(nemo_config, encoder_config)
+ print(f"Converted TDT config: {model_config}")
+
+ converted_state_dict = load_and_convert_tdt_state_dict(model_files, model_config.vocab_size)
+
+ print("Loading the checkpoint in a Parakeet TDT model.")
with torch.device("meta"):
- model = ParakeetEncoder(encoder_config)
+ model = ParakeetForTDT(model_config)
+
+ missing_keys, unexpected_keys = model.load_state_dict(converted_state_dict, strict=False, assign=True)
+
+ if missing_keys:
+ print(f"Warning: Missing keys: {missing_keys}")
+ if unexpected_keys:
+ print(f"Warning: Unexpected keys: {unexpected_keys}")
+
+ if not missing_keys and not unexpected_keys:
+ print("All weights loaded successfully!")
- model.load_state_dict(encoder_state_dict, strict=True, assign=True)
- print("Checkpoint loaded successfully.")
del model.config._name_or_path
+ model.generation_config.decoder_start_token_id = model.config.blank_token_id
+ model.generation_config.suppress_tokens = list(
+ range(model.config.vocab_size, model.config.vocab_size + len(model.config.durations))
+ )
+
print("Saving the model.")
model.save_pretrained(output_dir)
if push_to_repo_id:
- model.push_to_hub(push_to_repo_id)
+ model.push_to_hub(push_to_repo_id, revision=revision)
+
del model
- # Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
- ParakeetEncoder.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto")
+ ParakeetForTDT.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto")
print("Model reloaded successfully.")
-def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None):
+def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None, revision=None):
"""Main model conversion function."""
- # Step 1: Convert encoder config (shared across all model types)
encoder_config = convert_encoder_config(nemo_config)
print(f"Converted encoder config: {encoder_config}")
- # Step 2: Load and convert state dict (shared across all model types)
- converted_state_dict = load_and_convert_state_dict(model_files)
-
- # Step 3: Write model based on type
- if model_type == "encoder":
- write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id)
- elif model_type == "ctc":
- write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id)
+ if model_type == "ctc":
+ converted_state_dict = load_and_convert_state_dict(model_files)
+ write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id, revision)
+ elif model_type == "tdt":
+ write_tdt_model(nemo_config, encoder_config, model_files, output_dir, push_to_repo_id, revision)
else:
raise ValueError(f"Model type {model_type} not supported.")
@@ -352,6 +415,8 @@ def main(
output_dir,
model_type,
push_to_repo_id=None,
+ create_pr=True,
+ revision=None,
):
nemo_filename = f"{hf_repo_id.split('/')[-1]}.nemo"
filepath = cached_file(hf_repo_id, nemo_filename)
@@ -359,22 +424,62 @@ def main(
model_files = extract_nemo_archive(filepath, os.path.dirname(filepath))
nemo_config = yaml.load(open(model_files["model_config"], "r"), Loader=yaml.FullLoader)
- write_processor(nemo_config, model_files, output_dir, push_to_repo_id)
- write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id)
-
-
+ # When revision is given (e.g. "refs/pr/3"), both pushes target that existing PR branch.
+ # Otherwise, write_processor creates a new PR and returns its revision for write_model.
+ pr_revision = write_processor(
+ nemo_config,
+ model_files,
+ output_dir,
+ model_type,
+ push_to_repo_id,
+ create_pr=create_pr if revision is None else False,
+ revision=revision,
+ )
+ write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id, pr_revision)
+
+
+"""
+CTC conversion example:
+```bash
+python src/transformers/models/parakeet/convert_nemo_to_hf.py \
+ --hf_repo_id nvidia/parakeet-ctc-1.1b \
+ --model_type ctc \
+ --output_dir OUTPUT_DIR \
+ --push_to_repo_id USERNAME/parakeet-ctc-1.1b
+```
+
+TDT conversion example:
+```bash
+python src/transformers/models/parakeet/convert_nemo_to_hf.py \
+ --hf_repo_id nvidia/parakeet-tdt-0.6b-v3 \
+ --model_type tdt \
+ --output_dir OUTPUT_DIR \
+ --push_to_repo_id USERNAME/parakeet-tdt-0.6b-v3-hf
+```
+"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co")
- parser.add_argument(
- "--model_type", required=True, choices=["encoder", "ctc"], help="Model type (`encoder`, `ctc`)"
- )
+ parser.add_argument("--model_type", required=True, choices=["ctc", "tdt"], help="Model type (`ctc`, `tdt`)")
parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model")
parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub")
+ parser.add_argument(
+ "--create_pr",
+ default=True,
+ action=argparse.BooleanOptionalAction,
+ help="Create a PR when pushing to the Hub (default: True). Use --no-create_pr to push directly.",
+ )
+ parser.add_argument(
+ "--revision",
+ default=None,
+ help='Push to an existing Hub PR branch (e.g. "refs/pr/3"). Overrides --create_pr.',
+ )
args = parser.parse_args()
main(
args.hf_repo_id,
args.output_dir,
args.model_type,
args.push_to_repo_id,
+ args.create_pr,
+ args.revision,
)
diff --git a/src/transformers/models/parakeet/generation_parakeet.py b/src/transformers/models/parakeet/generation_parakeet.py
new file mode 100644
index 000000000000..fe422f3dd3a8
--- /dev/null
+++ b/src/transformers/models/parakeet/generation_parakeet.py
@@ -0,0 +1,185 @@
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+
+import torch
+
+from ...generation import GenerationMixin, StoppingCriteria
+from ...utils import ModelOutput
+
+
+@dataclass
+class ParakeetTDTGenerateOutput(ModelOutput):
+ """
+ Outputs of Parakeet TDT generation.
+
+ Args:
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Generated token sequences (including blank tokens).
+ durations (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Per-step durations in frames. Combined with `sequences`, this is sufficient
+ to reconstruct full timestamp information (frame indices are the cumulative sum
+ of durations).
+ attentions (`tuple(tuple(torch.FloatTensor))`, *optional*):
+ Encoder attention weights per layer.
+ hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*):
+ Encoder hidden states per layer.
+ """
+
+ sequences: torch.LongTensor
+ durations: torch.LongTensor | None = None
+ attentions: tuple[tuple[torch.FloatTensor]] | None = None
+ hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
+
+
+class EncoderExhaustedCriteria(StoppingCriteria):
+ """Stops generation when all batch elements have walked past their encoder output length."""
+
+ def __init__(self, model):
+ self.model = model
+
+ def __call__(self, input_ids, scores, **kwargs):
+ if self.model._encoder_finished is None:
+ return torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device)
+ return self.model._encoder_finished
+
+
+class ParakeetTDTGenerationMixin(GenerationMixin):
+ """Generation mixin for Parakeet TDT models.
+
+ Handles transducer-specific generation logic: encoder frame tracking,
+ duration accumulation, and encoder-exhaustion stopping.
+ """
+
+ def _get_stopping_criteria(self, *args, **kwargs):
+ criteria = super()._get_stopping_criteria(*args, **kwargs)
+ criteria.append(EncoderExhaustedCriteria(self))
+ return criteria
+
+ def _update_model_kwargs_for_generation(self, outputs, *args, **kwargs):
+ model_kwargs = super()._update_model_kwargs_for_generation(outputs, *args, **kwargs)
+
+ # Advance encoder frame pointer by the predicted duration
+ logits = outputs.logits[:, -1, :]
+ tokens = logits[:, : self.config.vocab_size].argmax(dim=-1)
+ durations = logits[:, self.config.vocab_size :].argmax(dim=-1)
+
+ # Only force forward progress (duration >= 1) for blank predictions;
+ blank_mask = tokens == self.config.blank_token_id
+ durations = torch.where(blank_mask & (durations == 0), torch.ones_like(durations), durations)
+ model_kwargs["encoder_frame_idxs"] = model_kwargs["encoder_frame_idxs"] + durations
+ self._step_durations.append(durations)
+
+ # Track which batch elements have exhausted their encoder frames.
+ self._encoder_finished = model_kwargs["encoder_frame_idxs"] >= model_kwargs["encoder_valid_lengths"]
+
+ return model_kwargs
+
+ def _prepare_generated_length(
+ self,
+ generation_config,
+ has_default_max_length,
+ has_default_min_length,
+ model_input_name,
+ input_ids_length,
+ inputs_tensor,
+ ):
+ # When the user hasn't explicitly set max_length/max_new_tokens, derive an upper
+ # bound from the encoder capacity. The actual stopping is handled by the
+ # encoder-exhaustion stopping criteria; this just sizes the output buffer.
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ encoder_seq_len = self.encoder._get_subsampling_output_length(
+ torch.tensor([inputs_tensor.shape[1]], device=inputs_tensor.device)
+ ).item()
+ generation_config.max_length = self.max_symbols_per_step * encoder_seq_len
+ has_default_max_length = False # prevent super() from overwriting
+ return super()._prepare_generated_length(
+ generation_config,
+ has_default_max_length,
+ has_default_min_length,
+ model_input_name,
+ input_ids_length,
+ inputs_tensor,
+ )
+
+ def _prepare_model_inputs(self, *args, **kwargs):
+ inputs, input_name, model_kwargs = super()._prepare_model_inputs(*args, **kwargs)
+
+ encoder_outputs = self.get_audio_features(
+ input_features=inputs,
+ attention_mask=model_kwargs.get("attention_mask", None),
+ output_attention_mask=True,
+ )
+ model_kwargs["encoder_outputs"] = encoder_outputs
+
+ if encoder_outputs.attention_mask is not None:
+ encoder_valid_lengths = encoder_outputs.attention_mask.sum(-1)
+ else:
+ batch_size = encoder_outputs.last_hidden_state.shape[0]
+ encoder_valid_lengths = torch.full(
+ (batch_size,),
+ encoder_outputs.last_hidden_state.shape[1],
+ dtype=torch.long,
+ device=encoder_outputs.last_hidden_state.device,
+ )
+ model_kwargs["encoder_valid_lengths"] = encoder_valid_lengths
+
+ model_kwargs["encoder_frame_idxs"] = torch.zeros(
+ inputs.shape[0],
+ device=inputs.device,
+ dtype=torch.long,
+ )
+
+ return inputs, input_name, model_kwargs
+
+ def _prepare_cache_for_generation(self, generation_config, model_kwargs, *args, **kwargs):
+ from .modeling_parakeet import ParakeetTDTDecoderCache
+
+ model_kwargs["decoder_cache"] = ParakeetTDTDecoderCache()
+
+ def prepare_inputs_for_generation(self, input_ids, *args, **kwargs):
+ from .modeling_parakeet import ParakeetEncoderModelOutput
+
+ model_inputs = super().prepare_inputs_for_generation(input_ids, *args, **kwargs)
+ encoder_frame_idxs = model_inputs.pop("encoder_frame_idxs").to(
+ model_inputs["encoder_outputs"].pooler_output.device
+ )
+
+ pooler_output = model_inputs["encoder_outputs"].pooler_output
+ batch_size, max_encoder_len = pooler_output.shape[0], pooler_output.shape[1]
+ encoder_frame_idxs = encoder_frame_idxs.clamp(max=max_encoder_len - 1)
+ model_inputs["encoder_outputs"] = ParakeetEncoderModelOutput(
+ pooler_output=pooler_output[torch.arange(batch_size), encoder_frame_idxs, None],
+ )
+
+ return model_inputs
+
+ def generate(self, inputs=None, generation_config=None, **kwargs):
+ # TODO @eustlb: this is temporary — we're going to modularize generate to allow doing this cleanly.
+ self._step_durations = []
+ self._encoder_finished = None
+
+ outputs = super().generate(inputs=inputs, generation_config=generation_config, **kwargs)
+ durations = torch.stack(self._step_durations, dim=1) # (batch, steps)
+ # Prepend a zero duration for the decoder_start_token_id that super().generate() prepends to sequences
+ durations = torch.cat(
+ [torch.zeros(durations.shape[0], 1, dtype=durations.dtype, device=durations.device), durations], dim=1
+ )
+ del self._step_durations, self._encoder_finished
+
+ return ParakeetTDTGenerateOutput(
+ sequences=outputs.sequences if isinstance(outputs, ModelOutput) else outputs,
+ durations=durations,
+ )
diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py
index 501a573f8494..4672dcab0cb2 100644
--- a/src/transformers/models/parakeet/modeling_parakeet.py
+++ b/src/transformers/models/parakeet/modeling_parakeet.py
@@ -27,35 +27,56 @@
from ... import initialization as init
from ...activations import ACT2FN
+from ...generation import CompileConfig, GenerationMixin, GenerationMode
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
from ...modeling_layers import GradientCheckpointingLayer
-from ...modeling_outputs import BaseModelOutput, CausalLMOutput
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
-from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils import (
+ ModelOutput,
+ TransformersKwargs,
+ auto_docstring,
+ can_return_tuple,
+ is_torchdynamo_compiling,
+ logging,
+)
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
-from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
+from ..auto import AutoModel
+from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig
+from .generation_parakeet import ParakeetTDTGenerationMixin
+
+
+logger = logging.get_logger(__name__)
@dataclass
@auto_docstring(
custom_intro="""
- Extends [~modeling_outputs.BaseModelOutput] to include the output attention mask since sequence length is not preserved in the model's forward.
+ Extends [~modeling_outputs.BaseModelOutputWithPooling] to include the output attention mask since sequence length
+ is not preserved in the model's forward.
"""
)
-class ParakeetEncoderModelOutput(BaseModelOutput):
+class ParakeetEncoderModelOutput(BaseModelOutputWithPooling):
attention_mask: torch.Tensor | None = None
class ParakeetEncoderRelPositionalEncoding(nn.Module):
- """Relative positional encoding for Parakeet."""
-
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: ParakeetEncoderConfig, device=None):
super().__init__()
self.max_position_embeddings = config.max_position_embeddings
+ self.config = config
+ inv_freq = self.compute_default_relative_positional_parameters(config, device=device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ @staticmethod
+ def compute_default_relative_positional_parameters(
+ config: ParakeetEncoderConfig | None = None,
+ device=None,
+ ) -> torch.Tensor:
base = 10000.0
inv_freq = 1.0 / (
base
@@ -64,18 +85,11 @@ def __init__(self, config: ParakeetEncoderConfig, device=None):
/ config.hidden_size
)
)
-
- self.register_buffer("inv_freq", inv_freq, persistent=False)
+ return inv_freq
@torch.no_grad()
def forward(self, hidden_states: torch.Tensor):
seq_length = hidden_states.shape[1]
- if seq_length > self.max_position_embeddings:
- raise ValueError(
- f"Sequence Length: {seq_length} has to be less or equal than "
- f"config.max_position_embeddings {self.max_position_embeddings}."
- )
-
position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device)
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device)
@@ -495,25 +509,17 @@ class ParakeetPreTrainedModel(PreTrainedModel):
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
-
- if hasattr(self.config, "initializer_range"):
- std = self.config.initializer_range
- else:
- # 0.02 is the standard default value across the library
- std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
+ std = getattr(self.config, "initializer_range", 0.02)
if isinstance(module, ParakeetEncoderAttention):
- # Initialize positional bias parameters
init.normal_(module.bias_u, mean=0.0, std=std)
init.normal_(module.bias_v, mean=0.0, std=std)
elif isinstance(module, ParakeetEncoderRelPositionalEncoding):
- inv_freq = 1.0 / (
- 10000.0 ** (torch.arange(0, self.config.hidden_size, 2, dtype=torch.int64) / self.config.hidden_size)
- )
- init.copy_(module.inv_freq, inv_freq)
+ buffer_value = module.compute_default_relative_positional_parameters(module.config)
+ init.copy_(module.inv_freq, buffer_value)
def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
- encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config
+ encoder_config = getattr(self.config, "encoder_config", self.config)
kernel_size = encoder_config.subsampling_conv_kernel_size
stride = encoder_config.subsampling_conv_stride
@@ -613,6 +619,7 @@ def forward(
position_embeddings, p=self.dropout_positions, training=self.training
)
+ output_mask = None
if attention_mask is not None:
output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
attention_mask = output_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
@@ -642,9 +649,9 @@ def forward(
@dataclass
-class ParakeetGenerateOutput(ModelOutput):
+class ParakeetCTCGenerateOutput(ModelOutput):
"""
- Outputs of Parakeet models.
+ Outputs of Parakeet CTC model generation.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -668,17 +675,30 @@ class ParakeetGenerateOutput(ModelOutput):
hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
+@dataclass
+class ParakeetGenerateOutput(ParakeetCTCGenerateOutput):
+ """
+ Deprecated alias for ParakeetCTCGenerateOutput. Use ParakeetCTCGenerateOutput instead.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ logger.warning_once(
+ "`ParakeetGenerateOutput` is deprecated and removed starting from version 5.5.0; please use `ParakeetCTCGenerateOutput` instead.",
+ )
+
+
@auto_docstring(
custom_intro="""
Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
"""
)
-class ParakeetForCTC(ParakeetPreTrainedModel):
+class ParakeetForCTC(ParakeetPreTrainedModel, GenerationMixin):
config: ParakeetCTCConfig
def __init__(self, config: ParakeetCTCConfig):
super().__init__(config)
- self.encoder = ParakeetEncoder(config.encoder_config)
+ self.encoder = AutoModel.from_config(config.encoder_config)
# Conv rather than linear to be consistent with NeMO decoding layer
self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
@@ -713,6 +733,8 @@ def forward(
>>> print(outputs.loss)
```"""
+ if labels is not None:
+ kwargs.setdefault("output_attention_mask", True)
encoder_outputs = self.encoder(
input_features=input_features,
attention_mask=attention_mask,
@@ -724,14 +746,9 @@ def forward(
loss = None
if labels is not None:
- # retrieve loss input_lengths from attention_mask
- attention_mask = (
- attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
- )
- input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
+ encoder_lengths = encoder_outputs.attention_mask.sum(-1)
- # assuming that padded tokens are filled with -100
- # when not being attended to
+ # assuming that padded tokens are filled with pad_token_id when not being attended to
labels_mask = labels != self.config.pad_token_id
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
@@ -743,7 +760,7 @@ def forward(
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
- input_lengths,
+ encoder_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
@@ -763,9 +780,13 @@ def generate(
input_features: torch.Tensor,
attention_mask: torch.Tensor | None = None,
return_dict_in_generate: bool = False,
+ compile_config: CompileConfig | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> ParakeetGenerateOutput | torch.LongTensor:
+ ) -> ParakeetCTCGenerateOutput | torch.LongTensor:
r"""
+ compile_config ([`~generation.CompileConfig`], *optional*):
+ If provided, `torch.compile` will be applied to the forward calls in the decoding loop.
+
Example:
```python
@@ -786,8 +807,10 @@ def generate(
>>> print(transcription)
```
"""
+ model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__
+
kwargs["return_dict"] = True
- outputs: CausalLMOutput = self.forward(
+ outputs: CausalLMOutput = model_forward(
input_features=input_features,
attention_mask=attention_mask,
**kwargs,
@@ -802,7 +825,7 @@ def generate(
sequences[~attention_mask] = self.config.pad_token_id
if return_dict_in_generate:
- return ParakeetGenerateOutput(
+ return ParakeetCTCGenerateOutput(
sequences=sequences,
logits=outputs.logits,
attentions=outputs.attentions,
@@ -812,4 +835,272 @@ def generate(
return sequences
-__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"]
+class ParakeetTDTDecoderCache:
+ def __init__(self):
+ self.cache: torch.Tensor | None = None
+ self.hidden_state: torch.Tensor | None = None
+ self.cell_state: torch.Tensor | None = None
+ self.is_initialized: bool = False
+
+ def lazy_initialization(self, hidden_states, lstm_module):
+ self.cache = torch.zeros(
+ hidden_states.shape[0], 1, lstm_module.hidden_size, device=hidden_states.device, dtype=hidden_states.dtype
+ )
+ self.hidden_state = torch.zeros(
+ lstm_module.num_layers,
+ hidden_states.shape[0],
+ lstm_module.hidden_size,
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+ self.cell_state = torch.zeros(
+ lstm_module.num_layers,
+ hidden_states.shape[0],
+ lstm_module.hidden_size,
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+
+ if not is_torchdynamo_compiling():
+ torch._dynamo.mark_static_address(self.cache)
+ torch._dynamo.mark_static_address(self.hidden_state)
+ torch._dynamo.mark_static_address(self.cell_state)
+
+ self.is_initialized = True
+
+ def update(
+ self,
+ decoder_output,
+ hidden_state,
+ cell_state,
+ lstm_module=None,
+ mask=None,
+ ):
+ if not self.is_initialized and lstm_module is not None:
+ self.lazy_initialization(decoder_output, lstm_module)
+ elif not self.is_initialized:
+ raise ValueError(
+ "ParakeetTDTDecoderCache is not initialized. Make sure to provide lstm_module to the update method."
+ )
+
+ if mask is None:
+ self.hidden_state.copy_(hidden_state)
+ self.cell_state.copy_(cell_state)
+ self.cache.copy_(decoder_output)
+ else:
+ # Mask to update specific batch elements
+ mask = mask.to(decoder_output.device)
+ batch_size = decoder_output.shape[0]
+ mask_h = mask.view(1, batch_size, 1)
+ mask_d = mask.view(batch_size, 1, 1)
+ self.cache = torch.where(mask_d, decoder_output, self.cache)
+ self.hidden_state = torch.where(mask_h, hidden_state, self.hidden_state)
+ self.cell_state = torch.where(mask_h, cell_state, self.cell_state)
+
+
+class ParakeetTDTDecoder(nn.Module):
+ """LSTM-based prediction network for TDT."""
+
+ def __init__(self, config: ParakeetTDTConfig):
+ super().__init__()
+ self.blank_token_id = config.blank_token_id
+ self.embedding = nn.Embedding(config.vocab_size, config.decoder_hidden_size)
+ self.lstm = nn.LSTM(
+ input_size=config.decoder_hidden_size,
+ hidden_size=config.decoder_hidden_size,
+ num_layers=config.num_decoder_layers,
+ batch_first=True,
+ )
+ self.decoder_projector = nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ cache: ParakeetTDTDecoderCache | None = None,
+ ) -> torch.Tensor:
+ # All-blank fast path
+ if cache is not None and cache.is_initialized:
+ blank_mask = input_ids[:, -1] == self.blank_token_id
+ if blank_mask.all():
+ return cache.cache
+
+ hidden_cell_states = (
+ (cache.hidden_state, cache.cell_state) if cache is not None and cache.is_initialized else None
+ )
+ embeddings = self.embedding(input_ids)
+ lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, hidden_cell_states)
+ decoder_output = self.decoder_projector(lstm_output)
+
+ if cache is not None:
+ mask = ~blank_mask if cache.is_initialized else None
+ cache.update(decoder_output, hidden_state, cell_state, lstm_module=self.lstm, mask=mask)
+ return cache.cache
+
+ return decoder_output
+
+
+class ParakeetTDTJointNetwork(nn.Module):
+ """Joint network that combines encoder and decoder outputs to predict tokens and durations."""
+
+ def __init__(self, config: ParakeetTDTConfig):
+ super().__init__()
+ self.activation = ACT2FN[config.hidden_act]
+ self.head = nn.Linear(config.decoder_hidden_size, config.vocab_size + len(config.durations))
+ self.vocab_size = config.vocab_size
+
+ def forward(
+ self,
+ decoder_hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ joint_output = self.activation(encoder_hidden_states + decoder_hidden_states)
+ return self.head(joint_output)
+
+
+@dataclass
+class ParakeetTDTOutput(BaseModelOutputWithPooling):
+ """
+ Output of the Parakeet TDT forward pass.
+
+ Args:
+ loss (`torch.FloatTensor`, *optional*):
+ TDT loss, returned when `labels` are provided.
+ logits (`torch.FloatTensor`):
+ Joint token and duration logits. Shape is `(batch, T, U+1, vocab+durations)` for training
+ or `(batch, 1, 1, vocab+durations)` for single-step inference.
+ decoder_cache (`ParakeetTDTDecoderCache`, *optional*):
+ Decoder LSTM cache containing hidden state, cell state, and last output.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ decoder_cache: ParakeetTDTDecoderCache | None = None
+
+
+@auto_docstring(
+ custom_intro="""
+ Parakeet Encoder with a TDT (Token Duration Transducer) head.
+ """
+)
+class ParakeetForTDT(ParakeetPreTrainedModel, ParakeetTDTGenerationMixin):
+ config: ParakeetTDTConfig
+ _no_split_modules = ["ParakeetTDTDecoder"]
+ _supported_generation_modes = [GenerationMode.GREEDY_SEARCH]
+
+ def __init__(self, config: ParakeetTDTConfig):
+ super().__init__(config)
+ self.encoder = AutoModel.from_config(config.encoder_config)
+ self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size)
+ self.decoder = ParakeetTDTDecoder(config)
+ self.joint = ParakeetTDTJointNetwork(config)
+ self.max_symbols_per_step = config.max_symbols_per_step # used in generation
+
+ self.post_init()
+
+ @can_return_tuple
+ def get_audio_features(
+ self,
+ input_features: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> ParakeetEncoderModelOutput:
+ encoder_outputs = self.encoder(
+ input_features=input_features,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ encoder_outputs.pooler_output = self.encoder_projector(encoder_outputs.last_hidden_state)
+ return encoder_outputs
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_features: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ decoder_input_ids: torch.LongTensor | None = None,
+ decoder_cache: ParakeetTDTDecoderCache | None = None,
+ use_decoder_cache: bool | None = None,
+ encoder_outputs: ParakeetEncoderModelOutput | tuple[torch.FloatTensor] | None = None,
+ labels: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> ParakeetTDTOutput:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
+ Decoder input token ids for single-step inference.
+ decoder_cache (`ParakeetTDTDecoderCache`, *optional*):
+ Decoder LSTM cache. When provided and initialized, the cached `decoder_output` is reused
+ (e.g. during blank-skipping) instead of running the decoder. When `input_ids` is provided,
+ the decoder runs and the cache is updated in-place.
+ use_decoder_cache (`bool`, *optional*):
+ Whether to use a decoder cache. When `True` and `decoder_cache` is `None`, a new cache
+ is created automatically during the forward pass.
+ encoder_outputs (`tuple(torch.FloatTensor)`, *optional*):
+ Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask).
+ Can be a tuple or `ParakeetEncoderModelOutput`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, ParakeetForTDT
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = "nvidia/parakeet-tdt-0.6b-v3"
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+ >>> model = ParakeetForTDT.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
+
+ >>> inputs = processor(ds[0]["audio"]["array"])
+ >>> outputs = model(**inputs)
+ ```
+ """
+ if encoder_outputs is None:
+ encoder_outputs = self.get_audio_features(
+ input_features=input_features,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ elif not isinstance(encoder_outputs, ParakeetEncoderModelOutput):
+ encoder_outputs = ParakeetEncoderModelOutput(
+ last_hidden_state=encoder_outputs[0] if len(encoder_outputs) > 0 else None,
+ pooler_output=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ hidden_states=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ attentions=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
+ attention_mask=encoder_outputs[4] if len(encoder_outputs) > 4 else None,
+ )
+
+ if use_decoder_cache and decoder_cache is None:
+ decoder_cache = ParakeetTDTDecoderCache()
+
+ decoder_hidden_states = self.decoder(decoder_input_ids, cache=decoder_cache)
+ logits = self.joint(
+ encoder_hidden_states=encoder_outputs.pooler_output[:, :, None, :],
+ decoder_hidden_states=decoder_hidden_states[:, None, :, :],
+ ).squeeze(2)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ token_logits=logits[..., : self.config.vocab_size],
+ duration_logits=logits[..., self.config.vocab_size :],
+ labels=labels,
+ logit_lengths=encoder_outputs.attention_mask.sum(-1),
+ label_lengths=(labels != self.config.pad_token_id).sum(-1),
+ blank_token_id=self.config.blank_token_id,
+ durations=self.config.durations,
+ )
+
+ return ParakeetTDTOutput(
+ loss=loss,
+ logits=logits,
+ last_hidden_state=encoder_outputs.last_hidden_state,
+ pooler_output=encoder_outputs.pooler_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ decoder_cache=decoder_cache,
+ )
+
+
+__all__ = ["ParakeetForCTC", "ParakeetForTDT", "ParakeetEncoder", "ParakeetPreTrainedModel"]
diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py
index b53d61a0c22d..22fce9362648 100644
--- a/src/transformers/models/parakeet/modular_parakeet.py
+++ b/src/transformers/models/parakeet/modular_parakeet.py
@@ -22,36 +22,57 @@
from ... import initialization as init
from ...activations import ACT2FN
+from ...generation import CompileConfig, GenerationMixin, GenerationMode
from ...modeling_layers import GradientCheckpointingLayer
-from ...modeling_outputs import BaseModelOutput, CausalLMOutput
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
-from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils import (
+ ModelOutput,
+ TransformersKwargs,
+ auto_docstring,
+ can_return_tuple,
+ is_torchdynamo_compiling,
+ logging,
+)
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
+from ..auto import AutoModel
from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule
from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
-from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
+from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig, ParakeetTDTConfig
+from .generation_parakeet import ParakeetTDTGenerationMixin
+
+
+logger = logging.get_logger(__name__)
@dataclass
@auto_docstring(
custom_intro="""
- Extends [~modeling_outputs.BaseModelOutput] to include the output attention mask since sequence length is not preserved in the model's forward.
+ Extends [~modeling_outputs.BaseModelOutputWithPooling] to include the output attention mask since sequence length
+ is not preserved in the model's forward.
"""
)
-class ParakeetEncoderModelOutput(BaseModelOutput):
+class ParakeetEncoderModelOutput(BaseModelOutputWithPooling):
attention_mask: torch.Tensor | None = None
class ParakeetEncoderRelPositionalEncoding(nn.Module):
- """Relative positional encoding for Parakeet."""
-
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: ParakeetEncoderConfig, device=None):
super().__init__()
self.max_position_embeddings = config.max_position_embeddings
+ self.config = config
+ inv_freq = self.compute_default_relative_positional_parameters(config, device=device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ @staticmethod
+ def compute_default_relative_positional_parameters(
+ config: ParakeetEncoderConfig | None = None,
+ device=None,
+ ) -> torch.Tensor:
base = 10000.0
inv_freq = 1.0 / (
base
@@ -60,18 +81,11 @@ def __init__(self, config: ParakeetEncoderConfig, device=None):
/ config.hidden_size
)
)
-
- self.register_buffer("inv_freq", inv_freq, persistent=False)
+ return inv_freq
@torch.no_grad()
def forward(self, hidden_states: torch.Tensor):
seq_length = hidden_states.shape[1]
- if seq_length > self.max_position_embeddings:
- raise ValueError(
- f"Sequence Length: {seq_length} has to be less or equal than "
- f"config.max_position_embeddings {self.max_position_embeddings}."
- )
-
position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device)
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device)
@@ -334,25 +348,17 @@ class ParakeetPreTrainedModel(PreTrainedModel):
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
-
- if hasattr(self.config, "initializer_range"):
- std = self.config.initializer_range
- else:
- # 0.02 is the standard default value across the library
- std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
+ std = getattr(self.config, "initializer_range", 0.02)
if isinstance(module, ParakeetEncoderAttention):
- # Initialize positional bias parameters
init.normal_(module.bias_u, mean=0.0, std=std)
init.normal_(module.bias_v, mean=0.0, std=std)
elif isinstance(module, ParakeetEncoderRelPositionalEncoding):
- inv_freq = 1.0 / (
- 10000.0 ** (torch.arange(0, self.config.hidden_size, 2, dtype=torch.int64) / self.config.hidden_size)
- )
- init.copy_(module.inv_freq, inv_freq)
+ buffer_value = module.compute_default_relative_positional_parameters(module.config)
+ init.copy_(module.inv_freq, buffer_value)
def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
- encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config
+ encoder_config = getattr(self.config, "encoder_config", self.config)
kernel_size = encoder_config.subsampling_conv_kernel_size
stride = encoder_config.subsampling_conv_stride
@@ -452,6 +458,7 @@ def forward(
position_embeddings, p=self.dropout_positions, training=self.training
)
+ output_mask = None
if attention_mask is not None:
output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
attention_mask = output_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
@@ -481,9 +488,9 @@ def forward(
@dataclass
-class ParakeetGenerateOutput(ModelOutput):
+class ParakeetCTCGenerateOutput(ModelOutput):
"""
- Outputs of Parakeet models.
+ Outputs of Parakeet CTC model generation.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -507,17 +514,30 @@ class ParakeetGenerateOutput(ModelOutput):
hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
+@dataclass
+class ParakeetGenerateOutput(ParakeetCTCGenerateOutput):
+ """
+ Deprecated alias for ParakeetCTCGenerateOutput. Use ParakeetCTCGenerateOutput instead.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ logger.warning_once(
+ "`ParakeetGenerateOutput` is deprecated and removed starting from version 5.5.0; please use `ParakeetCTCGenerateOutput` instead.",
+ )
+
+
@auto_docstring(
custom_intro="""
Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
"""
)
-class ParakeetForCTC(ParakeetPreTrainedModel):
+class ParakeetForCTC(ParakeetPreTrainedModel, GenerationMixin):
config: ParakeetCTCConfig
def __init__(self, config: ParakeetCTCConfig):
super().__init__(config)
- self.encoder = ParakeetEncoder(config.encoder_config)
+ self.encoder = AutoModel.from_config(config.encoder_config)
# Conv rather than linear to be consistent with NeMO decoding layer
self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
@@ -552,6 +572,8 @@ def forward(
>>> print(outputs.loss)
```"""
+ if labels is not None:
+ kwargs.setdefault("output_attention_mask", True)
encoder_outputs = self.encoder(
input_features=input_features,
attention_mask=attention_mask,
@@ -563,14 +585,9 @@ def forward(
loss = None
if labels is not None:
- # retrieve loss input_lengths from attention_mask
- attention_mask = (
- attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
- )
- input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
+ encoder_lengths = encoder_outputs.attention_mask.sum(-1)
- # assuming that padded tokens are filled with -100
- # when not being attended to
+ # assuming that padded tokens are filled with pad_token_id when not being attended to
labels_mask = labels != self.config.pad_token_id
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
@@ -582,7 +599,7 @@ def forward(
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
- input_lengths,
+ encoder_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
@@ -602,9 +619,13 @@ def generate(
input_features: torch.Tensor,
attention_mask: torch.Tensor | None = None,
return_dict_in_generate: bool = False,
+ compile_config: CompileConfig | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> ParakeetGenerateOutput | torch.LongTensor:
+ ) -> ParakeetCTCGenerateOutput | torch.LongTensor:
r"""
+ compile_config ([`~generation.CompileConfig`], *optional*):
+ If provided, `torch.compile` will be applied to the forward calls in the decoding loop.
+
Example:
```python
@@ -625,8 +646,10 @@ def generate(
>>> print(transcription)
```
"""
+ model_forward = self.get_compiled_call(compile_config) if compile_config is not None else self.__call__
+
kwargs["return_dict"] = True
- outputs: CausalLMOutput = self.forward(
+ outputs: CausalLMOutput = model_forward(
input_features=input_features,
attention_mask=attention_mask,
**kwargs,
@@ -641,7 +664,7 @@ def generate(
sequences[~attention_mask] = self.config.pad_token_id
if return_dict_in_generate:
- return ParakeetGenerateOutput(
+ return ParakeetCTCGenerateOutput(
sequences=sequences,
logits=outputs.logits,
attentions=outputs.attentions,
@@ -651,4 +674,272 @@ def generate(
return sequences
-__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"]
+class ParakeetTDTDecoderCache:
+ def __init__(self):
+ self.cache: torch.Tensor | None = None
+ self.hidden_state: torch.Tensor | None = None
+ self.cell_state: torch.Tensor | None = None
+ self.is_initialized: bool = False
+
+ def lazy_initialization(self, hidden_states, lstm_module):
+ self.cache = torch.zeros(
+ hidden_states.shape[0], 1, lstm_module.hidden_size, device=hidden_states.device, dtype=hidden_states.dtype
+ )
+ self.hidden_state = torch.zeros(
+ lstm_module.num_layers,
+ hidden_states.shape[0],
+ lstm_module.hidden_size,
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+ self.cell_state = torch.zeros(
+ lstm_module.num_layers,
+ hidden_states.shape[0],
+ lstm_module.hidden_size,
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+
+ if not is_torchdynamo_compiling():
+ torch._dynamo.mark_static_address(self.cache)
+ torch._dynamo.mark_static_address(self.hidden_state)
+ torch._dynamo.mark_static_address(self.cell_state)
+
+ self.is_initialized = True
+
+ def update(
+ self,
+ decoder_output,
+ hidden_state,
+ cell_state,
+ lstm_module=None,
+ mask=None,
+ ):
+ if not self.is_initialized and lstm_module is not None:
+ self.lazy_initialization(decoder_output, lstm_module)
+ elif not self.is_initialized:
+ raise ValueError(
+ "ParakeetTDTDecoderCache is not initialized. Make sure to provide lstm_module to the update method."
+ )
+
+ if mask is None:
+ self.hidden_state.copy_(hidden_state)
+ self.cell_state.copy_(cell_state)
+ self.cache.copy_(decoder_output)
+ else:
+ # Mask to update specific batch elements
+ mask = mask.to(decoder_output.device)
+ batch_size = decoder_output.shape[0]
+ mask_h = mask.view(1, batch_size, 1)
+ mask_d = mask.view(batch_size, 1, 1)
+ self.cache = torch.where(mask_d, decoder_output, self.cache)
+ self.hidden_state = torch.where(mask_h, hidden_state, self.hidden_state)
+ self.cell_state = torch.where(mask_h, cell_state, self.cell_state)
+
+
+class ParakeetTDTDecoder(nn.Module):
+ """LSTM-based prediction network for TDT."""
+
+ def __init__(self, config: ParakeetTDTConfig):
+ super().__init__()
+ self.blank_token_id = config.blank_token_id
+ self.embedding = nn.Embedding(config.vocab_size, config.decoder_hidden_size)
+ self.lstm = nn.LSTM(
+ input_size=config.decoder_hidden_size,
+ hidden_size=config.decoder_hidden_size,
+ num_layers=config.num_decoder_layers,
+ batch_first=True,
+ )
+ self.decoder_projector = nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ cache: ParakeetTDTDecoderCache | None = None,
+ ) -> torch.Tensor:
+ # All-blank fast path
+ if cache is not None and cache.is_initialized:
+ blank_mask = input_ids[:, -1] == self.blank_token_id
+ if blank_mask.all():
+ return cache.cache
+
+ hidden_cell_states = (
+ (cache.hidden_state, cache.cell_state) if cache is not None and cache.is_initialized else None
+ )
+ embeddings = self.embedding(input_ids)
+ lstm_output, (hidden_state, cell_state) = self.lstm(embeddings, hidden_cell_states)
+ decoder_output = self.decoder_projector(lstm_output)
+
+ if cache is not None:
+ mask = ~blank_mask if cache.is_initialized else None
+ cache.update(decoder_output, hidden_state, cell_state, lstm_module=self.lstm, mask=mask)
+ return cache.cache
+
+ return decoder_output
+
+
+class ParakeetTDTJointNetwork(nn.Module):
+ """Joint network that combines encoder and decoder outputs to predict tokens and durations."""
+
+ def __init__(self, config: ParakeetTDTConfig):
+ super().__init__()
+ self.activation = ACT2FN[config.hidden_act]
+ self.head = nn.Linear(config.decoder_hidden_size, config.vocab_size + len(config.durations))
+ self.vocab_size = config.vocab_size
+
+ def forward(
+ self,
+ decoder_hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ joint_output = self.activation(encoder_hidden_states + decoder_hidden_states)
+ return self.head(joint_output)
+
+
+@dataclass
+class ParakeetTDTOutput(BaseModelOutputWithPooling):
+ """
+ Output of the Parakeet TDT forward pass.
+
+ Args:
+ loss (`torch.FloatTensor`, *optional*):
+ TDT loss, returned when `labels` are provided.
+ logits (`torch.FloatTensor`):
+ Joint token and duration logits. Shape is `(batch, T, U+1, vocab+durations)` for training
+ or `(batch, 1, 1, vocab+durations)` for single-step inference.
+ decoder_cache (`ParakeetTDTDecoderCache`, *optional*):
+ Decoder LSTM cache containing hidden state, cell state, and last output.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ decoder_cache: ParakeetTDTDecoderCache | None = None
+
+
+@auto_docstring(
+ custom_intro="""
+ Parakeet Encoder with a TDT (Token Duration Transducer) head.
+ """
+)
+class ParakeetForTDT(ParakeetPreTrainedModel, ParakeetTDTGenerationMixin):
+ config: ParakeetTDTConfig
+ _no_split_modules = ["ParakeetTDTDecoder"]
+ _supported_generation_modes = [GenerationMode.GREEDY_SEARCH]
+
+ def __init__(self, config: ParakeetTDTConfig):
+ super().__init__(config)
+ self.encoder = AutoModel.from_config(config.encoder_config)
+ self.encoder_projector = nn.Linear(config.encoder_config.hidden_size, config.decoder_hidden_size)
+ self.decoder = ParakeetTDTDecoder(config)
+ self.joint = ParakeetTDTJointNetwork(config)
+ self.max_symbols_per_step = config.max_symbols_per_step # used in generation
+
+ self.post_init()
+
+ @can_return_tuple
+ def get_audio_features(
+ self,
+ input_features: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> ParakeetEncoderModelOutput:
+ encoder_outputs = self.encoder(
+ input_features=input_features,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ encoder_outputs.pooler_output = self.encoder_projector(encoder_outputs.last_hidden_state)
+ return encoder_outputs
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_features: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ decoder_input_ids: torch.LongTensor | None = None,
+ decoder_cache: ParakeetTDTDecoderCache | None = None,
+ use_decoder_cache: bool | None = None,
+ encoder_outputs: ParakeetEncoderModelOutput | tuple[torch.FloatTensor] | None = None,
+ labels: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> ParakeetTDTOutput:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
+ Decoder input token ids for single-step inference.
+ decoder_cache (`ParakeetTDTDecoderCache`, *optional*):
+ Decoder LSTM cache. When provided and initialized, the cached `decoder_output` is reused
+ (e.g. during blank-skipping) instead of running the decoder. When `input_ids` is provided,
+ the decoder runs and the cache is updated in-place.
+ use_decoder_cache (`bool`, *optional*):
+ Whether to use a decoder cache. When `True` and `decoder_cache` is `None`, a new cache
+ is created automatically during the forward pass.
+ encoder_outputs (`tuple(torch.FloatTensor)`, *optional*):
+ Pre-computed encoder outputs (last_hidden_state, pooler_output, hidden_states, attentions, attention_mask).
+ Can be a tuple or `ParakeetEncoderModelOutput`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, ParakeetForTDT
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = "nvidia/parakeet-tdt-0.6b-v3"
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+ >>> model = ParakeetForTDT.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
+
+ >>> inputs = processor(ds[0]["audio"]["array"])
+ >>> outputs = model(**inputs)
+ ```
+ """
+ if encoder_outputs is None:
+ encoder_outputs = self.get_audio_features(
+ input_features=input_features,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ elif not isinstance(encoder_outputs, ParakeetEncoderModelOutput):
+ encoder_outputs = ParakeetEncoderModelOutput(
+ last_hidden_state=encoder_outputs[0] if len(encoder_outputs) > 0 else None,
+ pooler_output=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ hidden_states=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ attentions=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
+ attention_mask=encoder_outputs[4] if len(encoder_outputs) > 4 else None,
+ )
+
+ if use_decoder_cache and decoder_cache is None:
+ decoder_cache = ParakeetTDTDecoderCache()
+
+ decoder_hidden_states = self.decoder(decoder_input_ids, cache=decoder_cache)
+ logits = self.joint(
+ encoder_hidden_states=encoder_outputs.pooler_output[:, :, None, :],
+ decoder_hidden_states=decoder_hidden_states[:, None, :, :],
+ ).squeeze(2)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ token_logits=logits[..., : self.config.vocab_size],
+ duration_logits=logits[..., self.config.vocab_size :],
+ labels=labels,
+ logit_lengths=encoder_outputs.attention_mask.sum(-1),
+ label_lengths=(labels != self.config.pad_token_id).sum(-1),
+ blank_token_id=self.config.blank_token_id,
+ durations=self.config.durations,
+ )
+
+ return ParakeetTDTOutput(
+ loss=loss,
+ logits=logits,
+ last_hidden_state=encoder_outputs.last_hidden_state,
+ pooler_output=encoder_outputs.pooler_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ decoder_cache=decoder_cache,
+ )
+
+
+__all__ = ["ParakeetForCTC", "ParakeetForTDT", "ParakeetEncoder", "ParakeetPreTrainedModel"]
diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py
index 69734fb055af..85b63f396765 100644
--- a/src/transformers/models/parakeet/processing_parakeet.py
+++ b/src/transformers/models/parakeet/processing_parakeet.py
@@ -27,6 +27,7 @@ class ParakeetProcessorKwargs(ProcessingKwargs, total=False):
"sampling_rate": 16000,
"padding": "longest",
"return_attention_mask": True,
+ "subsampling_factor": 8,
},
"text_kwargs": {
"padding": True,
@@ -39,7 +40,13 @@ class ParakeetProcessorKwargs(ProcessingKwargs, total=False):
@auto_docstring
class ParakeetProcessor(ProcessorMixin):
- def __init__(self, feature_extractor, tokenizer):
+ def __init__(self, feature_extractor, tokenizer, blank_token=""):
+ """
+ blank_token (`str`, *optional*, defaults to `""`):
+ Blank token for TDT decoding.
+ """
+ self.blank_token = blank_token
+ self.blank_token_id = tokenizer.convert_tokens_to_ids(blank_token)
super().__init__(feature_extractor, tokenizer)
@auto_docstring
@@ -83,12 +90,78 @@ def __call__(
return inputs
else:
inputs["labels"] = encodings["input_ids"]
+ # Prepend blank token to labels to form decoder_input_ids.
+ # The TDT decoder expects [blank, label_0, ..., label_{U-1}] as input,
+ if isinstance(text, str):
+ text = [text]
+ decoder_text = [self.blank_token + t for t in text]
+ decoder_encodings = self.tokenizer(decoder_text, **output_kwargs["text_kwargs"])
+ inputs["decoder_input_ids"] = decoder_encodings["input_ids"]
return inputs
@property
def model_input_names(self):
feature_extractor_input_names = self.feature_extractor.model_input_names
- return feature_extractor_input_names + ["labels"]
+ return feature_extractor_input_names + ["labels", "decoder_input_ids"]
+
+ def decode(self, *args, durations=None, **kwargs):
+ """
+ Forward arguments to [`~PreTrainedTokenizer.decode`] and post-process the timestamps (if provided for TDT) as
+ in the NeMo library.
+ """
+ decoded = self.tokenizer.decode(*args, **kwargs)
+
+ if durations is not None:
+ token_ids = args[0]
+ # Derive per-step frame indices from cumulative sum of durations.
+ timestamps = durations.cumsum(dim=-1) - durations
+
+ output_kwargs = self._merge_kwargs(
+ ParakeetProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ )
+ frame_rate = (
+ self.feature_extractor.hop_length
+ / self.feature_extractor.sampling_rate
+ * output_kwargs["audio_kwargs"]["subsampling_factor"]
+ )
+ proc_timestamps = []
+ for batch_ids, batch_timestamps, batch_durations in zip(token_ids, timestamps, durations):
+ # See `compute_rnnt_timestamps` in NeMo: https://github.com/NVIDIA-NeMo/NeMo/blob/1692a8fb97e1aadc883cfadd2a57c4e8a1b793aa/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L993
+ # Filter padding and blank tokens
+ skip_ids = {self.tokenizer.pad_token_id, self.blank_token_id}
+ non_blank_indices = [i for i, token_id in enumerate(batch_ids) if int(token_id) not in skip_ids]
+ non_blank_ids = [batch_ids[i] for i in non_blank_indices]
+ decoded_tokens = [self.tokenizer.decode([token_id]) for token_id in non_blank_ids]
+ timestamp_dict = [
+ {
+ "token": token_str,
+ "start": int(batch_timestamps[i]),
+ "end": int(batch_timestamps[i] + batch_durations[i]),
+ }
+ for token_str, i in zip(decoded_tokens, non_blank_indices)
+ ]
+ timestamp_dict = self._refine_timestamps_tdt(timestamp_dict)
+
+ # Convert to seconds
+ for offset in timestamp_dict:
+ offset["start"] = offset["start"] * frame_rate
+ offset["end"] = offset["end"] * frame_rate
+ proc_timestamps.append(timestamp_dict)
+
+ return decoded, proc_timestamps
+ return decoded
+
+ def _refine_timestamps_tdt(
+ self, char_offsets, supported_punctuation=["?", "'", "¡", "¿", "-", ":", ",", "%", "/", ".", "!"]
+ ):
+ for i, offset in enumerate(char_offsets):
+ # If token is a punctuation mark, set its start and end offset as start and end of previous token
+ if offset["token"] in supported_punctuation and i > 0:
+ offset["start"] = char_offsets[i - 1]["end"]
+ offset["end"] = offset["start"]
+
+ return char_offsets
__all__ = ["ParakeetProcessor"]
diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py
index effc53f378b4..2fcaf8cac3cd 100755
--- a/src/transformers/pipelines/__init__.py
+++ b/src/transformers/pipelines/__init__.py
@@ -110,6 +110,7 @@
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
+ AutoModelForTDT,
AutoModelForTextToSpectrogram,
AutoModelForTextToWaveform,
AutoModelForTokenClassification,
@@ -143,7 +144,7 @@
},
"automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline,
- "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
+ "pt": (AutoModelForCTC, AutoModelForTDT, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
"default": {"model": ("facebook/wav2vec2-base-960h", "22aad52")},
"type": "multimodal",
},
diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py
index 58349d0b10b7..f71f4c4bd62c 100644
--- a/src/transformers/pipelines/automatic_speech_recognition.py
+++ b/src/transformers/pipelines/automatic_speech_recognition.py
@@ -176,6 +176,8 @@ def __init__(
self.type = "seq2seq_whisper"
elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
self.type = "seq2seq"
+ elif model.config.model_type == "parakeet_tdt":
+ self.type = "tdt"
elif decoder is not None:
self.decoder = decoder
self.type = "ctc_with_lm"
@@ -534,7 +536,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
if stride is not None:
out["stride"] = stride
- else:
+ elif self.type in {"ctc", "ctc_with_lm"}:
inputs = {
self.model.main_input_name: model_inputs.pop(self.model.main_input_name),
"attention_mask": attention_mask,
@@ -555,6 +557,17 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
out["stride"] = rescale_stride([stride], ratio)[0]
else:
out["stride"] = rescale_stride(stride, ratio)
+ elif self.type == "tdt":
+ inputs = {
+ self.model.main_input_name: model_inputs.pop(self.model.main_input_name),
+ }
+ if "attention_mask" in model_inputs:
+ inputs["attention_mask"] = model_inputs.pop("attention_mask")
+ outputs = self.model.generate(**inputs)
+ out = {"tokens": outputs.sequences}
+ else:
+ raise ValueError("Unsupported model type {self.type}.")
+
# Leftover
extra = model_inputs
return {"is_last": is_last, **out, **extra}
diff --git a/tests/fixtures/parakeet/expected_loss_tdt.json b/tests/fixtures/parakeet/expected_loss_tdt.json
new file mode 100644
index 000000000000..aee3c3f16c2b
--- /dev/null
+++ b/tests/fixtures/parakeet/expected_loss_tdt.json
@@ -0,0 +1,5 @@
+{
+ "num_samples": 2,
+ "expected_mean_loss": 0.528089,
+ "comment": "NeMo reference with sigma=0, HF-style mean reduction (per-sample / target_length, then average). Generated with https://gist.github.com/883ea42bf7d8ce2af42f3055627476a7"
+}
diff --git a/tests/fixtures/parakeet/expected_results_batch_tdt.json b/tests/fixtures/parakeet/expected_results_batch_tdt.json
new file mode 100644
index 000000000000..c6a37bad56e8
--- /dev/null
+++ b/tests/fixtures/parakeet/expected_results_batch_tdt.json
@@ -0,0 +1,9 @@
+{
+ "transcriptions": [
+ "mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
+ "Nor is mister Quilter's manner less interesting than his matter.",
+ "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind.",
+ "He has grave doubts whether Sir Frederick Leighton's work is really Greek after all, and can discover in it but little of Rocky Ithaca.",
+ "Linnell's pictures are a sort of up guards an atom paintings, and Mason's exquisite idols are as national as a jingo poem. mister Burkett Foster's landscapes smile at one much in the same way that mister Carker used to flash his teeth. And mister John Collier gives his sitter a cheerful slap on the back, before he says, like a shampooer in a Turkish bath Next man"
+ ]
+}
diff --git a/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json b/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json
new file mode 100644
index 000000000000..f13d5aee8b5f
--- /dev/null
+++ b/tests/fixtures/parakeet/expected_results_batch_tdt_timestamp.json
@@ -0,0 +1,251 @@
+{
+ "transcriptions": [
+ "mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
+ "Nor is mister Quilter's manner less interesting than his matter.",
+ "He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind."
+ ],
+ "start_timestamps": [
+ [
+ 0.24,
+ 0.48,
+ 0.64,
+ 0.88,
+ 1.12,
+ 1.36,
+ 1.44,
+ 1.6,
+ 1.76,
+ 2.0,
+ 2.16,
+ 2.24,
+ 2.4,
+ 2.48,
+ 2.56,
+ 2.72,
+ 2.88,
+ 3.04,
+ 3.12,
+ 3.2800000000000002,
+ 3.44,
+ 3.6,
+ 3.7600000000000002,
+ 3.92,
+ 4.08,
+ 4.24,
+ 4.4,
+ 4.48,
+ 4.72,
+ 4.96,
+ 5.36,
+ 5.6000000000000005
+ ],
+ [
+ 0.32,
+ 0.64,
+ 0.88,
+ 1.04,
+ 1.2,
+ 1.44,
+ 1.68,
+ 1.84,
+ 1.92,
+ 2.0,
+ 2.16,
+ 2.4,
+ 2.56,
+ 2.72,
+ 2.96,
+ 3.12,
+ 3.36,
+ 3.6,
+ 3.92,
+ 4.16,
+ 4.32
+ ],
+ [
+ 0.32,
+ 0.64,
+ 0.72,
+ 0.96,
+ 1.12,
+ 1.36,
+ 1.6,
+ 1.84,
+ 2.08,
+ 2.24,
+ 2.48,
+ 2.64,
+ 2.8000000000000003,
+ 2.88,
+ 3.04,
+ 3.2,
+ 3.44,
+ 3.68,
+ 3.84,
+ 4.08,
+ 4.4,
+ 4.5600000000000005,
+ 4.72,
+ 4.96,
+ 5.12,
+ 5.36,
+ 5.5200000000000005,
+ 5.68,
+ 5.92,
+ 6.16,
+ 6.24,
+ 6.4,
+ 6.5600000000000005,
+ 6.72,
+ 6.96,
+ 7.28,
+ 7.6000000000000005,
+ 7.92,
+ 8.16,
+ 8.32,
+ 8.48,
+ 8.72,
+ 8.88,
+ 8.96,
+ 9.120000000000001,
+ 9.28,
+ 9.44,
+ 9.68,
+ 9.76,
+ 9.92,
+ 10.16,
+ 10.24,
+ 10.4,
+ 10.64,
+ 10.88,
+ 10.96,
+ 11.200000000000001,
+ 11.36,
+ 11.52,
+ 11.84,
+ 12.16
+ ]
+ ],
+ "end_timestamps": [
+ [
+ 0.48,
+ 0.64,
+ 0.88,
+ 1.12,
+ 1.36,
+ 1.44,
+ 1.6,
+ 1.76,
+ 1.92,
+ 2.16,
+ 2.24,
+ 2.4,
+ 2.48,
+ 2.56,
+ 2.64,
+ 2.88,
+ 3.04,
+ 3.12,
+ 3.12,
+ 3.44,
+ 3.6,
+ 3.7600000000000002,
+ 3.92,
+ 4.08,
+ 4.24,
+ 4.4,
+ 4.48,
+ 4.72,
+ 4.96,
+ 5.12,
+ 5.6000000000000005,
+ 5.6000000000000005
+ ],
+ [
+ 0.64,
+ 0.88,
+ 1.04,
+ 1.2,
+ 1.44,
+ 1.68,
+ 1.84,
+ 1.84,
+ 2.0,
+ 2.16,
+ 2.4,
+ 2.56,
+ 2.72,
+ 2.96,
+ 3.12,
+ 3.36,
+ 3.6,
+ 3.92,
+ 4.16,
+ 4.32,
+ 4.32
+ ],
+ [
+ 0.64,
+ 0.72,
+ 0.96,
+ 1.12,
+ 1.36,
+ 1.6,
+ 1.84,
+ 2.08,
+ 2.24,
+ 2.48,
+ 2.64,
+ 2.8000000000000003,
+ 2.88,
+ 3.04,
+ 3.2,
+ 3.44,
+ 3.68,
+ 3.84,
+ 3.84,
+ 4.4,
+ 4.5600000000000005,
+ 4.72,
+ 4.96,
+ 5.12,
+ 5.36,
+ 5.5200000000000005,
+ 5.68,
+ 5.92,
+ 6.16,
+ 6.24,
+ 6.4,
+ 6.5600000000000005,
+ 6.72,
+ 6.96,
+ 7.28,
+ 7.28,
+ 7.92,
+ 8.16,
+ 8.24,
+ 8.48,
+ 8.72,
+ 8.88,
+ 8.96,
+ 9.120000000000001,
+ 9.200000000000001,
+ 9.44,
+ 9.68,
+ 9.76,
+ 9.92,
+ 10.16,
+ 10.24,
+ 10.4,
+ 10.64,
+ 10.88,
+ 10.96,
+ 11.200000000000001,
+ 11.36,
+ 11.52,
+ 11.84,
+ 12.16,
+ 12.16
+ ]
+ ]
+}
diff --git a/tests/fixtures/parakeet/expected_results_single_tdt.json b/tests/fixtures/parakeet/expected_results_single_tdt.json
new file mode 100644
index 000000000000..a757d763b6a3
--- /dev/null
+++ b/tests/fixtures/parakeet/expected_results_single_tdt.json
@@ -0,0 +1,5 @@
+{
+ "transcriptions": [
+ "mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
+ ]
+}
diff --git a/tests/fixtures/parakeet/expected_tdt_loss.json b/tests/fixtures/parakeet/expected_tdt_loss.json
new file mode 100644
index 000000000000..7c3ff498483f
--- /dev/null
+++ b/tests/fixtures/parakeet/expected_tdt_loss.json
@@ -0,0 +1,43 @@
+{
+ "seed": 42,
+ "batch_size": 2,
+ "max_t": 8,
+ "max_u": 4,
+ "vocab_size": 5,
+ "durations": [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4
+ ],
+ "targets": [
+ [
+ 4,
+ 2,
+ 2,
+ 1
+ ],
+ [
+ 0,
+ 4,
+ 2,
+ 4
+ ]
+ ],
+ "logit_lengths": [
+ 8,
+ 7
+ ],
+ "target_lengths": [
+ 4,
+ 3
+ ],
+ "expected_loss_sum": 21.978166580200195,
+ "expected_loss_mean": 3.124553918838501,
+ "expected_loss_none": [
+ 12.923372268676758,
+ 9.054794311523438
+ ],
+ "expected_loss_mean_sigma_0p05": 3.1921849250793457
+}
\ No newline at end of file
diff --git a/tests/models/lasr/test_modeling_lasr.py b/tests/models/lasr/test_modeling_lasr.py
index 36060eecac3b..d212730676f9 100644
--- a/tests/models/lasr/test_modeling_lasr.py
+++ b/tests/models/lasr/test_modeling_lasr.py
@@ -245,6 +245,7 @@ def test_ctc_loss_inference(self):
@require_torch
class LasrForCTCModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (LasrForCTC,) if is_torch_available() else ()
+ all_generative_model_classes = () # LasrForCTC has a custom genereate method
pipeline_model_mapping = (
{
"feature-extraction": LasrEncoder,
diff --git a/tests/models/parakeet/test_modeling_parakeet.py b/tests/models/parakeet/test_modeling_parakeet.py
index b1de3904bba0..2c6d219797aa 100644
--- a/tests/models/parakeet/test_modeling_parakeet.py
+++ b/tests/models/parakeet/test_modeling_parakeet.py
@@ -16,7 +16,9 @@
import json
import tempfile
import unittest
+from contextlib import nullcontext
from pathlib import Path
+from unittest.mock import patch
from transformers import is_datasets_available, is_torch_available
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
@@ -37,7 +39,87 @@
ParakeetEncoder,
ParakeetEncoderConfig,
ParakeetForCTC,
+ ParakeetForTDT,
+ ParakeetTDTConfig,
)
+ from transformers.loss.loss_tdt import tdt_loss
+
+
+@require_torch
+class TDTLossTest(unittest.TestCase):
+ """Test tdt_loss against reference values generated by NeMo's TDTLossPytorch.
+ reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-generate_tdt_loss_fixtures-py
+ """
+
+ FIXTURE_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_tdt_loss.json"
+
+ @classmethod
+ def setUpClass(cls):
+ with open(cls.FIXTURE_PATH) as f:
+ cls.fixture = json.load(f)
+
+ def _make_inputs(self):
+ torch.manual_seed(self.fixture["seed"])
+ batch_size = self.fixture["batch_size"]
+ max_t = self.fixture["max_t"]
+ max_u = self.fixture["max_u"]
+ vocab_size = self.fixture["vocab_size"]
+ num_durations = len(self.fixture["durations"])
+ blank_token_id = vocab_size
+
+ combined_logits = torch.randn(batch_size, max_t, max_u + 1, vocab_size + 1 + num_durations)
+ targets = torch.randint(0, vocab_size, (batch_size, max_u))
+ logit_lengths = torch.tensor(self.fixture["logit_lengths"])
+ target_lengths = torch.tensor(self.fixture["target_lengths"])
+
+ return {
+ "token_logits": combined_logits[..., : vocab_size + 1],
+ "duration_logits": combined_logits[..., vocab_size + 1 :],
+ "targets": targets,
+ "logit_lengths": logit_lengths,
+ "target_lengths": target_lengths,
+ "blank_token_id": blank_token_id,
+ "durations": self.fixture["durations"],
+ }
+
+ def test_tdt_loss_sum(self):
+ inputs = self._make_inputs()
+ loss = tdt_loss(**inputs, reduction="sum")
+ expected = torch.tensor(self.fixture["expected_loss_sum"])
+ torch.testing.assert_close(loss, expected)
+
+ def test_tdt_loss_mean(self):
+ inputs = self._make_inputs()
+ loss = tdt_loss(**inputs, reduction="mean")
+ expected = torch.tensor(self.fixture["expected_loss_mean"])
+ torch.testing.assert_close(loss, expected)
+
+ def test_tdt_loss_none(self):
+ inputs = self._make_inputs()
+ losses = tdt_loss(**inputs, reduction="none")
+ expected = torch.tensor(self.fixture["expected_loss_none"])
+ torch.testing.assert_close(losses, expected)
+
+ def test_tdt_loss_with_sigma(self):
+ inputs = self._make_inputs()
+ loss_no_sigma = tdt_loss(**inputs, sigma=0.0, reduction="mean")
+ loss_with_sigma = tdt_loss(**inputs, sigma=0.05, reduction="mean")
+ self.assertFalse(torch.allclose(loss_no_sigma, loss_with_sigma))
+ self.assertGreater(loss_with_sigma.item(), loss_no_sigma.item())
+
+ expected = torch.tensor(self.fixture["expected_loss_mean_sigma_0p05"])
+ torch.testing.assert_close(loss_with_sigma, expected)
+
+ def test_tdt_loss_gradient_flows(self):
+ inputs = self._make_inputs()
+ inputs["token_logits"] = inputs["token_logits"].requires_grad_(True)
+ inputs["duration_logits"] = inputs["duration_logits"].requires_grad_(True)
+ loss = tdt_loss(**inputs, reduction="mean")
+ loss.backward()
+ self.assertIsNotNone(inputs["token_logits"].grad)
+ self.assertIsNotNone(inputs["duration_logits"].grad)
+ self.assertFalse(torch.all(inputs["token_logits"].grad == 0))
+ self.assertFalse(torch.all(inputs["duration_logits"].grad == 0))
class ParakeetEncoderModelTester:
@@ -56,7 +138,7 @@ def __init__(
conv_kernel_size=9,
subsampling_factor=8,
subsampling_conv_channels=32,
- use_bias=True,
+ attention_bias=True,
num_mel_bins=80,
scale_input=True,
):
@@ -77,7 +159,7 @@ def __init__(
self.conv_kernel_size = conv_kernel_size
self.subsampling_factor = subsampling_factor
self.subsampling_conv_channels = subsampling_conv_channels
- self.use_bias = use_bias
+ self.attention_bias = attention_bias
self.num_mel_bins = num_mel_bins
self.scale_input = scale_input
@@ -108,7 +190,7 @@ def get_config(self):
conv_kernel_size=self.conv_kernel_size,
subsampling_factor=self.subsampling_factor,
subsampling_conv_channels=self.subsampling_conv_channels,
- use_bias=self.use_bias,
+ attention_bias=self.attention_bias,
num_mel_bins=self.num_mel_bins,
scale_input=self.scale_input,
)
@@ -167,6 +249,10 @@ class ParakeetEncoderModelTest(ModelTesterMixin, unittest.TestCase):
test_resize_embeddings = False
+ @unittest.skip(reason="No available flash-SDPA kernels for Parakeet test shapes on this setup")
+ def test_sdpa_can_dispatch_on_flash(self):
+ pass
+
def setUp(self):
self.model_tester = ParakeetEncoderModelTester(self)
self.config_tester = ConfigTester(self, config_class=ParakeetEncoderConfig, has_text_modality=False)
@@ -237,6 +323,7 @@ def test_ctc_loss_inference(self):
@require_torch
class ParakeetForCTCModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (ParakeetForCTC,) if is_torch_available() else ()
+ all_generative_model_classes = () # ParakeetForCTC has a custom genereate method
pipeline_model_mapping = (
{
"feature-extraction": ParakeetEncoder,
@@ -247,11 +334,13 @@ class ParakeetForCTCModelTest(ModelTesterMixin, unittest.TestCase):
)
test_attention_outputs = False
-
test_resize_embeddings = False
-
_is_composite = True
+ @unittest.skip(reason="No available flash-SDPA kernels for Parakeet test shapes on this setup")
+ def test_sdpa_can_dispatch_on_flash(self):
+ pass
+
def setUp(self):
self.model_tester = ParakeetForCTCModelTester(self)
self.config_tester = ConfigTester(self, config_class=ParakeetCTCConfig)
@@ -303,14 +392,13 @@ class ParakeetForCTCIntegrationTest(unittest.TestCase):
def setUp(cls):
cls.checkpoint_name = "nvidia/parakeet-ctc-1.1b"
cls.dtype = torch.bfloat16
- cls.processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
+ cls.processor = AutoProcessor.from_pretrained(cls.checkpoint_name)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@classmethod
def _load_dataset(cls):
- # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process.
if cls._dataset is None:
cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
cls._dataset = cls._dataset.cast_column(
@@ -326,8 +414,7 @@ def _load_datasamples(self, num_samples):
@slow
def test_1b_model_integration(self):
"""
- bezzam reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single-py
- eustlb reproducer: https://gist.github.com/eustlb/6e9e3aa85de3f7c340ec3c36e65f2fe6
+ reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single-py
"""
RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single.json"
with open(RESULTS_PATH, "r") as f:
@@ -336,25 +423,20 @@ def test_1b_model_integration(self):
EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"]
samples = self._load_datasamples(1)
- model = ParakeetForCTC.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device)
- model.eval()
- model.to(torch_device)
+ model = ParakeetForCTC.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto")
- # -- apply
inputs = self.processor(samples)
- inputs.to(torch_device, dtype=self.dtype)
+ inputs.to(model.device, dtype=self.dtype)
predicted_ids = model.generate(**inputs)
torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS)
- predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
+ predicted_transcripts = self.processor.decode(predicted_ids, skip_special_tokens=True)
self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)
@slow
def test_1b_model_integration_batched(self):
"""
- bezzam reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batched-py
- eustlb reproducer: https://gist.github.com/eustlb/575b5da58de34a70116a1955b1183596
+ reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batched-py
"""
-
RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch.json"
with open(RESULTS_PATH, "r") as f:
raw_data = json.load(f)
@@ -362,14 +444,348 @@ def test_1b_model_integration_batched(self):
EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"]
samples = self._load_datasamples(5)
- model = ParakeetForCTC.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device)
- model.eval()
- model.to(torch_device)
+ model = ParakeetForCTC.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto")
- # -- apply
inputs = self.processor(samples)
- inputs.to(torch_device, dtype=self.dtype)
+ inputs.to(model.device, dtype=self.dtype)
predicted_ids = model.generate(**inputs)
torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS)
- predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
+ predicted_transcripts = self.processor.decode(predicted_ids, skip_special_tokens=True)
+ self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)
+
+
+class ParakeetForTDTModelTester:
+ def __init__(
+ self,
+ parent,
+ encoder_kwargs=None,
+ is_training=True,
+ vocab_size=129,
+ decoder_hidden_size=32,
+ num_decoder_layers=1,
+ durations=[0, 1, 2, 3, 4],
+ hidden_act="relu",
+ max_symbols_per_step=5,
+ pad_token_id=2,
+ ):
+ if encoder_kwargs is None:
+ encoder_kwargs = {}
+
+ self.parent = parent
+ self.encoder_model_tester = ParakeetEncoderModelTester(parent, **encoder_kwargs)
+ self.is_training = is_training
+
+ self.batch_size = self.encoder_model_tester.batch_size
+ self.output_seq_length = self.encoder_model_tester.output_seq_length
+ self.num_hidden_layers = self.encoder_model_tester.num_hidden_layers
+ self.hidden_size = self.encoder_model_tester.hidden_size
+ self.seq_length = self.encoder_model_tester.output_seq_length
+ self.encoder_seq_length = self.encoder_model_tester.output_seq_length
+
+ self.vocab_size = vocab_size
+ self.decoder_hidden_size = decoder_hidden_size
+ self.num_decoder_layers = num_decoder_layers
+ self.durations = durations
+ self.hidden_act = hidden_act
+ self.max_symbols_per_step = max_symbols_per_step
+ self.pad_token_id = pad_token_id
+ self.blank_token_id = vocab_size - 1
+
+ def prepare_config_and_inputs(self):
+ _, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs()
+ config = self.get_config()
+ return config, input_features, attention_mask
+
+ def get_config(self):
+ return ParakeetTDTConfig(
+ vocab_size=self.vocab_size,
+ decoder_hidden_size=self.decoder_hidden_size,
+ num_decoder_layers=self.num_decoder_layers,
+ durations=self.durations,
+ hidden_act=self.hidden_act,
+ max_symbols_per_step=self.max_symbols_per_step,
+ encoder_config=self.encoder_model_tester.get_config().to_dict(),
+ pad_token_id=self.pad_token_id,
+ blank_token_id=self.blank_token_id,
+ )
+
+ def create_and_check_model(self, config, inputs_dict):
+ model = ParakeetForTDT(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(**inputs_dict)
+
+ # Check encoder last hidden state
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.output_seq_length, self.encoder_model_tester.hidden_size),
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_features, attention_mask = self.prepare_config_and_inputs()
+ decoder_input_ids = ids_tensor([self.batch_size, 1], self.vocab_size)
+ inputs_dict = {
+ "input_features": input_features,
+ "attention_mask": attention_mask,
+ "decoder_input_ids": decoder_input_ids,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class ParakeetForTDTModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (ParakeetForTDT,) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {
+ "feature-extraction": ParakeetEncoder,
+ "automatic-speech-recognition": ParakeetForTDT,
+ }
+ if is_torch_available()
+ else {}
+ )
+
+ test_attention_outputs = False
+ test_resize_embeddings = False
+ test_torch_exportable = False
+ _is_composite = True
+
+ @unittest.skip(reason="No available flash-SDPA kernels for Parakeet test shapes on this setup")
+ def test_sdpa_can_dispatch_on_flash(self):
+ pass
+
+ def setUp(self):
+ self.model_tester = ParakeetForTDTModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=ParakeetTDTConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip(reason="ParakeetForTDT does not use inputs_embeds")
+ def test_model_get_set_embeddings(self):
+ pass
+
+ @unittest.skip(
+ reason="ParakeetForTDT is a transducer, not a standard encoder-decoder: no separate text config to set"
+ )
+ def test_attn_implementation_composite_models(self):
+ pass
+
+ @unittest.skip(
+ reason="ParakeetForTDT is a transducer with an LSTM prediction network; "
+ "it does not expose encoder_hidden_states in the standard encoder-decoder sense"
+ )
+ def test_hidden_states_output(self):
+ pass
+
+ @unittest.skip(
+ reason="ParakeetForTDT is a transducer with an LSTM prediction network; "
+ "it does not expose encoder_hidden_states in the standard encoder-decoder sense"
+ )
+ def test_retain_grad_hidden_states_attentions(self):
+ pass
+
+ @unittest.skip(
+ reason="ParakeetForTDT has a custom generate() that is not fully compatible with GenerationTesterMixin"
+ )
+ def test_generation_tester_mixin_inheritance(self):
+ pass
+
+ @unittest.skip(reason="ParakeetForTDT is a flat composite model without a separate base_model sub-module")
+ def test_model_base_model_prefix(self):
+ pass
+
+ @unittest.skip(reason="ParakeetForTDT decoder is an LSTM prediction network without attention")
+ def test_flex_attention_with_grads(self):
+ pass
+
+ # Original function assumes vision+text model, so overwrite since Parakeet is audio+text
+ def test_sdpa_can_dispatch_composite_models(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ if not self._is_composite:
+ self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
+
+ for model_class in self.all_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_sdpa = model_class.from_pretrained(tmpdirname)
+ model_sdpa = model_sdpa.eval().to(torch_device)
+
+ model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
+ model_eager = model_eager.eval().to(torch_device)
+ self.assertTrue(model_eager.config._attn_implementation == "eager")
+
+ for name, submodule in model_eager.named_modules():
+ class_name = submodule.__class__.__name__
+ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
+ raise ValueError("The eager model should not have SDPA attention layers")
+
+
+@require_torch
+class ParakeetForTDTIntegrationTest(unittest.TestCase):
+ _dataset = None
+
+ @classmethod
+ def setUp(cls):
+ cls.checkpoint_name = "nvidia/parakeet-tdt-0.6b-v3"
+ cls.dtype = torch.bfloat16
+ cls.processor = AutoProcessor.from_pretrained(cls.checkpoint_name)
+
+ def tearDown(self):
+ cleanup(torch_device, gc_collect=True)
+
+ @classmethod
+ def _load_dataset(cls):
+ if cls._dataset is None:
+ cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ cls._dataset = cls._dataset.cast_column(
+ "audio", Audio(sampling_rate=cls.processor.feature_extractor.sampling_rate)
+ )
+
+ def _load_datasamples(self, num_samples):
+ self._load_dataset()
+ ds = self._dataset
+ speech_samples = ds.sort("id")[:num_samples]["audio"]
+ return [x["array"] for x in speech_samples]
+
+ @slow
+ def test_tdt_model_integration(self):
+ """
+ reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single_tdt-py
+ """
+ RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single_tdt.json"
+ with open(RESULTS_PATH, "r") as f:
+ raw_data = json.load(f)
+ EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"]
+
+ samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS))
+ model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto")
+
+ inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate)
+ inputs.to(model.device, dtype=self.dtype)
+ output = model.generate(**inputs, return_dict_in_generate=True)
+ predicted_transcripts = self.processor.decode(output.sequences, skip_special_tokens=True)
+ self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)
+
+ @slow
+ def test_tdt_model_integration_batched(self):
+ """
+ reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batch_tdt-py
+ """
+ RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch_tdt.json"
+ with open(RESULTS_PATH, "r") as f:
+ raw_data = json.load(f)
+ EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"]
+
+ samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS))
+ model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto")
+
+ inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate)
+ inputs.to(model.device, dtype=self.dtype)
+ output = model.generate(**inputs, return_dict_in_generate=True)
+ predicted_transcripts = self.processor.decode(output.sequences, skip_special_tokens=True)
+ self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)
+
+ @slow
+ def test_tdt_model_integration_timestamps(self):
+ """
+ reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batch_tdt_timestamps-py
+ """
+ RESULTS_PATH = (
+ Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch_tdt_timestamp.json"
+ )
+ with open(RESULTS_PATH, "r") as f:
+ raw_data = json.load(f)
+ EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"]
+ EXPECTED_START_TIMESTAMPS = raw_data["start_timestamps"]
+ EXPECTED_END_TIMESTAMPS = raw_data["end_timestamps"]
+
+ # Use larger precision for testing token durations and timestamps
+ samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS))
+ model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map="auto")
+
+ inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate)
+ inputs.to(model.device, dtype=model.dtype)
+ output = model.generate(**inputs, return_dict_in_generate=True)
+ predicted_transcripts, predicted_timestamps = self.processor.decode(
+ output.sequences,
+ durations=output.durations,
+ skip_special_tokens=True,
+ )
self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)
+
+ # Check timestamps and durations
+ self.assertIsNotNone(output.durations, "durations should be returned")
+ predicted_start_times = [[entry["start"] for entry in el] for el in predicted_timestamps]
+ predicted_end_times = [[entry["end"] for entry in el] for el in predicted_timestamps]
+ torch.testing.assert_close(predicted_start_times, EXPECTED_START_TIMESTAMPS)
+ torch.testing.assert_close(predicted_end_times, EXPECTED_END_TIMESTAMPS)
+
+ @slow
+ def test_tdt_model_integration_loss(self):
+ """
+ Verify that ParakeetForTDT loss matches NeMo's TDT loss (sigma=0) for both
+ the CUDA kernel and the pure PyTorch implementation.
+ reproducer: https://gist.github.com/883ea42bf7d8ce2af42f3055627476a7
+ """
+ from transformers.loss.loss_tdt import _load_tdt_kernel
+
+ RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_loss_tdt.json"
+ with open(RESULTS_PATH, "r") as f:
+ raw_data = json.load(f)
+ EXPECTED_MEAN_LOSS = torch.tensor(raw_data["expected_mean_loss"])
+ num_samples = raw_data["num_samples"]
+
+ samples = self._load_datasamples(num_samples)
+ transcripts = self._dataset.sort("id")[:num_samples]["text"]
+ transcripts = [t.lower() for t in transcripts]
+
+ # Use float32 for loss precision
+ model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map="auto")
+
+ inputs = self.processor(
+ audio=samples,
+ text=transcripts,
+ sampling_rate=self.processor.feature_extractor.sampling_rate,
+ )
+ inputs.to(model.device)
+
+ # Test both backends: kernel (if available) and pure PyTorch
+ has_kernel = _load_tdt_kernel() is not None
+ backends = [
+ ("kernel", None),
+ ("torch", patch("transformers.loss.loss_tdt._load_tdt_kernel", return_value=None)),
+ ]
+ if not has_kernel:
+ backends = backends[1:] # skip kernel test when not installed
+
+ for backend_name, ctx in backends:
+ with self.subTest(backend=backend_name):
+ ctx_manager = ctx if ctx is not None else nullcontext()
+ with ctx_manager:
+ # Forward in eval mode — check loss matches NeMo
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**inputs)
+ self.assertIsNotNone(outputs.loss, "Loss must be computed when labels are provided")
+ self.assertEqual(outputs.logits.dim(), 4, "Training logits must be 4D (B, T, U+1, V+D)")
+ torch.testing.assert_close(outputs.loss.cpu(), EXPECTED_MEAN_LOSS, rtol=1e-3, atol=1e-3)
+
+ # Backward — verify gradients flow
+ del outputs
+ torch.cuda.empty_cache()
+ model.train()
+ model.zero_grad()
+ outputs = model(**inputs)
+ outputs.loss.backward()
+ n_with_grad = sum(1 for p in model.parameters() if p.grad is not None)
+ self.assertGreater(n_with_grad, 0, "No gradients after backward")
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 24f278c24704..ce4f51969cc1 100644
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -5366,7 +5366,10 @@ def test_get_audio_features_output(self, return_dict: bool | None):
elif hasattr(audio_config, "hidden_size"):
hidden_size = audio_config.hidden_size
elif hasattr(audio_config, "encoder_config"):
- hidden_size = audio_config.encoder_config.hidden_dim
+ if hasattr(audio_config.encoder_config, "hidden_size"):
+ hidden_size = audio_config.encoder_config.hidden_size
+ else:
+ hidden_size = audio_config.encoder_config.hidden_dim
elif hasattr(audio_config, "encoder_ffn_dim"):
hidden_size = audio_config.encoder_ffn_dim
self.assertEqual(