Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
import re
from collections.abc import Callable
from functools import partial
from types import ModuleType
from typing import Optional, Union

from ..modeling_flash_attention_utils import lazy_import_flash_attention
from ..utils import logging
from .flash_attention import flash_attention_forward


logger = logging.get_logger(__name__)

try:
from kernels import (
Device,
Expand Down Expand Up @@ -158,6 +162,13 @@ def register_kernel_mapping(*args, **kwargs):
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")


_HUB_KERNEL_MAPPING: dict[str, str] = {
"causal-conv1d": "kernels-community/causal-conv1d",
}

_KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {}


def is_kernel(attn_implementation: Optional[str]) -> bool:
"""Check whether `attn_implementation` matches a kernel pattern from the hub."""
return (
Expand Down Expand Up @@ -220,9 +231,53 @@ def load_and_register_attn_kernel(attn_implementation: str, attention_wrapper: O
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])


def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] = _KERNEL_MODULE_MAPPING):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be None, if none default to kernel mapping python will cry otherwise!

if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
return mapping[kernel_name]
if kernel_name not in _HUB_KERNEL_MAPPING:
logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
mapping[kernel_name] = None
return None
if _kernels_available:
from kernels import get_kernel

try:
kernel = get_kernel(_HUB_KERNEL_MAPPING[kernel_name])
mapping[kernel_name] = kernel
except FileNotFoundError:
mapping[kernel_name] = None

else:
# Try to import is_{kernel_name}_available from ..utils
import importlib

new_kernel_name = kernel_name.replace("-", "_")
func_name = f"is_{new_kernel_name}_available"

try:
utils_mod = importlib.import_module("..utils.import_utils", __package__)
is_kernel_available = getattr(utils_mod, func_name, None)
except Exception:
is_kernel_available = None

if callable(is_kernel_available) and is_kernel_available():
# Try to import the module "{kernel_name}" from parent package level
try:
module = importlib.import_module(f"{kernel_name}")
mapping[kernel_name] = module
return module
except Exception:
mapping[kernel_name] = None
else:
mapping[kernel_name] = None

return mapping[kernel_name]


__all__ = [
"LayerRepository",
"use_kernel_forward_from_hub",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"lazy_load_kernel",
]
51 changes: 19 additions & 32 deletions src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@
from ...activations import ACT2FN
from ...configuration_utils import PreTrainedConfig
from ...generation import GenerationMixin
from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging
from ...utils.import_utils import (
is_causal_conv1d_available,
is_kernels_available,
is_mamba_ssm_available,
is_mambapy_available,
)
Expand Down Expand Up @@ -162,33 +161,6 @@ def reset(self):
self.ssm_states[layer_idx].zero_()


def _lazy_load_causal_conv1d():
global _causal_conv1d_cache
if _causal_conv1d_cache is not None:
return _causal_conv1d_cache

if is_kernels_available():
from kernels import get_kernel

try:
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
except FileNotFoundError:
# no kernel binary match, fallback to slow path
_causal_conv1d_cache = (None, None)
else:
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
elif is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update

_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
else:
_causal_conv1d_cache = (None, None)
return _causal_conv1d_cache


_causal_conv1d_cache = None


def rms_forward(hidden_states, variance_epsilon=1e-6):
"""
Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
Expand Down Expand Up @@ -268,7 +240,12 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int):
self.rms_eps = config.mixer_rms_eps

def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
Comment on lines +243 to +248
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice and simple

is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand Down Expand Up @@ -323,7 +300,12 @@ def cuda_kernels_forward(
)

else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
Expand Down Expand Up @@ -518,7 +500,12 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand Down
25 changes: 19 additions & 6 deletions src/transformers/models/falcon_mamba/modular_falcon_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
from torch import nn

from ...integrations.hub_kernels import lazy_load_kernel
from ...utils import auto_docstring, logging
from ...utils.import_utils import (
is_mamba_ssm_available,
Expand All @@ -35,7 +36,6 @@
MambaOutput,
MambaPreTrainedModel,
MambaRMSNorm,
_lazy_load_causal_conv1d,
)


Expand All @@ -54,8 +54,6 @@
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None

_causal_conv1d_cache = None


class FalconMambaConfig(MambaConfig):
"""
Expand Down Expand Up @@ -258,7 +256,12 @@ def rms_forward(hidden_states, variance_epsilon=1e-6):

class FalconMambaMixer(MambaMixer):
def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand Down Expand Up @@ -324,7 +327,12 @@ def cuda_kernels_forward(
)

else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
Expand Down Expand Up @@ -518,7 +526,12 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand Down
50 changes: 19 additions & 31 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ...activations import ACT2FN
from ...configuration_utils import PreTrainedConfig
from ...generation import GenerationMixin
from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...utils import (
Expand All @@ -33,8 +34,6 @@
logging,
)
from ...utils.import_utils import (
is_causal_conv1d_available,
is_kernels_available,
is_mamba_ssm_available,
is_mambapy_available,
)
Expand All @@ -54,32 +53,6 @@
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None

_causal_conv1d_cache = None


def _lazy_load_causal_conv1d():
global _causal_conv1d_cache
if _causal_conv1d_cache is not None:
return _causal_conv1d_cache

if is_kernels_available():
from kernels import get_kernel

try:
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
except FileNotFoundError:
# no kernel binary match, fallback to slow path
_causal_conv1d_cache = (None, None)
else:
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
elif is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update

_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
else:
_causal_conv1d_cache = (None, None)
return _causal_conv1d_cache


class MambaCache:
"""
Expand Down Expand Up @@ -236,7 +209,12 @@ def __init__(self, config: MambaConfig, layer_idx: int):
self.warn_slow_implementation()

def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand Down Expand Up @@ -287,7 +265,12 @@ def cuda_kernels_forward(
)

else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
Expand Down Expand Up @@ -451,7 +434,12 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand Down