Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4f3c38d
Added kernels from kernel hub for Bamba model
romitjain Oct 13, 2025
d68a2e7
Merge branch 'main' of github.com:romitjain/transformers into romit/f…
romitjain Nov 10, 2025
a892065
Updated kernel loading
romitjain Nov 10, 2025
5ed0fc1
Remove einops
romitjain Nov 12, 2025
0a4f79b
Removed global vars
romitjain Nov 12, 2025
db594e9
Fixed make style
romitjain Nov 12, 2025
2ef69e6
Nit
romitjain Nov 12, 2025
8ae9ce5
Added modeling files
romitjain Nov 13, 2025
dc04c10
Merge branch 'main' of github.com:huggingface/transformers into romit…
romitjain Nov 17, 2025
fa6ea6e
Fixed merge conflict
romitjain Nov 17, 2025
caa6d4a
Merge branch 'main' of github.com:huggingface/transformers into romit…
romitjain Dec 8, 2025
74c986a
fixed lint
romitjain Dec 8, 2025
4194c31
Merge branch 'main' into romit/feature-bamba-kernels-from-hub
MekkCyber Dec 8, 2025
e3ea6b8
Merge branch 'huggingface:main' into romit/feature-bamba-kernels-from…
romitjain Dec 8, 2025
d1d7820
Removed global import
romitjain Dec 8, 2025
1e5d856
Merge branch 'main' into romit/feature-bamba-kernels-from-hub
romitjain Dec 9, 2025
a3800e8
Small updates
romitjain Dec 9, 2025
8e48976
Updated
romitjain Dec 9, 2025
df3be53
Merge branch 'main' of github.com:huggingface/transformers into romit…
romitjain Dec 9, 2025
779486c
Resolved merge conflicts
romitjain Dec 9, 2025
bca8f06
Merge branch 'main' of github.com:huggingface/transformers into romit…
romitjain Dec 15, 2025
03590f1
Fixed the nested import
romitjain Dec 15, 2025
4867ef7
Moved imports inside mixer
romitjain Dec 15, 2025
8ab9724
CI CD fix
romitjain Dec 15, 2025
db7d824
Merge branch 'main' into romit/feature-bamba-kernels-from-hub
MekkCyber Dec 15, 2025
e555ef3
Merge branch 'main' into romit/feature-bamba-kernels-from-hub
MekkCyber Dec 15, 2025
0569a61
Merge branch 'main' into romit/feature-bamba-kernels-from-hub
romitjain Dec 16, 2025
65e283a
Merge branch 'main' into romit/feature-bamba-kernels-from-hub
MekkCyber Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
if callable(is_kernel_available) and is_kernel_available():
# Try to import the module "{kernel_name}" from parent package level
try:
module = importlib.import_module(f"{kernel_name}")
module = importlib.import_module(f"{new_kernel_name}")
Comment thread
romitjain marked this conversation as resolved.
mapping[kernel_name] = module
return module
except Exception:
Expand Down
31 changes: 15 additions & 16 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
Expand All @@ -44,22 +45,9 @@
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import maybe_autocast
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_bamba import BambaConfig


if is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
selective_state_update = None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -501,9 +489,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
return hidden_states


is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))


# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
class BambaMixer(nn.Module):
"""
Expand Down Expand Up @@ -575,6 +560,20 @@ def __init__(self, config: BambaConfig, layer_idx: int):

self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)

global is_fast_path_available
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
Expand Down
30 changes: 15 additions & 15 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)

from ... import initialization as init
from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
Expand All @@ -52,24 +53,9 @@
can_return_tuple,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_bamba import BambaConfig


if is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
selective_state_update = None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None

is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -276,6 +262,20 @@ def __init__(self, config: BambaConfig, layer_idx: int):

self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)

global is_fast_path_available
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...integrations.hub_kernels import lazy_load_kernel
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
Expand All @@ -40,22 +41,9 @@
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import check_model_inputs, maybe_autocast
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_granitemoehybrid import GraniteMoeHybridConfig


if is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
selective_state_update = None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None
Comment thread
vasqu marked this conversation as resolved.


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -371,9 +359,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
return hidden_states


is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))


# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
class GraniteMoeHybridMambaLayer(nn.Module):
"""
Expand Down Expand Up @@ -445,6 +430,20 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):

self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)

global is_fast_path_available
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
Expand Down
35 changes: 17 additions & 18 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,17 @@
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...integrations.hub_kernels import lazy_load_kernel
from ...masking_utils import create_causal_mask
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import OutputRecorder, check_model_inputs
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
from .configuration_jamba import JambaConfig


if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -306,11 +294,6 @@ def forward(
return attn_output, attn_weights


is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)


class JambaMambaMixer(nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
Expand Down Expand Up @@ -364,6 +347,22 @@ def __init__(self, config: JambaConfig, layer_idx):
self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_inner_fn, selective_scan_fn
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)
selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)

global is_fast_path_available
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
Expand Down
34 changes: 17 additions & 17 deletions src/transformers/models/jamba/modular_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,20 @@

from ... import initialization as init
from ...activations import ACT2FN
from ...integrations.hub_kernels import lazy_load_kernel
from ...masking_utils import create_causal_mask
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, logging
from ...utils.generic import OutputRecorder, check_model_inputs
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm, eager_attention_forward
from ..mistral.modeling_mistral import MistralMLP
from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM
from .configuration_jamba import JambaConfig


if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None

is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -258,6 +242,22 @@ def __init__(self, config: JambaConfig, layer_idx):
self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_inner_fn, selective_scan_fn
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)
selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)

global is_fast_path_available
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
Expand Down
46 changes: 23 additions & 23 deletions src/transformers/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,20 @@
from ... import initialization as init
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
auto_docstring,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_mamba2 import Mamba2Config


logger = logging.get_logger(__name__)


if is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None

is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,
)
)


# Helper methods for segment sum computation


Expand Down Expand Up @@ -286,6 +264,28 @@ def __init__(self, config: Mamba2Config, layer_idx: int):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)

global is_fast_path_available
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,
)
)

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/qwen3_next/modeling_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
from ...utils.import_utils import (
is_causal_conv1d_available,
is_flash_linear_attention_available,
)
from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available
from .configuration_qwen3_next import Qwen3NextConfig


Expand Down