Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
fa7d6e0
parakeet tdt intergration
Oct 13, 2025
f2b4938
Add TDT decoder support for Parakeet ASR models
lmaksym Feb 20, 2026
fa36657
Add expected outputs for TDT, small fixes.
ebezzam Feb 25, 2026
05e2e34
Separate CTC and TDT generate outputs.
ebezzam Feb 25, 2026
bb5ff33
Work with auto device, better init,
ebezzam Feb 25, 2026
9ec79b0
Test timestamps and expose token duration.
ebezzam Feb 26, 2026
33f128e
Add reproducer link.
ebezzam Feb 26, 2026
760b4b6
fix: align TDT training and decoding with NeMo implementation
lmaksym Feb 27, 2026
b33002f
revert: restore lasr generated files to original state
lmaksym Feb 27, 2026
48b39dd
warn: torchaudio rnnt_loss does not train duration head
lmaksym Feb 27, 2026
e9f23ab
Relax timestamp test, and test nits.
ebezzam Mar 2, 2026
e2b97aa
feat: TDT training
lmaksym Mar 3, 2026
6b9fc73
chore: for cuda detection and run without patching
lmaksym Mar 3, 2026
6c879bc
Equivalent timestamp processing as Nemo, and various nits/cleanup.
ebezzam Mar 3, 2026
149e17f
Merge branch 'parakeet-tdt' of github.com:lmaksym/transformers into p…
ebezzam Mar 3, 2026
36bfa63
Simplify durations config.
ebezzam Mar 3, 2026
2df0ccc
Update training examples.
ebezzam Mar 3, 2026
388c6d3
chore: enable parralelism
lmaksym Mar 3, 2026
08b2b55
chore: performance optimization
lmaksym Mar 4, 2026
0c4e05a
fix: formatting
lmaksym Mar 4, 2026
1ddd804
Doc and testing nits
ebezzam Mar 5, 2026
f512670
Use active mask from current step, and nits.
ebezzam Mar 6, 2026
07d8e35
Better pre-allocate.
ebezzam Mar 6, 2026
fab050a
TDT has separate pad token and blank token.
ebezzam Mar 6, 2026
c438565
Merge branch 'main' into parakeet-tdt
ebezzam Mar 6, 2026
86d980c
Regenerate lasr.
ebezzam Mar 6, 2026
895c4a0
Merge branch 'parakeet-tdt' of github.com:lmaksym/transformers into p…
ebezzam Mar 6, 2026
ab21380
Style checks and nits
ebezzam Mar 7, 2026
d0141d5
Nits, put back ctc loss test
ebezzam Mar 7, 2026
f7529d4
More standard model output.
ebezzam Mar 10, 2026
77b95d7
Style
ebezzam Mar 10, 2026
94eae66
Remove compute_loss flag and allow monkey patching to tdt loss
ebezzam Mar 23, 2026
f7d4067
Update src/transformers/models/parakeet/modular_parakeet.py
ebezzam Mar 23, 2026
f75c17b
Address various comments.
ebezzam Mar 23, 2026
5a49b65
More compatible with Transformers forward/generate approach
ebezzam Mar 24, 2026
881233f
compile option for generation and decoder cache
ebezzam Mar 24, 2026
b41a8ee
Cleaner, better conventions.
ebezzam Mar 24, 2026
897753a
Merge branch 'main' into parakeet-tdt
ebezzam Mar 24, 2026
6c914db
Update with main.
ebezzam Mar 24, 2026
756cee1
doc nits
ebezzam Mar 26, 2026
f30c536
Imitate whisper for encoder outputs as input
ebezzam Mar 26, 2026
fa95fc8
Address tests and nits.
ebezzam Mar 26, 2026
5df7f28
Inherit from GenerateMixIn for get_compiled_call
ebezzam Mar 26, 2026
cd706d4
Comment nit
ebezzam Mar 26, 2026
a47ed8a
forward cleanup
eustlb Apr 15, 2026
13b68ce
generate cleanup + separate generation file
eustlb Apr 15, 2026
72c1ad0
generate: add _supported_generation_modes
eustlb Apr 15, 2026
8e23b3d
automatic init of the loss
eustlb Apr 15, 2026
1cc39fd
modular cleanups
eustlb Apr 15, 2026
531f297
use is_encoder_decoder
eustlb Apr 15, 2026
2c0f23a
timestamp processing fully from tokens + durations
eustlb Apr 15, 2026
cef6639
convertion script update
eustlb Apr 15, 2026
fd3cf9b
test update
eustlb Apr 15, 2026
e63a5bf
make
eustlb Apr 15, 2026
f9d1a4f
Merge branch 'main' into parakeet-tdt
eustlb Apr 15, 2026
43ee7cd
test update
eustlb Apr 15, 2026
c2a0f78
test update
eustlb Apr 15, 2026
1fd7ed7
ensure correct loss computation
eustlb Apr 16, 2026
7cc9d2e
kernel loss
eustlb Apr 16, 2026
e753eab
test loss integration
eustlb Apr 16, 2026
ed3fa4d
push to hub pr
eustlb Apr 16, 2026
ab66b23
integration tests to rely fully on transcripts
eustlb Apr 16, 2026
a5ba0c6
udpate fixtures
eustlb Apr 16, 2026
48279a6
we don't need to monkey patch with numba anymore!
eustlb Apr 16, 2026
1d7680d
fix pipeline usage
eustlb Apr 16, 2026
59ddced
nit
eustlb Apr 16, 2026
31490d1
fix usage
eustlb Apr 16, 2026
d8eb1b6
Pass through tests and examples: improve kernel fallback, update with…
ebezzam Apr 17, 2026
1f1b912
Update checkpoint
ebezzam Apr 17, 2026
9ab08d1
Merge branch 'main' into parakeet-tdt
ebezzam Apr 17, 2026
fd9f8b1
Add TDT to mapping after merge.
ebezzam Apr 17, 2026
136f676
Fix lasr generate test.
ebezzam Apr 20, 2026
833d289
Output attention mask if labels provided for computing loss.
ebezzam Apr 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/auto.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ The following auto classes are available for the following audio tasks.

