Skip to content

🚨 [Cache] Native mamba & hybrid cache#44950

Merged
Cyrilvallez merged 59 commits intomainfrom
clean-mamba-cache
Mar 31, 2026
Merged

🚨 [Cache] Native mamba & hybrid cache#44950
Cyrilvallez merged 59 commits intomainfrom
clean-mamba-cache

Conversation

@Cyrilvallez
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez commented Mar 23, 2026

What does this PR do?

As per the title. This PR finally makes mamba layer caches first class citizen, and adds native support for them.

It supports the following layers combinations:

  • all mamba layers
  • alternating attention layer/mamba layer
  • layers that are BOTH mamba and attention (zamba models)

For this, it adds the 2 following layer classes:

  • MambaLayer
  • MambaAndAttentionLayer (combining both)

By essence, MambaLayer has static shape (i.e. it does not depend of the sequence length). So they were added to both StaticCache and DynamicCache, to blend smoothly to what we already have.
MambaAndAttentionLayer on the other hand has only the mamba part that is static, and the attention part is a dynamic attention layer. It would however be very easy to add the full static equivalent if we want in the future.

Everything integrates smoothly with the existing cache machinery in the case of hybrid attention/mamba archs, i.e. functions such as get_seq_length, get_mask_sizes (used for mask creation notably) will always look at attention layers.

Compile

Except from the obvious benefits from having a standardized API that seamlessly work with our Cache construction, the new MambaLayer is fully compatible with compile, including cudagraphs! This means, any mamba model or alternating mamba/attention model can now be fully compiled with cudagraphs natively!

BC-breaking

The 🚨 marker here is only used 2 classes (MambaCache and FalconMambaCache) were previously public classes. They do no longer exist, so it's breaking in this way. It should not really have been made directly public imo, and I don't expect any direct usage so should be fine!

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@huggingface huggingface deleted a comment from github-actions Bot Mar 25, 2026
@Cyrilvallez
Copy link
Copy Markdown
Member Author

run-slow: mamba2 zamba2 granitemoehybrid falcon_h1 lfm2 lfm2_moe qwen3_5 bamba mamba nemotron_h qwen3_next zamba jamba qwen3_5_moe falcon_mamba

@huggingface huggingface deleted a comment from github-actions Bot Mar 25, 2026
@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/bamba", "models/falcon_h1", "models/falcon_mamba", "models/granitemoehybrid", "models/jamba", "models/lfm2", "models/lfm2_moe", "models/mamba", "models/mamba2", "models/nemotron_h", "models/qwen3_5", "models/qwen3_5_moe", "models/qwen3_next", "models/zamba", "models/zamba2"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 8c3ad2af workflow commit (merge commit)
PR 3df0d85a branch commit (from PR)
main 2f624917 base commit (on main)

Model CI Report

4 new failed tests from this PR 😭

  • bamba:
    tests/models/bamba/test_modeling_bamba.py::BambaModelIntegrationTest::test_simple_batched_generate_with_padding (❌ ⟹ ❌)
    tests/models/bamba/test_modeling_bamba.py::BambaModelIntegrationTest::test_simple_generate (❌ ⟹ ❌)

  • mamba2:
    tests/models/mamba2/test_modeling_mamba2.py::Mamba2IntegrationTest::test_batched_equivalence_with_cache (❌ ⟹ ❌)
    tests/models/mamba2/test_modeling_mamba2.py::Mamba2IntegrationTest::test_batched_equivalence_without_cache (❌ ⟹ ❌)

@Cyrilvallez
Copy link
Copy Markdown
Member Author

For reviewers (@ArthurZucker @vasqu), I checked locally and the 4 failed tests above are failing exactly the exact same way on main and this PR. Once again, I don't know why run-slow is flagging them, but they are all ok!
So this PR does not bring any new failure!

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functionality wise, I don't really have much to complain. My comments are mostly about avoiding messy names and standards:

  • Mamba is super popular but it is a variation of linear attention
  • Not all linear attentions (GDN like qwen) are mamba (e.g. no SSM view)
  • This will get messy if we force all linear attentions to be named after mamba

