From 4f3c38dc86419e78285c2fe26dee58123c569c11 Mon Sep 17 00:00:00 2001 From: romit Date: Mon, 13 Oct 2025 08:56:21 +0000 Subject: [PATCH 01/16] Added kernels from kernel hub for Bamba model --- .../models/bamba/modular_bamba.py | 55 +++++++++++++++---- src/transformers/utils/import_utils.py | 4 ++ 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 85b6fed82efb..baa5119ef53f 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -50,20 +50,47 @@ can_return_tuple, logging, ) -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_einops_available, + is_kernels_available, + is_mamba_2_ssm_available, +) from .configuration_bamba import BambaConfig -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - selective_state_update = None +selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None +causal_conv1d_update, causal_conv1d_fn = None, None + + +def _lazy_load_mamba2_ssm(): + global selective_state_update + + if is_kernels_available() and is_einops_available(): + from kernels import get_kernel + + mamba_ssm = get_kernel("kernels-community/mamba-ssm") + + selective_state_update = mamba_ssm.ops.triton.selective_state_update.selective_state_update + mamba_chunk_scan_combined = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined + mamba_split_conv1d_scan_combined = mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined + + elif is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None + +def _lazy_load_causal_conv1d(): + global causal_conv1d_update, causal_conv1d_fn + + if is_kernels_available() and is_einops_available(): + from kernels import get_kernel + + causal_conv1d = get_kernel("kernels-community/causal-conv1d") + causal_conv1d_fn = causal_conv1d.causal_conv1d_fn + causal_conv1d_update = causal_conv1d.causal_conv1d_update + + elif is_causal_conv1d_available: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) @@ -285,11 +312,17 @@ def __init__(self, config: BambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + _lazy_load_causal_conv1d() + _lazy_load_mamba2_ssm() + + global is_fast_path_available + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d" + " https://github.com/Dao-AILab/causal-conv1d or install the kernels library using `pip install kernels`" ) else: logger.warning_once("The fast path for Bamba will be used when running the model on a GPU") diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index a956efc97fdb..45a9141e6249 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1166,6 +1166,10 @@ def is_matplotlib_available() -> bool: def is_mistral_common_available() -> bool: return _is_package_available("mistral_common") +@lru_cache +def is_einops_available() -> bool: + return _is_package_available("einops") + def check_torch_load_is_safe() -> None: if not is_torch_greater_or_equal("2.6"): From a8920655018ccfa1a9768cac0ee64a494541f7f6 Mon Sep 17 00:00:00 2001 From: romit Date: Mon, 10 Nov 2025 17:06:07 +0000 Subject: [PATCH 02/16] Updated kernel loading Signed-off-by: romit --- src/transformers/integrations/hub_kernels.py | 6 +- .../models/bamba/modular_bamba.py | 63 ++++++++----------- .../models/jamba/modular_jamba.py | 41 ++++++++---- .../models/mamba2/modeling_mamba2.py | 61 ++++++++++++------ src/transformers/utils/import_utils.py | 1 + 5 files changed, 99 insertions(+), 73 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 112ac670e9a0..a5505fb76ab5 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -185,6 +185,7 @@ def register_kernel_mapping(*args, **kwargs): _HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = { "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"}, + "mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "revision": "clean-mamba-ssm"}, } _KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {} @@ -264,8 +265,9 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] try: repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] + revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) - kernel = get_kernel(repo_id, version=version) + kernel = get_kernel(repo_id, revision=revision, version=version) mapping[kernel_name] = kernel except FileNotFoundError: mapping[kernel_name] = None @@ -286,7 +288,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] if callable(is_kernel_available) and is_kernel_available(): # Try to import the module "{kernel_name}" from parent package level try: - module = importlib.import_module(f"{kernel_name}") + module = importlib.import_module(f"{new_kernel_name}") mapping[kernel_name] = module return module except Exception: diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index b4bfe0109bb5..c211eea303ac 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -42,6 +42,7 @@ segment_sum, ) +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -51,50 +52,37 @@ can_return_tuple, logging, ) -from ...utils.import_utils import ( - is_causal_conv1d_available, - is_einops_available, - is_kernels_available, - is_mamba_2_ssm_available, -) from .configuration_bamba import BambaConfig -selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None -causal_conv1d_update, causal_conv1d_fn = None, None - - -def _lazy_load_mamba2_ssm(): - global selective_state_update - - if is_kernels_available() and is_einops_available(): - from kernels import get_kernel - - mamba_ssm = get_kernel("kernels-community/mamba-ssm") - - selective_state_update = mamba_ssm.ops.triton.selective_state_update.selective_state_update - mamba_chunk_scan_combined = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined - mamba_split_conv1d_scan_combined = mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined - - elif is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - - -def _lazy_load_causal_conv1d(): +def _lazy_load_kernels(): global causal_conv1d_update, causal_conv1d_fn - if is_kernels_available() and is_einops_available(): - from kernels import get_kernel - - causal_conv1d = get_kernel("kernels-community/causal-conv1d") - causal_conv1d_fn = causal_conv1d.causal_conv1d_fn - causal_conv1d_update = causal_conv1d.causal_conv1d_update + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = ( + ( + mamba_ssm.ops.triton.selective_state_update.selective_state_update, + mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined, + mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined, + ) + if mamba_ssm is not None + else (None, None, None) + ) - elif is_causal_conv1d_available: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) +selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None +causal_conv1d_update, causal_conv1d_fn = None, None +is_fast_path_available = False logger = logging.get_logger(__name__) @@ -302,8 +290,7 @@ def __init__(self, config: BambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) - _lazy_load_causal_conv1d() - _lazy_load_mamba2_ssm() + _lazy_load_kernels() global is_fast_path_available is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index c6cfe339fabb..f3f92cb2c14e 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -24,6 +24,7 @@ from torch import nn from ...activations import ACT2FN +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -31,27 +32,35 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import OutputRecorder, check_model_inputs -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm, eager_attention_forward from ..mistral.modeling_mistral import MistralMLP from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM from .configuration_jamba import JambaConfig -if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None +def _lazy_load_kernels(): + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) + ) + + global selective_state_update, mamba_inner_fn, selective_scan_fn + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update, mamba_inner_fn, selective_scan_fn = ( + ( + mamba_ssm.ops.triton.selective_state_update.selective_state_update, + mamba_ssm.ops.selective_scan_interface.mamba_inner_fn, + mamba_ssm.ops.selective_scan_interface.selective_scan_fn, + ) + if mamba_ssm is not None + else (None, None, None) + ) -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) +is_fast_path_available = False logger = logging.get_logger(__name__) @@ -257,6 +266,12 @@ def __init__(self, config: JambaConfig, layer_idx): self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + _lazy_load_kernels() + + global is_fast_path_available + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) if not is_fast_path_available: logger.warning_once( "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 6f1f31b9002c..90bd7f2ea85e 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -30,33 +31,40 @@ auto_docstring, logging, ) -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_mamba2 import Mamba2Config logger = logging.get_logger(__name__) -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all( - ( - selective_state_update, - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - causal_conv1d_fn, - causal_conv1d_update, +def _lazy_load_kernels(): + global causal_conv1d_update, causal_conv1d_fn + + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) ) -) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = ( + ( + mamba_ssm.ops.triton.selective_state_update.selective_state_update, + mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined, + mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined, + ) + if mamba_ssm is not None + else (None, None, None) + ) + + +selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None +causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = False # Helper methods for segment sum computation @@ -285,6 +293,19 @@ def __init__(self, config: Mamba2Config, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias + _lazy_load_kernels() + global is_fast_path_available + + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 5c15601fecb6..3c2c4797ff47 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1175,6 +1175,7 @@ def is_matplotlib_available() -> bool: def is_mistral_common_available() -> bool: return _is_package_available("mistral_common") + @lru_cache def is_einops_available() -> bool: return _is_package_available("einops") From 5ed0fc1951bb4c9b9a57b5abf89291a7996aa55e Mon Sep 17 00:00:00 2001 From: romit Date: Wed, 12 Nov 2025 15:41:52 +0000 Subject: [PATCH 03/16] Remove einops Signed-off-by: romit --- src/transformers/utils/import_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 3c2c4797ff47..b38ea64cc4ff 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1176,11 +1176,6 @@ def is_mistral_common_available() -> bool: return _is_package_available("mistral_common") -@lru_cache -def is_einops_available() -> bool: - return _is_package_available("einops") - - @lru_cache def is_opentelemetry_available() -> bool: return _is_package_available("opentelemetry") and version.parse( From 0a4f79b5f37896835bbdae87eb5ff2a7ed9bd046 Mon Sep 17 00:00:00 2001 From: romit Date: Wed, 12 Nov 2025 22:11:06 +0530 Subject: [PATCH 04/16] Removed global vars Signed-off-by: romit --- .../models/bamba/modular_bamba.py | 45 +++++--------- .../models/jamba/modular_jamba.py | 40 ++++-------- .../models/mamba2/modeling_mamba2.py | 62 ++++++++----------- 3 files changed, 52 insertions(+), 95 deletions(-) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index c211eea303ac..b200521d93b8 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -55,34 +55,22 @@ from .configuration_bamba import BambaConfig -def _lazy_load_kernels(): - global causal_conv1d_update, causal_conv1d_fn - - causal_conv1d = lazy_load_kernel("causal-conv1d") - causal_conv1d_update, causal_conv1d_fn = ( - (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) - if causal_conv1d is not None - else (None, None) - ) - - global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined - - mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = ( - ( - mamba_ssm.ops.triton.selective_state_update.selective_state_update, - mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined, - mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined, - ) - if mamba_ssm is not None - else (None, None, None) - ) - +causal_conv1d = lazy_load_kernel("causal-conv1d") +causal_conv1d_update = causal_conv1d.causal_conv1d_update if causal_conv1d is not None else None +causal_conv1d_fn = causal_conv1d.causal_conv1d_fn if causal_conv1d is not None else None -selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None -causal_conv1d_update, causal_conv1d_fn = None, None +mamba_ssm = lazy_load_kernel("mamba-ssm") +selective_state_update = ( + mamba_ssm.ops.triton.selective_state_update.selective_state_update if mamba_ssm is not None else None +) +mamba_chunk_scan_combined = ( + mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined if mamba_ssm is not None else None +) +mamba_split_conv1d_scan_combined = ( + mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined if mamba_ssm is not None else None +) -is_fast_path_available = False +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) logger = logging.get_logger(__name__) @@ -290,11 +278,6 @@ def __init__(self, config: BambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) - _lazy_load_kernels() - - global is_fast_path_available - is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index f3f92cb2c14e..74976eb42992 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -38,29 +38,21 @@ from .configuration_jamba import JambaConfig -def _lazy_load_kernels(): - global causal_conv1d_update, causal_conv1d_fn - causal_conv1d = lazy_load_kernel("causal-conv1d") - causal_conv1d_update, causal_conv1d_fn = ( - (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) - if causal_conv1d is not None - else (None, None) - ) - - global selective_state_update, mamba_inner_fn, selective_scan_fn - mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update, mamba_inner_fn, selective_scan_fn = ( - ( - mamba_ssm.ops.triton.selective_state_update.selective_state_update, - mamba_ssm.ops.selective_scan_interface.mamba_inner_fn, - mamba_ssm.ops.selective_scan_interface.selective_scan_fn, - ) - if mamba_ssm is not None - else (None, None, None) - ) +causal_conv1d = lazy_load_kernel("causal-conv1d") +causal_conv1d_update = causal_conv1d.causal_conv1d_update if causal_conv1d is not None else None +causal_conv1d_fn = causal_conv1d.causal_conv1d_fn if causal_conv1d is not None else None + +mamba_ssm = lazy_load_kernel("mamba-ssm") +selective_state_update = ( + mamba_ssm.ops.triton.selective_state_update.selective_state_update if mamba_ssm is not None else None +) +mamba_inner_fn = mamba_ssm.ops.selective_scan_interface.mamba_inner_fn if mamba_ssm is not None else None +selective_scan_fn = mamba_ssm.ops.selective_scan_interface.selective_scan_fn if mamba_ssm is not None else None -is_fast_path_available = False +is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +) logger = logging.get_logger(__name__) @@ -266,12 +258,6 @@ def __init__(self, config: JambaConfig, layer_idx): self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) - _lazy_load_kernels() - - global is_fast_path_available - is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) - ) if not is_fast_path_available: logger.warning_once( "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 90bd7f2ea85e..2b19f6c4f430 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -37,34 +37,35 @@ logger = logging.get_logger(__name__) -def _lazy_load_kernels(): - global causal_conv1d_update, causal_conv1d_fn - - causal_conv1d = lazy_load_kernel("causal-conv1d") - causal_conv1d_update, causal_conv1d_fn = ( - (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) - if causal_conv1d is not None - else (None, None) - ) +causal_conv1d = lazy_load_kernel("causal-conv1d") +causal_conv1d_update, causal_conv1d_fn = ( + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) + if causal_conv1d is not None + else (None, None) +) - global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined - mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = ( - ( - mamba_ssm.ops.triton.selective_state_update.selective_state_update, - mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined, - mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined, - ) - if mamba_ssm is not None - else (None, None, None) +mamba_ssm = lazy_load_kernel("mamba-ssm") +selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = ( + ( + mamba_ssm.ops.triton.selective_state_update.selective_state_update, + mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined, + mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined, ) + if mamba_ssm is not None + else (None, None, None) +) - -selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None -causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = False +is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) +) # Helper methods for segment sum computation @@ -293,19 +294,6 @@ def __init__(self, config: Mamba2Config, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias - _lazy_load_kernels() - global is_fast_path_available - - is_fast_path_available = all( - ( - selective_state_update, - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - causal_conv1d_fn, - causal_conv1d_update, - ) - ) - if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" From db594e9e512482ba1bbb7fa2ca5e2530b2de54a5 Mon Sep 17 00:00:00 2001 From: romit Date: Wed, 12 Nov 2025 17:18:59 +0000 Subject: [PATCH 05/16] Fixed make style Signed-off-by: romit --- src/transformers/models/mamba2/modeling_mamba2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 2b19f6c4f430..dce4cb2f4a39 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -39,9 +39,7 @@ causal_conv1d = lazy_load_kernel("causal-conv1d") causal_conv1d_update, causal_conv1d_fn = ( - (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) - if causal_conv1d is not None - else (None, None) + (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None else (None, None) ) global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined From 2ef69e6b7069bb5573b7c9ab363277d9238a2230 Mon Sep 17 00:00:00 2001 From: romit Date: Wed, 12 Nov 2025 17:21:02 +0000 Subject: [PATCH 06/16] Nit Signed-off-by: romit --- src/transformers/models/bamba/modular_bamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index b200521d93b8..cdd1dffa4eba 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -282,7 +282,7 @@ def __init__(self, config: BambaConfig, layer_idx: int): logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d or install the kernels library using `pip install kernels`" + " https://github.com/Dao-AILab/causal-conv1d" ) else: logger.warning_once("The fast path for Bamba will be used when running the model on a GPU") From 8ae9ce5712955beb63b083d680bbd43f3a60af89 Mon Sep 17 00:00:00 2001 From: romit Date: Thu, 13 Nov 2025 16:41:57 +0000 Subject: [PATCH 07/16] Added modeling files Signed-off-by: romit --- .../models/bamba/modeling_bamba.py | 29 ++++++++++--------- .../modeling_granitemoehybrid.py | 29 ++++++++++--------- .../models/jamba/modeling_jamba.py | 26 ++++++++--------- .../models/qwen3_next/modeling_qwen3_next.py | 5 +--- 4 files changed, 46 insertions(+), 43 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 9285068292ad..1f857abe4589 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -35,6 +35,7 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -42,22 +43,9 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_bamba import BambaConfig -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - selective_state_update = None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - - logger = logging.get_logger(__name__) @@ -498,6 +486,21 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return hidden_states +causal_conv1d = lazy_load_kernel("causal-conv1d") +causal_conv1d_update = causal_conv1d.causal_conv1d_update if causal_conv1d is not None else None +causal_conv1d_fn = causal_conv1d.causal_conv1d_fn if causal_conv1d is not None else None + +mamba_ssm = lazy_load_kernel("mamba-ssm") +selective_state_update = ( + mamba_ssm.ops.triton.selective_state_update.selective_state_update if mamba_ssm is not None else None +) +mamba_chunk_scan_combined = ( + mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined if mamba_ssm is not None else None +) +mamba_split_conv1d_scan_combined = ( + mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined if mamba_ssm is not None else None +) + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 947d250cd134..dccb71f52b13 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -31,6 +31,7 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -39,22 +40,9 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import check_model_inputs -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_granitemoehybrid import GraniteMoeHybridConfig -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - selective_state_update = None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - - logger = logging.get_logger(__name__) @@ -334,6 +322,21 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return hidden_states +causal_conv1d = lazy_load_kernel("causal-conv1d") +causal_conv1d_update = causal_conv1d.causal_conv1d_update if causal_conv1d is not None else None +causal_conv1d_fn = causal_conv1d.causal_conv1d_fn if causal_conv1d is not None else None + +mamba_ssm = lazy_load_kernel("mamba-ssm") +selective_state_update = ( + mamba_ssm.ops.triton.selective_state_update.selective_state_update if mamba_ssm is not None else None +) +mamba_chunk_scan_combined = ( + mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined if mamba_ssm is not None else None +) +mamba_split_conv1d_scan_combined = ( + mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined if mamba_ssm is not None else None +) + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 94d8cdc3f7be..cebd0eabfefe 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -32,6 +32,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -39,22 +40,9 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import OutputRecorder, check_model_inputs -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_jamba import JambaConfig -if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - - logger = logging.get_logger(__name__) @@ -269,6 +257,18 @@ def forward( return attn_output, attn_weights +causal_conv1d = lazy_load_kernel("causal-conv1d") +causal_conv1d_update = causal_conv1d.causal_conv1d_update if causal_conv1d is not None else None +causal_conv1d_fn = causal_conv1d.causal_conv1d_fn if causal_conv1d is not None else None + +mamba_ssm = lazy_load_kernel("mamba-ssm") +selective_state_update = ( + mamba_ssm.ops.triton.selective_state_update.selective_state_update if mamba_ssm is not None else None +) +mamba_inner_fn = mamba_ssm.ops.selective_scan_interface.mamba_inner_fn if mamba_ssm is not None else None +selective_scan_fn = mamba_ssm.ops.selective_scan_interface.selective_scan_fn if mamba_ssm is not None else None + + is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 3847c43117a3..8c8c8f231211 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -43,10 +43,7 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import OutputRecorder, check_model_inputs -from ...utils.import_utils import ( - is_causal_conv1d_available, - is_flash_linear_attention_available, -) +from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from .configuration_qwen3_next import Qwen3NextConfig From fa6ea6e95f2838067d2f64f7987f70c611bd66b1 Mon Sep 17 00:00:00 2001 From: romit Date: Mon, 17 Nov 2025 06:10:01 +0000 Subject: [PATCH 08/16] Fixed merge conflict Signed-off-by: romit --- src/transformers/models/bamba/modular_bamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 5ea700839a68..440593f04679 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -42,8 +42,8 @@ segment_sum, ) -from ...integrations.hub_kernels import lazy_load_kernel from ... import initialization as init +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel From 74c986a447b6ca3596cf3685feeaf7fb11d7677a Mon Sep 17 00:00:00 2001 From: romitjain Date: Mon, 8 Dec 2025 06:34:35 +0000 Subject: [PATCH 09/16] fixed lint Signed-off-by: romitjain --- src/transformers/models/bamba/modular_bamba.py | 1 + src/transformers/models/jamba/modeling_jamba.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 440593f04679..81905dd56f3e 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -73,6 +73,7 @@ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + logger = logging.get_logger(__name__) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 9c7d2107a155..0d7d2d22fed0 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -32,8 +32,8 @@ from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin -from ...integrations.hub_kernels import lazy_load_kernel from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast From d1d78201bc42213c94691ff5c76bf1d875a8c4ab Mon Sep 17 00:00:00 2001 From: romit <11757603+romitjain@users.noreply.github.com> Date: Mon, 8 Dec 2025 18:31:18 +0530 Subject: [PATCH 10/16] Removed global import --- src/transformers/models/mamba2/modeling_mamba2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index d16e053ebd61..20f6500617a9 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -43,8 +43,6 @@ (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None else (None, None) ) -global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined - mamba_ssm = lazy_load_kernel("mamba-ssm") selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = ( ( From a3800e87515ca398bf7905a345447bcb35ef65f3 Mon Sep 17 00:00:00 2001 From: romit Date: Tue, 9 Dec 2025 19:42:26 +0530 Subject: [PATCH 11/16] Small updates --- .../models/bamba/modular_bamba.py | 18 ++++++++---------- .../models/jamba/modular_jamba.py | 15 +++++++++------ .../models/mamba2/modeling_mamba2.py | 19 ++++++++----------- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 81905dd56f3e..ec1373562949 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -57,19 +57,17 @@ causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update = causal_conv1d.causal_conv1d_update if causal_conv1d is not None else None -causal_conv1d_fn = causal_conv1d.causal_conv1d_fn if causal_conv1d is not None else None +causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) +causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) mamba_ssm = lazy_load_kernel("mamba-ssm") -selective_state_update = ( - mamba_ssm.ops.triton.selective_state_update.selective_state_update if mamba_ssm is not None else None -) -mamba_chunk_scan_combined = ( - mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined if mamba_ssm is not None else None -) -mamba_split_conv1d_scan_combined = ( - mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined if mamba_ssm is not None else None +mamba_ssm_triton = getattr(getattr(mamba_ssm, "ops", None), "triton", None) +selective_state_update = getattr( + getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None ) +ssd_combined = getattr(mamba_ssm_triton, "ssd_combined", None) +mamba_chunk_scan_combined = getattr(ssd_combined, "mamba_chunk_scan_combined", None) +mamba_split_conv1d_scan_combined = getattr(ssd_combined, "mamba_split_conv1d_scan_combined", None) is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index 7f549e82f387..65c0214f4d67 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -40,15 +40,18 @@ causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update = causal_conv1d.causal_conv1d_update if causal_conv1d is not None else None -causal_conv1d_fn = causal_conv1d.causal_conv1d_fn if causal_conv1d is not None else None +causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) +causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) mamba_ssm = lazy_load_kernel("mamba-ssm") -selective_state_update = ( - mamba_ssm.ops.triton.selective_state_update.selective_state_update if mamba_ssm is not None else None +mamba_ssm_ops = getattr(mamba_ssm, "ops", None) +mamba_ssm_triton = getattr(mamba_ssm_ops, "triton", None) +selective_state_update = getattr( + getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None ) -mamba_inner_fn = mamba_ssm.ops.selective_scan_interface.mamba_inner_fn if mamba_ssm is not None else None -selective_scan_fn = mamba_ssm.ops.selective_scan_interface.selective_scan_fn if mamba_ssm is not None else None +selective_scan_interface = getattr(mamba_ssm_ops, "selective_scan_interface", None) +mamba_inner_fn = getattr(selective_scan_interface, "mamba_inner_fn", None) +selective_scan_fn = getattr(selective_scan_interface, "selective_scan_fn", None) is_fast_path_available = all( diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 20f6500617a9..f95631b39d16 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -39,20 +39,17 @@ causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update, causal_conv1d_fn = ( - (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) if causal_conv1d is not None else (None, None) -) +causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) +causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) mamba_ssm = lazy_load_kernel("mamba-ssm") -selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = ( - ( - mamba_ssm.ops.triton.selective_state_update.selective_state_update, - mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined, - mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined, - ) - if mamba_ssm is not None - else (None, None, None) +mamba_ssm_triton = getattr(getattr(mamba_ssm, "ops", None), "triton", None) +selective_state_update = getattr( + getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None ) +ssd_combined = getattr(mamba_ssm_triton, "ssd_combined", None) +mamba_chunk_scan_combined = getattr(ssd_combined, "mamba_chunk_scan_combined", None) +mamba_split_conv1d_scan_combined = getattr(ssd_combined, "mamba_split_conv1d_scan_combined", None) is_fast_path_available = all( ( From 8e489761210750c79cf06a75de5602a0e1984360 Mon Sep 17 00:00:00 2001 From: romitjain Date: Tue, 9 Dec 2025 14:23:27 +0000 Subject: [PATCH 12/16] Updated --- .../models/bamba/modeling_bamba.py | 18 ++++++++---------- .../modeling_granitemoehybrid.py | 18 ++++++++---------- .../models/jamba/modeling_jamba.py | 15 +++++++++------ 3 files changed, 25 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index da63a47102f3..7be086098eea 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -489,19 +489,17 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update = causal_conv1d.causal_conv1d_update if causal_conv1d is not None else None -causal_conv1d_fn = causal_conv1d.causal_conv1d_fn if causal_conv1d is not None else None +causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) +causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) mamba_ssm = lazy_load_kernel("mamba-ssm") -selective_state_update = ( - mamba_ssm.ops.triton.selective_state_update.selective_state_update if mamba_ssm is not None else None -) -mamba_chunk_scan_combined = ( - mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined if mamba_ssm is not None else None -) -mamba_split_conv1d_scan_combined = ( - mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined if mamba_ssm is not None else None +mamba_ssm_triton = getattr(getattr(mamba_ssm, "ops", None), "triton", None) +selective_state_update = getattr( + getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None ) +ssd_combined = getattr(mamba_ssm_triton, "ssd_combined", None) +mamba_chunk_scan_combined = getattr(ssd_combined, "mamba_chunk_scan_combined", None) +mamba_split_conv1d_scan_combined = getattr(ssd_combined, "mamba_split_conv1d_scan_combined", None) is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 92f7535312d1..50a72db73c95 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -360,19 +360,17 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update = causal_conv1d.causal_conv1d_update if causal_conv1d is not None else None -causal_conv1d_fn = causal_conv1d.causal_conv1d_fn if causal_conv1d is not None else None +causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) +causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) mamba_ssm = lazy_load_kernel("mamba-ssm") -selective_state_update = ( - mamba_ssm.ops.triton.selective_state_update.selective_state_update if mamba_ssm is not None else None -) -mamba_chunk_scan_combined = ( - mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined if mamba_ssm is not None else None -) -mamba_split_conv1d_scan_combined = ( - mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined if mamba_ssm is not None else None +mamba_ssm_triton = getattr(getattr(mamba_ssm, "ops", None), "triton", None) +selective_state_update = getattr( + getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None ) +ssd_combined = getattr(mamba_ssm_triton, "ssd_combined", None) +mamba_chunk_scan_combined = getattr(ssd_combined, "mamba_chunk_scan_combined", None) +mamba_split_conv1d_scan_combined = getattr(ssd_combined, "mamba_split_conv1d_scan_combined", None) is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 0d7d2d22fed0..00b8bb410ab0 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -295,15 +295,18 @@ def forward( causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update = causal_conv1d.causal_conv1d_update if causal_conv1d is not None else None -causal_conv1d_fn = causal_conv1d.causal_conv1d_fn if causal_conv1d is not None else None +causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) +causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) mamba_ssm = lazy_load_kernel("mamba-ssm") -selective_state_update = ( - mamba_ssm.ops.triton.selective_state_update.selective_state_update if mamba_ssm is not None else None +mamba_ssm_ops = getattr(mamba_ssm, "ops", None) +mamba_ssm_triton = getattr(mamba_ssm_ops, "triton", None) +selective_state_update = getattr( + getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None ) -mamba_inner_fn = mamba_ssm.ops.selective_scan_interface.mamba_inner_fn if mamba_ssm is not None else None -selective_scan_fn = mamba_ssm.ops.selective_scan_interface.selective_scan_fn if mamba_ssm is not None else None +selective_scan_interface = getattr(mamba_ssm_ops, "selective_scan_interface", None) +mamba_inner_fn = getattr(selective_scan_interface, "mamba_inner_fn", None) +selective_scan_fn = getattr(selective_scan_interface, "selective_scan_fn", None) is_fast_path_available = all( From 779486c074c9d22fe2a3e16e1ddcc4fb2b48e04e Mon Sep 17 00:00:00 2001 From: romitjain Date: Tue, 9 Dec 2025 14:38:03 +0000 Subject: [PATCH 13/16] Resolved merge conflicts --- src/transformers/models/bamba/modeling_bamba.py | 2 +- .../models/granitemoehybrid/modeling_granitemoehybrid.py | 2 +- src/transformers/models/jamba/modeling_jamba.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index beda5f06472a..bb51326db02b 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -35,8 +35,8 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...integrations.hub_kernels import lazy_load_kernel from ...integrations import use_kernel_forward_from_hub, use_kernelized_func +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 73f672dff981..57af59fb6f83 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -31,8 +31,8 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...integrations.hub_kernels import lazy_load_kernel from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index aef0be011b3a..2ac0bda68ed6 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -32,8 +32,8 @@ from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin -from ...integrations.hub_kernels import lazy_load_kernel from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast From 03590f18cf59be7a9e215b946bd7ade3d8b12a7a Mon Sep 17 00:00:00 2001 From: romit Date: Mon, 15 Dec 2025 08:31:13 +0000 Subject: [PATCH 14/16] Fixed the nested import Signed-off-by: romit --- src/transformers/models/bamba/modeling_bamba.py | 10 +++------- src/transformers/models/bamba/modular_bamba.py | 10 +++------- .../granitemoehybrid/modeling_granitemoehybrid.py | 12 ++++-------- src/transformers/models/jamba/modeling_jamba.py | 11 +++-------- src/transformers/models/jamba/modular_jamba.py | 11 +++-------- src/transformers/models/mamba2/modeling_mamba2.py | 9 +++------ .../models/qwen3_next/modeling_qwen3_next.py | 5 +---- 7 files changed, 20 insertions(+), 48 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 2c6a7f30163c..1e5631c514d6 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -494,13 +494,9 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) mamba_ssm = lazy_load_kernel("mamba-ssm") -mamba_ssm_triton = getattr(getattr(mamba_ssm, "ops", None), "triton", None) -selective_state_update = getattr( - getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None -) -ssd_combined = getattr(mamba_ssm_triton, "ssd_combined", None) -mamba_chunk_scan_combined = getattr(ssd_combined, "mamba_chunk_scan_combined", None) -mamba_split_conv1d_scan_combined = getattr(ssd_combined, "mamba_split_conv1d_scan_combined", None) +selective_state_update = getattr(mamba_ssm, "selective_state_update", None) +mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) +mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index ec1373562949..6b149570fc96 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -61,13 +61,9 @@ causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) mamba_ssm = lazy_load_kernel("mamba-ssm") -mamba_ssm_triton = getattr(getattr(mamba_ssm, "ops", None), "triton", None) -selective_state_update = getattr( - getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None -) -ssd_combined = getattr(mamba_ssm_triton, "ssd_combined", None) -mamba_chunk_scan_combined = getattr(ssd_combined, "mamba_chunk_scan_combined", None) -mamba_split_conv1d_scan_combined = getattr(ssd_combined, "mamba_split_conv1d_scan_combined", None) +selective_state_update = getattr(mamba_ssm, "selective_state_update", None) +mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) +mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 3f93e6ce88eb..8e7b3753926a 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -40,7 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_granitemoehybrid import GraniteMoeHybridConfig @@ -364,13 +364,9 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) mamba_ssm = lazy_load_kernel("mamba-ssm") -mamba_ssm_triton = getattr(getattr(mamba_ssm, "ops", None), "triton", None) -selective_state_update = getattr( - getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None -) -ssd_combined = getattr(mamba_ssm_triton, "ssd_combined", None) -mamba_chunk_scan_combined = getattr(ssd_combined, "mamba_chunk_scan_combined", None) -mamba_split_conv1d_scan_combined = getattr(ssd_combined, "mamba_split_conv1d_scan_combined", None) +selective_state_update = getattr(mamba_ssm, "selective_state_update", None) +mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) +mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 2ac0bda68ed6..d07e962ce71f 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -299,14 +299,9 @@ def forward( causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) mamba_ssm = lazy_load_kernel("mamba-ssm") -mamba_ssm_ops = getattr(mamba_ssm, "ops", None) -mamba_ssm_triton = getattr(mamba_ssm_ops, "triton", None) -selective_state_update = getattr( - getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None -) -selective_scan_interface = getattr(mamba_ssm_ops, "selective_scan_interface", None) -mamba_inner_fn = getattr(selective_scan_interface, "mamba_inner_fn", None) -selective_scan_fn = getattr(selective_scan_interface, "selective_scan_fn", None) +selective_state_update = getattr(mamba_ssm, "selective_state_update", None) +mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) +selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) is_fast_path_available = all( diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index 65c0214f4d67..9c90272067a6 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -44,14 +44,9 @@ causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) mamba_ssm = lazy_load_kernel("mamba-ssm") -mamba_ssm_ops = getattr(mamba_ssm, "ops", None) -mamba_ssm_triton = getattr(mamba_ssm_ops, "triton", None) -selective_state_update = getattr( - getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None -) -selective_scan_interface = getattr(mamba_ssm_ops, "selective_scan_interface", None) -mamba_inner_fn = getattr(selective_scan_interface, "mamba_inner_fn", None) -selective_scan_fn = getattr(selective_scan_interface, "selective_scan_fn", None) +selective_state_update = getattr(mamba_ssm, "selective_state_update", None) +mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) +selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) is_fast_path_available = all( diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index f95631b39d16..f7c4f86cc371 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -44,12 +44,9 @@ mamba_ssm = lazy_load_kernel("mamba-ssm") mamba_ssm_triton = getattr(getattr(mamba_ssm, "ops", None), "triton", None) -selective_state_update = getattr( - getattr(mamba_ssm_triton, "selective_state_update", None), "selective_state_update", None -) -ssd_combined = getattr(mamba_ssm_triton, "ssd_combined", None) -mamba_chunk_scan_combined = getattr(ssd_combined, "mamba_chunk_scan_combined", None) -mamba_split_conv1d_scan_combined = getattr(ssd_combined, "mamba_split_conv1d_scan_combined", None) +selective_state_update = getattr(mamba_ssm, "selective_state_update", None) +mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) +mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) is_fast_path_available = all( ( diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index f67d83282696..535ec9a8a075 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -45,10 +45,7 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast -from ...utils.import_utils import ( - is_causal_conv1d_available, - is_flash_linear_attention_available, -) +from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from .configuration_qwen3_next import Qwen3NextConfig From 4867ef7fc6cb939a3e9c60af57066d3ae87f548c Mon Sep 17 00:00:00 2001 From: romit Date: Mon, 15 Dec 2025 09:43:04 +0000 Subject: [PATCH 15/16] Moved imports inside mixer Signed-off-by: romit --- .../models/bamba/modeling_bamba.py | 26 ++++++----- .../models/bamba/modular_bamba.py | 26 ++++++----- .../modeling_granitemoehybrid.py | 26 ++++++----- .../models/jamba/modeling_jamba.py | 31 ++++++------- .../models/jamba/modular_jamba.py | 32 +++++++------- .../models/mamba2/modeling_mamba2.py | 44 ++++++++++--------- 6 files changed, 98 insertions(+), 87 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 1e5631c514d6..afcce31c7c62 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -489,18 +489,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return hidden_states -causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) -causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) - -mamba_ssm = lazy_load_kernel("mamba-ssm") -selective_state_update = getattr(mamba_ssm, "selective_state_update", None) -mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) -mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) - -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer class BambaMixer(nn.Module): """ @@ -572,6 +560,20 @@ def __init__(self, config: BambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + + global is_fast_path_available + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 6b149570fc96..de8e7e77693d 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -56,18 +56,6 @@ from .configuration_bamba import BambaConfig -causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) -causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) - -mamba_ssm = lazy_load_kernel("mamba-ssm") -selective_state_update = getattr(mamba_ssm, "selective_state_update", None) -mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) -mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) - -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - logger = logging.get_logger(__name__) @@ -274,6 +262,20 @@ def __init__(self, config: BambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + + global is_fast_path_available + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 8e7b3753926a..8feabb5a363d 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -359,18 +359,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return hidden_states -causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) -causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) - -mamba_ssm = lazy_load_kernel("mamba-ssm") -selective_state_update = getattr(mamba_ssm, "selective_state_update", None) -mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) -mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) - -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer class GraniteMoeHybridMambaLayer(nn.Module): """ @@ -442,6 +430,20 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + + global is_fast_path_available + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index d07e962ce71f..0797dbb31dd0 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -294,21 +294,6 @@ def forward( return attn_output, attn_weights -causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) -causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) - -mamba_ssm = lazy_load_kernel("mamba-ssm") -selective_state_update = getattr(mamba_ssm, "selective_state_update", None) -mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) -selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) - - -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) - - class JambaMambaMixer(nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. @@ -362,6 +347,22 @@ def __init__(self, config: JambaConfig, layer_idx): self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_inner_fn, selective_scan_fn + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) + selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) + + global is_fast_path_available + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index 9c90272067a6..39b933df7c13 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -39,21 +39,6 @@ from .configuration_jamba import JambaConfig -causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) -causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) - -mamba_ssm = lazy_load_kernel("mamba-ssm") -selective_state_update = getattr(mamba_ssm, "selective_state_update", None) -mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) -selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) - - -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) - - logger = logging.get_logger(__name__) @@ -257,6 +242,23 @@ def __init__(self, config: JambaConfig, layer_idx): self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_inner_fn, selective_scan_fn + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) + selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) + + global is_fast_path_available + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + + if not is_fast_path_available: logger.warning_once( "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index f7c4f86cc371..eedb11ded61c 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -38,27 +38,6 @@ logger = logging.get_logger(__name__) -causal_conv1d = lazy_load_kernel("causal-conv1d") -causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) -causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) - -mamba_ssm = lazy_load_kernel("mamba-ssm") -mamba_ssm_triton = getattr(getattr(mamba_ssm, "ops", None), "triton", None) -selective_state_update = getattr(mamba_ssm, "selective_state_update", None) -mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) -mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) - -is_fast_path_available = all( - ( - selective_state_update, - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - causal_conv1d_fn, - causal_conv1d_update, - ) -) - - # Helper methods for segment sum computation @@ -285,6 +264,29 @@ def __init__(self, config: Mamba2Config, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + + global is_fast_path_available + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) + + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" From 8ab9724e1b2091d046357c8899b76a28b29ac81c Mon Sep 17 00:00:00 2001 From: romit Date: Mon, 15 Dec 2025 13:01:51 +0000 Subject: [PATCH 16/16] CI CD fix Signed-off-by: romit --- src/transformers/models/jamba/modular_jamba.py | 1 - src/transformers/models/mamba2/modeling_mamba2.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index 39b933df7c13..4bfdbc51c7ce 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -258,7 +258,6 @@ def __init__(self, config: JambaConfig, layer_idx): (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) - if not is_fast_path_available: logger.warning_once( "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index eedb11ded61c..caa370b5a309 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -286,7 +286,6 @@ def __init__(self, config: Mamba2Config, layer_idx: int): ) ) - if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"