-
Notifications
You must be signed in to change notification settings - Fork 33.1k
🚨 [Cache] Native mamba & hybrid cache #44950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
db8e4ff
9d52598
659beee
29b91ab
3e02650
fb88345
1aeddfa
35db152
a50293c
1607fe2
ddc198a
bae4a78
984b578
cac5d17
bd8f9e9
b2f1bb8
7795808
18685c6
fcec6bc
b1df43f
fdb1579
b156ade
330e397
b60c6f5
0e8ca28
c2ddcf9
b23708f
18feef2
ce92f3d
ab4472b
08e6265
66d0716
c86f9bb
ba1b7d6
1785621
f684133
8ca92a9
0d991d7
670d09a
bc99c9a
63e0b93
eb018e7
f8a0702
fc27c37
9c616dd
6f85f54
f4fc801
6aca24e
13781f1
3df0d85
39dae28
eadcfa4
908f0da
cf87066
86de2bc
f5dfd79
476aaaf
7a69287
3600b89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1775,19 +1775,19 @@ def _prepare_static_cache( | |||||
| def _supports_default_dynamic_cache(cls: type["GenerativePreTrainedModel"]) -> bool: | ||||||
| """ | ||||||
| Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. | ||||||
| This adds exception for some models like `Mamba` models which use their own caches. | ||||||
| """ | ||||||
| # NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name | ||||||
| return not cls._is_stateful and all( | ||||||
| special_model_name not in cls.__name__.lower() | ||||||
| or "minimaxm2" in cls.__name__.lower() # name clash between minimax and minimax m2 | ||||||
| for special_model_name in [ | ||||||
| "reformer", | ||||||
| "minimax", | ||||||
| "xlnet", | ||||||
| "lfm2", | ||||||
| "lfm2_vl", | ||||||
| ] | ||||||
| unsupported_model_names = ( | ||||||
| "reformer", | ||||||
| "minimax", | ||||||
| "xlnet", | ||||||
| "olmohybrid", # olmo_hybrid cannot use linear attention cache for now as it uses split k,q,v conv states | ||||||
| "rwkv", | ||||||
| "xlstm", | ||||||
| ) | ||||||
| # name clash between minimax and minimax m2, so we add this "or" | ||||||
| return "minimaxm2" in cls.__name__.lower() or all( | ||||||
| unsupported_name not in cls.__name__.lower() for unsupported_name in unsupported_model_names | ||||||
| ) | ||||||
|
|
||||||
| def _prepare_cache_for_generation( | ||||||
|
|
@@ -1849,7 +1849,12 @@ def _prepare_cache_for_generation( | |||||
| generation_config.cache_implementation = "dynamic_full" | ||||||
|
|
||||||
| dynamic_cache_kwargs = {} | ||||||
| if generation_config.cache_implementation != "dynamic_full": | ||||||
| # linear attention models always need to pass the config, otherwise it will use an Attention cache for the LinearAttention layers | ||||||
| is_linear_attention = any( | ||||||
| x in ("mamba", "conv", "linear_attention") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Wdyt about this naming convention? I think we will need some BC workings / breakings but I think it paves a clear path
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, would probably be very nice in the long run to harmonize all the names for sure - once again something I wanted to follow up with haha. We have way too many different names for the same things rn (from the lack of general coverage of those caches rn) |
||||||
| for x in getattr(self.config.get_text_config(decoder=True), "layer_types", []) | ||||||
| ) | ||||||
| if generation_config.cache_implementation != "dynamic_full" or is_linear_attention: | ||||||
| dynamic_cache_kwargs["config"] = self.config.get_text_config(decoder=True) | ||||||
|
|
||||||
| if generation_config.cache_implementation == "offloaded": | ||||||
|
|
@@ -1862,7 +1867,7 @@ def _prepare_cache_for_generation( | |||||
| f"and will be removed in v5.13. Please only use one of {STATIC_CACHE_IMPLEMENTATIONS}, " | ||||||
| "and the layer structure will be inferred automatically." | ||||||
| ) | ||||||
| model_kwargs["past_key_values"] = self._prepare_static_cache( | ||||||
| model_kwargs[cache_name] = self._prepare_static_cache( | ||||||
| cache_implementation=generation_config.cache_implementation, | ||||||
| batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, | ||||||
| max_cache_len=max_cache_length, | ||||||
|
|
@@ -1878,19 +1883,19 @@ def _prepare_cache_for_generation( | |||||
| cache_config = generation_config.cache_config if generation_config.cache_config is not None else {} | ||||||
| cache_config.setdefault("config", self.config.get_text_config(decoder=True)) | ||||||
| backend = cache_config.pop("backend", "quanto") | ||||||
| model_kwargs["past_key_values"] = QuantizedCache(backend=backend, **cache_config) | ||||||
| model_kwargs[cache_name] = QuantizedCache(backend=backend, **cache_config) | ||||||
| # i.e. `cache_implementation` in [None, "dynamic", "offloaded", "dynamic_full"] | ||||||
| # TODO: prepare linear cache from a single API, instead of creating in modeling code | ||||||
| else: | ||||||
| model_kwargs["past_key_values"] = DynamicCache(**dynamic_cache_kwargs) | ||||||
| model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs) | ||||||
|
|
||||||
| if ( | ||||||
| self.config.is_encoder_decoder | ||||||
| and "past_key_values" in model_kwargs | ||||||
| and not isinstance(model_kwargs["past_key_values"], EncoderDecoderCache) | ||||||
| and cache_name in model_kwargs | ||||||
| and not isinstance(model_kwargs[cache_name], EncoderDecoderCache) | ||||||
| ): | ||||||
| model_kwargs["past_key_values"] = EncoderDecoderCache( | ||||||
| model_kwargs["past_key_values"], # self-attention cache | ||||||
| model_kwargs[cache_name] = EncoderDecoderCache( | ||||||
| model_kwargs[cache_name], # self-attention cache | ||||||
| DynamicCache(**dynamic_cache_kwargs), # cross-attention cache | ||||||
| ) | ||||||
|
|
||||||
|
|
@@ -1990,13 +1995,15 @@ def _valid_auto_compile_criteria( | |||||
| if generation_config.disable_compile: | ||||||
| return False | ||||||
|
|
||||||
| cache = model_kwargs.get("past_key_values", model_kwargs.get("cache_params")) | ||||||
|
|
||||||
| # Base logic | ||||||
| valid_hardware = self.device.type in ["cuda", "xpu"] or bool( | ||||||
| generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices | ||||||
| ) | ||||||
| using_compilable_cache = ( | ||||||
| isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable | ||||||
| ) | ||||||
| # Note: for some models that only use linear attention (e.g. Mamba), even a DynamicCache is compileable since all | ||||||
| # layers are, but we don't want to ALWAYS compile when calling `generate`, so we check the type | ||||||
| using_compilable_cache = cache is not None and cache.is_compileable and type(cache) is not DynamicCache | ||||||
| can_compile = valid_hardware and using_compilable_cache | ||||||
|
|
||||||
| # Exception 1: Some quantization methods do not support compilation | ||||||
|
|
@@ -3467,10 +3474,9 @@ def _assisted_decoding( | |||||
| # The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache | ||||||
| if not model_kwargs["use_cache"]: | ||||||
| raise ValueError("assisted generate requires `use_cache=True`") | ||||||
| if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] or ( | ||||||
| "past_key_values" in model_kwargs | ||||||
| and hasattr(model_kwargs["past_key_values"], "layers") | ||||||
| and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers) | ||||||
| if ( | ||||||
| generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] | ||||||
| or type(model_kwargs.get("past_key_values")) is StaticCache | ||||||
| ): | ||||||
| raise ValueError("assisted generate is not supported with Static cache classes`") | ||||||
| # Get the candidate generator, given the parameterization | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.