Imo, we should be careful here and focus on establishing a good standard here. Let's assume more linear attention flavors to pop up!

Btw, could we have more mixins with e.g. only conv (lfm) and more convs x recurrent state (olmo hybrid). Probably for the future, just a thought

Comment thread src/transformers/generation/utils.py Outdated
Comment thread src/transformers/generation/utils.py Outdated
Comment thread src/transformers/models/bamba/modular_bamba.py Outdated
"""
mamba_mask = attention_mask
if (past_key_values is not None and past_key_values.has_previous_state) or (
if (past_key_values is not None and past_key_values.has_previous_state()) or (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other point, but maybe we should move this to our mask API - essentially all linear attns will need this and then we can interact with layer types like SWA --> this will also allow vLLM to exchange the linear attn layers

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing you already have some attribute maps :D yea that would go hand in hand again then

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but I'd rather do it later, as this PR already refactors quite a lot of modeling files - easier to do in a second time

Comment thread src/transformers/models/falcon_h1/modular_falcon_h1.py
Comment thread src/transformers/cache_utils.py Outdated
Comment thread src/transformers/cache_utils.py
Comment thread src/transformers/cache_utils.py Outdated
return self.ssm_states


class MambaAndAttentionLayer(MambaLayer, DynamicLayer):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely possible to make a static version as well imo, but no rush let's get this right first 🫡

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the Static version can basically be a copy/paste of the Dynamic one, but inheriting from StaticLayer. Did not add it yet, as I don't think it's really useful for the time being indeed

Comment thread tests/generation/test_utils.py Outdated
Comment thread tests/test_modeling_common.py Outdated
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO also a good time to abstract Layer's Keys?
Would help for say FP8Indexer tthat can just use set_default / request for cache_keys={"indexer_kv"}

Comment thread src/transformers/models/bamba/modeling_bamba.py
Comment thread src/transformers/models/bamba/modular_bamba.py Outdated
)
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx)
ssm_state = cache_params.update("ssm_states", ssm_state, self.layer_idx)

this is what I had in mind TBH it scales with whatever naming, and howeverr many sub caches you have

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imagine quantizing this:
ssm_state_scales = cache_params.update("ssm_state_scales", ssm_state_scales, self.layer_idx)
instead of creating a new class

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue is that they don't all update the same way... So for now, I believe this is the easiest way to proceed, rather than try a dispatch based on kwarg name (because almost all models pass those as positional arg, not kwarg, not we don't have access to the name...).

I do want to explore a more general way with only update everywhere more easily in the future though, but it would be way too much (unrelated to mamba caches) changes for this PR!

Comment thread src/transformers/cache_utils.py Outdated
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just some last comments on my side for more details but honestly we could also leave it as-is

Comment thread src/transformers/generation/utils.py Outdated
if generation_config.cache_implementation != "dynamic_full":
# linear attention models always need to pass the config, otherwise it will use an Attention cache for the LinearAttention layers
is_linear_attention = any(
x in ("mamba", "conv", "linear_attention")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x in ("mamba", "conv", "linear_attention")
x in ("linear_attention_mamba", "conv", "linear_attention_minimax")

Wdyt about this naming convention? I think we will need some BC workings / breakings but I think it paves a clear path

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, would probably be very nice in the long run to harmonize all the names for sure - once again something I wanted to follow up with haha. We have way too many different names for the same things rn (from the lack of general coverage of those caches rn)

Comment thread src/transformers/generation/utils.py Outdated
Comment thread src/transformers/models/bamba/modeling_bamba.py
Comment on lines -922 to -925
if use_precomputed_states:
previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
else:
previous_states = torch.zeros_like(states[:, :1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I opened that on mamba2 but just for clarification where this comes from: Mamba2 can theoretically have an initial recurrent state (and I developed that native torch version so it carried over 😓) - it just never got established as it did not really improve anything perf wise on tasks. Although, I could image this to become a power feature and maybe necessary for CP support

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think there was some mistake introduced because it should only check whether the cache has a prev state and exists - not care for the seq len == 1 case

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, cases need harmonization haha - but not related to cache directly!

# 2. Convolution sequence transformation
if cache_params is not None and cache_params.has_previous_state:
cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False)
is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I swear it's copy paste mistakes 😠 but yea no worries, not on you and don't mind moving this to another PR

Comment thread src/transformers/__init__.py
Comment thread src/transformers/cache_utils.py Outdated
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: bamba, falcon_h1, falcon_mamba, granitemoehybrid, jamba, lfm2, lfm2_moe, mamba, mamba2, musicflamingo

@Cyrilvallez Cyrilvallez changed the title [Cache] Native mamba & hybrid cache 🚨 [Cache] Native mamba & hybrid cache Mar 31, 2026
@Cyrilvallez Cyrilvallez merged commit 2dba8e0 into main Mar 31, 2026
30 checks passed
@Cyrilvallez Cyrilvallez deleted the clean-mamba-cache branch March 31, 2026 13:09
sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Mar 31, 2026
* add Cache and test on Mamba

* fix

* fix

* fix

* fix

* fix

* final fix

* test hybrid with jamba

* fix tests

* fixes

* fix

* fix

* fix

* combine both types + zambas

* add config mapèping

* adjust tests

* fix

* fix

* fix

* more models

* final mambas

* config

* finalize almost everything

* simplify tests

* simplify tests further

* fix tests

* oupsi

* fix

* fix broken no_split_modules

* fix

* fixes

* fix

* fix

* fixes

* add layer type

* oupsi

* fix

* style

* fix

* fixes

* final fix

* forgot those qwens

* tests

* offloading

* much better static shape native design

* oupsi

* adjustments in generate

* allow cudagraphs

* small oupsi

* start renaming

* revert unrelated what are they doing here

* more renaming

* revert offloading change

* add offloading skips

* split shapes for tests

* comments and renaming
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Apr 4, 2026
* add Cache and test on Mamba

* fix

* fix

* fix

* fix

* fix

* final fix

* test hybrid with jamba

* fix tests

* fixes

* fix

* fix

* fix

* combine both types + zambas

* add config mapèping

* adjust tests

* fix

* fix

* fix

* more models

* final mambas

* config

* finalize almost everything

* simplify tests

* simplify tests further

* fix tests

* oupsi

* fix

* fix broken no_split_modules

* fix

* fixes

* fix

* fix

* fixes

* add layer type

* oupsi

* fix

* style

* fix

* fixes

* final fix

* forgot those qwens

* tests

* offloading

* much better static shape native design

* oupsi

* adjustments in generate

* allow cudagraphs

* small oupsi

* start renaming

* revert unrelated what are they doing here

* more renaming

* revert offloading change

* add offloading skips

* split shapes for tests

* comments and renaming
sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Apr 18, 2026
* add Cache and test on Mamba

* fix

* fix

* fix

* fix

* fix

* final fix

* test hybrid with jamba

* fix tests

* fixes

* fix

* fix

* fix

* combine both types + zambas

* add config mapèping

* adjust tests

* fix

* fix

* fix

* more models

* final mambas

* config

* finalize almost everything

* simplify tests

* simplify tests further

* fix tests

* oupsi

* fix

* fix broken no_split_modules

* fix

* fixes

* fix

* fix

* fixes

* add layer type

* oupsi

* fix

* style

* fix

* fixes

* final fix

* forgot those qwens

* tests

* offloading

* much better static shape native design

* oupsi

* adjustments in generate

* allow cudagraphs

* small oupsi

* start renaming

* revert unrelated what are they doing here

* more renaming

* revert offloading change

* add offloading skips

* split shapes for tests

* comments and renaming
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants