From 1c3cbcccc89c2af50e649c51430b846bedf2c157 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Sun, 29 Jun 2025 11:46:42 +0200 Subject: [PATCH 01/35] Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests --- docs/source/en/internal/generation_utils.md | 21 - docs/source/en/model_doc/falcon_mamba.md | 7 + docs/source/en/model_doc/mamba.md | 7 + src/transformers/__init__.py | 2 - src/transformers/cache_utils.py | 2530 ++++++++--------- .../generation/configuration_utils.py | 9 +- src/transformers/generation/utils.py | 117 +- .../models/falcon_h1/modeling_falcon_h1.py | 6 +- .../configuration_falcon_mamba.py | 26 +- .../falcon_mamba/modeling_falcon_mamba.py | 235 +- .../falcon_mamba/modular_falcon_mamba.py | 537 ++++ .../models/jamba/modeling_jamba.py | 6 +- .../models/mamba/modeling_mamba.py | 162 +- .../models/mamba2/modeling_mamba2.py | 7 +- .../models/zamba/modeling_zamba.py | 7 +- .../models/zamba2/modeling_zamba2.py | 8 +- .../models/zamba2/modular_zamba2.py | 2 +- src/transformers/utils/dummy_pt_objects.py | 7 - .../test_modeling_falcon_mamba.py | 43 +- tests/models/mamba/test_modeling_mamba.py | 37 +- tests/utils/test_cache_utils.py | 43 +- 21 files changed, 2246 insertions(+), 1573 deletions(-) create mode 100644 src/transformers/models/falcon_mamba/modular_falcon_mamba.py diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 1c17e99d5da3..b956828fdcff 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -366,43 +366,22 @@ A [`Constraint`] can be used to force the generation to include specific tokens - validate [[autodoc]] DynamicCache - - update - - get_seq_length - - reorder_cache - - to_legacy_cache - - from_legacy_cache [[autodoc]] QuantizedCache - - update - - get_seq_length [[autodoc]] QuantoQuantizedCache [[autodoc]] HQQQuantizedCache [[autodoc]] OffloadedCache - - update - - prefetch_layer - - evict_previous_layer [[autodoc]] StaticCache - - update - - get_seq_length - - reset [[autodoc]] OffloadedStaticCache - - update - - get_seq_length - - reset [[autodoc]] HybridCache - - update - - get_seq_length - - reset [[autodoc]] SlidingWindowCache - - update - - reset [[autodoc]] EncoderDecoderCache - get_seq_length diff --git a/docs/source/en/model_doc/falcon_mamba.md b/docs/source/en/model_doc/falcon_mamba.md index a8d7886894b2..0b797c7c7829 100644 --- a/docs/source/en/model_doc/falcon_mamba.md +++ b/docs/source/en/model_doc/falcon_mamba.md @@ -110,6 +110,13 @@ outputs = model.generate(**inputs, max_new_tokens=100) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` +## FalconMambaCache + +[[autodoc]] FalconMambaCache + - update_conv_state + - update_ssm_state + - reset + ## FalconMambaConfig [[autodoc]] FalconMambaConfig diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 9ce98d8516a8..2d5543cdcf84 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -115,6 +115,13 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) trainer.train() ``` +## MambaCache + +[[autodoc]] MambaCache + - update_conv_state + - update_ssm_state + - reset + ## MambaConfig [[autodoc]] MambaConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 892acd32ead7..b07b0e849a9e 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -358,7 +358,6 @@ "EncoderDecoderCache", "HQQQuantizedCache", "HybridCache", - "MambaCache", "OffloadedCache", "OffloadedStaticCache", "QuantizedCache", @@ -839,7 +838,6 @@ EncoderDecoderCache, HQQQuantizedCache, HybridCache, - MambaCache, OffloadedCache, OffloadedStaticCache, QuantizedCache, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 04ccc6f7efcf..d05220cb62a5 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -2,6 +2,7 @@ import importlib.metadata import json import os +import warnings from collections.abc import Iterable from dataclasses import dataclass from typing import Any, Optional, Union @@ -18,116 +19,295 @@ if is_hqq_available(): from hqq.core.quantize import Quantizer as HQQQuantizer + logger = logging.get_logger(__name__) -# Utility functions for static/sliding cache update logic -def _static_cache_update( - k_cache: torch.Tensor, - v_cache: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_position: Optional[torch.LongTensor], -) -> tuple[torch.Tensor, torch.Tensor]: +class CacheProcessor: """ - Updates the static cache tensors in place. - - Args: - k_cache (`torch.Tensor`): The key cache tensor to update. - v_cache (`torch.Tensor`): The value cache tensor to update. - key_states (`torch.Tensor`): The new key states to add. - value_states (`torch.Tensor`): The new value states to add. - cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted. - If None, the entire cache is overwritten (prefill). - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place). + Base class for cache processors that can be applied to modify cache behavior. + This class should be subclassed. """ - if cache_position is None: - # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. - k_cache.copy_(key_states) - v_cache.copy_(value_states) - else: - # Generation phase. Update specific positions. - # Use index_copy_ for in-place update (compile-friendly). - try: - k_cache.index_copy_(2, cache_position, key_states) - v_cache.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # Fallback for devices like MPS where index_copy_ might not be supported. - k_cache[:, :, cache_position] = key_states - v_cache[:, :, cache_position] = value_states - return k_cache, v_cache - - -def _sliding_cache_update( - k_cache: torch.Tensor, - v_cache: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_position: torch.LongTensor, - max_cache_len: int, -) -> tuple[torch.Tensor, torch.Tensor]: + + def init(self, cache: "Cache", **kwargs) -> None: + """ + Initialize the processor and perform compatibility checks with the cache. + + Args: + cache (`Cache`): The cache instance this processor will be applied to. + **kwargs: Additional arguments that may be needed for initialization. + """ + raise NotImplementedError("Make sure to implement `init` in a subclass.") + + def pre_update( + self, + cache: "Cache", + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Hook called before the cache update. Can modify the key/value states. + + Args: + cache (`Cache`): The cache instance. + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + layer_idx (`int`): The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, `optional`): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states. + """ + return key_states, value_states + + def post_update( + self, + cache: "Cache", + key_tensors: torch.Tensor, + value_tensors: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Hook called after the cache update. Can process the cached data. + + Args: + cache (`Cache`): The cache instance. + key_states (`torch.Tensor`): The key states that were cached. + value_states (`torch.Tensor`): The value states that were cached. + layer_idx (`int`): The index of the layer that was updated. + cache_kwargs (`dict[str, Any]`, `optional`): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The final key and value states to return. + """ + return key_tensors, value_tensors + + +class CacheProcessorList(list): """ - Updates the sliding window cache tensors, returning the potentially modified tensors. - - Args: - k_cache (`torch.Tensor`): The key cache tensor to update. - v_cache (`torch.Tensor`): The value cache tensor to update. - key_states (`torch.Tensor`): The new key states to add. - value_states (`torch.Tensor`): The new value states to add. - cache_position (`torch.LongTensor`): The position indices where the new states should be inserted. - max_cache_len (`int`): The maximum length of the sliding window cache. - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update. - For prefill > window, these are the full input states. - Otherwise, they are the updated cache tensors. + list of cache processors that can be applied to a cache. """ - # Handle prefill phase when prompt length > sliding_window_size - if cache_position.shape[0] > max_cache_len: - new_k = key_states[:, :, -max_cache_len:, :] - new_v = value_states[:, :, -max_cache_len:, :] - k_cache.copy_(new_k) - v_cache.copy_(new_v) + + def init(self, cache: "Cache", **kwargs) -> None: + """Initialize all processors in the list.""" + for processor in self: + processor.init(cache, **kwargs) + + def pre_update( + self, + cache: "Cache", + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply pre_update hook for all processors.""" + for processor in self: + key_states, value_states = processor.pre_update(cache, key_states, value_states, layer_idx, cache_kwargs) return key_states, value_states - # Sliding window logic for generation phase or prefill < window - slicing = torch.arange(max_cache_len, device=value_states.device) - current_seq_len = cache_position[-1] + 1 # Use last position to determine current length - to_shift = current_seq_len > max_cache_len - indices = (slicing + to_shift.sum()) % max_cache_len + def post_update( + self, + cache: "Cache", + key_tensors: torch.Tensor, + value_tensors: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply post_update hook for all processors.""" + for processor in self: + key_tensors, value_tensors = processor.post_update( + cache, key_tensors, value_tensors, layer_idx, cache_kwargs + ) + return key_tensors, value_tensors + + +class KVList: + """Efficiently simulates layer-indexed key or value lists from a layered cache. + This allows for BC access, e.g., cache.key_cache[idx] or cache.value_cache[idx].""" + + def __init__(self, layers, cache_type="key"): + self.layers = layers + self.cache_type = cache_type + + def __getitem__(self, idx): + if isinstance(idx, slice): + return [getattr(layer, f"{self.cache_type}_cache") for layer in self.layers[idx]] + return getattr(self.layers[idx], f"{self.cache_type}_cache") + + def __setitem__(self, idx, value): + if isinstance(idx, slice): + for layer, val in zip(self.layers[idx], value): + setattr(layer, f"{self.cache_type}_cache", val) + else: + setattr(self.layers[idx], f"{self.cache_type}_cache", value) + + def __len__(self): + return len(self.layers) + + def __iter__(self): + for layer in self.layers: + yield getattr(layer, f"{self.cache_type}_cache") + + def __bool__(self): + return bool(self.layers) + + +class CacheLayer: + """Base, abstract class for a single layer's cache.""" + + is_compileable = False + + def __init__( + self, + config: Optional["CacheConfig"] = None, + ): + self.key_cache = None + self.value_cache = None + + @classmethod + def from_kv(cls, key_cache: torch.Tensor, value_cache: torch.Tensor) -> None: + cache = cls() + cache.key_cache = key_cache + cache.value_cache = value_cache + return cache + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Updates KV cache, returns updated K and V for this layer.""" + raise NotImplementedError("Make sure to implement `update` in a subclass.") - k_out_shifted = k_cache[:, :, indices] - v_out_shifted = v_cache[:, :, indices] + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache. + Early stops since first layer is enough to compute sequence length""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length, _ = self.get_seq_length(layer_idx) + if max_length != -1 and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length - # Clamp cache_position to determine the *target index* within the shifted cache view - update_position = cache_position.clamp(min=0, max=max_cache_len - 1) + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" + raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") - try: - k_out_updated = k_out_shifted.index_copy(2, update_position, key_states) - v_out_updated = v_out_shifted.index_copy(2, update_position, value_states) - except NotImplementedError: - # Fallback for MPS: clone and modify the clone - k_out_updated = k_out_shifted.clone() - v_out_updated = v_out_shifted.clone() - k_out_updated[:, :, update_position] = key_states - v_out_updated[:, :, update_position] = value_states + def reset(self) -> None: + """Resets this layer's cache.""" + raise NotImplementedError("Make sure to implement `reset` in a subclass.") + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + """Reorders this layer's cache for beam search.""" + if self.key_cache.numel(): + device = self.key_cache.device + self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) + if self.value_cache.numel(): + device = self.value_cache.device + self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) - k_cache.copy_(k_out_updated) - v_cache.copy_(v_out_updated) - return k_out_updated, v_out_updated + def __repr__(self): + return f"{self.__class__.__name__}(K={self.key_cache}, V={self.value_cache})" class Cache: """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. + Base, abstract class for all caches. The actual data structure is specific to the layers. + This class handles propagation of operations across layers. + + Parameters: + config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): + Model configuration for shape/device info, or DDP-distributed cache data for compatibility. + If DDP-distributed cache data, must be an iterable of (key_states, value_states) tuples for each layer. + processors (`CacheProcessorList`, *optional*): + List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): + Pattern of cache layer types to use. Defaults to `(DynamicLayer,)`. Must be a tuple whose length divides + the total number of layers. The pattern repeats to fill all layers. Examples: `(StaticLayer,)` for a + uniform cache, `(StaticLayer, StaticLayer, SlidingWindowLayer)` for a hybrid cache with repeating pattern, + or specify the full structure like `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)`. + Additional arguments for cache configuration: + - `max_batch_size`/`batch_size` (`int`): Maximum batch size for static caches + - `max_cache_len` (`int`): Maximum sequence length. For hybrid caches: + * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` + * StaticLayers: uses full `max_cache_len` + - `device` (`torch.device`): Device for cache tensors + - `dtype` (`torch.dtype`): Data type for cache tensors + - `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping + + Note for hybrid caches (blocks of (StaticLayer, ..., SlidingWindowLayer) repeated across layers): + - Requires `model_config.sliding_window` to be set + - Uses `sliding_window_pattern` (default: 2) to determine layer alternation if pattern not specified + - SlidingWindow layers are limited to sliding window size, Static layers use full max_cache_len """ - is_compileable = False + layers = [] + pattern_block = () # Subclasses can define their layer pattern statically - def __init__(self): - super().__init__() + def __init__( + self, + config_or_ddp_cache_data: Optional[ + Union[PretrainedConfig, Iterable[tuple[torch.Tensor, torch.Tensor]]] + ] = None, + processors: Optional[CacheProcessorList] = None, + pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, + *args, + **kwargs, + ): + self.layers: list[CacheLayer] = [] + self.processors = processors if processors is not None else CacheProcessorList() + pattern_block = pattern_block or self.pattern_block or (DynamicLayer,) + + if isinstance(config_or_ddp_cache_data, PretrainedConfig): + model_config = config_or_ddp_cache_data + elif isinstance(config_or_ddp_cache_data, Iterable): + _distributed_cache_data = config_or_ddp_cache_data + # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 + # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the + # iterable contains the key and value states for a layer gathered across replicas by torch.distributed + # (shape=[global batch size, num_heads, seq_len, head_dim]). + # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break + # compatibility. The name of the argument doesn't matter. + assert pattern_block == (DynamicLayer,), "torch DDP is only supported for DynamicCache" + for key_states, value_states in _distributed_cache_data: + self.layers.append(DynamicLayer.from_kv(key_states, value_states)) + self.processors.init(self, **kwargs) + return + else: + model_config = kwargs.pop("config", None) + + self.config, self.pattern_block = CacheConfig.from_model_config(model_config, pattern_block, *args, **kwargs) + self.layer_types = [self.pattern_block[i % len(self.pattern_block)] for i in range(self.config.num_layers)] + + for idx, layer_type in enumerate(self.layer_types): + layer = layer_type(self.config.to_layer(idx)) + self.layers.append(layer) + + self.processors.init(self, **kwargs) + + def grow_layers_to(self, layer_idx): + while len(self.layers) <= layer_idx: + next_type_idx = len(self.layer_types) % len(self.pattern_block) + next_layer_type = self.pattern_block[next_type_idx] + self.layer_types.append(next_layer_type) + self.layers.append(next_layer_type()) + + @property + def key_cache(self) -> KVList: + """Returns a list-like object of key cache tensors indexed by layer.""" + return KVList(self.layers, "key") + + @property + def value_cache(self) -> KVList: + """Returns a list-like object of value cache tensors indexed by layer.""" + return KVList(self.layers, "value") def update( self, @@ -153,48 +333,99 @@ def update( Return: A tuple containing the updated key and value states. """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") + key_states, value_states = self.processors.pre_update(self, key_states, value_states, layer_idx, cache_kwargs) + self.grow_layers_to(layer_idx) + key_tensors, value_tensors = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + key_tensors, value_tensors = self.processors.post_update( + self, key_tensors, value_tensors, layer_idx, cache_kwargs + ) + return key_tensors, value_tensors - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self.layers): + return self.layers[layer_idx].key_cache, self.layers[layer_idx].value_cache + else: + raise KeyError( + f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" + ) - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length (i.e. max capacity) of the cache object""" - raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.layers[layer_idx].key_cache, self.layers[layer_idx].value_cache) - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_cache_shape() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + # Best effort BC support for subclasses + if self.layers is None: + if getattr(self, "key_cache", None) is not None: + return len(self.key_cache) + return 0 + dynamic_empty = ( + len(self.layers) == 1 and isinstance(self.layers[0], DynamicLayer) and self.layers[0].key_cache is None + ) + return len(self.layers) if not dynamic_empty else 0 - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - if self.value_cache[layer_idx].numel(): - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + def __getattr__(self, name): + """ + Dynamically handle any method call / attribute access by forwarding to layers. + If it doesn't exist on all layers, raises AttributeError. + For attributes, checks all different layer types and reduces to an set of unique values (or to a single value) + For methods, returns a function that propagates the call to all layers with carry-over state (e.g. update, reset) + """ + if name in ("__getstate__", "__setstate__"): + raise AttributeError(name) + + # Check if the attribute/method exists and gather values if it is an attribute + attribute_values = [] + for i, layer in enumerate(self.layers[: len(self.pattern_block)]): + if not hasattr(layer, name): + raise AttributeError( + f"Layer {i} ({layer.__class__.__name__}) of {self.__class__.__name__} does not support `{name}`" + ) + if not callable(getattr(layer, name)): + attribute_values.append(getattr(layer, name)) - @property - def seen_tokens(self): - logger.warning_once( - "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " - "model input instead." - ) - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None + if attribute_values: + assert len(attribute_values) == len(self.pattern_block), ( + f"Cache {self.__class__.__name__} gathered {len(attribute_values)} values for {name}, but there are {len(self.pattern_block)} layers." + ) + values = set(attribute_values) + if len(values) == 1: + return values.pop() + else: + if all(isinstance(value, bool) for value in values): + return all(values) + else: + raise ValueError( + f"Cache {self.__class__.__name__}:{self.pattern_block} has multiple values for {name}: {attribute_values}. This is not supported." + ) + + # If the attribute is a method, we propagate it to all layers + def propagate_to_layers(*args, **kwargs): + for layer in self.layers: + return_value = getattr(layer, name)(*args, **kwargs) + if return_value is not None: + break + return return_value + + return propagate_to_layers + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" + if layer_idx >= len(self.layers): + return 0 + return self.layers[layer_idx].get_seq_length() def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: """ @@ -203,10 +434,49 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), for each layer. """ - query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens - return kv_length, 0 + if isinstance(self.layers[layer_idx], SlidingWindowLayer): + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + + local_mask_kv_offset = torch.clamp(first_cache_position - self.config.sliding_window + 1, min=0) + # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns + local_mask_kv_length = max(query_length, self.config.sliding_window) + return local_mask_kv_length, local_mask_kv_offset + + full_mask_kv_offset = 0 + if isinstance(self.layers[layer_idx], StaticLayer): + full_mask_kv_length = self.get_max_cache_shape() + return full_mask_kv_length, full_mask_kv_offset + else: + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, full_mask_kv_offset + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: + """Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer in self.layers: + if layer is not None: + legacy_cache += ((layer.key_cache, layer.value_cache),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None + ) -> "Cache": + """Converts a cache in the legacy cache format into an equivalent `Cache`. Used for + backward compatibility.""" + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def __repr__(self): + return f"{self.__class__.__name__}(layers={self.layers})" @dataclass @@ -215,7 +485,65 @@ class CacheConfig: Base class for cache configs """ - cache_implementation: None + def __init__(self, num_layers: Optional[int] = None, cache_implementation: Optional[str] = None): + self.num_layers = num_layers + self.cache_implementation = cache_implementation + + @classmethod + def from_model_config( + cls, + config: Optional[PretrainedConfig], + pattern_block: tuple[type["CacheLayer"], ...], + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: Optional[torch.dtype] = None, + layer_device_map=None, + max_batch_size: Optional[int] = None, + ) -> "CacheConfig": + num_layers = getattr(config, "num_hidden_layers", len(pattern_block)) + # No model config -> must be a dynamic cache, return bare CacheConfig + if config is None: + return cls(num_layers=num_layers), pattern_block + # Build a StaticCacheConfig for any kind of static: hybrid, sliding or static + else: + # Rename max_batch_size to batch_size + if max_batch_size is not None: + batch_size = max_batch_size + # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) + if StaticLayer in pattern_block and SlidingWindowLayer in pattern_block: + if getattr(config, "sliding_window", None) is None: + raise ValueError( + "Setting up a hybrid or sliding window KVCache requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) + max_cache_len = max_cache_len or config.max_position_embeddings + sliding_window_len = min(getattr(config, "sliding_window", max_cache_len) or max_cache_len, max_cache_len) + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: + head_dim = ( + config.head_dim + if getattr(config, "head_dim", None) is not None + else config.hidden_size // config.num_attention_heads + ) + num_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + cache_config = StaticCacheConfig( + batch_size=batch_size, + max_cache_len=max_cache_len, + device=torch.device(device) if device is not None else None, + dtype=dtype, + layer_device_map=layer_device_map, + head_dim=head_dim, + num_heads=num_heads, + sliding_window=sliding_window_len, + num_layers=num_layers, + ) + return cache_config, pattern_block @classmethod def from_dict(cls, config_dict, **kwargs): @@ -238,6 +566,9 @@ def from_dict(cls, config_dict, **kwargs): kwargs.pop(key, None) return config + def to_layer(self, layer_idx: int) -> "CacheLayer": + return self + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file def to_json_file(self, json_file_path: Union[str, os.PathLike]): """ @@ -408,15 +739,61 @@ def validate(self): @dataclass class StaticCacheConfig(CacheConfig): """ - Configuration class for static cache settings. + Configuration class for static and sliding window cache settings. """ - cache_implementation = "static" + batch_size: Optional[int] = None + max_cache_len: Optional[int] = None + device: Union[str, torch.device] = None + dtype: Optional[torch.dtype] = None + layer_device_map: Optional[dict[int, Union[str, torch.device]]] = None + head_dim: Optional[int] = None + num_heads: Optional[int] = None + sliding_window: Optional[int] = None + num_layers: Optional[int] = None + cache_implementation: Optional[str] = None + + def __post_init__(self): + self.cache_implementation = "static" + if self.batch_size is None: + raise ValueError("`batch_size` is required for static cache") + if self.max_cache_len is None: + raise ValueError("`max_cache_len` is required for static cache") + if self.device is None: + self.device = "cpu" + logger.warning_once("`device` not set in cache initialization, using default `cpu`") + if self.dtype is None: + self.dtype = torch.float32 + logger.warning_once("`dtype` not set in cache initialization, using default `float32`") + + def for_layer(self, layer_idx: int): + """ + Returns a StaticCacheConfig for a given layer index. + """ + device = self.layer_device_map[layer_idx] if self.layer_device_map is not None else self.device + return StaticCacheConfig( + self.batch_size, + self.max_cache_len, + device, + self.dtype, + None, + self.head_dim, + self.num_heads, + self.sliding_window, + ) - def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): - self.batch_size = batch_size - self.max_cache_len = max_cache_len - self.device = device + @property + def dtype(self): # noqa: F811 + return getattr(torch, self._dtype) if self._dtype is not None else None + + @dtype.setter + def dtype(self, value): + if isinstance(value, torch.dtype): + self._dtype = str(value).split(".")[-1] + elif isinstance(value, str): + self._dtype = value + else: + self._dtype = None def validate(self): """Validates if the arguments passed are correct""" @@ -445,6 +822,90 @@ def validate(self): ) +class DynamicLayer(CacheLayer): + """ + A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. + It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. + """ + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + cache_kwargs (`dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`. + + Return: + A tuple containing the updated key and value states. + """ + if self.key_cache is None: + self.key_cache = key_states + self.value_cache = value_states + else: + self.key_cache = torch.cat([self.key_cache, key_states], dim=-2) + self.value_cache = torch.cat([self.value_cache, value_states], dim=-2) + + return self.key_cache, self.value_cache + + def get_seq_length(self, cache_position: Optional[torch.LongTensor] = None) -> int: + """Returns the sequence length of the cached states.""" + # TODO: deprecate this function in favor of `cache_position` + if self is None or self.key_cache is None: + return 0 + return self.key_cache.shape[-2] + + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" + return -1 + + def reset(self) -> None: + """Resets the cache values while preserving the objects""" + self.key_cache = torch.tensor([], dtype=self.key_cache.dtype, device=self.key_cache.device) + self.value_cache = torch.tensor([], dtype=self.value_cache.dtype, device=self.value_cache.device) + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + """Reorders the cache for beam search, given the selected beam indices.""" + if self.key_cache is not None and self.key_cache.numel(): + self.key_cache = self.key_cache.index_select(0, beam_idx.to(self.key_cache.device)) + if self.value_cache is not None and self.value_cache.numel(): + self.value_cache = self.value_cache.index_select(0, beam_idx.to(self.value_cache.device)) + + def crop(self, max_length: int) -> None: + """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens.""" + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + if self.key_cache.numel(): + self.key_cache = self.key_cache[..., :max_length, :] + self.value_cache = self.value_cache[..., :max_length, :] + + def batch_repeat_interleave(self, repeats: int) -> None: + """Repeat the cache `repeats` times in the batch dimension.""" + if self.key_cache.numel(): + self.key_cache = self.key_cache.repeat_interleave(repeats, dim=0) + self.value_cache = self.value_cache.repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor) -> None: + """Only keep the `indices` in the batch dimension of the cache.""" + if self.key_cache.numel(): + self.key_cache = self.key_cache[indices, ...] + self.value_cache = self.value_cache[indices, ...] + + class DynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. @@ -470,184 +931,7 @@ class DynamicCache(Cache): ``` """ - def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: - super().__init__() - self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121 - # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the - # iterable contains the key and value states for a layer gathered across replicas by torch.distributed - # (shape=[global batch size, num_heads, seq_len, head_dim]). - # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break - # compatibility. The name of the argument doesn't matter. - if _distributed_cache_data is not None: - for key_states, value_states in _distributed_cache_data: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - - def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self): - return (self.key_cache[layer_idx], self.value_cache[layer_idx]) - else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") - - def __iter__(self): - """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over - keys and values - """ - for layer_idx in range(len(self)): - yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the cache - if key_states is not None: - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif ( - not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model - ): # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or not self.key_cache[layer_idx].numel() # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - return layer_seq_length - - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" - return None - - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" - legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None - ) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for - backward compatibility.""" - cache = cls() - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def crop(self, max_length: int): - """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be - negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" - # In case it is negative - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - self._seen_tokens = max_length - for idx in range(len(self.key_cache)): - if self.key_cache[idx].numel(): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def batch_split(self, full_batch_size: int, split_size: int) -> list["DynamicCache"]: - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" - out = [] - for i in range(0, full_batch_size, split_size): - current_split = DynamicCache() - current_split._seen_tokens = self._seen_tokens - current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] - current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] - out.append(current_split) - return out - - @classmethod - def from_batch_splits(cls, splits: list["DynamicCache"]) -> "DynamicCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - cache = cls() - for idx in range(len(splits[0])): - key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx].numel()] - value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx].numel()] - if key_cache != []: - layer_keys = torch.cat(key_cache, dim=0) - layer_values = torch.cat(value_cache, dim=0) - cache.update(layer_keys, layer_values, idx) - return cache - - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) - - def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + pattern_block = (DynamicLayer,) # Utilities for `DynamicCache` <> torch.export support @@ -663,18 +947,17 @@ def _flatten_dynamic_cache( "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." ) - # NOTE it seems _seen_tokens is deprecated, so probably doesn't need tracking dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), + "key_cache": [layer.key_cache for layer in dynamic_cache.layers if layer.key_cache is not None], + "value_cache": [layer.value_cache for layer in dynamic_cache.layers if layer.value_cache is not None], } return torch.utils._pytree._dict_flatten(dictionary) def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), + "key_cache": [layer.key_cache for layer in dynamic_cache.layers if layer.key_cache is not None], + "value_cache": [layer.value_cache for layer in dynamic_cache.layers if layer.value_cache is not None], } return torch.utils._pytree._dict_flatten_with_keys(dictionary) @@ -685,15 +968,20 @@ def _unflatten_dynamic_cache( ): dictionary = torch.utils._pytree._dict_unflatten(values, context) cache = DynamicCache() - for k, v in dictionary.items(): - setattr(cache, k, v) + # Reconstruct layers from key_cache and value_cache lists + key_list = dictionary.get("key_cache", []) + value_list = dictionary.get("value_cache", []) + for idx in range(max(len(key_list), len(value_list))): + key = key_list[idx] if idx < len(key_list) else None + value = value_list[idx] if idx < len(value_list) else None + cache.update(key, value, idx) return cache def _flatten_dynamic_cache_for_fx(cache, spec): dictionary = { - "key_cache": getattr(cache, "key_cache"), - "value_cache": getattr(cache, "value_cache"), + "key_cache": [layer.key_cache for layer in cache.layers if layer.key_cache is not None], + "value_cache": [layer.value_cache for layer in cache.layers if layer.value_cache is not None], } return torch.fx._pytree._dict_flatten_spec(dictionary, spec) @@ -723,125 +1011,13 @@ class OffloadedCache(DynamicCache): ensure the eviction is scheduled after all computations on that cache are finished. """ - def __init__(self) -> None: - if not ( - torch.cuda.is_available() - or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) - ): - raise RuntimeError( - "OffloadedCache can only be used with a GPU" - + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") - ) - - super().__init__() - self.original_device = [] - self.prefetch_stream = None - self.prefetch_stream = ( - torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream() - ) - self.beam_idx = None # used to delay beam search operations - - def prefetch_layer(self, layer_idx: int): - "Starts prefetching the next layer cache" - if layer_idx < len(self): - with ( - self.prefetch_stream - if is_torch_greater_or_equal("2.7", accept_dev=True) - else torch.cuda.stream(self.prefetch_stream) - ): - # Prefetch next layer tensors to GPU - device = self.original_device[layer_idx] - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True) - - def evict_previous_layer(self, layer_idx: int): - "Moves the previous layer cache to the CPU" - if len(self) > 2: - # We do it on the default stream so it occurs after all earlier computations on these tensors are done - prev_layer_idx = (layer_idx - 1) % len(self) - self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) - self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) - - def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: - "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." - if layer_idx < len(self): - # Evict the previous layer if necessary - if is_torch_greater_or_equal("2.7", accept_dev=True): - torch.accelerator.current_stream().synchronize() - else: - torch.cuda.current_stream().synchronize() - self.evict_previous_layer(layer_idx) - # Load current layer cache to its original device if not already there - original_device = self.original_device[layer_idx] - self.prefetch_stream.synchronize() - key_tensor = self.key_cache[layer_idx] - value_tensor = self.value_cache[layer_idx] - # Now deal with beam search ops which were delayed - if self.beam_idx is not None: - self.beam_idx = self.beam_idx.to(original_device) - key_tensor = key_tensor.index_select(0, self.beam_idx) - value_tensor = value_tensor.index_select(0, self.beam_idx) - # Prefetch the next layer - self.prefetch_layer((layer_idx + 1) % len(self)) - return (key_tensor, value_tensor) - else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Saves the beam indices and reorders the cache when the tensor is back to its device.""" - # We delay this operation until the tensors are back to their original - # device because performing torch.index_select on the CPU is very slow - del self.beam_idx - self.beam_idx = beam_idx.clone() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. - Return: - A tuple containing the updated key and value states. - """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] + def __init__(self, config: Optional[CacheConfig] = None) -> None: + # Create the underlying cache with offload processor + processors = CacheProcessorList([OffloadedCacheProcessor()]) + super().__init__(processors=processors, config=config) - # Update the cache - if len(self.key_cache) < layer_idx: - raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") - elif len(self.key_cache) == layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - self.original_device.append(key_states.device) - self.evict_previous_layer(layer_idx) - else: - key_tensor, value_tensor = self[layer_idx] - self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError - # if a method is not supposed to be supported in a subclass we should set it to None - from_legacy_cache = None - - to_legacy_cache = None - - -class QuantizedCache(DynamicCache): +class QuantoQuantizedCache(DynamicCache): """ A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. @@ -853,87 +1029,8 @@ class QuantizedCache(DynamicCache): It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and Value in original precision states as a list of tensors, one for each layer. The size of each tensor is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - """ - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - super().__init__() - self._quantized_key_cache: list[torch.Tensor] = [] - self._quantized_value_cache: list[torch.Tensor] = [] - - self.nbits = cache_config.nbits - self.residual_length = cache_config.residual_length - self.q_group_size = cache_config.q_group_size - self.axis_key = cache_config.axis_key - self.axis_value = cache_config.axis_value - self.compute_dtype = cache_config.compute_dtype - self.device = cache_config.device - - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - if len(self.key_cache) < layer_idx: - raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") - elif len(self.key_cache) == layer_idx: - self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key)) - self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value)) - self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) - self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) - keys_to_return, values_to_return = key_states, value_states - else: - dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) - dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) - keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] - values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] - - keys_to_return = torch.cat(keys_to_return, dim=-2) - values_to_return = torch.cat(values_to_return, dim=-2) - if ( - self.key_cache[layer_idx].dim() == 4 - and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length - ): - self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) - self._quantized_value_cache[layer_idx] = self._quantize( - values_to_return.contiguous(), axis=self.axis_value - ) - self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) - self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return keys_to_return, values_to_return - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is - # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to verify attn_weight shape in some models - return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 - - def _quantize(self, tensor, axis): - """Quantizes a key/value using a defined quantization method.""" - raise NotImplementedError("Make sure to implement `_quantize` in a subclass.") - - def _dequantize(self, q_tensor): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.") - - -class QuantoQuantizedCache(QuantizedCache): - """ - Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. Parameters: cache_config (`QuantizedCacheConfig`): @@ -959,47 +1056,25 @@ class QuantoQuantizedCache(QuantizedCache): ``` """ - def __init__(self, cache_config: CacheConfig) -> None: - super().__init__(cache_config) - - if is_optimum_quanto_available(): - optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) - if optimum_quanto_version <= version.parse("0.2.5"): - raise ImportError( - f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." - ) - from optimum.quanto import MaxOptimizer, qint2, qint4 - - if self.nbits not in [2, 4]: - raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") - - if self.axis_key not in [0, -1]: - raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") - - if self.axis_value not in [0, -1]: - raise ValueError( - f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" - ) - - self.qtype = qint4 if self.nbits == 4 else qint2 - self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) + super(DynamicCache, self).__init__(processors=processors) - def _quantize(self, tensor, axis): - # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore - if is_optimum_quanto_available(): - from optimum.quanto import quantize_weight - scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) - qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) - return qtensor +class HQQQuantizedCache(DynamicCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - def _dequantize(self, qtensor): - return qtensor.dequantize() + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` -class HQQQuantizedCache(QuantizedCache): - """ - Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. Parameters: cache_config (`QuantizedCacheConfig`): @@ -1025,55 +1100,99 @@ class HQQQuantizedCache(QuantizedCache): ``` """ - def __init__(self, cache_config: CacheConfig) -> None: - super().__init__(cache_config) - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" - ) - - if self.axis_key not in [0, 1]: - raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)]) + super(DynamicCache, self).__init__(processors=processors) - if self.axis_value not in [0, 1]: - raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") - self.quantizer = HQQQuantizer +class StaticLayer(CacheLayer): + is_compileable = True - def _quantize(self, tensor, axis): - qtensor, meta = self.quantizer.quantize( - tensor, - axis=axis, - device=self.device, - compute_dtype=self.compute_dtype, - nbits=self.nbits, - group_size=self.q_group_size, + def __init__( + self, + config: StaticCacheConfig, + max_len: Optional[int] = None, + ): + self.max_cache_len = max_len or config.max_cache_len + self.max_batch_size = config.batch_size + # Note: There will be significant perf decrease if switching to use 5D tensors instead. + self.key_cache = torch.zeros( + (config.batch_size, config.num_heads, self.max_cache_len, config.head_dim), + dtype=config.dtype, + device=config.device, ) - meta["compute_dtype"] = self.compute_dtype - self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype - meta["scale"] = meta["scale"].to(qtensor.device) - meta["zero"] = meta["zero"].to(qtensor.device) - return qtensor, meta + self.value_cache = torch.zeros( + (config.batch_size, config.num_heads, self.max_cache_len, config.head_dim), + dtype=config.dtype, + device=config.device, + ) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(self.key_cache) + torch._dynamo.mark_static_address(self.value_cache) - def _dequantize(self, qtensor): - quant_tensor, meta = qtensor - tensor = self.quantizer.dequantize(quant_tensor, meta) - return tensor + def get_max_cache_shape(self) -> int: + return self.max_cache_len + + def _static_update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_position: Optional[torch.LongTensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Utility functions for static/sliding cache update logic + """ + Updates the static cache tensors in place. + Args: + k_cache (`torch.Tensor`): The key cache tensor to update. + v_cache (`torch.Tensor`): The value cache tensor to update. + key_states (`torch.Tensor`): The new key states to add. + value_states (`torch.Tensor`): The new value states to add. + cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted. + If None, the entire cache is overwritten (prefill). -class SinkCache(Cache): - """ - Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache. - See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for - general `custom_generate`usage. - """ + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place). + """ + if cache_position is None: + # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. + self.key_cache.copy_(key_states) + self.value_cache.copy_(value_states) + else: + # Generation phase. Update specific positions. + # Use index_copy_ for in-place update (compile-friendly). + try: + self.key_cache.index_copy_(2, cache_position, key_states) + self.value_cache.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # Fallback for devices like MPS where index_copy_ might not be supported. + self.key_cache[:, :, cache_position] = key_states + self.value_cache[:, :, cache_position] = value_states + return self.key_cache, self.value_cache + + def update(self, key_states, value_states, cache_kwargs=None) -> tuple[torch.Tensor, torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + key_states = key_states.to(self.key_cache.dtype) + value_states = value_states.to(self.value_cache.dtype) + return self._static_update(key_states, value_states, cache_position) + + def get_seq_length(self, cache_position=None) -> int: + if cache_position is not None: + return int(cache_position[-1] + 1) + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + return (self.key_cache[0, 0].any(dim=-1)).sum() - # TODO (joao, manuel): Remove this class in v4.59.0 - def __init__(self, **kwargs) -> None: - raise NotImplementedError( - "`SinkCache` has been moved as a `custom_generate` repository on the Hub: " - "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples." - ) + def reset(self): + self.key_cache.zero_() + self.value_cache.zero_() + + def reorder_cache(self, beam_idx): + dev = self.key_cache.device + beam_idx_dev = beam_idx.to(dev) + self.key_cache = self.key_cache.index_select(0, beam_idx_dev) + self.value_cache = self.value_cache.index_select(0, beam_idx_dev) class StaticCache(Cache): @@ -1081,23 +1200,9 @@ class StaticCache(Cache): Static Cache class to be used with `torch.compile(model)` and `torch.export()`. Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. If you are manually setting the batch size, make sure to take into account the - number of beams if you are running beam search - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. + config_or_ddp_cache_data (`Union`, *optional*): Model configuration for shape/device info, or DDP-distributed cache data for compatibility. + processors (`Optional`, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + pattern_block (`Optional`, *optional*): Pattern of cache layer types to use. Defaults to `(StaticLayer,)` for backward compatibility. Example: @@ -1120,117 +1225,80 @@ class StaticCache(Cache): ``` """ - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.float32, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - super().__init__() - self.max_batch_size = max_batch_size - self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len + pattern_block = (StaticLayer,) - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - self._dtype = dtype - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) +class SlidingWindowLayer(StaticLayer): + """ + A static cache layer that implements sliding window attention caching. + Inherits from StaticLayer but uses sliding window update logic. + """ - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - # Note: There will be significant perf decrease if switching to use 5D tensors instead. - cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - device = torch.device(device) if device is not None else None - for idx in range(config.num_hidden_layers): - if layer_device_map is not None: - layer_device = layer_device_map[idx] - else: - layer_device = device - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) + def __init__(self, config: CacheConfig): + super().__init__(config, max_len=config.sliding_window) - def update( + def _static_update( self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, + cache_position: torch.LongTensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + Updates the sliding window cache tensors, returning the potentially modified tensors. - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input - to know how where to write in the cache. + Args: + k_cache (`torch.Tensor`): The key cache tensor to update. + v_cache (`torch.Tensor`): The value cache tensor to update. + key_states (`torch.Tensor`): The new key states to add. + value_states (`torch.Tensor`): The new value states to add. + cache_position (`torch.LongTensor`): The position indices where the new states should be inserted. + max_cache_len (`int`): The maximum length of the sliding window cache. - Return: - A tuple containing the updated key and value states. + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update. + For prefill > window, these are the full input states. + Otherwise, they are the updated cache tensors. """ - if cache_kwargs is None: - cache_kwargs = {} - key_states = key_states.to(self.key_cache[layer_idx].dtype) - value_states = value_states.to(self.value_cache[layer_idx].dtype) - return _static_cache_update( - self.key_cache[layer_idx], - self.value_cache[layer_idx], - key_states, - value_states, - cache_kwargs.get("cache_position"), - ) + if cache_position is None: + raise ValueError("`cache_position` must be provided for SlidingWindowLayer.") - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + # Handle prefill phase when prompt length > sliding_window_size + if cache_position.shape[0] > self.max_cache_len: + new_k = key_states[:, :, -self.max_cache_len :, :] + new_v = value_states[:, :, -self.max_cache_len :, :] + self.key_cache.copy_(new_k) + self.value_cache.copy_(new_v) + return self.key_cache, self.value_cache - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len + # Sliding window logic for generation phase or prefill < window + slicing = torch.arange(self.max_cache_len, device=value_states.device) + current_seq_len = cache_position[-1] + 1 # Use last position to determine current length + to_shift = current_seq_len > self.max_cache_len + indices = (slicing + to_shift.sum()) % self.max_cache_len - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + k_out_shifted = self.key_cache[:, :, indices] + v_out_shifted = self.value_cache[:, :, indices] - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - kv_length = self.get_max_cache_shape() - return kv_length, 0 + # Clamp cache_position to determine the *target index* within the shifted cache view + update_position = cache_position.clamp(min=0, max=self.max_cache_len - 1) + + try: + k_out_updated = k_out_shifted.index_copy(2, update_position, key_states) + v_out_updated = v_out_shifted.index_copy(2, update_position, value_states) + except NotImplementedError: + # Fallback for MPS: clone and modify the clone + k_out_updated = k_out_shifted.clone() + v_out_updated = v_out_shifted.clone() + k_out_updated[:, :, update_position] = key_states + v_out_updated[:, :, update_position] = value_states + + self.key_cache.copy_(k_out_updated) + self.value_cache.copy_(v_out_updated) + return self.key_cache, self.value_cache -class SlidingWindowCache(StaticCache): +class SlidingWindowCache(Cache): """ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, @@ -1246,25 +1314,6 @@ class SlidingWindowCache(StaticCache): 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) - - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - Example: ```python @@ -1285,83 +1334,7 @@ class SlidingWindowCache(StaticCache): ``` """ - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.float32, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - max_cache_len = min(config.sliding_window, max_cache_len) - self.sliding_window = config.sliding_window - super().__init__( - config=config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=device, - dtype=dtype, - layer_device_map=layer_device_map, - ) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - - if cache_position is None: - raise ValueError("`cache_position` must be provided for SlidingWindowCache.") - - key_states = key_states.to(self.key_cache[layer_idx].dtype) - value_states = value_states.to(self.value_cache[layer_idx].dtype) - - return _sliding_cache_update( - self.key_cache[layer_idx], - self.value_cache[layer_idx], - key_states, - value_states, - cache_position, - self.max_cache_len, - ) - - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len - - def reset(self): - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] - # torch.clamp() is equivalent to max() but should be compile-friendly/exportable as first_cache_position is a Tensor - kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) - # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns - kv_length = max(query_length, self.get_max_cache_shape()) - return kv_length, kv_offset + pattern_block = (SlidingWindowLayer,) class EncoderDecoderCache(Cache): @@ -1436,7 +1409,7 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]: @classmethod def from_legacy_cache( - cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None + cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None ) -> "EncoderDecoderCache": """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" cache = cls( @@ -1507,22 +1480,6 @@ def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDec out.append(EncoderDecoderCache(self_attn, cross_attn)) return out - @classmethod - def from_batch_splits(cls, splits: list["EncoderDecoderCache"]) -> "EncoderDecoderCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - self_attention_cache = DynamicCache() - cross_attention_cache = DynamicCache() - for idx in range(len(splits[0])): - layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0) - layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0) - self_attention_cache.update(layer_keys, layer_values, idx) - - layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0) - layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0) - cross_attention_cache.update(layer_keys, layer_values, idx) - return cls(self_attention_cache, cross_attention_cache) - def batch_repeat_interleave(self, repeats: int): """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" self.check_dynamic_cache(self.batch_repeat_interleave.__name__) @@ -1535,7 +1492,7 @@ def batch_select_indices(self, indices: torch.Tensor): self.self_attention_cache.batch_select_indices(indices) self.cross_attention_cache.batch_select_indices(indices) - def get_max_cache_shape(self) -> Optional[int]: + def get_max_cache_shape(self) -> int: """Returns the maximum sequence length (i.e. max capacity) of the cache object""" return self.self_attention_cache.get_max_cache_shape() @@ -1557,23 +1514,10 @@ class HybridCache(Cache): for global attention.For more information, see the documentation of each subcomponent cache class. Parameters: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (torch.dtype, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - + config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): Model configuration for shape/device info. No DDP-distributed cache data is supported. + processors (`CacheProcessorList`, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): Pattern of cache layer types to use. Defaults to `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)` + for backward compatibility. Example: ```python @@ -1594,146 +1538,24 @@ class HybridCache(Cache): ``` """ - is_compileable = True - def __init__( self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.float32, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - super().__init__() - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'hybrid' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings - # Sliding layers can't be larger than the overall max cache len - self.sliding_window_len = min(config.sliding_window, self.max_cache_len) - self.max_batch_size = max_batch_size - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) - - self._dtype = dtype - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - + config_or_ddp_cache_data=None, + processors: Optional[CacheProcessorList] = None, + pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, + *args, + **kwargs, + ): + model_config = config_or_ddp_cache_data or kwargs.get("config", None) + assert model_config is not None, "HybridCache requires a model config" # If the attribute does not exist in the config, fallback to a simple StaticCache - if hasattr(config, "layer_types"): - self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types] - else: - self.is_sliding = [False] * config.num_hidden_layers - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim) - self.sliding_window = min(config.sliding_window, max_cache_len) - device = torch.device(device) if device is not None else None - for i in range(config.num_hidden_layers): - if layer_device_map is not None: - layer_device = layer_device_map[i] - else: - layer_device = device - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - if cache_position is None: - raise ValueError("`cache_position` must be provided for HybridCache.") - - is_sliding_layer = self.is_sliding[layer_idx] - - # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used - # when the cache is initialized in the forward pass (e.g. Gemma2) - if self.key_cache[layer_idx].device != key_states.device: - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) - if self.value_cache[layer_idx].device != value_states.device: - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) - - k_cache = self.key_cache[layer_idx] - v_cache = self.value_cache[layer_idx] - key_states = key_states.to(k_cache.dtype) - value_states = value_states.to(v_cache.dtype) - - if is_sliding_layer: - return _sliding_cache_update( - k_cache, - v_cache, - key_states, - value_states, - cache_position, - k_cache.shape[2], # Use actual cache dim as max cache len - ) - else: - return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position) - - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len - - def get_seq_length(self, layer_idx: Optional[int] = 0): - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx != 0: - raise ValueError( - "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " - "Using the `layer_idx` argument is not supported." - ) - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - if self.is_sliding[layer_idx]: - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] - - local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) - # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns - local_mask_kv_length = max(query_length, self.sliding_window) - return local_mask_kv_length, local_mask_kv_offset + if hasattr(model_config, "layer_types"): + self.is_sliding = [layer_type != "full_attention" for layer_type in model_config.layer_types] + else: + self.is_sliding = [False] * model_config.num_hidden_layers - full_mask_kv_offset = 0 - full_mask_kv_length = self.get_max_cache_shape() - return full_mask_kv_length, full_mask_kv_offset + pattern_block = tuple(SlidingWindowLayer if sl else StaticLayer for sl in self.is_sliding) + super().__init__(config_or_ddp_cache_data, processors, pattern_block, *args, **kwargs) class HybridChunkedCache(Cache): @@ -1783,6 +1605,9 @@ class HybridChunkedCache(Cache): """ is_compileable = True + # Override @property since HybridChunked does its own thing + key_cache = None + value_cache = None def __init__( self, @@ -1857,8 +1682,13 @@ def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2) full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2) else: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + try: + self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + except NotImplementedError: + # MPS does not support index_copy_ + self.key_cache[layer_idx][:, :, cache_position] = key_states + self.value_cache[layer_idx][:, :, cache_position] = value_states return self.key_cache[layer_idx], self.value_cache[layer_idx] self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) @@ -1903,7 +1733,7 @@ def update( k_out.shape[2], ) - def get_max_cache_shape(self) -> Optional[int]: + def get_max_cache_shape(self) -> int: return self.max_cache_len def get_seq_length(self, layer_idx: Optional[int] = 0): @@ -1927,6 +1757,16 @@ def reset(self): self.value_cache[layer_idx].zero_() self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.value_cache[layer_idx].numel(): + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: """ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for @@ -2073,133 +1913,24 @@ def _prefetch_layer_in_context(self, layer_idx: int) -> None: self.device_value_cache[self.active_device_layer].fill_(0.0) -class MambaCache: - """ - Cache for mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Example: - - ```python - >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache - - >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") - - >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values - MambaCache() - ``` - """ - - is_compileable = True - - # TODO (joao): add layer_device_map arg and update code in `generate` accordingly - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float16, - device: Union[torch.device, str, None] = None, - ): - self.max_batch_size = max_batch_size - self._dtype = dtype - self.intermediate_size = config.intermediate_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - - self.conv_states: list[torch.Tensor] = [] - self.ssm_states: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - for _ in range(config.num_hidden_layers): - conv_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=self._dtype, - ) - ssm_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=self._dtype, - ) - - torch._dynamo.mark_static_address(conv_state) - torch._dynamo.mark_static_address(ssm_state) - self.conv_states.append(conv_state) - self.ssm_states.append(ssm_state) - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor - ) -> torch.Tensor: - # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used - # when the cache is initialized in the forward pass (e.g. Mamba) - if self.conv_states[layer_idx].device != new_conv_state.device: - self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) - - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) - return self.ssm_states[layer_idx] - - def reset(self): - for layer_idx in range(len(self.conv_states)): - # In-place ops prevent breaking the static address - self.conv_states[layer_idx].zero_() - self.ssm_states[layer_idx].zero_() - - class OffloadedStaticCache(StaticCache): """ - Static cache class to be used with `torch.compile(model)` that offloads to the CPU or - another device. + A drop-in replacement for StaticCache that conserves accelerator memory by offloading + cache tensors to CPU when not actively being used. - Args: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize - the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. - max_cache_len (`int`): - The maximum sequence length with which the model will be used. - device (`Union[str, torch.device]`): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*): - The default `dtype` to use when initializing the cache. - offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): - The device to offload to. Defaults to CPU. - layer_device_map (`dict[int, Union[str, torch.device, int]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. + This cache maintains the compilation-friendly properties of StaticCache while enabling + much longer sequences by offloading inactive layers to CPU memory. - Example: + Parameters: + config (`PretrainedConfig`): Model configuration for shape/device info. + max_batch_size (`int`): Maximum batch size for static caches. + max_cache_len (`int`, *optional*): Maximum sequence length. + device (`torch.device` or `str`, *optional*): Device for cache tensors. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type for cache tensors. + offload_device (`Union`, *optional*, defaults to `"cpu"`): Device to offload cache tensors to. + layer_device_map (`dict[int, Union[str, torch.device, int]]`, *optional*): Per-layer device mapping. + Example: ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache @@ -2208,240 +1939,493 @@ class OffloadedStaticCache(StaticCache): >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> # Prepare a cache class with offloading >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = OffloadedStaticCache( + ... config=model.config, + ... max_batch_size=1, + ... max_cache_len=max_generated_length, + ... device=model.device, + ... dtype=model.dtype + ... ) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + >>> outputs.past_key_values # access cache with offloaded layers + OffloadedStaticCache() ``` """ - is_compileable = True - def __init__( self, config: PretrainedConfig, max_batch_size: int, - max_cache_len: Optional[int], - device: Union[str, torch.device], + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, dtype: Optional[torch.dtype] = None, - offload_device: Union[str, torch.device] = torch.device("cpu"), + offload_device: Union[str, torch.device] = "cpu", layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, ) -> None: - super(Cache, self).__init__() + # Create offload processor + processors = CacheProcessorList([OffloadedCacheProcessor(offload_device)]) - # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps - # track of the original device of each layer - unique_devices = set(layer_device_map.values()) if layer_device_map else set() - if len(unique_devices) > 1: - raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}") + # Initialize the base StaticCache with the processor + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + layer_device_map=layer_device_map, + processors=processors, + ) - self.max_batch_size = max_batch_size - self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) + +class OffloadedCacheProcessor(CacheProcessor): + """ + A cache processor that offloads cache tensors to conserve accelerator memory. + + This processor manages moving cache tensors between accelerator and CPU memory, + using asynchronous prefetching to minimize performance impact. Works with both + dynamic and static layers. + """ + + def __init__(self, offload_device: Union[str, torch.device] = "cpu"): self.offload_device = torch.device(offload_device) - self._dtype = dtype if dtype is not None else torch.float32 + self.original_device = [] + self.prefetch_stream = None + self.beam_idx = None + + def init(self, cache: "Cache", **kwargs) -> None: + """Initialize the offload processor and check device compatibility.""" + if not ( + torch.cuda.is_available() + or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) + ): + raise RuntimeError( + "OffloadedCacheProcessor can only be used with a GPU" + + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") + ) - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers) + if self.is_static: + for i, layer in enumerate(cache.layers): + device = cache.config.device if i == 0 else self.offload_device + layer.key_cache = layer.key_cache.to(device) + layer.value_cache = layer.value_cache.to(device) + self.original_device.append(cache.config.device) + if len(cache) != cache.config.num_layers: + raise ValueError("If static layers are used, all cache layers must be initialized") - num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads + self.prefetch_stream = ( + torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream() ) - cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) + def pre_update( + self, + cache: "Cache", + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Handle prefetching and eviction before cache update.""" + # Update the cache + if len(cache) < layer_idx: + raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") + elif len(cache) == layer_idx: + self.original_device.append(key_states.device) + self._evict_previous_layer(cache, layer_idx) + else: + # Wait for the previous layer to be evicted (on default stream) + if is_torch_greater_or_equal("2.7", accept_dev=True): + torch.accelerator.current_stream().synchronize() + else: + torch.cuda.current_stream().synchronize() + self._evict_previous_layer(cache, layer_idx) + self._ensure_layer_on_device(cache, layer_idx) + + # Prefetch the next layer + self._prefetch_layer(cache, (layer_idx + 1) % len(cache)) + return key_states, value_states - # Create offloaded CPU tensors. - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] + def _prefetch_layer(self, cache: "Cache", layer_idx: int): + """Starts prefetching the next layer cache.""" + if layer_idx < len(cache): + with ( + self.prefetch_stream + if is_torch_greater_or_equal("2.7", accept_dev=True) + else torch.cuda.stream(self.prefetch_stream) + ): + # Prefetch next layer tensors to GPU + device = self.original_device[layer_idx] + cache.key_cache[layer_idx] = cache.key_cache[layer_idx].to(device, non_blocking=True) + cache.value_cache[layer_idx] = cache.value_cache[layer_idx].to(device, non_blocking=True) - for i in range(config.num_hidden_layers): - # First layer is always on-device. - device = self.device if i == 0 else self.offload_device + def _evict_previous_layer(self, cache: "Cache", layer_idx: int): + """Moves the previous layer cache to the CPU.""" + if len(cache) >= 2: # Layer 0 stays on device to be on-device after all layers are created + # We do it on the default stream so it occurs after all earlier computations on these tensors are done + prev_layer_idx = (layer_idx - 1) % len(cache) + cache.key_cache[prev_layer_idx] = cache.key_cache[prev_layer_idx].to( + self.offload_device, non_blocking=True + ) + cache.value_cache[prev_layer_idx] = cache.value_cache[prev_layer_idx].to( + self.offload_device, non_blocking=True + ) - key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device) + def _ensure_layer_on_device(self, cache: "Cache", layer_idx: int): + """Ensures the current layer is on the original device.""" + if layer_idx < len(cache): + # Wait for the previous prefetch to be done + self.prefetch_stream.synchronize() - self.key_cache.append(key_cache) - self.value_cache.append(value_cache) + # Handle delayed beam search operations + if self.beam_idx is not None: + self.beam_idx = self.beam_idx.to(self.original_device[layer_idx]) + cache.key_cache[layer_idx] = cache.key_cache[layer_idx].index_select(0, self.beam_idx) + cache.value_cache[layer_idx] = cache.value_cache[layer_idx].index_select(0, self.beam_idx) - # Create device tensors. - self._device_key_cache: list[torch.Tensor] = [] - self._device_value_cache: list[torch.Tensor] = [] - for i in range(2): - key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device) +class QuantizedCacheProcessor(CacheProcessor): + """ + A cache processor that applies quantization to cache tensors to reduce memory usage. - self._device_key_cache.append(key_cache) - self._device_value_cache.append(value_cache) + This processor quantizes cache tensors after they are stored, maintaining a residual + length in original precision and quantizing older tokens. + """ - # For backwards compatibility. - # TODO(gante): Remove this. + def __init__(self, cache_config: QuantizedCacheConfig): + self.config = cache_config + self._quantized_key_cache: list[torch.Tensor] = [] + self._quantized_value_cache: list[torch.Tensor] = [] self._seen_tokens = 0 - # Create new CUDA stream for parallel prefetching. - self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None + def init(self, cache: "Cache", **kwargs) -> None: + """Initialize the quantized processor and validate configuration.""" + self.config.validate() - def update( + # Only compatible with DynamicCache + if not isinstance(cache, DynamicCache): + raise ValueError("QuantizedCacheProcessor is only compatible with DynamicCache") + + def post_update( self, - key_states: torch.Tensor, - value_states: torch.Tensor, + cache: "Cache", + key_tensors: torch.Tensor, + value_tensors: torch.Tensor, layer_idx: int, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + """Apply quantization after cache update.""" + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_tensors.shape[-2] - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, *optional*): - Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the - `cache_position` input to know how where to write in the cache. + # Extend quantized cache if needed + while len(self._quantized_key_cache) <= layer_idx: + self._quantized_key_cache.append(torch.empty(0)) + self._quantized_value_cache.append(torch.empty(0)) - Return: - A tuple containing the updated key and value states. - """ + # Check if we need to quantize + if layer_idx < len(cache.key_cache): + current_key = cache.key_cache[layer_idx] + current_value = cache.value_cache[layer_idx] - key_states = key_states.to(self.key_cache[layer_idx].dtype) - value_states = value_states.to(self.value_cache[layer_idx].dtype) + if ( + current_key.dim() == 4 + and current_key.shape[-2] >= self.config.residual_length + and current_key.shape[-2] > self._get_quantized_length(layer_idx) + ): + # Quantize the older part, keep recent tokens in original precision + split_idx = current_key.shape[-2] - self.config.residual_length + + # Get the part to quantize + key_to_quantize = current_key[:, :, :split_idx, :].contiguous() + value_to_quantize = current_value[:, :, :split_idx, :].contiguous() + + # Quantize and store + self._quantized_key_cache[layer_idx] = self._quantize(key_to_quantize, axis=self.config.axis_key) + self._quantized_value_cache[layer_idx] = self._quantize(value_to_quantize, axis=self.config.axis_value) + + # Keep only the recent tokens in original precision + cache.key_cache[layer_idx] = current_key[:, :, split_idx:, :] + cache.value_cache[layer_idx] = current_value[:, :, split_idx:, :] + + # Return the full tensors for this update + if self._quantized_key_cache[layer_idx].numel() > 0: + dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) + dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + full_key = torch.cat([dequant_key, cache.key_cache[layer_idx]], dim=-2) + full_value = torch.cat([dequant_value, cache.value_cache[layer_idx]], dim=-2) + return full_key, full_value + + return key_tensors, value_tensors + + def _get_quantized_length(self, layer_idx: int) -> int: + """Get the length of quantized cache for a layer.""" + if layer_idx < len(self._quantized_key_cache) and self._quantized_key_cache[layer_idx].numel() > 0: + # This would depend on the specific quantization implementation + return ( + self._quantized_key_cache[layer_idx].shape[-2] + if hasattr(self._quantized_key_cache[layer_idx], "shape") + else 0 + ) + return 0 - if layer_idx == 0: - # Update seen tokens. - # TODO(gante): Remove this. - self._seen_tokens += key_states.shape[-2] + def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: + """Quantize a tensor - to be implemented by specific quantization backends.""" + raise NotImplementedError("Quantization backend must implement _quantize method") - # Always there. - k_out = self.key_cache[0] - v_out = self.value_cache[0] - else: - # Wait for prefetch stream. - if self._prefetch_stream is not None: - torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream) + def _dequantize(self, tensor: torch.Tensor) -> torch.Tensor: + """Dequantize a tensor - to be implemented by specific quantization backends.""" + raise NotImplementedError("Quantization backend must implement _dequantize method") - k_out = self._device_key_cache[layer_idx & 1] - v_out = self._device_value_cache[layer_idx & 1] - self._prefetch_layer(layer_idx + 1) +class QuantoQuantizedCacheProcessor(QuantizedCacheProcessor): + """ + Quantized cache processor that uses `quanto` as a backend to perform quantization. + Current implementation supports `int2` and `int4` dtypes only. + """ - cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) + def init(self, cache: "Cache", **kwargs) -> None: + """Initialize the quanto quantization processor.""" + super().init(cache, **kwargs) - # Copy the values to the offloaded device as well. - if layer_idx == 0: - self.key_cache[layer_idx].copy_(key_states.to(self.offload_device)) - self.value_cache[layer_idx].copy_(value_states.to(self.offload_device)) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does - # explicitly an in-place operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS - # device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # Copy the values to the offloaded device as well. - if layer_idx != 0: - cache_position = cache_position.to(self.offload_device) - key_states = key_states.to(self.offload_device) - value_states = value_states.to(self.offload_device) - - try: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS - # device. - self.key_cache[layer_idx][:, :, cache_position] = key_states - self.value_cache[layer_idx][:, :, cache_position] = value_states + if is_optimum_quanto_available(): + optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) + if optimum_quanto_version <= version.parse("0.2.5"): + raise ImportError( + f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCacheProcessor`. Detected version {optimum_quanto_version}." + ) + from optimum.quanto import MaxOptimizer, qint2, qint4 - return k_out, v_out + if self.config.nbits not in [2, 4]: + raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.config.nbits}") - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" + if self.config.axis_key not in [0, -1]: + raise ValueError( + f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.config.axis_key}" + ) - # TODO(gante): Remove this. - return self._seen_tokens + if self.config.axis_value not in [0, -1]: + raise ValueError( + f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.config.axis_value}" + ) - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states.""" + self.qtype = qint4 if self.config.nbits == 4 else qint2 + self.optimizer = MaxOptimizer() - return self.max_cache_len + def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: + """Quantize tensor using quanto backend.""" + if is_optimum_quanto_available(): + from optimum.quanto import quantize_weight - def reset(self) -> None: - """Resets the cache values while preserving the objects.""" + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.config.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.config.q_group_size) + return qtensor - # For backwards compatibility. - # TODO(gante): Remove this. - self._seen_tokens = 0 + def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor: + """Dequantize tensor using quanto backend.""" + return qtensor.dequantize() - # Zero out cache. - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address. - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - @property - def seen_tokens(self) -> int: - # For backwards compatibility. - # TODO(gante): Remove this. - return self._seen_tokens +class HQQQuantizedCacheProcessor(QuantizedCacheProcessor): + """ + Quantized cache processor that uses `HQQ` as a backend to perform quantization. + Current implementation supports `int2`, `int4`, `int8` dtypes. + """ - def _create_key_value_cache_tensors( - self, shape: tuple[int, ...], device: torch.device - ) -> tuple[torch.Tensor, torch.Tensor]: - """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static - addresses for non-CPU tensors. + def init(self, cache: "Cache", **kwargs) -> None: + """Initialize the HQQ quantization processor.""" + super().init(cache, **kwargs) - Args: - shape (`tuple[int, ...]`): Shape. - device (`torch.device`): Device. + if self.config.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.config.nbits}" + ) - Returns: - Key and value cache tensors as a tuple. - """ + if self.config.axis_key not in [0, 1]: + raise ValueError( + f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.config.axis_key}" + ) - is_cpu_device = device == torch.device("cpu") + if self.config.axis_value not in [0, 1]: + raise ValueError( + f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.config.axis_value}" + ) - key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) - value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) + self.quantizer = HQQQuantizer - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(key_cache) - torch._dynamo.mark_static_address(value_cache) + def _quantize(self, tensor: torch.Tensor, axis: int) -> tuple[torch.Tensor, dict]: + """Quantize tensor using HQQ backend.""" + qtensor, meta = self.quantizer.quantize( + tensor, + axis=axis, + device=self.config.device, + compute_dtype=self.config.compute_dtype, + nbits=self.config.nbits, + group_size=self.config.q_group_size, + ) + meta["compute_dtype"] = self.config.compute_dtype + self.quantizer.cuda(qtensor, meta=meta, device=self.config.device) # Move to device and cast to dtype + meta["scale"] = meta["scale"].to(qtensor.device) + meta["zero"] = meta["zero"].to(qtensor.device) + return qtensor, meta - return key_cache, value_cache + def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tensor: + """Dequantize tensor using HQQ backend.""" + quant_tensor, meta = qtensor_and_meta + tensor = self.quantizer.dequantize(quant_tensor, meta) + return tensor - def _prefetch_layer(self, layer_idx: int) -> None: - """Prefetch a layer to the device. Needs to be called in order of layer indices.""" - # Don't fetch layers that do not exist. - if layer_idx >= len(self.key_cache): - return +class QuantizedCache(DynamicCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - # Alternate between two on-device caches. - if self._prefetch_stream is not None: - with torch.cuda.stream(self._prefetch_stream): - self._prefetch_layer_in_context(layer_idx) - else: - self._prefetch_layer_in_context(layer_idx) + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - def _prefetch_layer_in_context(self, layer_idx: int) -> None: - """Performs the actual copy of the layer to device cache.""" + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + """ + + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) + super().__init__(processors=processors) + + +class SinkCache(Cache): + """ + Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache. + See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for + general `custom_generate`usage. + """ + + # TODO (joao, manuel): Remove this class in v4.59.0 + def __init__(self, **kwargs) -> None: + raise NotImplementedError( + "`SinkCache` has been moved as a `custom_generate` repository on the Hub: " + "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples." + ) - self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) - self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True) + +# TODO (manuel, joao): remove this class, it is here only for backwards compatibility +# PEP 562: Lazy loading for deprecated location of MambaCache +def __getattr__(name: str) -> Any: + if name == "MambaCache": + warnings.warn( + ( + "Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed " + "in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead." + ), + FutureWarning, + stacklevel=2, + ) + + class MambaCache: + """ + Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed + in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead. + + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + MambaCache() + ``` + """ + + is_compileable = True + + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Union[torch.device, str, None] = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. Mamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() + + return MambaCache + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 76587659371b..baae6690a94d 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -54,7 +54,6 @@ HQQQuantizedCache, HybridCache, HybridChunkedCache, - MambaCache, OffloadedHybridCache, OffloadedStaticCache, QuantizedCacheConfig, @@ -67,6 +66,8 @@ CACHE_CONFIG_MAPPING["quantized"] = QuantizedCacheConfig CACHE_CONFIG_MAPPING["static"] = StaticCacheConfig + CACHE_CONFIG_MAPPING["sliding_window"] = StaticCacheConfig + CACHE_CONFIG_MAPPING["hybrid"] = StaticCacheConfig NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache, "offloaded_static": OffloadedStaticCache, @@ -75,12 +76,9 @@ "hybrid_chunked": HybridChunkedCache, "offloaded_hybrid": OffloadedHybridCache, "offloaded_hybrid_chunked": OffloadedHybridCache, - "mamba": MambaCache, } QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} - ALL_CACHE_IMPLEMENTATIONS = ( - list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded", "dynamic"] - ) + ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + ["offloaded", "dynamic", "quantized"] class GenerationMode(ExplicitEnum): @@ -186,7 +184,6 @@ class GenerationConfig(PushToHubMixin): - `"offloaded_static"`: [`OffloadedStaticCache`] - `"sliding_window"`: [`SlidingWindowCache`] - `"hybrid"`: [`HybridCache`] - - `"mamba"`: [`MambaCache`] - `"quantized"`: [`QuantizedCache`] If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index bb15454c7f5b..74cdbbc9d982 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1929,9 +1929,8 @@ def _get_cache( or isinstance( cache_to_check, (HybridChunkedCache, OffloadedHybridCache) ) # due to internal slicing, we always re-init + or cache_to_check.max_cache_len < max_cache_len ) - if cache_implementation != "mamba": - need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len if requires_cross_attention_cache and hasattr(self, "_cache"): need_new_cache = ( @@ -1966,13 +1965,14 @@ def _get_cache( def _supports_default_dynamic_cache(self) -> bool: """ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. - This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which - uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in + This is mostly the same as `_supports_cache_class` attribute, but add exception for `Mamba` models which + use their own caches and do not need to initialize the Cache in advance in order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed - for `HybridMambaAttentionDynamicCache`). + for mamba-based models). """ return ( self._supports_cache_class + and "mamba" not in self.__class__.__name__.lower() and "jamba" not in self.__class__.__name__.lower() and "zamba" not in self.__class__.__name__.lower() and "bamba" not in self.__class__.__name__.lower() @@ -2023,7 +2023,7 @@ def _prepare_cache_for_generation( if generation_config.use_cache is False: return - # Quick escape route 3: model that only supports legacy caches = nothing to prepare + # Quick escape route 3: model that only supports legacy caches or models that supply it in `prepare_inputs_for_generation` (mamba, zamba, ...) if not self._supports_default_dynamic_cache(): if generation_config.cache_implementation is not None: warnings.warn( @@ -5199,106 +5199,6 @@ def _ranking_fast( return selected_idx -def _split(data, full_batch_size: int, split_size: int): - """ - Takes care of three cases: - 1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim - 2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and - return a list of tuples - 3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and - return a list of tuples of tuples - (see documentation of ModelOutput) - """ - if data is None: - return [None] * (full_batch_size // split_size) - if isinstance(data, torch.Tensor): - return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] - # New cache format - elif isinstance(data, DynamicCache) or ( - isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache) - ): - return data.batch_split(full_batch_size, split_size) - elif isinstance(data, tuple): - # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) - if isinstance(data[0], tuple): - return [ - tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data) - for i in range(0, full_batch_size, split_size) - ] - - else: - return [ - tuple(sub_tensor[i : i + split_size] for sub_tensor in data) - for i in range(0, full_batch_size, split_size) - ] - else: - raise TypeError(f"Unexpected attribute type: {type(data)}") - - -def _split_model_inputs( - model_input: Union[ModelOutput, dict], split_size: int, full_batch_size: int, config: PretrainedConfig -) -> list[Union[ModelOutput, dict]]: - """ - Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split - size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from - previous forward pass. - """ - # Edge case: if model_input is None, return a list of Nones - # this happens with Whisper where encoder_outputs is None - if model_input is None: - return [model_input] * (full_batch_size // split_size) - # Infer the class from the object - model_output_cls = type(model_input) - if (full_batch_size % split_size) != 0: - raise ValueError("`full_batch_size` must be divisible by `split_size`") - - if split_size > full_batch_size: - raise ValueError("`split_size` must be smaller or equal to `full_batch_size`") - - # Helper function to split tensors or tuples of tensors - - # Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them - keys = ( - model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys() - ) - # We only keep keys that are in the model_input - keys = [k for k in keys if k in model_input] - # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a - # ModelOutput object. - # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] - keys_to_ignore = ["cache_position", "encoder_outputs", "logits_to_keep"] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] - - # we split the tensors and tuples of tensors - data_split_list = [ - {k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys} - for i in range(full_batch_size // split_size) - ] - # bool values are the same and replicated for each split - bool_data = {k: model_input[k] for k in bool_keys} - # encoder_outputs is a ModelOutput object and should be split by its own - if "encoder_outputs" in model_input: - encoder_outputs_split = _split_model_inputs( - model_input["encoder_outputs"], split_size, full_batch_size, config.get_text_config() - ) - data_split_list = [ - {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) - ] - # logits_to_keep should be replicated for each split, similar to bool values - if "logits_to_keep" in model_input: - data_split_list = [ - {**data_split, "logits_to_keep": model_input["logits_to_keep"]} for data_split in data_split_list - ] - - # Convert each dictionary in the list to an object of the inferred class - split_model_inputs: list[Union[ModelOutput, dict]] = [ - model_output_cls(**data_split, **bool_data) for data_split in data_split_list - ] - - return split_model_inputs - - def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput: """ Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the @@ -5323,11 +5223,6 @@ def _concat(data): return None if isinstance(data[0], torch.Tensor): return torch.cat(data, dim=0) - # New cache format - elif isinstance(data[0], DynamicCache): - return DynamicCache.from_batch_splits(data) - elif isinstance(data[0], EncoderDecoderCache): - return EncoderDecoderCache.from_batch_splits(data) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index eb75d8f2b80a..fe066898d5ce 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -65,7 +65,7 @@ logger = logging.get_logger(__name__) -class FalconHybridMambaAttentionDynamicCache(DynamicCache): +class FalconHybridMambaAttentionDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -79,6 +79,10 @@ class FalconHybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__( self, config: FalconH1Config, diff --git a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py index 4099920f4028..f7e7719b37f9 100644 --- a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py @@ -1,5 +1,11 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_falcon_mamba.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,15 +18,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""FALCONMAMBA configuration""" import math from ...configuration_utils import PretrainedConfig -from ...utils import logging - - -logger = logging.get_logger(__name__) class FalconMambaConfig(PretrainedConfig): @@ -28,7 +29,7 @@ class FalconMambaConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the FALCON_MAMBA - [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture. + [state-spaces/falcon_mamba-2.8b](https://huggingface.co/state-spaces/falcon_mamba-2.8b) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -79,10 +80,12 @@ class FalconMambaConfig(PretrainedConfig): Whether or not to rescale `out_proj` weights when initializing. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. - use_mambapy (`bool`, *optional*, defaults to `False`): + use_falcon_mambapy (`bool`, *optional*, defaults to `False`): Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not available. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. mixer_rms_eps (`float`, *optional*, defaults to 1e-06): The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states. + + Example: ```python @@ -125,10 +128,11 @@ def __init__( time_step_floor=1e-4, rescale_prenorm_residual=False, use_cache=True, - use_mambapy=False, + use_falcon_mambapy=False, mixer_rms_eps=1e-6, **kwargs, ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.state_size = state_size @@ -153,10 +157,8 @@ def __init__( self.rescale_prenorm_residual = rescale_prenorm_residual self.residual_in_fp32 = residual_in_fp32 self.use_cache = use_cache - self.use_mambapy = use_mambapy + self.use_falcon_mambapy = use_falcon_mambapy self.mixer_rms_eps = mixer_rms_eps - super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) - __all__ = ["FalconMambaConfig"] diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 426e557d9d3c..07fb3c1d849c 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_falcon_mamba.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. # @@ -12,19 +18,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch FALCONMAMBA model.""" import math from dataclasses import dataclass from typing import Any, Optional, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import MambaCache +from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel @@ -35,27 +39,127 @@ logger = logging.get_logger(__name__) -if is_mambapy_available(): - from mambapy.pscan import pscan -else: - pscan = None -if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn - from mamba_ssm.ops.triton.selective_state_update import selective_state_update +class FalconMambaCache: + """ + Cache for falcon_mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache + + >>> model = FalconMambaForCausalLM.from_pretrained("state-spaces/falcon_mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/falcon_mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + FalconMambaCache() + ``` + """ - from ...kernels.falcon_mamba import mamba_inner_fn -else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + is_compileable = True -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Union[torch.device, str, None] = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. FalconMamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() + + +def is_fast_path_available(): + if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + + if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + else: + causal_conv1d_update, causal_conv1d_fn = None, None + + return ( + all((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)), + selective_state_update, + selective_scan_fn, + mamba_inner_fn, + causal_conv1d_update, + causal_conv1d_fn, + ) def rms_forward(hidden_states, variance_epsilon=1e-6): @@ -107,7 +211,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] - self.use_mambapy = config.use_mambapy + self.use_falcon_mambapy = config.use_falcon_mambapy # projection of the input hidden states self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) @@ -126,6 +230,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias + self.warn_slow_implementation() # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here self.register_buffer( "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False @@ -135,6 +240,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): ) self.rms_eps = config.mixer_rms_eps + def warn_slow_implementation(self): if not is_fast_path_available: if self.use_mambapy: if is_mambapy_available(): @@ -157,10 +263,21 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): + if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + + if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + else: + causal_conv1d_update, causal_conv1d_fn = None, None + # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -269,13 +386,17 @@ def cuda_kernels_forward( contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states - def slow_forward( - self, + # fmt: off + def slow_forward(self, input_states, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): + if is_mambapy_available(): + from mambapy.pscan import pscan + else: + pscan = None batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection @@ -344,7 +465,7 @@ def slow_forward( deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) - if self.use_mambapy and self.training and cache_params is None: + if self.use_falcon_mambapy and self.training and cache_params is None: hs = pscan( discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2) ) # [batch, seq_len, intermediate_size, ssm_state_size] @@ -371,12 +492,12 @@ def slow_forward( # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] return contextualized_states + # fmt: on - # Copied from transformers.models.mamba.modeling_mamba.MambaMixer.forward def forward( self, hidden_states, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): @@ -385,7 +506,6 @@ def forward( return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) -# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba class FalconMambaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -395,17 +515,15 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def extra_repr(self): - return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" - - # Ignore copy def forward(self, hidden_states): return self.weight.to(hidden_states.device) * rms_forward( hidden_states, variance_epsilon=self.variance_epsilon ) + def extra_repr(self): + return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" + -# Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache class FalconMambaBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() @@ -418,7 +536,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): @@ -435,7 +553,6 @@ def forward( @auto_docstring -# Copied from transformers.models.mamba.modeling_mamba.MambaPreTrainedModel with Mamba->FalconMamba class FalconMambaPreTrainedModel(PreTrainedModel): config_class = FalconMambaConfig base_model_prefix = "backbone" @@ -494,13 +611,12 @@ def _init_weights(self, module): @dataclass @auto_docstring( custom_intro=""" - Class for the FALCONMAMBA model outputs. + Class for the FALCON_MAMBA model outputs. """ ) -# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->FALCONMAMBA,Mamba->FalconMamba,FalconMambaCache->MambaCache class FalconMambaOutput(ModelOutput): r""" - cache_params (`MambaCache`): + cache_params (`FalconMambaCache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -508,7 +624,7 @@ class FalconMambaOutput(ModelOutput): """ last_hidden_state: Optional[torch.FloatTensor] = None - cache_params: Optional[MambaCache] = None + cache_params: Optional[FalconMambaCache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None @@ -518,14 +634,13 @@ class FalconMambaOutput(ModelOutput): Base class for causal language model (or autoregressive) outputs. """ ) -# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->FalconMamba,FalconMambaCache->MambaCache class FalconMambaCausalLMOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`MambaCache`): + cache_params (`FalconMambaCache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -534,7 +649,7 @@ class FalconMambaCausalLMOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None - cache_params: Optional[MambaCache] = None + cache_params: Optional[FalconMambaCache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None @@ -551,8 +666,15 @@ def __init__(self, config): self.gradient_checkpointing = False self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) self.post_init() + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + def get_input_embeddings(self): return self.embeddings @@ -564,7 +686,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -572,7 +694,7 @@ def forward( attention_mask: Optional[torch.LongTensor] = None, ) -> Union[tuple, FalconMambaOutput]: r""" - cache_params (`MambaCache`, *optional*): + cache_params (`FalconMambaCache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): @@ -595,7 +717,7 @@ def forward( if use_cache: if cache_params is None: - cache_params = MambaCache( + cache_params = FalconMambaCache( self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype ) cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) @@ -610,6 +732,7 @@ def forward( ) else: cache_params = None + hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None for mixer_block in self.layers: @@ -640,11 +763,10 @@ def forward( @auto_docstring( custom_intro=""" - The FALCONMAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input + The FALCON_MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). """ ) -# Copied from transformers.models.mamba.modeling_mamba.MambaForCausalLM with MAMBA->FALCONMAMBA,Mamba->FalconMamba,mamba->falcon_mamba,FalconMambaCache->MambaCache class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -691,14 +813,21 @@ def prepare_inputs_for_generation( input_ids, inputs_embeds=None, use_cache=None, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, **kwargs, ): # Overwritten -- uses `cache_params` as opposed to `past_key_values` + cache_params_not_initialized = cache_params is None if use_cache: + if cache_params_not_initialized: + max_batch_size = inputs_embeds.size(0) if inputs_embeds is not None else input_ids.size(0) + cache_params = FalconMambaCache( + self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype + ) + cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device) # `cache_position` should have been initialized in `generate` if cache_position is None: raise ValueError( @@ -719,7 +848,7 @@ def prepare_inputs_for_generation( # the length of `cache_params.conv_states`, which is `config.conv_kernel` cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) - if inputs_embeds is not None and cache_params is None: + if inputs_embeds is not None and cache_params_not_initialized: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} @@ -740,7 +869,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -749,7 +878,7 @@ def forward( **kwargs, # for now we need this for generation ) -> Union[tuple, FalconMambaCausalLMOutput]: r""" - cache_params (`MambaCache`, *optional*): + cache_params (`FalconMambaCache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -798,4 +927,4 @@ def forward( ) -__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel"] +__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", "FalconMambaCache"] diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py new file mode 100644 index 000000000000..7001cf982d0e --- /dev/null +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -0,0 +1,537 @@ +# coding=utf-8 +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch FALCONMAMBA model.""" + +from typing import Optional + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...utils import auto_docstring, logging +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available +from ..mamba.configuration_mamba import MambaConfig +from ..mamba.modeling_mamba import ( + MambaBlock, + MambaCache, + MambaCausalLMOutput, + MambaForCausalLM, + MambaMixer, + MambaModel, + MambaOutput, + MambaPreTrainedModel, + MambaRMSNorm, +) + + +logger = logging.get_logger(__name__) + + +def is_fast_path_available(): + if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + + if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + else: + causal_conv1d_update, causal_conv1d_fn = None, None + + return ( + all((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)), + selective_state_update, + selective_scan_fn, + mamba_inner_fn, + causal_conv1d_update, + causal_conv1d_fn, + ) + + +( + is_fast_path_available, + selective_state_update, + selective_scan_fn, + mamba_inner_fn, + causal_conv1d_update, + causal_conv1d_fn, +) = is_fast_path_available() + + +class FalconMambaConfig(MambaConfig): + """ + This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the FALCON_MAMBA + [state-spaces/falcon_mamba-2.8b](https://huggingface.co/state-spaces/falcon_mamba-2.8b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50280): + Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`FalconMambaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + state_size (`int`, *optional*, defaults to 16): shape of the state space latents. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the model. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 0): + The id of the end of sentence token in the vocabulary. + expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_scale (`float`, *optional*, defaults to 1.0): + Scale used used to scale `dt_proj.bias`. + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_init_scheme (`float`, *optional*, defaults to `"random"`): + Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]` + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the cache should be used. + use_falcon_mambapy (`bool`, *optional*, defaults to `False`): + Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not available. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. + mixer_rms_eps (`float`, *optional*, defaults to 1e-06): + The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states. + + + Example: + + ```python + >>> from transformers import FalconMambaConfig, FalconMambaModel + + >>> # Initializing a FalconMamba configuration + >>> configuration = FalconMambaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = FalconMambaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + def __init__( + self, + vocab_size=50280, + hidden_size=768, + state_size=16, + num_hidden_layers=32, + layer_norm_epsilon=1e-5, + pad_token_id=0, + bos_token_id=0, + eos_token_id=0, + expand=2, + conv_kernel=4, + use_bias=False, + use_conv_bias=True, + hidden_act="silu", + initializer_range=0.1, + residual_in_fp32=True, + time_step_rank="auto", + time_step_scale=1.0, + time_step_min=0.001, + time_step_max=0.1, + time_step_init_scheme="random", + time_step_floor=1e-4, + rescale_prenorm_residual=False, + use_cache=True, + use_falcon_mambapy=False, + mixer_rms_eps=1e-6, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + state_size=state_size, + num_hidden_layers=num_hidden_layers, + layer_norm_epsilon=layer_norm_epsilon, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + expand=expand, + conv_kernel=conv_kernel, + use_bias=use_bias, + use_conv_bias=use_conv_bias, + hidden_act=hidden_act, + initializer_range=initializer_range, + residual_in_fp32=residual_in_fp32, + time_step_rank=time_step_rank, + time_step_scale=time_step_scale, + time_step_min=time_step_min, + time_step_max=time_step_max, + time_step_init_scheme=time_step_init_scheme, + time_step_floor=time_step_floor, + rescale_prenorm_residual=rescale_prenorm_residual, + use_cache=use_cache, + use_falcon_mambapy=use_falcon_mambapy, + **kwargs, + ) + self.mixer_rms_eps = mixer_rms_eps + + +class FalconMambaCache(MambaCache): + pass + + +def rms_forward(hidden_states, variance_epsilon=1e-6): + """ + Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will + leverage this in order to multiply the final result with the RMSNorm weight + + Args: + hidden_states (`torch.Tensor`): + Hidden states to normalize + variance_epsilon (`float`): + The eps value to add in the square root scaling factor + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) + return hidden_states.to(input_dtype) + + +class FalconMambaMixer(MambaMixer): + def warn_slow_implementation(self): + if not is_fast_path_available: + if self.use_mambapy: + if is_mambapy_available(): + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + raise ImportError( + "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py." + ) + else: + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." + ) + + def __init__(self, config: FalconMambaConfig, layer_idx: int): + super().__init__(config, layer_idx) + # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here + self.register_buffer( + "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False + ) + self.register_buffer( + "dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False + ) + self.rms_eps = config.mixer_rms_eps + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + + if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + else: + causal_conv1d_update, causal_conv1d_fn = None, None + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + if self.training and cache_params is None: # Doesn't support outputting the states -> used for training + contextualized_states = mamba_inner_fn( + projected_states, + self.conv1d.weight, + self.conv1d.bias if self.use_conv_bias else None, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias.float() if self.use_bias else None, + -torch.exp(self.A_log.float()), + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + b_rms_weight=self.b_c_rms, + c_rms_weight=self.b_c_rms, + dt_rms_weight=self.dt_rms, + b_c_dt_rms_eps=self.rms_eps, + ) + + else: + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if cache_params is not None and cache_position[0] > 0: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.update_conv_state(self.layer_idx, conv_states, cache_position) + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) + + # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module + # at the price of a small overhead. + if hasattr(self.config, "_pre_quantization_dtype"): + discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2) + else: + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + if cache_params is not None and cache_position[0] > 0: + scan_outputs = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + def slow_forward( + self, + input_states, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + if is_mambapy_available(): + from mambapy.pscan import pscan + else: + pscan = None + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + # use `cache_position.shape[0]` to check whether we are in prefill + # stage, it's equivalent to check `cache_position[0] == 0`, which + # breaks dynamo fullgraph constraints + if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size: + conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + + cache_params.update_conv_state(self.layer_idx, conv_state, cache_position) + hidden_states = self.act( + self.conv1d(hidden_states)[..., :seq_len] + ) # [batch, intermediate_size, seq_len] + else: + conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position) + conv_state = conv_state.to(self.conv1d.weight.device) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = ( + self.act(hidden_states).to(dtype).unsqueeze(-1) + ) # [batch, intermediate_size, 1] : decoding + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) + + discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose( + 1, 2 + ) # [batch, intermediate_size, seq_len] + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] + discrete_A = torch.exp( + A[None, :, None, :] * discrete_time_step[:, :, :, None] + ) # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_B = ( + discrete_time_step[:, :, :, None] * B[:, None, :, :].float() + ) # [batch, intermediate_size, seq_len, ssm_state_size] + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + if self.use_falcon_mambapy and self.training and cache_params is None: + hs = pscan( + discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2) + ) # [batch, seq_len, intermediate_size, ssm_state_size] + scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] + scan_output = scan_output + hidden_states * self.D[None, :, None] + scan_output = scan_output * self.act(gate) + else: + scan_outputs = [] + for i in range(seq_len): + ssm_state = ( + discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] + ) # [batch, intermediate_size, ssm_state] + scan_output = torch.matmul( + ssm_state.to(dtype), C[:, i, :].unsqueeze(-1) + ) # [batch, intermediate_size, 1] + scan_outputs.append(scan_output[:, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = scan_output * self.act(gate) + + if cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + return contextualized_states + + +class FalconMambaRMSNorm(MambaRMSNorm): + def forward(self, hidden_states): + return self.weight.to(hidden_states.device) * rms_forward( + hidden_states, variance_epsilon=self.variance_epsilon + ) + + +class FalconMambaBlock(MambaBlock): + pass + + +@auto_docstring +class FalconMambaPreTrainedModel(MambaPreTrainedModel): + pass + + +class FalconMambaOutput(MambaOutput): + pass + + +class FalconMambaCausalLMOutput(MambaCausalLMOutput): + pass + + +class FalconMambaModel(MambaModel): + pass + + +class FalconMambaForCausalLM(MambaForCausalLM): + pass + + +__all__ = [ + "FalconMambaForCausalLM", + "FalconMambaModel", + "FalconMambaPreTrainedModel", + "FalconMambaCache", + "FalconMambaConfig", +] diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 0b93d4484c9f..44e29c4f0f91 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -182,7 +182,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class HybridMambaAttentionDynamicCache(DynamicCache): +class HybridMambaAttentionDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -196,6 +196,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__(self, config, batch_size, dtype=torch.float16, device=None): super().__init__() self.dtype = dtype diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index f2347833db6c..5d63c166372b 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -24,7 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import MambaCache +from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel @@ -39,25 +39,137 @@ logger = logging.get_logger(__name__) -if is_mambapy_available(): - from mambapy.pscan import pscan -else: - pscan = None -if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None +def is_fast_path_available(): + if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None + if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + else: + causal_conv1d_update, causal_conv1d_fn = None, None -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) + return ( + all((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)), + selective_state_update, + selective_scan_fn, + mamba_inner_fn, + causal_conv1d_update, + causal_conv1d_fn, + ) + + +( + is_fast_path_available, + selective_state_update, + selective_scan_fn, + mamba_inner_fn, + causal_conv1d_update, + causal_conv1d_fn, +) = is_fast_path_available() + + +class MambaCache: + """ + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + MambaCache() + ``` + """ + + is_compileable = True + + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Union[torch.device, str, None] = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. Mamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() class MambaMixer(nn.Module): @@ -109,6 +221,9 @@ def __init__(self, config: MambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias + self.warn_slow_implementation() + + def warn_slow_implementation(self): if not is_fast_path_available: if self.use_mambapy: if is_mambapy_available(): @@ -231,6 +346,10 @@ def cuda_kernels_forward( # fmt: off def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None): + if is_mambapy_available(): + from mambapy.pscan import pscan + else: + pscan = None batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection @@ -638,7 +757,12 @@ def prepare_inputs_for_generation( ): # Overwritten -- uses `cache_params` as opposed to `past_key_values` + cache_params_not_initialized = cache_params is None if use_cache: + if cache_params_not_initialized: + max_batch_size = inputs_embeds.size(0) if inputs_embeds is not None else input_ids.size(0) + cache_params = MambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype) + cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device) # `cache_position` should have been initialized in `generate` if cache_position is None: raise ValueError( @@ -659,7 +783,7 @@ def prepare_inputs_for_generation( # the length of `cache_params.conv_states`, which is `config.conv_kernel` cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) - if inputs_embeds is not None and cache_params is None: + if inputs_embeds is not None and cache_params_not_initialized: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} @@ -738,4 +862,4 @@ def forward( ) -__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel"] +__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel", "MambaCache"] diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 1f663462d5e1..69737bb7c4b5 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -959,7 +959,12 @@ def prepare_inputs_for_generation( ): # Overwritten -- uses `cache_params` as opposed to `past_key_values` + cache_params_not_initialized = cache_params is None if use_cache: + if cache_params_not_initialized: + max_batch_size = inputs_embeds.size(0) if inputs_embeds is not None else input_ids.size(0) + cache_params = Mamba2Cache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype) + cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device) # `cache_position` should have been initialized in `generate` if cache_position is None: raise ValueError( @@ -979,7 +984,7 @@ def prepare_inputs_for_generation( # the length of `cache_params.conv_states`, which is `config.conv_kernel` cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) - if inputs_embeds is not None and cache_params is None: + if inputs_embeds is not None and cache_params_not_initialized: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 6d77801d4ef6..b64d0107395e 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -93,7 +93,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class ZambaHybridDynamicCache(DynamicCache): +class ZambaHybridDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -107,8 +107,13 @@ class ZambaHybridDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.dtype = dtype + self.is_compileable = False self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba self.intermediate_size = config.mamba_expand * config.hidden_size diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index ecd0abcb0263..1c7a489784f3 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -97,7 +97,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Zamba2HybridDynamicCache(DynamicCache): +class Zamba2HybridDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -111,6 +111,10 @@ class Zamba2HybridDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__( self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None ): @@ -1387,7 +1391,7 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if past_key_values and not past_key_values.has_previous_state: + if past_key_values is not None and not past_key_values.has_previous_state: past_key_values.has_previous_state = True output = BaseModelOutputWithPast( diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index a89ab2729f21..990d725cd0f0 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1142,7 +1142,7 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if past_key_values and not past_key_values.has_previous_state: + if past_key_values is not None and not past_key_values.has_previous_state: past_key_values.has_previous_state = True output = BaseModelOutputWithPast( diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e865237485a6..7026bf1697c8 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -44,13 +44,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class MambaCache(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - class OffloadedCache(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index e59787fb8c65..ad63326ac0d4 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -26,7 +26,6 @@ require_torch_accelerator, require_torch_large_accelerator, require_torch_multi_accelerator, - require_torch_multi_gpu, slow, torch_device, ) @@ -41,10 +40,10 @@ import torch from transformers import ( + FalconMambaCache, FalconMambaForCausalLM, FalconMambaModel, ) - from transformers.cache_utils import MambaCache # Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba @@ -312,31 +311,6 @@ def assertInterval(self, member, container, msg=None): def test_config(self): self.config_tester.run_common_tests() - @require_torch_multi_gpu - def test_multi_gpu_data_parallel_forward(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - # some params shouldn't be scattered by nn.DataParallel - # so just remove them if they are present. - blacklist_non_batched_params = ["cache_params"] - for k in blacklist_non_batched_params: - inputs_dict.pop(k, None) - - # move input tensors to cuda:O - for k, v in inputs_dict.items(): - if torch.is_tensor(v): - inputs_dict[k] = v.to(0) - - for model_class in self.all_model_classes: - model = model_class(config=config) - model.to(0) - model.eval() - - # Wrap model in nn.DataParallel - model = torch.nn.DataParallel(model) - with torch.no_grad(): - _ = model(**self._prepare_for_class(inputs_dict, model_class)) - def test_falcon_mamba_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_falcon_mamba_model(*config_and_inputs) @@ -396,7 +370,7 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, MambaCache): # MODIFIED PART START + if isinstance(tuple_object, FalconMambaCache): # MODIFIED PART START recursive_check(tuple_object.conv_states, dict_object.conv_states) recursive_check(tuple_object.ssm_states, dict_object.ssm_states) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END @@ -443,6 +417,10 @@ def recursive_check(tuple_object, dict_object): dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + @unittest.skip("Mamba models do not support DDP.") + def test_multi_gpu_data_parallel_forward(self): + pass + @require_torch @require_torch_accelerator @@ -482,7 +460,9 @@ def test_generation_fp16(self): @require_bitsandbytes def test_generation_4bit(self): quantization_config = BitsAndBytesConfig(load_in_4bit=True) - model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config) + model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config).to( + torch_device + ) inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device) out = model.generate(**inputs, max_new_tokens=20, do_sample=False) @@ -498,6 +478,7 @@ def test_generation_torch_compile(self): inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device) out = model.generate(**inputs, max_new_tokens=20, do_sample=False) + print(self.tokenizer.batch_decode(out, skip_special_tokens=False)[0]) self.assertEqual( self.tokenizer.batch_decode(out, skip_special_tokens=False)[0], @@ -528,7 +509,7 @@ def test_batched_generation(self): inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device) model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.float16) - out = model.generate(**inputs, max_new_tokens=20) + out = model.generate(**inputs, max_new_tokens=20, do_sample=False) out = tok.batch_decode(out, skip_special_tokens=True) self.assertListEqual(out, EXPECTED_OUTPUT) @@ -538,7 +519,7 @@ def test_batched_generation(self): inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids")) inputs["inputs_embeds"] = inputs_embeds - out = model.generate(**inputs, max_new_tokens=20) + out = model.generate(**inputs, max_new_tokens=20, do_sample=False) out = tok.batch_decode(out, skip_special_tokens=True) EXPECTED_OUTPUTS = Expectations( diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 840493648ffc..3f87481b0703 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -20,7 +20,7 @@ from parameterized import parameterized from transformers import AutoTokenizer, MambaConfig, is_torch_available -from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device +from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -32,10 +32,10 @@ import torch from transformers import ( + MambaCache, MambaForCausalLM, MambaModel, ) - from transformers.models.mamba.modeling_mamba import MambaCache class MambaModelTester: @@ -279,31 +279,6 @@ def assertInterval(self, member, container, msg=None): def test_config(self): self.config_tester.run_common_tests() - @require_torch_multi_gpu - def test_multi_gpu_data_parallel_forward(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - # some params shouldn't be scattered by nn.DataParallel - # so just remove them if they are present. - blacklist_non_batched_params = ["cache_params"] - for k in blacklist_non_batched_params: - inputs_dict.pop(k, None) - - # move input tensors to cuda:O - for k, v in inputs_dict.items(): - if torch.is_tensor(v): - inputs_dict[k] = v.to(0) - - for model_class in self.all_model_classes: - model = model_class(config=config) - model.to(0) - model.eval() - - # Wrap model in nn.DataParallel - model = torch.nn.DataParallel(model) - with torch.no_grad(): - _ = model(**self._prepare_for_class(inputs_dict, model_class)) - def test_mamba_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_mamba_model(*config_and_inputs) @@ -437,6 +412,10 @@ def test_dtype_mismatch_handled_in_cache(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size), ) + @unittest.skip("Mamba models do not support DDP.") + def test_multi_gpu_data_parallel_forward(self): + pass + @require_torch class MambaIntegrationTests(unittest.TestCase): @@ -532,11 +511,11 @@ def test_compile_mamba_cache(self): torch_device ) - output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba") + output = model.generate(input_ids, max_new_tokens=20) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") - output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba") + output = model.generate(input_ids, max_new_tokens=20) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 8c864f9b64f1..06bf5a381a40 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -63,8 +63,6 @@ TEST_CACHE_IMPLEMENTATIONS = [ cache_name for cache_name in ALL_CACHE_IMPLEMENTATIONS - # TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`? - if cache_name != "mamba" # TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them if cache_name != "offloaded_hybrid" ] @@ -1119,3 +1117,44 @@ def test_hybrid_cache_sliding_mode(self): [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 4 failed", ) + + def test_dynamic_cache(self): + """Test DynamicCache with manually prefilled states and hardcoded assertions. + Scenario 1: prefill and update for one layer + prefill: [1.0, 2.0] + update pos 2: [1.0, 2.0, 3.0] + Scenario 2: prefill and update for two layers independently + """ + prefill = torch.tensor([1.0, 2.0])[None, None, :, None] + update3 = torch.tensor(3.0)[None, None, None, None] + update4 = torch.tensor(4.0)[None, None, None, None] + + # Scenario 1: prefill and update for one layer + cache = DynamicCache() + cache.update(prefill, prefill, 0) + cache.update(update3, update3, 0) + self.assertEqual(cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed") + cache.update(update4, update4, 0) + self.assertEqual( + cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed" + ) + + # Scenario 2: prefill and update for two layers independently + prefill1 = torch.tensor([10.0, 20.0])[None, None, :, None] + update3_1 = torch.tensor(30.0)[None, None, None, None] + update4_1 = torch.tensor(40.0)[None, None, None, None] + + cache = DynamicCache() + cache.update(prefill, prefill, 0) + cache.update(prefill1, prefill1, 1) + + cache.update(update3, update3, 0) + cache.update(update3_1, update3_1, 1) + cache.update(update4, update4, 0) + cache.update(update4_1, update4_1, 1) + self.assertEqual( + cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed" + ) + self.assertEqual( + cache.key_cache[1][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0], "DynamicCache Scenario 2 layer 1 failed" + ) From 04d7a0b12a597bb0a4d9a80c23193b5ae8357404 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 30 Jun 2025 15:51:22 +0200 Subject: [PATCH 02/35] fix quantized, add tests --- src/transformers/cache_utils.py | 302 +++++++++++++++------------ src/transformers/generation/utils.py | 7 +- tests/utils/test_cache_utils.py | 59 +++++- 3 files changed, 232 insertions(+), 136 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d05220cb62a5..97c2c765dc13 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -214,7 +214,9 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None: self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) def __repr__(self): - return f"{self.__class__.__name__}(K={self.key_cache}, V={self.value_cache})" + key_repr = "None" if self.key_cache is None else f"t({tuple(self.key_cache.shape)})" + value_repr = "None" if self.value_cache is None else f"t({tuple(self.value_cache.shape)})" + return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})" class Cache: @@ -860,7 +862,7 @@ def update( def get_seq_length(self, cache_position: Optional[torch.LongTensor] = None) -> int: """Returns the sequence length of the cached states.""" # TODO: deprecate this function in favor of `cache_position` - if self is None or self.key_cache is None: + if self is None or self.key_cache is None or self.key_cache.numel() == 0: return 0 return self.key_cache.shape[-2] @@ -1017,94 +1019,6 @@ def __init__(self, config: Optional[CacheConfig] = None) -> None: super().__init__(processors=processors, config=config) -class QuantoQuantizedCache(DynamicCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - - Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. - - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. - - Example: - - ```python - >>> # Run pip install quanto first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4) - >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - QuantoQuantizedCache() - ``` - """ - - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) - super(DynamicCache, self).__init__(processors=processors) - - -class HQQQuantizedCache(DynamicCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - - Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. - - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. - - Example: - - ```python - >>> # Run pip install hqq first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) - >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HQQQuantizedCache() - ``` - """ - - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)]) - super(DynamicCache, self).__init__(processors=processors) - - class StaticLayer(CacheLayer): is_compileable = True @@ -2120,56 +2034,60 @@ def post_update( if layer_idx == 0: self._seen_tokens += key_tensors.shape[-2] - # Extend quantized cache if needed - while len(self._quantized_key_cache) <= layer_idx: - self._quantized_key_cache.append(torch.empty(0)) - self._quantized_value_cache.append(torch.empty(0)) + if len(cache.key_cache) < layer_idx: + raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") - # Check if we need to quantize - if layer_idx < len(cache.key_cache): - current_key = cache.key_cache[layer_idx] - current_value = cache.value_cache[layer_idx] + # `key_tensors` is the content of the residual cache, after having been updated by DynamicLayer + # On the first forward pass, we quantize the whole prompt. + # On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full. + is_prefill = self._get_quantized_length(layer_idx) == 0 - if ( - current_key.dim() == 4 - and current_key.shape[-2] >= self.config.residual_length - and current_key.shape[-2] > self._get_quantized_length(layer_idx) - ): - # Quantize the older part, keep recent tokens in original precision - split_idx = current_key.shape[-2] - self.config.residual_length + if is_prefill: + self._quantized_key_cache.append(self._quantize(key_tensors.contiguous(), axis=self.config.axis_key)) + self._quantized_value_cache.append(self._quantize(value_tensors.contiguous(), axis=self.config.axis_value)) - # Get the part to quantize - key_to_quantize = current_key[:, :, :split_idx, :].contiguous() - value_to_quantize = current_value[:, :, :split_idx, :].contiguous() + # Clear the residual cache + cache.key_cache[layer_idx] = torch.zeros( + 0, + dtype=key_tensors.dtype, + device=key_tensors.device, + ) + cache.value_cache[layer_idx] = torch.zeros( + 0, + dtype=value_tensors.dtype, + device=value_tensors.device, + ) + # On prefill, we return the original prompt + keys_to_return, values_to_return = key_tensors, value_tensors + else: + # Prepend the previously quantized cache + dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) + dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2) + values_to_return = torch.cat([dequant_value, value_tensors], dim=-2) + if key_tensors.shape[-2] >= self.config.residual_length: # Quantize and store - self._quantized_key_cache[layer_idx] = self._quantize(key_to_quantize, axis=self.config.axis_key) - self._quantized_value_cache[layer_idx] = self._quantize(value_to_quantize, axis=self.config.axis_value) - - # Keep only the recent tokens in original precision - cache.key_cache[layer_idx] = current_key[:, :, split_idx:, :] - cache.value_cache[layer_idx] = current_value[:, :, split_idx:, :] - - # Return the full tensors for this update - if self._quantized_key_cache[layer_idx].numel() > 0: - dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) - dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) - full_key = torch.cat([dequant_key, cache.key_cache[layer_idx]], dim=-2) - full_value = torch.cat([dequant_value, cache.value_cache[layer_idx]], dim=-2) - return full_key, full_value + self._quantized_key_cache[layer_idx] = self._quantize( + keys_to_return.contiguous(), axis=self.config.axis_key + ) + self._quantized_value_cache[layer_idx] = self._quantize( + values_to_return.contiguous(), axis=self.config.axis_value + ) - return key_tensors, value_tensors + # Clear the residual cache + cache.key_cache[layer_idx] = torch.zeros( + 0, + dtype=key_tensors.dtype, + device=key_tensors.device, + ) + cache.value_cache[layer_idx] = torch.zeros( + 0, + dtype=value_tensors.dtype, + device=value_tensors.device, + ) - def _get_quantized_length(self, layer_idx: int) -> int: - """Get the length of quantized cache for a layer.""" - if layer_idx < len(self._quantized_key_cache) and self._quantized_key_cache[layer_idx].numel() > 0: - # This would depend on the specific quantization implementation - return ( - self._quantized_key_cache[layer_idx].shape[-2] - if hasattr(self._quantized_key_cache[layer_idx], "shape") - else 0 - ) - return 0 + return keys_to_return, values_to_return def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: """Quantize a tensor - to be implemented by specific quantization backends.""" @@ -2227,6 +2145,12 @@ def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor: """Dequantize tensor using quanto backend.""" return qtensor.dequantize() + def _get_quantized_length(self, layer_idx: int) -> int: + """Get the length of quantized cache for a layer.""" + if layer_idx < len(self._quantized_key_cache): + return self._quantized_key_cache[layer_idx].shape[-2] + return 0 + class HQQQuantizedCacheProcessor(QuantizedCacheProcessor): """ @@ -2277,6 +2201,12 @@ def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tens tensor = self.quantizer.dequantize(quant_tensor, meta) return tensor + def _get_quantized_length(self, layer_idx: int) -> int: + """Get the length of quantized cache for a layer.""" + if layer_idx < len(self._quantized_key_cache): + return self._quantized_key_cache[layer_idx][0].shape[-2] + return 0 + class QuantizedCache(DynamicCache): """ @@ -2293,9 +2223,113 @@ class QuantizedCache(DynamicCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: - processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) + if cache_config.backend == "quanto": + processor = QuantoQuantizedCacheProcessor(cache_config) + elif cache_config.backend == "hqq": + processor = HQQQuantizedCacheProcessor(cache_config) + else: + raise ValueError(f"Unknown quantization backend `{cache_config.backend}`") + + processors = CacheProcessorList([processor]) super().__init__(processors=processors) + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is + # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx + # this part of code otherwise fails when used to verify attn_weight shape in some models + return self.processors[0]._seen_tokens if layer_idx == 0 else self.processors[0]._seen_tokens - 1 + + +class QuantoQuantizedCache(QuantizedCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + + Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + + Parameters: + cache_config (`QuantizedCacheConfig`): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install quanto first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4) + >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + QuantoQuantizedCache() + ``` + """ + + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) + Cache.__init__(self, processors=processors) + + +class HQQQuantizedCache(QuantizedCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + + Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + + Parameters: + cache_config (`QuantizedCacheConfig`): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install hqq first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) + >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HQQQuantizedCache() + ``` + """ + + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)]) + Cache.__init__(self, processors=processors) + class SinkCache(Cache): """ diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 74cdbbc9d982..242b16195460 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -574,7 +574,12 @@ def prepare_inputs_for_generation( # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly # (this alternative is not as robust as calling `generate` and letting it create `cache_position`) elif cache_position is None: - past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_length = 0 + if past_key_values is not None: + if not isinstance(past_key_values, Cache): + past_length = past_key_values[0][0].shape[2] + elif hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() is not None: + past_length = past_key_values.get_seq_length() cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) # 2. Generic cache-dependent input preparation diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 06bf5a381a40..de110b872a7e 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -36,7 +36,7 @@ slow, torch_device, ) -from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal +from transformers.utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal if is_torch_available(): @@ -52,11 +52,13 @@ GenerationConfig, HybridCache, LlamaConfig, + QuantizedCache, SlidingWindowCache, StaticCache, convert_and_export_with_cache, pipeline, ) + from transformers.cache_utils import HQQQuantizedCacheProcessor, QuantoQuantizedCacheProcessor from transformers.integrations.executorch import export_with_dynamic_cache @@ -283,6 +285,61 @@ def test_cache_beam_search(self, cache_implementation): decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) self.assertListEqual(decoded, EXPECTED_GENERATION) + @parameterized.expand([("quanto"), ("HQQ")]) + def test_quantized_cache_generation(self, backend): + """Tests that QuantizedCache works as expected for both `quanto` and `hqq` backends.""" + if backend == "quanto": + if not is_optimum_quanto_available(): + self.skipTest("Quanto is not available") + axis_key, axis_value = 0, 0 + # This output is taken from a run with the same parameters, and is known to be correct + expected_generation = ["The cat's whiskers are also a sign of anxiety."] + elif backend == "HQQ": + if not is_hqq_available(): + self.skipTest("HQQ is not available") + axis_key, axis_value = 1, 1 + # HQQ has slightly different numerics + expected_generation = ["The cat's whiskers are also a sign of anxiety."] + else: + return + + inputs = self.tokenizer(["The cat"], return_tensors="pt").to(self.model.device) + + gen_out = self.model.generate( + **inputs, + do_sample=False, + max_new_tokens=10, + return_dict_in_generate=True, + cache_implementation="quantized", + cache_config={ + "backend": backend, + "nbits": 4, + "q_group_size": 16, + "residual_length": 4, + "axis_key": axis_key, + "axis_value": axis_value, + }, + disable_compile=True, + ) + + self.assertIsInstance(gen_out.past_key_values, QuantizedCache) + processor = gen_out.past_key_values.processors[0] + if backend == "quanto": + self.assertIsInstance(processor, QuantoQuantizedCacheProcessor) + elif backend == "hqq": + self.assertIsInstance(processor, HQQQuantizedCacheProcessor) + + decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) + self.assertListEqual(decoded, expected_generation) + + self.assertTrue(len(processor._quantized_key_cache) > 0) + + # Check that something is actually quantized + has_been_quantized = any( + (q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_key_cache + ) + self.assertTrue(has_been_quantized) + @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) def test_cache_extra_left_padding(self, cache_implementation): """Tests that adding extra left-padding does not affect the generation with the cache""" From 26c28af61632dca73ce0f784e7596928e773768a Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 30 Jun 2025 18:24:18 +0200 Subject: [PATCH 03/35] remove CacheProcessorList --- src/transformers/cache_utils.py | 103 ++++++++++---------------------- tests/utils/test_cache_utils.py | 2 +- 2 files changed, 32 insertions(+), 73 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 97c2c765dc13..233b72bcbfff 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -86,46 +86,7 @@ def post_update( return key_tensors, value_tensors -class CacheProcessorList(list): - """ - list of cache processors that can be applied to a cache. - """ - - def init(self, cache: "Cache", **kwargs) -> None: - """Initialize all processors in the list.""" - for processor in self: - processor.init(cache, **kwargs) - - def pre_update( - self, - cache: "Cache", - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Apply pre_update hook for all processors.""" - for processor in self: - key_states, value_states = processor.pre_update(cache, key_states, value_states, layer_idx, cache_kwargs) - return key_states, value_states - - def post_update( - self, - cache: "Cache", - key_tensors: torch.Tensor, - value_tensors: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Apply post_update hook for all processors.""" - for processor in self: - key_tensors, value_tensors = processor.post_update( - cache, key_tensors, value_tensors, layer_idx, cache_kwargs - ) - return key_tensors, value_tensors - - -class KVList: +class KVProxy: """Efficiently simulates layer-indexed key or value lists from a layered cache. This allows for BC access, e.g., cache.key_cache[idx] or cache.value_cache[idx].""" @@ -228,8 +189,8 @@ class Cache: config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): Model configuration for shape/device info, or DDP-distributed cache data for compatibility. If DDP-distributed cache data, must be an iterable of (key_states, value_states) tuples for each layer. - processors (`CacheProcessorList`, *optional*): - List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + processor (`CacheProcessor`, *optional*): + Cache processor to apply (e.g., quantization, offloading). pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): Pattern of cache layer types to use. Defaults to `(DynamicLayer,)`. Must be a tuple whose length divides the total number of layers. The pattern repeats to fill all layers. Examples: `(StaticLayer,)` for a @@ -258,13 +219,13 @@ def __init__( config_or_ddp_cache_data: Optional[ Union[PretrainedConfig, Iterable[tuple[torch.Tensor, torch.Tensor]]] ] = None, - processors: Optional[CacheProcessorList] = None, + processor: Optional[CacheProcessor] = None, pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, *args, **kwargs, ): self.layers: list[CacheLayer] = [] - self.processors = processors if processors is not None else CacheProcessorList() + self.processor = processor pattern_block = pattern_block or self.pattern_block or (DynamicLayer,) if isinstance(config_or_ddp_cache_data, PretrainedConfig): @@ -280,7 +241,8 @@ def __init__( assert pattern_block == (DynamicLayer,), "torch DDP is only supported for DynamicCache" for key_states, value_states in _distributed_cache_data: self.layers.append(DynamicLayer.from_kv(key_states, value_states)) - self.processors.init(self, **kwargs) + if self.processor is not None: + self.processor.init(self, **kwargs) return else: model_config = kwargs.pop("config", None) @@ -292,7 +254,8 @@ def __init__( layer = layer_type(self.config.to_layer(idx)) self.layers.append(layer) - self.processors.init(self, **kwargs) + if self.processor is not None: + self.processor.init(self, **kwargs) def grow_layers_to(self, layer_idx): while len(self.layers) <= layer_idx: @@ -302,14 +265,14 @@ def grow_layers_to(self, layer_idx): self.layers.append(next_layer_type()) @property - def key_cache(self) -> KVList: + def key_cache(self) -> KVProxy: """Returns a list-like object of key cache tensors indexed by layer.""" - return KVList(self.layers, "key") + return KVProxy(self.layers, "key") @property - def value_cache(self) -> KVList: + def value_cache(self) -> KVProxy: """Returns a list-like object of value cache tensors indexed by layer.""" - return KVList(self.layers, "value") + return KVProxy(self.layers, "value") def update( self, @@ -335,12 +298,16 @@ def update( Return: A tuple containing the updated key and value states. """ - key_states, value_states = self.processors.pre_update(self, key_states, value_states, layer_idx, cache_kwargs) + if self.processor is not None: + key_states, value_states = self.processor.pre_update( + self, key_states, value_states, layer_idx, cache_kwargs + ) self.grow_layers_to(layer_idx) key_tensors, value_tensors = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) - key_tensors, value_tensors = self.processors.post_update( - self, key_tensors, value_tensors, layer_idx, cache_kwargs - ) + if self.processor is not None: + key_tensors, value_tensors = self.processor.post_update( + self, key_tensors, value_tensors, layer_idx, cache_kwargs + ) return key_tensors, value_tensors def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: @@ -1015,8 +982,7 @@ class OffloadedCache(DynamicCache): def __init__(self, config: Optional[CacheConfig] = None) -> None: # Create the underlying cache with offload processor - processors = CacheProcessorList([OffloadedCacheProcessor()]) - super().__init__(processors=processors, config=config) + super().__init__(processor=OffloadedCacheProcessor(), config=config) class StaticLayer(CacheLayer): @@ -1115,7 +1081,7 @@ class StaticCache(Cache): Parameters: config_or_ddp_cache_data (`Union`, *optional*): Model configuration for shape/device info, or DDP-distributed cache data for compatibility. - processors (`Optional`, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + processor (`Optional`, *optional*): Cache processor to apply (e.g., quantization, offloading). pattern_block (`Optional`, *optional*): Pattern of cache layer types to use. Defaults to `(StaticLayer,)` for backward compatibility. @@ -1429,7 +1395,7 @@ class HybridCache(Cache): Parameters: config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): Model configuration for shape/device info. No DDP-distributed cache data is supported. - processors (`CacheProcessorList`, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + processor (`CacheProcessor`, *optional*): Cache processor to apply (e.g., quantization, offloading). pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): Pattern of cache layer types to use. Defaults to `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)` for backward compatibility. Example: @@ -1455,7 +1421,7 @@ class HybridCache(Cache): def __init__( self, config_or_ddp_cache_data=None, - processors: Optional[CacheProcessorList] = None, + processor: Optional[CacheProcessor] = None, pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, *args, **kwargs, @@ -1469,7 +1435,7 @@ def __init__( self.is_sliding = [False] * model_config.num_hidden_layers pattern_block = tuple(SlidingWindowLayer if sl else StaticLayer for sl in self.is_sliding) - super().__init__(config_or_ddp_cache_data, processors, pattern_block, *args, **kwargs) + super().__init__(config_or_ddp_cache_data, processor, pattern_block, *args, **kwargs) class HybridChunkedCache(Cache): @@ -1878,10 +1844,6 @@ def __init__( offload_device: Union[str, torch.device] = "cpu", layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, ) -> None: - # Create offload processor - processors = CacheProcessorList([OffloadedCacheProcessor(offload_device)]) - - # Initialize the base StaticCache with the processor super().__init__( config=config, max_batch_size=max_batch_size, @@ -1889,7 +1851,7 @@ def __init__( device=device, dtype=dtype, layer_device_map=layer_device_map, - processors=processors, + processor=OffloadedCacheProcessor(offload_device), ) @@ -2230,8 +2192,7 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None: else: raise ValueError(f"Unknown quantization backend `{cache_config.backend}`") - processors = CacheProcessorList([processor]) - super().__init__(processors=processors) + super().__init__(processor=processor) def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" @@ -2240,7 +2201,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx # this part of code otherwise fails when used to verify attn_weight shape in some models - return self.processors[0]._seen_tokens if layer_idx == 0 else self.processors[0]._seen_tokens - 1 + return self.processor._seen_tokens if layer_idx == 0 else self.processor._seen_tokens - 1 class QuantoQuantizedCache(QuantizedCache): @@ -2283,8 +2244,7 @@ class QuantoQuantizedCache(QuantizedCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: - processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) - Cache.__init__(self, processors=processors) + Cache.__init__(self, processor=QuantoQuantizedCacheProcessor(cache_config)) class HQQQuantizedCache(QuantizedCache): @@ -2327,8 +2287,7 @@ class HQQQuantizedCache(QuantizedCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: - processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)]) - Cache.__init__(self, processors=processors) + Cache.__init__(self, processor=HQQQuantizedCacheProcessor(cache_config)) class SinkCache(Cache): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index de110b872a7e..8580316d26bd 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -323,7 +323,7 @@ def test_quantized_cache_generation(self, backend): ) self.assertIsInstance(gen_out.past_key_values, QuantizedCache) - processor = gen_out.past_key_values.processors[0] + processor = gen_out.past_key_values.processor if backend == "quanto": self.assertIsInstance(processor, QuantoQuantizedCacheProcessor) elif backend == "hqq": From 16a662408777d3bb66124192d92a53a5464cb871 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 2 Jul 2025 17:15:16 +0200 Subject: [PATCH 04/35] raushan review, arthur review --- docs/source/en/cache_explanation.md | 20 +- src/transformers/cache_utils.py | 710 ++++++++---------- src/transformers/generation/utils.py | 2 +- src/transformers/integrations/executorch.py | 24 +- src/transformers/masking_utils.py | 24 +- src/transformers/models/bart/modeling_bart.py | 4 +- .../modeling_bigbird_pegasus.py | 4 +- .../models/biogpt/modeling_biogpt.py | 4 +- .../models/blenderbot/modeling_blenderbot.py | 4 +- .../modeling_blenderbot_small.py | 4 +- src/transformers/models/dia/modeling_dia.py | 4 +- src/transformers/models/dia/modular_dia.py | 4 +- .../models/gemma3n/modeling_gemma3n.py | 4 +- .../models/gemma3n/modular_gemma3n.py | 4 +- src/transformers/models/gptj/modeling_gptj.py | 5 +- .../models/informer/modeling_informer.py | 8 +- .../models/informer/modular_informer.py | 4 +- .../models/longt5/modeling_longt5.py | 4 +- .../models/m2m_100/modeling_m2m_100.py | 4 +- .../models/marian/modeling_marian.py | 4 +- .../models/mbart/modeling_mbart.py | 4 +- .../models/minimax/modeling_minimax.py | 6 +- .../models/minimax/modular_minimax.py | 6 +- .../models/mllama/modeling_mllama.py | 4 +- .../models/moonshine/modeling_moonshine.py | 4 +- .../models/moonshine/modular_moonshine.py | 4 +- src/transformers/models/mt5/modeling_mt5.py | 4 +- .../models/pegasus/modeling_pegasus.py | 4 +- .../models/pegasus_x/modeling_pegasus_x.py | 4 +- .../models/pix2struct/modeling_pix2struct.py | 4 +- .../models/plbart/modeling_plbart.py | 4 +- .../models/pop2piano/modeling_pop2piano.py | 4 +- .../modeling_switch_transformers.py | 4 +- src/transformers/models/t5/modeling_t5.py | 4 +- .../models/t5gemma/modeling_t5gemma.py | 4 +- .../models/t5gemma/modular_t5gemma.py | 4 +- .../modeling_time_series_transformer.py | 4 +- src/transformers/models/udop/modeling_udop.py | 4 +- src/transformers/models/umt5/modeling_umt5.py | 4 +- .../models/whisper/generation_whisper.py | 4 +- .../models/whisper/modeling_whisper.py | 4 +- tests/generation/test_utils.py | 82 +- .../deepseek_v3/test_modeling_deepseek_v3.py | 8 +- tests/models/dia/test_modeling_dia.py | 8 +- .../models/gpt_neox/test_modeling_gpt_neox.py | 4 +- tests/models/t5gemma/test_modeling_t5gemma.py | 44 +- tests/utils/test_cache_utils.py | 102 +-- utils/check_docstrings.py | 5 +- 48 files changed, 535 insertions(+), 651 deletions(-) diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 4190cefdb8a1..2adcc0c78012 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -82,24 +82,18 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s ## Cache storage implementation -The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`]. +Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape `[batch_size, num_heads, seq_len, head_dim]`. +Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWindowLayer`), which mostly changes how sequence length is handled and how the cache is updated. -In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`. -- `key_cache`: A list of tensors, one for each layer. -- `value_cache`: A list of tensors, one for each layer. +The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token: -When new tokens are processed: - -1. For each layer, the new key and value states are concatenated with the existing cache. ```py -self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) -self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) +cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2) +cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2) ``` -2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token. - -3. The cache maintains a count of seen tokens through `self._seen_tokens`. This is updated when the first layer processes a new token. +Other layers like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added. The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token. @@ -143,7 +137,7 @@ The legacy format is essentially the same data structure but organized different - The tensors have the same shape `[batch_size, num_heads, seq_len, head_dim]`. - The format is less flexible and doesn't support features like quantization or offloading. -If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~DynamicCache.from_legacy_cache`] and [`DynamicCache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format. +If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~Cache.from_legacy_cache`] and [`Cache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format. ```py import torch diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 233b72bcbfff..9219365e76ad 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -48,7 +48,7 @@ def pre_update( cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Hook called before the cache update. Can modify the key/value states. + Function called before the cache update. Can modify the key/value states. Args: cache (`Cache`): The cache instance. @@ -71,7 +71,7 @@ def post_update( cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Hook called after the cache update. Can process the cached data. + Function called after the cache update. Can process the cached data. Args: cache (`Cache`): The cache instance. @@ -86,37 +86,6 @@ def post_update( return key_tensors, value_tensors -class KVProxy: - """Efficiently simulates layer-indexed key or value lists from a layered cache. - This allows for BC access, e.g., cache.key_cache[idx] or cache.value_cache[idx].""" - - def __init__(self, layers, cache_type="key"): - self.layers = layers - self.cache_type = cache_type - - def __getitem__(self, idx): - if isinstance(idx, slice): - return [getattr(layer, f"{self.cache_type}_cache") for layer in self.layers[idx]] - return getattr(self.layers[idx], f"{self.cache_type}_cache") - - def __setitem__(self, idx, value): - if isinstance(idx, slice): - for layer, val in zip(self.layers[idx], value): - setattr(layer, f"{self.cache_type}_cache", val) - else: - setattr(self.layers[idx], f"{self.cache_type}_cache", value) - - def __len__(self): - return len(self.layers) - - def __iter__(self): - for layer in self.layers: - yield getattr(layer, f"{self.cache_type}_cache") - - def __bool__(self): - return bool(self.layers) - - class CacheLayer: """Base, abstract class for a single layer's cache.""" @@ -126,14 +95,14 @@ def __init__( self, config: Optional["CacheConfig"] = None, ): - self.key_cache = None - self.value_cache = None + self.keys = None + self.values = None @classmethod - def from_kv(cls, key_cache: torch.Tensor, value_cache: torch.Tensor) -> None: + def from_kv(cls, keys: torch.Tensor, values: torch.Tensor) -> None: cache = cls() - cache.key_cache = key_cache - cache.value_cache = value_cache + cache.keys = keys + cache.values = values return cache def update( @@ -167,44 +136,44 @@ def reset(self) -> None: def reorder_cache(self, beam_idx: torch.LongTensor) -> None: """Reorders this layer's cache for beam search.""" - if self.key_cache.numel(): - device = self.key_cache.device - self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) - if self.value_cache.numel(): - device = self.value_cache.device - self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) + if self.keys.numel(): + device = self.keys.device + self.keys = self.keys.index_select(0, beam_idx.to(device)) + if self.values.numel(): + device = self.values.device + self.values = self.values.index_select(0, beam_idx.to(device)) def __repr__(self): - key_repr = "None" if self.key_cache is None else f"t({tuple(self.key_cache.shape)})" - value_repr = "None" if self.value_cache is None else f"t({tuple(self.value_cache.shape)})" + key_repr = "None" if self.keys is None else f"t({tuple(self.keys.shape)})" + value_repr = "None" if self.values is None else f"t({tuple(self.values.shape)})" return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})" -class Cache: +class CacheBase: + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.cache_processor is not None: + key_states, value_states = self.cache_processor.pre_update( + self, key_states, value_states, layer_idx, cache_kwargs + ) + key_tensors, value_tensors = self._update(key_states, value_states, layer_idx, cache_kwargs) + if self.cache_processor is not None: + key_tensors, value_tensors = self.cache_processor.post_update( + self, key_tensors, value_tensors, layer_idx, cache_kwargs + ) + return key_tensors, value_tensors + + +class Cache(CacheBase): """ Base, abstract class for all caches. The actual data structure is specific to the layers. This class handles propagation of operations across layers. - Parameters: - config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): - Model configuration for shape/device info, or DDP-distributed cache data for compatibility. - If DDP-distributed cache data, must be an iterable of (key_states, value_states) tuples for each layer. - processor (`CacheProcessor`, *optional*): - Cache processor to apply (e.g., quantization, offloading). - pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): - Pattern of cache layer types to use. Defaults to `(DynamicLayer,)`. Must be a tuple whose length divides - the total number of layers. The pattern repeats to fill all layers. Examples: `(StaticLayer,)` for a - uniform cache, `(StaticLayer, StaticLayer, SlidingWindowLayer)` for a hybrid cache with repeating pattern, - or specify the full structure like `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)`. - Additional arguments for cache configuration: - - `max_batch_size`/`batch_size` (`int`): Maximum batch size for static caches - - `max_cache_len` (`int`): Maximum sequence length. For hybrid caches: - * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` - * StaticLayers: uses full `max_cache_len` - - `device` (`torch.device`): Device for cache tensors - - `dtype` (`torch.dtype`): Data type for cache tensors - - `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping - Note for hybrid caches (blocks of (StaticLayer, ..., SlidingWindowLayer) repeated across layers): - Requires `model_config.sliding_window` to be set - Uses `sliding_window_pattern` (default: 2) to determine layer alternation if pattern not specified @@ -212,69 +181,61 @@ class Cache: """ layers = [] - pattern_block = () # Subclasses can define their layer pattern statically def __init__( self, - config_or_ddp_cache_data: Optional[ - Union[PretrainedConfig, Iterable[tuple[torch.Tensor, torch.Tensor]]] - ] = None, - processor: Optional[CacheProcessor] = None, - pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, + model_config: Optional[PretrainedConfig] = None, + cache_processor: Optional[CacheProcessor] = None, + layer_classes: Optional[list[type[CacheLayer]]] = None, *args, **kwargs, ): + """ + Parameters: + model_config (`PretrainedConfig`): + Model configuration for shape/device info. + cache_processor (`CacheProcessor`, *optional*): + Cache processor to apply (e.g., quantization, offloading). + layer_classes (`list[type[CacheLayer]]`, *optional*): + List of layer classes to use for the cache. + Additional arguments for cache configuration: + - `max_batch_size`/`batch_size` (`int`): Maximum batch size for static caches + - `max_cache_len` (`int`): Maximum sequence length. For hybrid caches: + * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` + * StaticLayers: uses full `max_cache_len` + - `device` (`torch.device`): Device for cache tensors + - `dtype` (`torch.dtype`): Data type for cache tensors + - `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping + """ self.layers: list[CacheLayer] = [] - self.processor = processor - pattern_block = pattern_block or self.pattern_block or (DynamicLayer,) - - if isinstance(config_or_ddp_cache_data, PretrainedConfig): - model_config = config_or_ddp_cache_data - elif isinstance(config_or_ddp_cache_data, Iterable): - _distributed_cache_data = config_or_ddp_cache_data - # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 - # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the - # iterable contains the key and value states for a layer gathered across replicas by torch.distributed - # (shape=[global batch size, num_heads, seq_len, head_dim]). - # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break - # compatibility. The name of the argument doesn't matter. - assert pattern_block == (DynamicLayer,), "torch DDP is only supported for DynamicCache" - for key_states, value_states in _distributed_cache_data: - self.layers.append(DynamicLayer.from_kv(key_states, value_states)) - if self.processor is not None: - self.processor.init(self, **kwargs) - return - else: - model_config = kwargs.pop("config", None) + self.cache_processor = cache_processor - self.config, self.pattern_block = CacheConfig.from_model_config(model_config, pattern_block, *args, **kwargs) - self.layer_types = [self.pattern_block[i % len(self.pattern_block)] for i in range(self.config.num_layers)] + if ( + layer_classes is None # setting layer_classes takes precedence + and model_config is not None + and getattr(model_config, "layer_types", None) is not None + ): + layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in model_config.layer_types] + self.layer_classes = layer_classes or [DynamicLayer] - for idx, layer_type in enumerate(self.layer_types): - layer = layer_type(self.config.to_layer(idx)) - self.layers.append(layer) + self.config = CacheConfig.from_model_config(model_config, *args, **kwargs) - if self.processor is not None: - self.processor.init(self, **kwargs) + self.append_new_layers(self.config.num_layers - 1) - def grow_layers_to(self, layer_idx): - while len(self.layers) <= layer_idx: - next_type_idx = len(self.layer_types) % len(self.pattern_block) - next_layer_type = self.pattern_block[next_type_idx] - self.layer_types.append(next_layer_type) - self.layers.append(next_layer_type()) + if self.cache_processor is not None: + self.cache_processor.init(self, **kwargs) - @property - def key_cache(self) -> KVProxy: - """Returns a list-like object of key cache tensors indexed by layer.""" - return KVProxy(self.layers, "key") - - @property - def value_cache(self) -> KVProxy: - """Returns a list-like object of value cache tensors indexed by layer.""" - return KVProxy(self.layers, "value") + def append_new_layers(self, layer_idx): + """ + Appends layers to the cache until the layer `layer_idx` is reached. + Used in prefill and for skipped layers. + """ + while len(self.layers) <= layer_idx: + self.layers.append( + self.layer_classes[layer_idx % len(self.layer_classes)](self.config.for_layer(layer_idx)) + ) - def update( + def _update( self, key_states: torch.Tensor, value_states: torch.Tensor, @@ -298,17 +259,8 @@ def update( Return: A tuple containing the updated key and value states. """ - if self.processor is not None: - key_states, value_states = self.processor.pre_update( - self, key_states, value_states, layer_idx, cache_kwargs - ) - self.grow_layers_to(layer_idx) - key_tensors, value_tensors = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) - if self.processor is not None: - key_tensors, value_tensors = self.processor.post_update( - self, key_tensors, value_tensors, layer_idx, cache_kwargs - ) - return key_tensors, value_tensors + self.append_new_layers(layer_idx) + return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -316,7 +268,7 @@ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: sequence length. """ if layer_idx < len(self.layers): - return self.layers[layer_idx].key_cache, self.layers[layer_idx].value_cache + return self.layers[layer_idx].keys, self.layers[layer_idx].values else: raise KeyError( f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" @@ -328,20 +280,20 @@ def __iter__(self): keys and values """ for layer_idx in range(len(self)): - yield (self.layers[layer_idx].key_cache, self.layers[layer_idx].value_cache) + yield (self.layers[layer_idx].keys, self.layers[layer_idx].values) def __len__(self): """ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds to the number of layers in the model. """ - # Best effort BC support for subclasses + # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked. if self.layers is None: if getattr(self, "key_cache", None) is not None: return len(self.key_cache) return 0 dynamic_empty = ( - len(self.layers) == 1 and isinstance(self.layers[0], DynamicLayer) and self.layers[0].key_cache is None + len(self.layers) == 1 and isinstance(self.layers[0], DynamicLayer) and self.layers[0].keys is None ) return len(self.layers) if not dynamic_empty else 0 @@ -355,45 +307,42 @@ def __getattr__(self, name): if name in ("__getstate__", "__setstate__"): raise AttributeError(name) - # Check if the attribute/method exists and gather values if it is an attribute - attribute_values = [] - for i, layer in enumerate(self.layers[: len(self.pattern_block)]): - if not hasattr(layer, name): - raise AttributeError( - f"Layer {i} ({layer.__class__.__name__}) of {self.__class__.__name__} does not support `{name}`" - ) - if not callable(getattr(layer, name)): - attribute_values.append(getattr(layer, name)) - - if attribute_values: - assert len(attribute_values) == len(self.pattern_block), ( - f"Cache {self.__class__.__name__} gathered {len(attribute_values)} values for {name}, but there are {len(self.pattern_block)} layers." - ) - values = set(attribute_values) - if len(values) == 1: - return values.pop() + is_attribute = ( + all(hasattr(layer, name) and not callable(getattr(layer, name)) for layer in self.layers) + and len(self.layers) > 0 + ) + is_method = all(callable(getattr(layer, name)) for layer in self.layers) and len(self.layers) > 0 + + if is_attribute: + attribute_values = [getattr(layer, name) for layer in self.layers] + if all(isinstance(value, bool) for value in attribute_values): + return all(attribute_values) + elif len(set(attribute_values)) == 1: + return attribute_values[0] else: - if all(isinstance(value, bool) for value in values): - return all(values) - else: - raise ValueError( - f"Cache {self.__class__.__name__}:{self.pattern_block} has multiple values for {name}: {attribute_values}. This is not supported." - ) + raise ValueError( + f"{self.__class__.__name__}: layers have multiple values for layer.{name}: {attribute_values}. This is not supported." + ) + elif is_method: - # If the attribute is a method, we propagate it to all layers - def propagate_to_layers(*args, **kwargs): - for layer in self.layers: - return_value = getattr(layer, name)(*args, **kwargs) - if return_value is not None: - break - return return_value + def propagate_to_layers(*args, **kwargs): + for layer in self.layers: + return_value = getattr(layer, name)(*args, **kwargs) + if return_value is not None: + break + return return_value - return propagate_to_layers + return propagate_to_layers + else: + raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") def get_seq_length(self, layer_idx: int = 0) -> int: """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" if layer_idx >= len(self.layers): return 0 + # Hack since QuantizedCache messes with keys shape as it becomes the residual cache + if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): + return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length() return self.layers[layer_idx].get_seq_length() def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: @@ -403,24 +352,7 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), for each layer. """ - if isinstance(self.layers[layer_idx], SlidingWindowLayer): - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] - - local_mask_kv_offset = torch.clamp(first_cache_position - self.config.sliding_window + 1, min=0) - # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns - local_mask_kv_length = max(query_length, self.config.sliding_window) - return local_mask_kv_length, local_mask_kv_offset - - full_mask_kv_offset = 0 - if isinstance(self.layers[layer_idx], StaticLayer): - full_mask_kv_length = self.get_max_cache_shape() - return full_mask_kv_length, full_mask_kv_offset - else: - query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens - return kv_length, full_mask_kv_offset + return self.layers[layer_idx].get_mask_sizes(cache_position) def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: """Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for @@ -428,7 +360,7 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: legacy_cache = () for layer in self.layers: if layer is not None: - legacy_cache += ((layer.key_cache, layer.value_cache),) + legacy_cache += ((layer.keys, layer.values),) return legacy_cache @classmethod @@ -461,8 +393,7 @@ def __init__(self, num_layers: Optional[int] = None, cache_implementation: Optio @classmethod def from_model_config( cls, - config: Optional[PretrainedConfig], - pattern_block: tuple[type["CacheLayer"], ...], + model_config: Optional[PretrainedConfig], batch_size: Optional[int] = None, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, @@ -470,39 +401,41 @@ def from_model_config( layer_device_map=None, max_batch_size: Optional[int] = None, ) -> "CacheConfig": - num_layers = getattr(config, "num_hidden_layers", len(pattern_block)) # No model config -> must be a dynamic cache, return bare CacheConfig - if config is None: - return cls(num_layers=num_layers), pattern_block - # Build a StaticCacheConfig for any kind of static: hybrid, sliding or static + if model_config is None: + return cls(num_layers=getattr(model_config, "num_hidden_layers", 1)) + # Build a StaticCacheConfig for hybrid, sliding or static else: - # Rename max_batch_size to batch_size - if max_batch_size is not None: - batch_size = max_batch_size # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) - if StaticLayer in pattern_block and SlidingWindowLayer in pattern_block: - if getattr(config, "sliding_window", None) is None: + if ( + getattr(model_config, "layer_types", None) is not None + and "sliding_attention" in model_config.layer_types + and "full_attention" in model_config.layer_types + ): + if getattr(model_config, "sliding_window", None) is None: raise ValueError( "Setting up a hybrid or sliding window KVCache requires the model config supporting " "sliding window attention, please check if there is a `sliding_window` field in the model " "config and it's not set to None." ) # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) - max_cache_len = max_cache_len or config.max_position_embeddings - sliding_window_len = min(getattr(config, "sliding_window", max_cache_len) or max_cache_len, max_cache_len) + max_cache_len = max_cache_len or model_config.max_position_embeddings + sliding_window_len = min( + getattr(model_config, "sliding_window", max_cache_len) or max_cache_len, max_cache_len + ) # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: head_dim = ( - config.head_dim - if getattr(config, "head_dim", None) is not None - else config.hidden_size // config.num_attention_heads + model_config.head_dim + if getattr(model_config, "head_dim", None) is not None + else model_config.hidden_size // model_config.num_attention_heads ) num_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads + model_config.num_attention_heads + if getattr(model_config, "num_key_value_heads", None) is None + else model_config.num_key_value_heads ) cache_config = StaticCacheConfig( - batch_size=batch_size, + batch_size=max_batch_size if max_batch_size is not None else batch_size, max_cache_len=max_cache_len, device=torch.device(device) if device is not None else None, dtype=dtype, @@ -510,9 +443,9 @@ def from_model_config( head_dim=head_dim, num_heads=num_heads, sliding_window=sliding_window_len, - num_layers=num_layers, + num_layers=model_config.num_hidden_layers, ) - return cache_config, pattern_block + return cache_config @classmethod def from_dict(cls, config_dict, **kwargs): @@ -535,7 +468,7 @@ def from_dict(cls, config_dict, **kwargs): kwargs.pop(key, None) return config - def to_layer(self, layer_idx: int) -> "CacheLayer": + def for_layer(self, layer_idx: int) -> "CacheConfig": return self # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file @@ -817,21 +750,20 @@ def update( Return: A tuple containing the updated key and value states. """ - if self.key_cache is None: - self.key_cache = key_states - self.value_cache = value_states + if self.keys is None: + self.keys = key_states + self.values = value_states else: - self.key_cache = torch.cat([self.key_cache, key_states], dim=-2) - self.value_cache = torch.cat([self.value_cache, value_states], dim=-2) - - return self.key_cache, self.value_cache + self.keys = torch.cat([self.keys, key_states], dim=-2) + self.values = torch.cat([self.values, value_states], dim=-2) + return self.keys, self.values def get_seq_length(self, cache_position: Optional[torch.LongTensor] = None) -> int: """Returns the sequence length of the cached states.""" # TODO: deprecate this function in favor of `cache_position` - if self is None or self.key_cache is None or self.key_cache.numel() == 0: + if self is None or self.keys is None or self.keys.numel() == 0: return 0 - return self.key_cache.shape[-2] + return self.keys.shape[-2] def get_max_cache_shape(self) -> int: """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" @@ -839,15 +771,15 @@ def get_max_cache_shape(self) -> int: def reset(self) -> None: """Resets the cache values while preserving the objects""" - self.key_cache = torch.tensor([], dtype=self.key_cache.dtype, device=self.key_cache.device) - self.value_cache = torch.tensor([], dtype=self.value_cache.dtype, device=self.value_cache.device) + self.keys = torch.tensor([], dtype=self.keys.dtype, device=self.keys.device) + self.values = torch.tensor([], dtype=self.values.dtype, device=self.values.device) def reorder_cache(self, beam_idx: torch.LongTensor) -> None: """Reorders the cache for beam search, given the selected beam indices.""" - if self.key_cache is not None and self.key_cache.numel(): - self.key_cache = self.key_cache.index_select(0, beam_idx.to(self.key_cache.device)) - if self.value_cache is not None and self.value_cache.numel(): - self.value_cache = self.value_cache.index_select(0, beam_idx.to(self.value_cache.device)) + if self.keys is not None and self.keys.numel(): + self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device)) + if self.values is not None and self.values.numel(): + self.values = self.values.index_select(0, beam_idx.to(self.values.device)) def crop(self, max_length: int) -> None: """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be @@ -858,21 +790,28 @@ def crop(self, max_length: int) -> None: if self.get_seq_length() <= max_length: return - if self.key_cache.numel(): - self.key_cache = self.key_cache[..., :max_length, :] - self.value_cache = self.value_cache[..., :max_length, :] + if self.keys is not None and self.keys.numel(): + self.keys = self.keys[..., :max_length, :] + self.values = self.values[..., :max_length, :] def batch_repeat_interleave(self, repeats: int) -> None: """Repeat the cache `repeats` times in the batch dimension.""" - if self.key_cache.numel(): - self.key_cache = self.key_cache.repeat_interleave(repeats, dim=0) - self.value_cache = self.value_cache.repeat_interleave(repeats, dim=0) + if self.keys.numel(): + self.keys = self.keys.repeat_interleave(repeats, dim=0) + self.values = self.values.repeat_interleave(repeats, dim=0) def batch_select_indices(self, indices: torch.Tensor) -> None: """Only keep the `indices` in the batch dimension of the cache.""" - if self.key_cache.numel(): - self.key_cache = self.key_cache[indices, ...] - self.value_cache = self.value_cache[indices, ...] + if self.keys.numel(): + self.keys = self.keys[indices, ...] + self.values = self.values[indices, ...] + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + full_mask_kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length(cache_position) + kv_length = query_length + past_seen_tokens + return kv_length, full_mask_kv_offset class DynamicCache(Cache): @@ -900,7 +839,18 @@ class DynamicCache(Cache): ``` """ - pattern_block = (DynamicLayer,) + # Specialized constructor for DDP cache data, needed for BC + def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 + # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the + # iterable contains the key and value states for a layer gathered across replicas by torch.distributed + # (shape=[global batch size, num_heads, seq_len, head_dim]). + # WARNING: `ddp_cache_data` must be the first argument in `__init__`, otherwise we'll break + # compatibility. The name of the argument doesn't matter. + if ddp_cache_data is not None: + for key_states, value_states in ddp_cache_data: + self.layers.append(DynamicLayer.from_kv(key_states, value_states)) + super().__init__(*args, **kwargs) # Utilities for `DynamicCache` <> torch.export support @@ -917,16 +867,16 @@ def _flatten_dynamic_cache( ) dictionary = { - "key_cache": [layer.key_cache for layer in dynamic_cache.layers if layer.key_cache is not None], - "value_cache": [layer.value_cache for layer in dynamic_cache.layers if layer.value_cache is not None], + "key_cache": [layer.keys for layer in dynamic_cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in dynamic_cache.layers if layer.values is not None], } return torch.utils._pytree._dict_flatten(dictionary) def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): dictionary = { - "key_cache": [layer.key_cache for layer in dynamic_cache.layers if layer.key_cache is not None], - "value_cache": [layer.value_cache for layer in dynamic_cache.layers if layer.value_cache is not None], + "key_cache": [layer.keys for layer in dynamic_cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in dynamic_cache.layers if layer.values is not None], } return torch.utils._pytree._dict_flatten_with_keys(dictionary) @@ -937,7 +887,7 @@ def _unflatten_dynamic_cache( ): dictionary = torch.utils._pytree._dict_unflatten(values, context) cache = DynamicCache() - # Reconstruct layers from key_cache and value_cache lists + # Reconstruct layers from keys and values lists key_list = dictionary.get("key_cache", []) value_list = dictionary.get("value_cache", []) for idx in range(max(len(key_list), len(value_list))): @@ -949,8 +899,8 @@ def _unflatten_dynamic_cache( def _flatten_dynamic_cache_for_fx(cache, spec): dictionary = { - "key_cache": [layer.key_cache for layer in cache.layers if layer.key_cache is not None], - "value_cache": [layer.value_cache for layer in cache.layers if layer.value_cache is not None], + "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in cache.layers if layer.values is not None], } return torch.fx._pytree._dict_flatten_spec(dictionary, spec) @@ -980,13 +930,14 @@ class OffloadedCache(DynamicCache): ensure the eviction is scheduled after all computations on that cache are finished. """ - def __init__(self, config: Optional[CacheConfig] = None) -> None: + def __init__(self, model_config: Optional[PretrainedConfig] = None) -> None: # Create the underlying cache with offload processor - super().__init__(processor=OffloadedCacheProcessor(), config=config) + super().__init__(cache_processor=OffloadedCacheProcessor(), model_config=model_config) class StaticLayer(CacheLayer): is_compileable = True + is_sliding = False def __init__( self, @@ -996,20 +947,20 @@ def __init__( self.max_cache_len = max_len or config.max_cache_len self.max_batch_size = config.batch_size # Note: There will be significant perf decrease if switching to use 5D tensors instead. - self.key_cache = torch.zeros( + self.keys = torch.zeros( (config.batch_size, config.num_heads, self.max_cache_len, config.head_dim), dtype=config.dtype, device=config.device, ) - self.value_cache = torch.zeros( + self.values = torch.zeros( (config.batch_size, config.num_heads, self.max_cache_len, config.head_dim), dtype=config.dtype, device=config.device, ) # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(self.key_cache) - torch._dynamo.mark_static_address(self.value_cache) + torch._dynamo.mark_static_address(self.keys) + torch._dynamo.mark_static_address(self.values) def get_max_cache_shape(self) -> int: return self.max_cache_len @@ -1037,24 +988,24 @@ def _static_update( """ if cache_position is None: # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. - self.key_cache.copy_(key_states) - self.value_cache.copy_(value_states) + self.keys.copy_(key_states) + self.values.copy_(value_states) else: # Generation phase. Update specific positions. # Use index_copy_ for in-place update (compile-friendly). try: - self.key_cache.index_copy_(2, cache_position, key_states) - self.value_cache.index_copy_(2, cache_position, value_states) + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) except NotImplementedError: # Fallback for devices like MPS where index_copy_ might not be supported. - self.key_cache[:, :, cache_position] = key_states - self.value_cache[:, :, cache_position] = value_states - return self.key_cache, self.value_cache + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + return self.keys, self.values def update(self, key_states, value_states, cache_kwargs=None) -> tuple[torch.Tensor, torch.Tensor]: cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None - key_states = key_states.to(self.key_cache.dtype) - value_states = value_states.to(self.value_cache.dtype) + key_states = key_states.to(self.keys.dtype) + value_states = value_states.to(self.values.dtype) return self._static_update(key_states, value_states, cache_position) def get_seq_length(self, cache_position=None) -> int: @@ -1062,29 +1013,28 @@ def get_seq_length(self, cache_position=None) -> int: return int(cache_position[-1] + 1) # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. - return (self.key_cache[0, 0].any(dim=-1)).sum() + return (self.keys[0, 0].any(dim=-1)).sum() def reset(self): - self.key_cache.zero_() - self.value_cache.zero_() + self.keys.zero_() + self.values.zero_() def reorder_cache(self, beam_idx): - dev = self.key_cache.device + dev = self.keys.device beam_idx_dev = beam_idx.to(dev) - self.key_cache = self.key_cache.index_select(0, beam_idx_dev) - self.value_cache = self.value_cache.index_select(0, beam_idx_dev) + self.keys = self.keys.index_select(0, beam_idx_dev) + self.values = self.values.index_select(0, beam_idx_dev) + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + full_mask_kv_offset = 0 + full_mask_kv_length = self.max_cache_len + return full_mask_kv_length, full_mask_kv_offset class StaticCache(Cache): """ Static Cache class to be used with `torch.compile(model)` and `torch.export()`. - Parameters: - config_or_ddp_cache_data (`Union`, *optional*): Model configuration for shape/device info, or DDP-distributed cache data for compatibility. - processor (`Optional`, *optional*): Cache processor to apply (e.g., quantization, offloading). - pattern_block (`Optional`, *optional*): Pattern of cache layer types to use. Defaults to `(StaticLayer,)` for backward compatibility. - - Example: ```python @@ -1105,7 +1055,8 @@ class StaticCache(Cache): ``` """ - pattern_block = (StaticLayer,) + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=[StaticLayer], *args, **kwargs) class SlidingWindowLayer(StaticLayer): @@ -1147,9 +1098,9 @@ def _static_update( if cache_position.shape[0] > self.max_cache_len: new_k = key_states[:, :, -self.max_cache_len :, :] new_v = value_states[:, :, -self.max_cache_len :, :] - self.key_cache.copy_(new_k) - self.value_cache.copy_(new_v) - return self.key_cache, self.value_cache + self.keys.copy_(new_k) + self.values.copy_(new_v) + return self.keys, self.values # Sliding window logic for generation phase or prefill < window slicing = torch.arange(self.max_cache_len, device=value_states.device) @@ -1157,8 +1108,8 @@ def _static_update( to_shift = current_seq_len > self.max_cache_len indices = (slicing + to_shift.sum()) % self.max_cache_len - k_out_shifted = self.key_cache[:, :, indices] - v_out_shifted = self.value_cache[:, :, indices] + k_out_shifted = self.keys[:, :, indices] + v_out_shifted = self.values[:, :, indices] # Clamp cache_position to determine the *target index* within the shifted cache view update_position = cache_position.clamp(min=0, max=self.max_cache_len - 1) @@ -1173,9 +1124,18 @@ def _static_update( k_out_updated[:, :, update_position] = key_states v_out_updated[:, :, update_position] = value_states - self.key_cache.copy_(k_out_updated) - self.value_cache.copy_(v_out_updated) - return self.key_cache, self.value_cache + self.keys.copy_(k_out_updated) + self.values.copy_(v_out_updated) + return self.keys, self.values + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + + local_mask_kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0) + # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns + local_mask_kv_length = max(query_length, self.max_cache_len) + return local_mask_kv_length, local_mask_kv_offset class SlidingWindowCache(Cache): @@ -1214,7 +1174,8 @@ class SlidingWindowCache(Cache): ``` """ - pattern_block = (SlidingWindowLayer,) + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs) class EncoderDecoderCache(Cache): @@ -1250,7 +1211,7 @@ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) self.is_updated = {} - for layer_idx in range(len(cross_attention_cache.key_cache)): + for layer_idx in range(len(cross_attention_cache)): self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -1260,10 +1221,10 @@ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch """ if layer_idx < len(self): return ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], + self.self_attention_cache.layers[layer_idx].keys, + self.self_attention_cache.layers[layer_idx].values, + self.cross_attention_cache.layers[layer_idx].keys, + self.cross_attention_cache.layers[layer_idx].values, ) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") @@ -1377,12 +1338,6 @@ def get_max_cache_shape(self) -> int: return self.self_attention_cache.get_max_cache_shape() def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) @@ -1393,11 +1348,6 @@ class HybridCache(Cache): Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] for global attention.For more information, see the documentation of each subcomponent cache class. - Parameters: - config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): Model configuration for shape/device info. No DDP-distributed cache data is supported. - processor (`CacheProcessor`, *optional*): Cache processor to apply (e.g., quantization, offloading). - pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): Pattern of cache layer types to use. Defaults to `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)` - for backward compatibility. Example: ```python @@ -1418,24 +1368,10 @@ class HybridCache(Cache): ``` """ - def __init__( - self, - config_or_ddp_cache_data=None, - processor: Optional[CacheProcessor] = None, - pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, - *args, - **kwargs, - ): - model_config = config_or_ddp_cache_data or kwargs.get("config", None) - assert model_config is not None, "HybridCache requires a model config" - # If the attribute does not exist in the config, fallback to a simple StaticCache - if hasattr(model_config, "layer_types"): - self.is_sliding = [layer_type != "full_attention" for layer_type in model_config.layer_types] - else: - self.is_sliding = [False] * model_config.num_hidden_layers - - pattern_block = tuple(SlidingWindowLayer if sl else StaticLayer for sl in self.is_sliding) - super().__init__(config_or_ddp_cache_data, processor, pattern_block, *args, **kwargs) + def __init__(self, model_config: PretrainedConfig, *args, **kwargs): + # Ugly but needed for BC + layer_classes = [StaticLayer] if not hasattr(model_config, "layer_types") else None + super().__init__(model_config=model_config, layer_classes=layer_classes, *args, **kwargs) class HybridChunkedCache(Cache): @@ -1447,7 +1383,7 @@ class HybridChunkedCache(Cache): for global attention. For more information, see the documentation of each subcomponent cache class. Parameters: - config (`PretrainedConfig): + model_config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a @@ -1485,40 +1421,39 @@ class HybridChunkedCache(Cache): """ is_compileable = True - # Override @property since HybridChunked does its own thing + # Override @property since HybridChunked does not conform to layered caches yet key_cache = None value_cache = None def __init__( self, - config: PretrainedConfig, + model_config: PretrainedConfig, max_batch_size: int, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.bfloat16, layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, ) -> None: - super().__init__() - if not hasattr(config, "sliding_window") or config.sliding_window is None: - self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192) + if not hasattr(model_config, "sliding_window") or model_config.sliding_window is None: + self.sliding_window = getattr(model_config.get_text_config(), "attention_chunk_size", 8192) else: - self.sliding_window = config.sliding_window + self.sliding_window = model_config.sliding_window self.max_cache_len = max_cache_len # Sliding layers can't be larger than the overall max cache len self.sliding_window = min(self.sliding_window, self.max_cache_len) self.max_batch_size = max_batch_size - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.head_dim = getattr(model_config, "head_dim", model_config.hidden_size // model_config.num_attention_heads) self._dtype = dtype # If the attribute does not exist in the config, fallback to a simple StaticCache - if hasattr(config, "layer_types"): - self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types] + if hasattr(model_config, "layer_types"): + self.is_sliding = [layer_type != "full_attention" for layer_type in model_config.layer_types] else: - self.is_sliding = [False] * config.num_hidden_layers + self.is_sliding = [False] * model_config.num_hidden_layers self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] - self.cumulative_length = [0 for _ in range(config.num_hidden_layers)] + self.cumulative_length = [0 for _ in range(model_config.num_hidden_layers)] def initialise_cache_layer(self, layer_idx, key_states): if len(self.key_cache) > layer_idx: @@ -1648,12 +1583,6 @@ def reorder_cache(self, beam_idx: torch.LongTensor): self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ if self.is_sliding[layer_idx]: query_length = cache_position.shape[0] first_cache_position = cache_position[0] @@ -1682,7 +1611,7 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ class OffloadedHybridCache(HybridChunkedCache): def __init__( self, - config: PretrainedConfig, + model_config: PretrainedConfig, max_batch_size: int, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, @@ -1690,7 +1619,7 @@ def __init__( offload_device: Union[str, torch.device] = torch.device("cpu"), layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, ): - super().__init__(config, max_batch_size, max_cache_len, device, dtype, layer_device_map) + super().__init__(model_config, max_batch_size, max_cache_len, device, dtype, layer_device_map) # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps # track of the original device of each layer @@ -1801,15 +1730,6 @@ class OffloadedStaticCache(StaticCache): This cache maintains the compilation-friendly properties of StaticCache while enabling much longer sequences by offloading inactive layers to CPU memory. - Parameters: - config (`PretrainedConfig`): Model configuration for shape/device info. - max_batch_size (`int`): Maximum batch size for static caches. - max_cache_len (`int`, *optional*): Maximum sequence length. - device (`torch.device` or `str`, *optional*): Device for cache tensors. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type for cache tensors. - offload_device (`Union`, *optional*, defaults to `"cpu"`): Device to offload cache tensors to. - layer_device_map (`dict[int, Union[str, torch.device, int]]`, *optional*): Per-layer device mapping. - Example: ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache @@ -1834,25 +1754,8 @@ class OffloadedStaticCache(StaticCache): ``` """ - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: Optional[torch.dtype] = None, - offload_device: Union[str, torch.device] = "cpu", - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - super().__init__( - config=config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=device, - dtype=dtype, - layer_device_map=layer_device_map, - processor=OffloadedCacheProcessor(offload_device), - ) + def __init__(self, *args, offload_device: Union[str, torch.device] = "cpu", **kwargs) -> None: + super().__init__(*args, cache_processor=OffloadedCacheProcessor(offload_device), **kwargs) class OffloadedCacheProcessor(CacheProcessor): @@ -1885,8 +1788,8 @@ def init(self, cache: "Cache", **kwargs) -> None: if self.is_static: for i, layer in enumerate(cache.layers): device = cache.config.device if i == 0 else self.offload_device - layer.key_cache = layer.key_cache.to(device) - layer.value_cache = layer.value_cache.to(device) + layer.keys = layer.keys.to(device) + layer.values = layer.values.to(device) self.original_device.append(cache.config.device) if len(cache) != cache.config.num_layers: raise ValueError("If static layers are used, all cache layers must be initialized") @@ -1933,18 +1836,18 @@ def _prefetch_layer(self, cache: "Cache", layer_idx: int): ): # Prefetch next layer tensors to GPU device = self.original_device[layer_idx] - cache.key_cache[layer_idx] = cache.key_cache[layer_idx].to(device, non_blocking=True) - cache.value_cache[layer_idx] = cache.value_cache[layer_idx].to(device, non_blocking=True) + cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.to(device, non_blocking=True) + cache.layers[layer_idx].values = cache.layers[layer_idx].values.to(device, non_blocking=True) def _evict_previous_layer(self, cache: "Cache", layer_idx: int): """Moves the previous layer cache to the CPU.""" if len(cache) >= 2: # Layer 0 stays on device to be on-device after all layers are created # We do it on the default stream so it occurs after all earlier computations on these tensors are done prev_layer_idx = (layer_idx - 1) % len(cache) - cache.key_cache[prev_layer_idx] = cache.key_cache[prev_layer_idx].to( + cache.layers[prev_layer_idx].keys = cache.layers[prev_layer_idx].keys.to( self.offload_device, non_blocking=True ) - cache.value_cache[prev_layer_idx] = cache.value_cache[prev_layer_idx].to( + cache.layers[prev_layer_idx].values = cache.layers[prev_layer_idx].values.to( self.offload_device, non_blocking=True ) @@ -1957,8 +1860,8 @@ def _ensure_layer_on_device(self, cache: "Cache", layer_idx: int): # Handle delayed beam search operations if self.beam_idx is not None: self.beam_idx = self.beam_idx.to(self.original_device[layer_idx]) - cache.key_cache[layer_idx] = cache.key_cache[layer_idx].index_select(0, self.beam_idx) - cache.value_cache[layer_idx] = cache.value_cache[layer_idx].index_select(0, self.beam_idx) + cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.index_select(0, self.beam_idx) + cache.layers[layer_idx].values = cache.layers[layer_idx].values.index_select(0, self.beam_idx) class QuantizedCacheProcessor(CacheProcessor): @@ -1971,16 +1874,16 @@ class QuantizedCacheProcessor(CacheProcessor): def __init__(self, cache_config: QuantizedCacheConfig): self.config = cache_config - self._quantized_key_cache: list[torch.Tensor] = [] - self._quantized_value_cache: list[torch.Tensor] = [] - self._seen_tokens = 0 + self._quantized_keys: list[torch.Tensor] = [] + self._quantized_values: list[torch.Tensor] = [] def init(self, cache: "Cache", **kwargs) -> None: """Initialize the quantized processor and validate configuration.""" self.config.validate() + self.erased_length = 0 # Only compatible with DynamicCache - if not isinstance(cache, DynamicCache): + if not isinstance(cache.layers[0], DynamicLayer): raise ValueError("QuantizedCacheProcessor is only compatible with DynamicCache") def post_update( @@ -1992,29 +1895,25 @@ def post_update( cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Apply quantization after cache update.""" - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_tensors.shape[-2] - if len(cache.key_cache) < layer_idx: + if len(cache) < layer_idx: raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") # `key_tensors` is the content of the residual cache, after having been updated by DynamicLayer - # On the first forward pass, we quantize the whole prompt. + # On the first forward pass, we quantize the whole prompt (prefill, quantize_length=0) # On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full. - is_prefill = self._get_quantized_length(layer_idx) == 0 - - if is_prefill: - self._quantized_key_cache.append(self._quantize(key_tensors.contiguous(), axis=self.config.axis_key)) - self._quantized_value_cache.append(self._quantize(value_tensors.contiguous(), axis=self.config.axis_value)) + if self._is_quantized_length_zero(layer_idx): + self._quantized_keys.append(self._quantize(key_tensors.contiguous(), axis=self.config.axis_key)) + self._quantized_values.append(self._quantize(value_tensors.contiguous(), axis=self.config.axis_value)) # Clear the residual cache - cache.key_cache[layer_idx] = torch.zeros( + self.erased_length = key_tensors.shape[-2] + cache.layers[layer_idx].keys = torch.zeros( 0, dtype=key_tensors.dtype, device=key_tensors.device, ) - cache.value_cache[layer_idx] = torch.zeros( + cache.layers[layer_idx].values = torch.zeros( 0, dtype=value_tensors.dtype, device=value_tensors.device, @@ -2024,26 +1923,27 @@ def post_update( else: # Prepend the previously quantized cache - dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) - dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + dequant_key = self._dequantize(self._quantized_keys[layer_idx]) + dequant_value = self._dequantize(self._quantized_values[layer_idx]) keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2) values_to_return = torch.cat([dequant_value, value_tensors], dim=-2) if key_tensors.shape[-2] >= self.config.residual_length: # Quantize and store - self._quantized_key_cache[layer_idx] = self._quantize( + self._quantized_keys[layer_idx] = self._quantize( keys_to_return.contiguous(), axis=self.config.axis_key ) - self._quantized_value_cache[layer_idx] = self._quantize( + self._quantized_values[layer_idx] = self._quantize( values_to_return.contiguous(), axis=self.config.axis_value ) # Clear the residual cache - cache.key_cache[layer_idx] = torch.zeros( + self.erased_length += key_tensors.shape[-2] + cache.layers[layer_idx].keys = torch.zeros( 0, dtype=key_tensors.dtype, device=key_tensors.device, ) - cache.value_cache[layer_idx] = torch.zeros( + cache.layers[layer_idx].values = torch.zeros( 0, dtype=value_tensors.dtype, device=value_tensors.device, @@ -2059,6 +1959,10 @@ def _dequantize(self, tensor: torch.Tensor) -> torch.Tensor: """Dequantize a tensor - to be implemented by specific quantization backends.""" raise NotImplementedError("Quantization backend must implement _dequantize method") + def _is_quantized_length_zero(self, layer_idx: int) -> bool: + """Check if quantized cache is empty for layer. Note: shape[-2] is unreliable since quantized tensors are bit-packed and flattened.""" + return layer_idx >= len(self._quantized_keys) + class QuantoQuantizedCacheProcessor(QuantizedCacheProcessor): """ @@ -2107,12 +2011,6 @@ def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor: """Dequantize tensor using quanto backend.""" return qtensor.dequantize() - def _get_quantized_length(self, layer_idx: int) -> int: - """Get the length of quantized cache for a layer.""" - if layer_idx < len(self._quantized_key_cache): - return self._quantized_key_cache[layer_idx].shape[-2] - return 0 - class HQQQuantizedCacheProcessor(QuantizedCacheProcessor): """ @@ -2163,12 +2061,6 @@ def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tens tensor = self.quantizer.dequantize(quant_tensor, meta) return tensor - def _get_quantized_length(self, layer_idx: int) -> int: - """Get the length of quantized cache for a layer.""" - if layer_idx < len(self._quantized_key_cache): - return self._quantized_key_cache[layer_idx][0].shape[-2] - return 0 - class QuantizedCache(DynamicCache): """ @@ -2192,16 +2084,7 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None: else: raise ValueError(f"Unknown quantization backend `{cache_config.backend}`") - super().__init__(processor=processor) - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is - # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to verify attn_weight shape in some models - return self.processor._seen_tokens if layer_idx == 0 else self.processor._seen_tokens - 1 + super().__init__(cache_processor=processor) class QuantoQuantizedCache(QuantizedCache): @@ -2244,7 +2127,7 @@ class QuantoQuantizedCache(QuantizedCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: - Cache.__init__(self, processor=QuantoQuantizedCacheProcessor(cache_config)) + Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor(cache_config)) class HQQQuantizedCache(QuantizedCache): @@ -2287,7 +2170,7 @@ class HQQQuantizedCache(QuantizedCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: - Cache.__init__(self, processor=HQQQuantizedCacheProcessor(cache_config)) + Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor(cache_config)) class SinkCache(Cache): @@ -2422,3 +2305,10 @@ def reset(self): return MambaCache raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +LAYER_CLASS_MAP = { + "full_attention": StaticLayer, + "sliding_attention": SlidingWindowLayer, + # "chunked_attention": ChunkedLayer, +} diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 242b16195460..53095c19121a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1951,7 +1951,7 @@ def _get_cache( layer_device_map = self._get_layer_device_map_for_cache_init() cache_kwargs = { - "config": self.config.get_text_config(), + "model_config": self.config.get_text_config(), "max_batch_size": batch_size, "max_cache_len": max_cache_len, "dtype": cache_dtype, diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 0df283c83b71..36aab8699a81 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -275,15 +275,15 @@ def __init__(self, model: PreTrainedModel): self.model = model self.static_cache = StaticCache( - config=self.model.config, + model_config=self.model.config, max_batch_size=self.model.generation_config.cache_config.batch_size, max_cache_len=self.model.generation_config.cache_config.max_cache_len, device=self.model.generation_config.cache_config.device, dtype=self.model.dtype, ) - for i in range(len(self.static_cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + for i in range(len(self.static_cache)): + self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): """ @@ -404,7 +404,7 @@ def __init__( # Initialize the HybridCache self.cache = HybridCache( - config=self.model.config, + model_config=self.model.config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=self.model.device, @@ -412,9 +412,9 @@ def __init__( ) # Register all key and value cache tensors as buffers - for i in range(len(self.cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False) + for i in range(len(self.cache)): + self.register_buffer(f"key_cache_{i}", self.cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.cache.layers[i].values, persistent=False) def forward( self, @@ -550,7 +550,7 @@ def __init__(self, model, max_static_cache_length, batch_size): # Initialize static cache self.static_cache = StaticCache( - config=self.config, + model_config=self.config, max_batch_size=batch_size, max_cache_len=max_static_cache_length, device="cpu", @@ -558,9 +558,9 @@ def __init__(self, model, max_static_cache_length, batch_size): ) # Register cache buffers to make them exportable - for i in range(len(self.static_cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + for i in range(len(self.static_cache)): + self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): # Get outputs from decoder diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 128abd56ffac..e06056d7c0be 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -692,10 +692,10 @@ def create_causal_mask( useful to easily overlay another mask on top of the causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the full layers - if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding: - layer_idx = past_key_values.is_sliding.index(False) - else: - layer_idx = 0 + is_sliding = [] + if past_key_values is not None: + is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] + layer_idx = is_sliding.index(True) if True in is_sliding else 0 early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx @@ -774,10 +774,10 @@ def create_sliding_window_causal_mask( useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the sliding layers - if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding: - layer_idx = past_key_values.is_sliding.index(True) - else: - layer_idx = 0 + is_sliding = [] + if past_key_values is not None: + is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] + layer_idx = is_sliding.index(True) if True in is_sliding else 0 early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx @@ -861,10 +861,10 @@ def create_chunked_causal_mask( useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the sliding layers - if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding: - layer_idx = past_key_values.is_sliding.index(True) - else: - layer_idx = 0 + is_sliding = [] + if past_key_values is not None: + is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] + layer_idx = is_sliding.index(True) if True in is_sliding else 0 early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 994bf9d85dca..57a1eff22c65 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -230,8 +230,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 465b94e13bee..bb4660ae0998 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1293,8 +1293,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 8a0c43eafd3f..87cd60b2ca66 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -207,8 +207,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 7821e1c7b4fb..974a3fb7e9ec 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -229,8 +229,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 550e51221929..3f08d6804f9c 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -213,8 +213,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 19cac3e8c3ac..12677705002c 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -356,8 +356,8 @@ def forward( is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False if past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] - value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] + key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys + value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values else: key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index fe437fde84ed..dfe345968da6 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -182,8 +182,8 @@ def forward( is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False if past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] - value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] + key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys + value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values else: key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 0817e16451ac..57c8c4900e07 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1332,8 +1332,8 @@ def forward( else: indices = cache_position - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] + key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices] + value_states = past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices] else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index a3ffa710d842..7b8bcc6d37ec 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1774,8 +1774,8 @@ def forward( else: indices = cache_position - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] + key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices] + value_states = past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices] else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index a8504db42c82..0526a067b020 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -728,8 +728,9 @@ def forward( # Ensure layer_past is on same device as hidden_states (might not be correct) if past_key_values is not None: - past_key_values.key_cache = past_key_values.key_cache.to(hidden_states.device) - past_key_values.value_cache = past_key_values.value_cache.to(hidden_states.device) + for layer in past_key_values.layers: + layer.keys = layer.keys.to(hidden_states.device) + layer.values = layer.values.to(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states if causal_mask is not None: diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 9718e8fb736e..c988a874c20f 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -484,8 +484,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) @@ -601,8 +601,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 3d46275bdc81..606627823514 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -290,8 +290,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 081869ec8fc5..1a15f11b3850 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -478,8 +478,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 7d5a73667ee4..05368453e9e3 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -294,8 +294,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 7319671b485e..f199b3edfe19 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -229,8 +229,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 2585d91a3e3e..26f7b53caa67 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -239,8 +239,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 66ed4adcea4c..e3e438c0714b 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -105,16 +105,14 @@ def batch_repeat_interleave(self, repeats: int): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) else: - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.layers[layer_idx].batch_repeat_interleave(repeats) def batch_select_indices(self, indices: torch.Tensor): for layer_idx in range(len(self)): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] else: - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.layers[layer_idx].batch_select_indices(indices) def crop(self, max_length: int): raise RuntimeError("MiniMaxCache doesnot support `crop` method") diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 9b6fc12ae3de..0477a942a695 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -215,16 +215,14 @@ def batch_repeat_interleave(self, repeats: int): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) else: - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.layers[layer_idx].batch_repeat_interleave(repeats) def batch_select_indices(self, indices: torch.Tensor): for layer_idx in range(len(self)): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] else: - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.layers[layer_idx].batch_select_indices(indices) def crop(self, max_length: int): raise RuntimeError("MiniMaxCache doesnot support `crop` method") diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index d33edcb3dd00..985a35448e99 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -496,8 +496,8 @@ def forward( ) elif cache_position[0] != 0: key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], + past_key_value.layers[self.layer_idx].keys, + past_key_value.layers[self.layer_idx].values, ) else: raise ValueError( diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 2909fb386fb5..307bbe9ae90e 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -235,8 +235,8 @@ def forward( # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = ( self.k_proj(current_states) diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 500231f3b48b..df864b0c1ff2 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -331,8 +331,8 @@ def forward( # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = ( self.k_proj(current_states) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 5584b2ee8255..f467467e7fba 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -376,8 +376,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 2ffb53ee9e01..5ade9ee41d8a 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -228,8 +228,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 13f0ea27a6e4..396b49dfb6e4 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -249,8 +249,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 6b90ae80d7c7..a01f88b5443b 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -770,8 +770,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.key(current_states) value_states = self.value(current_states) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 327b70b5ec73..04384b8265a6 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -425,8 +425,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 5c4285afe728..13741a20ac18 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -320,8 +320,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index b0273c8a4a33..e5bdc624feb9 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -513,8 +513,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 2a1a84b81523..de7dbfa3e740 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -501,8 +501,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index feccf6d7d9fd..acf3bac94bcb 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -352,8 +352,8 @@ def forward( past_key_value.is_updated[self.layer_idx] = True # cross-attention: reuse cached states else: - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index ae69ae991009..522b60ebc83b 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -308,8 +308,8 @@ def forward( past_key_value.is_updated[self.layer_idx] = True # cross-attention: reuse cached states else: - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 778a0485b4e8..4c37cd42ef63 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -394,8 +394,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 8d4e368e945b..74bbce0259d8 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -599,8 +599,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 2b1f650c6789..bd2f3dd10ea6 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -285,8 +285,8 @@ def forward( current_states = encoder_hidden_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 248d17cac404..c5ce00016e6a 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1140,8 +1140,8 @@ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None for layer_idx in range(self.config.decoder_layers): layer_past_key_values = [] for cache_cls in [values.self_attention_cache, values.cross_attention_cache]: - for v in [cache_cls.key_cache, cache_cls.value_cache]: - layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu()) + for v in [cache_cls.layers[layer_idx].keys, cache_cls.layers[layer_idx].values]: + layer_past_key_values.append(v[batch_idx][None].cpu()) all_past_key_values.append(tuple(layer_past_key_values)) return tuple(all_past_key_values) else: diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d3e9c8e03a2b..e5dc5d59e7f3 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -329,8 +329,8 @@ def forward( current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 746fd2179d20..533c72be199f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1595,9 +1595,7 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): # 3. Check cache shapes # 3.1. Encoder-Decoder checks if config.is_encoder_decoder: - num_cache_decoder_layers = ( - len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache) - ) + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1605,30 +1603,30 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] + self_attention_layer_keys = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys ) - self_attention_layer_value_cache = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] + self_attention_layer_values = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values ) - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_key_cache = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] + cross_attention_layer_keys = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys ) - cross_attention_layer_value_cache = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] + cross_attention_layer_values = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values ) - cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] - cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] - self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) + cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] + cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] + self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) # 3.2. Decoder-only checks else: - num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache) + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1636,10 +1634,10 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] - self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys + self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) @pytest.mark.generate @parameterized.expand([("greedy", 1), ("beam search", 2)]) @@ -1798,8 +1796,8 @@ def test_generate_from_inputs_embeds_with_static_cache(self): max_length = max_new_tokens + inputs_embeds.shape[1] - 1 cache_shape = [batch_size, num_key_value_heads, max_length, head_dim] self.assertIsInstance(outputs.past_key_values, StaticCache) - self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers) - self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape) + self.assertEqual(len(outputs.past_key_values), num_hidden_layers) + self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape) @pytest.mark.generate def test_generate_continue_from_past_key_values(self): @@ -2029,8 +2027,8 @@ def test_generate_with_static_cache(self): num_hidden_layers = text_config.num_hidden_layers cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache)) - self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers) - self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape) + self.assertTrue(len(static_cache_generation.past_key_values) == num_hidden_layers) + self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape) # Check 2: The outputs must be similar to the case with dynamic cache dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) @@ -2612,12 +2610,12 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value if isinstance(decoder_past_key_values, Cache): self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_shape] * len(decoder_past_key_values.key_cache), + [layer.keys.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) self.assertListEqual( - [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], - [expected_shape] * len(decoder_past_key_values.value_cache), + [layer.values.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) # Legacy cache format checks. This branch should be removed when all models use `Cache` by default @@ -3976,13 +3974,13 @@ def test_generate_with_static_cache_multi_accelerator(self): self.assertTrue(isinstance(results.past_key_values, StaticCache)) # check device of each layer - key_cache_0 = results.past_key_values.key_cache[0] - value_cache_0 = results.past_key_values.value_cache[0] - self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + keys_0 = results.past_key_values.layers[0].keys + values_0 = results.past_key_values.layers[0].values + self.assertTrue(keys_0.device == values_0.device == torch.device(0)) - key_cache_1 = results.past_key_values.key_cache[1] - value_cache_1 = results.past_key_values.value_cache[1] - self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + keys_1 = results.past_key_values.layers[1].keys + values_1 = results.past_key_values.layers[1].values + self.assertTrue(keys_1.device == values_1.device == torch.device(1)) @pytest.mark.generate @require_torch_multi_accelerator @@ -4054,13 +4052,13 @@ def test_init_static_cache_multi_accelerator(self): results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) # check device of each layer - key_cache_0 = results.past_key_values.key_cache[0] - value_cache_0 = results.past_key_values.value_cache[0] - self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + keys_0 = results.past_key_values.layers[0].keys + values_0 = results.past_key_values.layers[0].values + self.assertTrue(keys_0.device == values_0.device == torch.device(0)) - key_cache_1 = results.past_key_values.key_cache[1] - value_cache_1 = results.past_key_values.value_cache[1] - self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + keys_1 = results.past_key_values.layers[1].keys + values_1 = results.past_key_values.layers[1].values + self.assertTrue(keys_1.device == values_1.device == torch.device(1)) @slow def test_padding_input_contrastive_search_gpt2(self): diff --git a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py index 6c0c3a19d067..87f7b2abb0e9 100644 --- a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py +++ b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -440,13 +440,11 @@ def test_past_key_values_format(self): # difference: last dim k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim v_embed_dim = config.v_head_dim - self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) - self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) + self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) + self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) # build the full cache shapes num_hidden_layers = config.num_hidden_layers - all_cache_shapes = [ - [self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers) - ] + all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)] super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes) @require_torch_large_accelerator diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py index f9427160c254..34a7a9884728 100644 --- a/tests/models/dia/test_modeling_dia.py +++ b/tests/models/dia/test_modeling_dia.py @@ -399,12 +399,12 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value if isinstance(decoder_past_key_values, Cache): self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_shape] * len(decoder_past_key_values.key_cache), + [layer.keys.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) self.assertListEqual( - [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], - [expected_shape] * len(decoder_past_key_values.value_cache), + [layer.values.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) def _check_scores(self, batch_size, scores, generated_length, config): diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index b0a0a6a3ccb4..ecd2af9fdc6c 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -235,8 +235,8 @@ def copy_cache(cache: DynamicCache): """Deep copy a DynamicCache to reuse the same one multiple times.""" new_cache = cache for i in range(len(cache)): - new_cache.key_cache[i] = cache.key_cache[i].clone() - new_cache.value_cache[i] = cache.value_cache[i].clone() + new_cache.layers[i].keys = cache.layers[i].keys.clone() + new_cache.layers[i].values = cache.layers[i].values.clone() # Cached forward once with the attention mask provided and the other time without it (which should assume full attention) # We need to run both on a copy of the cache, otherwise it is modified in-place diff --git a/tests/models/t5gemma/test_modeling_t5gemma.py b/tests/models/t5gemma/test_modeling_t5gemma.py index fd61e5e5c5db..269fee53d165 100644 --- a/tests/models/t5gemma/test_modeling_t5gemma.py +++ b/tests/models/t5gemma/test_modeling_t5gemma.py @@ -272,7 +272,7 @@ def create_and_check_model( self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertIsNotNone(decoder_past) self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers) - self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers) + self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers) def check_prepare_lm_labels_via_shift_left( self, @@ -1069,9 +1069,7 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): # 3. Check cache shapes # 3.1. Encoder-Decoder checks if config.is_encoder_decoder: - num_cache_decoder_layers = ( - len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache) - ) + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1079,30 +1077,30 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] + self_attention_layer_keys = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys ) - self_attention_layer_value_cache = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] + self_attention_layer_values = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values ) - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_key_cache = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] + cross_attention_layer_keys = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys ) - cross_attention_layer_value_cache = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] + cross_attention_layer_values = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values ) - cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] - cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] - self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) + cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] + cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] + self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) # 3.2. Decoder-only checks else: - num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache) + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1110,10 +1108,10 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] - self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys + self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) @unittest.skip("Mismatch issue doesn't exist in T5Gemma.") def test_load_with_mismatched_shapes(self): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 8580316d26bd..d5c1463cf618 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -171,7 +171,9 @@ def _random_kvs(config): return random_keys, random_values mha_config = LlamaConfig(num_attention_heads=32) - mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) + mha_static_cache = StaticCache( + model_config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device + ) cached_keys, cached_values = mha_static_cache.update( *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -179,7 +181,9 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 32, 10, 128)) gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) - gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + gqa_static_cache = StaticCache( + model_config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device + ) cached_keys, cached_values = gqa_static_cache.update( *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -187,7 +191,9 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 4, 10, 128)) mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) - mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + mqa_static_cache = StaticCache( + model_config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device + ) cached_keys, cached_values = mqa_static_cache.update( *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -323,7 +329,7 @@ def test_quantized_cache_generation(self, backend): ) self.assertIsInstance(gen_out.past_key_values, QuantizedCache) - processor = gen_out.past_key_values.processor + processor = gen_out.past_key_values.cache_processor if backend == "quanto": self.assertIsInstance(processor, QuantoQuantizedCacheProcessor) elif backend == "hqq": @@ -332,12 +338,10 @@ def test_quantized_cache_generation(self, backend): decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) self.assertListEqual(decoded, expected_generation) - self.assertTrue(len(processor._quantized_key_cache) > 0) + self.assertTrue(len(processor._quantized_keys) > 0) # Check that something is actually quantized - has_been_quantized = any( - (q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_key_cache - ) + has_been_quantized = any((q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_keys) self.assertTrue(has_been_quantized) @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) @@ -654,7 +658,7 @@ def test_dynamic_cache_exportability(self): past_key_values=DynamicCache(), use_cache=True, ) - self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers) + self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers) self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs)) self.assertEqual( 3, @@ -675,11 +679,9 @@ def test_dynamic_cache_exportability(self): use_cache=True, ) self.assertTrue(torch.allclose(res.logits, res_eager.logits)) - for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache): - self.assertTrue(torch.allclose(k1, k2)) - - for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache): - self.assertTrue(torch.allclose(v1, v2)) + for l1, l2 in zip(res.past_key_values.layers, res_eager.past_key_values.layers): + self.assertTrue(torch.allclose(l1.keys, l2.keys)) + self.assertTrue(torch.allclose(l1.values, l2.values)) def test_dynamic_cache_exportability_multiple_run(self): # When exporting with DynamicCache, you should export two graphs: @@ -703,7 +705,7 @@ def test_dynamic_cache_exportability_multiple_run(self): past_key_values=DynamicCache(), use_cache=True, ) - self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers) + self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers) self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs)) self.assertEqual( 3, @@ -728,9 +730,9 @@ def test_dynamic_cache_exportability_multiple_run(self): shapes = torch.export.ShapesCollection() dyn = torch.export.Dim("seq", max=512) - for ix in range(len(past_key_values.key_cache)): - shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None) - shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None) + for ix in range(len(past_key_values)): + shapes[past_key_values.layers[ix].keys] = (None, None, dyn, None) + shapes[past_key_values.layers[ix].values] = (None, None, dyn, None) ep_second = torch.export.export( model, @@ -771,11 +773,9 @@ def test_dynamic_cache_exportability_multiple_run(self): use_cache=True, ) - for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache): - self.assertTrue(torch.allclose(k1, k2)) - - for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache): - self.assertTrue(torch.allclose(v1, v2)) + for l1, l2 in zip(res_export_2.past_key_values.layers, res_eager_2.past_key_values.layers): + self.assertTrue(torch.allclose(l1.keys, l2.keys)) + self.assertTrue(torch.allclose(l1.values, l2.values)) def test_static_cache_exportability(self): """ @@ -922,7 +922,7 @@ def setUp(self): def test_static_cache_out_of_bounds(self): """Test StaticCache raises IndexError for out-of-bounds positions.""" - static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + static_cache = StaticCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len with self.assertRaises(IndexError): @@ -944,7 +944,7 @@ def test_static_cache(self): update pos 3: [1.0, 2.0, 3.0, 4.0] """ # Scenario 1: Fill up to near capacity - static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + static_cache = StaticCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None) static_cache.update( @@ -954,7 +954,7 @@ def test_static_cache(self): cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" + static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" ) # Scenario 2: Fill to capacity @@ -965,7 +965,7 @@ def test_static_cache(self): cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" + static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" ) def test_sliding_window_cache(self): @@ -984,7 +984,9 @@ def test_sliding_window_cache(self): result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ # Scenario 1: Update within window, no slide yet - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache( + model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len + ) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] sliding_cache.update( key_states=prefill, @@ -999,13 +1001,15 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "SlidingWindowCache Scenario 1 failed", ) # Scenario 2: Update causing slide - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache( + model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len + ) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] sliding_cache.update( key_states=prefill, @@ -1020,13 +1024,15 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "SlidingWindowCache Scenario 2 failed", ) # Scenario 3: Long prompt handling - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache( + model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len + ) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] sliding_cache.update( key_states=long_prefill, @@ -1035,7 +1041,7 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "SlidingWindowCache Scenario 3 failed", ) @@ -1054,7 +1060,7 @@ def test_hybrid_cache_static_mode(self): config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0) # Scenario 1 - hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache_static_mode = HybridCache(model_config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] hybrid_cache_static_mode.update( key_states=prefill, @@ -1069,7 +1075,7 @@ def test_hybrid_cache_static_mode(self): cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Static Scenario 1 failed", ) @@ -1082,7 +1088,7 @@ def test_hybrid_cache_static_mode(self): cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "HybridCache Static Scenario 2 failed", ) @@ -1106,7 +1112,7 @@ def test_hybrid_cache_sliding_mode(self): result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ # Scenario 1: Update within window, no slide yet - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, @@ -1121,13 +1127,13 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Sliding Scenario 1 failed", ) # Scenario 2: Update causing first slide - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, @@ -1142,7 +1148,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "HybridCache Sliding Scenario 2 failed", ) @@ -1155,13 +1161,13 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 3 failed", ) # Scenario 4: Long prompt handling - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] hybrid_cache.update( key_states=long_prefill, @@ -1170,7 +1176,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 4 failed", ) @@ -1190,10 +1196,10 @@ def test_dynamic_cache(self): cache = DynamicCache() cache.update(prefill, prefill, 0) cache.update(update3, update3, 0) - self.assertEqual(cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed") + self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed") cache.update(update4, update4, 0) self.assertEqual( - cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed" + cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed" ) # Scenario 2: prefill and update for two layers independently @@ -1210,8 +1216,10 @@ def test_dynamic_cache(self): cache.update(update4, update4, 0) cache.update(update4_1, update4_1, 1) self.assertEqual( - cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed" + cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed" ) self.assertEqual( - cache.key_cache[1][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0], "DynamicCache Scenario 2 layer 1 failed" + cache.layers[1].keys[0, 0, :, 0].tolist(), + [10.0, 20.0, 30.0, 40.0], + "DynamicCache Scenario 2 layer 1 failed", ) diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index bc247b2b6011..eb101ab566aa 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -956,8 +956,9 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str): idx += 1 if "".join(source[start_idx:idx])[:-1] != old_doc_args: - # Args are not fully defined in the docstring of this object - return + raise ValueError( + f"Expected\n{old_doc_args}\nbut got\n{''.join(source[start_idx:idx])[:-1]}\n in {find_source_file(obj)}: {obj.__name__}" + ) obj_file = find_source_file(obj) with open(obj_file, "r", encoding="utf-8") as f: From aec9ccd6ea5be33cac992bc93eff7dd777cd49c5 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 2 Jul 2025 20:52:22 +0200 Subject: [PATCH 05/35] joao review: minor things --- src/transformers/cache_utils.py | 253 +++++++++--------- .../models/phimoe/modeling_phimoe.py | 4 - .../models/zamba/modeling_zamba.py | 3 + .../models/zamba2/modeling_zamba2.py | 3 + tests/generation/test_utils.py | 14 +- 5 files changed, 137 insertions(+), 140 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 9219365e76ad..b392bcd7fc6c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -37,7 +37,7 @@ def init(self, cache: "Cache", **kwargs) -> None: cache (`Cache`): The cache instance this processor will be applied to. **kwargs: Additional arguments that may be needed for initialization. """ - raise NotImplementedError("Make sure to implement `init` in a subclass.") + raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.") def pre_update( self, @@ -55,7 +55,7 @@ def pre_update( key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): Additional arguments for the cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. Returns: tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states. @@ -78,7 +78,7 @@ def post_update( key_states (`torch.Tensor`): The key states that were cached. value_states (`torch.Tensor`): The value states that were cached. layer_idx (`int`): The index of the layer that was updated. - cache_kwargs (`dict[str, Any]`, `optional`): Additional arguments for the cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. Returns: tuple[`torch.Tensor`, `torch.Tensor`]: The final key and value states to return. @@ -86,24 +86,15 @@ def post_update( return key_tensors, value_tensors -class CacheLayer: +class CacheLayerMixin: """Base, abstract class for a single layer's cache.""" is_compileable = False - def __init__( - self, - config: Optional["CacheConfig"] = None, - ): - self.keys = None - self.values = None - - @classmethod - def from_kv(cls, keys: torch.Tensor, values: torch.Tensor) -> None: - cache = cls() - cache.keys = keys - cache.values = values - return cache + def __repr__(self): + key_repr = "None" if self.keys is None else f"t({tuple(self.keys.shape)})" + value_repr = "None" if self.values is None else f"t({tuple(self.values.shape)})" + return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})" def update( self, @@ -111,28 +102,16 @@ def update( value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Updates KV cache, returns updated K and V for this layer.""" - raise NotImplementedError("Make sure to implement `update` in a subclass.") - - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache. - Early stops since first layer is enough to compute sequence length""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_cache_shape() - previous_seq_length, _ = self.get_seq_length(layer_idx) - if max_length != -1 and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length + """Updates KV cache, returns updated keys/values for this layer.""" + raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") def get_max_cache_shape(self) -> int: """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" - raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") def reset(self) -> None: """Resets this layer's cache.""" - raise NotImplementedError("Make sure to implement `reset` in a subclass.") + raise NotImplementedError(f"Make sure to implement `reset` in {self.__class__.__name__}.") def reorder_cache(self, beam_idx: torch.LongTensor) -> None: """Reorders this layer's cache for beam search.""" @@ -143,13 +122,11 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None: device = self.values.device self.values = self.values.index_select(0, beam_idx.to(device)) - def __repr__(self): - key_repr = "None" if self.keys is None else f"t({tuple(self.keys.shape)})" - value_repr = "None" if self.values is None else f"t({tuple(self.values.shape)})" - return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})" - class CacheBase: + + layers = None + def update( self, key_states: torch.Tensor, @@ -180,13 +157,11 @@ class Cache(CacheBase): - SlidingWindow layers are limited to sliding window size, Static layers use full max_cache_len """ - layers = [] - def __init__( self, model_config: Optional[PretrainedConfig] = None, cache_processor: Optional[CacheProcessor] = None, - layer_classes: Optional[list[type[CacheLayer]]] = None, + layer_classes: Optional[list[type[CacheLayerMixin]]] = None, *args, **kwargs, ): @@ -207,7 +182,7 @@ def __init__( - `dtype` (`torch.dtype`): Data type for cache tensors - `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping """ - self.layers: list[CacheLayer] = [] + self.layers: list[CacheLayerMixin] = [] self.cache_processor = cache_processor if ( @@ -225,43 +200,6 @@ def __init__( if self.cache_processor is not None: self.cache_processor.init(self, **kwargs) - def append_new_layers(self, layer_idx): - """ - Appends layers to the cache until the layer `layer_idx` is reached. - Used in prefill and for skipped layers. - """ - while len(self.layers) <= layer_idx: - self.layers.append( - self.layer_classes[layer_idx % len(self.layer_classes)](self.config.for_layer(layer_idx)) - ) - - def _update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. - - Return: - A tuple containing the updated key and value states. - """ - self.append_new_layers(layer_idx) - return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) - def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the @@ -336,6 +274,46 @@ def propagate_to_layers(*args, **kwargs): else: raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") + def __repr__(self): + return f"{self.__class__.__name__}(layers={self.layers})" + + def append_new_layers(self, layer_idx): + """ + Appends layers to the cache until the layer `layer_idx` is reached. + Used in prefill and for skipped layers. + """ + while len(self.layers) <= layer_idx: + self.layers.append( + self.layer_classes[layer_idx % len(self.layer_classes)](self.config.for_layer(layer_idx)) + ) + + def _update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + self.append_new_layers(layer_idx) + return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + def get_seq_length(self, layer_idx: int = 0) -> int: """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" if layer_idx >= len(self.layers): @@ -354,42 +332,25 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ """ return self.layers[layer_idx].get_mask_sizes(cache_position) - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: - """Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" - legacy_cache = () - for layer in self.layers: - if layer is not None: - legacy_cache += ((layer.keys, layer.values),) - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None - ) -> "Cache": - """Converts a cache in the legacy cache format into an equivalent `Cache`. Used for - backward compatibility.""" - cache = cls() - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def __repr__(self): - return f"{self.__class__.__name__}(layers={self.layers})" - @dataclass class CacheConfig: - """ - Base class for cache configs - """ + """Base class for cache configs""" def __init__(self, num_layers: Optional[int] = None, cache_implementation: Optional[str] = None): self.num_layers = num_layers self.cache_implementation = cache_implementation + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + @classmethod def from_model_config( cls, @@ -497,16 +458,6 @@ def to_dict(self) -> dict[str, Any]: """ return copy.deepcopy(self.__dict__) - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ - def __iter__(self): - """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" - for attr, value in copy.deepcopy(self.__dict__).items(): - yield attr, value - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - def to_json_string(self): """ Serializes this instance to a JSON formatted string. @@ -640,9 +591,7 @@ def validate(self): @dataclass class StaticCacheConfig(CacheConfig): - """ - Configuration class for static and sliding window cache settings. - """ + """Configuration class for static and sliding window cache settings.""" batch_size: Optional[int] = None max_cache_len: Optional[int] = None @@ -669,9 +618,7 @@ def __post_init__(self): logger.warning_once("`dtype` not set in cache initialization, using default `float32`") def for_layer(self, layer_idx: int): - """ - Returns a StaticCacheConfig for a given layer index. - """ + """Returns a StaticCacheConfig for a given layer index.""" device = self.layer_device_map[layer_idx] if self.layer_device_map is not None else self.device return StaticCacheConfig( self.batch_size, @@ -724,11 +671,19 @@ def validate(self): ) -class DynamicLayer(CacheLayer): +class DynamicLayer(CacheLayerMixin): """ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. """ + keys, values = None, None + + @classmethod + def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> None: + cache = cls() + cache.keys = keys + cache.values = values + return cache def update( self, @@ -744,7 +699,7 @@ def update( The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. - cache_kwargs (`dict[str, Any]`, `optional`): + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`. Return: @@ -782,8 +737,10 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None: self.values = self.values.index_select(0, beam_idx.to(self.values.device)) def crop(self, max_length: int) -> None: - """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be - negative to remove `max_length` tokens.""" + """ + Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. + """ if max_length < 0: max_length = self.get_seq_length() - abs(max_length) @@ -849,9 +806,35 @@ def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.T # compatibility. The name of the argument doesn't matter. if ddp_cache_data is not None: for key_states, value_states in ddp_cache_data: - self.layers.append(DynamicLayer.from_kv(key_states, value_states)) + self.layers.append(DynamicLayer.from_tensors(key_states, value_states)) super().__init__(*args, **kwargs) + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: + """ + Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility. + """ + legacy_cache = () + for layer in self.layers: + if layer is not None: + legacy_cache += ((layer.keys, layer.values),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None + ) -> "Cache": + """ + Converts a cache in the legacy cache format into an equivalent `Cache`. Used for + backward compatibility. + """ + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + # Utilities for `DynamicCache` <> torch.export support def _flatten_dynamic_cache( @@ -935,7 +918,7 @@ def __init__(self, model_config: Optional[PretrainedConfig] = None) -> None: super().__init__(cache_processor=OffloadedCacheProcessor(), model_config=model_config) -class StaticLayer(CacheLayer): +class StaticLayer(CacheLayerMixin): is_compileable = True is_sliding = False @@ -1304,14 +1287,18 @@ def check_dynamic_cache(self, method: str): # TODO(gante, sanchit-gandhi): move following functionality into `.generate` def crop(self, maximum_length: int): - """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be - negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + """ + Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search. + """ self.check_dynamic_cache(self.crop.__name__) self.self_attention_cache.crop(maximum_length) def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" + """ + Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils` + """ self.check_dynamic_cache(self.batch_split.__name__) self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index af4e823a98d8..51902e059cda 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -350,10 +350,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index b64d0107395e..96b7eed4777c 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -143,6 +143,9 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + def __len__(self): + return len(self.key_cache) + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update def update( self, diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 1c7a489784f3..0ef3bd37b161 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -147,6 +147,9 @@ def __init__( self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + def __len__(self): + return len(self.key_cache) + def update( self, key_states: torch.Tensor, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 533c72be199f..cee456aa8f2c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1626,7 +1626,7 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): # 3.2. Decoder-only checks else: - num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv) + num_cache_decoder_layers = len(past_kv) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1634,8 +1634,16 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys - self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values + if is_legacy_cache: + self_attention_layer_keys = past_kv[i][0] + self_attention_layer_values = past_kv[i][1] + elif past_kv.layers is None: + # Cache is lot layered (i.e, Mamba derivatives) + self_attention_layer_keys = past_kv.key_cache[i] + self_attention_layer_values = past_kv.value_cache[i] + else: + self_attention_layer_keys = past_kv.layers[i].keys + self_attention_layer_values = past_kv.layers[i].values self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) From e80c68a61181d9ca0fc550847afc464de90a9d55 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Fri, 4 Jul 2025 18:59:04 +0200 Subject: [PATCH 06/35] remove cache configs, make CacheLayer a mixin (joaos review) --- docs/source/en/kv_cache.md | 7 +- docs/source/ko/internal/generation_utils.md | 6 - src/transformers/cache_utils.py | 906 ++++++++++-------- .../generation/configuration_utils.py | 28 +- src/transformers/generation/utils.py | 11 +- src/transformers/integrations/executorch.py | 6 +- src/transformers/masking_utils.py | 4 +- .../models/zamba/modeling_zamba.py | 3 + .../models/zamba2/modeling_zamba2.py | 3 + .../falcon_h1/test_modeling_falcon_h1.py | 39 +- tests/utils/test_cache_utils.py | 4 +- 11 files changed, 542 insertions(+), 475 deletions(-) diff --git a/docs/source/en/kv_cache.md b/docs/source/en/kv_cache.md index 14a0d4901d70..8139b08f5d1d 100644 --- a/docs/source/en/kv_cache.md +++ b/docs/source/en/kv_cache.md @@ -134,7 +134,7 @@ The [`QuantizedCache`] reduces memory requirements by quantizing the KV values t > [!WARNING] > Quantizing the cache can harm latency if the context length is short and there is enough GPU memory available for generation without enabling cache quantization. Try to find a balance between memory efficiency and latency. -Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and indicate the quantization backend in [`QuantizedCacheConfig`]. Any additional quantization related parameters should also be passed either as a dict or an instance of [`QuantizedCacheConfig`]. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length. +Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and the quantization backend, as well as any additional quantization related parameters should also be passed either as a dict. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length. @@ -142,7 +142,7 @@ Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [ For [`HQQQuantizedCache`], we recommend setting the `axis-key` and `axis-value` parameters to `1`. ```py -from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig +from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0") @@ -159,7 +159,7 @@ I like rock music because it's loud and energetic. It's a great way to express m For [`QuantoQuantizedCache`], we recommend setting the `axis-key` and `axis-value` parameters to `0`. ```py -from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig +from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0") @@ -273,7 +273,6 @@ from transformers.cache_utils import ( StaticCache, SlidingWindowCache, QuantoQuantizedCache, - QuantizedCacheConfig, ) model_id = "meta-llama/Llama-2-7b-chat-hf" diff --git a/docs/source/ko/internal/generation_utils.md b/docs/source/ko/internal/generation_utils.md index e4841f0c626a..1a08a79368d3 100644 --- a/docs/source/ko/internal/generation_utils.md +++ b/docs/source/ko/internal/generation_utils.md @@ -345,12 +345,6 @@ generation_output[:2] [[autodoc]] Cache - update -[[autodoc]] CacheConfig - - update - -[[autodoc]] QuantizedCacheConfig - - validate - [[autodoc]] DynamicCache - update - get_seq_length diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b392bcd7fc6c..08f1a70550f7 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -105,6 +105,10 @@ def update( """Updates KV cache, returns updated keys/values for this layer.""" raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") + def get_seq_length(self) -> int: + """Returns the sequence length of this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.") + def get_max_cache_shape(self) -> int: """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") @@ -123,8 +127,63 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None: self.values = self.values.index_select(0, beam_idx.to(device)) +def parse_layer_args_from_model_config( + model_config: Optional[PretrainedConfig], + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: Optional[torch.dtype] = None, + layer_device_map=None, + max_batch_size: Optional[int] = None, +) -> dict: + # No model config -> must be a dynamic cache, return bare dict + if model_config is None: + return {} + # Build the args dict for hybrid, sliding or static + else: + # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) + if ( + getattr(model_config, "layer_types", None) is not None + and "sliding_attention" in model_config.layer_types + and "full_attention" in model_config.layer_types + ): + if getattr(model_config, "sliding_window", None) is None: + raise ValueError( + "Setting up a hybrid or sliding window KVCache requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) + max_cache_len = max_cache_len or model_config.max_position_embeddings + if getattr(model_config, "sliding_window", None) is not None: + sliding_window_len = min(model_config.sliding_window, max_cache_len) + else: + sliding_window_len = None + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: + head_dim = ( + model_config.head_dim + if getattr(model_config, "head_dim", None) is not None + else model_config.hidden_size // model_config.num_attention_heads + ) + num_heads = ( + model_config.num_attention_heads + if getattr(model_config, "num_key_value_heads", None) is None + else model_config.num_key_value_heads + ) + layer_args = { + "batch_size": max_batch_size if max_batch_size is not None else batch_size, + "max_cache_len": max_cache_len, + "device": torch.device(device) if device is not None else None, + "dtype": dtype, + "layer_device_map": layer_device_map, + "head_dim": head_dim, + "num_heads": num_heads, + "sliding_window": sliding_window_len, + } + return {k: v for k, v in layer_args.items() if v is not None} + + class CacheBase: - layers = None def update( @@ -193,9 +252,10 @@ def __init__( layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in model_config.layer_types] self.layer_classes = layer_classes or [DynamicLayer] - self.config = CacheConfig.from_model_config(model_config, *args, **kwargs) + self.layer_args = parse_layer_args_from_model_config(model_config, *args, **kwargs) + self.model_num_layers = getattr(model_config, "num_hidden_layers", 1) - self.append_new_layers(self.config.num_layers - 1) + self.append_new_layers(self.model_num_layers - 1) if self.cache_processor is not None: self.cache_processor.init(self, **kwargs) @@ -283,9 +343,10 @@ def append_new_layers(self, layer_idx): Used in prefill and for skipped layers. """ while len(self.layers) <= layer_idx: - self.layers.append( - self.layer_classes[layer_idx % len(self.layer_classes)](self.config.for_layer(layer_idx)) - ) + args = self.layer_args.copy() + if self.layer_args.get("layer_device_map", None) is not None: + args["device"] = args.pop("layer_device_map")[layer_idx] + self.layers.append(self.layer_classes[layer_idx % len(self.layer_classes)](**args)) def _update( self, @@ -333,353 +394,16 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ return self.layers[layer_idx].get_mask_sizes(cache_position) -@dataclass -class CacheConfig: - """Base class for cache configs""" - - def __init__(self, num_layers: Optional[int] = None, cache_implementation: Optional[str] = None): - self.num_layers = num_layers - self.cache_implementation = cache_implementation - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ - def __iter__(self): - """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" - for attr, value in copy.deepcopy(self.__dict__).items(): - yield attr, value - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - @classmethod - def from_model_config( - cls, - model_config: Optional[PretrainedConfig], - batch_size: Optional[int] = None, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: Optional[torch.dtype] = None, - layer_device_map=None, - max_batch_size: Optional[int] = None, - ) -> "CacheConfig": - # No model config -> must be a dynamic cache, return bare CacheConfig - if model_config is None: - return cls(num_layers=getattr(model_config, "num_hidden_layers", 1)) - # Build a StaticCacheConfig for hybrid, sliding or static - else: - # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) - if ( - getattr(model_config, "layer_types", None) is not None - and "sliding_attention" in model_config.layer_types - and "full_attention" in model_config.layer_types - ): - if getattr(model_config, "sliding_window", None) is None: - raise ValueError( - "Setting up a hybrid or sliding window KVCache requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) - max_cache_len = max_cache_len or model_config.max_position_embeddings - sliding_window_len = min( - getattr(model_config, "sliding_window", max_cache_len) or max_cache_len, max_cache_len - ) - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: - head_dim = ( - model_config.head_dim - if getattr(model_config, "head_dim", None) is not None - else model_config.hidden_size // model_config.num_attention_heads - ) - num_heads = ( - model_config.num_attention_heads - if getattr(model_config, "num_key_value_heads", None) is None - else model_config.num_key_value_heads - ) - cache_config = StaticCacheConfig( - batch_size=max_batch_size if max_batch_size is not None else batch_size, - max_cache_len=max_cache_len, - device=torch.device(device) if device is not None else None, - dtype=dtype, - layer_device_map=layer_device_map, - head_dim=head_dim, - num_heads=num_heads, - sliding_window=sliding_window_len, - num_layers=model_config.num_hidden_layers, - ) - return cache_config - - @classmethod - def from_dict(cls, config_dict, **kwargs): - """ - Constructs a CacheConfig instance from a dictionary of parameters. - Args: - config_dict (dict[str, Any]): Dictionary containing configuration parameters. - **kwargs: Additional keyword arguments to override dictionary values. - - Returns: - CacheConfig: Instance of CacheConfig constructed from the dictionary. - """ - config = cls(**config_dict) - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - return config - - def for_layer(self, layer_idx: int) -> "CacheConfig": - return self - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file - def to_json_file(self, json_file_path: Union[str, os.PathLike]): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default - `QuantizationConfig()` is serialized to JSON file. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - config_dict = self.to_dict() - json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - - writer.write(json_string) - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict - def to_dict(self) -> dict[str, Any]: - """ - Serializes this instance to a Python dictionary. Returns: - `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. - """ - return copy.deepcopy(self.__dict__) - - def to_json_string(self): - """ - Serializes this instance to a JSON formatted string. - Returns: - str: JSON formatted string representing the configuration instance. - """ - return json.dumps(self.__dict__, indent=2) + "\n" - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update - def update(self, **kwargs): - """ - Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, - returning all the unused kwargs. - - Args: - kwargs (`dict[str, Any]`): - Dictionary of attributes to tentatively update this class. - - Returns: - `dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. - """ - to_remove = [] - for key, value in kwargs.items(): - if hasattr(self, key): - setattr(self, key, value) - to_remove.append(key) - - # Remove all the attributes that were updated, without modifying the input dict - unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} - return unused_kwargs - - -@dataclass -class QuantizedCacheConfig(CacheConfig): - """ - Configuration class for quantized cache settings. - - Attributes: - backend (`str`, *optional*, defaults to `"quanto"`): - Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] - nbits (`Optional[int]`, *optional*, defaults to 4): - Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. - axis_key (`int`, *optional*, defaults to 0): - Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - axis_value (`int`, *optional*, defaults to 0): - Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - q_group_size (`Optional[int]`, *optional*, defaults to 64): - Size of the quantization group, should be a divisor of the model's hidden dimension. - Defaults to 64. - residual_length (`Optional[int]`, *optional*, defaults to 128): - Length of the residual cache which will always be stored in original precision. - Defaults to 128. - compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. - device (`str`, *optional*, defaults to `"cpu"`): - Device on which to perform computations, should be same as the model's device. - """ - - def __init__( - self, - backend: str = "quanto", - nbits: Optional[int] = 4, - axis_key: Optional[int] = 0, - axis_value: Optional[int] = 0, - q_group_size: Optional[int] = 64, - residual_length: Optional[int] = 128, - compute_dtype: Optional[torch.dtype] = torch.float16, - device: Optional[str] = "cpu", - ): - self.backend = backend - self.nbits = nbits - self.axis_key = axis_key - self.axis_value = axis_value - self.q_group_size = q_group_size - self.residual_length = residual_length - self.compute_dtype = compute_dtype - self.device = device - - def validate(self): - """Validates if the arguments passed are correct""" - - incorrect_arg_msg = ( - "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " - "but found {found_value}" - ) - # Check that the values are reasonable in general (nbits, axis) - # Later in QuantizedCache init we check if they are supported for that particular backend - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - incorrect_arg_msg.format( - key="nbits", - correct_value="2 or 4 or 8", - found_value=self.nbits, - ), - ) - if self.q_group_size <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="q_group_size", - correct_value="a positive integer", - found_value=self.q_group_size, - ), - ) - if self.residual_length < 0: - raise ValueError( - incorrect_arg_msg.format( - key="residual_length", - correct_value="a positive integer", - found_value=self.residual_length, - ), - ) - - if self.axis_key not in [0, 1, -1]: - raise ValueError( - incorrect_arg_msg.format( - key="axis_key", - correct_value="`1` or `0`, `-1`", - found_value=self.axis_key, - ), - ) - - if self.axis_value not in [0, 1, -1]: - raise ValueError( - incorrect_arg_msg.format( - key="axis_value", - correct_value="`1` or `0` or `-1`", - found_value=self.axis_value, - ), - ) - - -@dataclass -class StaticCacheConfig(CacheConfig): - """Configuration class for static and sliding window cache settings.""" - - batch_size: Optional[int] = None - max_cache_len: Optional[int] = None - device: Union[str, torch.device] = None - dtype: Optional[torch.dtype] = None - layer_device_map: Optional[dict[int, Union[str, torch.device]]] = None - head_dim: Optional[int] = None - num_heads: Optional[int] = None - sliding_window: Optional[int] = None - num_layers: Optional[int] = None - cache_implementation: Optional[str] = None - - def __post_init__(self): - self.cache_implementation = "static" - if self.batch_size is None: - raise ValueError("`batch_size` is required for static cache") - if self.max_cache_len is None: - raise ValueError("`max_cache_len` is required for static cache") - if self.device is None: - self.device = "cpu" - logger.warning_once("`device` not set in cache initialization, using default `cpu`") - if self.dtype is None: - self.dtype = torch.float32 - logger.warning_once("`dtype` not set in cache initialization, using default `float32`") - - def for_layer(self, layer_idx: int): - """Returns a StaticCacheConfig for a given layer index.""" - device = self.layer_device_map[layer_idx] if self.layer_device_map is not None else self.device - return StaticCacheConfig( - self.batch_size, - self.max_cache_len, - device, - self.dtype, - None, - self.head_dim, - self.num_heads, - self.sliding_window, - ) - - @property - def dtype(self): # noqa: F811 - return getattr(torch, self._dtype) if self._dtype is not None else None - - @dtype.setter - def dtype(self, value): - if isinstance(value, torch.dtype): - self._dtype = str(value).split(".")[-1] - elif isinstance(value, str): - self._dtype = value - else: - self._dtype = None - - def validate(self): - """Validates if the arguments passed are correct""" - - incorrect_arg_msg = ( - "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " - "but found {found_value}" - ) - - if self.batch_size <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="batch_size", - correct_value="> 0", - found_value=self.batch_size, - ), - ) - - if self.max_cache_len <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="max_cache_len", - correct_value="> 0", - found_value=self.max_cache_len, - ), - ) - - class DynamicLayer(CacheLayerMixin): """ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. """ + keys, values = None, None @classmethod - def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> None: + def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer": cache = cls() cache.keys = keys cache.values = values @@ -809,7 +533,7 @@ def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.T self.layers.append(DynamicLayer.from_tensors(key_states, value_states)) super().__init__(*args, **kwargs) - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: """ Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for backward compatibility. @@ -924,21 +648,30 @@ class StaticLayer(CacheLayerMixin): def __init__( self, - config: StaticCacheConfig, - max_len: Optional[int] = None, + max_cache_len: int, + batch_size: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype = torch.float32, + device: str = "cpu", + sliding_window: Optional[int] = None, ): - self.max_cache_len = max_len or config.max_cache_len - self.max_batch_size = config.batch_size + self.max_cache_len = max_cache_len + self.max_batch_size = batch_size + self.num_heads = num_heads + self.head_dim = head_dim + self.dtype = dtype + self.device = device # Note: There will be significant perf decrease if switching to use 5D tensors instead. self.keys = torch.zeros( - (config.batch_size, config.num_heads, self.max_cache_len, config.head_dim), - dtype=config.dtype, - device=config.device, + (batch_size, num_heads, self.max_cache_len, head_dim), + dtype=dtype, + device=device, ) self.values = torch.zeros( - (config.batch_size, config.num_heads, self.max_cache_len, config.head_dim), - dtype=config.dtype, - device=config.device, + (batch_size, num_heads, self.max_cache_len, head_dim), + dtype=dtype, + device=device, ) # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, # preventing compiled graph breaks when updating the cache. @@ -1048,8 +781,8 @@ class SlidingWindowLayer(StaticLayer): Inherits from StaticLayer but uses sliding window update logic. """ - def __init__(self, config: CacheConfig): - super().__init__(config, max_len=config.sliding_window) + def __init__(self, sliding_window, max_cache_len=None, *args, **kwargs): + super().__init__(*args, max_cache_len=sliding_window, *args, **kwargs) def _static_update( self, @@ -1124,13 +857,13 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: class SlidingWindowCache(Cache): """ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. - Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, + Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.sliding_window - 1`, if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: - indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window + indices = (slicing + to_shift[-1].sum()-1) % self.sliding_window tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, @@ -1774,11 +1507,11 @@ def init(self, cache: "Cache", **kwargs) -> None: self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers) if self.is_static: for i, layer in enumerate(cache.layers): - device = cache.config.device if i == 0 else self.offload_device + device = cache.layer_args["device"] if i == 0 else self.offload_device layer.keys = layer.keys.to(device) layer.values = layer.values.to(device) - self.original_device.append(cache.config.device) - if len(cache) != cache.config.num_layers: + self.original_device.append(cache.layer_args["device"]) + if len(cache) != cache.model_num_layers: raise ValueError("If static layers are used, all cache layers must be initialized") self.prefetch_stream = ( @@ -1859,20 +1592,109 @@ class QuantizedCacheProcessor(CacheProcessor): length in original precision and quantizing older tokens. """ - def __init__(self, cache_config: QuantizedCacheConfig): - self.config = cache_config - self._quantized_keys: list[torch.Tensor] = [] - self._quantized_values: list[torch.Tensor] = [] - - def init(self, cache: "Cache", **kwargs) -> None: - """Initialize the quantized processor and validate configuration.""" - self.config.validate() + def init( + self, + cache: "Cache", + backend: str = "quanto", + nbits: Optional[int] = 4, + axis_key: Optional[int] = 0, + axis_value: Optional[int] = 0, + q_group_size: Optional[int] = 64, + residual_length: Optional[int] = 128, + compute_dtype: Optional[torch.dtype] = torch.float16, + device: Optional[str] = "cpu", + ): + """ + Parameters: + backend (`str`, *optional*, defaults to `"quanto"`): + Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] + nbits (`Optional[int]`, *optional*, defaults to 4): + Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. + axis_key (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + axis_value (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + q_group_size (`Optional[int]`, *optional*, defaults to 64): + Size of the quantization group, should be a divisor of the model's hidden dimension. + Defaults to 64. + residual_length (`Optional[int]`, *optional*, defaults to 128): + Length of the residual cache which will always be stored in original precision. + Defaults to 128. + compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. + device (`str`, *optional*, defaults to `"cpu"`): + Device on which to perform computations, should be same as the model's device. + """ + self.backend = backend + self.nbits = nbits + self.axis_key = axis_key + self.axis_value = axis_value + self.q_group_size = q_group_size + self.residual_length = residual_length + self.compute_dtype = compute_dtype + self.device = device + self._quantized_keys: list[torch.Tensor] = [] + self._quantized_values: list[torch.Tensor] = [] + + self.validate() self.erased_length = 0 # Only compatible with DynamicCache if not isinstance(cache.layers[0], DynamicLayer): raise ValueError("QuantizedCacheProcessor is only compatible with DynamicCache") + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + # Check that the values are reasonable in general (nbits, axis) + # Later in QuantizedCache init we check if they are supported for that particular backend + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + incorrect_arg_msg.format( + key="nbits", + correct_value="2 or 4 or 8", + found_value=self.nbits, + ), + ) + if self.q_group_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="q_group_size", + correct_value="a positive integer", + found_value=self.q_group_size, + ), + ) + if self.residual_length < 0: + raise ValueError( + incorrect_arg_msg.format( + key="residual_length", + correct_value="a positive integer", + found_value=self.residual_length, + ), + ) + + if self.axis_key not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_key", + correct_value="`1` or `0`, `-1`", + found_value=self.axis_key, + ), + ) + + if self.axis_value not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_value", + correct_value="`1` or `0` or `-1`", + found_value=self.axis_value, + ), + ) + def post_update( self, cache: "Cache", @@ -1890,8 +1712,8 @@ def post_update( # On the first forward pass, we quantize the whole prompt (prefill, quantize_length=0) # On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full. if self._is_quantized_length_zero(layer_idx): - self._quantized_keys.append(self._quantize(key_tensors.contiguous(), axis=self.config.axis_key)) - self._quantized_values.append(self._quantize(value_tensors.contiguous(), axis=self.config.axis_value)) + self._quantized_keys.append(self._quantize(key_tensors.contiguous(), axis=self.axis_key)) + self._quantized_values.append(self._quantize(value_tensors.contiguous(), axis=self.axis_value)) # Clear the residual cache self.erased_length = key_tensors.shape[-2] @@ -1914,14 +1736,10 @@ def post_update( dequant_value = self._dequantize(self._quantized_values[layer_idx]) keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2) values_to_return = torch.cat([dequant_value, value_tensors], dim=-2) - if key_tensors.shape[-2] >= self.config.residual_length: + if key_tensors.shape[-2] >= self.residual_length: # Quantize and store - self._quantized_keys[layer_idx] = self._quantize( - keys_to_return.contiguous(), axis=self.config.axis_key - ) - self._quantized_values[layer_idx] = self._quantize( - values_to_return.contiguous(), axis=self.config.axis_value - ) + self._quantized_keys[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) + self._quantized_values[layer_idx] = self._quantize(values_to_return.contiguous(), axis=self.axis_value) # Clear the residual cache self.erased_length += key_tensors.shape[-2] @@ -1969,20 +1787,18 @@ def init(self, cache: "Cache", **kwargs) -> None: ) from optimum.quanto import MaxOptimizer, qint2, qint4 - if self.config.nbits not in [2, 4]: - raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.config.nbits}") + if self.nbits not in [2, 4]: + raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") - if self.config.axis_key not in [0, -1]: - raise ValueError( - f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.config.axis_key}" - ) + if self.axis_key not in [0, -1]: + raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") - if self.config.axis_value not in [0, -1]: + if self.axis_value not in [0, -1]: raise ValueError( - f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.config.axis_value}" + f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" ) - self.qtype = qint4 if self.config.nbits == 4 else qint2 + self.qtype = qint4 if self.nbits == 4 else qint2 self.optimizer = MaxOptimizer() def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: @@ -1990,8 +1806,8 @@ def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: if is_optimum_quanto_available(): from optimum.quanto import quantize_weight - scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.config.q_group_size) - qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.config.q_group_size) + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) return qtensor def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor: @@ -2009,20 +1825,16 @@ def init(self, cache: "Cache", **kwargs) -> None: """Initialize the HQQ quantization processor.""" super().init(cache, **kwargs) - if self.config.nbits not in [1, 2, 3, 4, 8]: + if self.nbits not in [1, 2, 3, 4, 8]: raise ValueError( - f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.config.nbits}" + f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" ) - if self.config.axis_key not in [0, 1]: - raise ValueError( - f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.config.axis_key}" - ) + if self.axis_key not in [0, 1]: + raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") - if self.config.axis_value not in [0, 1]: - raise ValueError( - f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.config.axis_value}" - ) + if self.axis_value not in [0, 1]: + raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") self.quantizer = HQQQuantizer @@ -2031,13 +1843,13 @@ def _quantize(self, tensor: torch.Tensor, axis: int) -> tuple[torch.Tensor, dict qtensor, meta = self.quantizer.quantize( tensor, axis=axis, - device=self.config.device, - compute_dtype=self.config.compute_dtype, - nbits=self.config.nbits, - group_size=self.config.q_group_size, + device=self.device, + compute_dtype=self.compute_dtype, + nbits=self.nbits, + group_size=self.q_group_size, ) - meta["compute_dtype"] = self.config.compute_dtype - self.quantizer.cuda(qtensor, meta=meta, device=self.config.device) # Move to device and cast to dtype + meta["compute_dtype"] = self.compute_dtype + self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype meta["scale"] = meta["scale"].to(qtensor.device) meta["zero"] = meta["zero"].to(qtensor.device) return qtensor, meta @@ -2063,13 +1875,13 @@ class QuantizedCache(DynamicCache): is `[batch_size, num_heads, seq_len - residual_length, head_dim]` """ - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - if cache_config.backend == "quanto": - processor = QuantoQuantizedCacheProcessor(cache_config) - elif cache_config.backend == "hqq": - processor = HQQQuantizedCacheProcessor(cache_config) + def __init__(self, backend, *args, **kwargs) -> None: + if backend == "quanto": + processor = QuantoQuantizedCacheProcessor() + elif backend == "hqq": + processor = HQQQuantizedCacheProcessor() else: - raise ValueError(f"Unknown quantization backend `{cache_config.backend}`") + raise ValueError(f"Unknown quantization backend `{backend}`") super().__init__(cache_processor=processor) @@ -2089,10 +1901,6 @@ class QuantoQuantizedCache(QuantizedCache): Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. - Example: ```python @@ -2113,8 +1921,8 @@ class QuantoQuantizedCache(QuantizedCache): ``` """ - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor(cache_config)) + def __init__(self, *args, **kwargs) -> None: + Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor()) class HQQQuantizedCache(QuantizedCache): @@ -2132,10 +1940,6 @@ class HQQQuantizedCache(QuantizedCache): Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. - Example: ```python @@ -2156,8 +1960,9 @@ class HQQQuantizedCache(QuantizedCache): ``` """ - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor(cache_config)) + def __init__(self, backend="HQQ", *args, **kwargs) -> None: + assert backend == "HQQ" + Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor()) class SinkCache(Cache): @@ -2175,6 +1980,257 @@ def __init__(self, **kwargs) -> None: ) +@dataclass +class CacheConfig: + """ + Base class for cache configs + """ + + cache_implementation: None + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a CacheConfig instance from a dictionary of parameters. + Args: + config_dict (dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + + Returns: + CacheConfig: Instance of CacheConfig constructed from the dictionary. + """ + warnings.warn( + ("CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."), + FutureWarning, + stacklevel=2, + ) + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + + Args: + kwargs (`dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +@dataclass +class QuantizedCacheConfig(CacheConfig): + """ + Configuration class for quantized cache settings. + + Attributes: + backend (`str`, *optional*, defaults to `"quanto"`): + Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] + nbits (`Optional[int]`, *optional*, defaults to 4): + Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. + axis_key (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + axis_value (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + q_group_size (`Optional[int]`, *optional*, defaults to 64): + Size of the quantization group, should be a divisor of the model's hidden dimension. + Defaults to 64. + residual_length (`Optional[int]`, *optional*, defaults to 128): + Length of the residual cache which will always be stored in original precision. + Defaults to 128. + compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. + device (`str`, *optional*, defaults to `"cpu"`): + Device on which to perform computations, should be same as the model's device. + """ + + def __init__( + self, + backend: str = "quanto", + nbits: Optional[int] = 4, + axis_key: Optional[int] = 0, + axis_value: Optional[int] = 0, + q_group_size: Optional[int] = 64, + residual_length: Optional[int] = 128, + compute_dtype: Optional[torch.dtype] = torch.float16, + device: Optional[str] = "cpu", + ): + warnings.warn( + ("CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."), + FutureWarning, + stacklevel=2, + ) + self.backend = backend + self.nbits = nbits + self.axis_key = axis_key + self.axis_value = axis_value + self.q_group_size = q_group_size + self.residual_length = residual_length + self.compute_dtype = compute_dtype + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + # Check that the values are reasonable in general (nbits, axis) + # Later in QuantizedCache init we check if they are supported for that particular backend + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + incorrect_arg_msg.format( + key="nbits", + correct_value="2 or 4 or 8", + found_value=self.nbits, + ), + ) + if self.q_group_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="q_group_size", + correct_value="a positive integer", + found_value=self.q_group_size, + ), + ) + if self.residual_length < 0: + raise ValueError( + incorrect_arg_msg.format( + key="residual_length", + correct_value="a positive integer", + found_value=self.residual_length, + ), + ) + + if self.axis_key not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_key", + correct_value="`1` or `0`, `-1`", + found_value=self.axis_key, + ), + ) + + if self.axis_value not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_value", + correct_value="`1` or `0` or `-1`", + found_value=self.axis_value, + ), + ) + + +@dataclass +class StaticCacheConfig(CacheConfig): + """ + Configuration class for static cache settings. + """ + + cache_implementation = "static" + + def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): + warnings.warn( + ("CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."), + FutureWarning, + stacklevel=2, + ) + self.batch_size = batch_size + self.max_cache_len = max_cache_len + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + + if self.batch_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="batch_size", + correct_value="> 0", + found_value=self.batch_size, + ), + ) + + if self.max_cache_len <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="max_cache_len", + correct_value="> 0", + found_value=self.max_cache_len, + ), + ) + + # TODO (manuel, joao): remove this class, it is here only for backwards compatibility # PEP 562: Lazy loading for deprecated location of MambaCache def __getattr__(name: str) -> Any: diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index baae6690a94d..6453d7c6d287 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -44,7 +44,6 @@ logger = logging.get_logger(__name__) METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version") -CACHE_CONFIG_MAPPING = {} NEED_SETUP_CACHE_CLASSES_MAPPING = {} QUANT_BACKEND_CLASSES_MAPPING = {} ALL_CACHE_IMPLEMENTATIONS = [] @@ -56,18 +55,12 @@ HybridChunkedCache, OffloadedHybridCache, OffloadedStaticCache, - QuantizedCacheConfig, QuantoQuantizedCache, SlidingWindowCache, StaticCache, - StaticCacheConfig, ) from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor - CACHE_CONFIG_MAPPING["quantized"] = QuantizedCacheConfig - CACHE_CONFIG_MAPPING["static"] = StaticCacheConfig - CACHE_CONFIG_MAPPING["sliding_window"] = StaticCacheConfig - CACHE_CONFIG_MAPPING["hybrid"] = StaticCacheConfig NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache, "offloaded_static": OffloadedStaticCache, @@ -188,10 +181,8 @@ class GenerationConfig(PushToHubMixin): If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See our [cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information. - cache_config (`CacheConfig` or `dict`, *optional*, default to `None`): - Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and - it will be converted to its respective `CacheConfig` internally. - Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. + cache_config (`dict`, *optional*, default to `None`): + Arguments used in the key-value cache class can be passed in `cache_config`. return_legacy_cache (`bool`, *optional*, default to `True`): Whether to return the legacy or new format of the cache when `DynamicCache` is used by default. @@ -406,10 +397,6 @@ def __init__(self, **kwargs): self.use_cache = kwargs.pop("use_cache", True) self.cache_implementation = kwargs.pop("cache_implementation", None) self.cache_config = kwargs.pop("cache_config", None) - if self.cache_implementation is not None and self.cache_implementation in CACHE_CONFIG_MAPPING: - cache_config_class = CACHE_CONFIG_MAPPING[self.cache_implementation] - if isinstance(self.cache_config, dict): - self.cache_config = cache_config_class.from_dict(self.cache_config) self.return_legacy_cache = kwargs.pop("return_legacy_cache", None) self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None) @@ -611,17 +598,6 @@ def validate(self, strict=False): f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: " f"{ALL_CACHE_IMPLEMENTATIONS}" ) - if self.cache_config is not None: - cache_class = CACHE_CONFIG_MAPPING.get(self.cache_implementation) - if cache_class is None: - raise ValueError( - "You provided a `cache_config` but the cache implementation you are using " - f"({self.cache_implementation}) does not require any config. Make sure to use the " - "correct cache implementation matching your cache config." - ) - if not isinstance(self.cache_config, cache_class): - self.cache_config = cache_class.from_dict(self.cache_config) - self.cache_config.validate() # 1.3. Performance attributes if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig): raise ValueError( diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 53095c19121a..8be987f3dd9f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -35,7 +35,6 @@ HybridChunkedCache, OffloadedCache, OffloadedHybridCache, - QuantizedCacheConfig, ) from ..configuration_utils import PretrainedConfig from ..dynamic_module_utils import ( @@ -2077,22 +2076,22 @@ def _prepare_cache_for_generation( cache_config = ( generation_config.cache_config if generation_config.cache_config is not None - else QuantizedCacheConfig() + else {"backend": "quanto"} ) - cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config["backend"]] - if cache_config.backend == "quanto" and not is_optimum_quanto_available(): + if cache_config["backend"] == "quanto" and not is_optimum_quanto_available(): raise ImportError( "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. " "Please install it via with `pip install optimum-quanto`" ) - elif cache_config.backend == "HQQ" and not is_hqq_available(): + elif cache_config["backend"] == "HQQ" and not is_hqq_available(): raise ImportError( "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " "Please install it via with `pip install hqq`" ) - model_kwargs[cache_name] = cache_class(cache_config) + model_kwargs[cache_name] = cache_class(**cache_config) elif generation_config.cache_implementation == "offloaded": model_kwargs[cache_name] = OffloadedCache() elif generation_config.cache_implementation == "dynamic": diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 36aab8699a81..0331654fc1d4 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -276,9 +276,9 @@ def __init__(self, model: PreTrainedModel): self.model = model self.static_cache = StaticCache( model_config=self.model.config, - max_batch_size=self.model.generation_config.cache_config.batch_size, - max_cache_len=self.model.generation_config.cache_config.max_cache_len, - device=self.model.generation_config.cache_config.device, + max_batch_size=self.model.generation_config.cache_config.get("batch_size"), + max_cache_len=self.model.generation_config.cache_config.get("max_cache_len"), + device=self.model.generation_config.cache_config.get("device"), dtype=self.model.dtype, ) for i in range(len(self.static_cache)): diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index e06056d7c0be..f32218a15b18 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -693,7 +693,7 @@ def create_causal_mask( """ # If we have an HybridCache structure, here we want to create the mask for the full layers is_sliding = [] - if past_key_values is not None: + if past_key_values is not None and past_key_values.layers is not None: is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] layer_idx = is_sliding.index(True) if True in is_sliding else 0 @@ -775,7 +775,7 @@ def create_sliding_window_causal_mask( """ # If we have an HybridCache structure, here we want to create the mask for the sliding layers is_sliding = [] - if past_key_values is not None: + if past_key_values is not None and past_key_values.layers is not None: is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] layer_idx = is_sliding.index(True) if True in is_sliding else 0 diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 96b7eed4777c..a27a03368e19 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -146,6 +146,9 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None): def __len__(self): return len(self.key_cache) + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update def update( self, diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 0ef3bd37b161..81479ad302f6 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -150,6 +150,9 @@ def __init__( def __len__(self): return len(self.key_cache) + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + def update( self, key_states: torch.Tensor, diff --git a/tests/models/falcon_h1/test_modeling_falcon_h1.py b/tests/models/falcon_h1/test_modeling_falcon_h1.py index 1fb85f7de82a..fab9a7f83002 100644 --- a/tests/models/falcon_h1/test_modeling_falcon_h1.py +++ b/tests/models/falcon_h1/test_modeling_falcon_h1.py @@ -36,7 +36,7 @@ if is_torch_available(): import torch - from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model + from transformers import AutoTokenizer, Cache, FalconH1ForCausalLM, FalconH1Model from transformers.models.falcon_h1.modeling_falcon_h1 import ( FalconHybridMambaAttentionDynamicCache, ) @@ -270,6 +270,43 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM {"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {} ) + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, (tuple, Cache)) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + + if isinstance(decoder_past_key_values, Cache): + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) + + # Legacy cache format checks. This branch should be removed when all models use `Cache` by default + else: + self.assertListEqual( + [isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values], + [True] * len(decoder_past_key_values), + ) + # check shape key, value + self.assertListEqual( + [layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values], + [expected_shape] * len(decoder_past_key_values), + ) + self.assertListEqual( + [layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values], + [expected_shape] * len(decoder_past_key_values), + ) + def setUp(self): self.model_tester = FalconH1ModelTester(self) self.config_tester = ConfigTester(self, config_class=FalconH1Config, hidden_size=64) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index d5c1463cf618..1e5da8f542e6 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -813,8 +813,8 @@ def test_static_cache_exportability(self): self.assertEqual(model.generation_config.cache_implementation, cache_implementation) self.assertEqual(model.generation_config.max_length, max_cache_len) self.assertTrue(model.generation_config.cache_config is not None) - self.assertEqual(model.generation_config.cache_config.batch_size, batch_size) - self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len) + self.assertEqual(model.generation_config.cache_config.get("batch_size"), batch_size) + self.assertEqual(model.generation_config.cache_config.get("max_cache_len"), max_cache_len) exported_program = convert_and_export_with_cache(model) From 27916bc2737806bf849ce2148cb1e66d59573913 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 9 Jul 2025 20:56:42 +0200 Subject: [PATCH 07/35] back to storage inside Cache() --- docs/source/en/cache_explanation.md | 4 +- src/transformers/cache_utils.py | 489 +++++++++++------- src/transformers/integrations/executorch.py | 12 +- src/transformers/models/bart/modeling_bart.py | 4 +- .../modeling_bigbird_pegasus.py | 4 +- .../models/biogpt/modeling_biogpt.py | 4 +- .../models/blenderbot/modeling_blenderbot.py | 4 +- .../modeling_blenderbot_small.py | 4 +- src/transformers/models/dia/modeling_dia.py | 4 +- src/transformers/models/dia/modular_dia.py | 4 +- .../models/gemma3n/modeling_gemma3n.py | 4 +- .../models/gemma3n/modular_gemma3n.py | 4 +- src/transformers/models/gptj/modeling_gptj.py | 4 +- .../models/informer/modeling_informer.py | 8 +- .../models/informer/modular_informer.py | 4 +- .../models/longt5/modeling_longt5.py | 4 +- .../models/m2m_100/modeling_m2m_100.py | 4 +- .../models/marian/modeling_marian.py | 4 +- .../models/mbart/modeling_mbart.py | 4 +- .../models/minimax/modeling_minimax.py | 6 +- .../models/minimax/modular_minimax.py | 6 +- .../models/mllama/modeling_mllama.py | 4 +- .../models/moonshine/modeling_moonshine.py | 4 +- .../models/moonshine/modular_moonshine.py | 4 +- src/transformers/models/mt5/modeling_mt5.py | 4 +- .../models/pegasus/modeling_pegasus.py | 4 +- .../models/pegasus_x/modeling_pegasus_x.py | 4 +- .../models/pix2struct/modeling_pix2struct.py | 4 +- .../models/plbart/modeling_plbart.py | 4 +- .../models/pop2piano/modeling_pop2piano.py | 4 +- .../modeling_switch_transformers.py | 4 +- src/transformers/models/t5/modeling_t5.py | 4 +- .../models/t5gemma/modeling_t5gemma.py | 4 +- .../models/t5gemma/modular_t5gemma.py | 4 +- .../modeling_time_series_transformer.py | 4 +- src/transformers/models/udop/modeling_udop.py | 4 +- src/transformers/models/umt5/modeling_umt5.py | 4 +- .../models/whisper/generation_whisper.py | 4 +- .../models/whisper/modeling_whisper.py | 4 +- tests/generation/test_utils.py | 80 ++- .../deepseek_v3/test_modeling_deepseek_v3.py | 8 +- tests/models/dia/test_modeling_dia.py | 8 +- .../models/gpt_neox/test_modeling_gpt_neox.py | 4 +- tests/models/t5gemma/test_modeling_t5gemma.py | 36 +- tests/utils/test_cache_utils.py | 50 +- utils/check_docstrings.py | 5 +- 46 files changed, 473 insertions(+), 375 deletions(-) diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 2adcc0c78012..1c82330a2433 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -89,8 +89,8 @@ Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWi The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token: ```py -cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2) -cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2) +cache.key_cache[idx] = torch.cat([cache.key_cache[idx], key_states], dim=-2) +cache.value_cache[idx] = torch.cat([cache.value_cache[idx], value_states], dim=-2) ``` Other layers like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added. diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 08f1a70550f7..5500b9247ebc 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -91,21 +91,20 @@ class CacheLayerMixin: is_compileable = False - def __repr__(self): - key_repr = "None" if self.keys is None else f"t({tuple(self.keys.shape)})" - value_repr = "None" if self.values is None else f"t({tuple(self.values.shape)})" - return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})" - def update( self, key_states: torch.Tensor, value_states: torch.Tensor, + key_cache: Optional[torch.Tensor] = None, + value_cache: Optional[torch.Tensor] = None, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Updates KV cache, returns updated keys/values for this layer.""" raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") - def get_seq_length(self) -> int: + def get_seq_length( + self, key_cache: Optional[torch.Tensor] = None, value_cache: Optional[torch.Tensor] = None + ) -> int: """Returns the sequence length of this layer's cache.""" raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.") @@ -113,18 +112,37 @@ def get_max_cache_shape(self) -> int: """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") - def reset(self) -> None: + def reset(self, key_cache: torch.Tensor, value_cache: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Resets this layer's cache.""" raise NotImplementedError(f"Make sure to implement `reset` in {self.__class__.__name__}.") - def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + def get_mask_sizes( + self, + cache_position: torch.Tensor, + key_cache: Optional[torch.Tensor] = None, + value_cache: Optional[torch.Tensor] = None, + ) -> tuple[int, int]: + """Returns mask sizes for this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.") + + def new_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: + """Returns a new key and value tensor for this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `new_tensors` in {self.__class__.__name__}.") + + def reorder_cache( + self, + beam_idx: torch.LongTensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: """Reorders this layer's cache for beam search.""" - if self.keys.numel(): - device = self.keys.device - self.keys = self.keys.index_select(0, beam_idx.to(device)) - if self.values.numel(): - device = self.values.device - self.values = self.values.index_select(0, beam_idx.to(device)) + if key_cache.numel(): + device = key_cache.device + key_cache = key_cache.index_select(0, beam_idx.to(device)) + if value_cache.numel(): + device = value_cache.device + value_cache = value_cache.index_select(0, beam_idx.to(device)) + return key_cache, value_cache, None def parse_layer_args_from_model_config( @@ -184,7 +202,9 @@ def parse_layer_args_from_model_config( class CacheBase: - layers = None + layers: list[CacheLayerMixin] = None + key_cache: list[torch.Tensor] = None + value_cache: list[torch.Tensor] = None def update( self, @@ -242,6 +262,8 @@ def __init__( - `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping """ self.layers: list[CacheLayerMixin] = [] + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] self.cache_processor = cache_processor if ( @@ -252,7 +274,7 @@ def __init__( layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in model_config.layer_types] self.layer_classes = layer_classes or [DynamicLayer] - self.layer_args = parse_layer_args_from_model_config(model_config, *args, **kwargs) + self.layer_init_args = parse_layer_args_from_model_config(model_config, *args, **kwargs) self.model_num_layers = getattr(model_config, "num_hidden_layers", 1) self.append_new_layers(self.model_num_layers - 1) @@ -266,11 +288,9 @@ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: sequence length. """ if layer_idx < len(self.layers): - return self.layers[layer_idx].keys, self.layers[layer_idx].values + return self.key_cache[layer_idx], self.value_cache[layer_idx] else: - raise KeyError( - f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" - ) + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") def __iter__(self): """ @@ -278,61 +298,63 @@ def __iter__(self): keys and values """ for layer_idx in range(len(self)): - yield (self.layers[layer_idx].keys, self.layers[layer_idx].values) + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) def __len__(self): """ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds to the number of layers in the model. """ - # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked. - if self.layers is None: - if getattr(self, "key_cache", None) is not None: - return len(self.key_cache) - return 0 dynamic_empty = ( - len(self.layers) == 1 and isinstance(self.layers[0], DynamicLayer) and self.layers[0].keys is None + self.layers is not None + and len(self.layers) == 1 + and isinstance(self.layers[0], DynamicLayer) + and self.key_cache[0] is None ) - return len(self.layers) if not dynamic_empty else 0 + return len(self.key_cache) if not dynamic_empty else 0 def __getattr__(self, name): """ - Dynamically handle any method call / attribute access by forwarding to layers. - If it doesn't exist on all layers, raises AttributeError. - For attributes, checks all different layer types and reduces to an set of unique values (or to a single value) - For methods, returns a function that propagates the call to all layers with carry-over state (e.g. update, reset) + Forward a *shared* operation or property to every layer. + + - If every layer defines a callable `name`, we return a proxy that + calls each layer in turn and updates `key_cache` / `value_cache`. + - If every layer defines a boolean `name`, we return `all(values)`. + - If every layer defines a constant property `name`, we return it. + - Otherwise we raise AttributeError. """ - if name in ("__getstate__", "__setstate__"): - raise AttributeError(name) + if not self.layers: + return object.__getattribute__(self, name) - is_attribute = ( - all(hasattr(layer, name) and not callable(getattr(layer, name)) for layer in self.layers) - and len(self.layers) > 0 - ) - is_method = all(callable(getattr(layer, name)) for layer in self.layers) and len(self.layers) > 0 - - if is_attribute: - attribute_values = [getattr(layer, name) for layer in self.layers] - if all(isinstance(value, bool) for value in attribute_values): - return all(attribute_values) - elif len(set(attribute_values)) == 1: - return attribute_values[0] - else: - raise ValueError( - f"{self.__class__.__name__}: layers have multiple values for layer.{name}: {attribute_values}. This is not supported." - ) - elif is_method: + # 1) Gather the attribute from every layer + values = [] + for layer in self.layers: + values.append(getattr(layer, name, None)) - def propagate_to_layers(*args, **kwargs): - for layer in self.layers: - return_value = getattr(layer, name)(*args, **kwargs) - if return_value is not None: + # 2) All callables → make a forwarding function + if all(callable(v) for v in values): + + def _proxy(*args, **kwargs): + for i in range(len(self.layers)): + self.key_cache[i], self.value_cache[i], ret = values[i]( + key_cache=self.key_cache[i], value_cache=self.value_cache[i], *args, **kwargs + ) + if ret is not None: break - return return_value + return ret - return propagate_to_layers - else: - raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") + return _proxy + + # 3) All booleans → reduce with `all` + if all(isinstance(v, bool) for v in values): + return all(values) + + # 4) All identical → return first + if all(v == values[0] for v in values): + return values[0] + + # 5) Anything else → unsupported mixed attribute + raise AttributeError(f"{self.__class__.__name__}: layers disagree on attribute {name!r}: {values}") def __repr__(self): return f"{self.__class__.__name__}(layers={self.layers})" @@ -343,10 +365,14 @@ def append_new_layers(self, layer_idx): Used in prefill and for skipped layers. """ while len(self.layers) <= layer_idx: - args = self.layer_args.copy() - if self.layer_args.get("layer_device_map", None) is not None: + args = self.layer_init_args.copy() + if self.layer_init_args.get("layer_device_map", None) is not None: args["device"] = args.pop("layer_device_map")[layer_idx] - self.layers.append(self.layer_classes[layer_idx % len(self.layer_classes)](**args)) + new_layer = self.layer_classes[layer_idx % len(self.layer_classes)](**args) + new_key, new_value = new_layer.new_tensors() + self.layers.append(new_layer) + self.key_cache.append(new_key) + self.value_cache.append(new_value) def _update( self, @@ -373,16 +399,29 @@ def _update( A tuple containing the updated key and value states. """ self.append_new_layers(layer_idx) - return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].update( + key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], cache_kwargs + ) + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def to(self, device: torch.device) -> "Cache": + """Moves the cache to the given device.""" + for idx in range(len(self.key_cache)): + self.key_cache[idx] = self.key_cache[idx].to(device) + self.value_cache[idx] = self.value_cache[idx].to(device) + return self def get_seq_length(self, layer_idx: int = 0) -> int: """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" if layer_idx >= len(self.layers): return 0 + _, _, seq_length = self.layers[layer_idx].get_seq_length( + self.key_cache[layer_idx], self.value_cache[layer_idx] + ) # Hack since QuantizedCache messes with keys shape as it becomes the residual cache if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): - return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length() - return self.layers[layer_idx].get_seq_length() + return self.cache_processor.erased_length + seq_length + return seq_length def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: """ @@ -391,7 +430,10 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), for each layer. """ - return self.layers[layer_idx].get_mask_sizes(cache_position) + _, _, (kv_length, kv_offset) = self.layers[layer_idx].get_mask_sizes( + cache_position, self.key_cache[layer_idx], self.value_cache[layer_idx] + ) + return kv_length, kv_offset class DynamicLayer(CacheLayerMixin): @@ -400,19 +442,12 @@ class DynamicLayer(CacheLayerMixin): It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. """ - keys, values = None, None - - @classmethod - def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer": - cache = cls() - cache.keys = keys - cache.values = values - return cache - def update( self, key_states: torch.Tensor, value_states: torch.Tensor, + key_cache: Optional[torch.Tensor] = None, + value_cache: Optional[torch.Tensor] = None, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -429,70 +464,110 @@ def update( Return: A tuple containing the updated key and value states. """ - if self.keys is None: - self.keys = key_states - self.values = value_states + if key_cache is None: + key_cache = key_states + value_cache = value_states else: - self.keys = torch.cat([self.keys, key_states], dim=-2) - self.values = torch.cat([self.values, value_states], dim=-2) - return self.keys, self.values + key_cache = torch.cat([key_cache, key_states], dim=-2) + value_cache = torch.cat([value_cache, value_states], dim=-2) + return key_cache, value_cache + + def new_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: + """Returns a new key and value tensor for this layer's cache.""" + return None, None # They get initialized in the update() - def get_seq_length(self, cache_position: Optional[torch.LongTensor] = None) -> int: + def get_seq_length( + self, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, int]: """Returns the sequence length of the cached states.""" # TODO: deprecate this function in favor of `cache_position` - if self is None or self.keys is None or self.keys.numel() == 0: - return 0 - return self.keys.shape[-2] + if key_cache is None or key_cache.numel() == 0: + return key_cache, value_cache, 0 + return key_cache, value_cache, key_cache.shape[-2] - def get_max_cache_shape(self) -> int: + def get_max_cache_shape( + self, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + ) -> int: """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" - return -1 + return key_cache, value_cache, -1 - def reset(self) -> None: + def reset( + self, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: """Resets the cache values while preserving the objects""" - self.keys = torch.tensor([], dtype=self.keys.dtype, device=self.keys.device) - self.values = torch.tensor([], dtype=self.values.dtype, device=self.values.device) + key_cache.zero_() + value_cache.zero_() + return key_cache, value_cache, None - def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + def reorder_cache( + self, + beam_idx: torch.LongTensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: """Reorders the cache for beam search, given the selected beam indices.""" - if self.keys is not None and self.keys.numel(): - self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device)) - if self.values is not None and self.values.numel(): - self.values = self.values.index_select(0, beam_idx.to(self.values.device)) + if key_cache is not None and key_cache.numel(): + key_cache = key_cache.index_select(0, beam_idx.to(key_cache.device)) + value_cache = value_cache.index_select(0, beam_idx.to(value_cache.device)) + return key_cache, value_cache, None - def crop(self, max_length: int) -> None: + def crop( + self, + max_length: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative to remove `max_length` tokens. """ if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) + max_length = self.get_seq_length(key_cache, value_cache)[-1] - abs(max_length) - if self.get_seq_length() <= max_length: - return + if self.get_seq_length(key_cache, value_cache)[-1] <= max_length: + return key_cache, value_cache, None - if self.keys is not None and self.keys.numel(): - self.keys = self.keys[..., :max_length, :] - self.values = self.values[..., :max_length, :] + if key_cache is not None and key_cache.numel(): + key_cache = key_cache[..., :max_length, :] + value_cache = value_cache[..., :max_length, :] + return key_cache, value_cache, None - def batch_repeat_interleave(self, repeats: int) -> None: + def batch_repeat_interleave( + self, repeats: int, key_cache: torch.Tensor, value_cache: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: """Repeat the cache `repeats` times in the batch dimension.""" - if self.keys.numel(): - self.keys = self.keys.repeat_interleave(repeats, dim=0) - self.values = self.values.repeat_interleave(repeats, dim=0) + if key_cache.numel(): + key_cache = key_cache.repeat_interleave(repeats, dim=0) + value_cache = value_cache.repeat_interleave(repeats, dim=0) + return key_cache, value_cache, None - def batch_select_indices(self, indices: torch.Tensor) -> None: + def batch_select_indices( + self, indices: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: """Only keep the `indices` in the batch dimension of the cache.""" - if self.keys.numel(): - self.keys = self.keys[indices, ...] - self.values = self.values[indices, ...] + if key_cache.numel(): + key_cache = key_cache[indices, ...] + value_cache = value_cache[indices, ...] + return key_cache, value_cache, None - def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + def get_mask_sizes( + self, + cache_position: torch.Tensor, + key_cache: Optional[torch.Tensor] = None, + value_cache: Optional[torch.Tensor] = None, + ) -> tuple[int, int]: full_mask_kv_offset = 0 query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length(cache_position) + _, _, past_seen_tokens = self.get_seq_length(key_cache, value_cache, cache_position) kv_length = query_length + past_seen_tokens - return kv_length, full_mask_kv_offset + return key_cache, value_cache, (kv_length, full_mask_kv_offset) class DynamicCache(Cache): @@ -530,7 +605,9 @@ def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.T # compatibility. The name of the argument doesn't matter. if ddp_cache_data is not None: for key_states, value_states in ddp_cache_data: - self.layers.append(DynamicLayer.from_tensors(key_states, value_states)) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.layers.append(DynamicLayer()) super().__init__(*args, **kwargs) def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: @@ -539,9 +616,8 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: backward compatibility. """ legacy_cache = () - for layer in self.layers: - if layer is not None: - legacy_cache += ((layer.keys, layer.values),) + for keys, values in zip(self.key_cache, self.value_cache): + legacy_cache += ((keys, values),) return legacy_cache @classmethod @@ -574,16 +650,16 @@ def _flatten_dynamic_cache( ) dictionary = { - "key_cache": [layer.keys for layer in dynamic_cache.layers if layer.keys is not None], - "value_cache": [layer.values for layer in dynamic_cache.layers if layer.values is not None], + "key_cache": dynamic_cache.key_cache if dynamic_cache.key_cache[0] is not None else [], + "value_cache": dynamic_cache.value_cache if dynamic_cache.value_cache[0] is not None else [], } return torch.utils._pytree._dict_flatten(dictionary) def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): dictionary = { - "key_cache": [layer.keys for layer in dynamic_cache.layers if layer.keys is not None], - "value_cache": [layer.values for layer in dynamic_cache.layers if layer.values is not None], + "key_cache": dynamic_cache.key_cache if dynamic_cache.key_cache[0] is not None else [], + "value_cache": dynamic_cache.value_cache if dynamic_cache.value_cache[0] is not None else [], } return torch.utils._pytree._dict_flatten_with_keys(dictionary) @@ -606,8 +682,8 @@ def _unflatten_dynamic_cache( def _flatten_dynamic_cache_for_fx(cache, spec): dictionary = { - "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None], - "value_cache": [layer.values for layer in cache.layers if layer.values is not None], + "key_cache": cache.key_cache if cache.key_cache[0] is not None else [], + "value_cache": cache.value_cache if cache.value_cache[0] is not None else [], } return torch.fx._pytree._dict_flatten_spec(dictionary, spec) @@ -662,89 +738,109 @@ def __init__( self.head_dim = head_dim self.dtype = dtype self.device = device - # Note: There will be significant perf decrease if switching to use 5D tensors instead. - self.keys = torch.zeros( - (batch_size, num_heads, self.max_cache_len, head_dim), - dtype=dtype, - device=device, - ) - self.values = torch.zeros( - (batch_size, num_heads, self.max_cache_len, head_dim), - dtype=dtype, - device=device, - ) - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(self.keys) - torch._dynamo.mark_static_address(self.values) - def get_max_cache_shape(self) -> int: - return self.max_cache_len + def get_max_cache_shape(self, key_cache: torch.Tensor, value_cache: torch.Tensor) -> int: + return key_cache, value_cache, self.max_cache_len def _static_update( self, key_states: torch.Tensor, value_states: torch.Tensor, cache_position: Optional[torch.LongTensor], + key_cache: torch.Tensor, + value_cache: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: # Utility functions for static/sliding cache update logic """ Updates the static cache tensors in place. Args: - k_cache (`torch.Tensor`): The key cache tensor to update. - v_cache (`torch.Tensor`): The value cache tensor to update. key_states (`torch.Tensor`): The new key states to add. value_states (`torch.Tensor`): The new value states to add. cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted. If None, the entire cache is overwritten (prefill). + key_cache (`torch.Tensor`): The key cache tensor to update. + value_cache (`torch.Tensor`): The value cache tensor to update. Returns: tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place). """ if cache_position is None: # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. - self.keys.copy_(key_states) - self.values.copy_(value_states) + key_cache.copy_(key_states) + value_cache.copy_(value_states) else: # Generation phase. Update specific positions. # Use index_copy_ for in-place update (compile-friendly). try: - self.keys.index_copy_(2, cache_position, key_states) - self.values.index_copy_(2, cache_position, value_states) + key_cache.index_copy_(2, cache_position, key_states) + value_cache.index_copy_(2, cache_position, value_states) except NotImplementedError: # Fallback for devices like MPS where index_copy_ might not be supported. - self.keys[:, :, cache_position] = key_states - self.values[:, :, cache_position] = value_states - return self.keys, self.values + key_cache[:, :, cache_position] = key_states + value_cache[:, :, cache_position] = value_states + return key_cache, value_cache - def update(self, key_states, value_states, cache_kwargs=None) -> tuple[torch.Tensor, torch.Tensor]: + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + key_cache: Optional[torch.Tensor] = None, + value_cache: Optional[torch.Tensor] = None, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None - key_states = key_states.to(self.keys.dtype) - value_states = value_states.to(self.values.dtype) - return self._static_update(key_states, value_states, cache_position) + return self._static_update( + key_states.to(key_cache.dtype), value_states.to(value_cache.dtype), cache_position, key_cache, value_cache + ) + + def new_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: + """Returns a new key and value tensor for this layer's cache.""" + # Note: There will be significant perf decrease if switching to use 5D tensors instead. + keys = torch.zeros( + (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), + dtype=self.dtype, + device=self.device, + ) + values = torch.zeros( + (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), + dtype=self.dtype, + device=self.device, + ) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(keys) + torch._dynamo.mark_static_address(values) + return keys, values - def get_seq_length(self, cache_position=None) -> int: + def get_seq_length(self, key_cache: torch.Tensor, value_cache: torch.Tensor, cache_position=None) -> int: if cache_position is not None: return int(cache_position[-1] + 1) # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. - return (self.keys[0, 0].any(dim=-1)).sum() + seq_length = (key_cache[0, 0].any(dim=-1)).sum() if key_cache is not None else 0 + return key_cache, value_cache, seq_length - def reset(self): - self.keys.zero_() - self.values.zero_() + def reset(self, key_cache: torch.Tensor, value_cache: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + key_cache.zero_() + value_cache.zero_() + return key_cache, value_cache, None - def reorder_cache(self, beam_idx): - dev = self.keys.device + def reorder_cache( + self, beam_idx: torch.LongTensor, key_cache: torch.Tensor, value_cache: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + dev = key_cache.device beam_idx_dev = beam_idx.to(dev) - self.keys = self.keys.index_select(0, beam_idx_dev) - self.values = self.values.index_select(0, beam_idx_dev) + key_cache = key_cache.index_select(0, beam_idx_dev) + value_cache = value_cache.index_select(0, beam_idx_dev) + return key_cache, value_cache, None - def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + def get_mask_sizes( + self, cache_position: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor + ) -> tuple[int, int]: full_mask_kv_offset = 0 full_mask_kv_length = self.max_cache_len - return full_mask_kv_length, full_mask_kv_offset + return key_cache, value_cache, (full_mask_kv_length, full_mask_kv_offset) class StaticCache(Cache): @@ -789,17 +885,18 @@ def _static_update( key_states: torch.Tensor, value_states: torch.Tensor, cache_position: torch.LongTensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ Updates the sliding window cache tensors, returning the potentially modified tensors. Args: - k_cache (`torch.Tensor`): The key cache tensor to update. - v_cache (`torch.Tensor`): The value cache tensor to update. key_states (`torch.Tensor`): The new key states to add. value_states (`torch.Tensor`): The new value states to add. cache_position (`torch.LongTensor`): The position indices where the new states should be inserted. - max_cache_len (`int`): The maximum length of the sliding window cache. + key_cache (`torch.Tensor`): The key cache tensor to update. + value_cache (`torch.Tensor`): The value cache tensor to update. Returns: tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update. @@ -814,9 +911,9 @@ def _static_update( if cache_position.shape[0] > self.max_cache_len: new_k = key_states[:, :, -self.max_cache_len :, :] new_v = value_states[:, :, -self.max_cache_len :, :] - self.keys.copy_(new_k) - self.values.copy_(new_v) - return self.keys, self.values + key_cache.copy_(new_k) + value_cache.copy_(new_v) + return key_cache, value_cache # Sliding window logic for generation phase or prefill < window slicing = torch.arange(self.max_cache_len, device=value_states.device) @@ -824,8 +921,8 @@ def _static_update( to_shift = current_seq_len > self.max_cache_len indices = (slicing + to_shift.sum()) % self.max_cache_len - k_out_shifted = self.keys[:, :, indices] - v_out_shifted = self.values[:, :, indices] + k_out_shifted = key_cache[:, :, indices] + v_out_shifted = value_cache[:, :, indices] # Clamp cache_position to determine the *target index* within the shifted cache view update_position = cache_position.clamp(min=0, max=self.max_cache_len - 1) @@ -840,18 +937,20 @@ def _static_update( k_out_updated[:, :, update_position] = key_states v_out_updated[:, :, update_position] = value_states - self.keys.copy_(k_out_updated) - self.values.copy_(v_out_updated) - return self.keys, self.values + key_cache.copy_(k_out_updated) + value_cache.copy_(v_out_updated) + return key_cache, value_cache - def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + def get_mask_sizes( + self, cache_position: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor + ) -> tuple[int, int]: query_length = cache_position.shape[0] first_cache_position = cache_position[0] local_mask_kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0) # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns local_mask_kv_length = max(query_length, self.max_cache_len) - return local_mask_kv_length, local_mask_kv_offset + return key_cache, value_cache, (local_mask_kv_length, local_mask_kv_offset) class SlidingWindowCache(Cache): @@ -937,10 +1036,10 @@ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch """ if layer_idx < len(self): return ( - self.self_attention_cache.layers[layer_idx].keys, - self.self_attention_cache.layers[layer_idx].values, - self.cross_attention_cache.layers[layer_idx].keys, - self.cross_attention_cache.layers[layer_idx].values, + self.self_attention_cache.key_cache[layer_idx], + self.self_attention_cache.value_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + self.cross_attention_cache.value_cache[layer_idx], ) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") @@ -1507,10 +1606,10 @@ def init(self, cache: "Cache", **kwargs) -> None: self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers) if self.is_static: for i, layer in enumerate(cache.layers): - device = cache.layer_args["device"] if i == 0 else self.offload_device - layer.keys = layer.keys.to(device) - layer.values = layer.values.to(device) - self.original_device.append(cache.layer_args["device"]) + device = cache.layer_init_args["device"] if i == 0 else self.offload_device + cache.key_cache[i] = cache.key_cache[i].to(device) + cache.value_cache[i] = cache.value_cache[i].to(device) + self.original_device.append(cache.layer_init_args["device"]) if len(cache) != cache.model_num_layers: raise ValueError("If static layers are used, all cache layers must be initialized") @@ -1556,18 +1655,18 @@ def _prefetch_layer(self, cache: "Cache", layer_idx: int): ): # Prefetch next layer tensors to GPU device = self.original_device[layer_idx] - cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.to(device, non_blocking=True) - cache.layers[layer_idx].values = cache.layers[layer_idx].values.to(device, non_blocking=True) + cache.key_cache[layer_idx] = cache.key_cache[layer_idx].to(device, non_blocking=True) + cache.value_cache[layer_idx] = cache.value_cache[layer_idx].to(device, non_blocking=True) def _evict_previous_layer(self, cache: "Cache", layer_idx: int): """Moves the previous layer cache to the CPU.""" if len(cache) >= 2: # Layer 0 stays on device to be on-device after all layers are created # We do it on the default stream so it occurs after all earlier computations on these tensors are done prev_layer_idx = (layer_idx - 1) % len(cache) - cache.layers[prev_layer_idx].keys = cache.layers[prev_layer_idx].keys.to( + cache.key_cache[prev_layer_idx] = cache.key_cache[prev_layer_idx].to( self.offload_device, non_blocking=True ) - cache.layers[prev_layer_idx].values = cache.layers[prev_layer_idx].values.to( + cache.value_cache[prev_layer_idx] = cache.value_cache[prev_layer_idx].to( self.offload_device, non_blocking=True ) @@ -1580,8 +1679,8 @@ def _ensure_layer_on_device(self, cache: "Cache", layer_idx: int): # Handle delayed beam search operations if self.beam_idx is not None: self.beam_idx = self.beam_idx.to(self.original_device[layer_idx]) - cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.index_select(0, self.beam_idx) - cache.layers[layer_idx].values = cache.layers[layer_idx].values.index_select(0, self.beam_idx) + cache.key_cache[layer_idx] = cache.key_cache[layer_idx].index_select(0, self.beam_idx) + cache.value_cache[layer_idx] = cache.value_cache[layer_idx].index_select(0, self.beam_idx) class QuantizedCacheProcessor(CacheProcessor): @@ -1717,12 +1816,12 @@ def post_update( # Clear the residual cache self.erased_length = key_tensors.shape[-2] - cache.layers[layer_idx].keys = torch.zeros( + cache.key_cache[layer_idx] = torch.zeros( 0, dtype=key_tensors.dtype, device=key_tensors.device, ) - cache.layers[layer_idx].values = torch.zeros( + cache.value_cache[layer_idx] = torch.zeros( 0, dtype=value_tensors.dtype, device=value_tensors.device, @@ -1743,12 +1842,12 @@ def post_update( # Clear the residual cache self.erased_length += key_tensors.shape[-2] - cache.layers[layer_idx].keys = torch.zeros( + cache.key_cache[layer_idx] = torch.zeros( 0, dtype=key_tensors.dtype, device=key_tensors.device, ) - cache.layers[layer_idx].values = torch.zeros( + cache.value_cache[layer_idx] = torch.zeros( 0, dtype=value_tensors.dtype, device=value_tensors.device, diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 0331654fc1d4..81ddff503443 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -282,8 +282,8 @@ def __init__(self, model: PreTrainedModel): dtype=self.model.dtype, ) for i in range(len(self.static_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) + self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): """ @@ -413,8 +413,8 @@ def __init__( # Register all key and value cache tensors as buffers for i in range(len(self.cache)): - self.register_buffer(f"key_cache_{i}", self.cache.layers[i].keys, persistent=False) - self.register_buffer(f"value_cache_{i}", self.cache.layers[i].values, persistent=False) + self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False) + self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False) def forward( self, @@ -559,8 +559,8 @@ def __init__(self, model, max_static_cache_length, batch_size): # Register cache buffers to make them exportable for i in range(len(self.static_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) + self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): # Get outputs from decoder diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 57a1eff22c65..994bf9d85dca 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -230,8 +230,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index bb4660ae0998..465b94e13bee 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1293,8 +1293,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 87cd60b2ca66..8a0c43eafd3f 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -207,8 +207,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 974a3fb7e9ec..7821e1c7b4fb 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -229,8 +229,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 3f08d6804f9c..550e51221929 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -213,8 +213,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 12677705002c..19cac3e8c3ac 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -356,8 +356,8 @@ def forward( is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False if past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys - value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values + key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] + value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] else: key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index dfe345968da6..fe437fde84ed 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -182,8 +182,8 @@ def forward( is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False if past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys - value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values + key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] + value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] else: key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 57c8c4900e07..0817e16451ac 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1332,8 +1332,8 @@ def forward( else: indices = cache_position - key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices] - value_states = past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices] + key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] + value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index 7b8bcc6d37ec..a3ffa710d842 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1774,8 +1774,8 @@ def forward( else: indices = cache_position - key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices] - value_states = past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices] + key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] + value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 0526a067b020..4b8c17bd3b1c 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -728,9 +728,7 @@ def forward( # Ensure layer_past is on same device as hidden_states (might not be correct) if past_key_values is not None: - for layer in past_key_values.layers: - layer.keys = layer.keys.to(hidden_states.device) - layer.values = layer.values.to(hidden_states.device) + past_key_values = past_key_values.to(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states if causal_mask is not None: diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index c988a874c20f..9718e8fb736e 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -484,8 +484,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) @@ -601,8 +601,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 606627823514..3d46275bdc81 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -290,8 +290,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 1a15f11b3850..081869ec8fc5 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -478,8 +478,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 05368453e9e3..7d5a73667ee4 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -294,8 +294,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index f199b3edfe19..7319671b485e 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -229,8 +229,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 26f7b53caa67..2585d91a3e3e 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -239,8 +239,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index e3e438c0714b..66ed4adcea4c 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -105,14 +105,16 @@ def batch_repeat_interleave(self, repeats: int): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) else: - self.layers[layer_idx].batch_repeat_interleave(repeats) + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) def batch_select_indices(self, indices: torch.Tensor): for layer_idx in range(len(self)): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] else: - self.layers[layer_idx].batch_select_indices(indices) + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] def crop(self, max_length: int): raise RuntimeError("MiniMaxCache doesnot support `crop` method") diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 0477a942a695..9b6fc12ae3de 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -215,14 +215,16 @@ def batch_repeat_interleave(self, repeats: int): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) else: - self.layers[layer_idx].batch_repeat_interleave(repeats) + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) def batch_select_indices(self, indices: torch.Tensor): for layer_idx in range(len(self)): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] else: - self.layers[layer_idx].batch_select_indices(indices) + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] def crop(self, max_length: int): raise RuntimeError("MiniMaxCache doesnot support `crop` method") diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 985a35448e99..d33edcb3dd00 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -496,8 +496,8 @@ def forward( ) elif cache_position[0] != 0: key_states, value_states = ( - past_key_value.layers[self.layer_idx].keys, - past_key_value.layers[self.layer_idx].values, + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], ) else: raise ValueError( diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 307bbe9ae90e..2909fb386fb5 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -235,8 +235,8 @@ def forward( # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.layers[self.layer_idx].keys - value_states = past_key_value.layers[self.layer_idx].values + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] else: key_states = ( self.k_proj(current_states) diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index df864b0c1ff2..500231f3b48b 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -331,8 +331,8 @@ def forward( # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.layers[self.layer_idx].keys - value_states = past_key_value.layers[self.layer_idx].values + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] else: key_states = ( self.k_proj(current_states) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index f467467e7fba..5584b2ee8255 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -376,8 +376,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 5ade9ee41d8a..2ffb53ee9e01 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -228,8 +228,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 396b49dfb6e4..13f0ea27a6e4 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -249,8 +249,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index a01f88b5443b..6b90ae80d7c7 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -770,8 +770,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.key(current_states) value_states = self.value(current_states) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 04384b8265a6..327b70b5ec73 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -425,8 +425,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 13741a20ac18..5c4285afe728 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -320,8 +320,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index e5bdc624feb9..b0273c8a4a33 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -513,8 +513,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index de7dbfa3e740..2a1a84b81523 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -501,8 +501,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index acf3bac94bcb..feccf6d7d9fd 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -352,8 +352,8 @@ def forward( past_key_value.is_updated[self.layer_idx] = True # cross-attention: reuse cached states else: - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index 522b60ebc83b..ae69ae991009 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -308,8 +308,8 @@ def forward( past_key_value.is_updated[self.layer_idx] = True # cross-attention: reuse cached states else: - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 4c37cd42ef63..778a0485b4e8 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -394,8 +394,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 74bbce0259d8..8d4e368e945b 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -599,8 +599,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index bd2f3dd10ea6..2b1f650c6789 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -285,8 +285,8 @@ def forward( current_states = encoder_hidden_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.layers[self.layer_idx].keys - value_states = curr_past_key_value.layers[self.layer_idx].values + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index c5ce00016e6a..248d17cac404 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1140,8 +1140,8 @@ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None for layer_idx in range(self.config.decoder_layers): layer_past_key_values = [] for cache_cls in [values.self_attention_cache, values.cross_attention_cache]: - for v in [cache_cls.layers[layer_idx].keys, cache_cls.layers[layer_idx].values]: - layer_past_key_values.append(v[batch_idx][None].cpu()) + for v in [cache_cls.key_cache, cache_cls.value_cache]: + layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu()) all_past_key_values.append(tuple(layer_past_key_values)) return tuple(all_past_key_values) else: diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index e5dc5d59e7f3..d3e9c8e03a2b 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -329,8 +329,8 @@ def forward( current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.layers[self.layer_idx].keys - value_states = past_key_value.layers[self.layer_idx].values + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cee456aa8f2c..a04ea4e5cf59 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1603,26 +1603,26 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_keys = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys + self_attention_layer_key_cache = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] ) - self_attention_layer_values = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values + self_attention_layer_value_cache = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] ) - self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) + self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_keys = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys + cross_attention_layer_key_cache = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] ) - cross_attention_layer_values = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values + cross_attention_layer_value_cache = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] ) - cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] - cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] - self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) + cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] + cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] + self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) # 3.2. Decoder-only checks else: @@ -1634,18 +1634,10 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple # Self attention - if is_legacy_cache: - self_attention_layer_keys = past_kv[i][0] - self_attention_layer_values = past_kv[i][1] - elif past_kv.layers is None: - # Cache is lot layered (i.e, Mamba derivatives) - self_attention_layer_keys = past_kv.key_cache[i] - self_attention_layer_values = past_kv.value_cache[i] - else: - self_attention_layer_keys = past_kv.layers[i].keys - self_attention_layer_values = past_kv.layers[i].values - self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) + self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] + self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] + self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) @pytest.mark.generate @parameterized.expand([("greedy", 1), ("beam search", 2)]) @@ -1805,7 +1797,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self): cache_shape = [batch_size, num_key_value_heads, max_length, head_dim] self.assertIsInstance(outputs.past_key_values, StaticCache) self.assertEqual(len(outputs.past_key_values), num_hidden_layers) - self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape) + self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape) @pytest.mark.generate def test_generate_continue_from_past_key_values(self): @@ -2036,7 +2028,7 @@ def test_generate_with_static_cache(self): cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache)) self.assertTrue(len(static_cache_generation.past_key_values) == num_hidden_layers) - self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape) + self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape) # Check 2: The outputs must be similar to the case with dynamic cache dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) @@ -2618,12 +2610,12 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value if isinstance(decoder_past_key_values, Cache): self.assertListEqual( - [layer.keys.shape for layer in decoder_past_key_values.layers], - [expected_shape] * len(decoder_past_key_values.layers), + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), ) self.assertListEqual( - [layer.values.shape for layer in decoder_past_key_values.layers], - [expected_shape] * len(decoder_past_key_values.layers), + [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), ) # Legacy cache format checks. This branch should be removed when all models use `Cache` by default @@ -3982,13 +3974,13 @@ def test_generate_with_static_cache_multi_accelerator(self): self.assertTrue(isinstance(results.past_key_values, StaticCache)) # check device of each layer - keys_0 = results.past_key_values.layers[0].keys - values_0 = results.past_key_values.layers[0].values - self.assertTrue(keys_0.device == values_0.device == torch.device(0)) + key_cache_0 = results.past_key_values.key_cache[0] + value_cache_0 = results.past_key_values.value_cache[0] + self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) - keys_1 = results.past_key_values.layers[1].keys - values_1 = results.past_key_values.layers[1].values - self.assertTrue(keys_1.device == values_1.device == torch.device(1)) + key_cache_1 = results.past_key_values.key_cache[1] + value_cache_1 = results.past_key_values.value_cache[1] + self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) @pytest.mark.generate @require_torch_multi_accelerator @@ -4060,13 +4052,13 @@ def test_init_static_cache_multi_accelerator(self): results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) # check device of each layer - keys_0 = results.past_key_values.layers[0].keys - values_0 = results.past_key_values.layers[0].values - self.assertTrue(keys_0.device == values_0.device == torch.device(0)) + key_cache_0 = results.past_key_values.key_cache[0] + value_cache_0 = results.past_key_values.value_cache[0] + self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) - keys_1 = results.past_key_values.layers[1].keys - values_1 = results.past_key_values.layers[1].values - self.assertTrue(keys_1.device == values_1.device == torch.device(1)) + key_cache_1 = results.past_key_values.key_cache[1] + value_cache_1 = results.past_key_values.value_cache[1] + self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) @slow def test_padding_input_contrastive_search_gpt2(self): diff --git a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py index 87f7b2abb0e9..6c0c3a19d067 100644 --- a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py +++ b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -440,11 +440,13 @@ def test_past_key_values_format(self): # difference: last dim k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim v_embed_dim = config.v_head_dim - self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) - self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) + self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) + self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) # build the full cache shapes num_hidden_layers = config.num_hidden_layers - all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)] + all_cache_shapes = [ + [self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers) + ] super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes) @require_torch_large_accelerator diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py index 34a7a9884728..f9427160c254 100644 --- a/tests/models/dia/test_modeling_dia.py +++ b/tests/models/dia/test_modeling_dia.py @@ -399,12 +399,12 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value if isinstance(decoder_past_key_values, Cache): self.assertListEqual( - [layer.keys.shape for layer in decoder_past_key_values.layers], - [expected_shape] * len(decoder_past_key_values.layers), + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), ) self.assertListEqual( - [layer.values.shape for layer in decoder_past_key_values.layers], - [expected_shape] * len(decoder_past_key_values.layers), + [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), ) def _check_scores(self, batch_size, scores, generated_length, config): diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index ecd2af9fdc6c..b0a0a6a3ccb4 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -235,8 +235,8 @@ def copy_cache(cache: DynamicCache): """Deep copy a DynamicCache to reuse the same one multiple times.""" new_cache = cache for i in range(len(cache)): - new_cache.layers[i].keys = cache.layers[i].keys.clone() - new_cache.layers[i].values = cache.layers[i].values.clone() + new_cache.key_cache[i] = cache.key_cache[i].clone() + new_cache.value_cache[i] = cache.value_cache[i].clone() # Cached forward once with the attention mask provided and the other time without it (which should assume full attention) # We need to run both on a copy of the cache, otherwise it is modified in-place diff --git a/tests/models/t5gemma/test_modeling_t5gemma.py b/tests/models/t5gemma/test_modeling_t5gemma.py index 269fee53d165..002fb09009d5 100644 --- a/tests/models/t5gemma/test_modeling_t5gemma.py +++ b/tests/models/t5gemma/test_modeling_t5gemma.py @@ -1077,26 +1077,26 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_keys = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys + self_attention_layer_key_cache = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] ) - self_attention_layer_values = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values + self_attention_layer_value_cache = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] ) - self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) + self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_keys = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys + cross_attention_layer_key_cache = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] ) - cross_attention_layer_values = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values + cross_attention_layer_value_cache = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] ) - cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] - cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] - self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) + cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] + cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] + self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) # 3.2. Decoder-only checks else: @@ -1108,10 +1108,10 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys - self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values - self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) + self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] + self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] + self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) @unittest.skip("Mismatch issue doesn't exist in T5Gemma.") def test_load_with_mismatched_shapes(self): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 1e5da8f542e6..077f16a95af8 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -679,9 +679,11 @@ def test_dynamic_cache_exportability(self): use_cache=True, ) self.assertTrue(torch.allclose(res.logits, res_eager.logits)) - for l1, l2 in zip(res.past_key_values.layers, res_eager.past_key_values.layers): - self.assertTrue(torch.allclose(l1.keys, l2.keys)) - self.assertTrue(torch.allclose(l1.values, l2.values)) + for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache): + self.assertTrue(torch.allclose(k1, k2)) + + for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache): + self.assertTrue(torch.allclose(v1, v2)) def test_dynamic_cache_exportability_multiple_run(self): # When exporting with DynamicCache, you should export two graphs: @@ -731,8 +733,8 @@ def test_dynamic_cache_exportability_multiple_run(self): dyn = torch.export.Dim("seq", max=512) for ix in range(len(past_key_values)): - shapes[past_key_values.layers[ix].keys] = (None, None, dyn, None) - shapes[past_key_values.layers[ix].values] = (None, None, dyn, None) + shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None) + shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None) ep_second = torch.export.export( model, @@ -773,9 +775,11 @@ def test_dynamic_cache_exportability_multiple_run(self): use_cache=True, ) - for l1, l2 in zip(res_export_2.past_key_values.layers, res_eager_2.past_key_values.layers): - self.assertTrue(torch.allclose(l1.keys, l2.keys)) - self.assertTrue(torch.allclose(l1.values, l2.values)) + for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache): + self.assertTrue(torch.allclose(k1, k2)) + + for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache): + self.assertTrue(torch.allclose(v1, v2)) def test_static_cache_exportability(self): """ @@ -954,7 +958,7 @@ def test_static_cache(self): cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" + static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" ) # Scenario 2: Fill to capacity @@ -965,7 +969,7 @@ def test_static_cache(self): cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" + static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" ) def test_sliding_window_cache(self): @@ -1001,7 +1005,7 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), + sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "SlidingWindowCache Scenario 1 failed", ) @@ -1024,7 +1028,7 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), + sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "SlidingWindowCache Scenario 2 failed", ) @@ -1041,7 +1045,7 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), + sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "SlidingWindowCache Scenario 3 failed", ) @@ -1075,7 +1079,7 @@ def test_hybrid_cache_static_mode(self): cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(), + hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Static Scenario 1 failed", ) @@ -1088,7 +1092,7 @@ def test_hybrid_cache_static_mode(self): cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(), + hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "HybridCache Static Scenario 2 failed", ) @@ -1127,7 +1131,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Sliding Scenario 1 failed", ) @@ -1148,7 +1152,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "HybridCache Sliding Scenario 2 failed", ) @@ -1161,7 +1165,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 3 failed", ) @@ -1176,7 +1180,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 4 failed", ) @@ -1196,10 +1200,10 @@ def test_dynamic_cache(self): cache = DynamicCache() cache.update(prefill, prefill, 0) cache.update(update3, update3, 0) - self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed") + self.assertEqual(cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed") cache.update(update4, update4, 0) self.assertEqual( - cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed" + cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed" ) # Scenario 2: prefill and update for two layers independently @@ -1216,10 +1220,10 @@ def test_dynamic_cache(self): cache.update(update4, update4, 0) cache.update(update4_1, update4_1, 1) self.assertEqual( - cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed" + cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed" ) self.assertEqual( - cache.layers[1].keys[0, 0, :, 0].tolist(), + cache.key_cache[1][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0], "DynamicCache Scenario 2 layer 1 failed", ) diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index eb101ab566aa..bc247b2b6011 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -956,9 +956,8 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str): idx += 1 if "".join(source[start_idx:idx])[:-1] != old_doc_args: - raise ValueError( - f"Expected\n{old_doc_args}\nbut got\n{''.join(source[start_idx:idx])[:-1]}\n in {find_source_file(obj)}: {obj.__name__}" - ) + # Args are not fully defined in the docstring of this object + return obj_file = find_source_file(obj) with open(obj_file, "r", encoding="utf-8") as f: From fd83e14bb81cdb549955932d04a6821a319a87d0 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Thu, 10 Jul 2025 10:59:08 +0200 Subject: [PATCH 08/35] remove cachebase for decorator --- src/transformers/cache_utils.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 5500b9247ebc..71f95d997a36 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -201,31 +201,26 @@ def parse_layer_args_from_model_config( return {k: v for k, v in layer_args.items() if v is not None} -class CacheBase: - layers: list[CacheLayerMixin] = None - key_cache: list[torch.Tensor] = None - value_cache: list[torch.Tensor] = None - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: +def apply_processors(fn): + def _wrapped_update(self, key_states, value_states, layer_idx, cache_kwargs=None): if self.cache_processor is not None: key_states, value_states = self.cache_processor.pre_update( self, key_states, value_states, layer_idx, cache_kwargs ) - key_tensors, value_tensors = self._update(key_states, value_states, layer_idx, cache_kwargs) + + key_tensors, value_tensors = fn(self, key_states, value_states, layer_idx, cache_kwargs) + if self.cache_processor is not None: key_tensors, value_tensors = self.cache_processor.post_update( self, key_tensors, value_tensors, layer_idx, cache_kwargs ) + return key_tensors, value_tensors + return _wrapped_update + -class Cache(CacheBase): +class Cache: """ Base, abstract class for all caches. The actual data structure is specific to the layers. This class handles propagation of operations across layers. @@ -236,6 +231,10 @@ class Cache(CacheBase): - SlidingWindow layers are limited to sliding window size, Static layers use full max_cache_len """ + layers: list[CacheLayerMixin] = None + key_cache: list[torch.Tensor] = None + value_cache: list[torch.Tensor] = None + def __init__( self, model_config: Optional[PretrainedConfig] = None, @@ -374,7 +373,8 @@ def append_new_layers(self, layer_idx): self.key_cache.append(new_key) self.value_cache.append(new_value) - def _update( + @apply_processors + def update( self, key_states: torch.Tensor, value_states: torch.Tensor, From c2004471a42cfbb74ce0f9ce2e3830ad5b5bdf1c Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Thu, 10 Jul 2025 11:37:25 +0200 Subject: [PATCH 09/35] no more __getattr__ --- src/transformers/cache_utils.py | 159 ++++++++++++++++-------------- src/transformers/masking_utils.py | 4 +- 2 files changed, 88 insertions(+), 75 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 71f95d997a36..79e38bfadab1 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -231,10 +231,6 @@ class Cache: - SlidingWindow layers are limited to sliding window size, Static layers use full max_cache_len """ - layers: list[CacheLayerMixin] = None - key_cache: list[torch.Tensor] = None - value_cache: list[torch.Tensor] = None - def __init__( self, model_config: Optional[PretrainedConfig] = None, @@ -312,49 +308,6 @@ def __len__(self): ) return len(self.key_cache) if not dynamic_empty else 0 - def __getattr__(self, name): - """ - Forward a *shared* operation or property to every layer. - - - If every layer defines a callable `name`, we return a proxy that - calls each layer in turn and updates `key_cache` / `value_cache`. - - If every layer defines a boolean `name`, we return `all(values)`. - - If every layer defines a constant property `name`, we return it. - - Otherwise we raise AttributeError. - """ - if not self.layers: - return object.__getattribute__(self, name) - - # 1) Gather the attribute from every layer - values = [] - for layer in self.layers: - values.append(getattr(layer, name, None)) - - # 2) All callables → make a forwarding function - if all(callable(v) for v in values): - - def _proxy(*args, **kwargs): - for i in range(len(self.layers)): - self.key_cache[i], self.value_cache[i], ret = values[i]( - key_cache=self.key_cache[i], value_cache=self.value_cache[i], *args, **kwargs - ) - if ret is not None: - break - return ret - - return _proxy - - # 3) All booleans → reduce with `all` - if all(isinstance(v, bool) for v in values): - return all(values) - - # 4) All identical → return first - if all(v == values[0] for v in values): - return values[0] - - # 5) Anything else → unsupported mixed attribute - raise AttributeError(f"{self.__class__.__name__}: layers disagree on attribute {name!r}: {values}") - def __repr__(self): return f"{self.__class__.__name__}(layers={self.layers})" @@ -415,9 +368,7 @@ def get_seq_length(self, layer_idx: int = 0) -> int: """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" if layer_idx >= len(self.layers): return 0 - _, _, seq_length = self.layers[layer_idx].get_seq_length( - self.key_cache[layer_idx], self.value_cache[layer_idx] - ) + seq_length = self.layers[layer_idx].get_seq_length(self.key_cache[layer_idx], self.value_cache[layer_idx]) # Hack since QuantizedCache messes with keys shape as it becomes the residual cache if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): return self.cache_processor.erased_length + seq_length @@ -430,11 +381,69 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), for each layer. """ - _, _, (kv_length, kv_offset) = self.layers[layer_idx].get_mask_sizes( + kv_length, kv_offset = self.layers[layer_idx].get_mask_sizes( cache_position, self.key_cache[layer_idx], self.value_cache[layer_idx] ) return kv_length, kv_offset + ### Wrappers for layer operations and properties ### + + def get_max_cache_shape(self, layer_idx: int = 0) -> int: + """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" + return self.layers[layer_idx].get_max_cache_shape(self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def reset(self): + for layer_idx in range(len(self.layers)): + self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].reset( + self.key_cache[layer_idx], self.value_cache[layer_idx] + ) + + def reorder_cache(self, beam_idx: torch.LongTensor): + for layer_idx in range(len(self.layers)): + self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].reorder_cache( + beam_idx, self.key_cache[layer_idx], self.value_cache[layer_idx] + ) + + def crop(self, max_length: int): + for layer_idx in range(len(self.layers)): + self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].crop( + max_length, self.key_cache[layer_idx], self.value_cache[layer_idx] + ) + + def batch_repeat_interleave(self, repeats: int): + for layer_idx in range(len(self.layers)): + self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].batch_repeat_interleave( + repeats, self.key_cache[layer_idx], self.value_cache[layer_idx] + ) + + def batch_select_indices(self, indices: torch.Tensor): + for layer_idx in range(len(self.layers)): + self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].batch_select_indices( + indices, self.key_cache[layer_idx], self.value_cache[layer_idx] + ) + + @property + def max_batch_size(self) -> int: + values = [layer.max_batch_size for layer in self.layers] + if len(set(values)) > 1: + raise ValueError(f"Max batch size is not consistent across layers: {values}") + return values[0] + + @property + def max_cache_len(self) -> int: + values = [layer.max_cache_len for layer in self.layers] + if len(set(values)) > 1: + raise ValueError(f"Max cache length is not consistent across layers: {values}") + return values[0] + + @property + def is_compileable(self) -> bool: + return all(layer.is_compileable for layer in self.layers) + + @property + def is_sliding(self) -> bool: + return all(layer.is_sliding for layer in self.layers) + class DynamicLayer(CacheLayerMixin): """ @@ -474,7 +483,7 @@ def update( def new_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: """Returns a new key and value tensor for this layer's cache.""" - return None, None # They get initialized in the update() + return None, None # They are initialized in update() def get_seq_length( self, @@ -485,8 +494,8 @@ def get_seq_length( """Returns the sequence length of the cached states.""" # TODO: deprecate this function in favor of `cache_position` if key_cache is None or key_cache.numel() == 0: - return key_cache, value_cache, 0 - return key_cache, value_cache, key_cache.shape[-2] + return 0 + return key_cache.shape[-2] def get_max_cache_shape( self, @@ -494,7 +503,7 @@ def get_max_cache_shape( value_cache: torch.Tensor, ) -> int: """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" - return key_cache, value_cache, -1 + return -1 def reset( self, @@ -504,7 +513,7 @@ def reset( """Resets the cache values while preserving the objects""" key_cache.zero_() value_cache.zero_() - return key_cache, value_cache, None + return key_cache, value_cache def reorder_cache( self, @@ -516,7 +525,7 @@ def reorder_cache( if key_cache is not None and key_cache.numel(): key_cache = key_cache.index_select(0, beam_idx.to(key_cache.device)) value_cache = value_cache.index_select(0, beam_idx.to(value_cache.device)) - return key_cache, value_cache, None + return key_cache, value_cache def crop( self, @@ -529,15 +538,15 @@ def crop( negative to remove `max_length` tokens. """ if max_length < 0: - max_length = self.get_seq_length(key_cache, value_cache)[-1] - abs(max_length) + max_length = self.get_seq_length(key_cache, value_cache) - abs(max_length) - if self.get_seq_length(key_cache, value_cache)[-1] <= max_length: - return key_cache, value_cache, None + if self.get_seq_length(key_cache, value_cache) <= max_length: + return key_cache, value_cache if key_cache is not None and key_cache.numel(): key_cache = key_cache[..., :max_length, :] value_cache = value_cache[..., :max_length, :] - return key_cache, value_cache, None + return key_cache, value_cache def batch_repeat_interleave( self, repeats: int, key_cache: torch.Tensor, value_cache: torch.Tensor @@ -546,7 +555,7 @@ def batch_repeat_interleave( if key_cache.numel(): key_cache = key_cache.repeat_interleave(repeats, dim=0) value_cache = value_cache.repeat_interleave(repeats, dim=0) - return key_cache, value_cache, None + return key_cache, value_cache def batch_select_indices( self, indices: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor @@ -555,7 +564,7 @@ def batch_select_indices( if key_cache.numel(): key_cache = key_cache[indices, ...] value_cache = value_cache[indices, ...] - return key_cache, value_cache, None + return key_cache, value_cache def get_mask_sizes( self, @@ -565,9 +574,9 @@ def get_mask_sizes( ) -> tuple[int, int]: full_mask_kv_offset = 0 query_length = cache_position.shape[0] - _, _, past_seen_tokens = self.get_seq_length(key_cache, value_cache, cache_position) + past_seen_tokens = self.get_seq_length(key_cache, value_cache, cache_position) kv_length = query_length + past_seen_tokens - return key_cache, value_cache, (kv_length, full_mask_kv_offset) + return kv_length, full_mask_kv_offset class DynamicCache(Cache): @@ -740,7 +749,7 @@ def __init__( self.device = device def get_max_cache_shape(self, key_cache: torch.Tensor, value_cache: torch.Tensor) -> int: - return key_cache, value_cache, self.max_cache_len + return self.max_cache_len def _static_update( self, @@ -819,12 +828,12 @@ def get_seq_length(self, key_cache: torch.Tensor, value_cache: torch.Tensor, cac # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. seq_length = (key_cache[0, 0].any(dim=-1)).sum() if key_cache is not None else 0 - return key_cache, value_cache, seq_length + return seq_length def reset(self, key_cache: torch.Tensor, value_cache: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: key_cache.zero_() value_cache.zero_() - return key_cache, value_cache, None + return key_cache, value_cache def reorder_cache( self, beam_idx: torch.LongTensor, key_cache: torch.Tensor, value_cache: torch.Tensor @@ -833,14 +842,14 @@ def reorder_cache( beam_idx_dev = beam_idx.to(dev) key_cache = key_cache.index_select(0, beam_idx_dev) value_cache = value_cache.index_select(0, beam_idx_dev) - return key_cache, value_cache, None + return key_cache, value_cache def get_mask_sizes( self, cache_position: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor ) -> tuple[int, int]: full_mask_kv_offset = 0 full_mask_kv_length = self.max_cache_len - return key_cache, value_cache, (full_mask_kv_length, full_mask_kv_offset) + return full_mask_kv_length, full_mask_kv_offset class StaticCache(Cache): @@ -950,7 +959,7 @@ def get_mask_sizes( local_mask_kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0) # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns local_mask_kv_length = max(query_length, self.max_cache_len) - return key_cache, value_cache, (local_mask_kv_length, local_mask_kv_offset) + return local_mask_kv_length, local_mask_kv_offset class SlidingWindowCache(Cache): @@ -1019,6 +1028,9 @@ class EncoderDecoderCache(Cache): """ + # Override @property from Cache + is_compileable = None + def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): super().__init__() self.self_attention_cache = self_attention_cache @@ -1241,8 +1253,9 @@ class HybridChunkedCache(Cache): is_compileable = True # Override @property since HybridChunked does not conform to layered caches yet - key_cache = None - value_cache = None + is_sliding = None + max_batch_size = None + max_cache_len = None def __init__( self, diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index f32218a15b18..468c4154f479 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -693,7 +693,7 @@ def create_causal_mask( """ # If we have an HybridCache structure, here we want to create the mask for the full layers is_sliding = [] - if past_key_values is not None and past_key_values.layers is not None: + if past_key_values is not None and getattr(past_key_values, "layers", None) is not None: is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] layer_idx = is_sliding.index(True) if True in is_sliding else 0 @@ -775,7 +775,7 @@ def create_sliding_window_causal_mask( """ # If we have an HybridCache structure, here we want to create the mask for the sliding layers is_sliding = [] - if past_key_values is not None and past_key_values.layers is not None: + if past_key_values is not None and getattr(past_key_values, "layers", None) is not None: is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] layer_idx = is_sliding.index(True) if True in is_sliding else 0 From 5b1b1f17d2765e6e8ecf4b9dba90d447962ed4b6 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Thu, 10 Jul 2025 12:30:55 +0200 Subject: [PATCH 10/35] fix tests --- src/transformers/cache_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c2518ff44622..db0a86b3aec9 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -309,8 +309,9 @@ def __len__(self): Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds to the number of layers in the model. """ + # Empty dynamic caches initialize an empty layer to be ready for first update dynamic_empty = ( - self.layers is not None + getattr(self, "layers", None) is not None and len(self.layers) == 1 and isinstance(self.layers[0], DynamicLayer) and self.key_cache[0] is None From 58dbcfe25a2fae38aa8bd0c8deb93cdeb14cca3c Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Fri, 11 Jul 2025 14:31:06 +0200 Subject: [PATCH 11/35] joaos review except docs --- src/transformers/cache_utils.py | 2729 +++++++++-------- .../generation/configuration_utils.py | 10 + src/transformers/generation/utils.py | 2 +- src/transformers/integrations/executorch.py | 6 +- src/transformers/masking_utils.py | 24 +- tests/utils/test_cache_utils.py | 36 +- 6 files changed, 1456 insertions(+), 1351 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index db0a86b3aec9..a3d24bebac75 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1,11 +1,13 @@ import copy +import functools import importlib.metadata +import inspect import json import os import warnings from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch from packaging import version @@ -23,194 +25,20 @@ logger = logging.get_logger(__name__) -class CacheProcessor: - """ - Base class for cache processors that can be applied to modify cache behavior. - This class should be subclassed. - """ - - def init(self, cache: "Cache", **kwargs) -> None: - """ - Initialize the processor and perform compatibility checks with the cache. - - Args: - cache (`Cache`): The cache instance this processor will be applied to. - **kwargs: Additional arguments that may be needed for initialization. - """ - raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.") - - def pre_update( +def apply_processors( + fn: Callable[..., tuple[torch.Tensor, torch.Tensor]], +) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: + @functools.wraps(fn) + def _wrapped_update( self, - cache: "Cache", key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Function called before the cache update. Can modify the key/value states. - - Args: - cache (`Cache`): The cache instance. - key_states (`torch.Tensor`): The new key states to cache. - value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states. - """ - return key_states, value_states - - def post_update( - self, - cache: "Cache", - key_tensors: torch.Tensor, - value_tensors: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Function called after the cache update. Can process the cached data. - - Args: - cache (`Cache`): The cache instance. - key_states (`torch.Tensor`): The key states that were cached. - value_states (`torch.Tensor`): The value states that were cached. - layer_idx (`int`): The index of the layer that was updated. - cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The final key and value states to return. + Wrapper around the update method to apply cache processors. """ - return key_tensors, value_tensors - - -class CacheLayerMixin: - """Base, abstract class for a single layer's cache.""" - - is_compileable = False - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - key_cache: Optional[torch.Tensor] = None, - value_cache: Optional[torch.Tensor] = None, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Updates KV cache, returns updated keys/values for this layer.""" - raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") - - def get_seq_length( - self, key_cache: Optional[torch.Tensor] = None, value_cache: Optional[torch.Tensor] = None - ) -> int: - """Returns the sequence length of this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.") - - def get_max_cache_shape(self) -> int: - """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") - - def reset(self, key_cache: torch.Tensor, value_cache: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Resets this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `reset` in {self.__class__.__name__}.") - - def get_mask_sizes( - self, - cache_position: torch.Tensor, - key_cache: Optional[torch.Tensor] = None, - value_cache: Optional[torch.Tensor] = None, - ) -> tuple[int, int]: - """Returns mask sizes for this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.") - - def new_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: - """Returns a new key and value tensor for this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `new_tensors` in {self.__class__.__name__}.") - - def reorder_cache( - self, - beam_idx: torch.LongTensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Reorders this layer's cache for beam search.""" - if key_cache.numel(): - device = key_cache.device - key_cache = key_cache.index_select(0, beam_idx.to(device)) - if value_cache.numel(): - device = value_cache.device - value_cache = value_cache.index_select(0, beam_idx.to(device)) - return key_cache, value_cache, None - - -def parse_layer_args_from_model_config( - model_config: Optional[PretrainedConfig], - batch_size: Optional[int] = None, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: Optional[torch.dtype] = None, - layer_device_map=None, - tp_size: Optional[int] = None, - max_batch_size: Optional[int] = None, -) -> dict: - # No model config -> must be a dynamic cache, return bare dict - if model_config is None: - return {} - # Build the args dict for hybrid, sliding or static - else: - # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) - if ( - getattr(model_config, "layer_types", None) is not None - and "sliding_attention" in model_config.layer_types - and "full_attention" in model_config.layer_types - ): - if getattr(model_config, "sliding_window", None) is None: - raise ValueError( - "Setting up a hybrid or sliding window KVCache requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) - max_cache_len = max_cache_len or model_config.max_position_embeddings - if getattr(model_config, "sliding_window", None) is not None: - sliding_window_len = min(model_config.sliding_window, max_cache_len) - else: - sliding_window_len = None - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: - head_dim = ( - model_config.head_dim - if getattr(model_config, "head_dim", None) is not None - else model_config.hidden_size // model_config.num_attention_heads - ) - num_heads = ( - model_config.num_attention_heads - if getattr(model_config, "num_key_value_heads", None) is None - else model_config.num_key_value_heads - ) - if tp_size is not None and tp_size > 1: - if num_heads % tp_size != 0: - raise ValueError( - f"Number of key value heads {num_heads} must be divisible by tensor parallel size {tp_size}." - ) - # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - num_heads //= tp_size - layer_args = { - "batch_size": max_batch_size if max_batch_size is not None else batch_size, - "max_cache_len": max_cache_len, - "device": torch.device(device) if device is not None else None, - "dtype": dtype, - "layer_device_map": layer_device_map, - "head_dim": head_dim, - "num_heads": num_heads, - "sliding_window": sliding_window_len, - } - return {k: v for k, v in layer_args.items() if v is not None} - - -def apply_processors(fn): - def _wrapped_update(self, key_states, value_states, layer_idx, cache_kwargs=None): if self.cache_processor is not None: key_states, value_states = self.cache_processor.pre_update( self, key_states, value_states, layer_idx, cache_kwargs @@ -230,61 +58,60 @@ def _wrapped_update(self, key_states, value_states, layer_idx, cache_kwargs=None class Cache: """ - Base, abstract class for all caches. The actual data structure is specific to the layers. + Base class for all caches. + The actual data structure is specific to the layers. This class handles propagation of operations across layers. - Note for hybrid caches (blocks of (StaticLayer, ..., SlidingWindowLayer) repeated across layers): - - Requires `model_config.sliding_window` to be set - - Uses `sliding_window_pattern` (default: 2) to determine layer alternation if pattern not specified - - SlidingWindow layers are limited to sliding window size, Static layers use full max_cache_len + Parameters: + config (`PretrainedConfig`): + Model configuration for shape/device info. + cache_processor (`CacheProcessor` or `str`, *optional*): + Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") or + a CacheProcessor class. + layer_classes (`list[type[CacheLayer]]`, *optional*): + List of layer classes to use for the cache. + Additional arguments for cache configuration: + max_batch_size/batch_size (`int`): Maximum batch size for static caches + max_cache_len (`int`): Maximum sequence length. For hybrid caches: + * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` + * StaticLayers: uses full `max_cache_len` + device (`torch.device`): Device for cache tensors + dtype (`torch.dtype`): Data type for cache tensors + layer_device_map (`dict[int, Union[str, torch.device]]`): Per-layer device mapping + tp_size (`int`): Tensor parallel size to adjust the number of key/value heads + device (`torch.device`): Device for cache tensors + dtype (`torch.dtype`): Data type for cache tensors + layer_device_map (`dict[int, Union[str, torch.device]]`): Per-layer device mapping + tp_size (`int`): Tensor parallel size to adjust the number of key/value heads """ def __init__( self, - model_config: Optional[PretrainedConfig] = None, - cache_processor: Optional[CacheProcessor] = None, - layer_classes: Optional[list[type[CacheLayerMixin]]] = None, + config: Optional[PretrainedConfig] = None, + cache_processor: Optional[Union[str, type["CacheProcessor"]]] = None, + layer_classes: Optional[list[type["CacheLayerMixin"]]] = None, *args, **kwargs, ): - """ - Parameters: - model_config (`PretrainedConfig`): - Model configuration for shape/device info. - cache_processor (`CacheProcessor`, *optional*): - Cache processor to apply (e.g., quantization, offloading). - layer_classes (`list[type[CacheLayer]]`, *optional*): - List of layer classes to use for the cache. - Additional arguments for cache configuration: - - `max_batch_size`/`batch_size` (`int`): Maximum batch size for static caches - - `max_cache_len` (`int`): Maximum sequence length. For hybrid caches: - * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` - * StaticLayers: uses full `max_cache_len` - - `device` (`torch.device`): Device for cache tensors - - `dtype` (`torch.dtype`): Data type for cache tensors - - `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping - - `tp_size` (`int`): Tensor parallel size to adjust the number of key/value heads - """ - self.layers: list[CacheLayerMixin] = [] + self.layers: list["CacheLayerMixin"] = [] self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] - self.cache_processor = cache_processor + processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor if ( layer_classes is None # setting layer_classes takes precedence - and model_config is not None - and getattr(model_config, "layer_types", None) is not None + and config is not None + and getattr(config, "layer_types", None) is not None ): - layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in model_config.layer_types] + layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] self.layer_classes = layer_classes or [DynamicLayer] - self.layer_init_args = parse_layer_args_from_model_config(model_config, *args, **kwargs) - self.model_num_layers = getattr(model_config, "num_hidden_layers", 1) + processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) + self.layer_init_args = parse_layer_args_from_model_config(config, *args, **kwargs) + self.model_num_layers = getattr(config, "num_hidden_layers", 1) self.append_new_layers(self.model_num_layers - 1) - - if self.cache_processor is not None: - self.cache_processor.init(self, **kwargs) + self.cache_processor = processor_class(self, **processor_kwargs) if processor_class is not None else None def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -321,18 +148,22 @@ def __len__(self): def __repr__(self): return f"{self.__class__.__name__}(layers={self.layers})" - def append_new_layers(self, layer_idx): + def append_new_layers(self, layer_idx: int) -> None: """ Appends layers to the cache until the layer `layer_idx` is reached. - Used in prefill and for skipped layers. + Used for preallocation in static caches and on the fly in dynamic caches. + + Args: + layer_idx (`int`): + The index of the layer to append. """ while len(self.layers) <= layer_idx: args = self.layer_init_args.copy() if self.layer_init_args.get("layer_device_map", None) is not None: args["device"] = args.pop("layer_device_map")[layer_idx] new_layer = self.layer_classes[layer_idx % len(self.layer_classes)](**args) - new_key, new_value = new_layer.new_tensors() self.layers.append(new_layer) + new_key, new_value = new_layer.new_tensors() self.key_cache.append(new_key) self.value_cache.append(new_value) @@ -403,30 +234,35 @@ def get_max_cache_shape(self, layer_idx: int = 0) -> int: return self.layers[layer_idx].get_max_cache_shape(self.key_cache[layer_idx], self.value_cache[layer_idx]) def reset(self): + """Recursively reset all layers tensors""" for layer_idx in range(len(self.layers)): self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].reset( self.key_cache[layer_idx], self.value_cache[layer_idx] ) def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorder the cache for beam search""" for layer_idx in range(len(self.layers)): self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].reorder_cache( beam_idx, self.key_cache[layer_idx], self.value_cache[layer_idx] ) def crop(self, max_length: int): + """Crop the cache to the given length""" for layer_idx in range(len(self.layers)): self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].crop( max_length, self.key_cache[layer_idx], self.value_cache[layer_idx] ) def batch_repeat_interleave(self, repeats: int): + """Repeat and interleave the cache""" for layer_idx in range(len(self.layers)): self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].batch_repeat_interleave( repeats, self.key_cache[layer_idx], self.value_cache[layer_idx] ) def batch_select_indices(self, indices: torch.Tensor): + """Select indices from the cache""" for layer_idx in range(len(self.layers)): self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].batch_select_indices( indices, self.key_cache[layer_idx], self.value_cache[layer_idx] @@ -434,6 +270,7 @@ def batch_select_indices(self, indices: torch.Tensor): @property def max_batch_size(self) -> int: + """Return the maximum batch size of the cache""" values = [layer.max_batch_size for layer in self.layers] if len(set(values)) > 1: raise ValueError(f"Max batch size is not consistent across layers: {values}") @@ -441,6 +278,7 @@ def max_batch_size(self) -> int: @property def max_cache_len(self) -> int: + """Return the maximum cache length of the cache""" values = [layer.max_cache_len for layer in self.layers] if len(set(values)) > 1: raise ValueError(f"Max cache length is not consistent across layers: {values}") @@ -448,11 +286,72 @@ def max_cache_len(self) -> int: @property def is_compileable(self) -> bool: + """Return whether the cache is compileable""" return all(layer.is_compileable for layer in self.layers) @property - def is_sliding(self) -> bool: - return all(layer.is_sliding for layer in self.layers) + def is_sliding(self) -> list[bool]: + """Return whether the layers of the cache are sliding window""" + return [getattr(layer, "is_sliding", False) for layer in self.layers] + + +class CacheLayerMixin: + """Base, abstract class for a single layer's cache.""" + + is_compileable = False + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Updates KV cache, returns updated keys/values for this layer.""" + raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") + + def get_seq_length( + self, key_cache: Optional[torch.Tensor] = None, value_cache: Optional[torch.Tensor] = None + ) -> int: + """Returns the sequence length of this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.") + + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") + + def reset(self, key_cache: torch.Tensor, value_cache: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Resets this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `reset` in {self.__class__.__name__}.") + + def get_mask_sizes( + self, + cache_position: torch.Tensor, + key_cache: Optional[torch.Tensor] = None, + value_cache: Optional[torch.Tensor] = None, + ) -> tuple[int, int]: + """Returns mask sizes for this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.") + + def new_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: + """Returns a new key and value tensor for this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `new_tensors` in {self.__class__.__name__}.") + + def reorder_cache( + self, + beam_idx: torch.LongTensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Reorders this layer's cache for beam search.""" + if key_cache.numel(): + device = key_cache.device + key_cache = key_cache.index_select(0, beam_idx.to(device)) + if value_cache.numel(): + device = value_cache.device + value_cache = value_cache.index_select(0, beam_idx.to(device)) + return key_cache, value_cache class DynamicLayer(CacheLayerMixin): @@ -465,8 +364,8 @@ def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - key_cache: Optional[torch.Tensor] = None, - value_cache: Optional[torch.Tensor] = None, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -484,16 +383,15 @@ def update( A tuple containing the updated key and value states. """ if key_cache is None: - key_cache = key_states - value_cache = value_states + key_cache, value_cache = key_states, value_states else: key_cache = torch.cat([key_cache, key_states], dim=-2) value_cache = torch.cat([value_cache, value_states], dim=-2) return key_cache, value_cache - def new_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: + def new_tensors(self) -> tuple[None, None]: """Returns a new key and value tensor for this layer's cache.""" - return None, None # They are initialized in update() + return None, None def get_seq_length( self, @@ -582,6 +480,7 @@ def get_mask_sizes( key_cache: Optional[torch.Tensor] = None, value_cache: Optional[torch.Tensor] = None, ) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the mask""" full_mask_kv_offset = 0 query_length = cache_position.shape[0] past_seen_tokens = self.get_seq_length(key_cache, value_cache, cache_position) @@ -589,152 +488,67 @@ def get_mask_sizes( return kv_length, full_mask_kv_offset -class DynamicCache(Cache): +class CacheProcessor: + """ + Base class for cache processors that can be applied to modify cache behavior. + This class should be subclassed. """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - Example: - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + def __init__(self, cache: "Cache", **kwargs) -> None: + """ + Initialize the processor and perform compatibility checks with the cache. - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + Args: + cache (`Cache`): The cache instance this processor will be applied to. + **kwargs: Additional arguments that may be needed for initialization. + """ + raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.") - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = DynamicCache() - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - DynamicCache() - ``` - """ + def pre_update( + self, + cache: "Cache", + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Function called before the cache update. Can modify the key/value states. - # Specialized constructor for DDP cache data, needed for BC - def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): - # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 - # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the - # iterable contains the key and value states for a layer gathered across replicas by torch.distributed - # (shape=[global batch size, num_heads, seq_len, head_dim]). - # WARNING: `ddp_cache_data` must be the first argument in `__init__`, otherwise we'll break - # compatibility. The name of the argument doesn't matter. - if ddp_cache_data is not None: - for key_states, value_states in ddp_cache_data: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - self.layers.append(DynamicLayer()) - super().__init__(*args, **kwargs) + Args: + cache (`Cache`): The cache instance. + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + layer_idx (`int`): The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: - """ - Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility. + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states. """ - legacy_cache = () - for keys, values in zip(self.key_cache, self.value_cache): - legacy_cache += ((keys, values),) - return legacy_cache + return key_states, value_states - @classmethod - def from_legacy_cache( - cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None - ) -> "Cache": - """ - Converts a cache in the legacy cache format into an equivalent `Cache`. Used for - backward compatibility. + def post_update( + self, + cache: "Cache", + key_tensors: torch.Tensor, + value_tensors: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - cache = cls() - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - -# Utilities for `DynamicCache` <> torch.export support -def _flatten_dynamic_cache( - dynamic_cache: DynamicCache, -): - """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" - if not isinstance(dynamic_cache, DynamicCache): - raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") - - if not is_torch_greater_or_equal_than_2_6: - logger.warning_once( - "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." - ) - - dictionary = { - "key_cache": dynamic_cache.key_cache if dynamic_cache.key_cache[0] is not None else [], - "value_cache": dynamic_cache.value_cache if dynamic_cache.value_cache[0] is not None else [], - } - return torch.utils._pytree._dict_flatten(dictionary) - - -def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): - dictionary = { - "key_cache": dynamic_cache.key_cache if dynamic_cache.key_cache[0] is not None else [], - "value_cache": dynamic_cache.value_cache if dynamic_cache.value_cache[0] is not None else [], - } - return torch.utils._pytree._dict_flatten_with_keys(dictionary) - - -def _unflatten_dynamic_cache( - values, - context: torch.utils._pytree.Context, -): - dictionary = torch.utils._pytree._dict_unflatten(values, context) - cache = DynamicCache() - # Reconstruct layers from keys and values lists - key_list = dictionary.get("key_cache", []) - value_list = dictionary.get("value_cache", []) - for idx in range(max(len(key_list), len(value_list))): - key = key_list[idx] if idx < len(key_list) else None - value = value_list[idx] if idx < len(value_list) else None - cache.update(key, value, idx) - return cache - - -def _flatten_dynamic_cache_for_fx(cache, spec): - dictionary = { - "key_cache": cache.key_cache if cache.key_cache[0] is not None else [], - "value_cache": cache.value_cache if cache.value_cache[0] is not None else [], - } - return torch.fx._pytree._dict_flatten_spec(dictionary, spec) - - -if is_torch_greater_or_equal("2.3"): - torch.utils._pytree.register_pytree_node( - DynamicCache, - _flatten_dynamic_cache, - _unflatten_dynamic_cache, - serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", - flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, - ) - # TODO (tmanlaibaatar) This won't be needed in torch 2.7. - torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) - - -class OffloadedCache(DynamicCache): - """ - A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. - Useful for generating from models with very long context. + Function called after the cache update. Can process the cached data. - In addition to the default accelerator stream, where all forward() computations happen, - this class uses another stream, the prefetch stream, which it creates itself. - Since scheduling of operations on separate streams happens independently, this class uses - the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. - The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to - ensure the eviction is scheduled after all computations on that cache are finished. - """ + Args: + cache (`Cache`): The cache instance. + key_states (`torch.Tensor`): The key states that were cached. + value_states (`torch.Tensor`): The value states that were cached. + layer_idx (`int`): The index of the layer that was updated. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - def __init__(self, model_config: Optional[PretrainedConfig] = None) -> None: - # Create the underlying cache with offload processor - super().__init__(cache_processor=OffloadedCacheProcessor(), model_config=model_config) + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The final key and value states to return. + """ + return key_tensors, value_tensors class StaticLayer(CacheLayerMixin): @@ -759,6 +573,7 @@ def __init__( self.device = device def get_max_cache_shape(self, key_cache: torch.Tensor, value_cache: torch.Tensor) -> int: + """Return the maximum cache shape of the cache""" return self.max_cache_len def _static_update( @@ -769,7 +584,6 @@ def _static_update( key_cache: torch.Tensor, value_cache: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - # Utility functions for static/sliding cache update logic """ Updates the static cache tensors in place. @@ -804,10 +618,11 @@ def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - key_cache: Optional[torch.Tensor] = None, - value_cache: Optional[torch.Tensor] = None, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + """Update the static cache tensors in place""" cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None return self._static_update( key_states.to(key_cache.dtype), value_states.to(value_cache.dtype), cache_position, key_cache, value_cache @@ -862,34 +677,6 @@ def get_mask_sizes( return full_mask_kv_length, full_mask_kv_offset -class StaticCache(Cache): - """ - Static Cache class to be used with `torch.compile(model)` and `torch.export()`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache - - >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - - >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - StaticCache() - ``` - """ - - def __init__(self, *args, **kwargs): - super().__init__(layer_classes=[StaticLayer], *args, **kwargs) - - class SlidingWindowLayer(StaticLayer): """ A static cache layer that implements sliding window attention caching. @@ -972,1132 +759,1452 @@ def get_mask_sizes( return local_mask_kv_length, local_mask_kv_offset -class SlidingWindowCache(Cache): +class DynamicCache(Cache): """ - Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. - Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.sliding_window - 1`, - if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), - we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. - - The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + A cache that grows dynamically as more tokens are generated. This is the default for generative models. - indices = (slicing + to_shift[-1].sum()-1) % self.sliding_window - tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. - We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) Example: ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache - >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = DynamicCache() >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation - SlidingWindowCache() + DynamicCache() ``` """ - def __init__(self, *args, **kwargs): - super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs) - - -class EncoderDecoderCache(Cache): - """ - Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and - cross-attention caches. + # Specialized constructor for DDP cache data, needed for BC + def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 + # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the + # iterable contains the key and value states for a layer gathered across replicas by torch.distributed + # (shape=[global batch size, num_heads, seq_len, head_dim]). + # WARNING: `ddp_cache_data` must be the first argument in `__init__`, otherwise we'll break + # compatibility. The name of the argument doesn't matter. + if ddp_cache_data is not None: + for key_states, value_states in ddp_cache_data: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.layers.append(DynamicLayer()) + super().__init__(*args, **kwargs) - Example: - - ```python - >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") - >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") - - >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") - - >>> # Prepare cache classes for encoder and decoder and pass it to model's forward - >>> self_attention_cache = DynamicCache() - >>> cross_attention_cache = DynamicCache() - >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - EncoderDecoderCache() - ``` - - """ - - # Override @property from Cache - is_compileable = None - - def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): - super().__init__() - self.self_attention_cache = self_attention_cache - self.cross_attention_cache = cross_attention_cache - self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) - - self.is_updated = {} - for layer_idx in range(len(cross_attention_cache)): - self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) - - def __iter__(self): - """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over - keys and values - """ - for layer_idx in range(len(self)): - yield ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], - ) - - def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self): - return ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], - ) - else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.self_attention_cache) - - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]: - """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" - legacy_cache = () - if len(self.cross_attention_cache) > 0: - for self_attn, cross_attn in zip( - self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache() - ): - legacy_cache += (self_attn + cross_attn,) - else: - legacy_cache = self.self_attention_cache.to_legacy_cache() - return legacy_cache + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: + """ + Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility. + """ + legacy_cache = () + for keys, values in zip(self.key_cache, self.value_cache): + legacy_cache += ((keys, values),) + return legacy_cache @classmethod def from_legacy_cache( cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None - ) -> "EncoderDecoderCache": - """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" - cache = cls( - self_attention_cache=DynamicCache(), - cross_attention_cache=DynamicCache(), - ) + ) -> "Cache": + """ + Converts a cache in the legacy cache format into an equivalent `Cache`. Used for + backward compatibility. + """ + cache = cls() if past_key_values is not None: for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx][:2] - cache.self_attention_cache.update(key_states, value_states, layer_idx) - if len(past_key_values[layer_idx]) > 2: - key_states, value_states = past_key_values[layer_idx][2:] - cache.cross_attention_cache.update(key_states, value_states, layer_idx) - cache.is_updated[layer_idx] = True + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) return cache - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` - return self.self_attention_cache.get_seq_length(layer_idx) - def reset(self): - if hasattr(self.self_attention_cache, "reset"): - self.self_attention_cache.reset() - if hasattr(self.cross_attention_cache, "reset"): - self.cross_attention_cache.reset() - elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"): - raise ValueError( - "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " - "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " - f"Got {self.self_attention_cache.__str__()} for the self attention cache and " - f"{self.cross_attention_cache.__str__()} for the cross attention cache." - ) - for layer_idx in self.is_updated: - self.is_updated[layer_idx] = False +# Utilities for `DynamicCache` <> torch.export support +def _flatten_dynamic_cache( + dynamic_cache: DynamicCache, +): + """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" + if not isinstance(dynamic_cache, DynamicCache): + raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - self.self_attention_cache.reorder_cache(beam_idx) - self.cross_attention_cache.reorder_cache(beam_idx) + if not is_torch_greater_or_equal_than_2_6: + logger.warning_once( + "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." + ) - def check_dynamic_cache(self, method: str): - if not ( - isinstance(self.self_attention_cache, DynamicCache) - and isinstance(self.cross_attention_cache, DynamicCache) - ): - raise ValueError( - f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " - f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." - ) + dictionary = { + "key_cache": dynamic_cache.key_cache if dynamic_cache.key_cache[0] is not None else [], + "value_cache": dynamic_cache.value_cache if dynamic_cache.value_cache[0] is not None else [], + } + return torch.utils._pytree._dict_flatten(dictionary) - # TODO(gante, sanchit-gandhi): move following functionality into `.generate` - def crop(self, maximum_length: int): - """ - Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be - negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search. - """ - self.check_dynamic_cache(self.crop.__name__) - self.self_attention_cache.crop(maximum_length) - def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": - """ - Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils` - """ - self.check_dynamic_cache(self.batch_split.__name__) - self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) - cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) +def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): + dictionary = { + "key_cache": dynamic_cache.key_cache if dynamic_cache.key_cache[0] is not None else [], + "value_cache": dynamic_cache.value_cache if dynamic_cache.value_cache[0] is not None else [], + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) - out = [] - for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): - out.append(EncoderDecoderCache(self_attn, cross_attn)) - return out - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - self.check_dynamic_cache(self.batch_repeat_interleave.__name__) - self.self_attention_cache.batch_repeat_interleave(repeats) - self.cross_attention_cache.batch_repeat_interleave(repeats) +def _unflatten_dynamic_cache( + values, + context: torch.utils._pytree.Context, +): + dictionary = torch.utils._pytree._dict_unflatten(values, context) + cache = DynamicCache() + # Reconstruct layers from keys and values lists + key_list = dictionary.get("key_cache", []) + value_list = dictionary.get("value_cache", []) + for idx in range(max(len(key_list), len(value_list))): + key = key_list[idx] if idx < len(key_list) else None + value = value_list[idx] if idx < len(value_list) else None + cache.update(key, value, idx) + return cache - def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - self.check_dynamic_cache(self.batch_select_indices.__name__) - self.self_attention_cache.batch_select_indices(indices) - self.cross_attention_cache.batch_select_indices(indices) - def get_max_cache_shape(self) -> int: - """Returns the maximum sequence length (i.e. max capacity) of the cache object""" - return self.self_attention_cache.get_max_cache_shape() +def _flatten_dynamic_cache_for_fx(cache, spec): + dictionary = { + "key_cache": cache.key_cache if cache.key_cache[0] is not None else [], + "value_cache": cache.value_cache if cache.value_cache[0] is not None else [], + } + return torch.fx._pytree._dict_flatten_spec(dictionary, spec) - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) + +if is_torch_greater_or_equal("2.3"): + torch.utils._pytree.register_pytree_node( + DynamicCache, + _flatten_dynamic_cache, + _unflatten_dynamic_cache, + serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, + ) + # TODO (tmanlaibaatar) This won't be needed in torch 2.7. + torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) -class HybridCache(Cache): +class OffloadedCache(DynamicCache): """ - Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window - attention and global attention in every other layer (originally implemented for Gemma2). - Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention.For more information, see the documentation of each subcomponent cache class. + A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. + Useful for generating from models with very long context. + + In addition to the default accelerator stream, where all forward() computations happen, + this class uses another stream, the prefetch stream, which it creates itself. + Since scheduling of operations on separate streams happens independently, this class uses + the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. + The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to + ensure the eviction is scheduled after all computations on that cache are finished. + """ + + def __init__(self, config: Optional[PretrainedConfig] = None) -> None: + # Create the underlying cache with offload processor + super().__init__(cache_processor=OffloadedCacheProcessor, config=config) + + +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. Example: ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() + StaticCache() ``` """ - def __init__(self, model_config: PretrainedConfig, *args, **kwargs): - # Ugly but needed for BC - layer_classes = [StaticLayer] if not hasattr(model_config, "layer_types") else None - super().__init__(model_config=model_config, layer_classes=layer_classes, *args, **kwargs) + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=[StaticLayer], *args, **kwargs) -class HybridChunkedCache(Cache): +class SlidingWindowCache(Cache): """ - Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window - attention and global attention in every other layer, with support for chunked attention (originally implemented - for Llama4). - Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention. For more information, see the documentation of each subcomponent cache class. + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.sliding_window - 1`, + if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), + we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. - Parameters: - model_config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. + The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + + indices = (slicing + to_shift[-1].sum()-1) % self.sliding_window + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) Example: ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() + SlidingWindowCache() ``` """ - is_compileable = True - # Override @property since HybridChunked does not conform to layered caches yet - is_sliding = None - max_batch_size = None - max_cache_len = None - - def __init__( - self, - model_config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.bfloat16, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - if not hasattr(model_config, "sliding_window") or model_config.sliding_window is None: - self.sliding_window = getattr(model_config.get_text_config(), "attention_chunk_size", 8192) - else: - self.sliding_window = model_config.sliding_window - self.max_cache_len = max_cache_len - # Sliding layers can't be larger than the overall max cache len - self.sliding_window = min(self.sliding_window, self.max_cache_len) - self.max_batch_size = max_batch_size - self.head_dim = getattr(model_config, "head_dim", model_config.hidden_size // model_config.num_attention_heads) - self._dtype = dtype - - # If the attribute does not exist in the config, fallback to a simple StaticCache - if hasattr(model_config, "layer_types"): - self.is_sliding = [layer_type != "full_attention" for layer_type in model_config.layer_types] - else: - self.is_sliding = [False] * model_config.num_hidden_layers + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs) - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - self.cumulative_length = [0 for _ in range(model_config.num_hidden_layers)] - def initialise_cache_layer(self, layer_idx, key_states): - if len(self.key_cache) > layer_idx: - return +class OffloadedCacheProcessor(CacheProcessor): + """ + A cache processor that offloads cache tensors to conserve accelerator memory. - num_key_value_heads = key_states.shape[1] - device = key_states.device - global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) + This processor manages moving cache tensors between accelerator and CPU memory, + using asynchronous prefetching to minimize performance impact. Works with both + dynamic and static layers. + """ - def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - cumulative_length = self.cumulative_length[layer_idx] - # Update it now that we saved the value above - self.cumulative_length[layer_idx] += key_states.shape[-2] - is_full = cumulative_length >= max_cache_len - if is_full: - full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2) - full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2) - # Fast decoding path -> here as the effective size is still sliding window, it is extremely important - # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address - # in memory (the values are the same as the full states, but not the address!!) - if key_states.shape[-2] == 1: - self.key_cache[layer_idx].copy_(full_key_states) - self.value_cache[layer_idx].copy_(full_value_states) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: - # Fast prefill path, no need to cat() in this case (which creates a copy even if cating from 0 dim) - if cumulative_length == 0: - full_key_states = key_states - full_value_states = value_states - else: - full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2) - full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2) - else: - try: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - except NotImplementedError: - # MPS does not support index_copy_ - self.key_cache[layer_idx][:, :, cache_position] = key_states - self.value_cache[layer_idx][:, :, cache_position] = value_states - return self.key_cache[layer_idx], self.value_cache[layer_idx] + def __init__(self, cache: "Cache", offload_device: Union[str, torch.device] = "cpu", **kwargs): + """Initialize the offload processor and check device compatibility.""" + self.offload_device = torch.device(offload_device) + self.original_device = [] + self.prefetch_stream = None + self.beam_idx = None - self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) - self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return full_key_states, full_value_states + if not ( + torch.cuda.is_available() + or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) + ): + raise RuntimeError( + "OffloadedCacheProcessor can only be used with a GPU" + + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") + ) - def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states + self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers) + if self.is_static: + for i, layer in enumerate(cache.layers): + device = cache.layer_init_args["device"] if i == 0 else self.offload_device + cache.key_cache[i] = cache.key_cache[i].to(device) + cache.value_cache[i] = cache.value_cache[i].to(device) + self.original_device.append(cache.layer_init_args["device"]) + if len(cache) != cache.model_num_layers: + raise ValueError("If static layers are used, all cache layers must be initialized") - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out - return k_out, v_out + self.prefetch_stream = ( + torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream() + ) - def update( + def pre_update( self, + cache: "Cache", key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - self.initialise_cache_layer(layer_idx, key_states) - - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) + """Handle prefetching and eviction before cache update.""" + # Update the cache + if len(cache) < layer_idx: + raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") + elif len(cache) == layer_idx: + self.original_device.append(key_states.device) + self._evict_previous_layer(cache, layer_idx) + else: + # Wait for the previous layer to be evicted (on default stream) + if is_torch_greater_or_equal("2.7", accept_dev=True): + torch.accelerator.current_stream().synchronize() + else: + torch.cuda.current_stream().synchronize() + self._evict_previous_layer(cache, layer_idx) + self._ensure_layer_on_device(cache, layer_idx) - update_fn = self._sliding_update if self.is_sliding[layer_idx] else self._static_update - return update_fn( - cache_position, - layer_idx, - key_states, - value_states, - k_out, - v_out, - k_out.shape[2], - ) + # Prefetch the next layer + self._prefetch_layer(cache, (layer_idx + 1) % len(cache)) + return key_states, value_states - def get_max_cache_shape(self) -> int: - return self.max_cache_len + def _prefetch_layer(self, cache: "Cache", layer_idx: int): + """Starts prefetching the next layer cache.""" + if layer_idx < len(cache): + with ( + self.prefetch_stream + if is_torch_greater_or_equal("2.7", accept_dev=True) + else torch.cuda.stream(self.prefetch_stream) + ): + # Prefetch next layer tensors to GPU + device = self.original_device[layer_idx] + cache.key_cache[layer_idx] = cache.key_cache[layer_idx].to(device, non_blocking=True) + cache.value_cache[layer_idx] = cache.value_cache[layer_idx].to(device, non_blocking=True) - def get_seq_length(self, layer_idx: Optional[int] = 0): - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx != 0: - raise ValueError( - "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " - "Using the `layer_idx` argument is not supported." + def _evict_previous_layer(self, cache: "Cache", layer_idx: int): + """Moves the previous layer cache to the CPU.""" + if len(cache) >= 2: # Layer 0 stays on device to be on-device after all layers are created + # We do it on the default stream so it occurs after all earlier computations on these tensors are done + prev_layer_idx = (layer_idx - 1) % len(cache) + cache.key_cache[prev_layer_idx] = cache.key_cache[prev_layer_idx].to( + self.offload_device, non_blocking=True + ) + cache.value_cache[prev_layer_idx] = cache.value_cache[prev_layer_idx].to( + self.offload_device, non_blocking=True ) - if len(self.key_cache) == 0: - return 0 - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] + def _ensure_layer_on_device(self, cache: "Cache", layer_idx: int): + """Ensures the current layer is on the original device.""" + if layer_idx < len(cache): + # Wait for the previous prefetch to be done + self.prefetch_stream.synchronize() - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - if self.value_cache[layer_idx].numel(): - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - if self.is_sliding[layer_idx]: - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] + # Handle delayed beam search operations + if self.beam_idx is not None: + self.beam_idx = self.beam_idx.to(self.original_device[layer_idx]) + cache.key_cache[layer_idx] = cache.key_cache[layer_idx].index_select(0, self.beam_idx) + cache.value_cache[layer_idx] = cache.value_cache[layer_idx].index_select(0, self.beam_idx) - local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) - # This is the true general case for any Cache using local attention (sliding or chunked) - if first_cache_position >= self.sliding_window: - # Here the Cache is already full - local_mask_kv_length = self.sliding_window + query_length - 1 - elif ( - first_cache_position < self.sliding_window - and first_cache_position + query_length > self.sliding_window - ): - # Here the Cache becomes full with the new input - local_mask_kv_length = first_cache_position + query_length - else: - # Here the Cache is still smaller than the local size, but we return the local size as it's static - local_mask_kv_length = self.sliding_window - return local_mask_kv_length, local_mask_kv_offset - full_mask_kv_offset = 0 - full_mask_kv_length = self.get_max_cache_shape() - return full_mask_kv_length, full_mask_kv_offset +class QuantizedCacheProcessor(CacheProcessor): + """ + A cache processor that applies quantization to cache tensors to reduce memory usage. + This processor quantizes cache tensors after they are stored, maintaining a residual + length in original precision and quantizing older tokens. + """ -class OffloadedHybridCache(HybridChunkedCache): def __init__( self, - model_config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.bfloat16, - offload_device: Union[str, torch.device] = torch.device("cpu"), - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, + cache: "Cache", + backend: str = "quanto", + nbits: Optional[int] = 4, + axis_key: Optional[int] = 0, + axis_value: Optional[int] = 0, + q_group_size: Optional[int] = 64, + residual_length: Optional[int] = 128, + compute_dtype: Optional[torch.dtype] = torch.float16, + device: Optional[str] = "cpu", ): - super().__init__(model_config, max_batch_size, max_cache_len, device, dtype, layer_device_map) + """ + Parameters: + backend (`str`, *optional*, defaults to `"quanto"`): + Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] + nbits (`Optional[int]`, *optional*, defaults to 4): + Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. + axis_key (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + axis_value (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + q_group_size (`Optional[int]`, *optional*, defaults to 64): + Size of the quantization group, should be a divisor of the model's hidden dimension. + Defaults to 64. + residual_length (`Optional[int]`, *optional*, defaults to 128): + Length of the residual cache which will always be stored in original precision. + Defaults to 128. + compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. + device (`str`, *optional*, defaults to `"cpu"`): + Device on which to perform computations, should be same as the model's device. + """ + self.backend = backend + self.nbits = nbits + self.axis_key = axis_key + self.axis_value = axis_value + self.q_group_size = q_group_size + self.residual_length = residual_length + self.compute_dtype = compute_dtype + self.device = device + self._quantized_keys: list[torch.Tensor] = [] + self._quantized_values: list[torch.Tensor] = [] - # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps - # track of the original device of each layer - unique_devices = set(layer_device_map.values()) if layer_device_map else set() - if len(unique_devices) > 1: - raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}") + self.validate() + self.erased_length = 0 - self.offload_device = torch.device(offload_device) - # Create new CUDA stream for parallel prefetching. - self._prefetch_stream = torch.cuda.Stream() if torch._C._get_accelerator().type == "cuda" else None - # Those will be dynamically created as the other layers (for TP) - self.device_key_cache = None - self.device_value_cache = None - # This gives the index of which on-device full layer to use (we need 2 to avoid race conditions when prefetching) - self.active_device_layer = 0 + # Only compatible with DynamicCache + if not isinstance(cache.layers[0], DynamicLayer): + raise ValueError("QuantizedCacheProcessor is only compatible with DynamicCache") - def initialise_cache_layer(self, layer_idx, key_states): - """Overridden to use the correct device if offloaded layer (and pin memory).""" - if len(self.key_cache) > layer_idx: - return + def validate(self): + """Validates if the arguments passed are correct""" - num_key_value_heads = key_states.shape[1] - device = key_states.device if self.is_sliding[layer_idx] else self.offload_device - pin_memory = not self.is_sliding[layer_idx] - global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + # Check that the values are reasonable in general (nbits, axis) + # Later in QuantizedCache init we check if they are supported for that particular backend + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + incorrect_arg_msg.format( + key="nbits", + correct_value="2 or 4 or 8", + found_value=self.nbits, + ), + ) + if self.q_group_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="q_group_size", + correct_value="a positive integer", + found_value=self.q_group_size, + ), + ) + if self.residual_length < 0: + raise ValueError( + incorrect_arg_msg.format( + key="residual_length", + correct_value="a positive integer", + found_value=self.residual_length, + ), + ) - # Make sure to initialize the on-device layer if it does not already exist - if self.device_key_cache is None and not self.is_sliding[layer_idx]: - self.device_key_cache = [] - self.device_value_cache = [] - # We need 2 layers to avoid race conditions when prefetching the next one - for _ in range(2): - device_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device) - device_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.device_key_cache.append(device_layer_key_cache) - self.device_value_cache.append(device_layer_value_cache) + if self.axis_key not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_key", + correct_value="`1` or `0`, `-1`", + found_value=self.axis_key, + ), + ) - def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - # Wait for prefetch stream if needed - if self._prefetch_stream is not None: - torch.cuda.default_stream(key_states.device).wait_stream(self._prefetch_stream) + if self.axis_value not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_value", + correct_value="`1` or `0` or `-1`", + found_value=self.axis_value, + ), + ) - # Get correct on-device layer - k_out = self.device_key_cache[self.active_device_layer] - v_out = self.device_value_cache[self.active_device_layer] + def post_update( + self, + cache: "Cache", + key_tensors: torch.Tensor, + value_tensors: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply quantization after cache update.""" - # Let's prefetch the next layer as soon as possible - self._prefetch_next_layer(layer_idx) + if len(cache) < layer_idx: + raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") - # Copy to on-device layer - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states + # `key_tensors` is the content of the residual cache, after having been updated by DynamicLayer + # On the first forward pass, we quantize the whole prompt (prefill, quantize_length=0) + # On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full. + if self._is_quantized_length_zero(layer_idx): + self._quantized_keys.append(self._quantize(key_tensors.contiguous(), axis=self.axis_key)) + self._quantized_values.append(self._quantize(value_tensors.contiguous(), axis=self.axis_value)) - # Copy to offloaded device - self.key_cache[layer_idx][:, :, cache_position] = key_states.to(self.offload_device) - self.value_cache[layer_idx][:, :, cache_position] = value_states.to(self.offload_device) + # Clear the residual cache + self.erased_length = key_tensors.shape[-2] + cache.key_cache[layer_idx] = torch.zeros( + 0, + dtype=key_tensors.dtype, + device=key_tensors.device, + ) + cache.value_cache[layer_idx] = torch.zeros( + 0, + dtype=value_tensors.dtype, + device=value_tensors.device, + ) + # On prefill, we return the original prompt + keys_to_return, values_to_return = key_tensors, value_tensors - return k_out, v_out + else: + # Prepend the previously quantized cache + dequant_key = self._dequantize(self._quantized_keys[layer_idx]) + dequant_value = self._dequantize(self._quantized_values[layer_idx]) + keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2) + values_to_return = torch.cat([dequant_value, value_tensors], dim=-2) + if key_tensors.shape[-2] >= self.residual_length: + # Quantize and store + self._quantized_keys[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) + self._quantized_values[layer_idx] = self._quantize(values_to_return.contiguous(), axis=self.axis_value) - def _prefetch_next_layer(self, layer_idx: int) -> None: - """Based on current layer_idx, prefetch next full layer to the device.""" + # Clear the residual cache + self.erased_length += key_tensors.shape[-2] + cache.key_cache[layer_idx] = torch.zeros( + 0, + dtype=key_tensors.dtype, + device=key_tensors.device, + ) + cache.value_cache[layer_idx] = torch.zeros( + 0, + dtype=value_tensors.dtype, + device=value_tensors.device, + ) - # Switch the active layer - self.active_device_layer = 0 if self.active_device_layer == 1 else 1 + return keys_to_return, values_to_return - # Find the next non-sliding layer - try: - next_layer = layer_idx + 1 + self.is_sliding[layer_idx + 1 :].index(False) - # In this case, we are at the last layer, and we go back to prefect the first one - except ValueError: - next_layer = self.is_sliding.index(False) + def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: + """Quantize a tensor - to be implemented by specific quantization backends.""" + raise NotImplementedError("Quantization backend must implement _quantize method") + + def _dequantize(self, tensor: torch.Tensor) -> torch.Tensor: + """Dequantize a tensor - to be implemented by specific quantization backends.""" + raise NotImplementedError("Quantization backend must implement _dequantize method") + + def _is_quantized_length_zero(self, layer_idx: int) -> bool: + """Check if quantized cache is empty for layer. Note: shape[-2] is unreliable since quantized tensors are bit-packed and flattened.""" + return layer_idx >= len(self._quantized_keys) + + +class QuantoQuantizedCacheProcessor(QuantizedCacheProcessor): + """ + Quantized cache processor that uses `quanto` as a backend to perform quantization. + Current implementation supports `int2` and `int4` dtypes only. + """ + + def __init__( + self, + cache: "Cache", + backend: str = "quanto", + nbits: Optional[int] = 4, + axis_key: Optional[int] = 0, + axis_value: Optional[int] = 0, + q_group_size: Optional[int] = 64, + residual_length: Optional[int] = 128, + compute_dtype: Optional[torch.dtype] = torch.float16, + device: Optional[str] = "cpu", + ) -> None: + """Initialize the quanto quantization processor.""" + super().__init__( + cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype, device + ) + + if backend != "quanto": + raise ValueError(f"QuantoQuantizedCacheProcessor only supports `quanto` backend, but got {backend}") + + if is_optimum_quanto_available(): + optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) + if optimum_quanto_version <= version.parse("0.2.5"): + raise ImportError( + f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCacheProcessor`. Detected version {optimum_quanto_version}." + ) + from optimum.quanto import MaxOptimizer, qint2, qint4 + + if self.nbits not in [2, 4]: + raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") + + if self.axis_key not in [0, -1]: + raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") + + if self.axis_value not in [0, -1]: + raise ValueError( + f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" + ) + + self.qtype = qint4 if self.nbits == 4 else qint2 + self.optimizer = MaxOptimizer() + + def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: + """Quantize tensor using quanto backend.""" + if is_optimum_quanto_available(): + from optimum.quanto import quantize_weight + + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) + return qtensor + + def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor: + """Dequantize tensor using quanto backend.""" + return qtensor.dequantize() + + +class HQQQuantizedCacheProcessor(QuantizedCacheProcessor): + """ + Quantized cache processor that uses `HQQ` as a backend to perform quantization. + Current implementation supports `int2`, `int4`, `int8` dtypes. + """ + + def __init__( + self, + cache: "Cache", + backend: str = "quanto", + nbits: Optional[int] = 4, + axis_key: Optional[int] = 0, + axis_value: Optional[int] = 0, + q_group_size: Optional[int] = 64, + residual_length: Optional[int] = 128, + compute_dtype: Optional[torch.dtype] = torch.float16, + device: Optional[str] = "cpu", + ) -> None: + """Initialize the HQQ quantization processor.""" + super().__init__( + cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype, device + ) + + if backend != "quanto": + raise ValueError(f"HQQQuantizedCacheProcessor only supports `quanto` backend, but got {backend}") + + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" + ) + + if self.axis_key not in [0, 1]: + raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") + + if self.axis_value not in [0, 1]: + raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") + + self.quantizer = HQQQuantizer + + def _quantize(self, tensor: torch.Tensor, axis: int) -> tuple[torch.Tensor, dict]: + """Quantize tensor using HQQ backend.""" + qtensor, meta = self.quantizer.quantize( + tensor, + axis=axis, + device=self.device, + compute_dtype=self.compute_dtype, + nbits=self.nbits, + group_size=self.q_group_size, + ) + meta["compute_dtype"] = self.compute_dtype + self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype + meta["scale"] = meta["scale"].to(qtensor.device) + meta["zero"] = meta["zero"].to(qtensor.device) + return qtensor, meta + + def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tensor: + """Dequantize tensor using HQQ backend.""" + quant_tensor, meta = qtensor_and_meta + tensor = self.quantizer.dequantize(quant_tensor, meta) + return tensor + + +def parse_processor_args(processor_class: Optional[type["CacheProcessor"]], kwargs: dict) -> tuple[dict, dict]: + """ + Parse processor arguments from kwargs based on the processor class init signature. + + Args: + processor_class: The processor class to inspect, or None + kwargs: Dictionary of keyword arguments + + Returns: + tuple: (processor_kwargs, remaining_kwargs) + """ + try: + params = list(inspect.signature(processor_class.__init__).parameters)[2:] + except Exception: + return {}, kwargs + + processor_kwargs = {k: kwargs[k] for k in params if k in kwargs} + remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs} + return processor_kwargs, remaining_kwargs + + +def parse_layer_args_from_model_config( + config: Optional[PretrainedConfig], + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: Optional[torch.dtype] = None, + layer_device_map: Optional[dict[int, torch.device]] = None, + tp_size: Optional[int] = None, + max_batch_size: Optional[int] = None, +) -> dict: + """ + Parse layer arguments from model configuration for cache initialization. + + Args: + config (`Optional[PretrainedConfig]`): Model configuration containing shape/device info. + batch_size (`Optional[int]`): Batch size for cache initialization. + max_cache_len (`Optional[int]`): Maximum sequence length for cache. + device (`Union[torch.device, str, None]`): Device for cache tensors. + dtype (`Optional[torch.dtype]`): Data type for cache tensors. + layer_device_map: Per-layer device mapping. + tp_size (`Optional[int]`): Tensor parallel size to adjust number of key/value heads. + max_batch_size (`Optional[int]`): Maximum batch size for cache initialization. + + Returns: + `dict`: Dictionary containing parsed layer arguments for cache initialization. + """ + # No model config -> must be a dynamic cache, return bare dict + if config is None: + return {} + # Build the args dict for hybrid, sliding or static + else: + # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) + if ( + getattr(config, "layer_types", None) is not None + and "sliding_attention" in config.layer_types + and "full_attention" in config.layer_types + ): + if getattr(config, "sliding_window", None) is None: + raise ValueError( + "Setting up a hybrid or sliding window KVCache requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) + max_cache_len = max_cache_len or config.max_position_embeddings + if getattr(config, "sliding_window", None) is not None: + sliding_window_len = min(config.sliding_window, max_cache_len) + else: + sliding_window_len = None + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: + head_dim = ( + config.head_dim + if getattr(config, "head_dim", None) is not None + else config.hidden_size // config.num_attention_heads + ) + num_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + if tp_size is not None and tp_size > 1: + if num_heads % tp_size != 0: + raise ValueError( + f"Number of key value heads {num_heads} must be divisible by tensor parallel size {tp_size}." + ) + # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. + num_heads //= tp_size + layer_args = { + "batch_size": max_batch_size if max_batch_size is not None else batch_size, + "max_cache_len": max_cache_len, + "device": torch.device(device) if device is not None else None, + "dtype": dtype, + "layer_device_map": layer_device_map, + "head_dim": head_dim, + "num_heads": num_heads, + "sliding_window": sliding_window_len, + } + return {k: v for k, v in layer_args.items() if v is not None} - # Alternate between two on-device caches. - if self._prefetch_stream is not None: - with torch.cuda.stream(self._prefetch_stream): - self._prefetch_layer_in_context(next_layer) - else: - self._prefetch_layer_in_context(next_layer) - def _prefetch_layer_in_context(self, layer_idx: int) -> None: - """Performs the actual copy of the layer to device cache.""" - if len(self.key_cache) > layer_idx: - self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True) - self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True) - # The layer was not yet initialized - else: - self.device_key_cache[self.active_device_layer].fill_(0.0) - self.device_value_cache[self.active_device_layer].fill_(0.0) +LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = { + "full_attention": StaticLayer, + "sliding_attention": SlidingWindowLayer, + # "chunked_attention": ChunkedLayer, +} +PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = { + "offloaded": OffloadedCacheProcessor, + "quanto_quantized": QuantizedCacheProcessor, + "hqq_quantized": HQQQuantizedCacheProcessor, +} -class OffloadedStaticCache(StaticCache): +class EncoderDecoderCache(Cache): """ - A drop-in replacement for StaticCache that conserves accelerator memory by offloading - cache tensors to CPU when not actively being used. - - This cache maintains the compilation-friendly properties of StaticCache while enabling - much longer sequences by offloading inactive layers to CPU memory. + Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and + cross-attention caches. Example: + ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache + >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") + >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") - >>> # Prepare a cache class with offloading - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = OffloadedStaticCache( - ... config=model.config, - ... max_batch_size=1, - ... max_cache_len=max_generated_length, - ... device=model.device, - ... dtype=model.dtype - ... ) + >>> # Prepare cache classes for encoder and decoder and pass it to model's forward + >>> self_attention_cache = DynamicCache() + >>> cross_attention_cache = DynamicCache() + >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache with offloaded layers - OffloadedStaticCache() + >>> outputs.past_key_values # access cache filled with key/values from generation + EncoderDecoderCache() ``` + """ - def __init__(self, *args, offload_device: Union[str, torch.device] = "cpu", **kwargs) -> None: - super().__init__(*args, cache_processor=OffloadedCacheProcessor(offload_device), **kwargs) + # Override @property from Cache + is_compileable = None + def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): + super().__init__() + self.self_attention_cache = self_attention_cache + self.cross_attention_cache = cross_attention_cache + self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) -class OffloadedCacheProcessor(CacheProcessor): - """ - A cache processor that offloads cache tensors to conserve accelerator memory. + self.is_updated = {} + for layer_idx in range(len(cross_attention_cache)): + self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) - This processor manages moving cache tensors between accelerator and CPU memory, - using asynchronous prefetching to minimize performance impact. Works with both - dynamic and static layers. - """ + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield ( + self.self_attention_cache.key_cache[layer_idx], + self.self_attention_cache.value_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + self.cross_attention_cache.value_cache[layer_idx], + ) - def __init__(self, offload_device: Union[str, torch.device] = "cpu"): - self.offload_device = torch.device(offload_device) - self.original_device = [] - self.prefetch_stream = None - self.beam_idx = None + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return ( + self.self_attention_cache.key_cache[layer_idx], + self.self_attention_cache.value_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + self.cross_attention_cache.value_cache[layer_idx], + ) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") - def init(self, cache: "Cache", **kwargs) -> None: - """Initialize the offload processor and check device compatibility.""" + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.self_attention_cache) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]: + """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" + legacy_cache = () + if len(self.cross_attention_cache) > 0: + for self_attn, cross_attn in zip( + self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache() + ): + legacy_cache += (self_attn + cross_attn,) + else: + legacy_cache = self.self_attention_cache.to_legacy_cache() + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None + ) -> "EncoderDecoderCache": + """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" + cache = cls( + self_attention_cache=DynamicCache(), + cross_attention_cache=DynamicCache(), + ) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx][:2] + cache.self_attention_cache.update(key_states, value_states, layer_idx) + if len(past_key_values[layer_idx]) > 2: + key_states, value_states = past_key_values[layer_idx][2:] + cache.cross_attention_cache.update(key_states, value_states, layer_idx) + cache.is_updated[layer_idx] = True + return cache + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` + return self.self_attention_cache.get_seq_length(layer_idx) + + def reset(self): + if hasattr(self.self_attention_cache, "reset"): + self.self_attention_cache.reset() + if hasattr(self.cross_attention_cache, "reset"): + self.cross_attention_cache.reset() + elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"): + raise ValueError( + "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " + "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " + f"Got {self.self_attention_cache.__str__()} for the self attention cache and " + f"{self.cross_attention_cache.__str__()} for the cross attention cache." + ) + for layer_idx in self.is_updated: + self.is_updated[layer_idx] = False + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + self.self_attention_cache.reorder_cache(beam_idx) + self.cross_attention_cache.reorder_cache(beam_idx) + + def check_dynamic_cache(self, method: str): if not ( - torch.cuda.is_available() - or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) + isinstance(self.self_attention_cache, DynamicCache) + and isinstance(self.cross_attention_cache, DynamicCache) ): - raise RuntimeError( - "OffloadedCacheProcessor can only be used with a GPU" - + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") + raise ValueError( + f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " + f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." ) - self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers) - if self.is_static: - for i, layer in enumerate(cache.layers): - device = cache.layer_init_args["device"] if i == 0 else self.offload_device - cache.key_cache[i] = cache.key_cache[i].to(device) - cache.value_cache[i] = cache.value_cache[i].to(device) - self.original_device.append(cache.layer_init_args["device"]) - if len(cache) != cache.model_num_layers: - raise ValueError("If static layers are used, all cache layers must be initialized") + # TODO(gante, sanchit-gandhi): move following functionality into `.generate` + def crop(self, maximum_length: int): + """ + Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search. + """ + self.check_dynamic_cache(self.crop.__name__) + self.self_attention_cache.crop(maximum_length) - self.prefetch_stream = ( - torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream() - ) + def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": + """ + Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils` + """ + self.check_dynamic_cache(self.batch_split.__name__) + self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) + cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) - def pre_update( - self, - cache: "Cache", - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Handle prefetching and eviction before cache update.""" - # Update the cache - if len(cache) < layer_idx: - raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") - elif len(cache) == layer_idx: - self.original_device.append(key_states.device) - self._evict_previous_layer(cache, layer_idx) - else: - # Wait for the previous layer to be evicted (on default stream) - if is_torch_greater_or_equal("2.7", accept_dev=True): - torch.accelerator.current_stream().synchronize() - else: - torch.cuda.current_stream().synchronize() - self._evict_previous_layer(cache, layer_idx) - self._ensure_layer_on_device(cache, layer_idx) + out = [] + for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): + out.append(EncoderDecoderCache(self_attn, cross_attn)) + return out + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_repeat_interleave.__name__) + self.self_attention_cache.batch_repeat_interleave(repeats) + self.cross_attention_cache.batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_select_indices.__name__) + self.self_attention_cache.batch_select_indices(indices) + self.cross_attention_cache.batch_select_indices(indices) + + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + return self.self_attention_cache.get_max_cache_shape() - # Prefetch the next layer - self._prefetch_layer(cache, (layer_idx + 1) % len(cache)) - return key_states, value_states + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) - def _prefetch_layer(self, cache: "Cache", layer_idx: int): - """Starts prefetching the next layer cache.""" - if layer_idx < len(cache): - with ( - self.prefetch_stream - if is_torch_greater_or_equal("2.7", accept_dev=True) - else torch.cuda.stream(self.prefetch_stream) - ): - # Prefetch next layer tensors to GPU - device = self.original_device[layer_idx] - cache.key_cache[layer_idx] = cache.key_cache[layer_idx].to(device, non_blocking=True) - cache.value_cache[layer_idx] = cache.value_cache[layer_idx].to(device, non_blocking=True) - def _evict_previous_layer(self, cache: "Cache", layer_idx: int): - """Moves the previous layer cache to the CPU.""" - if len(cache) >= 2: # Layer 0 stays on device to be on-device after all layers are created - # We do it on the default stream so it occurs after all earlier computations on these tensors are done - prev_layer_idx = (layer_idx - 1) % len(cache) - cache.key_cache[prev_layer_idx] = cache.key_cache[prev_layer_idx].to( - self.offload_device, non_blocking=True - ) - cache.value_cache[prev_layer_idx] = cache.value_cache[prev_layer_idx].to( - self.offload_device, non_blocking=True - ) +class HybridCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window + attention and global attention in every other layer (originally implemented for Gemma2). + Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] + for global attention.For more information, see the documentation of each subcomponent cache class. - def _ensure_layer_on_device(self, cache: "Cache", layer_idx: int): - """Ensures the current layer is on the original device.""" - if layer_idx < len(cache): - # Wait for the previous prefetch to be done - self.prefetch_stream.synchronize() + Example: - # Handle delayed beam search operations - if self.beam_idx is not None: - self.beam_idx = self.beam_idx.to(self.original_device[layer_idx]) - cache.key_cache[layer_idx] = cache.key_cache[layer_idx].index_select(0, self.beam_idx) - cache.value_cache[layer_idx] = cache.value_cache[layer_idx].index_select(0, self.beam_idx) + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") -class QuantizedCacheProcessor(CacheProcessor): + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` """ - A cache processor that applies quantization to cache tensors to reduce memory usage. - This processor quantizes cache tensors after they are stored, maintaining a residual - length in original precision and quantizing older tokens. + def __init__(self, config: PretrainedConfig, *args, **kwargs): + # Ugly hack for BC: if layer_types is not set, fallback to StaticCache. Otherwise, Cache init will use config.layer_types + layer_classes = [StaticLayer] if not hasattr(config, "layer_types") else None + super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) + + +class QuantizedCache(DynamicCache): """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - def init( - self, - cache: "Cache", - backend: str = "quanto", - nbits: Optional[int] = 4, - axis_key: Optional[int] = 0, - axis_value: Optional[int] = 0, - q_group_size: Optional[int] = 64, - residual_length: Optional[int] = 128, - compute_dtype: Optional[torch.dtype] = torch.float16, - device: Optional[str] = "cpu", - ): - """ - Parameters: - backend (`str`, *optional*, defaults to `"quanto"`): - Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] - nbits (`Optional[int]`, *optional*, defaults to 4): - Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. - axis_key (`int`, *optional*, defaults to 0): - Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - axis_value (`int`, *optional*, defaults to 0): - Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - q_group_size (`Optional[int]`, *optional*, defaults to 64): - Size of the quantization group, should be a divisor of the model's hidden dimension. - Defaults to 64. - residual_length (`Optional[int]`, *optional*, defaults to 128): - Length of the residual cache which will always be stored in original precision. - Defaults to 128. - compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. - device (`str`, *optional*, defaults to `"cpu"`): - Device on which to perform computations, should be same as the model's device. - """ - self.backend = backend - self.nbits = nbits - self.axis_key = axis_key - self.axis_value = axis_value - self.q_group_size = q_group_size - self.residual_length = residual_length - self.compute_dtype = compute_dtype - self.device = device - self._quantized_keys: list[torch.Tensor] = [] - self._quantized_values: list[torch.Tensor] = [] + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - self.validate() - self.erased_length = 0 + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + """ - # Only compatible with DynamicCache - if not isinstance(cache.layers[0], DynamicLayer): - raise ValueError("QuantizedCacheProcessor is only compatible with DynamicCache") + def __init__(self, backend, **kwargs) -> None: + if backend == "quanto": + processor = QuantoQuantizedCacheProcessor + elif backend == "hqq": + processor = HQQQuantizedCacheProcessor + else: + raise ValueError(f"Unknown quantization backend `{backend}`") - def validate(self): - """Validates if the arguments passed are correct""" + super().__init__(cache_processor=processor, **kwargs) - incorrect_arg_msg = ( - "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " - "but found {found_value}" - ) - # Check that the values are reasonable in general (nbits, axis) - # Later in QuantizedCache init we check if they are supported for that particular backend - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - incorrect_arg_msg.format( - key="nbits", - correct_value="2 or 4 or 8", - found_value=self.nbits, - ), - ) - if self.q_group_size <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="q_group_size", - correct_value="a positive integer", - found_value=self.q_group_size, - ), - ) - if self.residual_length < 0: - raise ValueError( - incorrect_arg_msg.format( - key="residual_length", - correct_value="a positive integer", - found_value=self.residual_length, - ), - ) - if self.axis_key not in [0, 1, -1]: - raise ValueError( - incorrect_arg_msg.format( - key="axis_key", - correct_value="`1` or `0`, `-1`", - found_value=self.axis_key, - ), - ) +class QuantoQuantizedCache(QuantizedCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - if self.axis_value not in [0, 1, -1]: - raise ValueError( - incorrect_arg_msg.format( - key="axis_value", - correct_value="`1` or `0` or `-1`", - found_value=self.axis_value, - ), - ) + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - def post_update( - self, - cache: "Cache", - key_tensors: torch.Tensor, - value_tensors: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Apply quantization after cache update.""" + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - if len(cache) < layer_idx: - raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") + Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. - # `key_tensors` is the content of the residual cache, after having been updated by DynamicLayer - # On the first forward pass, we quantize the whole prompt (prefill, quantize_length=0) - # On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full. - if self._is_quantized_length_zero(layer_idx): - self._quantized_keys.append(self._quantize(key_tensors.contiguous(), axis=self.axis_key)) - self._quantized_values.append(self._quantize(value_tensors.contiguous(), axis=self.axis_value)) + Example: - # Clear the residual cache - self.erased_length = key_tensors.shape[-2] - cache.key_cache[layer_idx] = torch.zeros( - 0, - dtype=key_tensors.dtype, - device=key_tensors.device, - ) - cache.value_cache[layer_idx] = torch.zeros( - 0, - dtype=value_tensors.dtype, - device=value_tensors.device, - ) - # On prefill, we return the original prompt - keys_to_return, values_to_return = key_tensors, value_tensors + ```python + >>> # Run pip install quanto first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig - else: - # Prepend the previously quantized cache - dequant_key = self._dequantize(self._quantized_keys[layer_idx]) - dequant_value = self._dequantize(self._quantized_values[layer_idx]) - keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2) - values_to_return = torch.cat([dequant_value, value_tensors], dim=-2) - if key_tensors.shape[-2] >= self.residual_length: - # Quantize and store - self._quantized_keys[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) - self._quantized_values[layer_idx] = self._quantize(values_to_return.contiguous(), axis=self.axis_value) + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4) + >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + QuantoQuantizedCache() + ``` + """ + + def __init__(self, **kwargs) -> None: + Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs) + + +class HQQQuantizedCache(QuantizedCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - # Clear the residual cache - self.erased_length += key_tensors.shape[-2] - cache.key_cache[layer_idx] = torch.zeros( - 0, - dtype=key_tensors.dtype, - device=key_tensors.device, - ) - cache.value_cache[layer_idx] = torch.zeros( - 0, - dtype=value_tensors.dtype, - device=value_tensors.device, - ) + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - return keys_to_return, values_to_return + Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. - def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: - """Quantize a tensor - to be implemented by specific quantization backends.""" - raise NotImplementedError("Quantization backend must implement _quantize method") + Example: - def _dequantize(self, tensor: torch.Tensor) -> torch.Tensor: - """Dequantize a tensor - to be implemented by specific quantization backends.""" - raise NotImplementedError("Quantization backend must implement _dequantize method") + ```python + >>> # Run pip install hqq first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig - def _is_quantized_length_zero(self, layer_idx: int) -> bool: - """Check if quantized cache is empty for layer. Note: shape[-2] is unreliable since quantized tensors are bit-packed and flattened.""" - return layer_idx >= len(self._quantized_keys) + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") -class QuantoQuantizedCacheProcessor(QuantizedCacheProcessor): - """ - Quantized cache processor that uses `quanto` as a backend to perform quantization. - Current implementation supports `int2` and `int4` dtypes only. + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) + >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HQQQuantizedCache() + ``` """ - def init(self, cache: "Cache", **kwargs) -> None: - """Initialize the quanto quantization processor.""" - super().init(cache, **kwargs) + def __init__(self, backend="HQQ", **kwargs) -> None: + assert backend == "HQQ" + Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs) - if is_optimum_quanto_available(): - optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) - if optimum_quanto_version <= version.parse("0.2.5"): - raise ImportError( - f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCacheProcessor`. Detected version {optimum_quanto_version}." - ) - from optimum.quanto import MaxOptimizer, qint2, qint4 - if self.nbits not in [2, 4]: - raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") +class OffloadedStaticCache(StaticCache): + """ + A drop-in replacement for StaticCache that conserves accelerator memory by offloading + cache tensors to CPU when not actively being used. - if self.axis_key not in [0, -1]: - raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") + This cache maintains the compilation-friendly properties of StaticCache while enabling + much longer sequences by offloading inactive layers to CPU memory. - if self.axis_value not in [0, -1]: - raise ValueError( - f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" - ) + Example: + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache - self.qtype = qint4 if self.nbits == 4 else qint2 - self.optimizer = MaxOptimizer() + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: - """Quantize tensor using quanto backend.""" - if is_optimum_quanto_available(): - from optimum.quanto import quantize_weight + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) - qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) - return qtensor + >>> # Prepare a cache class with offloading + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = OffloadedStaticCache( + ... config=model.config, + ... max_batch_size=1, + ... max_cache_len=max_generated_length, + ... device=model.device, + ... dtype=model.dtype + ... ) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache with offloaded layers + OffloadedStaticCache() + ``` + """ - def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor: - """Dequantize tensor using quanto backend.""" - return qtensor.dequantize() + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) -class HQQQuantizedCacheProcessor(QuantizedCacheProcessor): +class HybridChunkedCache(Cache): """ - Quantized cache processor that uses `HQQ` as a backend to perform quantization. - Current implementation supports `int2`, `int4`, `int8` dtypes. + Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window + attention and global attention in every other layer, with support for chunked attention (originally implemented + for Llama4). + Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] + for global attention. For more information, see the documentation of each subcomponent cache class. + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` """ - def init(self, cache: "Cache", **kwargs) -> None: - """Initialize the HQQ quantization processor.""" - super().init(cache, **kwargs) + is_compileable = True + # Override Cache's @property methods since we will define them in the init + is_sliding = None + max_batch_size = None + max_cache_len = None - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" - ) + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: torch.dtype = torch.bfloat16, + layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, + ) -> None: + if not hasattr(config, "sliding_window") or config.sliding_window is None: + self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192) + else: + self.sliding_window = config.sliding_window + self.max_cache_len = max_cache_len + # Sliding layers can't be larger than the overall max cache len + self.sliding_window = min(self.sliding_window, self.max_cache_len) + self.max_batch_size = max_batch_size + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self._dtype = dtype - if self.axis_key not in [0, 1]: - raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") + # If the attribute does not exist in the config, fallback to a simple StaticCache + if hasattr(config, "layer_types"): + self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types] + else: + self.is_sliding = [False] * config.num_hidden_layers - if self.axis_value not in [0, 1]: - raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + self.cumulative_length = [0 for _ in range(config.num_hidden_layers)] - self.quantizer = HQQQuantizer + def initialise_cache_layer(self, layer_idx, key_states): + if len(self.key_cache) > layer_idx: + return - def _quantize(self, tensor: torch.Tensor, axis: int) -> tuple[torch.Tensor, dict]: - """Quantize tensor using HQQ backend.""" - qtensor, meta = self.quantizer.quantize( - tensor, - axis=axis, - device=self.device, - compute_dtype=self.compute_dtype, - nbits=self.nbits, - group_size=self.q_group_size, - ) - meta["compute_dtype"] = self.compute_dtype - self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype - meta["scale"] = meta["scale"].to(qtensor.device) - meta["zero"] = meta["zero"].to(qtensor.device) - return qtensor, meta + num_key_value_heads = key_states.shape[1] + device = key_states.device + global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) + sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) - def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tensor: - """Dequantize tensor using HQQ backend.""" - quant_tensor, meta = qtensor_and_meta - tensor = self.quantizer.dequantize(quant_tensor, meta) - return tensor + def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): + cumulative_length = self.cumulative_length[layer_idx] + # Update it now that we saved the value above + self.cumulative_length[layer_idx] += key_states.shape[-2] + is_full = cumulative_length >= max_cache_len + if is_full: + full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2) + full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2) + # Fast decoding path -> here as the effective size is still sliding window, it is extremely important + # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address + # in memory (the values are the same as the full states, but not the address!!) + if key_states.shape[-2] == 1: + self.key_cache[layer_idx].copy_(full_key_states) + self.value_cache[layer_idx].copy_(full_value_states) + return self.key_cache[layer_idx], self.value_cache[layer_idx] + elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: + # Fast prefill path, no need to cat() in this case (which creates a copy even if cating from 0 dim) + if cumulative_length == 0: + full_key_states = key_states + full_value_states = value_states + else: + full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2) + full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2) + else: + try: + self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + except NotImplementedError: + # MPS does not support index_copy_ + self.key_cache[layer_idx][:, :, cache_position] = key_states + self.value_cache[layer_idx][:, :, cache_position] = value_states + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) + self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return full_key_states, full_value_states + + def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + return k_out, v_out -class QuantizedCache(DynamicCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if cache_kwargs is None: + cache_kwargs = {} + cache_position = cache_kwargs.get("cache_position") + self.initialise_cache_layer(layer_idx, key_states) - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - """ + update_fn = self._sliding_update if self.is_sliding[layer_idx] else self._static_update + return update_fn( + cache_position, + layer_idx, + key_states, + value_states, + k_out, + v_out, + k_out.shape[2], + ) - def __init__(self, backend, *args, **kwargs) -> None: - if backend == "quanto": - processor = QuantoQuantizedCacheProcessor() - elif backend == "hqq": - processor = HQQQuantizedCacheProcessor() - else: - raise ValueError(f"Unknown quantization backend `{backend}`") + def get_max_cache_shape(self) -> int: + return self.max_cache_len - super().__init__(cache_processor=processor) + def get_seq_length(self, layer_idx: Optional[int] = 0): + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx != 0: + raise ValueError( + "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " + "Using the `layer_idx` argument is not supported." + ) + if len(self.key_cache) == 0: + return 0 + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] -class QuantoQuantizedCache(QuantizedCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.value_cache[layer_idx].numel(): + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + if self.is_sliding[layer_idx]: + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) + # This is the true general case for any Cache using local attention (sliding or chunked) + if first_cache_position >= self.sliding_window: + # Here the Cache is already full + local_mask_kv_length = self.sliding_window + query_length - 1 + elif ( + first_cache_position < self.sliding_window + and first_cache_position + query_length > self.sliding_window + ): + # Here the Cache becomes full with the new input + local_mask_kv_length = first_cache_position + query_length + else: + # Here the Cache is still smaller than the local size, but we return the local size as it's static + local_mask_kv_length = self.sliding_window + return local_mask_kv_length, local_mask_kv_offset - Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + full_mask_kv_offset = 0 + full_mask_kv_length = self.get_max_cache_shape() + return full_mask_kv_length, full_mask_kv_offset - Example: - ```python - >>> # Run pip install quanto first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig +class OffloadedHybridCache(HybridChunkedCache): + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: torch.dtype = torch.bfloat16, + offload_device: Union[str, torch.device] = torch.device("cpu"), + layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, + ): + super().__init__(config, max_batch_size, max_cache_len, device, dtype, layer_device_map) - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps + # track of the original device of each layer + unique_devices = set(layer_device_map.values()) if layer_device_map else set() + if len(unique_devices) > 1: + raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}") - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + self.offload_device = torch.device(offload_device) + # Create new CUDA stream for parallel prefetching. + self._prefetch_stream = torch.cuda.Stream() if torch._C._get_accelerator().type == "cuda" else None + # Those will be dynamically created as the other layers (for TP) + self.device_key_cache = None + self.device_value_cache = None + # This gives the index of which on-device full layer to use (we need 2 to avoid race conditions when prefetching) + self.active_device_layer = 0 - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4) - >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - QuantoQuantizedCache() - ``` - """ + def initialise_cache_layer(self, layer_idx, key_states): + """Overridden to use the correct device if offloaded layer (and pin memory).""" + if len(self.key_cache) > layer_idx: + return - def __init__(self, *args, **kwargs) -> None: - Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor()) + num_key_value_heads = key_states.shape[1] + device = key_states.device if self.is_sliding[layer_idx] else self.offload_device + pin_memory = not self.is_sliding[layer_idx] + global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) + sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + # Make sure to initialize the on-device layer if it does not already exist + if self.device_key_cache is None and not self.is_sliding[layer_idx]: + self.device_key_cache = [] + self.device_value_cache = [] + # We need 2 layers to avoid race conditions when prefetching the next one + for _ in range(2): + device_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device) + device_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.device_key_cache.append(device_layer_key_cache) + self.device_value_cache.append(device_layer_value_cache) -class HQQQuantizedCache(QuantizedCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): + # Wait for prefetch stream if needed + if self._prefetch_stream is not None: + torch.cuda.default_stream(key_states.device).wait_stream(self._prefetch_stream) - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + # Get correct on-device layer + k_out = self.device_key_cache[self.active_device_layer] + v_out = self.device_value_cache[self.active_device_layer] - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + # Let's prefetch the next layer as soon as possible + self._prefetch_next_layer(layer_idx) - Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + # Copy to on-device layer + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states - Example: + # Copy to offloaded device + self.key_cache[layer_idx][:, :, cache_position] = key_states.to(self.offload_device) + self.value_cache[layer_idx][:, :, cache_position] = value_states.to(self.offload_device) - ```python - >>> # Run pip install hqq first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig + return k_out, v_out - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + def _prefetch_next_layer(self, layer_idx: int) -> None: + """Based on current layer_idx, prefetch next full layer to the device.""" - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + # Switch the active layer + self.active_device_layer = 0 if self.active_device_layer == 1 else 1 - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) - >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HQQQuantizedCache() - ``` - """ + # Find the next non-sliding layer + try: + next_layer = layer_idx + 1 + self.is_sliding[layer_idx + 1 :].index(False) + # In this case, we are at the last layer, and we go back to prefect the first one + except ValueError: + next_layer = self.is_sliding.index(False) - def __init__(self, backend="HQQ", *args, **kwargs) -> None: - assert backend == "HQQ" - Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor()) + # Alternate between two on-device caches. + if self._prefetch_stream is not None: + with torch.cuda.stream(self._prefetch_stream): + self._prefetch_layer_in_context(next_layer) + else: + self._prefetch_layer_in_context(next_layer) + + def _prefetch_layer_in_context(self, layer_idx: int) -> None: + """Performs the actual copy of the layer to device cache.""" + if len(self.key_cache) > layer_idx: + self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True) + self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True) + # The layer was not yet initialized + else: + self.device_key_cache[self.active_device_layer].fill_(0.0) + self.device_value_cache[self.active_device_layer].fill_(0.0) class SinkCache(Cache): @@ -2123,6 +2230,13 @@ class CacheConfig: cache_implementation: None + def __post_init__(self): + warnings.warn( + ("CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."), + FutureWarning, + stacklevel=2, + ) + @classmethod def from_dict(cls, config_dict, **kwargs): """ @@ -2483,10 +2597,3 @@ def reset(self): return MambaCache raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -LAYER_CLASS_MAP = { - "full_attention": StaticLayer, - "sliding_attention": SlidingWindowLayer, - # "chunked_attention": ChunkedLayer, -} diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 6453d7c6d287..7d2cd21effb2 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -397,6 +397,16 @@ def __init__(self, **kwargs): self.use_cache = kwargs.pop("use_cache", True) self.cache_implementation = kwargs.pop("cache_implementation", None) self.cache_config = kwargs.pop("cache_config", None) + if self.cache_config is not None and not isinstance(self.cache_config, dict): + warnings.warn( + ( + "Passing a CacheConfig object is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary." + ), + FutureWarning, + stacklevel=2, + ) + self.cache_config = self.cache_config.to_dict() + self.return_legacy_cache = kwargs.pop("return_legacy_cache", None) self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3f2cea382bdb..7c749b6768c7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1959,7 +1959,7 @@ def _get_cache( layer_device_map = self._get_layer_device_map_for_cache_init() cache_kwargs = { - "model_config": self.config.get_text_config(), + "config": self.config.get_text_config(), "max_batch_size": batch_size, "max_cache_len": max_cache_len, "dtype": cache_dtype, diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 0acdcdc95efe..18b7b3a92980 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -275,7 +275,7 @@ def __init__(self, model: PreTrainedModel): self.model = model self.static_cache = StaticCache( - model_config=self.model.config, + config=self.model.config, max_batch_size=self.model.generation_config.cache_config.get("batch_size"), max_cache_len=self.model.generation_config.cache_config.get("max_cache_len"), device=self.model.generation_config.cache_config.get("device"), @@ -404,7 +404,7 @@ def __init__( # Initialize the HybridCache self.cache = HybridCache( - model_config=self.model.config, + config=self.model.config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=self.model.device, @@ -550,7 +550,7 @@ def __init__(self, model, max_static_cache_length, batch_size): # Initialize static cache for decoder and DynamicCache for encoder self.static_cache = StaticCache( - model_config=self.config, + config=self.config, max_batch_size=batch_size, max_cache_len=max_static_cache_length, device="cpu", diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index f729fb2b5413..8d5aab9f1342 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -753,10 +753,10 @@ def create_causal_mask( useful to easily overlay another mask on top of the causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the full layers - is_sliding = [] - if past_key_values is not None and getattr(past_key_values, "layers", None) is not None: - is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] - layer_idx = is_sliding.index(True) if True in is_sliding else 0 + if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding: + layer_idx = past_key_values.is_sliding.index(False) + else: + layer_idx = 0 early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments( config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx @@ -843,10 +843,10 @@ def create_sliding_window_causal_mask( useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the sliding layers - is_sliding = [] - if past_key_values is not None and getattr(past_key_values, "layers", None) is not None: - is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] - layer_idx = is_sliding.index(True) if True in is_sliding else 0 + if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding: + layer_idx = past_key_values.is_sliding.index(True) + else: + layer_idx = 0 early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments( config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx @@ -938,10 +938,10 @@ def create_chunked_causal_mask( useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the sliding layers - is_sliding = [] - if past_key_values is not None: - is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] - layer_idx = is_sliding.index(True) if True in is_sliding else 0 + if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding: + layer_idx = past_key_values.is_sliding.index(True) + else: + layer_idx = 0 early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments( config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 8cc24857ce81..db00dc7dbfe1 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -171,9 +171,7 @@ def _random_kvs(config): return random_keys, random_values mha_config = LlamaConfig(num_attention_heads=32) - mha_static_cache = StaticCache( - model_config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device - ) + mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mha_static_cache.update( *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -181,9 +179,7 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 32, 10, 128)) gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) - gqa_static_cache = StaticCache( - model_config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device - ) + gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = gqa_static_cache.update( *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -191,9 +187,7 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 4, 10, 128)) mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) - mqa_static_cache = StaticCache( - model_config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device - ) + mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mqa_static_cache.update( *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -927,7 +921,7 @@ def setUp(self): def test_static_cache_out_of_bounds(self): """Test StaticCache raises IndexError for out-of-bounds positions.""" - static_cache = StaticCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len with self.assertRaises(IndexError): @@ -949,7 +943,7 @@ def test_static_cache(self): update pos 3: [1.0, 2.0, 3.0, 4.0] """ # Scenario 1: Fill up to near capacity - static_cache = StaticCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None) static_cache.update( @@ -989,9 +983,7 @@ def test_sliding_window_cache(self): result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ # Scenario 1: Update within window, no slide yet - sliding_cache = SlidingWindowCache( - model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len - ) + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] sliding_cache.update( key_states=prefill, @@ -1012,9 +1004,7 @@ def test_sliding_window_cache(self): ) # Scenario 2: Update causing slide - sliding_cache = SlidingWindowCache( - model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len - ) + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] sliding_cache.update( key_states=prefill, @@ -1035,9 +1025,7 @@ def test_sliding_window_cache(self): ) # Scenario 3: Long prompt handling - sliding_cache = SlidingWindowCache( - model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len - ) + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] sliding_cache.update( key_states=long_prefill, @@ -1065,7 +1053,7 @@ def test_hybrid_cache_static_mode(self): config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0) # Scenario 1 - hybrid_cache_static_mode = HybridCache(model_config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] hybrid_cache_static_mode.update( key_states=prefill, @@ -1117,7 +1105,7 @@ def test_hybrid_cache_sliding_mode(self): result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ # Scenario 1: Update within window, no slide yet - hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, @@ -1138,7 +1126,7 @@ def test_hybrid_cache_sliding_mode(self): ) # Scenario 2: Update causing first slide - hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, @@ -1172,7 +1160,7 @@ def test_hybrid_cache_sliding_mode(self): ) # Scenario 4: Long prompt handling - hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] hybrid_cache.update( key_states=long_prefill, From 0c6d2ff6c6c370df4a757acfa7e3abcea8baa401 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Thu, 19 Jun 2025 17:40:24 +0200 Subject: [PATCH 12/35] fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. --- utils/check_docstrings.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 35fe662beaaa..5d637e0f2eaf 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -681,8 +681,8 @@ def eval_math_expression(expression: str) -> Optional[Union[float, int]]: def eval_node(node): - if isinstance(node, ast.Num): # - return node.n + if isinstance(node, ast.Constant) and type(node.value) in (int, float, complex): + return node.value elif isinstance(node, ast.BinOp): # return MATH_OPERATORS[type(node.op)](eval_node(node.left), eval_node(node.right)) elif isinstance(node, ast.UnaryOp): # e.g., -1 @@ -942,7 +942,12 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str): if idx == len(source): # Args are not defined in the docstring of this object - return + obj_file = find_source_file(obj) + raise ValueError( + f"Cannot fix docstring of {obj.__name__} in {obj_file} because no argument section was found in the docstring. " + f"The docstring should contain a section starting with 'Args:', 'Arguments:', 'Parameters:', or similar. " + f"Current docstring:\n{obj.__doc__[:200]}{'...' if len(obj.__doc__) > 200 else ''}" + ) # Get to the line where we stop documenting arguments indent = find_indent(source[idx]) @@ -958,7 +963,17 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str): if "".join(source[start_idx:idx])[:-1] != old_doc_args: # Args are not fully defined in the docstring of this object - return + obj_file = find_source_file(obj) + actual_args_section = "".join(source[start_idx:idx])[:-1] + raise ValueError( + f"Cannot fix docstring of {obj.__name__} in {obj_file} because the argument section in the source code " + f"does not match the expected format. This usually happens when:\n" + f"1. The argument section is not properly indented\n" + f"2. The argument section contains unexpected formatting\n" + f"3. The docstring parsing failed to correctly identify the argument boundaries\n\n" + f"Expected argument section:\n{repr(old_doc_args)}\n\n" + f"Actual argument section found:\n{repr(actual_args_section)}\n\n" + ) obj_file = find_source_file(obj) with open(obj_file, "r", encoding="utf-8") as f: From 6a77408a80b4e3e98d64081015925ce71f630b63 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 14 Jul 2025 15:03:46 +0200 Subject: [PATCH 13/35] Revert "back to storage inside Cache()" This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913. --- docs/source/en/cache_explanation.md | 4 +- src/transformers/cache_utils.py | 405 +++++++----------- src/transformers/integrations/executorch.py | 12 +- src/transformers/models/bart/modeling_bart.py | 4 +- .../modeling_bigbird_pegasus.py | 4 +- .../models/biogpt/modeling_biogpt.py | 4 +- .../models/blenderbot/modeling_blenderbot.py | 4 +- .../modeling_blenderbot_small.py | 4 +- src/transformers/models/dia/modeling_dia.py | 4 +- src/transformers/models/dia/modular_dia.py | 4 +- .../models/gemma3n/modeling_gemma3n.py | 8 +- .../models/gemma3n/modular_gemma3n.py | 8 +- src/transformers/models/gptj/modeling_gptj.py | 4 +- .../models/informer/modeling_informer.py | 8 +- .../models/informer/modular_informer.py | 4 +- .../models/longt5/modeling_longt5.py | 4 +- .../models/m2m_100/modeling_m2m_100.py | 4 +- .../models/marian/modeling_marian.py | 4 +- .../models/mbart/modeling_mbart.py | 4 +- .../models/minimax/modeling_minimax.py | 6 +- .../models/minimax/modular_minimax.py | 6 +- .../models/mllama/modeling_mllama.py | 4 +- .../models/moonshine/modeling_moonshine.py | 4 +- .../models/moonshine/modular_moonshine.py | 4 +- src/transformers/models/mt5/modeling_mt5.py | 4 +- .../models/pegasus/modeling_pegasus.py | 4 +- .../models/pegasus_x/modeling_pegasus_x.py | 4 +- .../models/pix2struct/modeling_pix2struct.py | 4 +- .../models/plbart/modeling_plbart.py | 4 +- .../models/pop2piano/modeling_pop2piano.py | 4 +- .../modeling_switch_transformers.py | 4 +- src/transformers/models/t5/modeling_t5.py | 4 +- .../models/t5gemma/modeling_t5gemma.py | 4 +- .../models/t5gemma/modular_t5gemma.py | 4 +- .../modeling_time_series_transformer.py | 4 +- src/transformers/models/udop/modeling_udop.py | 4 +- src/transformers/models/umt5/modeling_umt5.py | 4 +- .../models/whisper/generation_whisper.py | 4 +- .../models/whisper/modeling_whisper.py | 4 +- tests/generation/test_utils.py | 80 ++-- .../deepseek_v2/test_modeling_deepseek_v2.py | 11 +- .../deepseek_v3/test_modeling_deepseek_v3.py | 8 +- tests/models/dia/test_modeling_dia.py | 8 +- .../models/gpt_neox/test_modeling_gpt_neox.py | 4 +- tests/models/t5gemma/test_modeling_t5gemma.py | 36 +- tests/utils/test_cache_utils.py | 50 +-- 46 files changed, 341 insertions(+), 437 deletions(-) diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 410304d7356a..538e46ba846a 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -89,8 +89,8 @@ Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWi The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token: ```py -cache.key_cache[idx] = torch.cat([cache.key_cache[idx], key_states], dim=-2) -cache.value_cache[idx] = torch.cat([cache.value_cache[idx], value_states], dim=-2) +cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2) +cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2) ``` Other layer types like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added. diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index a3d24bebac75..a7967b37bd61 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -63,7 +63,7 @@ class Cache: This class handles propagation of operations across layers. Parameters: - config (`PretrainedConfig`): + config (`PretrainedConfig`, *optional*): Model configuration for shape/device info. cache_processor (`CacheProcessor` or `str`, *optional*): Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") or @@ -94,8 +94,6 @@ def __init__( **kwargs, ): self.layers: list["CacheLayerMixin"] = [] - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor if ( @@ -119,9 +117,11 @@ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: sequence length. """ if layer_idx < len(self.layers): - return self.key_cache[layer_idx], self.value_cache[layer_idx] + return self.layers[layer_idx].keys, self.layers[layer_idx].values else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + raise KeyError( + f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" + ) def __iter__(self): """ @@ -129,21 +129,26 @@ def __iter__(self): keys and values """ for layer_idx in range(len(self)): - yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + yield (self.layers[layer_idx].keys, self.layers[layer_idx].values) def __len__(self): """ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds to the number of layers in the model. """ + # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked that rely on __len__ + if getattr(self, "layers", None) is None: + if getattr(self, "key_cache", None) is not None: + return len(self.key_cache) + return 0 # Empty dynamic caches initialize an empty layer to be ready for first update dynamic_empty = ( getattr(self, "layers", None) is not None and len(self.layers) == 1 and isinstance(self.layers[0], DynamicLayer) - and self.key_cache[0] is None + and self.layers[0].keys is None ) - return len(self.key_cache) if not dynamic_empty else 0 + return len(self.layers) if not dynamic_empty else 0 def __repr__(self): return f"{self.__class__.__name__}(layers={self.layers})" @@ -163,9 +168,6 @@ def append_new_layers(self, layer_idx: int) -> None: args["device"] = args.pop("layer_device_map")[layer_idx] new_layer = self.layer_classes[layer_idx % len(self.layer_classes)](**args) self.layers.append(new_layer) - new_key, new_value = new_layer.new_tensors() - self.key_cache.append(new_key) - self.value_cache.append(new_value) @apply_processors def update( @@ -193,27 +195,16 @@ def update( A tuple containing the updated key and value states. """ self.append_new_layers(layer_idx) - self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].update( - key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], cache_kwargs - ) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def to(self, device: torch.device) -> "Cache": - """Moves the cache to the given device.""" - for idx in range(len(self.key_cache)): - self.key_cache[idx] = self.key_cache[idx].to(device) - self.value_cache[idx] = self.value_cache[idx].to(device) - return self + return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) def get_seq_length(self, layer_idx: int = 0) -> int: """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" if layer_idx >= len(self.layers): return 0 - seq_length = self.layers[layer_idx].get_seq_length(self.key_cache[layer_idx], self.value_cache[layer_idx]) # Hack since QuantizedCache messes with keys shape as it becomes the residual cache if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): - return self.cache_processor.erased_length + seq_length - return seq_length + return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length() + return self.layers[layer_idx].get_seq_length() def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: """ @@ -222,51 +213,39 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), for each layer. """ - kv_length, kv_offset = self.layers[layer_idx].get_mask_sizes( - cache_position, self.key_cache[layer_idx], self.value_cache[layer_idx] - ) + kv_length, kv_offset = self.layers[layer_idx].get_mask_sizes(cache_position) return kv_length, kv_offset ### Wrappers for layer operations and properties ### def get_max_cache_shape(self, layer_idx: int = 0) -> int: """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" - return self.layers[layer_idx].get_max_cache_shape(self.key_cache[layer_idx], self.value_cache[layer_idx]) + return self.layers[layer_idx].get_max_cache_shape() def reset(self): """Recursively reset all layers tensors""" for layer_idx in range(len(self.layers)): - self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].reset( - self.key_cache[layer_idx], self.value_cache[layer_idx] - ) + self.layers[layer_idx].reset() def reorder_cache(self, beam_idx: torch.LongTensor): """Reorder the cache for beam search""" for layer_idx in range(len(self.layers)): - self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].reorder_cache( - beam_idx, self.key_cache[layer_idx], self.value_cache[layer_idx] - ) + self.layers[layer_idx].reorder_cache(beam_idx) def crop(self, max_length: int): """Crop the cache to the given length""" for layer_idx in range(len(self.layers)): - self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].crop( - max_length, self.key_cache[layer_idx], self.value_cache[layer_idx] - ) + self.layers[layer_idx].crop(max_length) def batch_repeat_interleave(self, repeats: int): """Repeat and interleave the cache""" for layer_idx in range(len(self.layers)): - self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].batch_repeat_interleave( - repeats, self.key_cache[layer_idx], self.value_cache[layer_idx] - ) + self.layers[layer_idx].batch_repeat_interleave(repeats) def batch_select_indices(self, indices: torch.Tensor): """Select indices from the cache""" for layer_idx in range(len(self.layers)): - self.key_cache[layer_idx], self.value_cache[layer_idx] = self.layers[layer_idx].batch_select_indices( - indices, self.key_cache[layer_idx], self.value_cache[layer_idx] - ) + self.layers[layer_idx].batch_select_indices(indices) @property def max_batch_size(self) -> int: @@ -304,16 +283,12 @@ def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Updates KV cache, returns updated keys/values for this layer.""" raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") - def get_seq_length( - self, key_cache: Optional[torch.Tensor] = None, value_cache: Optional[torch.Tensor] = None - ) -> int: + def get_seq_length(self) -> int: """Returns the sequence length of this layer's cache.""" raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.") @@ -321,37 +296,22 @@ def get_max_cache_shape(self) -> int: """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") - def reset(self, key_cache: torch.Tensor, value_cache: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def reset(self) -> tuple[torch.Tensor, torch.Tensor]: """Resets this layer's cache.""" raise NotImplementedError(f"Make sure to implement `reset` in {self.__class__.__name__}.") - def get_mask_sizes( - self, - cache_position: torch.Tensor, - key_cache: Optional[torch.Tensor] = None, - value_cache: Optional[torch.Tensor] = None, - ) -> tuple[int, int]: + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: """Returns mask sizes for this layer's cache.""" raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.") - def new_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: - """Returns a new key and value tensor for this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `new_tensors` in {self.__class__.__name__}.") - - def reorder_cache( - self, - beam_idx: torch.LongTensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + def reorder_cache(self, beam_idx: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]: """Reorders this layer's cache for beam search.""" - if key_cache.numel(): - device = key_cache.device - key_cache = key_cache.index_select(0, beam_idx.to(device)) - if value_cache.numel(): - device = value_cache.device - value_cache = value_cache.index_select(0, beam_idx.to(device)) - return key_cache, value_cache + if self.keys.numel(): + device = self.keys.device + self.keys = self.keys.index_select(0, beam_idx.to(device)) + if self.values.numel(): + device = self.values.device + self.values = self.values.index_select(0, beam_idx.to(device)) class DynamicLayer(CacheLayerMixin): @@ -360,12 +320,19 @@ class DynamicLayer(CacheLayerMixin): It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. """ + keys, values = None, None + + @classmethod + def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer": + layer = cls() + layer.keys = keys + layer.values = values + return layer + def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - key_cache: Optional[torch.Tensor], - value_cache: Optional[torch.Tensor], cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -382,108 +349,68 @@ def update( Return: A tuple containing the updated key and value states. """ - if key_cache is None: - key_cache, value_cache = key_states, value_states + if self.keys is None: + self.keys = key_states + self.values = value_states else: - key_cache = torch.cat([key_cache, key_states], dim=-2) - value_cache = torch.cat([value_cache, value_states], dim=-2) - return key_cache, value_cache + self.keys = torch.cat([self.keys, key_states], dim=-2) + self.values = torch.cat([self.values, value_states], dim=-2) + return self.keys, self.values - def new_tensors(self) -> tuple[None, None]: - """Returns a new key and value tensor for this layer's cache.""" - return None, None - - def get_seq_length( - self, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cache_position: Optional[torch.LongTensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor, int]: + def get_seq_length(self) -> int: """Returns the sequence length of the cached states.""" - # TODO: deprecate this function in favor of `cache_position` - if key_cache is None or key_cache.numel() == 0: + if self.keys is None or self.keys.numel() == 0: return 0 - return key_cache.shape[-2] + return self.keys.shape[-2] - def get_max_cache_shape( - self, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - ) -> int: + def get_max_cache_shape(self) -> int: """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" return -1 - def reset( - self, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + def reset(self) -> None: """Resets the cache values while preserving the objects""" - key_cache.zero_() - value_cache.zero_() - return key_cache, value_cache + self.keys.zero_() + self.values.zero_() + return self.keys, self.values - def reorder_cache( - self, - beam_idx: torch.LongTensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: """Reorders the cache for beam search, given the selected beam indices.""" - if key_cache is not None and key_cache.numel(): - key_cache = key_cache.index_select(0, beam_idx.to(key_cache.device)) - value_cache = value_cache.index_select(0, beam_idx.to(value_cache.device)) - return key_cache, value_cache + if self.keys is not None and self.keys.numel(): + self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device)) + self.values = self.values.index_select(0, beam_idx.to(self.values.device)) - def crop( - self, - max_length: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + def crop(self, max_length: int) -> None: """ Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative to remove `max_length` tokens. """ if max_length < 0: - max_length = self.get_seq_length(key_cache, value_cache) - abs(max_length) + max_length = self.get_seq_length() - abs(max_length) - if self.get_seq_length(key_cache, value_cache) <= max_length: - return key_cache, value_cache + if self.get_seq_length() <= max_length: + return - if key_cache is not None and key_cache.numel(): - key_cache = key_cache[..., :max_length, :] - value_cache = value_cache[..., :max_length, :] - return key_cache, value_cache + if self.keys is not None and self.keys.numel(): + self.keys = self.keys[..., :max_length, :] + self.values = self.values[..., :max_length, :] - def batch_repeat_interleave( - self, repeats: int, key_cache: torch.Tensor, value_cache: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def batch_repeat_interleave(self, repeats: int) -> None: """Repeat the cache `repeats` times in the batch dimension.""" - if key_cache.numel(): - key_cache = key_cache.repeat_interleave(repeats, dim=0) - value_cache = value_cache.repeat_interleave(repeats, dim=0) - return key_cache, value_cache + if self.keys is not None and self.keys.numel(): + self.keys = self.keys.repeat_interleave(repeats, dim=0) + self.values = self.values.repeat_interleave(repeats, dim=0) - def batch_select_indices( - self, indices: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def batch_select_indices(self, indices: torch.Tensor) -> None: """Only keep the `indices` in the batch dimension of the cache.""" - if key_cache.numel(): - key_cache = key_cache[indices, ...] - value_cache = value_cache[indices, ...] - return key_cache, value_cache + if self.keys is not None and self.keys.numel(): + self.keys = self.keys[indices, ...] + self.values = self.values[indices, ...] - def get_mask_sizes( - self, - cache_position: torch.Tensor, - key_cache: Optional[torch.Tensor] = None, - value_cache: Optional[torch.Tensor] = None, - ) -> tuple[int, int]: + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: """Return the length and offset of the cache, used to generate the mask""" full_mask_kv_offset = 0 query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length(key_cache, value_cache, cache_position) + past_seen_tokens = self.get_seq_length() kv_length = query_length + past_seen_tokens return kv_length, full_mask_kv_offset @@ -572,7 +499,23 @@ def __init__( self.dtype = dtype self.device = device - def get_max_cache_shape(self, key_cache: torch.Tensor, value_cache: torch.Tensor) -> int: + # Note: There will be significant perf decrease if switching to use 5D tensors instead. + self.keys = torch.zeros( + (batch_size, num_heads, self.max_cache_len, head_dim), + dtype=dtype, + device=device, + ) + self.values = torch.zeros( + (batch_size, num_heads, self.max_cache_len, head_dim), + dtype=dtype, + device=device, + ) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(self.keys) + torch._dynamo.mark_static_address(self.values) + + def get_max_cache_shape(self) -> int: """Return the maximum cache shape of the cache""" return self.max_cache_len @@ -581,97 +524,68 @@ def _static_update( key_states: torch.Tensor, value_states: torch.Tensor, cache_position: Optional[torch.LongTensor], - key_cache: torch.Tensor, - value_cache: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ Updates the static cache tensors in place. Args: + k_cache (`torch.Tensor`): The key cache tensor to update. + v_cache (`torch.Tensor`): The value cache tensor to update. key_states (`torch.Tensor`): The new key states to add. value_states (`torch.Tensor`): The new value states to add. cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted. If None, the entire cache is overwritten (prefill). - key_cache (`torch.Tensor`): The key cache tensor to update. - value_cache (`torch.Tensor`): The value cache tensor to update. Returns: tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place). """ if cache_position is None: # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. - key_cache.copy_(key_states) - value_cache.copy_(value_states) + self.keys.copy_(key_states) + self.values.copy_(value_states) else: # Generation phase. Update specific positions. # Use index_copy_ for in-place update (compile-friendly). try: - key_cache.index_copy_(2, cache_position, key_states) - value_cache.index_copy_(2, cache_position, value_states) + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) except NotImplementedError: # Fallback for devices like MPS where index_copy_ might not be supported. - key_cache[:, :, cache_position] = key_states - value_cache[:, :, cache_position] = value_states - return key_cache, value_cache + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + return self.keys, self.values def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Update the static cache tensors in place""" cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None - return self._static_update( - key_states.to(key_cache.dtype), value_states.to(value_cache.dtype), cache_position, key_cache, value_cache - ) + key_states = key_states.to(self.keys.dtype) + value_states = value_states.to(self.values.dtype) + return self._static_update(key_states, value_states, cache_position) - def new_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: - """Returns a new key and value tensor for this layer's cache.""" - # Note: There will be significant perf decrease if switching to use 5D tensors instead. - keys = torch.zeros( - (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), - dtype=self.dtype, - device=self.device, - ) - values = torch.zeros( - (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), - dtype=self.dtype, - device=self.device, - ) - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(keys) - torch._dynamo.mark_static_address(values) - return keys, values - - def get_seq_length(self, key_cache: torch.Tensor, value_cache: torch.Tensor, cache_position=None) -> int: + def get_seq_length(self, cache_position=None) -> int: if cache_position is not None: return int(cache_position[-1] + 1) # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. - seq_length = (key_cache[0, 0].any(dim=-1)).sum() if key_cache is not None else 0 + seq_length = (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0 return seq_length - def reset(self, key_cache: torch.Tensor, value_cache: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - key_cache.zero_() - value_cache.zero_() - return key_cache, value_cache + def reset(self) -> None: + self.keys.zero_() + self.values.zero_() - def reorder_cache( - self, beam_idx: torch.LongTensor, key_cache: torch.Tensor, value_cache: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - dev = key_cache.device + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + dev = self.keys.device beam_idx_dev = beam_idx.to(dev) - key_cache = key_cache.index_select(0, beam_idx_dev) - value_cache = value_cache.index_select(0, beam_idx_dev) - return key_cache, value_cache + self.keys = self.keys.index_select(0, beam_idx_dev) + self.values = self.values.index_select(0, beam_idx_dev) - def get_mask_sizes( - self, cache_position: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor - ) -> tuple[int, int]: + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: full_mask_kv_offset = 0 full_mask_kv_length = self.max_cache_len return full_mask_kv_length, full_mask_kv_offset @@ -691,18 +605,17 @@ def _static_update( key_states: torch.Tensor, value_states: torch.Tensor, cache_position: torch.LongTensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ Updates the sliding window cache tensors, returning the potentially modified tensors. Args: + k_cache (`torch.Tensor`): The key cache tensor to update. + v_cache (`torch.Tensor`): The value cache tensor to update. key_states (`torch.Tensor`): The new key states to add. value_states (`torch.Tensor`): The new value states to add. cache_position (`torch.LongTensor`): The position indices where the new states should be inserted. - key_cache (`torch.Tensor`): The key cache tensor to update. - value_cache (`torch.Tensor`): The value cache tensor to update. + max_cache_len (`int`): The maximum length of the sliding window cache. Returns: tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update. @@ -717,9 +630,9 @@ def _static_update( if cache_position.shape[0] > self.max_cache_len: new_k = key_states[:, :, -self.max_cache_len :, :] new_v = value_states[:, :, -self.max_cache_len :, :] - key_cache.copy_(new_k) - value_cache.copy_(new_v) - return key_cache, value_cache + self.keys.copy_(new_k) + self.values.copy_(new_v) + return self.keys, self.values # Sliding window logic for generation phase or prefill < window slicing = torch.arange(self.max_cache_len, device=value_states.device) @@ -727,8 +640,8 @@ def _static_update( to_shift = current_seq_len > self.max_cache_len indices = (slicing + to_shift.sum()) % self.max_cache_len - k_out_shifted = key_cache[:, :, indices] - v_out_shifted = value_cache[:, :, indices] + k_out_shifted = self.keys[:, :, indices] + v_out_shifted = self.values[:, :, indices] # Clamp cache_position to determine the *target index* within the shifted cache view update_position = cache_position.clamp(min=0, max=self.max_cache_len - 1) @@ -743,13 +656,11 @@ def _static_update( k_out_updated[:, :, update_position] = key_states v_out_updated[:, :, update_position] = value_states - key_cache.copy_(k_out_updated) - value_cache.copy_(v_out_updated) - return key_cache, value_cache + self.keys.copy_(k_out_updated) + self.values.copy_(v_out_updated) + return self.keys, self.values - def get_mask_sizes( - self, cache_position: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor - ) -> tuple[int, int]: + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: query_length = cache_position.shape[0] first_cache_position = cache_position[0] @@ -794,9 +705,7 @@ def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.T # compatibility. The name of the argument doesn't matter. if ddp_cache_data is not None: for key_states, value_states in ddp_cache_data: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - self.layers.append(DynamicLayer()) + self.layers.append(DynamicLayer.from_tensors(key_states, value_states)) super().__init__(*args, **kwargs) def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: @@ -805,8 +714,8 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: backward compatibility. """ legacy_cache = () - for keys, values in zip(self.key_cache, self.value_cache): - legacy_cache += ((keys, values),) + for layer in self.layers: + legacy_cache += ((layer.keys, layer.values),) return legacy_cache @classmethod @@ -839,16 +748,16 @@ def _flatten_dynamic_cache( ) dictionary = { - "key_cache": dynamic_cache.key_cache if dynamic_cache.key_cache[0] is not None else [], - "value_cache": dynamic_cache.value_cache if dynamic_cache.value_cache[0] is not None else [], + "key_cache": [layer.keys for layer in dynamic_cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in dynamic_cache.layers if layer.values is not None], } return torch.utils._pytree._dict_flatten(dictionary) def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): dictionary = { - "key_cache": dynamic_cache.key_cache if dynamic_cache.key_cache[0] is not None else [], - "value_cache": dynamic_cache.value_cache if dynamic_cache.value_cache[0] is not None else [], + "key_cache": [layer.keys for layer in dynamic_cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in dynamic_cache.layers if layer.values is not None], } return torch.utils._pytree._dict_flatten_with_keys(dictionary) @@ -871,8 +780,8 @@ def _unflatten_dynamic_cache( def _flatten_dynamic_cache_for_fx(cache, spec): dictionary = { - "key_cache": cache.key_cache if cache.key_cache[0] is not None else [], - "value_cache": cache.value_cache if cache.value_cache[0] is not None else [], + "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in cache.layers if layer.values is not None], } return torch.fx._pytree._dict_flatten_spec(dictionary, spec) @@ -1004,8 +913,8 @@ def __init__(self, cache: "Cache", offload_device: Union[str, torch.device] = "c if self.is_static: for i, layer in enumerate(cache.layers): device = cache.layer_init_args["device"] if i == 0 else self.offload_device - cache.key_cache[i] = cache.key_cache[i].to(device) - cache.value_cache[i] = cache.value_cache[i].to(device) + layer.keys = layer.keys.to(device) + layer.values = layer.values.to(device) self.original_device.append(cache.layer_init_args["device"]) if len(cache) != cache.model_num_layers: raise ValueError("If static layers are used, all cache layers must be initialized") @@ -1052,18 +961,18 @@ def _prefetch_layer(self, cache: "Cache", layer_idx: int): ): # Prefetch next layer tensors to GPU device = self.original_device[layer_idx] - cache.key_cache[layer_idx] = cache.key_cache[layer_idx].to(device, non_blocking=True) - cache.value_cache[layer_idx] = cache.value_cache[layer_idx].to(device, non_blocking=True) + cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.to(device, non_blocking=True) + cache.layers[layer_idx].values = cache.layers[layer_idx].values.to(device, non_blocking=True) def _evict_previous_layer(self, cache: "Cache", layer_idx: int): """Moves the previous layer cache to the CPU.""" if len(cache) >= 2: # Layer 0 stays on device to be on-device after all layers are created # We do it on the default stream so it occurs after all earlier computations on these tensors are done prev_layer_idx = (layer_idx - 1) % len(cache) - cache.key_cache[prev_layer_idx] = cache.key_cache[prev_layer_idx].to( + cache.layers[prev_layer_idx].keys = cache.layers[prev_layer_idx].keys.to( self.offload_device, non_blocking=True ) - cache.value_cache[prev_layer_idx] = cache.value_cache[prev_layer_idx].to( + cache.layers[prev_layer_idx].values = cache.layers[prev_layer_idx].values.to( self.offload_device, non_blocking=True ) @@ -1076,8 +985,8 @@ def _ensure_layer_on_device(self, cache: "Cache", layer_idx: int): # Handle delayed beam search operations if self.beam_idx is not None: self.beam_idx = self.beam_idx.to(self.original_device[layer_idx]) - cache.key_cache[layer_idx] = cache.key_cache[layer_idx].index_select(0, self.beam_idx) - cache.value_cache[layer_idx] = cache.value_cache[layer_idx].index_select(0, self.beam_idx) + cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.index_select(0, self.beam_idx) + cache.layers[layer_idx].values = cache.layers[layer_idx].values.index_select(0, self.beam_idx) class QuantizedCacheProcessor(CacheProcessor): @@ -1213,12 +1122,12 @@ def post_update( # Clear the residual cache self.erased_length = key_tensors.shape[-2] - cache.key_cache[layer_idx] = torch.zeros( + cache.layers[layer_idx].keys = torch.zeros( 0, dtype=key_tensors.dtype, device=key_tensors.device, ) - cache.value_cache[layer_idx] = torch.zeros( + cache.layers[layer_idx].values = torch.zeros( 0, dtype=value_tensors.dtype, device=value_tensors.device, @@ -1239,12 +1148,12 @@ def post_update( # Clear the residual cache self.erased_length += key_tensors.shape[-2] - cache.key_cache[layer_idx] = torch.zeros( + cache.layers[layer_idx].keys = torch.zeros( 0, dtype=key_tensors.dtype, device=key_tensors.device, ) - cache.value_cache[layer_idx] = torch.zeros( + cache.layers[layer_idx].values = torch.zeros( 0, dtype=value_tensors.dtype, device=value_tensors.device, @@ -1548,10 +1457,10 @@ def __iter__(self): """ for layer_idx in range(len(self)): yield ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], + self.self_attention_cache.layers[layer_idx].keys, + self.self_attention_cache.layers[layer_idx].values, + self.cross_attention_cache.layers[layer_idx].keys, + self.cross_attention_cache.layers[layer_idx].values, ) def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -1561,10 +1470,10 @@ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch """ if layer_idx < len(self): return ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], + self.self_attention_cache.layers[layer_idx].keys, + self.self_attention_cache.layers[layer_idx].values, + self.cross_attention_cache.layers[layer_idx].keys, + self.cross_attention_cache.layers[layer_idx].values, ) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 18b7b3a92980..27e41c482570 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -282,8 +282,8 @@ def __init__(self, model: PreTrainedModel): dtype=self.model.dtype, ) for i in range(len(self.static_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): """ @@ -413,8 +413,8 @@ def __init__( # Register all key and value cache tensors as buffers for i in range(len(self.cache)): - self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False) + self.register_buffer(f"key_cache_{i}", self.cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.cache.layers[i].values, persistent=False) def forward( self, @@ -560,8 +560,8 @@ def __init__(self, model, max_static_cache_length, batch_size): # Register cache buffers to make them exportable for i in range(len(self.static_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): # Get outputs from decoder diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 7d768a734827..00df14eb0c17 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -230,8 +230,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 1e126fbcaff8..e422a5733e43 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1293,8 +1293,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index a9fa3584d424..4625cf8f2a68 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -207,8 +207,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 659b856a77c1..0169b8d70686 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -229,8 +229,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index f82a0d322283..cfa344230d5f 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -213,8 +213,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 38bc0715fc6d..ccde9fc33d10 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -364,8 +364,8 @@ def forward( is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False if past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] - value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] + key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys + value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values else: key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index 18512856dd43..8346bcc93268 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -183,8 +183,8 @@ def forward( is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False if past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] - value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] + key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys + value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values else: key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 1e3641ddc6a2..d7159a59b2d5 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1329,7 +1329,7 @@ def forward( if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: # Device of past layer may be different from current one - indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device) + indices = cache_position.to(past_key_value.layers[self.kv_shared_layer_index].keys.device) # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) if isinstance(past_key_value, HybridCache) and self.is_sliding: max_length = past_key_value.sliding_window @@ -1340,9 +1340,9 @@ def forward( ) # Device of past layer may be different from current one - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device) - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to( - query_states.device + key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices].to(query_states.device) + value_states = ( + past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices].to(query_states.device) ) else: key_states = self.k_proj(hidden_states).view(hidden_shape) diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index b2ff4d7daef5..bcecd8370801 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1770,7 +1770,7 @@ def forward( if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: # Device of past layer may be different from current one - indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device) + indices = cache_position.to(past_key_value.layers[self.kv_shared_layer_index].keys.device) # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) if isinstance(past_key_value, HybridCache) and self.is_sliding: max_length = past_key_value.sliding_window @@ -1781,9 +1781,9 @@ def forward( ) # Device of past layer may be different from current one - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device) - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to( - query_states.device + key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices].to(query_states.device) + value_states = ( + past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices].to(query_states.device) ) else: key_states = self.k_proj(hidden_states).view(hidden_shape) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 3c62372abf22..4c0005009ec8 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -704,7 +704,9 @@ def forward( # Ensure layer_past is on same device as hidden_states (might not be correct) if past_key_values is not None: - past_key_values = past_key_values.to(hidden_states.device) + for layer in past_key_values.layers: + layer.keys = layer.keys.to(hidden_states.device) + layer.values = layer.values.to(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states if causal_mask is not None: diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 9718e8fb736e..c988a874c20f 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -484,8 +484,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) @@ -601,8 +601,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 3d46275bdc81..606627823514 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -290,8 +290,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 537b816b79f8..a103c6fe83d1 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -481,8 +481,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index df2f6cdaa96c..2a1192b7e302 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -294,8 +294,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index c4393a3948bd..5e6800ce11a5 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -229,8 +229,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 861eeaf68ec0..964963b0d659 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -239,8 +239,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 2157315238ae..b27edb72c84e 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -108,16 +108,14 @@ def batch_repeat_interleave(self, repeats: int): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) else: - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.layers[layer_idx].batch_repeat_interleave(repeats) def batch_select_indices(self, indices: torch.Tensor): for layer_idx in range(len(self)): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] else: - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.layers[layer_idx].batch_select_indices(indices) def crop(self, max_length: int): raise RuntimeError("MiniMaxCache doesnot support `crop` method") diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 9e2ee0b17b03..35d6e6c0f153 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -217,16 +217,14 @@ def batch_repeat_interleave(self, repeats: int): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) else: - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.layers[layer_idx].batch_repeat_interleave(repeats) def batch_select_indices(self, indices: torch.Tensor): for layer_idx in range(len(self)): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] else: - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.layers[layer_idx].batch_select_indices(indices) def crop(self, max_length: int): raise RuntimeError("MiniMaxCache doesnot support `crop` method") diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 0f68c2d03d7e..4acd4e0e33ba 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -496,8 +496,8 @@ def forward( ) elif cache_position[0] != 0: key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], + past_key_value.layers[self.layer_idx].keys, + past_key_value.layers[self.layer_idx].values, ) else: raise ValueError( diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 9b55fa8c961b..7c3137af622c 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -234,8 +234,8 @@ def forward( # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = ( self.k_proj(current_states) diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 7180d35e8e6e..1dd2a4af9f2c 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -333,8 +333,8 @@ def forward( # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = ( self.k_proj(current_states) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 0341c3afcfde..cab7e59d7b3e 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -379,8 +379,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index d8f8a511bc0c..88064f135740 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -228,8 +228,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 500150bba4d3..2dc7f66ad1cd 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -249,8 +249,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 93c281dc5bab..89b4f542cb3f 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -773,8 +773,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.key(current_states) value_states = self.value(current_states) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 3b071a1fe3d9..506cdce64661 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -426,8 +426,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index dca3191be9e0..7235314c1ea0 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -323,8 +323,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 63f1bfadeebe..7dd0bd264d51 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -516,8 +516,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 216ea793fd87..99eca814eea8 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -504,8 +504,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 6e31f181327d..080c1f6d78f3 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -339,8 +339,8 @@ def forward( key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) past_key_value.is_updated[self.layer_idx] = True else: - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index 9360008e30e6..6b3837e7422e 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -291,8 +291,8 @@ def forward( key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) past_key_value.is_updated[self.layer_idx] = True else: - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 778a0485b4e8..4c37cd42ef63 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -394,8 +394,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 23af62e4a1dc..277c41feadbb 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -602,8 +602,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 4b9d96db2f21..02295602a1d6 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -288,8 +288,8 @@ def forward( current_states = encoder_hidden_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 1bed6ce27b46..e7b80dbf4720 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1149,8 +1149,8 @@ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None for layer_idx in range(self.config.decoder_layers): layer_past_key_values = [] for cache_cls in [values.self_attention_cache, values.cross_attention_cache]: - for v in [cache_cls.key_cache, cache_cls.value_cache]: - layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu()) + for v in [cache_cls.layers[layer_idx].keys, cache_cls.layers[layer_idx].values]: + layer_past_key_values.append(v[batch_idx][None].cpu()) all_past_key_values.append(tuple(layer_past_key_values)) return tuple(all_past_key_values) else: diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f8473cda9fe0..c869f8ae23d2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -330,8 +330,8 @@ def forward( current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d22b9be51bee..b8006860f155 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1598,26 +1598,26 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] + self_attention_layer_keys = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys ) - self_attention_layer_value_cache = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] + self_attention_layer_values = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values ) - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_key_cache = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] + cross_attention_layer_keys = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys ) - cross_attention_layer_value_cache = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] + cross_attention_layer_values = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values ) - cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] - cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] - self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) + cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] + cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] + self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) # 3.2. Decoder-only checks else: @@ -1629,10 +1629,18 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] - self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + if is_legacy_cache: + self_attention_layer_keys = past_kv[i][0] + self_attention_layer_values = past_kv[i][1] + elif getattr(past_kv, "layers", None) is None: + # Cache is lot layered (i.e, Mamba derivatives) + self_attention_layer_keys = past_kv.key_cache[i] + self_attention_layer_values = past_kv.value_cache[i] + else: + self_attention_layer_keys = past_kv.layers[i].keys + self_attention_layer_values = past_kv.layers[i].values + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) @pytest.mark.generate def test_generate_from_random_inputs_embeds(self): @@ -1816,7 +1824,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self): cache_shape = [batch_size, num_key_value_heads, max_length, head_dim] self.assertIsInstance(outputs.past_key_values, StaticCache) self.assertEqual(len(outputs.past_key_values), num_hidden_layers) - self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape) + self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape) @pytest.mark.generate def test_generate_continue_from_past_key_values(self): @@ -2039,7 +2047,7 @@ def test_generate_with_static_cache(self): cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache)) self.assertTrue(len(static_cache_generation.past_key_values) == num_hidden_layers) - self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape) + self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape) # Check 2: The outputs must be similar to the case with dynamic cache dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) @@ -2628,12 +2636,12 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value if isinstance(decoder_past_key_values, Cache): self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_shape] * len(decoder_past_key_values.key_cache), + [layer.keys.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) self.assertListEqual( - [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], - [expected_shape] * len(decoder_past_key_values.value_cache), + [layer.values.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) # Legacy cache format checks. This branch should be removed when all models use `Cache` by default @@ -4039,13 +4047,13 @@ def test_generate_with_static_cache_multi_accelerator(self): self.assertTrue(isinstance(results.past_key_values, StaticCache)) # check device of each layer - key_cache_0 = results.past_key_values.key_cache[0] - value_cache_0 = results.past_key_values.value_cache[0] - self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + keys_0 = results.past_key_values.layers[0].keys + values_0 = results.past_key_values.layers[0].values + self.assertTrue(keys_0.device == values_0.device == torch.device(0)) - key_cache_1 = results.past_key_values.key_cache[1] - value_cache_1 = results.past_key_values.value_cache[1] - self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + keys_1 = results.past_key_values.layers[1].keys + values_1 = results.past_key_values.layers[1].values + self.assertTrue(keys_1.device == values_1.device == torch.device(1)) @pytest.mark.generate @require_torch_multi_accelerator @@ -4117,13 +4125,13 @@ def test_init_static_cache_multi_accelerator(self): results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) # check device of each layer - key_cache_0 = results.past_key_values.key_cache[0] - value_cache_0 = results.past_key_values.value_cache[0] - self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + keys_0 = results.past_key_values.layers[0].keys + values_0 = results.past_key_values.layers[0].values + self.assertTrue(keys_0.device == values_0.device == torch.device(0)) - key_cache_1 = results.past_key_values.key_cache[1] - value_cache_1 = results.past_key_values.value_cache[1] - self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + keys_1 = results.past_key_values.layers[1].keys + values_1 = results.past_key_values.layers[1].values + self.assertTrue(keys_1.device == values_1.device == torch.device(1)) @slow def test_padding_input_contrastive_search_gpt2(self): diff --git a/tests/models/deepseek_v2/test_modeling_deepseek_v2.py b/tests/models/deepseek_v2/test_modeling_deepseek_v2.py index 02d087cb8b9a..0bdc6884590f 100644 --- a/tests/models/deepseek_v2/test_modeling_deepseek_v2.py +++ b/tests/models/deepseek_v2/test_modeling_deepseek_v2.py @@ -168,14 +168,9 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value expected_value_shape = expected_common_shape + (config.v_head_dim,) if isinstance(decoder_past_key_values, Cache): - self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_key_shape] * len(decoder_past_key_values.key_cache), - ) - self.assertListEqual( - [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], - [expected_value_shape] * len(decoder_past_key_values.value_cache), - ) + for layer in decoder_past_key_values.layers: + self.assertEqual(layer.keys.shape, expected_key_shape) + self.assertEqual(layer.values.shape, expected_value_shape) @unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape") def test_generate_compilation_all_outputs(self): diff --git a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py index 6c0c3a19d067..87f7b2abb0e9 100644 --- a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py +++ b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -440,13 +440,11 @@ def test_past_key_values_format(self): # difference: last dim k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim v_embed_dim = config.v_head_dim - self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) - self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) + self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) + self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) # build the full cache shapes num_hidden_layers = config.num_hidden_layers - all_cache_shapes = [ - [self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers) - ] + all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)] super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes) @require_torch_large_accelerator diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py index 447491f90102..c6d1547afb40 100644 --- a/tests/models/dia/test_modeling_dia.py +++ b/tests/models/dia/test_modeling_dia.py @@ -399,12 +399,12 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value if isinstance(decoder_past_key_values, Cache): self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_shape] * len(decoder_past_key_values.key_cache), + [layer.keys.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) self.assertListEqual( - [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], - [expected_shape] * len(decoder_past_key_values.value_cache), + [layer.values.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) def _check_scores(self, batch_size, scores, generated_length, config): diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index b0a0a6a3ccb4..ecd2af9fdc6c 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -235,8 +235,8 @@ def copy_cache(cache: DynamicCache): """Deep copy a DynamicCache to reuse the same one multiple times.""" new_cache = cache for i in range(len(cache)): - new_cache.key_cache[i] = cache.key_cache[i].clone() - new_cache.value_cache[i] = cache.value_cache[i].clone() + new_cache.layers[i].keys = cache.layers[i].keys.clone() + new_cache.layers[i].values = cache.layers[i].values.clone() # Cached forward once with the attention mask provided and the other time without it (which should assume full attention) # We need to run both on a copy of the cache, otherwise it is modified in-place diff --git a/tests/models/t5gemma/test_modeling_t5gemma.py b/tests/models/t5gemma/test_modeling_t5gemma.py index 1fab6427b1cc..e806973d88d9 100644 --- a/tests/models/t5gemma/test_modeling_t5gemma.py +++ b/tests/models/t5gemma/test_modeling_t5gemma.py @@ -1068,26 +1068,26 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] + self_attention_layer_keys = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys ) - self_attention_layer_value_cache = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] + self_attention_layer_values = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values ) - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_key_cache = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] + cross_attention_layer_keys = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys ) - cross_attention_layer_value_cache = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] + cross_attention_layer_values = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values ) - cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] - cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] - self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) + cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] + cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] + self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) # 3.2. Decoder-only checks else: @@ -1099,10 +1099,10 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] - self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys + self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) @unittest.skip("Mismatch issue doesn't exist in T5Gemma.") def test_load_with_mismatched_shapes(self): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index db00dc7dbfe1..5c5ee08da6ee 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -673,11 +673,9 @@ def test_dynamic_cache_exportability(self): use_cache=True, ) self.assertTrue(torch.allclose(res.logits, res_eager.logits)) - for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache): - self.assertTrue(torch.allclose(k1, k2)) - - for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache): - self.assertTrue(torch.allclose(v1, v2)) + for l1, l2 in zip(res.past_key_values.layers, res_eager.past_key_values.layers): + self.assertTrue(torch.allclose(l1.keys, l2.keys)) + self.assertTrue(torch.allclose(l1.values, l2.values)) def test_dynamic_cache_exportability_multiple_run(self): # When exporting with DynamicCache, you should export two graphs: @@ -727,8 +725,8 @@ def test_dynamic_cache_exportability_multiple_run(self): dyn = torch.export.Dim("seq", max=512) for ix in range(len(past_key_values)): - shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None) - shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None) + shapes[past_key_values.layers[ix].keys] = (None, None, dyn, None) + shapes[past_key_values.layers[ix].values] = (None, None, dyn, None) ep_second = torch.export.export( model, @@ -769,11 +767,9 @@ def test_dynamic_cache_exportability_multiple_run(self): use_cache=True, ) - for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache): - self.assertTrue(torch.allclose(k1, k2)) - - for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache): - self.assertTrue(torch.allclose(v1, v2)) + for l1, l2 in zip(res_export_2.past_key_values.layers, res_eager_2.past_key_values.layers): + self.assertTrue(torch.allclose(l1.keys, l2.keys)) + self.assertTrue(torch.allclose(l1.values, l2.values)) @unittest.skip("Runs on my machine locally, passed, no idea why it does not online") def test_static_cache_exportability(self): @@ -953,7 +949,7 @@ def test_static_cache(self): cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" + static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" ) # Scenario 2: Fill to capacity @@ -964,7 +960,7 @@ def test_static_cache(self): cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" + static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" ) def test_sliding_window_cache(self): @@ -998,7 +994,7 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "SlidingWindowCache Scenario 1 failed", ) @@ -1019,7 +1015,7 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "SlidingWindowCache Scenario 2 failed", ) @@ -1034,7 +1030,7 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "SlidingWindowCache Scenario 3 failed", ) @@ -1068,7 +1064,7 @@ def test_hybrid_cache_static_mode(self): cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Static Scenario 1 failed", ) @@ -1081,7 +1077,7 @@ def test_hybrid_cache_static_mode(self): cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "HybridCache Static Scenario 2 failed", ) @@ -1120,7 +1116,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Sliding Scenario 1 failed", ) @@ -1141,7 +1137,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "HybridCache Sliding Scenario 2 failed", ) @@ -1154,7 +1150,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 3 failed", ) @@ -1169,7 +1165,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 4 failed", ) @@ -1189,10 +1185,10 @@ def test_dynamic_cache(self): cache = DynamicCache() cache.update(prefill, prefill, 0) cache.update(update3, update3, 0) - self.assertEqual(cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed") + self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed") cache.update(update4, update4, 0) self.assertEqual( - cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed" + cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed" ) # Scenario 2: prefill and update for two layers independently @@ -1209,10 +1205,10 @@ def test_dynamic_cache(self): cache.update(update4, update4, 0) cache.update(update4_1, update4_1, 1) self.assertEqual( - cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed" + cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed" ) self.assertEqual( - cache.key_cache[1][0, 0, :, 0].tolist(), + cache.layers[1].keys[0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0], "DynamicCache Scenario 2 layer 1 failed", ) From 13ec4a44827fe5a3d9cb86d25ddd4bd1e0cf9343 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 14 Jul 2025 15:14:42 +0200 Subject: [PATCH 14/35] cyril review --- src/transformers/cache_utils.py | 60 +++++++++++++-------------------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index a7967b37bd61..46f7879ed108 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -519,26 +519,27 @@ def get_max_cache_shape(self) -> int: """Return the maximum cache shape of the cache""" return self.max_cache_len - def _static_update( + def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - cache_position: Optional[torch.LongTensor], + cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Updates the static cache tensors in place. + Update the static cache tensors in place. Args: - k_cache (`torch.Tensor`): The key cache tensor to update. - v_cache (`torch.Tensor`): The value cache tensor to update. - key_states (`torch.Tensor`): The new key states to add. - value_states (`torch.Tensor`): The new value states to add. - cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted. - If None, the entire cache is overwritten (prefill). + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place). + tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value states. """ + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + key_states = key_states.to(self.keys.dtype) + value_states = value_states.to(self.values.dtype) + if cache_position is None: # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. self.keys.copy_(key_states) @@ -555,18 +556,6 @@ def _static_update( self.values[:, :, cache_position] = value_states return self.keys, self.values - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Update the static cache tensors in place""" - cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None - key_states = key_states.to(self.keys.dtype) - value_states = value_states.to(self.values.dtype) - return self._static_update(key_states, value_states, cache_position) - def get_seq_length(self, cache_position=None) -> int: if cache_position is not None: return int(cache_position[-1] + 1) @@ -600,33 +589,32 @@ class SlidingWindowLayer(StaticLayer): def __init__(self, sliding_window, max_cache_len=None, *args, **kwargs): super().__init__(*args, max_cache_len=sliding_window, *args, **kwargs) - def _static_update( + def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - cache_position: torch.LongTensor, + cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Updates the sliding window cache tensors, returning the potentially modified tensors. + Update the sliding window cache tensors in place. Args: - k_cache (`torch.Tensor`): The key cache tensor to update. - v_cache (`torch.Tensor`): The value cache tensor to update. - key_states (`torch.Tensor`): The new key states to add. - value_states (`torch.Tensor`): The new value states to add. - cache_position (`torch.LongTensor`): The position indices where the new states should be inserted. - max_cache_len (`int`): The maximum length of the sliding window cache. + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update. - For prefill > window, these are the full input states. - Otherwise, they are the updated cache tensors. + tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value states. """ - + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None if cache_position is None: raise ValueError("`cache_position` must be provided for SlidingWindowLayer.") - # Handle prefill phase when prompt length > sliding_window_size + key_states = key_states.to(self.keys.dtype) + value_states = value_states.to(self.values.dtype) + + # Handle prefill phase when prompt length > sliding_window_size. + # Note that we store cropped key/value states in the cache but return the full key/value states. if cache_position.shape[0] > self.max_cache_len: new_k = key_states[:, :, -self.max_cache_len :, :] new_v = value_states[:, :, -self.max_cache_len :, :] From 7029a90c430ec7d3161f00e18a8baf6e47128c9b Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 14 Jul 2025 16:11:27 +0200 Subject: [PATCH 15/35] simplify cache export --- src/transformers/cache_utils.py | 88 ++++++++++++++------------------- 1 file changed, 36 insertions(+), 52 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 46f7879ed108..615b9b435af7 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -723,67 +723,51 @@ def from_legacy_cache( # Utilities for `DynamicCache` <> torch.export support -def _flatten_dynamic_cache( - dynamic_cache: DynamicCache, -): - """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" - if not isinstance(dynamic_cache, DynamicCache): - raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") - - if not is_torch_greater_or_equal_than_2_6: - logger.warning_once( - "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." - ) - dictionary = { - "key_cache": [layer.keys for layer in dynamic_cache.layers if layer.keys is not None], - "value_cache": [layer.values for layer in dynamic_cache.layers if layer.values is not None], - } - return torch.utils._pytree._dict_flatten(dictionary) - - -def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): - dictionary = { - "key_cache": [layer.keys for layer in dynamic_cache.layers if layer.keys is not None], - "value_cache": [layer.values for layer in dynamic_cache.layers if layer.values is not None], - } - return torch.utils._pytree._dict_flatten_with_keys(dictionary) - - -def _unflatten_dynamic_cache( - values, - context: torch.utils._pytree.Context, -): - dictionary = torch.utils._pytree._dict_unflatten(values, context) - cache = DynamicCache() - # Reconstruct layers from keys and values lists - key_list = dictionary.get("key_cache", []) - value_list = dictionary.get("value_cache", []) - for idx in range(max(len(key_list), len(value_list))): - key = key_list[idx] if idx < len(key_list) else None - value = value_list[idx] if idx < len(value_list) else None - cache.update(key, value, idx) - return cache - - -def _flatten_dynamic_cache_for_fx(cache, spec): - dictionary = { - "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None], - "value_cache": [layer.values for layer in cache.layers if layer.values is not None], - } - return torch.fx._pytree._dict_flatten_spec(dictionary, spec) +if is_torch_greater_or_equal("2.3"): + def _get_cache_dict(cache: DynamicCache): + if any(not isinstance(layer, DynamicLayer) for layer in cache.layers): + raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") + + if not is_torch_greater_or_equal_than_2_6: + logger.warning_once( + "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." + ) + + return { + "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in cache.layers if layer.values is not None], + } + + def _unflatten_dynamic_cache( + values, + context: torch.utils._pytree.Context, + ): + dictionary = torch.utils._pytree._dict_unflatten(values, context) + cache = DynamicCache() + # Reconstruct layers from keys and values lists + key_list = dictionary.get("key_cache", []) + value_list = dictionary.get("value_cache", []) + for idx in range(max(len(key_list), len(value_list))): + key = key_list[idx] if idx < len(key_list) else None + value = value_list[idx] if idx < len(value_list) else None + cache.update(key, value, idx) + return cache -if is_torch_greater_or_equal("2.3"): torch.utils._pytree.register_pytree_node( DynamicCache, - _flatten_dynamic_cache, + lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)), _unflatten_dynamic_cache, serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", - flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, + flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys( + _get_cache_dict(dynamic_cache) + ), ) # TODO (tmanlaibaatar) This won't be needed in torch 2.7. - torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) + torch.fx._pytree.register_pytree_flatten_spec( + DynamicCache, lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec) + ) class OffloadedCache(DynamicCache): From dd7458b53eac0651f4b1ef149a4d69b340610e7e Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 14 Jul 2025 18:33:17 +0200 Subject: [PATCH 16/35] fix lfm2 cache --- src/transformers/models/lfm2/modeling_lfm2.py | 34 ++++++++++++++++++- src/transformers/models/lfm2/modular_lfm2.py | 34 ++++++++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 4931a3a46e04..577f73538bae 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -128,6 +128,9 @@ class Lfm2HybridConvCache(DynamicCache): Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. """ + max_batch_size = None + is_compileable = False + def __init__( self, config: Lfm2Config, @@ -135,7 +138,8 @@ def __init__( dtype: torch.dtype = torch.float32, device: Union[torch.device, str, None] = None, ): - super().__init__() # initialize key and value cache + self.key_cache = [] + self.value_cache = [] self.max_batch_size = max_batch_size self.layer_types = config.layer_types self.first_attention_layer = self.layer_types.index("full_attention") @@ -218,6 +222,34 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + full_mask_kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, full_mask_kv_offset + + def crop(self, max_length: int): + """Crop the cache to the given length""" + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + if self.key_cache is not None and self.key_cache.numel(): + self.key_cache = self.key_cache[..., :max_length, :] + self.value_cache = self.value_cache[..., :max_length, :] + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 338e6ec5242d..f1b65172435f 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -89,6 +89,9 @@ class Lfm2HybridConvCache(DynamicCache): Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. """ + max_batch_size = None + is_compileable = False + def __init__( self, config: Lfm2Config, @@ -96,7 +99,8 @@ def __init__( dtype: torch.dtype = torch.float32, device: Union[torch.device, str, None] = None, ): - super().__init__() # initialize key and value cache + self.key_cache = [] + self.value_cache = [] self.max_batch_size = max_batch_size self.layer_types = config.layer_types self.first_attention_layer = self.layer_types.index("full_attention") @@ -179,6 +183,34 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + full_mask_kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, full_mask_kv_offset + + def crop(self, max_length: int): + """Crop the cache to the given length""" + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + if self.key_cache is not None and self.key_cache.numel(): + self.key_cache = self.key_cache[..., :max_length, :] + self.value_cache = self.value_cache[..., :max_length, :] + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") From dc08253c46ed1bd664a22c5418df740a6784fde8 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 14 Jul 2025 18:37:38 +0200 Subject: [PATCH 17/35] HybridChunked to layer --- src/transformers/cache_utils.py | 580 ++++++++++---------------------- 1 file changed, 187 insertions(+), 393 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 615b9b435af7..a7287d6435ac 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -415,69 +415,6 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: return kv_length, full_mask_kv_offset -class CacheProcessor: - """ - Base class for cache processors that can be applied to modify cache behavior. - This class should be subclassed. - """ - - def __init__(self, cache: "Cache", **kwargs) -> None: - """ - Initialize the processor and perform compatibility checks with the cache. - - Args: - cache (`Cache`): The cache instance this processor will be applied to. - **kwargs: Additional arguments that may be needed for initialization. - """ - raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.") - - def pre_update( - self, - cache: "Cache", - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Function called before the cache update. Can modify the key/value states. - - Args: - cache (`Cache`): The cache instance. - key_states (`torch.Tensor`): The new key states to cache. - value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states. - """ - return key_states, value_states - - def post_update( - self, - cache: "Cache", - key_tensors: torch.Tensor, - value_tensors: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Function called after the cache update. Can process the cached data. - - Args: - cache (`Cache`): The cache instance. - key_states (`torch.Tensor`): The key states that were cached. - value_states (`torch.Tensor`): The value states that were cached. - layer_idx (`int`): The index of the layer that was updated. - cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The final key and value states to return. - """ - return key_tensors, value_tensors - - class StaticLayer(CacheLayerMixin): is_compileable = True is_sliding = False @@ -658,6 +595,84 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: return local_mask_kv_length, local_mask_kv_offset +class ChunkedAttentionLayer(StaticLayer): + """ + A static cache layer that implements chunked attention caching. + Inherits from StaticLayer but uses chunked attention update logic. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cumulative_length = 0 + self.is_sliding = True + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + if cache_position is None: + raise ValueError("`cache_position` must be provided for ChunkedAttentionLayer.") + + key_states = key_states.to(self.keys.dtype) + value_states = value_states.to(self.values.dtype) + + cumulative_length = self.cumulative_length + self.cumulative_length += key_states.shape[-2] + is_full = cumulative_length >= self.max_cache_len + + if is_full: + full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) + full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) + if key_states.shape[-2] == 1: + self.keys.copy_(full_key_states) + self.values.copy_(full_value_states) + return self.keys, self.values + elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len: + if cumulative_length == 0: + full_key_states = key_states + full_value_states = value_states + else: + full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2) + full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2) + else: + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + return self.keys, self.values + + self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :]) + self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :]) + return full_key_states, full_value_states + + def reset(self) -> None: + super().reset() + self.cumulative_length = 0 + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + sliding_window = self.max_cache_len + + local_mask_kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0) + # This is the true general case for any Cache using local attention (sliding or chunked) + if first_cache_position >= sliding_window: + # Here the Cache is already full + local_mask_kv_length = sliding_window + query_length - 1 + elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window: + # Here the Cache becomes full with the new input + local_mask_kv_length = first_cache_position + query_length + else: + # Here the Cache is still smaller than the local size, but we return the local size as it's static + local_mask_kv_length = sliding_window + return local_mask_kv_length, local_mask_kv_offset + + class DynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. @@ -816,6 +831,39 @@ def __init__(self, *args, **kwargs): super().__init__(layer_classes=[StaticLayer], *args, **kwargs) +class HybridCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window + attention and global attention in every other layer (originally implemented for Gemma2). + Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] + for global attention.For more information, see the documentation of each subcomponent cache class. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + + def __init__(self, config: PretrainedConfig, *args, **kwargs): + # Ugly hack for BC: if layer_types is not set, fallback to StaticCache. Otherwise, Cache init will use config.layer_types + layer_classes = [StaticLayer] if not hasattr(config, "layer_types") else None + super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) + + class SlidingWindowCache(Cache): """ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. @@ -856,6 +904,69 @@ def __init__(self, *args, **kwargs): super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs) +class CacheProcessor: + """ + Base class for cache processors that can be applied to modify cache behavior. + This class should be subclassed. + """ + + def __init__(self, cache: "Cache", **kwargs) -> None: + """ + Initialize the processor and perform compatibility checks with the cache. + + Args: + cache (`Cache`): The cache instance this processor will be applied to. + **kwargs: Additional arguments that may be needed for initialization. + """ + raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.") + + def pre_update( + self, + cache: "Cache", + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Function called before the cache update. Can modify the key/value states. + + Args: + cache (`Cache`): The cache instance. + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + layer_idx (`int`): The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states. + """ + return key_states, value_states + + def post_update( + self, + cache: "Cache", + key_tensors: torch.Tensor, + value_tensors: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Function called after the cache update. Can process the cached data. + + Args: + cache (`Cache`): The cache instance. + key_states (`torch.Tensor`): The key states that were cached. + value_states (`torch.Tensor`): The value states that were cached. + layer_idx (`int`): The index of the layer that was updated. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The final key and value states to return. + """ + return key_tensors, value_tensors + + class OffloadedCacheProcessor(CacheProcessor): """ A cache processor that offloads cache tensors to conserve accelerator memory. @@ -1374,7 +1485,7 @@ def parse_layer_args_from_model_config( LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = { "full_attention": StaticLayer, "sliding_attention": SlidingWindowLayer, - # "chunked_attention": ChunkedLayer, + "chunked_attention": ChunkedAttentionLayer, } PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = { "offloaded": OffloadedCacheProcessor, @@ -1566,39 +1677,6 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) -class HybridCache(Cache): - """ - Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window - attention and global attention in every other layer (originally implemented for Gemma2). - Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention.For more information, see the documentation of each subcomponent cache class. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` - """ - - def __init__(self, config: PretrainedConfig, *args, **kwargs): - # Ugly hack for BC: if layer_types is not set, fallback to StaticCache. Otherwise, Cache init will use config.layer_types - layer_classes = [StaticLayer] if not hasattr(config, "layer_types") else None - super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) - - class QuantizedCache(DynamicCache): """ A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). @@ -1785,307 +1863,23 @@ class HybridChunkedCache(Cache): ``` """ - is_compileable = True - # Override Cache's @property methods since we will define them in the init - is_sliding = None - max_batch_size = None - max_cache_len = None - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.bfloat16, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - if not hasattr(config, "sliding_window") or config.sliding_window is None: - self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192) - else: - self.sliding_window = config.sliding_window - self.max_cache_len = max_cache_len - # Sliding layers can't be larger than the overall max cache len - self.sliding_window = min(self.sliding_window, self.max_cache_len) - self.max_batch_size = max_batch_size - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self._dtype = dtype - - # If the attribute does not exist in the config, fallback to a simple StaticCache - if hasattr(config, "layer_types"): - self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types] - else: - self.is_sliding = [False] * config.num_hidden_layers - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - self.cumulative_length = [0 for _ in range(config.num_hidden_layers)] - - def initialise_cache_layer(self, layer_idx, key_states): - if len(self.key_cache) > layer_idx: - return - - num_key_value_heads = key_states.shape[1] - device = key_states.device - global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - cumulative_length = self.cumulative_length[layer_idx] - # Update it now that we saved the value above - self.cumulative_length[layer_idx] += key_states.shape[-2] - is_full = cumulative_length >= max_cache_len - if is_full: - full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2) - full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2) - # Fast decoding path -> here as the effective size is still sliding window, it is extremely important - # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address - # in memory (the values are the same as the full states, but not the address!!) - if key_states.shape[-2] == 1: - self.key_cache[layer_idx].copy_(full_key_states) - self.value_cache[layer_idx].copy_(full_value_states) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: - # Fast prefill path, no need to cat() in this case (which creates a copy even if cating from 0 dim) - if cumulative_length == 0: - full_key_states = key_states - full_value_states = value_states - else: - full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2) - full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2) - else: - try: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - except NotImplementedError: - # MPS does not support index_copy_ - self.key_cache[layer_idx][:, :, cache_position] = key_states - self.value_cache[layer_idx][:, :, cache_position] = value_states - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) - self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return full_key_states, full_value_states - - def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out - return k_out, v_out - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - self.initialise_cache_layer(layer_idx, key_states) - - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - update_fn = self._sliding_update if self.is_sliding[layer_idx] else self._static_update - return update_fn( - cache_position, - layer_idx, - key_states, - value_states, - k_out, - v_out, - k_out.shape[2], - ) - - def get_max_cache_shape(self) -> int: - return self.max_cache_len - - def get_seq_length(self, layer_idx: Optional[int] = 0): - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx != 0: - raise ValueError( - "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " - "Using the `layer_idx` argument is not supported." - ) - if len(self.key_cache) == 0: - return 0 - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - if self.value_cache[layer_idx].numel(): - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - if self.is_sliding[layer_idx]: - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] - - local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) - # This is the true general case for any Cache using local attention (sliding or chunked) - if first_cache_position >= self.sliding_window: - # Here the Cache is already full - local_mask_kv_length = self.sliding_window + query_length - 1 - elif ( - first_cache_position < self.sliding_window - and first_cache_position + query_length > self.sliding_window - ): - # Here the Cache becomes full with the new input - local_mask_kv_length = first_cache_position + query_length - else: - # Here the Cache is still smaller than the local size, but we return the local size as it's static - local_mask_kv_length = self.sliding_window - return local_mask_kv_length, local_mask_kv_offset - - full_mask_kv_offset = 0 - full_mask_kv_length = self.get_max_cache_shape() - return full_mask_kv_length, full_mask_kv_offset + def __init__(self, config: PretrainedConfig, *args, **kwargs): + # Ugly hack for BC: if layer_types is not set, fallback to StaticCache. Otherwise, Cache init will use config.layer_types + layer_classes = [StaticLayer] if not hasattr(config, "layer_types") else None + super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) class OffloadedHybridCache(HybridChunkedCache): - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.bfloat16, - offload_device: Union[str, torch.device] = torch.device("cpu"), - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ): - super().__init__(config, max_batch_size, max_cache_len, device, dtype, layer_device_map) - - # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps - # track of the original device of each layer - unique_devices = set(layer_device_map.values()) if layer_device_map else set() - if len(unique_devices) > 1: - raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}") + """ + A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading + cache tensors to CPU when not actively being used. - self.offload_device = torch.device(offload_device) - # Create new CUDA stream for parallel prefetching. - self._prefetch_stream = torch.cuda.Stream() if torch._C._get_accelerator().type == "cuda" else None - # Those will be dynamically created as the other layers (for TP) - self.device_key_cache = None - self.device_value_cache = None - # This gives the index of which on-device full layer to use (we need 2 to avoid race conditions when prefetching) - self.active_device_layer = 0 - - def initialise_cache_layer(self, layer_idx, key_states): - """Overridden to use the correct device if offloaded layer (and pin memory).""" - if len(self.key_cache) > layer_idx: - return + This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling + much longer sequences by offloading inactive layers to CPU memory. + """ - num_key_value_heads = key_states.shape[1] - device = key_states.device if self.is_sliding[layer_idx] else self.offload_device - pin_memory = not self.is_sliding[layer_idx] - global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - # Make sure to initialize the on-device layer if it does not already exist - if self.device_key_cache is None and not self.is_sliding[layer_idx]: - self.device_key_cache = [] - self.device_value_cache = [] - # We need 2 layers to avoid race conditions when prefetching the next one - for _ in range(2): - device_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device) - device_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.device_key_cache.append(device_layer_key_cache) - self.device_value_cache.append(device_layer_value_cache) - - def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - # Wait for prefetch stream if needed - if self._prefetch_stream is not None: - torch.cuda.default_stream(key_states.device).wait_stream(self._prefetch_stream) - - # Get correct on-device layer - k_out = self.device_key_cache[self.active_device_layer] - v_out = self.device_value_cache[self.active_device_layer] - - # Let's prefetch the next layer as soon as possible - self._prefetch_next_layer(layer_idx) - - # Copy to on-device layer - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # Copy to offloaded device - self.key_cache[layer_idx][:, :, cache_position] = key_states.to(self.offload_device) - self.value_cache[layer_idx][:, :, cache_position] = value_states.to(self.offload_device) - - return k_out, v_out - - def _prefetch_next_layer(self, layer_idx: int) -> None: - """Based on current layer_idx, prefetch next full layer to the device.""" - - # Switch the active layer - self.active_device_layer = 0 if self.active_device_layer == 1 else 1 - - # Find the next non-sliding layer - try: - next_layer = layer_idx + 1 + self.is_sliding[layer_idx + 1 :].index(False) - # In this case, we are at the last layer, and we go back to prefect the first one - except ValueError: - next_layer = self.is_sliding.index(False) - - # Alternate between two on-device caches. - if self._prefetch_stream is not None: - with torch.cuda.stream(self._prefetch_stream): - self._prefetch_layer_in_context(next_layer) - else: - self._prefetch_layer_in_context(next_layer) - - def _prefetch_layer_in_context(self, layer_idx: int) -> None: - """Performs the actual copy of the layer to device cache.""" - if len(self.key_cache) > layer_idx: - self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True) - self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True) - # The layer was not yet initialized - else: - self.device_key_cache[self.active_device_layer].fill_(0.0) - self.device_value_cache[self.active_device_layer].fill_(0.0) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) class SinkCache(Cache): From a952124942808ee5c7686cc1ab3b3b3c504b51cc Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 14 Jul 2025 18:45:51 +0200 Subject: [PATCH 18/35] BC proxy object for cache.key_cache[i]=... --- src/transformers/cache_utils.py | 47 +++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index a7287d6435ac..459add5366d9 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -56,6 +56,37 @@ def _wrapped_update( return _wrapped_update +class KeyValuesBCWrapper: + """Efficiently simulates layer-indexed key or value lists from a layered cache. + This allows for BC access and writing, e.g., cache.key_cache[idx] = ...""" + + def __init__(self, layers, cache_type="key"): + self.layers = layers + self.cache_type = cache_type + + def __getitem__(self, idx): + if isinstance(idx, slice): + return [getattr(layer, f"{self.cache_type}_cache") for layer in self.layers[idx]] + return getattr(self.layers[idx], f"{self.cache_type}_cache") + + def __setitem__(self, idx, value): + if isinstance(idx, slice): + for layer, val in zip(self.layers[idx], value): + setattr(layer, f"{self.cache_type}_cache", val) + else: + setattr(self.layers[idx], f"{self.cache_type}_cache", value) + + def __len__(self): + return len(self.layers) + + def __iter__(self): + for layer in self.layers: + yield getattr(layer, f"{self.cache_type}_cache") + + def __bool__(self): + return bool(self.layers) + + class Cache: """ Base class for all caches. @@ -216,6 +247,22 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ kv_length, kv_offset = self.layers[layer_idx].get_mask_sizes(cache_position) return kv_length, kv_offset + @property + def key_cache(self) -> KeyValuesBCWrapper: + """Returns a list-like object of key cache tensors indexed by layer.""" + warnings.warn( + "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead." + ) + return KeyValuesBCWrapper(self.layers, "key") + + @property + def value_cache(self) -> KeyValuesBCWrapper: + """Returns a list-like object of value cache tensors indexed by layer.""" + warnings.warn( + "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead." + ) + return KeyValuesBCWrapper(self.layers, "value") + ### Wrappers for layer operations and properties ### def get_max_cache_shape(self, layer_idx: int = 0) -> int: From dbbc4d51041ea1b6840bfebfcc4ee49f2d2393fd Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 14 Jul 2025 18:59:35 +0200 Subject: [PATCH 19/35] reorder classes --- src/transformers/cache_utils.py | 1880 ++++++++++++++++--------------- 1 file changed, 942 insertions(+), 938 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 459add5366d9..154eba290607 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -56,402 +56,137 @@ def _wrapped_update( return _wrapped_update -class KeyValuesBCWrapper: - """Efficiently simulates layer-indexed key or value lists from a layered cache. - This allows for BC access and writing, e.g., cache.key_cache[idx] = ...""" - - def __init__(self, layers, cache_type="key"): - self.layers = layers - self.cache_type = cache_type - - def __getitem__(self, idx): - if isinstance(idx, slice): - return [getattr(layer, f"{self.cache_type}_cache") for layer in self.layers[idx]] - return getattr(self.layers[idx], f"{self.cache_type}_cache") - - def __setitem__(self, idx, value): - if isinstance(idx, slice): - for layer, val in zip(self.layers[idx], value): - setattr(layer, f"{self.cache_type}_cache", val) - else: - setattr(self.layers[idx], f"{self.cache_type}_cache", value) - - def __len__(self): - return len(self.layers) - - def __iter__(self): - for layer in self.layers: - yield getattr(layer, f"{self.cache_type}_cache") - - def __bool__(self): - return bool(self.layers) - - -class Cache: - """ - Base class for all caches. - The actual data structure is specific to the layers. - This class handles propagation of operations across layers. +class CacheLayerMixin: + """Base, abstract class for a single layer's cache.""" - Parameters: - config (`PretrainedConfig`, *optional*): - Model configuration for shape/device info. - cache_processor (`CacheProcessor` or `str`, *optional*): - Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") or - a CacheProcessor class. - layer_classes (`list[type[CacheLayer]]`, *optional*): - List of layer classes to use for the cache. - Additional arguments for cache configuration: - max_batch_size/batch_size (`int`): Maximum batch size for static caches - max_cache_len (`int`): Maximum sequence length. For hybrid caches: - * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` - * StaticLayers: uses full `max_cache_len` - device (`torch.device`): Device for cache tensors - dtype (`torch.dtype`): Data type for cache tensors - layer_device_map (`dict[int, Union[str, torch.device]]`): Per-layer device mapping - tp_size (`int`): Tensor parallel size to adjust the number of key/value heads - device (`torch.device`): Device for cache tensors - dtype (`torch.dtype`): Data type for cache tensors - layer_device_map (`dict[int, Union[str, torch.device]]`): Per-layer device mapping - tp_size (`int`): Tensor parallel size to adjust the number of key/value heads - """ + is_compileable = False - def __init__( + def update( self, - config: Optional[PretrainedConfig] = None, - cache_processor: Optional[Union[str, type["CacheProcessor"]]] = None, - layer_classes: Optional[list[type["CacheLayerMixin"]]] = None, - *args, - **kwargs, - ): - self.layers: list["CacheLayerMixin"] = [] - processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Updates KV cache, returns updated keys/values for this layer.""" + raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") - if ( - layer_classes is None # setting layer_classes takes precedence - and config is not None - and getattr(config, "layer_types", None) is not None - ): - layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] - self.layer_classes = layer_classes or [DynamicLayer] + def get_seq_length(self) -> int: + """Returns the sequence length of this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.") - processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) - self.layer_init_args = parse_layer_args_from_model_config(config, *args, **kwargs) - self.model_num_layers = getattr(config, "num_hidden_layers", 1) + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") - self.append_new_layers(self.model_num_layers - 1) - self.cache_processor = processor_class(self, **processor_kwargs) if processor_class is not None else None + def reset(self) -> tuple[torch.Tensor, torch.Tensor]: + """Resets this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `reset` in {self.__class__.__name__}.") - def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self.layers): - return self.layers[layer_idx].keys, self.layers[layer_idx].values - else: - raise KeyError( - f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" - ) + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Returns mask sizes for this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.") - def __iter__(self): - """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over - keys and values - """ - for layer_idx in range(len(self)): - yield (self.layers[layer_idx].keys, self.layers[layer_idx].values) + def reorder_cache(self, beam_idx: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]: + """Reorders this layer's cache for beam search.""" + if self.keys.numel(): + device = self.keys.device + self.keys = self.keys.index_select(0, beam_idx.to(device)) + if self.values.numel(): + device = self.values.device + self.values = self.values.index_select(0, beam_idx.to(device)) - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked that rely on __len__ - if getattr(self, "layers", None) is None: - if getattr(self, "key_cache", None) is not None: - return len(self.key_cache) - return 0 - # Empty dynamic caches initialize an empty layer to be ready for first update - dynamic_empty = ( - getattr(self, "layers", None) is not None - and len(self.layers) == 1 - and isinstance(self.layers[0], DynamicLayer) - and self.layers[0].keys is None - ) - return len(self.layers) if not dynamic_empty else 0 - def __repr__(self): - return f"{self.__class__.__name__}(layers={self.layers})" +class DynamicLayer(CacheLayerMixin): + """ + A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. + It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. + """ - def append_new_layers(self, layer_idx: int) -> None: - """ - Appends layers to the cache until the layer `layer_idx` is reached. - Used for preallocation in static caches and on the fly in dynamic caches. + keys, values = None, None - Args: - layer_idx (`int`): - The index of the layer to append. - """ - while len(self.layers) <= layer_idx: - args = self.layer_init_args.copy() - if self.layer_init_args.get("layer_device_map", None) is not None: - args["device"] = args.pop("layer_device_map")[layer_idx] - new_layer = self.layer_classes[layer_idx % len(self.layer_classes)](**args) - self.layers.append(new_layer) + @classmethod + def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer": + layer = cls() + layer.keys = keys + layer.values = values + return layer - @apply_processors def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: int, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Updates the cache with the new `key_states` and `value_states`. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. cache_kwargs (`dict[str, Any]`, *optional*): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. + Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`. Return: A tuple containing the updated key and value states. """ - self.append_new_layers(layer_idx) - return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + if self.keys is None: + self.keys = key_states + self.values = value_states + else: + self.keys = torch.cat([self.keys, key_states], dim=-2) + self.values = torch.cat([self.values, value_states], dim=-2) + return self.keys, self.values - def get_seq_length(self, layer_idx: int = 0) -> int: - """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" - if layer_idx >= len(self.layers): + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + if self.keys is None or self.keys.numel() == 0: return 0 - # Hack since QuantizedCache messes with keys shape as it becomes the residual cache - if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): - return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length() - return self.layers[layer_idx].get_seq_length() + return self.keys.shape[-2] - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" + return -1 + + def reset(self) -> None: + """Resets the cache values while preserving the objects""" + self.keys.zero_() + self.values.zero_() + return self.keys, self.values + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + """Reorders the cache for beam search, given the selected beam indices.""" + if self.keys is not None and self.keys.numel(): + self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device)) + self.values = self.values.index_select(0, beam_idx.to(self.values.device)) + + def crop(self, max_length: int) -> None: """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. + Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. """ - kv_length, kv_offset = self.layers[layer_idx].get_mask_sizes(cache_position) - return kv_length, kv_offset + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) - @property - def key_cache(self) -> KeyValuesBCWrapper: - """Returns a list-like object of key cache tensors indexed by layer.""" - warnings.warn( - "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead." - ) - return KeyValuesBCWrapper(self.layers, "key") + if self.get_seq_length() <= max_length: + return - @property - def value_cache(self) -> KeyValuesBCWrapper: - """Returns a list-like object of value cache tensors indexed by layer.""" - warnings.warn( - "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead." - ) - return KeyValuesBCWrapper(self.layers, "value") + if self.keys is not None and self.keys.numel(): + self.keys = self.keys[..., :max_length, :] + self.values = self.values[..., :max_length, :] - ### Wrappers for layer operations and properties ### + def batch_repeat_interleave(self, repeats: int) -> None: + """Repeat the cache `repeats` times in the batch dimension.""" + if self.keys is not None and self.keys.numel(): + self.keys = self.keys.repeat_interleave(repeats, dim=0) + self.values = self.values.repeat_interleave(repeats, dim=0) - def get_max_cache_shape(self, layer_idx: int = 0) -> int: - """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" - return self.layers[layer_idx].get_max_cache_shape() - - def reset(self): - """Recursively reset all layers tensors""" - for layer_idx in range(len(self.layers)): - self.layers[layer_idx].reset() - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorder the cache for beam search""" - for layer_idx in range(len(self.layers)): - self.layers[layer_idx].reorder_cache(beam_idx) - - def crop(self, max_length: int): - """Crop the cache to the given length""" - for layer_idx in range(len(self.layers)): - self.layers[layer_idx].crop(max_length) - - def batch_repeat_interleave(self, repeats: int): - """Repeat and interleave the cache""" - for layer_idx in range(len(self.layers)): - self.layers[layer_idx].batch_repeat_interleave(repeats) - - def batch_select_indices(self, indices: torch.Tensor): - """Select indices from the cache""" - for layer_idx in range(len(self.layers)): - self.layers[layer_idx].batch_select_indices(indices) - - @property - def max_batch_size(self) -> int: - """Return the maximum batch size of the cache""" - values = [layer.max_batch_size for layer in self.layers] - if len(set(values)) > 1: - raise ValueError(f"Max batch size is not consistent across layers: {values}") - return values[0] - - @property - def max_cache_len(self) -> int: - """Return the maximum cache length of the cache""" - values = [layer.max_cache_len for layer in self.layers] - if len(set(values)) > 1: - raise ValueError(f"Max cache length is not consistent across layers: {values}") - return values[0] - - @property - def is_compileable(self) -> bool: - """Return whether the cache is compileable""" - return all(layer.is_compileable for layer in self.layers) - - @property - def is_sliding(self) -> list[bool]: - """Return whether the layers of the cache are sliding window""" - return [getattr(layer, "is_sliding", False) for layer in self.layers] - - -class CacheLayerMixin: - """Base, abstract class for a single layer's cache.""" - - is_compileable = False - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Updates KV cache, returns updated keys/values for this layer.""" - raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") - - def get_seq_length(self) -> int: - """Returns the sequence length of this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.") - - def get_max_cache_shape(self) -> int: - """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") - - def reset(self) -> tuple[torch.Tensor, torch.Tensor]: - """Resets this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `reset` in {self.__class__.__name__}.") - - def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: - """Returns mask sizes for this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.") - - def reorder_cache(self, beam_idx: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]: - """Reorders this layer's cache for beam search.""" - if self.keys.numel(): - device = self.keys.device - self.keys = self.keys.index_select(0, beam_idx.to(device)) - if self.values.numel(): - device = self.values.device - self.values = self.values.index_select(0, beam_idx.to(device)) - - -class DynamicLayer(CacheLayerMixin): - """ - A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. - It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. - """ - - keys, values = None, None - - @classmethod - def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer": - layer = cls() - layer.keys = keys - layer.values = values - return layer - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - cache_kwargs (`dict[str, Any]`, *optional*): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`. - - Return: - A tuple containing the updated key and value states. - """ - if self.keys is None: - self.keys = key_states - self.values = value_states - else: - self.keys = torch.cat([self.keys, key_states], dim=-2) - self.values = torch.cat([self.values, value_states], dim=-2) - return self.keys, self.values - - def get_seq_length(self) -> int: - """Returns the sequence length of the cached states.""" - if self.keys is None or self.keys.numel() == 0: - return 0 - return self.keys.shape[-2] - - def get_max_cache_shape(self) -> int: - """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" - return -1 - - def reset(self) -> None: - """Resets the cache values while preserving the objects""" - self.keys.zero_() - self.values.zero_() - return self.keys, self.values - - def reorder_cache(self, beam_idx: torch.LongTensor) -> None: - """Reorders the cache for beam search, given the selected beam indices.""" - if self.keys is not None and self.keys.numel(): - self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device)) - self.values = self.values.index_select(0, beam_idx.to(self.values.device)) - - def crop(self, max_length: int) -> None: - """ - Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be - negative to remove `max_length` tokens. - """ - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - if self.keys is not None and self.keys.numel(): - self.keys = self.keys[..., :max_length, :] - self.values = self.values[..., :max_length, :] - - def batch_repeat_interleave(self, repeats: int) -> None: - """Repeat the cache `repeats` times in the batch dimension.""" - if self.keys is not None and self.keys.numel(): - self.keys = self.keys.repeat_interleave(repeats, dim=0) - self.values = self.values.repeat_interleave(repeats, dim=0) - - def batch_select_indices(self, indices: torch.Tensor) -> None: - """Only keep the `indices` in the batch dimension of the cache.""" - if self.keys is not None and self.keys.numel(): - self.keys = self.keys[indices, ...] - self.values = self.values[indices, ...] + def batch_select_indices(self, indices: torch.Tensor) -> None: + """Only keep the `indices` in the batch dimension of the cache.""" + if self.keys is not None and self.keys.numel(): + self.keys = self.keys[indices, ...] + self.values = self.values[indices, ...] def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: """Return the length and offset of the cache, used to generate the mask""" @@ -720,275 +455,44 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: return local_mask_kv_length, local_mask_kv_offset -class DynamicCache(Cache): +class CacheProcessor: + """ + Base class for cache processors that can be applied to modify cache behavior. + This class should be subclassed. """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + def __init__(self, cache: "Cache", **kwargs) -> None: + """ + Initialize the processor and perform compatibility checks with the cache. - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + Args: + cache (`Cache`): The cache instance this processor will be applied to. + **kwargs: Additional arguments that may be needed for initialization. + """ + raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.") - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = DynamicCache() - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - DynamicCache() - ``` - """ + def pre_update( + self, + cache: "Cache", + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Function called before the cache update. Can modify the key/value states. - # Specialized constructor for DDP cache data, needed for BC - def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): - # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 - # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the - # iterable contains the key and value states for a layer gathered across replicas by torch.distributed - # (shape=[global batch size, num_heads, seq_len, head_dim]). - # WARNING: `ddp_cache_data` must be the first argument in `__init__`, otherwise we'll break - # compatibility. The name of the argument doesn't matter. - if ddp_cache_data is not None: - for key_states, value_states in ddp_cache_data: - self.layers.append(DynamicLayer.from_tensors(key_states, value_states)) - super().__init__(*args, **kwargs) + Args: + cache (`Cache`): The cache instance. + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + layer_idx (`int`): The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: - """ - Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility. + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states. """ - legacy_cache = () - for layer in self.layers: - legacy_cache += ((layer.keys, layer.values),) - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None - ) -> "Cache": - """ - Converts a cache in the legacy cache format into an equivalent `Cache`. Used for - backward compatibility. - """ - cache = cls() - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - -# Utilities for `DynamicCache` <> torch.export support - -if is_torch_greater_or_equal("2.3"): - - def _get_cache_dict(cache: DynamicCache): - if any(not isinstance(layer, DynamicLayer) for layer in cache.layers): - raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") - - if not is_torch_greater_or_equal_than_2_6: - logger.warning_once( - "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." - ) - - return { - "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None], - "value_cache": [layer.values for layer in cache.layers if layer.values is not None], - } - - def _unflatten_dynamic_cache( - values, - context: torch.utils._pytree.Context, - ): - dictionary = torch.utils._pytree._dict_unflatten(values, context) - cache = DynamicCache() - # Reconstruct layers from keys and values lists - key_list = dictionary.get("key_cache", []) - value_list = dictionary.get("value_cache", []) - for idx in range(max(len(key_list), len(value_list))): - key = key_list[idx] if idx < len(key_list) else None - value = value_list[idx] if idx < len(value_list) else None - cache.update(key, value, idx) - return cache - - torch.utils._pytree.register_pytree_node( - DynamicCache, - lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)), - _unflatten_dynamic_cache, - serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", - flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys( - _get_cache_dict(dynamic_cache) - ), - ) - # TODO (tmanlaibaatar) This won't be needed in torch 2.7. - torch.fx._pytree.register_pytree_flatten_spec( - DynamicCache, lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec) - ) - - -class OffloadedCache(DynamicCache): - """ - A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. - Useful for generating from models with very long context. - - In addition to the default accelerator stream, where all forward() computations happen, - this class uses another stream, the prefetch stream, which it creates itself. - Since scheduling of operations on separate streams happens independently, this class uses - the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. - The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to - ensure the eviction is scheduled after all computations on that cache are finished. - """ - - def __init__(self, config: Optional[PretrainedConfig] = None) -> None: - # Create the underlying cache with offload processor - super().__init__(cache_processor=OffloadedCacheProcessor, config=config) - - -class StaticCache(Cache): - """ - Static Cache class to be used with `torch.compile(model)` and `torch.export()`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache - - >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - - >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - StaticCache() - ``` - """ - - def __init__(self, *args, **kwargs): - super().__init__(layer_classes=[StaticLayer], *args, **kwargs) - - -class HybridCache(Cache): - """ - Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window - attention and global attention in every other layer (originally implemented for Gemma2). - Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention.For more information, see the documentation of each subcomponent cache class. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` - """ - - def __init__(self, config: PretrainedConfig, *args, **kwargs): - # Ugly hack for BC: if layer_types is not set, fallback to StaticCache. Otherwise, Cache init will use config.layer_types - layer_classes = [StaticLayer] if not hasattr(config, "layer_types") else None - super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) - - -class SlidingWindowCache(Cache): - """ - Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. - Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.sliding_window - 1`, - if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), - we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. - - The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: - - indices = (slicing + to_shift[-1].sum()-1) % self.sliding_window - tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) - - We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache - - >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - - >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - SlidingWindowCache() - ``` - """ - - def __init__(self, *args, **kwargs): - super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs) - - -class CacheProcessor: - """ - Base class for cache processors that can be applied to modify cache behavior. - This class should be subclassed. - """ - - def __init__(self, cache: "Cache", **kwargs) -> None: - """ - Initialize the processor and perform compatibility checks with the cache. - - Args: - cache (`Cache`): The cache instance this processor will be applied to. - **kwargs: Additional arguments that may be needed for initialization. - """ - raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.") - - def pre_update( - self, - cache: "Cache", - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Function called before the cache update. Can modify the key/value states. - - Args: - cache (`Cache`): The cache instance. - key_states (`torch.Tensor`): The new key states to cache. - value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states. - """ - return key_states, value_states + return key_states, value_states def post_update( self, @@ -1428,117 +932,674 @@ def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tens return tensor -def parse_processor_args(processor_class: Optional[type["CacheProcessor"]], kwargs: dict) -> tuple[dict, dict]: +class Cache: """ - Parse processor arguments from kwargs based on the processor class init signature. - - Args: - processor_class: The processor class to inspect, or None - kwargs: Dictionary of keyword arguments + Base class for all caches. + The actual data structure is specific to the layers. + This class handles propagation of operations across layers. - Returns: - tuple: (processor_kwargs, remaining_kwargs) + Parameters: + config (`PretrainedConfig`, *optional*): + Model configuration for shape/device info. + cache_processor (`CacheProcessor` or `str`, *optional*): + Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") or + a CacheProcessor class. + layer_classes (`list[type[CacheLayer]]`, *optional*): + List of layer classes to use for the cache. + Additional arguments for cache configuration: + max_batch_size/batch_size (`int`): Maximum batch size for static caches + max_cache_len (`int`): Maximum sequence length. For hybrid caches: + * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` + * StaticLayers: uses full `max_cache_len` + device (`torch.device`): Device for cache tensors + dtype (`torch.dtype`): Data type for cache tensors + layer_device_map (`dict[int, Union[str, torch.device]]`): Per-layer device mapping + tp_size (`int`): Tensor parallel size to adjust the number of key/value heads + device (`torch.device`): Device for cache tensors + dtype (`torch.dtype`): Data type for cache tensors + layer_device_map (`dict[int, Union[str, torch.device]]`): Per-layer device mapping + tp_size (`int`): Tensor parallel size to adjust the number of key/value heads """ - try: - params = list(inspect.signature(processor_class.__init__).parameters)[2:] - except Exception: - return {}, kwargs - processor_kwargs = {k: kwargs[k] for k in params if k in kwargs} - remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs} - return processor_kwargs, remaining_kwargs + def __init__( + self, + config: Optional[PretrainedConfig] = None, + cache_processor: Optional[Union[str, type["CacheProcessor"]]] = None, + layer_classes: Optional[list[type["CacheLayerMixin"]]] = None, + *args, + **kwargs, + ): + self.layers: list["CacheLayerMixin"] = [] + processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor + + if ( + layer_classes is None # setting layer_classes takes precedence + and config is not None + and getattr(config, "layer_types", None) is not None + ): + layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] + self.layer_classes = layer_classes or [DynamicLayer] + + processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) + self.layer_init_args = parse_layer_args_from_model_config(config, *args, **kwargs) + self.model_num_layers = getattr(config, "num_hidden_layers", 1) + + self.append_new_layers(self.model_num_layers - 1) + self.cache_processor = processor_class(self, **processor_kwargs) if processor_class is not None else None + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self.layers): + return self.layers[layer_idx].keys, self.layers[layer_idx].values + else: + raise KeyError( + f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.layers[layer_idx].keys, self.layers[layer_idx].values) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked that rely on __len__ + if getattr(self, "layers", None) is None: + if getattr(self, "key_cache", None) is not None: + return len(self.key_cache) + return 0 + # Empty dynamic caches initialize an empty layer to be ready for first update + dynamic_empty = ( + getattr(self, "layers", None) is not None + and len(self.layers) == 1 + and isinstance(self.layers[0], DynamicLayer) + and self.layers[0].keys is None + ) + return len(self.layers) if not dynamic_empty else 0 + + def __repr__(self): + return f"{self.__class__.__name__}(layers={self.layers})" + + def append_new_layers(self, layer_idx: int) -> None: + """ + Appends layers to the cache until the layer `layer_idx` is reached. + Used for preallocation in static caches and on the fly in dynamic caches. + + Args: + layer_idx (`int`): + The index of the layer to append. + """ + while len(self.layers) <= layer_idx: + args = self.layer_init_args.copy() + if self.layer_init_args.get("layer_device_map", None) is not None: + args["device"] = args.pop("layer_device_map")[layer_idx] + new_layer = self.layer_classes[layer_idx % len(self.layer_classes)](**args) + self.layers.append(new_layer) + + @apply_processors + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + self.append_new_layers(layer_idx) + return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" + if layer_idx >= len(self.layers): + return 0 + # Hack since QuantizedCache messes with keys shape as it becomes the residual cache + if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): + return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length() + return self.layers[layer_idx].get_seq_length() + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + kv_length, kv_offset = self.layers[layer_idx].get_mask_sizes(cache_position) + return kv_length, kv_offset + + @property + def key_cache(self) -> "KeyValuesWrapper": + """Returns a list-like object of key cache tensors indexed by layer.""" + warnings.warn( + "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead." + ) + return KeyValuesWrapper(self.layers, "key") + + @property + def value_cache(self) -> "KeyValuesWrapper": + """Returns a list-like object of value cache tensors indexed by layer.""" + warnings.warn( + "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead." + ) + return KeyValuesWrapper(self.layers, "value") + + ### Wrappers for layer operations and properties ### + + def get_max_cache_shape(self, layer_idx: int = 0) -> int: + """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" + return self.layers[layer_idx].get_max_cache_shape() + + def reset(self): + """Recursively reset all layers tensors""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].reset() + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorder the cache for beam search""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].reorder_cache(beam_idx) + + def crop(self, max_length: int): + """Crop the cache to the given length""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].crop(max_length) + + def batch_repeat_interleave(self, repeats: int): + """Repeat and interleave the cache""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: torch.Tensor): + """Select indices from the cache""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].batch_select_indices(indices) + + @property + def max_batch_size(self) -> int: + """Return the maximum batch size of the cache""" + values = [layer.max_batch_size for layer in self.layers] + if len(set(values)) > 1: + raise ValueError(f"Max batch size is not consistent across layers: {values}") + return values[0] + + @property + def max_cache_len(self) -> int: + """Return the maximum cache length of the cache""" + values = [layer.max_cache_len for layer in self.layers] + if len(set(values)) > 1: + raise ValueError(f"Max cache length is not consistent across layers: {values}") + return values[0] + + @property + def is_compileable(self) -> bool: + """Return whether the cache is compileable""" + return all(layer.is_compileable for layer in self.layers) + + @property + def is_sliding(self) -> list[bool]: + """Return whether the layers of the cache are sliding window""" + return [getattr(layer, "is_sliding", False) for layer in self.layers] + + +class DynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicCache() + ``` + """ + + # Specialized constructor for DDP cache data, needed for BC + def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 + # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the + # iterable contains the key and value states for a layer gathered across replicas by torch.distributed + # (shape=[global batch size, num_heads, seq_len, head_dim]). + # WARNING: `ddp_cache_data` must be the first argument in `__init__`, otherwise we'll break + # compatibility. The name of the argument doesn't matter. + if ddp_cache_data is not None: + for key_states, value_states in ddp_cache_data: + self.layers.append(DynamicLayer.from_tensors(key_states, value_states)) + super().__init__(*args, **kwargs) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: + """ + Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility. + """ + legacy_cache = () + for layer in self.layers: + legacy_cache += ((layer.keys, layer.values),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None + ) -> "Cache": + """ + Converts a cache in the legacy cache format into an equivalent `Cache`. Used for + backward compatibility. + """ + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + +# Utilities for `DynamicCache` <> torch.export support + +if is_torch_greater_or_equal("2.3"): + + def _get_cache_dict(cache: DynamicCache): + if any(not isinstance(layer, DynamicLayer) for layer in cache.layers): + raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") + + if not is_torch_greater_or_equal_than_2_6: + logger.warning_once( + "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." + ) + + return { + "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in cache.layers if layer.values is not None], + } + + def _unflatten_dynamic_cache( + values, + context: torch.utils._pytree.Context, + ): + dictionary = torch.utils._pytree._dict_unflatten(values, context) + cache = DynamicCache() + # Reconstruct layers from keys and values lists + key_list = dictionary.get("key_cache", []) + value_list = dictionary.get("value_cache", []) + for idx in range(max(len(key_list), len(value_list))): + key = key_list[idx] if idx < len(key_list) else None + value = value_list[idx] if idx < len(value_list) else None + cache.update(key, value, idx) + return cache + + torch.utils._pytree.register_pytree_node( + DynamicCache, + lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)), + _unflatten_dynamic_cache, + serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", + flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys( + _get_cache_dict(dynamic_cache) + ), + ) + # TODO (tmanlaibaatar) This won't be needed in torch 2.7. + torch.fx._pytree.register_pytree_flatten_spec( + DynamicCache, lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec) + ) + + +class OffloadedCache(DynamicCache): + """ + A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. + Useful for generating from models with very long context. + + In addition to the default accelerator stream, where all forward() computations happen, + this class uses another stream, the prefetch stream, which it creates itself. + Since scheduling of operations on separate streams happens independently, this class uses + the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. + The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to + ensure the eviction is scheduled after all computations on that cache are finished. + """ + + def __init__(self, config: Optional[PretrainedConfig] = None) -> None: + # Create the underlying cache with offload processor + super().__init__(cache_processor=OffloadedCacheProcessor, config=config) + + +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + StaticCache() + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=[StaticLayer], *args, **kwargs) + + +class HybridCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window + attention and global attention in every other layer (originally implemented for Gemma2). + Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] + for global attention.For more information, see the documentation of each subcomponent cache class. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + + def __init__(self, config: PretrainedConfig, *args, **kwargs): + # Ugly hack for BC: if layer_types is not set, fallback to StaticCache. Otherwise, Cache init will use config.layer_types + layer_classes = [StaticLayer] if not hasattr(config, "layer_types") else None + super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) + + +class SlidingWindowCache(Cache): + """ + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.sliding_window - 1`, + if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), + we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. + + The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + + indices = (slicing + to_shift[-1].sum()-1) % self.sliding_window + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + + We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + SlidingWindowCache() + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs) + + +class QuantizedCache(DynamicCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + """ + + def __init__(self, backend, **kwargs) -> None: + if backend == "quanto": + processor = QuantoQuantizedCacheProcessor + elif backend == "hqq": + processor = HQQQuantizedCacheProcessor + else: + raise ValueError(f"Unknown quantization backend `{backend}`") + + super().__init__(cache_processor=processor, **kwargs) + + +class QuantoQuantizedCache(QuantizedCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + + Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + + Example: + + ```python + >>> # Run pip install quanto first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4) + >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + QuantoQuantizedCache() + ``` + """ + + def __init__(self, **kwargs) -> None: + Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs) + + +class HQQQuantizedCache(QuantizedCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + + Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + + Example: + + ```python + >>> # Run pip install hqq first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) + >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HQQQuantizedCache() + ``` + """ + + def __init__(self, backend="HQQ", **kwargs) -> None: + assert backend == "HQQ" + Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs) + + +class OffloadedStaticCache(StaticCache): + """ + A drop-in replacement for StaticCache that conserves accelerator memory by offloading + cache tensors to CPU when not actively being used. + + This cache maintains the compilation-friendly properties of StaticCache while enabling + much longer sequences by offloading inactive layers to CPU memory. + + Example: + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class with offloading + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = OffloadedStaticCache( + ... config=model.config, + ... max_batch_size=1, + ... max_cache_len=max_generated_length, + ... device=model.device, + ... dtype=model.dtype + ... ) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache with offloaded layers + OffloadedStaticCache() + ``` + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) + + +class HybridChunkedCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window + attention and global attention in every other layer, with support for chunked attention (originally implemented + for Llama4). + Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] + for global attention. For more information, see the documentation of each subcomponent cache class. + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + Example: + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache -def parse_layer_args_from_model_config( - config: Optional[PretrainedConfig], - batch_size: Optional[int] = None, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: Optional[torch.dtype] = None, - layer_device_map: Optional[dict[int, torch.device]] = None, - tp_size: Optional[int] = None, - max_batch_size: Optional[int] = None, -) -> dict: + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` """ - Parse layer arguments from model configuration for cache initialization. - Args: - config (`Optional[PretrainedConfig]`): Model configuration containing shape/device info. - batch_size (`Optional[int]`): Batch size for cache initialization. - max_cache_len (`Optional[int]`): Maximum sequence length for cache. - device (`Union[torch.device, str, None]`): Device for cache tensors. - dtype (`Optional[torch.dtype]`): Data type for cache tensors. - layer_device_map: Per-layer device mapping. - tp_size (`Optional[int]`): Tensor parallel size to adjust number of key/value heads. - max_batch_size (`Optional[int]`): Maximum batch size for cache initialization. + def __init__(self, config: PretrainedConfig, *args, **kwargs): + # Ugly hack for BC: if layer_types is not set, fallback to StaticCache. Otherwise, Cache init will use config.layer_types + layer_classes = [StaticLayer] if not hasattr(config, "layer_types") else None + super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) - Returns: - `dict`: Dictionary containing parsed layer arguments for cache initialization. + +class OffloadedHybridCache(HybridChunkedCache): """ - # No model config -> must be a dynamic cache, return bare dict - if config is None: - return {} - # Build the args dict for hybrid, sliding or static - else: - # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) - if ( - getattr(config, "layer_types", None) is not None - and "sliding_attention" in config.layer_types - and "full_attention" in config.layer_types - ): - if getattr(config, "sliding_window", None) is None: - raise ValueError( - "Setting up a hybrid or sliding window KVCache requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) - max_cache_len = max_cache_len or config.max_position_embeddings - if getattr(config, "sliding_window", None) is not None: - sliding_window_len = min(config.sliding_window, max_cache_len) - else: - sliding_window_len = None - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: - head_dim = ( - config.head_dim - if getattr(config, "head_dim", None) is not None - else config.hidden_size // config.num_attention_heads - ) - num_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - if tp_size is not None and tp_size > 1: - if num_heads % tp_size != 0: - raise ValueError( - f"Number of key value heads {num_heads} must be divisible by tensor parallel size {tp_size}." - ) - # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - num_heads //= tp_size - layer_args = { - "batch_size": max_batch_size if max_batch_size is not None else batch_size, - "max_cache_len": max_cache_len, - "device": torch.device(device) if device is not None else None, - "dtype": dtype, - "layer_device_map": layer_device_map, - "head_dim": head_dim, - "num_heads": num_heads, - "sliding_window": sliding_window_len, - } - return {k: v for k, v in layer_args.items() if v is not None} + A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading + cache tensors to CPU when not actively being used. + This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling + much longer sequences by offloading inactive layers to CPU memory. + """ -LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = { - "full_attention": StaticLayer, - "sliding_attention": SlidingWindowLayer, - "chunked_attention": ChunkedAttentionLayer, -} -PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = { - "offloaded": OffloadedCacheProcessor, - "quanto_quantized": QuantizedCacheProcessor, - "hqq_quantized": HQQQuantizedCacheProcessor, -} + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) class EncoderDecoderCache(Cache): @@ -1676,257 +1737,200 @@ def check_dynamic_cache(self, method: str): isinstance(self.self_attention_cache, DynamicCache) and isinstance(self.cross_attention_cache, DynamicCache) ): - raise ValueError( - f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " - f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." - ) - - # TODO(gante, sanchit-gandhi): move following functionality into `.generate` - def crop(self, maximum_length: int): - """ - Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be - negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search. - """ - self.check_dynamic_cache(self.crop.__name__) - self.self_attention_cache.crop(maximum_length) - - def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": - """ - Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils` - """ - self.check_dynamic_cache(self.batch_split.__name__) - self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) - cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) - - out = [] - for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): - out.append(EncoderDecoderCache(self_attn, cross_attn)) - return out - - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - self.check_dynamic_cache(self.batch_repeat_interleave.__name__) - self.self_attention_cache.batch_repeat_interleave(repeats) - self.cross_attention_cache.batch_repeat_interleave(repeats) - - def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - self.check_dynamic_cache(self.batch_select_indices.__name__) - self.self_attention_cache.batch_select_indices(indices) - self.cross_attention_cache.batch_select_indices(indices) - - def get_max_cache_shape(self) -> int: - """Returns the maximum sequence length (i.e. max capacity) of the cache object""" - return self.self_attention_cache.get_max_cache_shape() - - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) - - -class QuantizedCache(DynamicCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - """ - - def __init__(self, backend, **kwargs) -> None: - if backend == "quanto": - processor = QuantoQuantizedCacheProcessor - elif backend == "hqq": - processor = HQQQuantizedCacheProcessor - else: - raise ValueError(f"Unknown quantization backend `{backend}`") - - super().__init__(cache_processor=processor, **kwargs) - - -class QuantoQuantizedCache(QuantizedCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - - Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. - - Example: - - ```python - >>> # Run pip install quanto first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4) - >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - QuantoQuantizedCache() - ``` - """ - - def __init__(self, **kwargs) -> None: - Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs) - - -class HQQQuantizedCache(QuantizedCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + raise ValueError( + f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " + f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." + ) - Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + # TODO(gante, sanchit-gandhi): move following functionality into `.generate` + def crop(self, maximum_length: int): + """ + Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search. + """ + self.check_dynamic_cache(self.crop.__name__) + self.self_attention_cache.crop(maximum_length) - Example: + def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": + """ + Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils` + """ + self.check_dynamic_cache(self.batch_split.__name__) + self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) + cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) - ```python - >>> # Run pip install hqq first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig + out = [] + for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): + out.append(EncoderDecoderCache(self_attn, cross_attn)) + return out - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_repeat_interleave.__name__) + self.self_attention_cache.batch_repeat_interleave(repeats) + self.cross_attention_cache.batch_repeat_interleave(repeats) - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_select_indices.__name__) + self.self_attention_cache.batch_select_indices(indices) + self.cross_attention_cache.batch_select_indices(indices) - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) - >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HQQQuantizedCache() - ``` - """ + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + return self.self_attention_cache.get_max_cache_shape() - def __init__(self, backend="HQQ", **kwargs) -> None: - assert backend == "HQQ" - Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs) + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) -class OffloadedStaticCache(StaticCache): +def parse_processor_args(processor_class: Optional[type["CacheProcessor"]], kwargs: dict) -> tuple[dict, dict]: """ - A drop-in replacement for StaticCache that conserves accelerator memory by offloading - cache tensors to CPU when not actively being used. + Parse processor arguments from kwargs based on the processor class init signature. - This cache maintains the compilation-friendly properties of StaticCache while enabling - much longer sequences by offloading inactive layers to CPU memory. + Args: + processor_class: The processor class to inspect, or None + kwargs: Dictionary of keyword arguments - Example: - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache + Returns: + tuple: (processor_kwargs, remaining_kwargs) + """ + try: + params = list(inspect.signature(processor_class.__init__).parameters)[2:] + except Exception: + return {}, kwargs - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + processor_kwargs = {k: kwargs[k] for k in params if k in kwargs} + remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs} + return processor_kwargs, remaining_kwargs - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - >>> # Prepare a cache class with offloading - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = OffloadedStaticCache( - ... config=model.config, - ... max_batch_size=1, - ... max_cache_len=max_generated_length, - ... device=model.device, - ... dtype=model.dtype - ... ) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache with offloaded layers - OffloadedStaticCache() - ``` +def parse_layer_args_from_model_config( + config: Optional[PretrainedConfig], + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: Optional[torch.dtype] = None, + layer_device_map: Optional[dict[int, torch.device]] = None, + tp_size: Optional[int] = None, + max_batch_size: Optional[int] = None, +) -> dict: """ + Parse layer arguments from model configuration for cache initialization. - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) - + Args: + config (`Optional[PretrainedConfig]`): Model configuration containing shape/device info. + batch_size (`Optional[int]`): Batch size for cache initialization. + max_cache_len (`Optional[int]`): Maximum sequence length for cache. + device (`Union[torch.device, str, None]`): Device for cache tensors. + dtype (`Optional[torch.dtype]`): Data type for cache tensors. + layer_device_map: Per-layer device mapping. + tp_size (`Optional[int]`): Tensor parallel size to adjust number of key/value heads. + max_batch_size (`Optional[int]`): Maximum batch size for cache initialization. -class HybridChunkedCache(Cache): + Returns: + `dict`: Dictionary containing parsed layer arguments for cache initialization. """ - Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window - attention and global attention in every other layer, with support for chunked attention (originally implemented - for Llama4). - Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention. For more information, see the documentation of each subcomponent cache class. + # No model config -> must be a dynamic cache, return bare dict + if config is None: + return {} + # Build the args dict for hybrid, sliding or static + else: + # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) + if ( + getattr(config, "layer_types", None) is not None + and "sliding_attention" in config.layer_types + and "full_attention" in config.layer_types + ): + if getattr(config, "sliding_window", None) is None: + raise ValueError( + "Setting up a hybrid or sliding window KVCache requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) + max_cache_len = max_cache_len or config.max_position_embeddings + if getattr(config, "sliding_window", None) is not None: + sliding_window_len = min(config.sliding_window, max_cache_len) + else: + sliding_window_len = None + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: + head_dim = ( + config.head_dim + if getattr(config, "head_dim", None) is not None + else config.hidden_size // config.num_attention_heads + ) + num_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + if tp_size is not None and tp_size > 1: + if num_heads % tp_size != 0: + raise ValueError( + f"Number of key value heads {num_heads} must be divisible by tensor parallel size {tp_size}." + ) + # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. + num_heads //= tp_size + layer_args = { + "batch_size": max_batch_size if max_batch_size is not None else batch_size, + "max_cache_len": max_cache_len, + "device": torch.device(device) if device is not None else None, + "dtype": dtype, + "layer_device_map": layer_device_map, + "head_dim": head_dim, + "num_heads": num_heads, + "sliding_window": sliding_window_len, + } + return {k: v for k, v in layer_args.items() if v is not None} - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - Example: +LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = { + "full_attention": StaticLayer, + "sliding_attention": SlidingWindowLayer, + "chunked_attention": ChunkedAttentionLayer, +} +PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = { + "offloaded": OffloadedCacheProcessor, + "quanto_quantized": QuantizedCacheProcessor, + "hqq_quantized": HQQQuantizedCacheProcessor, +} - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") +### Deprecated classes - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` - """ +class KeyValuesWrapper: + """Helper class for Cache that simulates layer-indexed key/value lists from a layered cache. + This allows for BC access and writing, e.g., cache.key_cache[idx] = ... + Deprecated in favor of Cache.layers[idx].keys/values. TODO: remove in v4.56.0""" - def __init__(self, config: PretrainedConfig, *args, **kwargs): - # Ugly hack for BC: if layer_types is not set, fallback to StaticCache. Otherwise, Cache init will use config.layer_types - layer_classes = [StaticLayer] if not hasattr(config, "layer_types") else None - super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) + def __init__(self, layers, cache_type="key"): + self.layers = layers + self.cache_type = cache_type + def __getitem__(self, idx): + if isinstance(idx, slice): + return [getattr(layer, f"{self.cache_type}_cache") for layer in self.layers[idx]] + return getattr(self.layers[idx], f"{self.cache_type}_cache") -class OffloadedHybridCache(HybridChunkedCache): - """ - A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading - cache tensors to CPU when not actively being used. + def __setitem__(self, idx, value): + if isinstance(idx, slice): + for layer, val in zip(self.layers[idx], value): + setattr(layer, f"{self.cache_type}_cache", val) + else: + setattr(self.layers[idx], f"{self.cache_type}_cache", value) - This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling - much longer sequences by offloading inactive layers to CPU memory. - """ + def __len__(self): + return len(self.layers) - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) + def __iter__(self): + for layer in self.layers: + yield getattr(layer, f"{self.cache_type}_cache") + + def __bool__(self): + return bool(self.layers) class SinkCache(Cache): From 4bb48fcfdab3633c1cd57a001f6a9032437673fc Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 14 Jul 2025 19:09:48 +0200 Subject: [PATCH 20/35] bfff come on LFM2 --- src/transformers/models/lfm2/modeling_lfm2.py | 3 +++ src/transformers/models/lfm2/modular_lfm2.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 577f73538bae..684eb859a24d 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -128,8 +128,11 @@ class Lfm2HybridConvCache(DynamicCache): Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. """ + # Override @property existing in Cache max_batch_size = None is_compileable = False + key_cache = None + value_cache = None def __init__( self, diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index f1b65172435f..d8f7de5fd9ee 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -89,8 +89,11 @@ class Lfm2HybridConvCache(DynamicCache): Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. """ + # Override @property existing in Cache max_batch_size = None is_compileable = False + key_cache = None + value_cache = None def __init__( self, From 00b1f96ad09c1137cfe75e453c4b2bc1dd2abe1e Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 15 Jul 2025 13:24:20 +0200 Subject: [PATCH 21/35] better tests for hybrid and hybridChunked --- src/transformers/cache_utils.py | 38 +++--- tests/utils/test_cache_utils.py | 228 +++++++++++++++++++++++++++++--- 2 files changed, 233 insertions(+), 33 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 154eba290607..7ebd7d10f979 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -339,7 +339,7 @@ def update( new_v = value_states[:, :, -self.max_cache_len :, :] self.keys.copy_(new_k) self.values.copy_(new_v) - return self.keys, self.values + return key_states, value_states # Sliding window logic for generation phase or prefill < window slicing = torch.arange(self.max_cache_len, device=value_states.device) @@ -377,16 +377,12 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: return local_mask_kv_length, local_mask_kv_offset -class ChunkedAttentionLayer(StaticLayer): - """ - A static cache layer that implements chunked attention caching. - Inherits from StaticLayer but uses chunked attention update logic. - """ +class ChunkedSlidingLayer(SlidingWindowLayer): + """An extended SlidingWindowLayer that supports prefill chunking.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cumulative_length = 0 - self.is_sliding = True def update( self, @@ -396,7 +392,7 @@ def update( ) -> tuple[torch.Tensor, torch.Tensor]: cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None if cache_position is None: - raise ValueError("`cache_position` must be provided for ChunkedAttentionLayer.") + raise ValueError("`cache_position` must be provided for ChunkedSlidingLayer.") key_states = key_states.to(self.keys.dtype) value_states = value_states.to(self.values.dtype) @@ -979,6 +975,11 @@ def __init__( ): layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] self.layer_classes = layer_classes or [DynamicLayer] + hybrid_chunked = kwargs.pop("hybrid_chunked", "llama4" in getattr(config, "model_type", "")) + if hybrid_chunked: + self.layer_classes = [ + ChunkedSlidingLayer if cls == SlidingWindowLayer else cls for cls in self.layer_classes + ] processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) self.layer_init_args = parse_layer_args_from_model_config(config, *args, **kwargs) @@ -1042,7 +1043,7 @@ def append_new_layers(self, layer_idx: int) -> None: args = self.layer_init_args.copy() if self.layer_init_args.get("layer_device_map", None) is not None: args["device"] = args.pop("layer_device_map")[layer_idx] - new_layer = self.layer_classes[layer_idx % len(self.layer_classes)](**args) + new_layer = self.layer_classes[len(self.layers) % len(self.layer_classes)](**args) self.layers.append(new_layer) @apply_processors @@ -1098,7 +1099,7 @@ def key_cache(self) -> "KeyValuesWrapper": warnings.warn( "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead." ) - return KeyValuesWrapper(self.layers, "key") + return KeyValuesWrapper(self.layers, "keys") @property def value_cache(self) -> "KeyValuesWrapper": @@ -1106,7 +1107,7 @@ def value_cache(self) -> "KeyValuesWrapper": warnings.warn( "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead." ) - return KeyValuesWrapper(self.layers, "value") + return KeyValuesWrapper(self.layers, "values") ### Wrappers for layer operations and properties ### @@ -1586,6 +1587,7 @@ class HybridChunkedCache(Cache): def __init__(self, config: PretrainedConfig, *args, **kwargs): # Ugly hack for BC: if layer_types is not set, fallback to StaticCache. Otherwise, Cache init will use config.layer_types layer_classes = [StaticLayer] if not hasattr(config, "layer_types") else None + kwargs["hybrid_chunked"] = True super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) @@ -1889,7 +1891,7 @@ def parse_layer_args_from_model_config( LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = { "full_attention": StaticLayer, "sliding_attention": SlidingWindowLayer, - "chunked_attention": ChunkedAttentionLayer, + "chunked_attention": SlidingWindowLayer, } PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = { "offloaded": OffloadedCacheProcessor, @@ -1906,28 +1908,28 @@ class KeyValuesWrapper: This allows for BC access and writing, e.g., cache.key_cache[idx] = ... Deprecated in favor of Cache.layers[idx].keys/values. TODO: remove in v4.56.0""" - def __init__(self, layers, cache_type="key"): + def __init__(self, layers, cache_type="keys"): self.layers = layers self.cache_type = cache_type def __getitem__(self, idx): if isinstance(idx, slice): - return [getattr(layer, f"{self.cache_type}_cache") for layer in self.layers[idx]] - return getattr(self.layers[idx], f"{self.cache_type}_cache") + return [getattr(layer, self.cache_type) for layer in self.layers[idx]] + return getattr(self.layers[idx], self.cache_type) def __setitem__(self, idx, value): if isinstance(idx, slice): for layer, val in zip(self.layers[idx], value): - setattr(layer, f"{self.cache_type}_cache", val) + setattr(layer, self.cache_type, val) else: - setattr(self.layers[idx], f"{self.cache_type}_cache", value) + setattr(self.layers[idx], self.cache_type, value) def __len__(self): return len(self.layers) def __iter__(self): for layer in self.layers: - yield getattr(layer, f"{self.cache_type}_cache") + yield getattr(layer, self.cache_type) def __bool__(self): return bool(self.layers) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 5c5ee08da6ee..59e6a18c49f4 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -45,20 +45,24 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, - Cache, ClvpForCausalLM, - DynamicCache, Gemma2Config, GenerationConfig, - HybridCache, LlamaConfig, + convert_and_export_with_cache, + pipeline, + ) + from transformers.cache_utils import ( + Cache, + DynamicCache, + HQQQuantizedCacheProcessor, + HybridCache, + HybridChunkedCache, QuantizedCache, + QuantoQuantizedCacheProcessor, SlidingWindowCache, StaticCache, - convert_and_export_with_cache, - pipeline, ) - from transformers.cache_utils import HQQQuantizedCacheProcessor, QuantoQuantizedCacheProcessor from transformers.integrations.executorch import export_with_dynamic_cache @@ -912,7 +916,7 @@ def setUp(self): head_dim=1, hidden_size=1, sliding_window=self.window_size, - sliding_window_pattern=2, # Default pattern for hybrid sliding + layer_types=["full_attention"] * 1, # Static cache by default ) def test_static_cache_out_of_bounds(self): @@ -979,7 +983,9 @@ def test_sliding_window_cache(self): result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ # Scenario 1: Update within window, no slide yet - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + config = copy.deepcopy(self.config) + config.layer_types = ["sliding_attention"] * config.num_hidden_layers + sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] sliding_cache.update( key_states=prefill, @@ -1000,7 +1006,7 @@ def test_sliding_window_cache(self): ) # Scenario 2: Update causing slide - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] sliding_cache.update( key_states=prefill, @@ -1021,7 +1027,7 @@ def test_sliding_window_cache(self): ) # Scenario 3: Long prompt handling - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] sliding_cache.update( key_states=long_prefill, @@ -1036,7 +1042,7 @@ def test_sliding_window_cache(self): ) def test_hybrid_cache_static_mode(self): - """Test HybridCache in static mode with hardcoded assertions. + """Test HybridCache with only 1 static layer. Scenario 1: Static layer behavior prefill: [1.0, 2.0, 0.0, 0.0] @@ -1046,7 +1052,7 @@ def test_hybrid_cache_static_mode(self): update pos 3: [1.0, 2.0, 3.0, 4.0] """ config = copy.deepcopy(self.config) - config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0) + config.layer_types = ["full_attention"] * config.num_hidden_layers # Scenario 1 hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) @@ -1100,8 +1106,10 @@ def test_hybrid_cache_sliding_mode(self): input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ + config = copy.deepcopy(self.config) + config.layer_types = ["sliding_attention"] * config.num_hidden_layers # Scenario 1: Update within window, no slide yet - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, @@ -1122,7 +1130,7 @@ def test_hybrid_cache_sliding_mode(self): ) # Scenario 2: Update causing first slide - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, @@ -1156,7 +1164,7 @@ def test_hybrid_cache_sliding_mode(self): ) # Scenario 4: Long prompt handling - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] hybrid_cache.update( key_states=long_prefill, @@ -1212,3 +1220,193 @@ def test_dynamic_cache(self): [10.0, 20.0, 30.0, 40.0], "DynamicCache Scenario 2 layer 1 failed", ) + + def test_hybrid_cache(self): + """ + Test HybridCache with a mix of static and sliding layers, + with prefill size bigger than sliding window. + + prefill: + static: [1.0, 2.0, 3.0] + sliding: [10.0, 20.0, 30.0] + (stores only [20.0, 30.0]) + + update pos 4: + static: [1.0, 2.0, 3.0, 5.0] + sliding: [30.0, 50.0] + """ + config = copy.deepcopy(self.config) + config.num_hidden_layers = 2 + config.layer_types = ["full_attention", "sliding_attention"] + config.sliding_window = 2 + hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + + # Prefill both layers up to cache capacity + prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None] + # Sliding window is 2, so it should return full [10.0, 20.0, 30.0], but store only [20.0, 30.0] + prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None] + + # Update static layer (layer 0) + res_static = hybrid_cache.update( + key_states=prefill_static, + value_states=prefill_static, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(3)}, + ) + + # Update sliding layer (layer 1) + res_sliding = hybrid_cache.update( + key_states=prefill_sliding, + value_states=prefill_sliding, + layer_idx=1, + cache_kwargs={"cache_position": torch.arange(3), "sliding_window": self.window_size}, + ) + + # Verify initial states + self.assertEqual( + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 0.0], + "Initial static layer state is wrong", + ) + self.assertEqual( + res_static[0][0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 0.0], + "Static layer did not return the correct value.", + ) + self.assertEqual( + hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(), + [20.0, 30.0], + "Initial sliding layer state is wrong", + ) + self.assertEqual( + res_sliding[0][0, 0, :, 0].tolist(), + [10.0, 20.0, 30.0], + "Sliding layer did not return the correct value.", + ) + + # Update at position 4 + new_key_static = torch.tensor(5.0)[None, None, None, None] + new_key_sliding = torch.tensor(50.0)[None, None, None, None] + + # Update static layer (layer 0) + hybrid_cache.update( + key_states=new_key_static, + value_states=new_key_static, + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + + # Update sliding layer (layer 1) + hybrid_cache.update( + key_states=new_key_sliding, + value_states=new_key_sliding, + layer_idx=1, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + + # The static layer does not slide, so it should have updated the element at position 3 + self.assertEqual( + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 5.0], + "Static layer did not update as expected.", + ) + + # The sliding layer should have shifted, discarding the first element and adding the new one at the end + self.assertEqual( + hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(), + [30.0, 50.0], + "Sliding layer did not slide as expected.", + ) + + def test_hybrid_chunked_cache(self): + """ + Test HybridChunkedCache special cases that it handles: + 1. a pre-fill longer than the sliding window + 2. a single-token decoding step (normal generation) + 3. a multi-token decoding step after the window is already full + + Sliding-window size: 2 + Static layer is full-attention. + ───────────────────────────────────────────── + Prefill: + static : [1, 2, 3] + sliding : [10, 20, 30] (cache keeps [20, 30]) + +1 token: + static : [1, 2, 3, 5] + sliding : [30, 50] (returned [30, 50]) + +2 tokens: + sliding : [60, 70] (returned [50, 60, 70]) + """ + + config = copy.deepcopy(self.config) + config.num_hidden_layers = 2 + config.layer_types = ["full_attention", "sliding_attention"] + config.sliding_window = 2 + max_cache_len = 4 # window == max_cache_len for sliding layer + chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len) + + # ------------------------------------------------------------------ # + # 1) PREFILL (3 tokens > window) + # ------------------------------------------------------------------ # + prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None] + prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None] + + res_static = chunked_cache.update( + key_states=prefill_static, + value_states=prefill_static, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(3)}, + ) + res_sliding = chunked_cache.update( + key_states=prefill_sliding, + value_states=prefill_sliding, + layer_idx=1, + cache_kwargs={"cache_position": torch.arange(3)}, + ) + + # Static layer keeps everything + self.assertEqual(res_static[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0]) + # Sliding layer returned full prompt but stored the tail + self.assertEqual(res_sliding[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0]) + self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [20.0, 30.0]) + + # ------------------------------------------------------------------ # + # 2) ONE-TOKEN UPDATE (normal decode) + # ------------------------------------------------------------------ # + new_static = torch.tensor(5.0)[None, None, None, None] + new_sliding = torch.tensor(50.0)[None, None, None, None] + + chunked_cache.update( + key_states=new_static, + value_states=new_static, + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + res_one = chunked_cache.update( + key_states=new_sliding, + value_states=new_sliding, + layer_idx=1, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + + # Static grew by one + self.assertEqual(chunked_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 5.0]) + # Sliding window slid by exactly 1 + self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [30.0, 50.0]) + self.assertEqual(res_one[0][0, 0, :, 0].tolist(), [30.0, 50.0]) + + # ------------------------------------------------------------------ # + # 3) TWO-TOKEN UPDATE after window is full + # ------------------------------------------------------------------ # + new_sliding_2 = torch.tensor([60.0, 70.0])[None, None, :, None] # shape (1,1,2,1) + res_two = chunked_cache.update( + key_states=new_sliding_2, + value_states=new_sliding_2, + layer_idx=1, + cache_kwargs={"cache_position": torch.tensor([4, 5])}, # arbitrary positions; ignored in full mode + ) + + # Cache now keeps the latest two tokens + self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [60.0, 70.0]) + # Returned tensor contains previous last token + new ones + self.assertEqual(res_two[0][0, 0, :, 0].tolist(), [50.0, 60.0, 70.0]) From 38e86034517b5e1aa2d67ca12db78963acac44de Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 15 Jul 2025 14:38:36 +0200 Subject: [PATCH 22/35] complete coverage for hybrid chunked caches (prefill chunking) --- tests/utils/test_cache_utils.py | 62 ++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 59e6a18c49f4..9e3645f9d013 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -1320,7 +1320,7 @@ def test_hybrid_cache(self): def test_hybrid_chunked_cache(self): """ - Test HybridChunkedCache special cases that it handles: + Test HybridChunkedCache with both static and sliding layers and special cases: 1. a pre-fill longer than the sliding window 2. a single-token decoding step (normal generation) 3. a multi-token decoding step after the window is already full @@ -1342,12 +1342,10 @@ def test_hybrid_chunked_cache(self): config.num_hidden_layers = 2 config.layer_types = ["full_attention", "sliding_attention"] config.sliding_window = 2 - max_cache_len = 4 # window == max_cache_len for sliding layer + max_cache_len = 4 chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len) - # ------------------------------------------------------------------ # - # 1) PREFILL (3 tokens > window) - # ------------------------------------------------------------------ # + # 1) PREFILL (3 tokens > sliding_window) prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None] prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None] @@ -1370,9 +1368,7 @@ def test_hybrid_chunked_cache(self): self.assertEqual(res_sliding[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0]) self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [20.0, 30.0]) - # ------------------------------------------------------------------ # # 2) ONE-TOKEN UPDATE (normal decode) - # ------------------------------------------------------------------ # new_static = torch.tensor(5.0)[None, None, None, None] new_sliding = torch.tensor(50.0)[None, None, None, None] @@ -1389,16 +1385,12 @@ def test_hybrid_chunked_cache(self): cache_kwargs={"cache_position": torch.tensor([3])}, ) - # Static grew by one self.assertEqual(chunked_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 5.0]) - # Sliding window slid by exactly 1 self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [30.0, 50.0]) self.assertEqual(res_one[0][0, 0, :, 0].tolist(), [30.0, 50.0]) - # ------------------------------------------------------------------ # # 3) TWO-TOKEN UPDATE after window is full - # ------------------------------------------------------------------ # - new_sliding_2 = torch.tensor([60.0, 70.0])[None, None, :, None] # shape (1,1,2,1) + new_sliding_2 = torch.tensor([60.0, 70.0])[None, None, :, None] res_two = chunked_cache.update( key_states=new_sliding_2, value_states=new_sliding_2, @@ -1410,3 +1402,49 @@ def test_hybrid_chunked_cache(self): self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [60.0, 70.0]) # Returned tensor contains previous last token + new ones self.assertEqual(res_two[0][0, 0, :, 0].tolist(), [50.0, 60.0, 70.0]) + + def test_hybrid_chunked_cache_extra_cases(self): + """ + Covers the new cases that appear on prefill chunking: + 1) Not full multi-token update (cache_position[0] + update_len <= max_cache_len) + 2) Multi-token update crossing the window (cache_position[0] < max_cache_len and cache_position[0] + update_len > max_cache_len) + + Single sliding layer, max_cache_len = 3. + + Step 0 (prefill 2 tokens, update_len < max_cache_len + cache = [10, 20, 0] returned [10, 20, 0] + + Step 1 (add 2 tokens, p = 2, update_len = 2, p + update_len = 4 > max_cache_len) + cache = [20, 30, 40] returned [10, 20, 30, 40] + """ + + config = copy.deepcopy(self.config) + config.num_hidden_layers = 1 + config.layer_types = ["sliding_attention"] + config.sliding_window = 3 + cache = HybridChunkedCache(config, max_batch_size=1, max_cache_len=3) + + # Step 0 : multi-token prefill + first_chunk = torch.tensor([10.0, 20.0])[None, None, :, None] # L = 2 + returned_0 = cache.update( + key_states=first_chunk, + value_states=first_chunk, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(2)}, # p = 0,1 + ) + + # internal cache should have first two tokens and a zero pad + self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [10.0, 20.0, 0.0]) + self.assertEqual(returned_0[0][0, 0, :, 0].tolist(), [10.0, 20.0, 0.0]) + + # Step 1 : multi-token update crossing the window boundary + second_chunk = torch.tensor([30.0, 40.0])[None, None, :, None] # L = 2 + returned_1 = cache.update( + key_states=second_chunk, + value_states=second_chunk, + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2, 3])}, # p = 2 + ) + + self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [20.0, 30.0, 40.0]) + self.assertEqual(returned_1[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0]) From 34a3022fb0f348f68214c5a53bbfcccdfb5f3b20 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 15 Jul 2025 14:39:21 +0200 Subject: [PATCH 23/35] reimplementing HybridChunked --- src/transformers/cache_utils.py | 72 +++++++++++++++------------------ 1 file changed, 33 insertions(+), 39 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 7ebd7d10f979..5b897f445b73 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -380,57 +380,51 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: class ChunkedSlidingLayer(SlidingWindowLayer): """An extended SlidingWindowLayer that supports prefill chunking.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.cumulative_length = 0 - def update( self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None - if cache_position is None: + cache_pos = cache_kwargs.get("cache_position") if cache_kwargs else None + if cache_pos is None: raise ValueError("`cache_position` must be provided for ChunkedSlidingLayer.") key_states = key_states.to(self.keys.dtype) value_states = value_states.to(self.values.dtype) - - cumulative_length = self.cumulative_length - self.cumulative_length += key_states.shape[-2] - is_full = cumulative_length >= self.max_cache_len - - if is_full: - full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) - full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) - if key_states.shape[-2] == 1: - self.keys.copy_(full_key_states) - self.values.copy_(full_value_states) + states_len = key_states.size(-2) + + # Case 1: states are longer than the window + if states_len > self.max_cache_len: + self.keys.copy_(key_states[:, :, -self.max_cache_len :, :]) + self.values.copy_(value_states[:, :, -self.max_cache_len :, :]) + return key_states, value_states # full prompt returned + + # Case 2: already full before the call + if cache_pos[0] >= self.max_cache_len: + full_k = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) + full_v = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) + if states_len == 1: # fast decode path, return tensors that have been marked as static address + self.keys.copy_(full_k) + self.values.copy_(full_v) return self.keys, self.values - elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len: - if cumulative_length == 0: - full_key_states = key_states - full_value_states = value_states else: - full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2) - full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2) - else: - try: - self.keys.index_copy_(2, cache_position, key_states) - self.values.index_copy_(2, cache_position, value_states) - except NotImplementedError: - self.keys[:, :, cache_position] = key_states - self.values[:, :, cache_position] = value_states - return self.keys, self.values - - self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :]) - self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :]) - return full_key_states, full_value_states - - def reset(self) -> None: - super().reset() - self.cumulative_length = 0 + self.keys.copy_(full_k[:, :, -self.max_cache_len :, :]) + self.values.copy_(full_v[:, :, -self.max_cache_len :, :]) + return full_k, full_v + + # Case 3: will overflow during this call + if cache_pos[0] + states_len > self.max_cache_len: + full_k = torch.cat((self.keys[:, :, : cache_pos[0], :], key_states), dim=-2) + full_v = torch.cat((self.values[:, :, : cache_pos[0], :], value_states), dim=-2) + self.keys.copy_(full_k[:, :, -self.max_cache_len :, :]) + self.values.copy_(full_v[:, :, -self.max_cache_len :, :]) + return full_k, full_v + + # Case 4: still filling + self.keys.index_copy_(2, cache_pos, key_states) + self.values.index_copy_(2, cache_pos, value_states) + return self.keys, self.values def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: query_length = cache_position.shape[0] From 4222653604baa0fa267f906038de0bebbbe957e0 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 16 Jul 2025 18:25:43 +0200 Subject: [PATCH 24/35] cyril review --- src/transformers/cache_utils.py | 233 +++++++++++++++----------------- 1 file changed, 111 insertions(+), 122 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 5b897f445b73..856944236d7e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -25,42 +25,14 @@ logger = logging.get_logger(__name__) -def apply_processors( - fn: Callable[..., tuple[torch.Tensor, torch.Tensor]], -) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: - @functools.wraps(fn) - def _wrapped_update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Wrapper around the update method to apply cache processors. - """ - if self.cache_processor is not None: - key_states, value_states = self.cache_processor.pre_update( - self, key_states, value_states, layer_idx, cache_kwargs - ) - - key_tensors, value_tensors = fn(self, key_states, value_states, layer_idx, cache_kwargs) - - if self.cache_processor is not None: - key_tensors, value_tensors = self.cache_processor.post_update( - self, key_tensors, value_tensors, layer_idx, cache_kwargs - ) - - return key_tensors, value_tensors - - return _wrapped_update - - class CacheLayerMixin: """Base, abstract class for a single layer's cache.""" is_compileable = False + def __init__(self): + self.keys, self.values = None, None + def update( self, key_states: torch.Tensor, @@ -70,7 +42,7 @@ def update( """Updates KV cache, returns updated keys/values for this layer.""" raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") - def get_seq_length(self) -> int: + def get_seq_length(self, cache_position=None) -> int: """Returns the sequence length of this layer's cache.""" raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.") @@ -78,14 +50,15 @@ def get_max_cache_shape(self) -> int: """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") - def reset(self) -> tuple[torch.Tensor, torch.Tensor]: - """Resets this layer's cache.""" - raise NotImplementedError(f"Make sure to implement `reset` in {self.__class__.__name__}.") - def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: """Returns mask sizes for this layer's cache.""" raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.") + def reset(self) -> None: + """Resets the cache values while preserving the objects""" + self.keys.zero_() + self.values.zero_() + def reorder_cache(self, beam_idx: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]: """Reorders this layer's cache for beam search.""" if self.keys.numel(): @@ -102,8 +75,6 @@ class DynamicLayer(CacheLayerMixin): It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. """ - keys, values = None, None - @classmethod def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer": layer = cls() @@ -139,7 +110,7 @@ def update( self.values = torch.cat([self.values, value_states], dim=-2) return self.keys, self.values - def get_seq_length(self) -> int: + def get_seq_length(self, cache_position=None) -> int: """Returns the sequence length of the cached states.""" if self.keys is None or self.keys.numel() == 0: return 0 @@ -149,12 +120,6 @@ def get_max_cache_shape(self) -> int: """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" return -1 - def reset(self) -> None: - """Resets the cache values while preserving the objects""" - self.keys.zero_() - self.values.zero_() - return self.keys, self.values - def reorder_cache(self, beam_idx: torch.LongTensor) -> None: """Reorders the cache for beam search, given the selected beam indices.""" if self.keys is not None and self.keys.numel(): @@ -190,11 +155,11 @@ def batch_select_indices(self, indices: torch.Tensor) -> None: def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: """Return the length and offset of the cache, used to generate the mask""" - full_mask_kv_offset = 0 + kv_offset = 0 query_length = cache_position.shape[0] past_seen_tokens = self.get_seq_length() kv_length = query_length + past_seen_tokens - return kv_length, full_mask_kv_offset + return kv_length, kv_offset class StaticLayer(CacheLayerMixin): @@ -283,10 +248,6 @@ def get_seq_length(self, cache_position=None) -> int: seq_length = (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0 return seq_length - def reset(self) -> None: - self.keys.zero_() - self.values.zero_() - def reorder_cache(self, beam_idx: torch.LongTensor) -> None: dev = self.keys.device beam_idx_dev = beam_idx.to(dev) @@ -294,9 +255,9 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None: self.values = self.values.index_select(0, beam_idx_dev) def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: - full_mask_kv_offset = 0 - full_mask_kv_length = self.max_cache_len - return full_mask_kv_length, full_mask_kv_offset + kv_offset = 0 + kv_length = self.max_cache_len + return kv_length, kv_offset class SlidingWindowLayer(StaticLayer): @@ -371,78 +332,80 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: query_length = cache_position.shape[0] first_cache_position = cache_position[0] - local_mask_kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0) + kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0) # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns - local_mask_kv_length = max(query_length, self.max_cache_len) - return local_mask_kv_length, local_mask_kv_offset + kv_length = max(query_length, self.max_cache_len) + return kv_length, kv_offset class ChunkedSlidingLayer(SlidingWindowLayer): """An extended SlidingWindowLayer that supports prefill chunking.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cumulative_length = 0 + def update( self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - cache_pos = cache_kwargs.get("cache_position") if cache_kwargs else None - if cache_pos is None: + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + if cache_position is None: raise ValueError("`cache_position` must be provided for ChunkedSlidingLayer.") - key_states = key_states.to(self.keys.dtype) - value_states = value_states.to(self.values.dtype) - states_len = key_states.size(-2) - - # Case 1: states are longer than the window - if states_len > self.max_cache_len: - self.keys.copy_(key_states[:, :, -self.max_cache_len :, :]) - self.values.copy_(value_states[:, :, -self.max_cache_len :, :]) - return key_states, value_states # full prompt returned - - # Case 2: already full before the call - if cache_pos[0] >= self.max_cache_len: - full_k = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) - full_v = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) - if states_len == 1: # fast decode path, return tensors that have been marked as static address - self.keys.copy_(full_k) - self.values.copy_(full_v) - return self.keys, self.values + cumulative_length = self.cumulative_length + self.cumulative_length += key_states.shape[-2] + is_full = cumulative_length >= self.max_cache_len + + if is_full: + full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) + full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) + if key_states.shape[-2] == 1: + self.keys.copy_(full_key_states) + self.values.copy_(full_value_states) + elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len: + if cumulative_length == 0: + full_key_states = key_states + full_value_states = value_states else: - self.keys.copy_(full_k[:, :, -self.max_cache_len :, :]) - self.values.copy_(full_v[:, :, -self.max_cache_len :, :]) - return full_k, full_v - - # Case 3: will overflow during this call - if cache_pos[0] + states_len > self.max_cache_len: - full_k = torch.cat((self.keys[:, :, : cache_pos[0], :], key_states), dim=-2) - full_v = torch.cat((self.values[:, :, : cache_pos[0], :], value_states), dim=-2) - self.keys.copy_(full_k[:, :, -self.max_cache_len :, :]) - self.values.copy_(full_v[:, :, -self.max_cache_len :, :]) - return full_k, full_v - - # Case 4: still filling - self.keys.index_copy_(2, cache_pos, key_states) - self.values.index_copy_(2, cache_pos, value_states) - return self.keys, self.values + full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2) + full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2) + else: + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + return self.keys, self.values + + self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :]) + self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :]) + return full_key_states, full_value_states + + def reset(self) -> None: + super().reset() + self.cumulative_length = 0 def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: query_length = cache_position.shape[0] first_cache_position = cache_position[0] sliding_window = self.max_cache_len - local_mask_kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0) + kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0) # This is the true general case for any Cache using local attention (sliding or chunked) if first_cache_position >= sliding_window: # Here the Cache is already full - local_mask_kv_length = sliding_window + query_length - 1 + kv_length = sliding_window + query_length - 1 elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window: # Here the Cache becomes full with the new input - local_mask_kv_length = first_cache_position + query_length + kv_length = first_cache_position + query_length else: # Here the Cache is still smaller than the local size, but we return the local size as it's static - local_mask_kv_length = sliding_window - return local_mask_kv_length, local_mask_kv_offset + kv_length = sliding_window + return kv_length, kv_offset class CacheProcessor: @@ -922,6 +885,37 @@ def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tens return tensor +def apply_processors( + fn: Callable[..., tuple[torch.Tensor, torch.Tensor]], +) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: + @functools.wraps(fn) + def _wrapped_update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Wrapper around the update method to apply cache processors. + """ + if self.cache_processor is not None: + key_states, value_states = self.cache_processor.pre_update( + self, key_states, value_states, layer_idx, cache_kwargs + ) + + key_tensors, value_tensors = fn(self, key_states, value_states, layer_idx, cache_kwargs) + + if self.cache_processor is not None: + key_tensors, value_tensors = self.cache_processor.post_update( + self, key_tensors, value_tensors, layer_idx, cache_kwargs + ) + + return key_tensors, value_tensors + + return _wrapped_update + + class Cache: """ Base class for all caches. @@ -935,7 +929,7 @@ class Cache: Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") or a CacheProcessor class. layer_classes (`list[type[CacheLayer]]`, *optional*): - List of layer classes to use for the cache. + List of layer classes to use for the cache. Default is [DynamicLayer]. Additional arguments for cache configuration: max_batch_size/batch_size (`int`): Maximum batch size for static caches max_cache_len (`int`): Maximum sequence length. For hybrid caches: @@ -962,19 +956,10 @@ def __init__( self.layers: list["CacheLayerMixin"] = [] processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor - if ( - layer_classes is None # setting layer_classes takes precedence - and config is not None - and getattr(config, "layer_types", None) is not None - ): - layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] - self.layer_classes = layer_classes or [DynamicLayer] - hybrid_chunked = kwargs.pop("hybrid_chunked", "llama4" in getattr(config, "model_type", "")) - if hybrid_chunked: - self.layer_classes = [ - ChunkedSlidingLayer if cls == SlidingWindowLayer else cls for cls in self.layer_classes - ] + if layer_classes is None: + layer_classes = [DynamicLayer] + self.layer_classes = layer_classes processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) self.layer_init_args = parse_layer_args_from_model_config(config, *args, **kwargs) self.model_num_layers = getattr(config, "num_hidden_layers", 1) @@ -1068,14 +1053,14 @@ def update( self.append_new_layers(layer_idx) return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) - def get_seq_length(self, layer_idx: int = 0) -> int: + def get_seq_length(self, layer_idx: int = 0, cache_position=None) -> int: """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" if layer_idx >= len(self.layers): return 0 # Hack since QuantizedCache messes with keys shape as it becomes the residual cache if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): - return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length() - return self.layers[layer_idx].get_seq_length() + return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length(cache_position) + return self.layers[layer_idx].get_seq_length(cache_position) def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: """ @@ -1347,8 +1332,14 @@ class HybridCache(Cache): """ def __init__(self, config: PretrainedConfig, *args, **kwargs): - # Ugly hack for BC: if layer_types is not set, fallback to StaticCache. Otherwise, Cache init will use config.layer_types - layer_classes = [StaticLayer] if not hasattr(config, "layer_types") else None + layer_classes = ( + [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] + if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None + else [StaticLayer] + ) + hybrid_chunked = kwargs.pop("hybrid_chunked", "llama4" in getattr(config, "model_type", "")) + if hybrid_chunked: + layer_classes = [ChunkedSlidingLayer if cls == SlidingWindowLayer else cls for cls in layer_classes] super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) @@ -1532,7 +1523,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) -class HybridChunkedCache(Cache): +class HybridChunkedCache(HybridCache): """ Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window attention and global attention in every other layer, with support for chunked attention (originally implemented @@ -1579,10 +1570,8 @@ class HybridChunkedCache(Cache): """ def __init__(self, config: PretrainedConfig, *args, **kwargs): - # Ugly hack for BC: if layer_types is not set, fallback to StaticCache. Otherwise, Cache init will use config.layer_types - layer_classes = [StaticLayer] if not hasattr(config, "layer_types") else None kwargs["hybrid_chunked"] = True - super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) + super().__init__(config=config, *args, **kwargs) class OffloadedHybridCache(HybridChunkedCache): @@ -1703,10 +1692,10 @@ def from_legacy_cache( cache.is_updated[layer_idx] = True return cache - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position=None) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` - return self.self_attention_cache.get_seq_length(layer_idx) + return self.self_attention_cache.get_seq_length(layer_idx, cache_position) def reset(self): if hasattr(self.self_attention_cache, "reset"): From 1acc648fb99822dc1ff028b0194d5f4d360d0259 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Thu, 17 Jul 2025 10:20:24 +0200 Subject: [PATCH 25/35] fix ci --- docs/source/en/internal/generation_utils.md | 6 ++++++ src/transformers/cache_utils.py | 18 ------------------ .../models/bamba/modeling_bamba.py | 6 +++++- .../modeling_granitemoehybrid.py | 6 +++++- 4 files changed, 16 insertions(+), 20 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index b956828fdcff..ed319a1e54b0 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -371,8 +371,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] QuantoQuantizedCache +[[autodoc]] QuantoQuantizedCacheProcessor + [[autodoc]] HQQQuantizedCache +[[autodoc]] HQQQuantizedCacheProcessor + [[autodoc]] OffloadedCache [[autodoc]] StaticCache @@ -381,6 +385,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] HybridCache +[[autodoc]] HybridChunkedCache + [[autodoc]] SlidingWindowCache [[autodoc]] EncoderDecoderCache diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 856944236d7e..ef3fb54831cb 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1531,24 +1531,6 @@ class HybridChunkedCache(HybridCache): Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponent cache class. - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - Example: ```python diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index fa5dda36cd45..dcbc40de6373 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -85,7 +85,7 @@ class BambaFlashAttentionKwargs(TypedDict, total=False): seq_idx: torch.IntTensor -class HybridMambaAttentionDynamicCache(DynamicCache): +class HybridMambaAttentionDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -99,6 +99,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): super().__init__() self.layers_block_type = config.layers_block_type diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index df823c30bd34..e601bc1c46cb 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -221,7 +221,7 @@ def forward( return attn_output, attn_weights -class HybridMambaAttentionDynamicCache(DynamicCache): +class HybridMambaAttentionDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -235,6 +235,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=torch.float16, device=None): super().__init__() self.layers_block_type = config.layers_block_type From ca39ffecbcc2f793c5e41c94f96ea9de5cf0394f Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Fri, 18 Jul 2025 13:57:54 +0200 Subject: [PATCH 26/35] docs for cache refactor --- docs/source/en/internal/generation_utils.md | 76 ++++++++++++-- src/transformers/__init__.py | 10 ++ src/transformers/cache_utils.py | 107 ++++++++++++++++---- 3 files changed, 165 insertions(+), 28 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index ed319a1e54b0..1888d76e4a43 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -356,38 +356,102 @@ A [`Constraint`] can be used to force the generation to include specific tokens ## Caches -[[autodoc]] Cache +[[autodoc]] CacheLayerMixin + - update + - get_seq_length + - get_mask_sizes + - get_max_cache_shape + - reset + - reorder_cache + +[[autodoc]] DynamicLayer + - update + - get_seq_length + - get_mask_sizes + - get_max_cache_shape + - reset + - reorder_cache + - crop + - batch_repeat_interleave + - batch_select_indices + +[[autodoc]] StaticLayer - update + - get_seq_length + - get_mask_sizes + - get_max_cache_shape + - reset + - reorder_cache -[[autodoc]] CacheConfig - - update +[[autodoc]] SlidingWindowLayer + - update + - get_seq_length + - get_mask_sizes + - get_max_cache_shape + - reset + - reorder_cache + +[[autodoc]] CacheProcessor + - pre_update + - post_update + +[[autodoc]] OffloadedCacheProcessor + - pre_update -[[autodoc]] QuantizedCacheConfig - - validate +[[autodoc]] QuantizedCacheProcessor + - post_update + +[[autodoc]] QuantoQuantizedCacheProcessor + - post_update + +[[autodoc]] HQQQuantizedCacheProcessor + - post_update + +[[autodoc]] Cache + - update + - get_seq_length + - get_mask_sizes + - get_max_cache_shape + - reset + - reorder_cache + - crop + - batch_repeat_interleave + - batch_select_indices [[autodoc]] DynamicCache + - update + - to_legacy_cache + - from_legacy_cache [[autodoc]] QuantizedCache + - update [[autodoc]] QuantoQuantizedCache + - update [[autodoc]] QuantoQuantizedCacheProcessor [[autodoc]] HQQQuantizedCache - + - update [[autodoc]] HQQQuantizedCacheProcessor [[autodoc]] OffloadedCache + - update [[autodoc]] StaticCache + - update [[autodoc]] OffloadedStaticCache + - update [[autodoc]] HybridCache + - update [[autodoc]] HybridChunkedCache + - update [[autodoc]] SlidingWindowCache + - update [[autodoc]] EncoderDecoderCache - get_seq_length diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b6a8e3da7e9e..84892590b1af 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -365,6 +365,16 @@ ] _import_structure["activations"] = [] _import_structure["cache_utils"] = [ + "CacheLayerMixin", + "DynamicLayer", + "StaticLayer", + "SlidingWindowLayer", + "ChunkedSlidingLayer", + "CacheProcessor", + "OffloadedCacheProcessor", + "QuantizedCacheProcessor", + "QuantoQuantizedCacheProcessor", + "HQQQuantizedCacheProcessor", "Cache", "CacheConfig", "DynamicCache", diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ef3fb54831cb..5d455e13aa77 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -77,6 +77,19 @@ class DynamicLayer(CacheLayerMixin): @classmethod def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer": + """ + Build a `DynamicLayer` instance from pre-existing key/value tensors. + + Args: + keys (`torch.Tensor`): + Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. + values (`torch.Tensor`): + Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. + + Returns: + `DynamicLayer`: The newly constructed layer whose internal cache directly references + the supplied tensors. + """ layer = cls() layer.keys = keys layer.values = values @@ -176,6 +189,29 @@ def __init__( device: str = "cpu", sliding_window: Optional[int] = None, ): + """ + Args: + max_cache_len (`int`): + Maximum number of tokens that can be stored, used for tensor preallocation. + batch_size (`int`): + Maximum batch size the cache is pre-allocated for. + num_heads (`int`): + Number of attention heads. + head_dim (`int`): + Per-head hidden dimension. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the cache tensors. + device (`str` or `torch.device`, *optional*, defaults to `"cpu"`): + Device on which the cache tensors will be materialised. + sliding_window (`int`, *optional*): + When not ``None``, indicates that this layer will be used as a sliding-window + cache capped to ``sliding_window`` tokens (see `SlidingWindowLayer`). + + Notes: + Static layers allocate their full backing tensors up-front and mutate them + in-place. See the documentation of `Cache` for shared helper methods that + operate uniformly across all layer types. + """ self.max_cache_len = max_cache_len self.max_batch_size = batch_size self.num_heads = num_heads @@ -241,6 +277,7 @@ def update( return self.keys, self.values def get_seq_length(self, cache_position=None) -> int: + """Returns the sequence length of the cached states.""" if cache_position is not None: return int(cache_position[-1] + 1) # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's @@ -249,12 +286,14 @@ def get_seq_length(self, cache_position=None) -> int: return seq_length def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + """Reorders the cache for beam search, given the selected beam indices.""" dev = self.keys.device beam_idx_dev = beam_idx.to(dev) self.keys = self.keys.index_select(0, beam_idx_dev) self.values = self.values.index_select(0, beam_idx_dev) def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the attention mask""" kv_offset = 0 kv_length = self.max_cache_len return kv_length, kv_offset @@ -266,7 +305,13 @@ class SlidingWindowLayer(StaticLayer): Inherits from StaticLayer but uses sliding window update logic. """ - def __init__(self, sliding_window, max_cache_len=None, *args, **kwargs): + def __init__(self, sliding_window, *args, **kwargs): + """ + Args: + sliding_window (`int`): + Effective window size: number of tokens that are kept on each update call. + """ + kwargs.pop("max_cache_len", None) super().__init__(*args, max_cache_len=sliding_window, *args, **kwargs) def update( @@ -329,6 +374,7 @@ def update( return self.keys, self.values def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the attention mask""" query_length = cache_position.shape[0] first_cache_position = cache_position[0] @@ -339,7 +385,7 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: class ChunkedSlidingLayer(SlidingWindowLayer): - """An extended SlidingWindowLayer that supports prefill chunking.""" + """An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -410,7 +456,7 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: class CacheProcessor: """ - Base class for cache processors that can be applied to modify cache behavior. + Base class for cache processors. It defines a pre-update and post-update methods that are called before and after the cache update. This class should be subclassed. """ @@ -518,7 +564,7 @@ def pre_update( layer_idx: int, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Handle prefetching and eviction before cache update.""" + """Handles prefetching and eviction before cache update.""" # Update the cache if len(cache) < layer_idx: raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") @@ -918,20 +964,38 @@ def _wrapped_update( class Cache: """ - Base class for all caches. - The actual data structure is specific to the layers. - This class handles propagation of operations across layers. + Base container for per-layer key/value caches. + + A `Cache` behaves like a list of `CacheLayerMixin` objects, one per model layer. + Sub-classes such as `DynamicCache`, `StaticCache`, or `SlidingWindowCache` + simply pre-select which `CacheLayerMixin` class to use and may attach a + `CacheProcessor` (off-loading, quantization). + + Example + ------- + ```python + from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache + + model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + inputs = tok("Hello", return_tensors="pt") + + cache = DynamicCache() + outputs = model(**inputs, past_key_values=cache, use_cache=True) + ``` Parameters: config (`PretrainedConfig`, *optional*): - Model configuration for shape/device info. + Model configuration used to infer number of layers, head sizes, default + device/dtype, etc. cache_processor (`CacheProcessor` or `str`, *optional*): - Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") or - a CacheProcessor class. - layer_classes (`list[type[CacheLayer]]`, *optional*): - List of layer classes to use for the cache. Default is [DynamicLayer]. - Additional arguments for cache configuration: - max_batch_size/batch_size (`int`): Maximum batch size for static caches + Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") + or a CacheProcessor class. + layer_classes (`list[type[CacheLayerMixin]]`, *optional*): + List of `CacheLayerMixin` classes to instantiate for the cache. When shorter than the + required number of layers the list is cycled. Default is [DynamicLayer]. + Additional arguments passed to the layers: + max_batch_size (`int`): Maximum batch size for static caches max_cache_len (`int`): Maximum sequence length. For hybrid caches: * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` * StaticLayers: uses full `max_cache_len` @@ -939,10 +1003,9 @@ class Cache: dtype (`torch.dtype`): Data type for cache tensors layer_device_map (`dict[int, Union[str, torch.device]]`): Per-layer device mapping tp_size (`int`): Tensor parallel size to adjust the number of key/value heads - device (`torch.device`): Device for cache tensors - dtype (`torch.dtype`): Data type for cache tensors - layer_device_map (`dict[int, Union[str, torch.device]]`): Per-layer device mapping - tp_size (`int`): Tensor parallel size to adjust the number of key/value heads + + Additional keyword arguments are forwarded to the chosen layers constructor(s). See the + documentation of the relevant `CacheLayerMixin` class for more details. """ def __init__( @@ -1074,7 +1137,7 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ @property def key_cache(self) -> "KeyValuesWrapper": - """Returns a list-like object of key cache tensors indexed by layer.""" + """List-like object of key cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].keys`""" warnings.warn( "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead." ) @@ -1082,7 +1145,7 @@ def key_cache(self) -> "KeyValuesWrapper": @property def value_cache(self) -> "KeyValuesWrapper": - """Returns a list-like object of value cache tensors indexed by layer.""" + """List-like object of value cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].values`""" warnings.warn( "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead." ) @@ -1091,7 +1154,7 @@ def value_cache(self) -> "KeyValuesWrapper": ### Wrappers for layer operations and properties ### def get_max_cache_shape(self, layer_idx: int = 0) -> int: - """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" + """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length.""" return self.layers[layer_idx].get_max_cache_shape() def reset(self): @@ -1526,7 +1589,7 @@ def __init__(self, *args, **kwargs) -> None: class HybridChunkedCache(HybridCache): """ Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window - attention and global attention in every other layer, with support for chunked attention (originally implemented + attention and global attention in every other layer, with support for prefill chunking (originally implemented for Llama4). Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponent cache class. From 731d0b73afbe9bbd666d46a3b47f458d87af185d Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Fri, 11 Jul 2025 18:28:19 +0200 Subject: [PATCH 27/35] docs --- docs/source/en/internal/generation_utils.md | 34 +--- src/transformers/cache_utils.py | 189 ++++++++++++-------- 2 files changed, 121 insertions(+), 102 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 1888d76e4a43..c64ba2a3ca43 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -366,30 +366,15 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] DynamicLayer - update - - get_seq_length - - get_mask_sizes - - get_max_cache_shape - - reset - - reorder_cache - crop - batch_repeat_interleave - batch_select_indices [[autodoc]] StaticLayer - update - - get_seq_length - - get_mask_sizes - - get_max_cache_shape - - reset - - reorder_cache [[autodoc]] SlidingWindowLayer - update - - get_seq_length - - get_mask_sizes - - get_max_cache_shape - - reset - - reorder_cache [[autodoc]] CacheProcessor - pre_update @@ -419,52 +404,45 @@ A [`Constraint`] can be used to force the generation to include specific tokens - batch_select_indices [[autodoc]] DynamicCache - - update - to_legacy_cache - from_legacy_cache [[autodoc]] QuantizedCache - - update [[autodoc]] QuantoQuantizedCache - - update [[autodoc]] QuantoQuantizedCacheProcessor [[autodoc]] HQQQuantizedCache - - update + [[autodoc]] HQQQuantizedCacheProcessor [[autodoc]] OffloadedCache - - update [[autodoc]] StaticCache - - update [[autodoc]] OffloadedStaticCache - - update [[autodoc]] HybridCache - - update [[autodoc]] HybridChunkedCache - - update [[autodoc]] SlidingWindowCache - - update [[autodoc]] EncoderDecoderCache - - get_seq_length - to_legacy_cache - from_legacy_cache - - reset - - reorder_cache [[autodoc]] MambaCache - update_conv_state - update_ssm_state - reset +[[autodoc]] CacheConfig + +[[autodoc]] QuantizedCacheConfig + + ## Watermark Utils [[autodoc]] WatermarkingConfig diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 5d455e13aa77..07a9e557a4b7 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -39,7 +39,7 @@ def update( value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Updates KV cache, returns updated keys/values for this layer.""" + """Updates KV cache, returns updated keys/values of the layer.""" raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") def get_seq_length(self, cache_position=None) -> int: @@ -51,7 +51,7 @@ def get_max_cache_shape(self) -> int: raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: - """Returns mask sizes for this layer's cache.""" + """Returns mask sizes for the layer.""" raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.") def reset(self) -> None: @@ -73,6 +73,8 @@ class DynamicLayer(CacheLayerMixin): """ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. + + See `CacheLayerMixin` for details on common methods that are implemented by all cache layers. """ @classmethod @@ -176,6 +178,13 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: class StaticLayer(CacheLayerMixin): + """ + A static cache layer that stores the Key and Value states as static tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. + It allocates its full backing tensors up-front and mutates them in-place. Built for `torch.compile` support. + + See `CacheLayerMixin` for details on common methods that are implemented by all cache layers. + """ + is_compileable = True is_sliding = False @@ -199,13 +208,10 @@ def __init__( Number of attention heads. head_dim (`int`): Per-head hidden dimension. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + dtype (`torch.dtype`, defaults to `torch.float32`): Data type of the cache tensors. - device (`str` or `torch.device`, *optional*, defaults to `"cpu"`): + device (`str` or `torch.device`, defaults to `"cpu"`): Device on which the cache tensors will be materialised. - sliding_window (`int`, *optional*): - When not ``None``, indicates that this layer will be used as a sliding-window - cache capped to ``sliding_window`` tokens (see `SlidingWindowLayer`). Notes: Static layers allocate their full backing tensors up-front and mutate them @@ -302,7 +308,8 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: class SlidingWindowLayer(StaticLayer): """ A static cache layer that implements sliding window attention caching. - Inherits from StaticLayer but uses sliding window update logic. + + See `CacheLayerMixin` for details on common methods that are implemented by all cache layers. """ def __init__(self, sliding_window, *args, **kwargs): @@ -385,7 +392,11 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: class ChunkedSlidingLayer(SlidingWindowLayer): - """An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4.""" + """ + An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4. + + See `SlidingWindowLayer` for details on common methods that are implemented by all cache layers. + """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -489,7 +500,7 @@ def pre_update( cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states. + The modified key and value states. """ return key_states, value_states @@ -512,7 +523,7 @@ def post_update( cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The final key and value states to return. + The final key and value states to return to the model. """ return key_tensors, value_tensors @@ -634,33 +645,33 @@ def __init__( self, cache: "Cache", backend: str = "quanto", - nbits: Optional[int] = 4, - axis_key: Optional[int] = 0, - axis_value: Optional[int] = 0, - q_group_size: Optional[int] = 64, - residual_length: Optional[int] = 128, - compute_dtype: Optional[torch.dtype] = torch.float16, - device: Optional[str] = "cpu", + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + compute_dtype: torch.dtype = torch.float16, + device: str = "cpu", ): """ Parameters: - backend (`str`, *optional*, defaults to `"quanto"`): + backend (`str`, defaults to `"quanto"`): Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] - nbits (`Optional[int]`, *optional*, defaults to 4): + nbits (`int`, defaults to 4): Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. - axis_key (`int`, *optional*, defaults to 0): + axis_key (`int`, defaults to 0): Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - axis_value (`int`, *optional*, defaults to 0): + axis_value (`int`, defaults to 0): Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - q_group_size (`Optional[int]`, *optional*, defaults to 64): + q_group_size (`int`, defaults to 64): Size of the quantization group, should be a divisor of the model's hidden dimension. Defaults to 64. - residual_length (`Optional[int]`, *optional*, defaults to 128): + residual_length (`int`, defaults to 128): Length of the residual cache which will always be stored in original precision. Defaults to 128. - compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + compute_dtype (`torch.dtype`, defaults to `torch.float16`): The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. - device (`str`, *optional*, defaults to `"cpu"`): + device (`str`, defaults to `"cpu"`): Device on which to perform computations, should be same as the model's device. """ self.backend = backend @@ -817,13 +828,13 @@ def __init__( self, cache: "Cache", backend: str = "quanto", - nbits: Optional[int] = 4, - axis_key: Optional[int] = 0, - axis_value: Optional[int] = 0, - q_group_size: Optional[int] = 64, - residual_length: Optional[int] = 128, - compute_dtype: Optional[torch.dtype] = torch.float16, - device: Optional[str] = "cpu", + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + compute_dtype: torch.dtype = torch.float16, + device: str = "cpu", ) -> None: """Initialize the quanto quantization processor.""" super().__init__( @@ -879,13 +890,13 @@ def __init__( self, cache: "Cache", backend: str = "quanto", - nbits: Optional[int] = 4, - axis_key: Optional[int] = 0, - axis_value: Optional[int] = 0, - q_group_size: Optional[int] = 64, - residual_length: Optional[int] = 128, - compute_dtype: Optional[torch.dtype] = torch.float16, - device: Optional[str] = "cpu", + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + compute_dtype: torch.dtype = torch.float16, + device: str = "cpu", ) -> None: """Initialize the HQQ quantization processor.""" super().__init__( @@ -994,18 +1005,16 @@ class Cache: layer_classes (`list[type[CacheLayerMixin]]`, *optional*): List of `CacheLayerMixin` classes to instantiate for the cache. When shorter than the required number of layers the list is cycled. Default is [DynamicLayer]. - Additional arguments passed to the layers: - max_batch_size (`int`): Maximum batch size for static caches - max_cache_len (`int`): Maximum sequence length. For hybrid caches: - * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` - * StaticLayers: uses full `max_cache_len` - device (`torch.device`): Device for cache tensors - dtype (`torch.dtype`): Data type for cache tensors - layer_device_map (`dict[int, Union[str, torch.device]]`): Per-layer device mapping - tp_size (`int`): Tensor parallel size to adjust the number of key/value heads - - Additional keyword arguments are forwarded to the chosen layers constructor(s). See the - documentation of the relevant `CacheLayerMixin` class for more details. + max_batch_size (`int`, *optional*): Maximum batch size for static caches. + max_cache_len (`int`, *optional*): Maximum sequence length. For hybrid caches, SlidingWindowLayers are + clamped to `min(sliding_window, max_cache_len)`, StaticLayers use full `max_cache_len`. + device (`torch.device`, *optional*): Device for cache tensors. + dtype (`torch.dtype`, *optional*): Data type for cache tensors. + layer_device_map (`dict[int, Union[str, torch.device]]`, *optional*): Per-layer device mapping. + tp_size (`int`, *optional*): Tensor parallel size to adjust the number of key/value heads. + + Additional keyword arguments are forwarded to the chosen layers constructor(s) and CacheProcessors. See the + documentation of the relevant `CacheLayerMixin` class and `CacheProcessor` class for more details. """ def __init__( @@ -1013,7 +1022,12 @@ def __init__( config: Optional[PretrainedConfig] = None, cache_processor: Optional[Union[str, type["CacheProcessor"]]] = None, layer_classes: Optional[list[type["CacheLayerMixin"]]] = None, - *args, + max_batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: Optional[torch.dtype] = None, + layer_device_map: Optional[dict[int, torch.device]] = None, + tp_size: Optional[int] = None, **kwargs, ): self.layers: list["CacheLayerMixin"] = [] @@ -1023,8 +1037,16 @@ def __init__( layer_classes = [DynamicLayer] self.layer_classes = layer_classes + kwargs.update( + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + layer_device_map=layer_device_map, + tp_size=tp_size, + ) processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) - self.layer_init_args = parse_layer_args_from_model_config(config, *args, **kwargs) + self.layer_init_args = parse_layer_args_from_model_config(config, **kwargs) self.model_num_layers = getattr(config, "num_hidden_layers", 1) self.append_new_layers(self.model_num_layers - 1) @@ -1216,6 +1238,8 @@ class DynamicCache(Cache): It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. + See `Cache` for details on common methods that are implemented by all cache classes. + Example: ```python @@ -1258,18 +1282,15 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: return legacy_cache @classmethod - def from_legacy_cache( - cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None - ) -> "Cache": + def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...]) -> "Cache": """ Converts a cache in the legacy cache format into an equivalent `Cache`. Used for backward compatibility. """ cache = cls() - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) return cache @@ -1343,6 +1364,8 @@ class StaticCache(Cache): """ Static Cache class to be used with `torch.compile(model)` and `torch.export()`. + See `Cache` for details on common methods that are implemented by all cache classes. + Example: ```python @@ -1372,7 +1395,9 @@ class HybridCache(Cache): Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window attention and global attention in every other layer (originally implemented for Gemma2). Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention.For more information, see the documentation of each subcomponent cache class. + for global attention. For more information, see the documentation of those layer types. + + See `Cache` for details on common methods that are implemented by all cache classes. Example: @@ -1422,6 +1447,9 @@ class SlidingWindowCache(Cache): 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) + + See `Cache` for details on common methods that are implemented by all cache classes. + Example: ```python @@ -1457,7 +1485,9 @@ class QuantizedCache(DynamicCache): It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + is `[batch_size, num_heads, seq_len - residual_length, head_dim]`. + + See `Cache` for details on common methods that are implemented by all cache classes. """ def __init__(self, backend, **kwargs) -> None: @@ -1486,6 +1516,8 @@ class QuantoQuantizedCache(QuantizedCache): Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + See `Cache` for details on common methods that are implemented by all cache classes. + Example: ```python @@ -1525,6 +1557,8 @@ class HQQQuantizedCache(QuantizedCache): Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + See `Cache` for details on common methods that are implemented by all cache classes. + Example: ```python @@ -1558,6 +1592,8 @@ class OffloadedStaticCache(StaticCache): This cache maintains the compilation-friendly properties of StaticCache while enabling much longer sequences by offloading inactive layers to CPU memory. + See `Cache` for details on common methods that are implemented by all cache classes. + Example: ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache @@ -1594,6 +1630,8 @@ class HybridChunkedCache(HybridCache): Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponent cache class. + See `Cache` for details on common methods that are implemented by all cache classes. + Example: ```python @@ -1626,6 +1664,8 @@ class OffloadedHybridCache(HybridChunkedCache): This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling much longer sequences by offloading inactive layers to CPU memory. + + See `Cache` for details on common methods that are implemented by all cache classes. """ def __init__(self, *args, **kwargs) -> None: @@ -1637,6 +1677,8 @@ class EncoderDecoderCache(Cache): Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and cross-attention caches. + See `Cache` for details on common methods that are implemented by all cache classes. + Example: ```python @@ -1720,21 +1762,20 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]: @classmethod def from_legacy_cache( - cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None + cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...] ) -> "EncoderDecoderCache": """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" cache = cls( self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache(), ) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx][:2] - cache.self_attention_cache.update(key_states, value_states, layer_idx) - if len(past_key_values[layer_idx]) > 2: - key_states, value_states = past_key_values[layer_idx][2:] - cache.cross_attention_cache.update(key_states, value_states, layer_idx) - cache.is_updated[layer_idx] = True + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx][:2] + cache.self_attention_cache.update(key_states, value_states, layer_idx) + if len(past_key_values[layer_idx]) > 2: + key_states, value_states = past_key_values[layer_idx][2:] + cache.cross_attention_cache.update(key_states, value_states, layer_idx) + cache.is_updated[layer_idx] = True return cache def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position=None) -> int: @@ -1981,7 +2022,7 @@ def __init__(self, **kwargs) -> None: @dataclass class CacheConfig: """ - Base class for cache configs + Base class for cache configs. Deprecated in favor of a simpler dictionary. """ cache_implementation: None @@ -2090,7 +2131,7 @@ def update(self, **kwargs): @dataclass class QuantizedCacheConfig(CacheConfig): """ - Configuration class for quantized cache settings. + Configuration class for quantized cache settings. Deprecated in favor of a simpler dictionary. Attributes: backend (`str`, *optional*, defaults to `"quanto"`): From a4794705ed025ceef11d9fdb1a347337c0807bf5 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Fri, 18 Jul 2025 16:21:59 +0200 Subject: [PATCH 28/35] oopsie --- src/transformers/cache_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 07a9e557a4b7..f90287634e27 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1769,13 +1769,14 @@ def from_legacy_cache( self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache(), ) - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx][:2] - cache.self_attention_cache.update(key_states, value_states, layer_idx) - if len(past_key_values[layer_idx]) > 2: - key_states, value_states = past_key_values[layer_idx][2:] - cache.cross_attention_cache.update(key_states, value_states, layer_idx) - cache.is_updated[layer_idx] = True + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx][:2] + cache.self_attention_cache.update(key_states, value_states, layer_idx) + if len(past_key_values[layer_idx]) > 2: + key_states, value_states = past_key_values[layer_idx][2:] + cache.cross_attention_cache.update(key_states, value_states, layer_idx) + cache.is_updated[layer_idx] = True return cache def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position=None) -> int: From 9c0bdcc57919a285991cc7f5d1c62d9892ed984f Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Fri, 18 Jul 2025 17:00:15 +0200 Subject: [PATCH 29/35] oopsie --- src/transformers/cache_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index f90287634e27..1d4f0990db70 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1288,9 +1288,10 @@ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.FloatTensor, torch backward compatibility. """ cache = cls() - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) return cache From 0c4700dcef2763dcf98c316551fcb90e928af24e Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 21 Jul 2025 15:29:35 +0200 Subject: [PATCH 30/35] fix after merge --- .../configuration_falcon_mamba.py | 2 +- .../falcon_mamba/modeling_falcon_mamba.py | 30 +++---------------- 2 files changed, 5 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py index fb6fa5215760..86a0e9ad22d8 100644 --- a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py @@ -29,7 +29,7 @@ class FalconMambaConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the FALCON_MAMBA - [state-spaces/falcon_mamba-2.8b](https://huggingface.co/state-spaces/falcon_mamba-2.8b) architecture. + [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 90b5fec865d9..56a5770ba7b6 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -54,10 +54,10 @@ else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None - if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - else: - causal_conv1d_update, causal_conv1d_fn = None, None +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None logger = logging.get_logger(__name__) @@ -272,17 +272,6 @@ def cuda_kernels_forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): - if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None - - if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - else: - causal_conv1d_update, causal_conv1d_fn = None, None - # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -398,10 +387,6 @@ def slow_forward(self, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): - if is_mambapy_available(): - from mambapy.pscan import pscan - else: - pscan = None batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection @@ -687,15 +672,8 @@ def __init__(self, config): self.gradient_checkpointing = False self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) # Initialize weights and apply final processing - self._register_load_state_dict_pre_hook(self.load_hook) self.post_init() - def load_hook(self, state_dict, prefix, *args): - for k in state_dict: - if "embedding." in k: - state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) - break - def get_input_embeddings(self): return self.embeddings From b3a35e922dc23ba87f906c31d8137436dc929c77 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 21 Jul 2025 18:01:01 +0200 Subject: [PATCH 31/35] cyril review --- src/transformers/cache_utils.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 6bc4559de354..49d652ebc589 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -225,7 +225,6 @@ def __init__( self.dtype = dtype self.device = device - # Note: There will be significant perf decrease if switching to use 5D tensors instead. self.keys = torch.zeros( (batch_size, num_heads, self.max_cache_len, head_dim), dtype=dtype, @@ -419,9 +418,13 @@ def update( if is_full: full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) + # Fast decoding path -> here as the effective size is still sliding window, it is extremely important + # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address + # in memory (the values are the same as the full states, but not the address!!) if key_states.shape[-2] == 1: self.keys.copy_(full_key_states) self.values.copy_(full_value_states) + return self.keys, self.values elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len: if cumulative_length == 0: full_key_states = key_states @@ -560,7 +563,7 @@ def __init__(self, cache: "Cache", offload_device: Union[str, torch.device] = "c layer.keys = layer.keys.to(device) layer.values = layer.values.to(device) self.original_device.append(cache.layer_init_args["device"]) - if len(cache) != cache.model_num_layers: + if len(cache) != cache.num_hidden_layers: raise ValueError("If static layers are used, all cache layers must be initialized") self.prefetch_stream = ( @@ -1047,9 +1050,9 @@ def __init__( ) processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) self.layer_init_args = parse_layer_args_from_model_config(config, **kwargs) - self.model_num_layers = getattr(config, "num_hidden_layers", 1) + self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) - self.append_new_layers(self.model_num_layers - 1) + self.append_new_layers(self.num_hidden_layers - 1) self.cache_processor = processor_class(self, **processor_kwargs) if processor_class is not None else None def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: @@ -1426,9 +1429,6 @@ def __init__(self, config: PretrainedConfig, *args, **kwargs): if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None else [StaticLayer] ) - hybrid_chunked = kwargs.pop("hybrid_chunked", "llama4" in getattr(config, "model_type", "")) - if hybrid_chunked: - layer_classes = [ChunkedSlidingLayer if cls == SlidingWindowLayer else cls for cls in layer_classes] super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) @@ -1623,7 +1623,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) -class HybridChunkedCache(HybridCache): +class HybridChunkedCache(Cache): """ Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window attention and global attention in every other layer, with support for prefill chunking (originally implemented @@ -1654,8 +1654,13 @@ class HybridChunkedCache(HybridCache): """ def __init__(self, config: PretrainedConfig, *args, **kwargs): - kwargs["hybrid_chunked"] = True - super().__init__(config=config, *args, **kwargs) + layer_classes = ( + [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] + if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None + else [StaticLayer] + ) + layer_classes = [ChunkedSlidingLayer if cls == SlidingWindowLayer else cls for cls in layer_classes] + super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) class OffloadedHybridCache(HybridChunkedCache): From e4878ad00fd3d91a7ecb5a009b1b0bf8b450b551 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 22 Jul 2025 12:12:43 +0200 Subject: [PATCH 32/35] arthur review --- src/transformers/cache_utils.py | 129 ++++++++++++++------------------ 1 file changed, 58 insertions(+), 71 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 49d652ebc589..b31feb40ddaf 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -4,7 +4,6 @@ import inspect import json import os -import warnings from collections.abc import Iterable from dataclasses import dataclass from typing import Any, Callable, Optional, Union @@ -976,6 +975,38 @@ def _wrapped_update( return _wrapped_update +class KeyValuesWrapper: + """Helper class for Cache that simulates layer-indexed key/value lists from a layered cache. + This allows for BC access and writing, e.g., cache.key_cache[idx] = ... + Deprecated in favor of Cache.layers[idx].keys/values. TODO: remove in v4.56.0""" + + def __init__(self, layers, cache_type="keys"): + self.layers = layers + self.cache_type = cache_type + + def __getitem__(self, idx): + if isinstance(idx, slice): + return [getattr(layer, self.cache_type) for layer in self.layers[idx]] + return getattr(self.layers[idx], self.cache_type) + + def __setitem__(self, idx, value): + if isinstance(idx, slice): + for layer, val in zip(self.layers[idx], value): + setattr(layer, self.cache_type, val) + else: + setattr(self.layers[idx], self.cache_type, value) + + def __len__(self): + return len(self.layers) + + def __iter__(self): + for layer in self.layers: + yield getattr(layer, self.cache_type) + + def __bool__(self): + return bool(self.layers) + + class Cache: """ Base container for per-layer key/value caches. @@ -1161,17 +1192,17 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ return kv_length, kv_offset @property - def key_cache(self) -> "KeyValuesWrapper": + def key_cache(self) -> KeyValuesWrapper: """List-like object of key cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].keys`""" - warnings.warn( + logger.warning_once( "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead." ) return KeyValuesWrapper(self.layers, "keys") @property - def value_cache(self) -> "KeyValuesWrapper": + def value_cache(self) -> KeyValuesWrapper: """List-like object of value cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].values`""" - warnings.warn( + logger.warning_once( "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead." ) return KeyValuesWrapper(self.layers, "values") @@ -1424,11 +1455,10 @@ class HybridCache(Cache): """ def __init__(self, config: PretrainedConfig, *args, **kwargs): - layer_classes = ( - [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] - if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None - else [StaticLayer] - ) + if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None: + layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] + else: + layer_classes = [StaticLayer] super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) @@ -1654,12 +1684,13 @@ class HybridChunkedCache(Cache): """ def __init__(self, config: PretrainedConfig, *args, **kwargs): - layer_classes = ( - [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] - if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None - else [StaticLayer] - ) - layer_classes = [ChunkedSlidingLayer if cls == SlidingWindowLayer else cls for cls in layer_classes] + hybrid_map = LAYER_CLASS_MAP.copy() + hybrid_map["sliding_attention"] = ChunkedSlidingLayer + hybrid_map["chunked_attention"] = ChunkedSlidingLayer + if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None: + layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] + else: + layer_classes = [StaticLayer] super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) @@ -1979,38 +2010,6 @@ def parse_layer_args_from_model_config( ### Deprecated classes -class KeyValuesWrapper: - """Helper class for Cache that simulates layer-indexed key/value lists from a layered cache. - This allows for BC access and writing, e.g., cache.key_cache[idx] = ... - Deprecated in favor of Cache.layers[idx].keys/values. TODO: remove in v4.56.0""" - - def __init__(self, layers, cache_type="keys"): - self.layers = layers - self.cache_type = cache_type - - def __getitem__(self, idx): - if isinstance(idx, slice): - return [getattr(layer, self.cache_type) for layer in self.layers[idx]] - return getattr(self.layers[idx], self.cache_type) - - def __setitem__(self, idx, value): - if isinstance(idx, slice): - for layer, val in zip(self.layers[idx], value): - setattr(layer, self.cache_type, val) - else: - setattr(self.layers[idx], self.cache_type, value) - - def __len__(self): - return len(self.layers) - - def __iter__(self): - for layer in self.layers: - yield getattr(layer, self.cache_type) - - def __bool__(self): - return bool(self.layers) - - class SinkCache(Cache): """ Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache. @@ -2035,10 +2034,8 @@ class CacheConfig: cache_implementation: None def __post_init__(self): - warnings.warn( - ("CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."), - FutureWarning, - stacklevel=2, + logger.warning_once( + "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary." ) @classmethod @@ -2052,10 +2049,8 @@ def from_dict(cls, config_dict, **kwargs): Returns: CacheConfig: Instance of CacheConfig constructed from the dictionary. """ - warnings.warn( - ("CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."), - FutureWarning, - stacklevel=2, + logger.warning_once( + "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary." ) config = cls(**config_dict) to_remove = [] @@ -2172,10 +2167,8 @@ def __init__( compute_dtype: Optional[torch.dtype] = torch.float16, device: Optional[str] = "cpu", ): - warnings.warn( - ("CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."), - FutureWarning, - stacklevel=2, + logger.warning_once( + "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary." ) self.backend = backend self.nbits = nbits @@ -2248,10 +2241,8 @@ class StaticCacheConfig(CacheConfig): cache_implementation = "static" def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): - warnings.warn( - ("CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."), - FutureWarning, - stacklevel=2, + logger.warning_once( + "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary." ) self.batch_size = batch_size self.max_cache_len = max_cache_len @@ -2347,13 +2338,9 @@ def _prefetch_layer_in_context(self, layer_idx: int) -> None: # PEP 562: Lazy loading for deprecated location of MambaCache def __getattr__(name: str) -> Any: if name == "MambaCache": - warnings.warn( - ( - "Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed " - "in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead." - ), - FutureWarning, - stacklevel=2, + logger.warning_once( + "Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed " + "in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead." ) class MambaCache: From 8df15953bdda409896558bbb3c62c50f4c9a4caa Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 22 Jul 2025 12:23:06 +0200 Subject: [PATCH 33/35] opsie --- src/transformers/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b31feb40ddaf..c8471b60e449 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1688,7 +1688,7 @@ def __init__(self, config: PretrainedConfig, *args, **kwargs): hybrid_map["sliding_attention"] = ChunkedSlidingLayer hybrid_map["chunked_attention"] = ChunkedSlidingLayer if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None: - layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] + layer_classes = [hybrid_map[layer_type] for layer_type in config.layer_types] else: layer_classes = [StaticLayer] super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) From ad65a02da6c34098b0b46b7b29adc706c561b84f Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 22 Jul 2025 12:50:27 +0200 Subject: [PATCH 34/35] fix lfm2 --- src/transformers/models/lfm2/modular_lfm2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index d8f7de5fd9ee..75ef05e3182c 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -207,9 +207,10 @@ def crop(self, max_length: int): if self.get_seq_length() <= max_length: return - if self.key_cache is not None and self.key_cache.numel(): - self.key_cache = self.key_cache[..., :max_length, :] - self.value_cache = self.value_cache[..., :max_length, :] + for idx in range(len(self.key_cache)): + if self.key_cache[idx].numel(): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: return self.key_cache[layer_idx], self.value_cache[layer_idx] From d9fbb0442092dbe376d8a564cf5bfe27c692e74e Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 22 Jul 2025 13:09:27 +0200 Subject: [PATCH 35/35] opsie2 --- src/transformers/models/lfm2/modeling_lfm2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 9b250c3b97fb..ca653e4114af 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -246,9 +246,10 @@ def crop(self, max_length: int): if self.get_seq_length() <= max_length: return - if self.key_cache is not None and self.key_cache.numel(): - self.key_cache = self.key_cache[..., :max_length, :] - self.value_cache = self.value_cache[..., :max_length, :] + for idx in range(len(self.key_cache)): + if self.key_cache[idx].numel(): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: return self.key_cache[layer_idx], self.value_cache[layer_idx]