[[autodoc]] AutoModelForCTC

### AutoModelForTDT

[[autodoc]] AutoModelForTDT

### AutoModelForSpeechSeq2Seq

[[autodoc]] AutoModelForSpeechSeq2Seq
Expand Down
158 changes: 137 additions & 21 deletions docs/source/en/model_doc/parakeet.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,20 @@ Parakeet models, [introduced by NVIDIA NeMo](https://developer.nvidia.com/blog/p
- 1D convolution projection from encoder hidden size to vocabulary size (for optimal NeMo compatibility).
- CTC loss computation for training.
- Greedy CTC decoding for inference.
- [**ParakeetForTDT**](#parakeetfortdt): a Fast Conformer Encoder + a TDT (Token Duration Transducer) decoder
- **TDT Decoder**: Jointly predicts tokens and their durations, enabling efficient decoding:
- LSTM prediction network maintains language context across token predictions.
- Joint network combines encoder and decoder outputs.
- Duration head predicts how many frames to skip, enabling fast inference.

The original implementation can be found in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo).
Model checkpoints are to be found under [the NVIDIA organization](https://huggingface.co/nvidia/models?search=parakeet).

This model was contributed by [Nithin Rao Koluguri](https://huggingface.co/nithinraok), [Eustache Le Bihan](https://huggingface.co/eustlb) and [Eric Bezzam](https://huggingface.co/bezzam).
This model was contributed by [Nithin Rao Koluguri](https://huggingface.co/nithinraok), [Eustache Le Bihan](https://huggingface.co/eustlb), [Eric Bezzam](https://huggingface.co/bezzam), [Maksym Lypivskyi](https://huggingface.co/MaksL), and [Hainan Xu](https://huggingface.co/hainanx).

## Usage

### Basic usage
### `ParakeetForCTC` usage

<hfoptions id="usage">
<hfoption id="Pipeline">
Expand All @@ -53,6 +58,7 @@ from transformers import pipeline
pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-ctc-1.1b")
out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
print(out)
# {'text': 'yesterday it was thirty five degrees in barcelona but today the temperature will go down to minus twenty degrees'}
```

</hfoption>
Expand All @@ -61,12 +67,10 @@ print(out)
```py
from transformers import AutoModelForCTC, AutoProcessor
from datasets import load_dataset, Audio
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device)
model_id = "nvidia/parakeet-ctc-1.1b"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCTC.from_pretrained(model_id, dtype="auto", device_map="auto")
Comment thread
ebezzam marked this conversation as resolved.

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
Expand All @@ -75,7 +79,80 @@ speech_samples = [el['array'] for el in ds["audio"][:5]]
inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(model.device, dtype=model.dtype)
outputs = model.generate(**inputs)
print(processor.batch_decode(outputs))
print(processor.decode(outputs))
```

</hfoption>
</hfoptions>

### `ParakeetForTDT` usage

<hfoptions id="tdt-usage">
<hfoption id="Pipeline">

Parakeet TDT transcripts include casing, and the model can also perform token timestamping.

```py
from transformers import pipeline

pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-tdt-0.6b-v3")
out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
print(out)
# {'text': 'Yesterday it was 35 degrees in Barcelona, but today the temperature will go down to minus 20 degrees.'}
```

</hfoption>
<hfoption id="AutoModel">

```py
from transformers import AutoModelForTDT, AutoProcessor
from datasets import load_dataset, Audio

model_id = "nvidia/parakeet-tdt-0.6b-v3"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto")

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:5]]

inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(model.device, dtype=model.dtype)
output = model.generate(**inputs, return_dict_in_generate=True)
print(processor.decode(output.sequences, skip_special_tokens=True))
```

</hfoption>
<hfoption id="Timestamping">

```py
from datasets import Audio, load_dataset
from transformers import AutoModelForTDT, AutoProcessor

model_id = "nvidia/parakeet-tdt-0.6b-v3"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto")

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:1]]

inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(model.device, dtype=model.dtype)
output = model.generate(**inputs, return_dict_in_generate=True)
decoded_output, decoded_timestamps = processor.decode(
output.sequences,
durations=output.durations,
skip_special_tokens=True,
)
print("Transcription:", decoded_output)
print("\nTimestamped tokens:", decoded_timestamps)

"""
Transcription: ['mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.']

Timestamped tokens: [[{'token': 'm', 'start': 0.24, 'end': 0.48}, {'token': 'ister', 'start': 0.48, 'end': 0.64}, {'token': 'Qu', 'start': 0.64, 'end': 0.88}, {'token': 'il', 'start': 0.88, 'end': 1.12}, {'token': 'ter', 'start': 1.12, 'end': 1.36}, {'token': 'is', 'start': 1.36, 'end': 1.44}, {'token': 'the', 'start': 1.44, 'end': 1.6}, {'token': 'ap', 'start': 1.6, 'end': 1.76}, {'token': 'ost', 'start': 1.76, 'end': 1.92}, {'token': 'le', 'start': 2.0, 'end': 2.16}, {'token': 'of', 'start': 2.16, 'end': 2.24}, {'token': 'the', 'start': 2.24, 'end': 2.4}, {'token': 'mid', 'start': 2.4, 'end': 2.48}, {'token': 'd', 'start': 2.48, 'end': 2.56}, {'token': 'le', 'start': 2.56, 'end': 2.64}, {'token': 'clas', 'start': 2.72, 'end': 2.88}, {'token': 's', 'start': 2.88, 'end': 3.04}, {'token': 'es', 'start': 3.04, 'end': 3.12}, {'token': ',', 'start': 3.12, 'end': 3.12}, {'token': 'and', 'start': 3.2800000000000002, 'end': 3.44}, {'token': 'we', 'start': 3.44, 'end': 3.6}, {'token': 'are', 'start': 3.6, 'end': 3.7600000000000002}, {'token': 'gl', 'start': 3.7600000000000002, 'end': 3.92}, {'token': 'ad', 'start': 3.92, 'end': 4.08}, {'token': 'to', 'start': 4.08, 'end': 4.24}, {'token': 'wel', 'start': 4.24, 'end': 4.4}, {'token': 'c', 'start': 4.4, 'end': 4.48}, {'token': 'ome', 'start': 4.48, 'end': 4.72}, {'token': 'his', 'start': 4.72, 'end': 4.96}, {'token': 'gos', 'start': 4.96, 'end': 5.12}, {'token': 'pel', 'start': 5.36, 'end': 5.6000000000000005}, {'token': '.', 'start': 5.6000000000000005, 'end': 5.6000000000000005}]]
"""
```

</hfoption>
Expand Down Expand Up @@ -136,58 +213,90 @@ print("First generation - compiling...")
# Generate with the compiled model
with TimerContext("First generation"):
outputs = model.generate(**inputs)
print(processor.batch_decode(outputs))
print(processor.decode(outputs))

inputs = processor(speech_samples[1], **processor_kwargs)
inputs.to(device, dtype=model.dtype)
print("\n" + "="*50)
print("Second generation - recording CUDA graphs...")
with TimerContext("Second generation"):
outputs = model.generate(**inputs)
print(processor.batch_decode(outputs))
print(processor.decode(outputs))

inputs = processor(speech_samples[2], **processor_kwargs)
inputs.to(device, dtype=model.dtype)
print("\n" + "="*50)
print("Third generation - fast !!!")
with TimerContext("Third generation"):
outputs = model.generate(**inputs)
print(processor.batch_decode(outputs))
print(processor.decode(outputs))

inputs = processor(speech_samples[3], **processor_kwargs)
inputs.to(device, dtype=model.dtype)
print("\n" + "="*50)
print("Fourth generation - still fast !!!")
with TimerContext("Fourth generation"):
outputs = model.generate(**inputs)
print(processor.batch_decode(outputs))
print(processor.decode(outputs))
```

