diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index bbbb8bbf6444..c1441ba20047 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -14,12 +14,16 @@ import re from collections.abc import Callable from functools import partial +from types import ModuleType from typing import Optional, Union from ..modeling_flash_attention_utils import lazy_import_flash_attention +from ..utils import logging from .flash_attention import flash_attention_forward +logger = logging.get_logger(__name__) + try: from kernels import ( Device, @@ -158,6 +162,13 @@ def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") +_HUB_KERNEL_MAPPING: dict[str, str] = { + "causal-conv1d": "kernels-community/causal-conv1d", +} + +_KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {} + + def is_kernel(attn_implementation: Optional[str]) -> bool: """Check whether `attn_implementation` matches a kernel pattern from the hub.""" return ( @@ -220,9 +231,53 @@ def load_and_register_attn_kernel(attn_implementation: str, attention_wrapper: O ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) +def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] = _KERNEL_MODULE_MAPPING): + if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType): + return mapping[kernel_name] + if kernel_name not in _HUB_KERNEL_MAPPING: + logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING") + mapping[kernel_name] = None + return None + if _kernels_available: + from kernels import get_kernel + + try: + kernel = get_kernel(_HUB_KERNEL_MAPPING[kernel_name]) + mapping[kernel_name] = kernel + except FileNotFoundError: + mapping[kernel_name] = None + + else: + # Try to import is_{kernel_name}_available from ..utils + import importlib + + new_kernel_name = kernel_name.replace("-", "_") + func_name = f"is_{new_kernel_name}_available" + + try: + utils_mod = importlib.import_module("..utils.import_utils", __package__) + is_kernel_available = getattr(utils_mod, func_name, None) + except Exception: + is_kernel_available = None + + if callable(is_kernel_available) and is_kernel_available(): + # Try to import the module "{kernel_name}" from parent package level + try: + module = importlib.import_module(f"{kernel_name}") + mapping[kernel_name] = module + return module + except Exception: + mapping[kernel_name] = None + else: + mapping[kernel_name] = None + + return mapping[kernel_name] + + __all__ = [ "LayerRepository", "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", + "lazy_load_kernel", ] diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 06ad59e20872..ea373c973873 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -30,12 +30,11 @@ from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from ...utils.import_utils import ( - is_causal_conv1d_available, - is_kernels_available, is_mamba_ssm_available, is_mambapy_available, ) @@ -162,33 +161,6 @@ def reset(self): self.ssm_states[layer_idx].zero_() -def _lazy_load_causal_conv1d(): - global _causal_conv1d_cache - if _causal_conv1d_cache is not None: - return _causal_conv1d_cache - - if is_kernels_available(): - from kernels import get_kernel - - try: - _causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d") - except FileNotFoundError: - # no kernel binary match, fallback to slow path - _causal_conv1d_cache = (None, None) - else: - _causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn) - elif is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - - _causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn) - else: - _causal_conv1d_cache = (None, None) - return _causal_conv1d_cache - - -_causal_conv1d_cache = None - - def rms_forward(hidden_states, variance_epsilon=1e-6): """ Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will @@ -268,7 +240,12 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): self.rms_eps = config.mixer_rms_eps def warn_slow_implementation(self): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) @@ -323,7 +300,12 @@ def cuda_kernels_forward( ) else: - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) hidden_states, gate = projected_states.chunk(2, dim=1) if attention_mask is not None: @@ -518,7 +500,12 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index 5370e5fe19da..f9af68d785bd 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -19,6 +19,7 @@ import torch from torch import nn +from ...integrations.hub_kernels import lazy_load_kernel from ...utils import auto_docstring, logging from ...utils.import_utils import ( is_mamba_ssm_available, @@ -35,7 +36,6 @@ MambaOutput, MambaPreTrainedModel, MambaRMSNorm, - _lazy_load_causal_conv1d, ) @@ -54,8 +54,6 @@ else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None -_causal_conv1d_cache = None - class FalconMambaConfig(MambaConfig): """ @@ -258,7 +256,12 @@ def rms_forward(hidden_states, variance_epsilon=1e-6): class FalconMambaMixer(MambaMixer): def warn_slow_implementation(self): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) @@ -324,7 +327,12 @@ def cuda_kernels_forward( ) else: - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) hidden_states, gate = projected_states.chunk(2, dim=1) if attention_mask is not None: @@ -518,7 +526,12 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index ceeefbec8851..3db7d0fa717c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -33,8 +34,6 @@ logging, ) from ...utils.import_utils import ( - is_causal_conv1d_available, - is_kernels_available, is_mamba_ssm_available, is_mambapy_available, ) @@ -54,32 +53,6 @@ else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None -_causal_conv1d_cache = None - - -def _lazy_load_causal_conv1d(): - global _causal_conv1d_cache - if _causal_conv1d_cache is not None: - return _causal_conv1d_cache - - if is_kernels_available(): - from kernels import get_kernel - - try: - _causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d") - except FileNotFoundError: - # no kernel binary match, fallback to slow path - _causal_conv1d_cache = (None, None) - else: - _causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn) - elif is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - - _causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn) - else: - _causal_conv1d_cache = (None, None) - return _causal_conv1d_cache - class MambaCache: """ @@ -236,7 +209,12 @@ def __init__(self, config: MambaConfig, layer_idx: int): self.warn_slow_implementation() def warn_slow_implementation(self): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) @@ -287,7 +265,12 @@ def cuda_kernels_forward( ) else: - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) hidden_states, gate = projected_states.chunk(2, dim=1) if attention_mask is not None: @@ -451,7 +434,12 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): - causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) )