diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index e1907a8fab88..c2d874474750 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -370,7 +370,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ if callable(is_kernel_available) and is_kernel_available(): # Try to import the module "{kernel_name}" from parent package level try: - module = importlib.import_module(f"{kernel_name}") + module = importlib.import_module(f"{new_kernel_name}") mapping[kernel_name] = module return module except Exception: diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 9a5790c1fb6a..afcce31c7c62 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -36,6 +36,7 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernelized_func +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -44,22 +45,9 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import maybe_autocast -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_bamba import BambaConfig -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - selective_state_update = None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - - logger = logging.get_logger(__name__) @@ -501,9 +489,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return hidden_states -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer class BambaMixer(nn.Module): """ @@ -575,6 +560,20 @@ def __init__(self, config: BambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + + global is_fast_path_available + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index d29273940b8a..de8e7e77693d 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -43,6 +43,7 @@ ) from ... import initialization as init +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -52,24 +53,9 @@ can_return_tuple, logging, ) -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_bamba import BambaConfig -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - selective_state_update = None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - logger = logging.get_logger(__name__) @@ -276,6 +262,20 @@ def __init__(self, config: BambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + + global is_fast_path_available + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 2116a9811667..8feabb5a363d 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -32,6 +32,7 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -40,22 +41,9 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import check_model_inputs, maybe_autocast -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_granitemoehybrid import GraniteMoeHybridConfig -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - selective_state_update = None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - - logger = logging.get_logger(__name__) @@ -371,9 +359,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return hidden_states -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer class GraniteMoeHybridMambaLayer(nn.Module): """ @@ -445,6 +430,20 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + + global is_fast_path_available + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 9314ff7c90c6..0797dbb31dd0 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -33,6 +33,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -40,22 +41,9 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import OutputRecorder, check_model_inputs -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_jamba import JambaConfig -if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - - logger = logging.get_logger(__name__) @@ -306,11 +294,6 @@ def forward( return attn_output, attn_weights -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) - - class JambaMambaMixer(nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. @@ -364,6 +347,22 @@ def __init__(self, config: JambaConfig, layer_idx): self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_inner_fn, selective_scan_fn + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) + selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) + + global is_fast_path_available + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index 343dc19f9707..4bfdbc51c7ce 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -25,6 +25,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -32,29 +33,12 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import OutputRecorder, check_model_inputs -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm, eager_attention_forward from ..mistral.modeling_mistral import MistralMLP from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM from .configuration_jamba import JambaConfig -if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) - - logger = logging.get_logger(__name__) @@ -258,6 +242,22 @@ def __init__(self, config: JambaConfig, layer_idx): self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_inner_fn, selective_scan_fn + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) + selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) + + global is_fast_path_available + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index d2e7add0d4f3..caa370b5a309 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -24,6 +24,7 @@ from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -31,35 +32,12 @@ auto_docstring, logging, ) -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_mamba2 import Mamba2Config logger = logging.get_logger(__name__) -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all( - ( - selective_state_update, - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - causal_conv1d_fn, - causal_conv1d_update, - ) -) - - # Helper methods for segment sum computation @@ -286,6 +264,28 @@ def __init__(self, config: Mamba2Config, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + + global is_fast_path_available + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index f67d83282696..535ec9a8a075 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -45,10 +45,7 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast -from ...utils.import_utils import ( - is_causal_conv1d_available, - is_flash_linear_attention_available, -) +from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from .configuration_qwen3_next import Qwen3NextConfig