### Training
### CTC Training

```python
import torch
from datasets import Audio, load_dataset
from transformers import AutoModelForCTC, AutoProcessor
from datasets import load_dataset, Audio

model_id = "nvidia/parakeet-ctc-1.1b"
NUM_SAMPLES = 5

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCTC.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto")
model.train()

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:NUM_SAMPLES]]
text_samples = ds["text"][:NUM_SAMPLES]

# passing `text` to the processor will prepare inputs' `labels` key
inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(device=model.device, dtype=model.dtype)

outputs = model(**inputs)
print("Loss:", outputs.loss.item())
outputs.loss.backward()
```

### TDT Training

```py
from datasets import Audio, load_dataset
import torch
from transformers import AutoModelForTDT, AutoProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "nvidia/parakeet-tdt-0.6b-v3"
NUM_SAMPLES = 4

processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device)
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForTDT.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto")
model.train()

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:5]]
text_samples = [el for el in ds["text"][:5]]
speech_samples = [el['array'] for el in ds["audio"][:NUM_SAMPLES]]
text_samples = ds["text"][:NUM_SAMPLES]

# passing `text` to the processor will prepare inputs' `labels` key
inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(device, dtype=model.dtype)
inputs.to(device=model.device, dtype=model.dtype)

outputs = model(**inputs)
print("Loss:", outputs.loss.item())
outputs.loss.backward()
```


## ParakeetTokenizer

[[autodoc]] ParakeetTokenizer
Expand All @@ -201,7 +310,6 @@ outputs.loss.backward()

[[autodoc]] ParakeetProcessor
- __call__
- batch_decode
- decode

## ParakeetEncoderConfig
Expand All @@ -212,10 +320,18 @@ outputs.loss.backward()

[[autodoc]] ParakeetCTCConfig

## ParakeetTDTConfig

[[autodoc]] ParakeetTDTConfig

## ParakeetEncoder

[[autodoc]] ParakeetEncoder

## ParakeetForCTC

[[autodoc]] ParakeetForCTC

## ParakeetForTDT

[[autodoc]] ParakeetForTDT
6 changes: 4 additions & 2 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,8 @@ def tokenizer(self, proto):
)

elif model_type == 2:
_, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
result = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(None)
merges = result["merges"]
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(
Expand Down Expand Up @@ -1842,7 +1843,8 @@ def __init__(self, vocab_file=None, *args):
def tokenizer(self, proto):
vocab_scores = self.vocab(proto)

_, merges = self.SpmExtractor(self.vocab_file).extract(vocab_scores)
result = self.SpmExtractor(self.vocab_file).extract(None)
merges = result["merges"]
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,13 @@ def compute_transition_scores(
def _validate_generation_mode(
self: "GenerativePreTrainedModel", generation_mode, generation_config, generation_mode_kwargs
):
supported_modes = getattr(self, "_supported_generation_modes", None)
if supported_modes is not None and generation_mode not in supported_modes:
raise ValueError(
f"{self.__class__.__name__} only supports {supported_modes}, but got "
f"generation mode '{generation_mode}'."
)

Comment on lines +1462 to +1468
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this to be able to do

class ParakeetForTDT(ParakeetPreTrainedModel, ParakeetTDTGenerationMixin):
    _supported_generation_modes = [GenerationMode.GREEDY_SEARCH]

if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs:
raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def register_kernel_mapping_transformers(*args, **kwargs):
"falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1},
"finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1},
"deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1},
"tdt-loss": {"repo_id": "eustlb/tdt-loss", "revision": "v1"},
}
Comment on lines 288 to 290
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ErikKaum pinging you here because your YouTube kernel tutorial helped a lot for this 😊 What are the next steps to move my tdt kernel from my repo to kernels-community and compile for other environments?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eustlb thanks for creating the kernel! btw I changed from "version": 1 to "revision": 1 as your kernel is rather in a v1 branch. Otherwise it wasn't loading as expected since the main branch is empty.

And maybe we need to also add the source to the main branch? I was a bit confused where the content was at first 😝

I guess @ErikKaum will have have best practice tips!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I just used the same convention as for other hub kernels: "version": 1 corresponding to a v1 branch so I am not so sure about changing "version": 1 to "revision": 1


_KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {}
Expand Down Expand Up @@ -372,10 +373,12 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None)
version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None)
kernel = get_kernel(repo_id, revision=revision, version=version)
# Since we only read from `_HUB_KERNEL_MAPPING`, we can allow all kernels
kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True)
Comment on lines +376 to +377
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we hardcode allow_all_kernels=True since we only read kernels from the library defined _HUB_KERNEL_MAPPING?

mapping[kernel_name] = kernel
except FileNotFoundError:
except FileNotFoundError as e:
mapping[kernel_name] = None
logger.warning_once(f"Failed to load kernel {kernel_name}: {e}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a helpful error message, otherwise kernel may not load without notifying the user! E.g. due to different Torch.

For example it will now print:

[transformers] Failed to load kernel tdt-loss: Cannot find a build variant for this system in eustlb/tdt-loss (revision: v1). Available variants: torch211-cxx11-cu128-x86_64-linux

except AssertionError:
# Happens when torch is built without an accelerator backend; fall back to slow path.
mapping[kernel_name] = None
Expand Down
Loading
Loading