diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index c1441ba20047..95ca49a74915 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -162,8 +162,8 @@ def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") -_HUB_KERNEL_MAPPING: dict[str, str] = { - "causal-conv1d": "kernels-community/causal-conv1d", +_HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = { + "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"}, } _KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {} @@ -242,7 +242,9 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] from kernels import get_kernel try: - kernel = get_kernel(_HUB_KERNEL_MAPPING[kernel_name]) + repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] + version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) + kernel = get_kernel(repo_id, version=version) mapping[kernel_name] = kernel except FileNotFoundError: mapping[kernel_name] = None