From e2350753bf47a4b42b879041a3dd2a66579d383b Mon Sep 17 00:00:00 2001 From: medmekk Date: Tue, 14 Oct 2025 12:53:32 +0000 Subject: [PATCH 1/5] refactor function kernel callling --- src/transformers/integrations/hub_kernels.py | 57 +++++++++++++++++++ .../falcon_mamba/modeling_falcon_mamba.py | 51 +++++++---------- .../falcon_mamba/modular_falcon_mamba.py | 25 ++++++-- .../models/mamba/modeling_mamba.py | 50 +++++++--------- 4 files changed, 114 insertions(+), 69 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index bbbb8bbf6444..892776f8beaf 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -14,12 +14,17 @@ import re from collections.abc import Callable from functools import partial +from types import ModuleType from typing import Optional, Union from ..modeling_flash_attention_utils import lazy_import_flash_attention +from ..utils import logging +from ..utils.import_utils import is_kernels_available from .flash_attention import flash_attention_forward +logger = logging.get_logger(__name__) + try: from kernels import ( Device, @@ -158,6 +163,13 @@ def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") +_KERNEL_SIMPLE_MAPPING: dict[str, str] = { + "causal-conv1d": "kernels-community/causal-conv1d", +} + +_KERNEL_MAPPING_ACROSS_MODELS: dict[str, Optional[ModuleType]] = {} + + def is_kernel(attn_implementation: Optional[str]) -> bool: """Check whether `attn_implementation` matches a kernel pattern from the hub.""" return ( @@ -220,9 +232,54 @@ def load_and_register_attn_kernel(attn_implementation: str, attention_wrapper: O ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) +def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]]): + if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType): + return mapping[kernel_name] + if kernel_name not in _KERNEL_SIMPLE_MAPPING: + logger.warning(f"Kernel {kernel_name} not found in _KERNEL_SIMPLE_MAPPING") + mapping[kernel_name] = None + return None + if not is_kernels_available(): + from kernels import get_kernel + + try: + kernel = get_kernel(_KERNEL_SIMPLE_MAPPING[kernel_name]) + mapping[kernel_name] = kernel + except FileNotFoundError: + mapping[kernel_name] = None + + else: + # Try to import is_{kernel_name}_available from ..utils + import importlib + + new_kernel_name = kernel_name.replace("-", "_") + func_name = f"is_{new_kernel_name}_available" + + try: + utils_mod = importlib.import_module("..utils.import_utils", __package__) + is_kernel_available = getattr(utils_mod, func_name, None) + except Exception: + is_kernel_available = None + + if callable(is_kernel_available) and is_kernel_available(): + # Try to import the module "{kernel_name}" from parent package level + try: + module = importlib.import_module(f"{kernel_name}") + mapping[kernel_name] = module + return module + except Exception: + mapping[kernel_name] = None + else: + mapping[kernel_name] = None + + return mapping[kernel_name] + + __all__ = [ "LayerRepository", "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", + "lazy_load_kernel", + "_KERNEL_MAPPING_ACROSS_MODELS", ] diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 06ad59e20872..98750d1a4d84 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -30,12 +30,11 @@ from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin +from ...integrations.hub_kernels import _KERNEL_MAPPING_ACROSS_MODELS, lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from ...utils.import_utils import ( - is_causal_conv1d_available, - is_kernels_available, is_mamba_ssm_available, is_mambapy_available, ) @@ -162,33 +161,6 @@ def reset(self): self.ssm_states[layer_idx].zero_() -def _lazy_load_causal_conv1d(): - global _causal_conv1d_cache - if _causal_conv1d_cache is not None: - return _causal_conv1d_cache - - if is_kernels_available(): - from kernels import get_kernel - - try: - _causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d") - except FileNotFoundError: - # no kernel binary match, fallback to slow path - _causal_conv1d_cache = (None, None) - else: - _causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn) - elif is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - - _causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn) - else: - _causal_conv1d_cache = (None, None) - return _causal_conv1d_cache - - -_causal_conv1d_cache = None - - def rms_forward(hidden_states, variance_epsilon=1e-6): """ Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will @@ -268,7 +240,12 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): self.rms_eps = config.mixer_rms_eps def warn_slow_implementation(self): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) @@ -323,7 +300,12 @@ def cuda_kernels_forward( ) else: - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) hidden_states, gate = projected_states.chunk(2, dim=1) if attention_mask is not None: @@ -518,7 +500,12 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index 5370e5fe19da..8847bf7bea8e 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -19,6 +19,7 @@ import torch from torch import nn +from ...integrations.hub_kernels import _KERNEL_MAPPING_ACROSS_MODELS, lazy_load_kernel from ...utils import auto_docstring, logging from ...utils.import_utils import ( is_mamba_ssm_available, @@ -35,7 +36,6 @@ MambaOutput, MambaPreTrainedModel, MambaRMSNorm, - _lazy_load_causal_conv1d, ) @@ -54,8 +54,6 @@ else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None -_causal_conv1d_cache = None - class FalconMambaConfig(MambaConfig): """ @@ -258,7 +256,12 @@ def rms_forward(hidden_states, variance_epsilon=1e-6): class FalconMambaMixer(MambaMixer): def warn_slow_implementation(self): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) @@ -324,7 +327,12 @@ def cuda_kernels_forward( ) else: - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) hidden_states, gate = projected_states.chunk(2, dim=1) if attention_mask is not None: @@ -518,7 +526,12 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index ceeefbec8851..5116fb286b90 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin +from ...integrations.hub_kernels import _KERNEL_MAPPING_ACROSS_MODELS, lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -33,8 +34,6 @@ logging, ) from ...utils.import_utils import ( - is_causal_conv1d_available, - is_kernels_available, is_mamba_ssm_available, is_mambapy_available, ) @@ -54,32 +53,6 @@ else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None -_causal_conv1d_cache = None - - -def _lazy_load_causal_conv1d(): - global _causal_conv1d_cache - if _causal_conv1d_cache is not None: - return _causal_conv1d_cache - - if is_kernels_available(): - from kernels import get_kernel - - try: - _causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d") - except FileNotFoundError: - # no kernel binary match, fallback to slow path - _causal_conv1d_cache = (None, None) - else: - _causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn) - elif is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - - _causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn) - else: - _causal_conv1d_cache = (None, None) - return _causal_conv1d_cache - class MambaCache: """ @@ -236,7 +209,12 @@ def __init__(self, config: MambaConfig, layer_idx: int): self.warn_slow_implementation() def warn_slow_implementation(self): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) @@ -287,7 +265,12 @@ def cuda_kernels_forward( ) else: - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) hidden_states, gate = projected_states.chunk(2, dim=1) if attention_mask is not None: @@ -451,7 +434,12 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) From b2f40259a0bbb5ac8a197df519411dc7bdeafa4c Mon Sep 17 00:00:00 2001 From: medmekk Date: Tue, 14 Oct 2025 13:00:39 +0000 Subject: [PATCH 2/5] nit --- src/transformers/integrations/hub_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 892776f8beaf..fb1e7c54c0cc 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -239,7 +239,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]]) logger.warning(f"Kernel {kernel_name} not found in _KERNEL_SIMPLE_MAPPING") mapping[kernel_name] = None return None - if not is_kernels_available(): + if is_kernels_available(): from kernels import get_kernel try: From 168b6ddf3a5c16d8e5bafa7c975107cd74f7246a Mon Sep 17 00:00:00 2001 From: medmekk Date: Thu, 16 Oct 2025 12:42:01 +0000 Subject: [PATCH 3/5] don't pass the mapping --- src/transformers/integrations/hub_kernels.py | 13 ++++++------- .../models/falcon_mamba/modeling_falcon_mamba.py | 8 ++++---- .../models/falcon_mamba/modular_falcon_mamba.py | 8 ++++---- src/transformers/models/mamba/modeling_mamba.py | 8 ++++---- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index fb1e7c54c0cc..c0d72ef2884b 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -163,11 +163,11 @@ def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") -_KERNEL_SIMPLE_MAPPING: dict[str, str] = { +_HUB_KERNEL_MAPPING: dict[str, str] = { "causal-conv1d": "kernels-community/causal-conv1d", } -_KERNEL_MAPPING_ACROSS_MODELS: dict[str, Optional[ModuleType]] = {} +_KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {} def is_kernel(attn_implementation: Optional[str]) -> bool: @@ -232,18 +232,18 @@ def load_and_register_attn_kernel(attn_implementation: str, attention_wrapper: O ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) -def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]]): +def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] = _KERNEL_MODULE_MAPPING): if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType): return mapping[kernel_name] - if kernel_name not in _KERNEL_SIMPLE_MAPPING: - logger.warning(f"Kernel {kernel_name} not found in _KERNEL_SIMPLE_MAPPING") + if kernel_name not in _HUB_KERNEL_MAPPING: + logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING") mapping[kernel_name] = None return None if is_kernels_available(): from kernels import get_kernel try: - kernel = get_kernel(_KERNEL_SIMPLE_MAPPING[kernel_name]) + kernel = get_kernel(_HUB_KERNEL_MAPPING[kernel_name]) mapping[kernel_name] = kernel except FileNotFoundError: mapping[kernel_name] = None @@ -281,5 +281,4 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]]) "register_kernel_mapping", "replace_kernel_forward_from_hub", "lazy_load_kernel", - "_KERNEL_MAPPING_ACROSS_MODELS", ] diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 98750d1a4d84..ea373c973873 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin -from ...integrations.hub_kernels import _KERNEL_MAPPING_ACROSS_MODELS, lazy_load_kernel +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging @@ -240,7 +240,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): self.rms_eps = config.mixer_rms_eps def warn_slow_implementation(self): - causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d = lazy_load_kernel("causal-conv1d") causal_conv1d_update, causal_conv1d_fn = ( (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None @@ -300,7 +300,7 @@ def cuda_kernels_forward( ) else: - causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d = lazy_load_kernel("causal-conv1d") causal_conv1d_update, causal_conv1d_fn = ( (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None @@ -500,7 +500,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): - causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d = lazy_load_kernel("causal-conv1d") causal_conv1d_update, causal_conv1d_fn = ( (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index 8847bf7bea8e..f9af68d785bd 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -19,7 +19,7 @@ import torch from torch import nn -from ...integrations.hub_kernels import _KERNEL_MAPPING_ACROSS_MODELS, lazy_load_kernel +from ...integrations.hub_kernels import lazy_load_kernel from ...utils import auto_docstring, logging from ...utils.import_utils import ( is_mamba_ssm_available, @@ -256,7 +256,7 @@ def rms_forward(hidden_states, variance_epsilon=1e-6): class FalconMambaMixer(MambaMixer): def warn_slow_implementation(self): - causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d = lazy_load_kernel("causal-conv1d") causal_conv1d_update, causal_conv1d_fn = ( (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None @@ -327,7 +327,7 @@ def cuda_kernels_forward( ) else: - causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d = lazy_load_kernel("causal-conv1d") causal_conv1d_update, causal_conv1d_fn = ( (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None @@ -526,7 +526,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): - causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d = lazy_load_kernel("causal-conv1d") causal_conv1d_update, causal_conv1d_fn = ( (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 5116fb286b90..3db7d0fa717c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -25,7 +25,7 @@ from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin -from ...integrations.hub_kernels import _KERNEL_MAPPING_ACROSS_MODELS, lazy_load_kernel +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -209,7 +209,7 @@ def __init__(self, config: MambaConfig, layer_idx: int): self.warn_slow_implementation() def warn_slow_implementation(self): - causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d = lazy_load_kernel("causal-conv1d") causal_conv1d_update, causal_conv1d_fn = ( (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None @@ -265,7 +265,7 @@ def cuda_kernels_forward( ) else: - causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d = lazy_load_kernel("causal-conv1d") causal_conv1d_update, causal_conv1d_fn = ( (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None @@ -434,7 +434,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): - causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) + causal_conv1d = lazy_load_kernel("causal-conv1d") causal_conv1d_update, causal_conv1d_fn = ( (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None From 8467792293869c5f1e5aeb7f62d73fabd79d162f Mon Sep 17 00:00:00 2001 From: medmekk Date: Thu, 16 Oct 2025 12:47:42 +0000 Subject: [PATCH 4/5] use _kernels_available --- src/transformers/integrations/hub_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index c0d72ef2884b..1de151877cc7 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -239,7 +239,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING") mapping[kernel_name] = None return None - if is_kernels_available(): + if _kernels_available: from kernels import get_kernel try: From 374040458947013699454f267606233374be6303 Mon Sep 17 00:00:00 2001 From: medmekk Date: Thu, 16 Oct 2025 12:48:56 +0000 Subject: [PATCH 5/5] rm import --- src/transformers/integrations/hub_kernels.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 1de151877cc7..c1441ba20047 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -19,7 +19,6 @@ from ..modeling_flash_attention_utils import lazy_import_flash_attention from ..utils import logging -from ..utils.import_utils import is_kernels_available from .flash_attention import flash_attention_forward