Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
db8e4ff
add Cache and test on Mamba
Cyrilvallez Mar 23, 2026
9d52598
fix
Cyrilvallez Mar 23, 2026
659beee
fix
Cyrilvallez Mar 23, 2026
29b91ab
fix
Cyrilvallez Mar 23, 2026
3e02650
fix
Cyrilvallez Mar 23, 2026
fb88345
fix
Cyrilvallez Mar 23, 2026
1aeddfa
final fix
Cyrilvallez Mar 23, 2026
35db152
test hybrid with jamba
Cyrilvallez Mar 23, 2026
a50293c
fix tests
Cyrilvallez Mar 23, 2026
1607fe2
fixes
Cyrilvallez Mar 23, 2026
ddc198a
fix
Cyrilvallez Mar 23, 2026
bae4a78
fix
Cyrilvallez Mar 24, 2026
984b578
fix
Cyrilvallez Mar 24, 2026
cac5d17
combine both types + zambas
Cyrilvallez Mar 24, 2026
bd8f9e9
add config mapèping
Cyrilvallez Mar 24, 2026
b2f1bb8
adjust tests
Cyrilvallez Mar 24, 2026
7795808
fix
Cyrilvallez Mar 24, 2026
18685c6
fix
Cyrilvallez Mar 24, 2026
fcec6bc
fix
Cyrilvallez Mar 24, 2026
b1df43f
more models
Cyrilvallez Mar 24, 2026
fdb1579
final mambas
Cyrilvallez Mar 24, 2026
b156ade
config
Cyrilvallez Mar 24, 2026
330e397
finalize almost everything
Cyrilvallez Mar 24, 2026
b60c6f5
simplify tests
Cyrilvallez Mar 24, 2026
0e8ca28
simplify tests further
Cyrilvallez Mar 24, 2026
c2ddcf9
fix tests
Cyrilvallez Mar 24, 2026
b23708f
oupsi
Cyrilvallez Mar 24, 2026
18feef2
fix
Cyrilvallez Mar 24, 2026
ce92f3d
fix broken no_split_modules
Cyrilvallez Mar 24, 2026
ab4472b
fix
Cyrilvallez Mar 24, 2026
08e6265
fixes
Cyrilvallez Mar 24, 2026
66d0716
fix
Cyrilvallez Mar 24, 2026
c86f9bb
fix
Cyrilvallez Mar 24, 2026
ba1b7d6
fixes
Cyrilvallez Mar 24, 2026
1785621
add layer type
Cyrilvallez Mar 24, 2026
f684133
oupsi
Cyrilvallez Mar 24, 2026
8ca92a9
fix
Cyrilvallez Mar 24, 2026
0d991d7
style
Cyrilvallez Mar 24, 2026
670d09a
fix
Cyrilvallez Mar 24, 2026
bc99c9a
Merge branch 'main' into clean-mamba-cache
Cyrilvallez Mar 24, 2026
63e0b93
fixes
Cyrilvallez Mar 24, 2026
eb018e7
final fix
Cyrilvallez Mar 24, 2026
f8a0702
forgot those qwens
Cyrilvallez Mar 24, 2026
fc27c37
tests
Cyrilvallez Mar 24, 2026
9c616dd
offloading
Cyrilvallez Mar 25, 2026
6f85f54
much better static shape native design
Cyrilvallez Mar 25, 2026
f4fc801
oupsi
Cyrilvallez Mar 25, 2026
6aca24e
adjustments in generate
Cyrilvallez Mar 25, 2026
13781f1
allow cudagraphs
Cyrilvallez Mar 25, 2026
3df0d85
small oupsi
Cyrilvallez Mar 25, 2026
39dae28
Merge branch 'main' into clean-mamba-cache
Cyrilvallez Mar 26, 2026
eadcfa4
start renaming
Cyrilvallez Mar 31, 2026
908f0da
revert unrelated what are they doing here
Cyrilvallez Mar 31, 2026
cf87066
Merge branch 'main' into clean-mamba-cache
Cyrilvallez Mar 31, 2026
86de2bc
more renaming
Cyrilvallez Mar 31, 2026
f5dfd79
revert offloading change
Cyrilvallez Mar 31, 2026
476aaaf
add offloading skips
Cyrilvallez Mar 31, 2026
7a69287
split shapes for tests
Cyrilvallez Mar 31, 2026
3600b89
comments and renaming
Cyrilvallez Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions docs/source/en/model_doc/falcon_mamba.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,6 @@ outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```

## FalconMambaCache

[[autodoc]] FalconMambaCache
- update_conv_state
- update_ssm_state
- reset

## FalconMambaConfig

[[autodoc]] FalconMambaConfig
Expand Down
7 changes: 0 additions & 7 deletions docs/source/en/model_doc/mamba.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,6 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
trainer.train()
```

## MambaCache

[[autodoc]] MambaCache
- update_conv_state
- update_ssm_state
- reset

## MambaConfig

[[autodoc]] MambaConfig
Expand Down
1 change: 0 additions & 1 deletion src/transformers/__init__.py
Comment thread
Cyrilvallez marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@
from .modeling_utils import AttentionInterface as AttentionInterface
from .modeling_utils import PreTrainedModel as PreTrainedModel
from .models import *
from .models.mamba.modeling_mamba import MambaCache as MambaCache
from .models.timm_wrapper import TimmWrapperImageProcessor as TimmWrapperImageProcessor

# Optimization
Expand Down
295 changes: 286 additions & 9 deletions src/transformers/cache_utils.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
"attention",
"sparse",
"dense",
"hybrid", # for layers that have both mamba and attention in zamba and zamba2
"moe", # for nemotron_h, which uses either attention, mamba or moe
)


Expand Down
58 changes: 32 additions & 26 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,19 +1775,19 @@ def _prepare_static_cache(
def _supports_default_dynamic_cache(cls: type["GenerativePreTrainedModel"]) -> bool:
"""
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
This adds exception for some models like `Mamba` models which use their own caches.
"""
# NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name
return not cls._is_stateful and all(
special_model_name not in cls.__name__.lower()
or "minimaxm2" in cls.__name__.lower() # name clash between minimax and minimax m2
for special_model_name in [
"reformer",
"minimax",
"xlnet",
"lfm2",
"lfm2_vl",
]
unsupported_model_names = (
"reformer",
"minimax",
"xlnet",
"olmohybrid", # olmo_hybrid cannot use linear attention cache for now as it uses split k,q,v conv states
"rwkv",
"xlstm",
)
# name clash between minimax and minimax m2, so we add this "or"
return "minimaxm2" in cls.__name__.lower() or all(
unsupported_name not in cls.__name__.lower() for unsupported_name in unsupported_model_names
)

def _prepare_cache_for_generation(
Expand Down Expand Up @@ -1849,7 +1849,12 @@ def _prepare_cache_for_generation(
generation_config.cache_implementation = "dynamic_full"

dynamic_cache_kwargs = {}
if generation_config.cache_implementation != "dynamic_full":
# linear attention models always need to pass the config, otherwise it will use an Attention cache for the LinearAttention layers
is_linear_attention = any(
x in ("mamba", "conv", "linear_attention")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x in ("mamba", "conv", "linear_attention")
x in ("linear_attention_mamba", "conv", "linear_attention_minimax")

Wdyt about this naming convention? I think we will need some BC workings / breakings but I think it paves a clear path

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, would probably be very nice in the long run to harmonize all the names for sure - once again something I wanted to follow up with haha. We have way too many different names for the same things rn (from the lack of general coverage of those caches rn)

for x in getattr(self.config.get_text_config(decoder=True), "layer_types", [])
)
if generation_config.cache_implementation != "dynamic_full" or is_linear_attention:
dynamic_cache_kwargs["config"] = self.config.get_text_config(decoder=True)

if generation_config.cache_implementation == "offloaded":
Expand All @@ -1862,7 +1867,7 @@ def _prepare_cache_for_generation(
f"and will be removed in v5.13. Please only use one of {STATIC_CACHE_IMPLEMENTATIONS}, "
"and the layer structure will be inferred automatically."
)
model_kwargs["past_key_values"] = self._prepare_static_cache(
model_kwargs[cache_name] = self._prepare_static_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
max_cache_len=max_cache_length,
Expand All @@ -1878,19 +1883,19 @@ def _prepare_cache_for_generation(
cache_config = generation_config.cache_config if generation_config.cache_config is not None else {}
cache_config.setdefault("config", self.config.get_text_config(decoder=True))
backend = cache_config.pop("backend", "quanto")
model_kwargs["past_key_values"] = QuantizedCache(backend=backend, **cache_config)
model_kwargs[cache_name] = QuantizedCache(backend=backend, **cache_config)
# i.e. `cache_implementation` in [None, "dynamic", "offloaded", "dynamic_full"]
# TODO: prepare linear cache from a single API, instead of creating in modeling code
else:
model_kwargs["past_key_values"] = DynamicCache(**dynamic_cache_kwargs)
model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)

if (
self.config.is_encoder_decoder
and "past_key_values" in model_kwargs
and not isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
and cache_name in model_kwargs
and not isinstance(model_kwargs[cache_name], EncoderDecoderCache)
):
model_kwargs["past_key_values"] = EncoderDecoderCache(
model_kwargs["past_key_values"], # self-attention cache
model_kwargs[cache_name] = EncoderDecoderCache(
model_kwargs[cache_name], # self-attention cache
DynamicCache(**dynamic_cache_kwargs), # cross-attention cache
)

Expand Down Expand Up @@ -1990,13 +1995,15 @@ def _valid_auto_compile_criteria(
if generation_config.disable_compile:
return False

cache = model_kwargs.get("past_key_values", model_kwargs.get("cache_params"))

# Base logic
valid_hardware = self.device.type in ["cuda", "xpu"] or bool(
generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices
)
using_compilable_cache = (
isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
)
# Note: for some models that only use linear attention (e.g. Mamba), even a DynamicCache is compileable since all
# layers are, but we don't want to ALWAYS compile when calling `generate`, so we check the type
using_compilable_cache = cache is not None and cache.is_compileable and type(cache) is not DynamicCache
can_compile = valid_hardware and using_compilable_cache

# Exception 1: Some quantization methods do not support compilation
Expand Down Expand Up @@ -3467,10 +3474,9 @@ def _assisted_decoding(
# The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache
if not model_kwargs["use_cache"]:
raise ValueError("assisted generate requires `use_cache=True`")
if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] or (
"past_key_values" in model_kwargs
and hasattr(model_kwargs["past_key_values"], "layers")
and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers)
if (
generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]
or type(model_kwargs.get("past_key_values")) is StaticCache
):
raise ValueError("assisted generate is not supported with Static cache classes`")
# Get the candidate generator, given the parameterization
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bamba/configuration_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class BambaConfig(PreTrainedConfig):
"""

model_type = "bamba"
attribute_map = {"layer_types": "layers_block_type"}
keys_to_ignore_at_inference = ["past_key_values"]

vocab_size: int = 128000
Expand Down
Loading
Loading