From 1f7395dd3c8f99166f38d36b9d0161f396ac264f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 22 Jul 2025 19:00:38 +0200 Subject: [PATCH 1/8] fix --- src/transformers/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c8471b60e449..54454c86ff63 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1294,6 +1294,7 @@ class DynamicCache(Cache): # Specialized constructor for DDP cache data, needed for BC def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + super().__init__(*args, **kwargs) # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the # iterable contains the key and value states for a layer gathered across replicas by torch.distributed @@ -1303,7 +1304,6 @@ def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.T if ddp_cache_data is not None: for key_states, value_states in ddp_cache_data: self.layers.append(DynamicLayer.from_tensors(key_states, value_states)) - super().__init__(*args, **kwargs) def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: """ From 9c6eb7c580a65acc5e715413bb5a05d617f6c7db Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 22 Jul 2025 19:10:47 +0200 Subject: [PATCH 2/8] use kwargs --- src/transformers/cache_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 54454c86ff63..9546d48b3a72 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -558,10 +558,10 @@ def __init__(self, cache: "Cache", offload_device: Union[str, torch.device] = "c self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers) if self.is_static: for i, layer in enumerate(cache.layers): - device = cache.layer_init_args["device"] if i == 0 else self.offload_device + device = cache.layer_init_kwargs["device"] if i == 0 else self.offload_device layer.keys = layer.keys.to(device) layer.values = layer.values.to(device) - self.original_device.append(cache.layer_init_args["device"]) + self.original_device.append(cache.layer_init_kwargs["device"]) if len(cache) != cache.num_hidden_layers: raise ValueError("If static layers are used, all cache layers must be initialized") @@ -1080,7 +1080,7 @@ def __init__( tp_size=tp_size, ) processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) - self.layer_init_args = parse_layer_args_from_model_config(config, **kwargs) + self.layer_init_kwargs = parse_layer_args_from_model_config(config, **kwargs) self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) self.append_new_layers(self.num_hidden_layers - 1) @@ -1138,10 +1138,10 @@ def append_new_layers(self, layer_idx: int) -> None: The index of the layer to append. """ while len(self.layers) <= layer_idx: - args = self.layer_init_args.copy() - if self.layer_init_args.get("layer_device_map", None) is not None: - args["device"] = args.pop("layer_device_map")[layer_idx] - new_layer = self.layer_classes[len(self.layers) % len(self.layer_classes)](**args) + kwargs = self.layer_init_kwargs.copy() + if self.layer_init_kwargs.get("layer_device_map", None) is not None: + kwargs["device"] = kwargs.pop("layer_device_map")[layer_idx] + new_layer = self.layer_classes[len(self.layers) % len(self.layer_classes)](**kwargs) self.layers.append(new_layer) @apply_processors From 2b4f52f85f124186868c9da3933c673c39532202 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 22 Jul 2025 20:13:20 +0200 Subject: [PATCH 3/8] simplify --- src/transformers/cache_utils.py | 324 ++++++++++++++++---------------- 1 file changed, 161 insertions(+), 163 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 9546d48b3a72..ed55074e2f29 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -4,6 +4,7 @@ import inspect import json import os +from abc import ABC, abstractmethod from collections.abc import Iterable from dataclasses import dataclass from typing import Any, Callable, Optional, Union @@ -24,7 +25,7 @@ logger = logging.get_logger(__name__) -class CacheLayerMixin: +class CacheLayerMixin(ABC): """Base, abstract class for a single layer's cache.""" is_compileable = False @@ -32,26 +33,22 @@ class CacheLayerMixin: def __init__(self): self.keys, self.values = None, None + @abstractmethod def update( self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Updates KV cache, returns updated keys/values of the layer.""" - raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") + ) -> tuple[torch.Tensor, torch.Tensor]: ... - def get_seq_length(self, cache_position=None) -> int: - """Returns the sequence length of this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.") + @abstractmethod + def get_seq_length(self, cache_position=None) -> int: ... - def get_max_cache_shape(self) -> int: - """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") + @abstractmethod + def get_max_cache_shape(self) -> int: ... - def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: - """Returns mask sizes for the layer.""" - raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.") + @abstractmethod + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ... def reset(self) -> None: """Resets the cache values while preserving the objects""" @@ -76,26 +73,6 @@ class DynamicLayer(CacheLayerMixin): See `CacheLayerMixin` for details on common methods that are implemented by all cache layers. """ - @classmethod - def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer": - """ - Build a `DynamicLayer` instance from pre-existing key/value tensors. - - Args: - keys (`torch.Tensor`): - Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. - values (`torch.Tensor`): - Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. - - Returns: - `DynamicLayer`: The newly constructed layer whose internal cache directly references - the supplied tensors. - """ - layer = cls() - layer.keys = keys - layer.values = values - return layer - def update( self, key_states: torch.Tensor, @@ -175,6 +152,26 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: kv_length = query_length + past_seen_tokens return kv_length, kv_offset + @classmethod + def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer": + """ + Build a `DynamicLayer` instance from pre-existing key/value tensors. + + Args: + keys (`torch.Tensor`): + Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. + values (`torch.Tensor`): + Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. + + Returns: + `DynamicLayer`: The newly constructed layer whose internal cache directly references + the supplied tensors. + """ + layer = cls() + layer.keys = keys + layer.values = values + return layer + class StaticLayer(CacheLayerMixin): """ @@ -1053,9 +1050,9 @@ class Cache: def __init__( self, + layer_classes: Union[list[type[CacheLayerMixin]], type[CacheLayerMixin]], config: Optional[PretrainedConfig] = None, - cache_processor: Optional[Union[str, type["CacheProcessor"]]] = None, - layer_classes: Optional[list[type["CacheLayerMixin"]]] = None, + cache_processor: Optional[Union[str, type[CacheProcessor]]] = None, max_batch_size: Optional[int] = None, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, @@ -1064,13 +1061,10 @@ def __init__( tp_size: Optional[int] = None, **kwargs, ): - self.layers: list["CacheLayerMixin"] = [] - processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor - - if layer_classes is None: - layer_classes = [DynamicLayer] - + self.layers: list[CacheLayerMixin] = [] self.layer_classes = layer_classes + + processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor kwargs.update( max_batch_size=max_batch_size, max_cache_len=max_cache_len, @@ -1080,6 +1074,7 @@ def __init__( tp_size=tp_size, ) processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) + self.layer_init_kwargs = parse_layer_args_from_model_config(config, **kwargs) self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) @@ -1141,7 +1136,11 @@ def append_new_layers(self, layer_idx: int) -> None: kwargs = self.layer_init_kwargs.copy() if self.layer_init_kwargs.get("layer_device_map", None) is not None: kwargs["device"] = kwargs.pop("layer_device_map")[layer_idx] - new_layer = self.layer_classes[len(self.layers) % len(self.layer_classes)](**kwargs) + + new_layer_class = ( + self.layer_classes[len(self.layers)] if isinstance(self.layer_classes, list) else self.layer_classes + ) + new_layer = new_layer_class(**kwargs) self.layers.append(new_layer) @apply_processors @@ -1294,7 +1293,7 @@ class DynamicCache(Cache): # Specialized constructor for DDP cache data, needed for BC def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): - super().__init__(*args, **kwargs) + super().__init__(layer_classes=DynamicLayer, *args, **kwargs) # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the # iterable contains the key and value states for a layer gathered across replicas by torch.distributed @@ -1390,9 +1389,9 @@ class OffloadedCache(DynamicCache): ensure the eviction is scheduled after all computations on that cache are finished. """ - def __init__(self, config: Optional[PretrainedConfig] = None) -> None: + def __init__(self) -> None: # Create the underlying cache with offload processor - super().__init__(cache_processor=OffloadedCacheProcessor, config=config) + super().__init__(cache_processor=OffloadedCacheProcessor) class StaticCache(Cache): @@ -1422,44 +1421,45 @@ class StaticCache(Cache): """ def __init__(self, *args, **kwargs): - super().__init__(layer_classes=[StaticLayer], *args, **kwargs) + super().__init__(layer_classes=StaticLayer, *args, **kwargs) -class HybridCache(Cache): +class OffloadedStaticCache(StaticCache): """ - Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window - attention and global attention in every other layer (originally implemented for Gemma2). - Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention. For more information, see the documentation of those layer types. + A drop-in replacement for StaticCache that conserves accelerator memory by offloading + cache tensors to CPU when not actively being used. + + This cache maintains the compilation-friendly properties of StaticCache while enabling + much longer sequences by offloading inactive layers to CPU memory. See `Cache` for details on common methods that are implemented by all cache classes. Example: - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> # Prepare a cache class with offloading >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = OffloadedStaticCache( + ... config=model.config, + ... max_batch_size=1, + ... max_cache_len=max_generated_length, + ... device=model.device, + ... dtype=model.dtype + ... ) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() + >>> outputs.past_key_values # access cache with offloaded layers + OffloadedStaticCache() ``` """ - def __init__(self, config: PretrainedConfig, *args, **kwargs): - if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None: - layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] - else: - layer_classes = [StaticLayer] - super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) class SlidingWindowCache(Cache): @@ -1502,7 +1502,99 @@ class SlidingWindowCache(Cache): """ def __init__(self, *args, **kwargs): - super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs) + super().__init__(layer_classes=SlidingWindowLayer, *args, **kwargs) + + +class HybridCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window + attention and global attention in every other layer (originally implemented for Gemma2). + Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] + for global attention. For more information, see the documentation of those layer types. + + See `Cache` for details on common methods that are implemented by all cache classes. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + + def __init__(self, config: PretrainedConfig, *args, **kwargs): + if hasattr(config, "layer_types"): + layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] + else: + # In this case, fall back to StaticCache + layer_classes = [StaticLayer] * config.num_hidden_layers + super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) + + +class HybridChunkedCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window + attention and global attention in every other layer, with support for prefill chunking (originally implemented + for Llama4). + Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] + for global attention. For more information, see the documentation of each subcomponent cache class. + + See `Cache` for details on common methods that are implemented by all cache classes. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + + def __init__(self, config: PretrainedConfig, *args, **kwargs): + if hasattr(config, "layer_types"): + layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] + else: + # In this case, fall back to StaticCache + layer_classes = [StaticLayer] * config.num_hidden_layers + super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) + + +class OffloadedHybridCache(HybridChunkedCache): + """ + A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading + cache tensors to CPU when not actively being used. + + This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling + much longer sequences by offloading inactive layers to CPU memory. + + See `Cache` for details on common methods that are implemented by all cache classes. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) class QuantizedCache(DynamicCache): @@ -1615,100 +1707,6 @@ def __init__(self, backend="HQQ", **kwargs) -> None: Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs) -class OffloadedStaticCache(StaticCache): - """ - A drop-in replacement for StaticCache that conserves accelerator memory by offloading - cache tensors to CPU when not actively being used. - - This cache maintains the compilation-friendly properties of StaticCache while enabling - much longer sequences by offloading inactive layers to CPU memory. - - See `Cache` for details on common methods that are implemented by all cache classes. - - Example: - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - - >>> # Prepare a cache class with offloading - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = OffloadedStaticCache( - ... config=model.config, - ... max_batch_size=1, - ... max_cache_len=max_generated_length, - ... device=model.device, - ... dtype=model.dtype - ... ) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache with offloaded layers - OffloadedStaticCache() - ``` - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) - - -class HybridChunkedCache(Cache): - """ - Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window - attention and global attention in every other layer, with support for prefill chunking (originally implemented - for Llama4). - Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention. For more information, see the documentation of each subcomponent cache class. - - See `Cache` for details on common methods that are implemented by all cache classes. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` - """ - - def __init__(self, config: PretrainedConfig, *args, **kwargs): - hybrid_map = LAYER_CLASS_MAP.copy() - hybrid_map["sliding_attention"] = ChunkedSlidingLayer - hybrid_map["chunked_attention"] = ChunkedSlidingLayer - if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None: - layer_classes = [hybrid_map[layer_type] for layer_type in config.layer_types] - else: - layer_classes = [StaticLayer] - super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) - - -class OffloadedHybridCache(HybridChunkedCache): - """ - A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading - cache tensors to CPU when not actively being used. - - This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling - much longer sequences by offloading inactive layers to CPU memory. - - See `Cache` for details on common methods that are implemented by all cache classes. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) - - class EncoderDecoderCache(Cache): """ Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and @@ -1998,7 +1996,7 @@ def parse_layer_args_from_model_config( LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = { "full_attention": StaticLayer, "sliding_attention": SlidingWindowLayer, - "chunked_attention": SlidingWindowLayer, + "chunked_attention": ChunkedSlidingLayer, } PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = { "offloaded": OffloadedCacheProcessor, From 275d168f77efde4373f8bbe12beccf4844936304 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 22 Jul 2025 20:16:41 +0200 Subject: [PATCH 4/8] Update cache_utils.py --- src/transformers/cache_utils.py | 39 ++------------------------------- 1 file changed, 2 insertions(+), 37 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ed55074e2f29..935b1cf88262 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1543,43 +1543,8 @@ def __init__(self, config: PretrainedConfig, *args, **kwargs): super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) -class HybridChunkedCache(Cache): - """ - Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window - attention and global attention in every other layer, with support for prefill chunking (originally implemented - for Llama4). - Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention. For more information, see the documentation of each subcomponent cache class. - - See `Cache` for details on common methods that are implemented by all cache classes. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` - """ - - def __init__(self, config: PretrainedConfig, *args, **kwargs): - if hasattr(config, "layer_types"): - layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] - else: - # In this case, fall back to StaticCache - layer_classes = [StaticLayer] * config.num_hidden_layers - super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) +# The mapping already handles dispatching the correct layers in Hybrid, this is only used for BC +class HybridChunkedCache(HybridCache): ... class OffloadedHybridCache(HybridChunkedCache): From d699b97135f948a22d9a634dba53aebb20c6a057 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 22 Jul 2025 20:24:35 +0200 Subject: [PATCH 5/8] Update cache_utils.py --- src/transformers/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 935b1cf88262..d4dbff2a2a26 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1704,7 +1704,7 @@ class EncoderDecoderCache(Cache): is_compileable = None def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): - super().__init__() + super().__init__(layer_classes=DynamicLayer) self.self_attention_cache = self_attention_cache self.cross_attention_cache = cross_attention_cache self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) From 30490317318398a5f8548c24c8589d20024a2527 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 22 Jul 2025 20:32:51 +0200 Subject: [PATCH 6/8] Update test_cache_utils.py --- tests/utils/test_cache_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 8dbcc1194314..14b29344f190 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -1307,7 +1307,7 @@ def test_hybrid_chunked_cache(self): config = copy.deepcopy(self.config) config.num_hidden_layers = 2 - config.layer_types = ["full_attention", "sliding_attention"] + config.layer_types = ["full_attention", "chunked_attention"] config.sliding_window = 2 max_cache_len = 4 chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len) @@ -1387,7 +1387,7 @@ def test_hybrid_chunked_cache_extra_cases(self): config = copy.deepcopy(self.config) config.num_hidden_layers = 1 - config.layer_types = ["sliding_attention"] + config.layer_types = ["chunked_attention"] config.sliding_window = 3 cache = HybridChunkedCache(config, max_batch_size=1, max_cache_len=3) From 647a426685bc8ec1f48e30fcb17b3deb25b0a3b1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 22 Jul 2025 20:42:12 +0200 Subject: [PATCH 7/8] fix --- src/transformers/cache_utils.py | 6 +++--- src/transformers/models/bamba/modeling_bamba.py | 4 ++-- src/transformers/models/bamba/modular_bamba.py | 6 +++--- .../models/granitemoehybrid/modeling_granitemoehybrid.py | 4 ++-- src/transformers/models/jamba/modeling_jamba.py | 4 ++-- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d4dbff2a2a26..8fb19adc9807 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1027,15 +1027,15 @@ class Cache: ``` Parameters: + layer_classes (`type[CacheLayerMixin]` or `list[type[CacheLayerMixin]]`): + A list of `CacheLayerMixin` classes to instantiate for the cache. If only a `CacheLayerMixin` class is + provided, then it is used for all layers. config (`PretrainedConfig`, *optional*): Model configuration used to infer number of layers, head sizes, default device/dtype, etc. cache_processor (`CacheProcessor` or `str`, *optional*): Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") or a CacheProcessor class. - layer_classes (`list[type[CacheLayerMixin]]`, *optional*): - List of `CacheLayerMixin` classes to instantiate for the cache. When shorter than the - required number of layers the list is cycled. Default is [DynamicLayer]. max_batch_size (`int`, *optional*): Maximum batch size for static caches. max_cache_len (`int`, *optional*): Maximum sequence length. For hybrid caches, SlidingWindowLayers are clamped to `min(sliding_window, max_cache_len)`, StaticLayers use full `max_cache_len`. diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index a228e35587e5..ef75f254cc20 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -31,7 +31,7 @@ from transformers.activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache, DynamicCache, DynamicLayer from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -104,7 +104,7 @@ class HybridMambaAttentionDynamicCache(Cache): is_compileable = False def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): - super().__init__() + super().__init__(layer_classes=DynamicLayer) self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba conv_kernel_size = config.mamba_d_conv diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index f99faa9ed78d..0f5165878369 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -42,7 +42,7 @@ segment_sum, ) -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicLayer from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -99,7 +99,7 @@ class BambaFlashAttentionKwargs(TypedDict, total=False): # Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer -class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache, Cache): +class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -114,7 +114,7 @@ class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache, Cache): """ def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): - Cache.__init__() + HybridMambaAttentionDynamicCache.__init__(layer_classes=DynamicLayer) self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba conv_kernel_size = config.mamba_d_conv diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index c42300843d74..598673586cb4 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -27,7 +27,7 @@ from transformers.activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache, DynamicCache, DynamicLayer from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_layers import GradientCheckpointingLayer @@ -240,7 +240,7 @@ class HybridMambaAttentionDynamicCache(Cache): is_compileable = False def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=torch.float16, device=None): - super().__init__() + super().__init__(layer_classes=DynamicLayer) self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba conv_kernel_size = config.mamba_d_conv diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 8f259bb0c017..817e181ec25e 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -28,7 +28,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache, DynamicCache, DynamicLayer from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available @@ -202,7 +202,7 @@ class HybridMambaAttentionDynamicCache(Cache): is_compileable = False def __init__(self, config, batch_size, dtype=torch.float16, device=None): - super().__init__() + super().__init__(layer_classes=DynamicLayer) self.dtype = dtype self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba From febc7abb5ff5a552f15e767c713d825609c0abca Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 22 Jul 2025 20:43:19 +0200 Subject: [PATCH 8/8] style --- src/transformers/models/bamba/modular_bamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 0f5165878369..be58fd3abd42 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -42,7 +42,7 @@ segment_sum, ) -from ...cache_utils import Cache, DynamicLayer +from ...cache_utils import DynamicLayer from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel