🚨 [Cache] Native mamba & hybrid cache#44950
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
run-slow: mamba2 zamba2 granitemoehybrid falcon_h1 lfm2 lfm2_moe qwen3_5 bamba mamba nemotron_h qwen3_next zamba jamba qwen3_5_moe falcon_mamba |
|
This comment contains models: ["models/bamba", "models/falcon_h1", "models/falcon_mamba", "models/granitemoehybrid", "models/jamba", "models/lfm2", "models/lfm2_moe", "models/mamba", "models/mamba2", "models/nemotron_h", "models/qwen3_5", "models/qwen3_5_moe", "models/qwen3_next", "models/zamba", "models/zamba2"] |
CI ResultsCommit Info
Model CI Report❌ 4 new failed tests from this PR 😭
|
|
For reviewers (@ArthurZucker @vasqu), I checked locally and the 4 failed tests above are failing exactly the exact same way on main and this PR. Once again, I don't know why |
vasqu
left a comment
There was a problem hiding this comment.
Functionality wise, I don't really have much to complain. My comments are mostly about avoiding messy names and standards:
- Mamba is super popular but it is a variation of linear attention
- Not all linear attentions (GDN like qwen) are mamba (e.g. no SSM view)
- This will get messy if we force all linear attentions to be named after mamba
Imo, we should be careful here and focus on establishing a good standard here. Let's assume more linear attention flavors to pop up!
Btw, could we have more mixins with e.g. only conv (lfm) and more convs x recurrent state (olmo hybrid). Probably for the future, just a thought
| """ | ||
| mamba_mask = attention_mask | ||
| if (past_key_values is not None and past_key_values.has_previous_state) or ( | ||
| if (past_key_values is not None and past_key_values.has_previous_state()) or ( |
There was a problem hiding this comment.
Other point, but maybe we should move this to our mask API - essentially all linear attns will need this and then we can interact with layer types like SWA --> this will also allow vLLM to exchange the linear attn layers
There was a problem hiding this comment.
Seeing you already have some attribute maps :D yea that would go hand in hand again then
There was a problem hiding this comment.
Agreed, but I'd rather do it later, as this PR already refactors quite a lot of modeling files - easier to do in a second time
| return self.ssm_states | ||
|
|
||
|
|
||
| class MambaAndAttentionLayer(MambaLayer, DynamicLayer): |
There was a problem hiding this comment.
Definitely possible to make a static version as well imo, but no rush let's get this right first 🫡
There was a problem hiding this comment.
Yes, the Static version can basically be a copy/paste of the Dynamic one, but inheriting from StaticLayer. Did not add it yet, as I don't think it's really useful for the time being indeed
ArthurZucker
left a comment
There was a problem hiding this comment.
IMO also a good time to abstract Layer's Keys?
Would help for say FP8Indexer tthat can just use set_default / request for cache_keys={"indexer_kv"}
| ) | ||
| if ssm_state is not None and cache_params is not None: | ||
| cache_params.ssm_states[self.layer_idx].copy_(ssm_state) | ||
| ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) |
There was a problem hiding this comment.
| ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) | |
| ssm_state = cache_params.update("ssm_states", ssm_state, self.layer_idx) |
this is what I had in mind TBH it scales with whatever naming, and howeverr many sub caches you have
There was a problem hiding this comment.
imagine quantizing this:
ssm_state_scales = cache_params.update("ssm_state_scales", ssm_state_scales, self.layer_idx)
instead of creating a new class
There was a problem hiding this comment.
Issue is that they don't all update the same way... So for now, I believe this is the easiest way to proceed, rather than try a dispatch based on kwarg name (because almost all models pass those as positional arg, not kwarg, not we don't have access to the name...).
I do want to explore a more general way with only update everywhere more easily in the future though, but it would be way too much (unrelated to mamba caches) changes for this PR!
vasqu
left a comment
There was a problem hiding this comment.
LGTM, just some last comments on my side for more details but honestly we could also leave it as-is
| if generation_config.cache_implementation != "dynamic_full": | ||
| # linear attention models always need to pass the config, otherwise it will use an Attention cache for the LinearAttention layers | ||
| is_linear_attention = any( | ||
| x in ("mamba", "conv", "linear_attention") |
There was a problem hiding this comment.
| x in ("mamba", "conv", "linear_attention") | |
| x in ("linear_attention_mamba", "conv", "linear_attention_minimax") |
Wdyt about this naming convention? I think we will need some BC workings / breakings but I think it paves a clear path
There was a problem hiding this comment.
Yup, would probably be very nice in the long run to harmonize all the names for sure - once again something I wanted to follow up with haha. We have way too many different names for the same things rn (from the lack of general coverage of those caches rn)
| if use_precomputed_states: | ||
| previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) | ||
| else: | ||
| previous_states = torch.zeros_like(states[:, :1]) |
There was a problem hiding this comment.
I think I opened that on mamba2 but just for clarification where this comes from: Mamba2 can theoretically have an initial recurrent state (and I developed that native torch version so it carried over 😓) - it just never got established as it did not really improve anything perf wise on tasks. Although, I could image this to become a power feature and maybe necessary for CP support
There was a problem hiding this comment.
So I think there was some mistake introduced because it should only check whether the cache has a prev state and exists - not care for the seq len == 1 case
There was a problem hiding this comment.
Yup, cases need harmonization haha - but not related to cache directly!
| # 2. Convolution sequence transformation | ||
| if cache_params is not None and cache_params.has_previous_state: | ||
| cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) | ||
| is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx) |
There was a problem hiding this comment.
I swear it's copy paste mistakes 😠 but yea no worries, not on you and don't mind moving this to another PR
|
[For maintainers] Suggested jobs to run (before merge) run-slow: bamba, falcon_h1, falcon_mamba, granitemoehybrid, jamba, lfm2, lfm2_moe, mamba, mamba2, musicflamingo |
* add Cache and test on Mamba * fix * fix * fix * fix * fix * final fix * test hybrid with jamba * fix tests * fixes * fix * fix * fix * combine both types + zambas * add config mapèping * adjust tests * fix * fix * fix * more models * final mambas * config * finalize almost everything * simplify tests * simplify tests further * fix tests * oupsi * fix * fix broken no_split_modules * fix * fixes * fix * fix * fixes * add layer type * oupsi * fix * style * fix * fixes * final fix * forgot those qwens * tests * offloading * much better static shape native design * oupsi * adjustments in generate * allow cudagraphs * small oupsi * start renaming * revert unrelated what are they doing here * more renaming * revert offloading change * add offloading skips * split shapes for tests * comments and renaming
* add Cache and test on Mamba * fix * fix * fix * fix * fix * final fix * test hybrid with jamba * fix tests * fixes * fix * fix * fix * combine both types + zambas * add config mapèping * adjust tests * fix * fix * fix * more models * final mambas * config * finalize almost everything * simplify tests * simplify tests further * fix tests * oupsi * fix * fix broken no_split_modules * fix * fixes * fix * fix * fixes * add layer type * oupsi * fix * style * fix * fixes * final fix * forgot those qwens * tests * offloading * much better static shape native design * oupsi * adjustments in generate * allow cudagraphs * small oupsi * start renaming * revert unrelated what are they doing here * more renaming * revert offloading change * add offloading skips * split shapes for tests * comments and renaming
* add Cache and test on Mamba * fix * fix * fix * fix * fix * final fix * test hybrid with jamba * fix tests * fixes * fix * fix * fix * combine both types + zambas * add config mapèping * adjust tests * fix * fix * fix * more models * final mambas * config * finalize almost everything * simplify tests * simplify tests further * fix tests * oupsi * fix * fix broken no_split_modules * fix * fixes * fix * fix * fixes * add layer type * oupsi * fix * style * fix * fixes * final fix * forgot those qwens * tests * offloading * much better static shape native design * oupsi * adjustments in generate * allow cudagraphs * small oupsi * start renaming * revert unrelated what are they doing here * more renaming * revert offloading change * add offloading skips * split shapes for tests * comments and renaming
What does this PR do?
As per the title. This PR finally makes mamba layer caches first class citizen, and adds native support for them.
It supports the following layers combinations:
For this, it adds the 2 following layer classes:
By essence,
MambaLayerhas static shape (i.e. it does not depend of the sequence length). So they were added to both StaticCache and DynamicCache, to blend smoothly to what we already have.MambaAndAttentionLayeron the other hand has only the mamba part that is static, and the attention part is a dynamic attention layer. It would however be very easy to add the full static equivalent if we want in the future.Everything integrates smoothly with the existing cache machinery in the case of hybrid attention/mamba archs, i.e. functions such as
get_seq_length,get_mask_sizes(used for mask creation notably) will always look at attention layers.Compile
Except from the obvious benefits from having a standardized API that seamlessly work with our Cache construction, the new MambaLayer is fully compatible with
compile, includingcudagraphs! This means, any mamba model or alternating mamba/attention model can now be fully compiled with cudagraphs natively!BC-breaking
The 🚨 marker here is only used 2 classes (
MambaCacheandFalconMambaCache) were previously public classes. They do no longer exist, so it's breaking in this way. It should not really have been made directly public imo, and I don't expect any direct usage so should be fine!