From 1c3cbcccc89c2af50e649c51430b846bedf2c157 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Sun, 29 Jun 2025 11:46:42 +0200 Subject: [PATCH 1/4] Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests --- docs/source/en/internal/generation_utils.md | 21 - docs/source/en/model_doc/falcon_mamba.md | 7 + docs/source/en/model_doc/mamba.md | 7 + src/transformers/__init__.py | 2 - src/transformers/cache_utils.py | 2530 ++++++++--------- .../generation/configuration_utils.py | 9 +- src/transformers/generation/utils.py | 117 +- .../models/falcon_h1/modeling_falcon_h1.py | 6 +- .../configuration_falcon_mamba.py | 26 +- .../falcon_mamba/modeling_falcon_mamba.py | 235 +- .../falcon_mamba/modular_falcon_mamba.py | 537 ++++ .../models/jamba/modeling_jamba.py | 6 +- .../models/mamba/modeling_mamba.py | 162 +- .../models/mamba2/modeling_mamba2.py | 7 +- .../models/zamba/modeling_zamba.py | 7 +- .../models/zamba2/modeling_zamba2.py | 8 +- .../models/zamba2/modular_zamba2.py | 2 +- src/transformers/utils/dummy_pt_objects.py | 7 - .../test_modeling_falcon_mamba.py | 43 +- tests/models/mamba/test_modeling_mamba.py | 37 +- tests/utils/test_cache_utils.py | 43 +- 21 files changed, 2246 insertions(+), 1573 deletions(-) create mode 100644 src/transformers/models/falcon_mamba/modular_falcon_mamba.py diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 1c17e99d5da3..b956828fdcff 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -366,43 +366,22 @@ A [`Constraint`] can be used to force the generation to include specific tokens - validate [[autodoc]] DynamicCache - - update - - get_seq_length - - reorder_cache - - to_legacy_cache - - from_legacy_cache [[autodoc]] QuantizedCache - - update - - get_seq_length [[autodoc]] QuantoQuantizedCache [[autodoc]] HQQQuantizedCache [[autodoc]] OffloadedCache - - update - - prefetch_layer - - evict_previous_layer [[autodoc]] StaticCache - - update - - get_seq_length - - reset [[autodoc]] OffloadedStaticCache - - update - - get_seq_length - - reset [[autodoc]] HybridCache - - update - - get_seq_length - - reset [[autodoc]] SlidingWindowCache - - update - - reset [[autodoc]] EncoderDecoderCache - get_seq_length diff --git a/docs/source/en/model_doc/falcon_mamba.md b/docs/source/en/model_doc/falcon_mamba.md index a8d7886894b2..0b797c7c7829 100644 --- a/docs/source/en/model_doc/falcon_mamba.md +++ b/docs/source/en/model_doc/falcon_mamba.md @@ -110,6 +110,13 @@ outputs = model.generate(**inputs, max_new_tokens=100) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` +## FalconMambaCache + +[[autodoc]] FalconMambaCache + - update_conv_state + - update_ssm_state + - reset + ## FalconMambaConfig [[autodoc]] FalconMambaConfig diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 9ce98d8516a8..2d5543cdcf84 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -115,6 +115,13 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) trainer.train() ``` +## MambaCache + +[[autodoc]] MambaCache + - update_conv_state + - update_ssm_state + - reset + ## MambaConfig [[autodoc]] MambaConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 892acd32ead7..b07b0e849a9e 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -358,7 +358,6 @@ "EncoderDecoderCache", "HQQQuantizedCache", "HybridCache", - "MambaCache", "OffloadedCache", "OffloadedStaticCache", "QuantizedCache", @@ -839,7 +838,6 @@ EncoderDecoderCache, HQQQuantizedCache, HybridCache, - MambaCache, OffloadedCache, OffloadedStaticCache, QuantizedCache, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 04ccc6f7efcf..d05220cb62a5 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -2,6 +2,7 @@ import importlib.metadata import json import os +import warnings from collections.abc import Iterable from dataclasses import dataclass from typing import Any, Optional, Union @@ -18,116 +19,295 @@ if is_hqq_available(): from hqq.core.quantize import Quantizer as HQQQuantizer + logger = logging.get_logger(__name__) -# Utility functions for static/sliding cache update logic -def _static_cache_update( - k_cache: torch.Tensor, - v_cache: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_position: Optional[torch.LongTensor], -) -> tuple[torch.Tensor, torch.Tensor]: +class CacheProcessor: """ - Updates the static cache tensors in place. - - Args: - k_cache (`torch.Tensor`): The key cache tensor to update. - v_cache (`torch.Tensor`): The value cache tensor to update. - key_states (`torch.Tensor`): The new key states to add. - value_states (`torch.Tensor`): The new value states to add. - cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted. - If None, the entire cache is overwritten (prefill). - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place). + Base class for cache processors that can be applied to modify cache behavior. + This class should be subclassed. """ - if cache_position is None: - # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. - k_cache.copy_(key_states) - v_cache.copy_(value_states) - else: - # Generation phase. Update specific positions. - # Use index_copy_ for in-place update (compile-friendly). - try: - k_cache.index_copy_(2, cache_position, key_states) - v_cache.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # Fallback for devices like MPS where index_copy_ might not be supported. - k_cache[:, :, cache_position] = key_states - v_cache[:, :, cache_position] = value_states - return k_cache, v_cache - - -def _sliding_cache_update( - k_cache: torch.Tensor, - v_cache: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_position: torch.LongTensor, - max_cache_len: int, -) -> tuple[torch.Tensor, torch.Tensor]: + + def init(self, cache: "Cache", **kwargs) -> None: + """ + Initialize the processor and perform compatibility checks with the cache. + + Args: + cache (`Cache`): The cache instance this processor will be applied to. + **kwargs: Additional arguments that may be needed for initialization. + """ + raise NotImplementedError("Make sure to implement `init` in a subclass.") + + def pre_update( + self, + cache: "Cache", + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Hook called before the cache update. Can modify the key/value states. + + Args: + cache (`Cache`): The cache instance. + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + layer_idx (`int`): The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, `optional`): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states. + """ + return key_states, value_states + + def post_update( + self, + cache: "Cache", + key_tensors: torch.Tensor, + value_tensors: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Hook called after the cache update. Can process the cached data. + + Args: + cache (`Cache`): The cache instance. + key_states (`torch.Tensor`): The key states that were cached. + value_states (`torch.Tensor`): The value states that were cached. + layer_idx (`int`): The index of the layer that was updated. + cache_kwargs (`dict[str, Any]`, `optional`): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The final key and value states to return. + """ + return key_tensors, value_tensors + + +class CacheProcessorList(list): """ - Updates the sliding window cache tensors, returning the potentially modified tensors. - - Args: - k_cache (`torch.Tensor`): The key cache tensor to update. - v_cache (`torch.Tensor`): The value cache tensor to update. - key_states (`torch.Tensor`): The new key states to add. - value_states (`torch.Tensor`): The new value states to add. - cache_position (`torch.LongTensor`): The position indices where the new states should be inserted. - max_cache_len (`int`): The maximum length of the sliding window cache. - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update. - For prefill > window, these are the full input states. - Otherwise, they are the updated cache tensors. + list of cache processors that can be applied to a cache. """ - # Handle prefill phase when prompt length > sliding_window_size - if cache_position.shape[0] > max_cache_len: - new_k = key_states[:, :, -max_cache_len:, :] - new_v = value_states[:, :, -max_cache_len:, :] - k_cache.copy_(new_k) - v_cache.copy_(new_v) + + def init(self, cache: "Cache", **kwargs) -> None: + """Initialize all processors in the list.""" + for processor in self: + processor.init(cache, **kwargs) + + def pre_update( + self, + cache: "Cache", + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply pre_update hook for all processors.""" + for processor in self: + key_states, value_states = processor.pre_update(cache, key_states, value_states, layer_idx, cache_kwargs) return key_states, value_states - # Sliding window logic for generation phase or prefill < window - slicing = torch.arange(max_cache_len, device=value_states.device) - current_seq_len = cache_position[-1] + 1 # Use last position to determine current length - to_shift = current_seq_len > max_cache_len - indices = (slicing + to_shift.sum()) % max_cache_len + def post_update( + self, + cache: "Cache", + key_tensors: torch.Tensor, + value_tensors: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply post_update hook for all processors.""" + for processor in self: + key_tensors, value_tensors = processor.post_update( + cache, key_tensors, value_tensors, layer_idx, cache_kwargs + ) + return key_tensors, value_tensors + + +class KVList: + """Efficiently simulates layer-indexed key or value lists from a layered cache. + This allows for BC access, e.g., cache.key_cache[idx] or cache.value_cache[idx].""" + + def __init__(self, layers, cache_type="key"): + self.layers = layers + self.cache_type = cache_type + + def __getitem__(self, idx): + if isinstance(idx, slice): + return [getattr(layer, f"{self.cache_type}_cache") for layer in self.layers[idx]] + return getattr(self.layers[idx], f"{self.cache_type}_cache") + + def __setitem__(self, idx, value): + if isinstance(idx, slice): + for layer, val in zip(self.layers[idx], value): + setattr(layer, f"{self.cache_type}_cache", val) + else: + setattr(self.layers[idx], f"{self.cache_type}_cache", value) + + def __len__(self): + return len(self.layers) + + def __iter__(self): + for layer in self.layers: + yield getattr(layer, f"{self.cache_type}_cache") + + def __bool__(self): + return bool(self.layers) + + +class CacheLayer: + """Base, abstract class for a single layer's cache.""" + + is_compileable = False + + def __init__( + self, + config: Optional["CacheConfig"] = None, + ): + self.key_cache = None + self.value_cache = None + + @classmethod + def from_kv(cls, key_cache: torch.Tensor, value_cache: torch.Tensor) -> None: + cache = cls() + cache.key_cache = key_cache + cache.value_cache = value_cache + return cache + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Updates KV cache, returns updated K and V for this layer.""" + raise NotImplementedError("Make sure to implement `update` in a subclass.") - k_out_shifted = k_cache[:, :, indices] - v_out_shifted = v_cache[:, :, indices] + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache. + Early stops since first layer is enough to compute sequence length""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length, _ = self.get_seq_length(layer_idx) + if max_length != -1 and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length - # Clamp cache_position to determine the *target index* within the shifted cache view - update_position = cache_position.clamp(min=0, max=max_cache_len - 1) + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" + raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") - try: - k_out_updated = k_out_shifted.index_copy(2, update_position, key_states) - v_out_updated = v_out_shifted.index_copy(2, update_position, value_states) - except NotImplementedError: - # Fallback for MPS: clone and modify the clone - k_out_updated = k_out_shifted.clone() - v_out_updated = v_out_shifted.clone() - k_out_updated[:, :, update_position] = key_states - v_out_updated[:, :, update_position] = value_states + def reset(self) -> None: + """Resets this layer's cache.""" + raise NotImplementedError("Make sure to implement `reset` in a subclass.") + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + """Reorders this layer's cache for beam search.""" + if self.key_cache.numel(): + device = self.key_cache.device + self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) + if self.value_cache.numel(): + device = self.value_cache.device + self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) - k_cache.copy_(k_out_updated) - v_cache.copy_(v_out_updated) - return k_out_updated, v_out_updated + def __repr__(self): + return f"{self.__class__.__name__}(K={self.key_cache}, V={self.value_cache})" class Cache: """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. + Base, abstract class for all caches. The actual data structure is specific to the layers. + This class handles propagation of operations across layers. + + Parameters: + config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): + Model configuration for shape/device info, or DDP-distributed cache data for compatibility. + If DDP-distributed cache data, must be an iterable of (key_states, value_states) tuples for each layer. + processors (`CacheProcessorList`, *optional*): + List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): + Pattern of cache layer types to use. Defaults to `(DynamicLayer,)`. Must be a tuple whose length divides + the total number of layers. The pattern repeats to fill all layers. Examples: `(StaticLayer,)` for a + uniform cache, `(StaticLayer, StaticLayer, SlidingWindowLayer)` for a hybrid cache with repeating pattern, + or specify the full structure like `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)`. + Additional arguments for cache configuration: + - `max_batch_size`/`batch_size` (`int`): Maximum batch size for static caches + - `max_cache_len` (`int`): Maximum sequence length. For hybrid caches: + * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` + * StaticLayers: uses full `max_cache_len` + - `device` (`torch.device`): Device for cache tensors + - `dtype` (`torch.dtype`): Data type for cache tensors + - `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping + + Note for hybrid caches (blocks of (StaticLayer, ..., SlidingWindowLayer) repeated across layers): + - Requires `model_config.sliding_window` to be set + - Uses `sliding_window_pattern` (default: 2) to determine layer alternation if pattern not specified + - SlidingWindow layers are limited to sliding window size, Static layers use full max_cache_len """ - is_compileable = False + layers = [] + pattern_block = () # Subclasses can define their layer pattern statically - def __init__(self): - super().__init__() + def __init__( + self, + config_or_ddp_cache_data: Optional[ + Union[PretrainedConfig, Iterable[tuple[torch.Tensor, torch.Tensor]]] + ] = None, + processors: Optional[CacheProcessorList] = None, + pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, + *args, + **kwargs, + ): + self.layers: list[CacheLayer] = [] + self.processors = processors if processors is not None else CacheProcessorList() + pattern_block = pattern_block or self.pattern_block or (DynamicLayer,) + + if isinstance(config_or_ddp_cache_data, PretrainedConfig): + model_config = config_or_ddp_cache_data + elif isinstance(config_or_ddp_cache_data, Iterable): + _distributed_cache_data = config_or_ddp_cache_data + # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 + # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the + # iterable contains the key and value states for a layer gathered across replicas by torch.distributed + # (shape=[global batch size, num_heads, seq_len, head_dim]). + # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break + # compatibility. The name of the argument doesn't matter. + assert pattern_block == (DynamicLayer,), "torch DDP is only supported for DynamicCache" + for key_states, value_states in _distributed_cache_data: + self.layers.append(DynamicLayer.from_kv(key_states, value_states)) + self.processors.init(self, **kwargs) + return + else: + model_config = kwargs.pop("config", None) + + self.config, self.pattern_block = CacheConfig.from_model_config(model_config, pattern_block, *args, **kwargs) + self.layer_types = [self.pattern_block[i % len(self.pattern_block)] for i in range(self.config.num_layers)] + + for idx, layer_type in enumerate(self.layer_types): + layer = layer_type(self.config.to_layer(idx)) + self.layers.append(layer) + + self.processors.init(self, **kwargs) + + def grow_layers_to(self, layer_idx): + while len(self.layers) <= layer_idx: + next_type_idx = len(self.layer_types) % len(self.pattern_block) + next_layer_type = self.pattern_block[next_type_idx] + self.layer_types.append(next_layer_type) + self.layers.append(next_layer_type()) + + @property + def key_cache(self) -> KVList: + """Returns a list-like object of key cache tensors indexed by layer.""" + return KVList(self.layers, "key") + + @property + def value_cache(self) -> KVList: + """Returns a list-like object of value cache tensors indexed by layer.""" + return KVList(self.layers, "value") def update( self, @@ -153,48 +333,99 @@ def update( Return: A tuple containing the updated key and value states. """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") + key_states, value_states = self.processors.pre_update(self, key_states, value_states, layer_idx, cache_kwargs) + self.grow_layers_to(layer_idx) + key_tensors, value_tensors = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + key_tensors, value_tensors = self.processors.post_update( + self, key_tensors, value_tensors, layer_idx, cache_kwargs + ) + return key_tensors, value_tensors - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self.layers): + return self.layers[layer_idx].key_cache, self.layers[layer_idx].value_cache + else: + raise KeyError( + f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" + ) - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length (i.e. max capacity) of the cache object""" - raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.layers[layer_idx].key_cache, self.layers[layer_idx].value_cache) - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_cache_shape() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + # Best effort BC support for subclasses + if self.layers is None: + if getattr(self, "key_cache", None) is not None: + return len(self.key_cache) + return 0 + dynamic_empty = ( + len(self.layers) == 1 and isinstance(self.layers[0], DynamicLayer) and self.layers[0].key_cache is None + ) + return len(self.layers) if not dynamic_empty else 0 - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - if self.value_cache[layer_idx].numel(): - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + def __getattr__(self, name): + """ + Dynamically handle any method call / attribute access by forwarding to layers. + If it doesn't exist on all layers, raises AttributeError. + For attributes, checks all different layer types and reduces to an set of unique values (or to a single value) + For methods, returns a function that propagates the call to all layers with carry-over state (e.g. update, reset) + """ + if name in ("__getstate__", "__setstate__"): + raise AttributeError(name) + + # Check if the attribute/method exists and gather values if it is an attribute + attribute_values = [] + for i, layer in enumerate(self.layers[: len(self.pattern_block)]): + if not hasattr(layer, name): + raise AttributeError( + f"Layer {i} ({layer.__class__.__name__}) of {self.__class__.__name__} does not support `{name}`" + ) + if not callable(getattr(layer, name)): + attribute_values.append(getattr(layer, name)) - @property - def seen_tokens(self): - logger.warning_once( - "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " - "model input instead." - ) - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None + if attribute_values: + assert len(attribute_values) == len(self.pattern_block), ( + f"Cache {self.__class__.__name__} gathered {len(attribute_values)} values for {name}, but there are {len(self.pattern_block)} layers." + ) + values = set(attribute_values) + if len(values) == 1: + return values.pop() + else: + if all(isinstance(value, bool) for value in values): + return all(values) + else: + raise ValueError( + f"Cache {self.__class__.__name__}:{self.pattern_block} has multiple values for {name}: {attribute_values}. This is not supported." + ) + + # If the attribute is a method, we propagate it to all layers + def propagate_to_layers(*args, **kwargs): + for layer in self.layers: + return_value = getattr(layer, name)(*args, **kwargs) + if return_value is not None: + break + return return_value + + return propagate_to_layers + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" + if layer_idx >= len(self.layers): + return 0 + return self.layers[layer_idx].get_seq_length() def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: """ @@ -203,10 +434,49 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), for each layer. """ - query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens - return kv_length, 0 + if isinstance(self.layers[layer_idx], SlidingWindowLayer): + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + + local_mask_kv_offset = torch.clamp(first_cache_position - self.config.sliding_window + 1, min=0) + # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns + local_mask_kv_length = max(query_length, self.config.sliding_window) + return local_mask_kv_length, local_mask_kv_offset + + full_mask_kv_offset = 0 + if isinstance(self.layers[layer_idx], StaticLayer): + full_mask_kv_length = self.get_max_cache_shape() + return full_mask_kv_length, full_mask_kv_offset + else: + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, full_mask_kv_offset + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: + """Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer in self.layers: + if layer is not None: + legacy_cache += ((layer.key_cache, layer.value_cache),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None + ) -> "Cache": + """Converts a cache in the legacy cache format into an equivalent `Cache`. Used for + backward compatibility.""" + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def __repr__(self): + return f"{self.__class__.__name__}(layers={self.layers})" @dataclass @@ -215,7 +485,65 @@ class CacheConfig: Base class for cache configs """ - cache_implementation: None + def __init__(self, num_layers: Optional[int] = None, cache_implementation: Optional[str] = None): + self.num_layers = num_layers + self.cache_implementation = cache_implementation + + @classmethod + def from_model_config( + cls, + config: Optional[PretrainedConfig], + pattern_block: tuple[type["CacheLayer"], ...], + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: Optional[torch.dtype] = None, + layer_device_map=None, + max_batch_size: Optional[int] = None, + ) -> "CacheConfig": + num_layers = getattr(config, "num_hidden_layers", len(pattern_block)) + # No model config -> must be a dynamic cache, return bare CacheConfig + if config is None: + return cls(num_layers=num_layers), pattern_block + # Build a StaticCacheConfig for any kind of static: hybrid, sliding or static + else: + # Rename max_batch_size to batch_size + if max_batch_size is not None: + batch_size = max_batch_size + # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) + if StaticLayer in pattern_block and SlidingWindowLayer in pattern_block: + if getattr(config, "sliding_window", None) is None: + raise ValueError( + "Setting up a hybrid or sliding window KVCache requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) + max_cache_len = max_cache_len or config.max_position_embeddings + sliding_window_len = min(getattr(config, "sliding_window", max_cache_len) or max_cache_len, max_cache_len) + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: + head_dim = ( + config.head_dim + if getattr(config, "head_dim", None) is not None + else config.hidden_size // config.num_attention_heads + ) + num_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + cache_config = StaticCacheConfig( + batch_size=batch_size, + max_cache_len=max_cache_len, + device=torch.device(device) if device is not None else None, + dtype=dtype, + layer_device_map=layer_device_map, + head_dim=head_dim, + num_heads=num_heads, + sliding_window=sliding_window_len, + num_layers=num_layers, + ) + return cache_config, pattern_block @classmethod def from_dict(cls, config_dict, **kwargs): @@ -238,6 +566,9 @@ def from_dict(cls, config_dict, **kwargs): kwargs.pop(key, None) return config + def to_layer(self, layer_idx: int) -> "CacheLayer": + return self + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file def to_json_file(self, json_file_path: Union[str, os.PathLike]): """ @@ -408,15 +739,61 @@ def validate(self): @dataclass class StaticCacheConfig(CacheConfig): """ - Configuration class for static cache settings. + Configuration class for static and sliding window cache settings. """ - cache_implementation = "static" + batch_size: Optional[int] = None + max_cache_len: Optional[int] = None + device: Union[str, torch.device] = None + dtype: Optional[torch.dtype] = None + layer_device_map: Optional[dict[int, Union[str, torch.device]]] = None + head_dim: Optional[int] = None + num_heads: Optional[int] = None + sliding_window: Optional[int] = None + num_layers: Optional[int] = None + cache_implementation: Optional[str] = None + + def __post_init__(self): + self.cache_implementation = "static" + if self.batch_size is None: + raise ValueError("`batch_size` is required for static cache") + if self.max_cache_len is None: + raise ValueError("`max_cache_len` is required for static cache") + if self.device is None: + self.device = "cpu" + logger.warning_once("`device` not set in cache initialization, using default `cpu`") + if self.dtype is None: + self.dtype = torch.float32 + logger.warning_once("`dtype` not set in cache initialization, using default `float32`") + + def for_layer(self, layer_idx: int): + """ + Returns a StaticCacheConfig for a given layer index. + """ + device = self.layer_device_map[layer_idx] if self.layer_device_map is not None else self.device + return StaticCacheConfig( + self.batch_size, + self.max_cache_len, + device, + self.dtype, + None, + self.head_dim, + self.num_heads, + self.sliding_window, + ) - def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): - self.batch_size = batch_size - self.max_cache_len = max_cache_len - self.device = device + @property + def dtype(self): # noqa: F811 + return getattr(torch, self._dtype) if self._dtype is not None else None + + @dtype.setter + def dtype(self, value): + if isinstance(value, torch.dtype): + self._dtype = str(value).split(".")[-1] + elif isinstance(value, str): + self._dtype = value + else: + self._dtype = None def validate(self): """Validates if the arguments passed are correct""" @@ -445,6 +822,90 @@ def validate(self): ) +class DynamicLayer(CacheLayer): + """ + A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. + It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. + """ + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + cache_kwargs (`dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`. + + Return: + A tuple containing the updated key and value states. + """ + if self.key_cache is None: + self.key_cache = key_states + self.value_cache = value_states + else: + self.key_cache = torch.cat([self.key_cache, key_states], dim=-2) + self.value_cache = torch.cat([self.value_cache, value_states], dim=-2) + + return self.key_cache, self.value_cache + + def get_seq_length(self, cache_position: Optional[torch.LongTensor] = None) -> int: + """Returns the sequence length of the cached states.""" + # TODO: deprecate this function in favor of `cache_position` + if self is None or self.key_cache is None: + return 0 + return self.key_cache.shape[-2] + + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" + return -1 + + def reset(self) -> None: + """Resets the cache values while preserving the objects""" + self.key_cache = torch.tensor([], dtype=self.key_cache.dtype, device=self.key_cache.device) + self.value_cache = torch.tensor([], dtype=self.value_cache.dtype, device=self.value_cache.device) + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + """Reorders the cache for beam search, given the selected beam indices.""" + if self.key_cache is not None and self.key_cache.numel(): + self.key_cache = self.key_cache.index_select(0, beam_idx.to(self.key_cache.device)) + if self.value_cache is not None and self.value_cache.numel(): + self.value_cache = self.value_cache.index_select(0, beam_idx.to(self.value_cache.device)) + + def crop(self, max_length: int) -> None: + """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens.""" + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + if self.key_cache.numel(): + self.key_cache = self.key_cache[..., :max_length, :] + self.value_cache = self.value_cache[..., :max_length, :] + + def batch_repeat_interleave(self, repeats: int) -> None: + """Repeat the cache `repeats` times in the batch dimension.""" + if self.key_cache.numel(): + self.key_cache = self.key_cache.repeat_interleave(repeats, dim=0) + self.value_cache = self.value_cache.repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor) -> None: + """Only keep the `indices` in the batch dimension of the cache.""" + if self.key_cache.numel(): + self.key_cache = self.key_cache[indices, ...] + self.value_cache = self.value_cache[indices, ...] + + class DynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. @@ -470,184 +931,7 @@ class DynamicCache(Cache): ``` """ - def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: - super().__init__() - self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121 - # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the - # iterable contains the key and value states for a layer gathered across replicas by torch.distributed - # (shape=[global batch size, num_heads, seq_len, head_dim]). - # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break - # compatibility. The name of the argument doesn't matter. - if _distributed_cache_data is not None: - for key_states, value_states in _distributed_cache_data: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - - def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self): - return (self.key_cache[layer_idx], self.value_cache[layer_idx]) - else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") - - def __iter__(self): - """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over - keys and values - """ - for layer_idx in range(len(self)): - yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the cache - if key_states is not None: - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif ( - not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model - ): # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or not self.key_cache[layer_idx].numel() # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - return layer_seq_length - - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" - return None - - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" - legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None - ) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for - backward compatibility.""" - cache = cls() - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def crop(self, max_length: int): - """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be - negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" - # In case it is negative - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - self._seen_tokens = max_length - for idx in range(len(self.key_cache)): - if self.key_cache[idx].numel(): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def batch_split(self, full_batch_size: int, split_size: int) -> list["DynamicCache"]: - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" - out = [] - for i in range(0, full_batch_size, split_size): - current_split = DynamicCache() - current_split._seen_tokens = self._seen_tokens - current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] - current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] - out.append(current_split) - return out - - @classmethod - def from_batch_splits(cls, splits: list["DynamicCache"]) -> "DynamicCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - cache = cls() - for idx in range(len(splits[0])): - key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx].numel()] - value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx].numel()] - if key_cache != []: - layer_keys = torch.cat(key_cache, dim=0) - layer_values = torch.cat(value_cache, dim=0) - cache.update(layer_keys, layer_values, idx) - return cache - - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) - - def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + pattern_block = (DynamicLayer,) # Utilities for `DynamicCache` <> torch.export support @@ -663,18 +947,17 @@ def _flatten_dynamic_cache( "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." ) - # NOTE it seems _seen_tokens is deprecated, so probably doesn't need tracking dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), + "key_cache": [layer.key_cache for layer in dynamic_cache.layers if layer.key_cache is not None], + "value_cache": [layer.value_cache for layer in dynamic_cache.layers if layer.value_cache is not None], } return torch.utils._pytree._dict_flatten(dictionary) def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), + "key_cache": [layer.key_cache for layer in dynamic_cache.layers if layer.key_cache is not None], + "value_cache": [layer.value_cache for layer in dynamic_cache.layers if layer.value_cache is not None], } return torch.utils._pytree._dict_flatten_with_keys(dictionary) @@ -685,15 +968,20 @@ def _unflatten_dynamic_cache( ): dictionary = torch.utils._pytree._dict_unflatten(values, context) cache = DynamicCache() - for k, v in dictionary.items(): - setattr(cache, k, v) + # Reconstruct layers from key_cache and value_cache lists + key_list = dictionary.get("key_cache", []) + value_list = dictionary.get("value_cache", []) + for idx in range(max(len(key_list), len(value_list))): + key = key_list[idx] if idx < len(key_list) else None + value = value_list[idx] if idx < len(value_list) else None + cache.update(key, value, idx) return cache def _flatten_dynamic_cache_for_fx(cache, spec): dictionary = { - "key_cache": getattr(cache, "key_cache"), - "value_cache": getattr(cache, "value_cache"), + "key_cache": [layer.key_cache for layer in cache.layers if layer.key_cache is not None], + "value_cache": [layer.value_cache for layer in cache.layers if layer.value_cache is not None], } return torch.fx._pytree._dict_flatten_spec(dictionary, spec) @@ -723,125 +1011,13 @@ class OffloadedCache(DynamicCache): ensure the eviction is scheduled after all computations on that cache are finished. """ - def __init__(self) -> None: - if not ( - torch.cuda.is_available() - or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) - ): - raise RuntimeError( - "OffloadedCache can only be used with a GPU" - + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") - ) - - super().__init__() - self.original_device = [] - self.prefetch_stream = None - self.prefetch_stream = ( - torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream() - ) - self.beam_idx = None # used to delay beam search operations - - def prefetch_layer(self, layer_idx: int): - "Starts prefetching the next layer cache" - if layer_idx < len(self): - with ( - self.prefetch_stream - if is_torch_greater_or_equal("2.7", accept_dev=True) - else torch.cuda.stream(self.prefetch_stream) - ): - # Prefetch next layer tensors to GPU - device = self.original_device[layer_idx] - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True) - - def evict_previous_layer(self, layer_idx: int): - "Moves the previous layer cache to the CPU" - if len(self) > 2: - # We do it on the default stream so it occurs after all earlier computations on these tensors are done - prev_layer_idx = (layer_idx - 1) % len(self) - self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) - self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) - - def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: - "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." - if layer_idx < len(self): - # Evict the previous layer if necessary - if is_torch_greater_or_equal("2.7", accept_dev=True): - torch.accelerator.current_stream().synchronize() - else: - torch.cuda.current_stream().synchronize() - self.evict_previous_layer(layer_idx) - # Load current layer cache to its original device if not already there - original_device = self.original_device[layer_idx] - self.prefetch_stream.synchronize() - key_tensor = self.key_cache[layer_idx] - value_tensor = self.value_cache[layer_idx] - # Now deal with beam search ops which were delayed - if self.beam_idx is not None: - self.beam_idx = self.beam_idx.to(original_device) - key_tensor = key_tensor.index_select(0, self.beam_idx) - value_tensor = value_tensor.index_select(0, self.beam_idx) - # Prefetch the next layer - self.prefetch_layer((layer_idx + 1) % len(self)) - return (key_tensor, value_tensor) - else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Saves the beam indices and reorders the cache when the tensor is back to its device.""" - # We delay this operation until the tensors are back to their original - # device because performing torch.index_select on the CPU is very slow - del self.beam_idx - self.beam_idx = beam_idx.clone() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. - Return: - A tuple containing the updated key and value states. - """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] + def __init__(self, config: Optional[CacheConfig] = None) -> None: + # Create the underlying cache with offload processor + processors = CacheProcessorList([OffloadedCacheProcessor()]) + super().__init__(processors=processors, config=config) - # Update the cache - if len(self.key_cache) < layer_idx: - raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") - elif len(self.key_cache) == layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - self.original_device.append(key_states.device) - self.evict_previous_layer(layer_idx) - else: - key_tensor, value_tensor = self[layer_idx] - self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError - # if a method is not supposed to be supported in a subclass we should set it to None - from_legacy_cache = None - - to_legacy_cache = None - - -class QuantizedCache(DynamicCache): +class QuantoQuantizedCache(DynamicCache): """ A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. @@ -853,87 +1029,8 @@ class QuantizedCache(DynamicCache): It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and Value in original precision states as a list of tensors, one for each layer. The size of each tensor is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - """ - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - super().__init__() - self._quantized_key_cache: list[torch.Tensor] = [] - self._quantized_value_cache: list[torch.Tensor] = [] - - self.nbits = cache_config.nbits - self.residual_length = cache_config.residual_length - self.q_group_size = cache_config.q_group_size - self.axis_key = cache_config.axis_key - self.axis_value = cache_config.axis_value - self.compute_dtype = cache_config.compute_dtype - self.device = cache_config.device - - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - if len(self.key_cache) < layer_idx: - raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") - elif len(self.key_cache) == layer_idx: - self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key)) - self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value)) - self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) - self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) - keys_to_return, values_to_return = key_states, value_states - else: - dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) - dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) - keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] - values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] - - keys_to_return = torch.cat(keys_to_return, dim=-2) - values_to_return = torch.cat(values_to_return, dim=-2) - if ( - self.key_cache[layer_idx].dim() == 4 - and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length - ): - self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) - self._quantized_value_cache[layer_idx] = self._quantize( - values_to_return.contiguous(), axis=self.axis_value - ) - self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) - self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return keys_to_return, values_to_return - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is - # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to verify attn_weight shape in some models - return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 - - def _quantize(self, tensor, axis): - """Quantizes a key/value using a defined quantization method.""" - raise NotImplementedError("Make sure to implement `_quantize` in a subclass.") - - def _dequantize(self, q_tensor): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.") - - -class QuantoQuantizedCache(QuantizedCache): - """ - Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. Parameters: cache_config (`QuantizedCacheConfig`): @@ -959,47 +1056,25 @@ class QuantoQuantizedCache(QuantizedCache): ``` """ - def __init__(self, cache_config: CacheConfig) -> None: - super().__init__(cache_config) - - if is_optimum_quanto_available(): - optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) - if optimum_quanto_version <= version.parse("0.2.5"): - raise ImportError( - f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." - ) - from optimum.quanto import MaxOptimizer, qint2, qint4 - - if self.nbits not in [2, 4]: - raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") - - if self.axis_key not in [0, -1]: - raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") - - if self.axis_value not in [0, -1]: - raise ValueError( - f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" - ) - - self.qtype = qint4 if self.nbits == 4 else qint2 - self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) + super(DynamicCache, self).__init__(processors=processors) - def _quantize(self, tensor, axis): - # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore - if is_optimum_quanto_available(): - from optimum.quanto import quantize_weight - scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) - qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) - return qtensor +class HQQQuantizedCache(DynamicCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - def _dequantize(self, qtensor): - return qtensor.dequantize() + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` -class HQQQuantizedCache(QuantizedCache): - """ - Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. Parameters: cache_config (`QuantizedCacheConfig`): @@ -1025,55 +1100,99 @@ class HQQQuantizedCache(QuantizedCache): ``` """ - def __init__(self, cache_config: CacheConfig) -> None: - super().__init__(cache_config) - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" - ) - - if self.axis_key not in [0, 1]: - raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)]) + super(DynamicCache, self).__init__(processors=processors) - if self.axis_value not in [0, 1]: - raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") - self.quantizer = HQQQuantizer +class StaticLayer(CacheLayer): + is_compileable = True - def _quantize(self, tensor, axis): - qtensor, meta = self.quantizer.quantize( - tensor, - axis=axis, - device=self.device, - compute_dtype=self.compute_dtype, - nbits=self.nbits, - group_size=self.q_group_size, + def __init__( + self, + config: StaticCacheConfig, + max_len: Optional[int] = None, + ): + self.max_cache_len = max_len or config.max_cache_len + self.max_batch_size = config.batch_size + # Note: There will be significant perf decrease if switching to use 5D tensors instead. + self.key_cache = torch.zeros( + (config.batch_size, config.num_heads, self.max_cache_len, config.head_dim), + dtype=config.dtype, + device=config.device, ) - meta["compute_dtype"] = self.compute_dtype - self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype - meta["scale"] = meta["scale"].to(qtensor.device) - meta["zero"] = meta["zero"].to(qtensor.device) - return qtensor, meta + self.value_cache = torch.zeros( + (config.batch_size, config.num_heads, self.max_cache_len, config.head_dim), + dtype=config.dtype, + device=config.device, + ) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(self.key_cache) + torch._dynamo.mark_static_address(self.value_cache) - def _dequantize(self, qtensor): - quant_tensor, meta = qtensor - tensor = self.quantizer.dequantize(quant_tensor, meta) - return tensor + def get_max_cache_shape(self) -> int: + return self.max_cache_len + + def _static_update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_position: Optional[torch.LongTensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Utility functions for static/sliding cache update logic + """ + Updates the static cache tensors in place. + Args: + k_cache (`torch.Tensor`): The key cache tensor to update. + v_cache (`torch.Tensor`): The value cache tensor to update. + key_states (`torch.Tensor`): The new key states to add. + value_states (`torch.Tensor`): The new value states to add. + cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted. + If None, the entire cache is overwritten (prefill). -class SinkCache(Cache): - """ - Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache. - See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for - general `custom_generate`usage. - """ + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place). + """ + if cache_position is None: + # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. + self.key_cache.copy_(key_states) + self.value_cache.copy_(value_states) + else: + # Generation phase. Update specific positions. + # Use index_copy_ for in-place update (compile-friendly). + try: + self.key_cache.index_copy_(2, cache_position, key_states) + self.value_cache.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # Fallback for devices like MPS where index_copy_ might not be supported. + self.key_cache[:, :, cache_position] = key_states + self.value_cache[:, :, cache_position] = value_states + return self.key_cache, self.value_cache + + def update(self, key_states, value_states, cache_kwargs=None) -> tuple[torch.Tensor, torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + key_states = key_states.to(self.key_cache.dtype) + value_states = value_states.to(self.value_cache.dtype) + return self._static_update(key_states, value_states, cache_position) + + def get_seq_length(self, cache_position=None) -> int: + if cache_position is not None: + return int(cache_position[-1] + 1) + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + return (self.key_cache[0, 0].any(dim=-1)).sum() - # TODO (joao, manuel): Remove this class in v4.59.0 - def __init__(self, **kwargs) -> None: - raise NotImplementedError( - "`SinkCache` has been moved as a `custom_generate` repository on the Hub: " - "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples." - ) + def reset(self): + self.key_cache.zero_() + self.value_cache.zero_() + + def reorder_cache(self, beam_idx): + dev = self.key_cache.device + beam_idx_dev = beam_idx.to(dev) + self.key_cache = self.key_cache.index_select(0, beam_idx_dev) + self.value_cache = self.value_cache.index_select(0, beam_idx_dev) class StaticCache(Cache): @@ -1081,23 +1200,9 @@ class StaticCache(Cache): Static Cache class to be used with `torch.compile(model)` and `torch.export()`. Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. If you are manually setting the batch size, make sure to take into account the - number of beams if you are running beam search - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. + config_or_ddp_cache_data (`Union`, *optional*): Model configuration for shape/device info, or DDP-distributed cache data for compatibility. + processors (`Optional`, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + pattern_block (`Optional`, *optional*): Pattern of cache layer types to use. Defaults to `(StaticLayer,)` for backward compatibility. Example: @@ -1120,117 +1225,80 @@ class StaticCache(Cache): ``` """ - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.float32, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - super().__init__() - self.max_batch_size = max_batch_size - self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len + pattern_block = (StaticLayer,) - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - self._dtype = dtype - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) +class SlidingWindowLayer(StaticLayer): + """ + A static cache layer that implements sliding window attention caching. + Inherits from StaticLayer but uses sliding window update logic. + """ - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - # Note: There will be significant perf decrease if switching to use 5D tensors instead. - cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - device = torch.device(device) if device is not None else None - for idx in range(config.num_hidden_layers): - if layer_device_map is not None: - layer_device = layer_device_map[idx] - else: - layer_device = device - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) + def __init__(self, config: CacheConfig): + super().__init__(config, max_len=config.sliding_window) - def update( + def _static_update( self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, + cache_position: torch.LongTensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + Updates the sliding window cache tensors, returning the potentially modified tensors. - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input - to know how where to write in the cache. + Args: + k_cache (`torch.Tensor`): The key cache tensor to update. + v_cache (`torch.Tensor`): The value cache tensor to update. + key_states (`torch.Tensor`): The new key states to add. + value_states (`torch.Tensor`): The new value states to add. + cache_position (`torch.LongTensor`): The position indices where the new states should be inserted. + max_cache_len (`int`): The maximum length of the sliding window cache. - Return: - A tuple containing the updated key and value states. + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update. + For prefill > window, these are the full input states. + Otherwise, they are the updated cache tensors. """ - if cache_kwargs is None: - cache_kwargs = {} - key_states = key_states.to(self.key_cache[layer_idx].dtype) - value_states = value_states.to(self.value_cache[layer_idx].dtype) - return _static_cache_update( - self.key_cache[layer_idx], - self.value_cache[layer_idx], - key_states, - value_states, - cache_kwargs.get("cache_position"), - ) + if cache_position is None: + raise ValueError("`cache_position` must be provided for SlidingWindowLayer.") - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + # Handle prefill phase when prompt length > sliding_window_size + if cache_position.shape[0] > self.max_cache_len: + new_k = key_states[:, :, -self.max_cache_len :, :] + new_v = value_states[:, :, -self.max_cache_len :, :] + self.key_cache.copy_(new_k) + self.value_cache.copy_(new_v) + return self.key_cache, self.value_cache - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len + # Sliding window logic for generation phase or prefill < window + slicing = torch.arange(self.max_cache_len, device=value_states.device) + current_seq_len = cache_position[-1] + 1 # Use last position to determine current length + to_shift = current_seq_len > self.max_cache_len + indices = (slicing + to_shift.sum()) % self.max_cache_len - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + k_out_shifted = self.key_cache[:, :, indices] + v_out_shifted = self.value_cache[:, :, indices] - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - kv_length = self.get_max_cache_shape() - return kv_length, 0 + # Clamp cache_position to determine the *target index* within the shifted cache view + update_position = cache_position.clamp(min=0, max=self.max_cache_len - 1) + + try: + k_out_updated = k_out_shifted.index_copy(2, update_position, key_states) + v_out_updated = v_out_shifted.index_copy(2, update_position, value_states) + except NotImplementedError: + # Fallback for MPS: clone and modify the clone + k_out_updated = k_out_shifted.clone() + v_out_updated = v_out_shifted.clone() + k_out_updated[:, :, update_position] = key_states + v_out_updated[:, :, update_position] = value_states + + self.key_cache.copy_(k_out_updated) + self.value_cache.copy_(v_out_updated) + return self.key_cache, self.value_cache -class SlidingWindowCache(StaticCache): +class SlidingWindowCache(Cache): """ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, @@ -1246,25 +1314,6 @@ class SlidingWindowCache(StaticCache): 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) - - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - Example: ```python @@ -1285,83 +1334,7 @@ class SlidingWindowCache(StaticCache): ``` """ - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.float32, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - max_cache_len = min(config.sliding_window, max_cache_len) - self.sliding_window = config.sliding_window - super().__init__( - config=config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=device, - dtype=dtype, - layer_device_map=layer_device_map, - ) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - - if cache_position is None: - raise ValueError("`cache_position` must be provided for SlidingWindowCache.") - - key_states = key_states.to(self.key_cache[layer_idx].dtype) - value_states = value_states.to(self.value_cache[layer_idx].dtype) - - return _sliding_cache_update( - self.key_cache[layer_idx], - self.value_cache[layer_idx], - key_states, - value_states, - cache_position, - self.max_cache_len, - ) - - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len - - def reset(self): - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] - # torch.clamp() is equivalent to max() but should be compile-friendly/exportable as first_cache_position is a Tensor - kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) - # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns - kv_length = max(query_length, self.get_max_cache_shape()) - return kv_length, kv_offset + pattern_block = (SlidingWindowLayer,) class EncoderDecoderCache(Cache): @@ -1436,7 +1409,7 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]: @classmethod def from_legacy_cache( - cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None + cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None ) -> "EncoderDecoderCache": """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" cache = cls( @@ -1507,22 +1480,6 @@ def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDec out.append(EncoderDecoderCache(self_attn, cross_attn)) return out - @classmethod - def from_batch_splits(cls, splits: list["EncoderDecoderCache"]) -> "EncoderDecoderCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - self_attention_cache = DynamicCache() - cross_attention_cache = DynamicCache() - for idx in range(len(splits[0])): - layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0) - layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0) - self_attention_cache.update(layer_keys, layer_values, idx) - - layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0) - layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0) - cross_attention_cache.update(layer_keys, layer_values, idx) - return cls(self_attention_cache, cross_attention_cache) - def batch_repeat_interleave(self, repeats: int): """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" self.check_dynamic_cache(self.batch_repeat_interleave.__name__) @@ -1535,7 +1492,7 @@ def batch_select_indices(self, indices: torch.Tensor): self.self_attention_cache.batch_select_indices(indices) self.cross_attention_cache.batch_select_indices(indices) - def get_max_cache_shape(self) -> Optional[int]: + def get_max_cache_shape(self) -> int: """Returns the maximum sequence length (i.e. max capacity) of the cache object""" return self.self_attention_cache.get_max_cache_shape() @@ -1557,23 +1514,10 @@ class HybridCache(Cache): for global attention.For more information, see the documentation of each subcomponent cache class. Parameters: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (torch.dtype, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - + config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): Model configuration for shape/device info. No DDP-distributed cache data is supported. + processors (`CacheProcessorList`, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): Pattern of cache layer types to use. Defaults to `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)` + for backward compatibility. Example: ```python @@ -1594,146 +1538,24 @@ class HybridCache(Cache): ``` """ - is_compileable = True - def __init__( self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.float32, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - super().__init__() - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'hybrid' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings - # Sliding layers can't be larger than the overall max cache len - self.sliding_window_len = min(config.sliding_window, self.max_cache_len) - self.max_batch_size = max_batch_size - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) - - self._dtype = dtype - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - + config_or_ddp_cache_data=None, + processors: Optional[CacheProcessorList] = None, + pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, + *args, + **kwargs, + ): + model_config = config_or_ddp_cache_data or kwargs.get("config", None) + assert model_config is not None, "HybridCache requires a model config" # If the attribute does not exist in the config, fallback to a simple StaticCache - if hasattr(config, "layer_types"): - self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types] - else: - self.is_sliding = [False] * config.num_hidden_layers - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim) - self.sliding_window = min(config.sliding_window, max_cache_len) - device = torch.device(device) if device is not None else None - for i in range(config.num_hidden_layers): - if layer_device_map is not None: - layer_device = layer_device_map[i] - else: - layer_device = device - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - if cache_position is None: - raise ValueError("`cache_position` must be provided for HybridCache.") - - is_sliding_layer = self.is_sliding[layer_idx] - - # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used - # when the cache is initialized in the forward pass (e.g. Gemma2) - if self.key_cache[layer_idx].device != key_states.device: - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) - if self.value_cache[layer_idx].device != value_states.device: - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) - - k_cache = self.key_cache[layer_idx] - v_cache = self.value_cache[layer_idx] - key_states = key_states.to(k_cache.dtype) - value_states = value_states.to(v_cache.dtype) - - if is_sliding_layer: - return _sliding_cache_update( - k_cache, - v_cache, - key_states, - value_states, - cache_position, - k_cache.shape[2], # Use actual cache dim as max cache len - ) - else: - return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position) - - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len - - def get_seq_length(self, layer_idx: Optional[int] = 0): - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx != 0: - raise ValueError( - "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " - "Using the `layer_idx` argument is not supported." - ) - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - if self.is_sliding[layer_idx]: - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] - - local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) - # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns - local_mask_kv_length = max(query_length, self.sliding_window) - return local_mask_kv_length, local_mask_kv_offset + if hasattr(model_config, "layer_types"): + self.is_sliding = [layer_type != "full_attention" for layer_type in model_config.layer_types] + else: + self.is_sliding = [False] * model_config.num_hidden_layers - full_mask_kv_offset = 0 - full_mask_kv_length = self.get_max_cache_shape() - return full_mask_kv_length, full_mask_kv_offset + pattern_block = tuple(SlidingWindowLayer if sl else StaticLayer for sl in self.is_sliding) + super().__init__(config_or_ddp_cache_data, processors, pattern_block, *args, **kwargs) class HybridChunkedCache(Cache): @@ -1783,6 +1605,9 @@ class HybridChunkedCache(Cache): """ is_compileable = True + # Override @property since HybridChunked does its own thing + key_cache = None + value_cache = None def __init__( self, @@ -1857,8 +1682,13 @@ def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2) full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2) else: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + try: + self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + except NotImplementedError: + # MPS does not support index_copy_ + self.key_cache[layer_idx][:, :, cache_position] = key_states + self.value_cache[layer_idx][:, :, cache_position] = value_states return self.key_cache[layer_idx], self.value_cache[layer_idx] self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) @@ -1903,7 +1733,7 @@ def update( k_out.shape[2], ) - def get_max_cache_shape(self) -> Optional[int]: + def get_max_cache_shape(self) -> int: return self.max_cache_len def get_seq_length(self, layer_idx: Optional[int] = 0): @@ -1927,6 +1757,16 @@ def reset(self): self.value_cache[layer_idx].zero_() self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.value_cache[layer_idx].numel(): + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: """ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for @@ -2073,133 +1913,24 @@ def _prefetch_layer_in_context(self, layer_idx: int) -> None: self.device_value_cache[self.active_device_layer].fill_(0.0) -class MambaCache: - """ - Cache for mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Example: - - ```python - >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache - - >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") - - >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values - MambaCache() - ``` - """ - - is_compileable = True - - # TODO (joao): add layer_device_map arg and update code in `generate` accordingly - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float16, - device: Union[torch.device, str, None] = None, - ): - self.max_batch_size = max_batch_size - self._dtype = dtype - self.intermediate_size = config.intermediate_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - - self.conv_states: list[torch.Tensor] = [] - self.ssm_states: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - for _ in range(config.num_hidden_layers): - conv_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=self._dtype, - ) - ssm_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=self._dtype, - ) - - torch._dynamo.mark_static_address(conv_state) - torch._dynamo.mark_static_address(ssm_state) - self.conv_states.append(conv_state) - self.ssm_states.append(ssm_state) - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor - ) -> torch.Tensor: - # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used - # when the cache is initialized in the forward pass (e.g. Mamba) - if self.conv_states[layer_idx].device != new_conv_state.device: - self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) - - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) - return self.ssm_states[layer_idx] - - def reset(self): - for layer_idx in range(len(self.conv_states)): - # In-place ops prevent breaking the static address - self.conv_states[layer_idx].zero_() - self.ssm_states[layer_idx].zero_() - - class OffloadedStaticCache(StaticCache): """ - Static cache class to be used with `torch.compile(model)` that offloads to the CPU or - another device. + A drop-in replacement for StaticCache that conserves accelerator memory by offloading + cache tensors to CPU when not actively being used. - Args: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize - the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. - max_cache_len (`int`): - The maximum sequence length with which the model will be used. - device (`Union[str, torch.device]`): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*): - The default `dtype` to use when initializing the cache. - offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): - The device to offload to. Defaults to CPU. - layer_device_map (`dict[int, Union[str, torch.device, int]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. + This cache maintains the compilation-friendly properties of StaticCache while enabling + much longer sequences by offloading inactive layers to CPU memory. - Example: + Parameters: + config (`PretrainedConfig`): Model configuration for shape/device info. + max_batch_size (`int`): Maximum batch size for static caches. + max_cache_len (`int`, *optional*): Maximum sequence length. + device (`torch.device` or `str`, *optional*): Device for cache tensors. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type for cache tensors. + offload_device (`Union`, *optional*, defaults to `"cpu"`): Device to offload cache tensors to. + layer_device_map (`dict[int, Union[str, torch.device, int]]`, *optional*): Per-layer device mapping. + Example: ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache @@ -2208,240 +1939,493 @@ class OffloadedStaticCache(StaticCache): >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> # Prepare a cache class with offloading >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = OffloadedStaticCache( + ... config=model.config, + ... max_batch_size=1, + ... max_cache_len=max_generated_length, + ... device=model.device, + ... dtype=model.dtype + ... ) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + >>> outputs.past_key_values # access cache with offloaded layers + OffloadedStaticCache() ``` """ - is_compileable = True - def __init__( self, config: PretrainedConfig, max_batch_size: int, - max_cache_len: Optional[int], - device: Union[str, torch.device], + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, dtype: Optional[torch.dtype] = None, - offload_device: Union[str, torch.device] = torch.device("cpu"), + offload_device: Union[str, torch.device] = "cpu", layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, ) -> None: - super(Cache, self).__init__() + # Create offload processor + processors = CacheProcessorList([OffloadedCacheProcessor(offload_device)]) - # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps - # track of the original device of each layer - unique_devices = set(layer_device_map.values()) if layer_device_map else set() - if len(unique_devices) > 1: - raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}") + # Initialize the base StaticCache with the processor + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + layer_device_map=layer_device_map, + processors=processors, + ) - self.max_batch_size = max_batch_size - self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) + +class OffloadedCacheProcessor(CacheProcessor): + """ + A cache processor that offloads cache tensors to conserve accelerator memory. + + This processor manages moving cache tensors between accelerator and CPU memory, + using asynchronous prefetching to minimize performance impact. Works with both + dynamic and static layers. + """ + + def __init__(self, offload_device: Union[str, torch.device] = "cpu"): self.offload_device = torch.device(offload_device) - self._dtype = dtype if dtype is not None else torch.float32 + self.original_device = [] + self.prefetch_stream = None + self.beam_idx = None + + def init(self, cache: "Cache", **kwargs) -> None: + """Initialize the offload processor and check device compatibility.""" + if not ( + torch.cuda.is_available() + or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) + ): + raise RuntimeError( + "OffloadedCacheProcessor can only be used with a GPU" + + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") + ) - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers) + if self.is_static: + for i, layer in enumerate(cache.layers): + device = cache.config.device if i == 0 else self.offload_device + layer.key_cache = layer.key_cache.to(device) + layer.value_cache = layer.value_cache.to(device) + self.original_device.append(cache.config.device) + if len(cache) != cache.config.num_layers: + raise ValueError("If static layers are used, all cache layers must be initialized") - num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads + self.prefetch_stream = ( + torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream() ) - cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) + def pre_update( + self, + cache: "Cache", + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Handle prefetching and eviction before cache update.""" + # Update the cache + if len(cache) < layer_idx: + raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") + elif len(cache) == layer_idx: + self.original_device.append(key_states.device) + self._evict_previous_layer(cache, layer_idx) + else: + # Wait for the previous layer to be evicted (on default stream) + if is_torch_greater_or_equal("2.7", accept_dev=True): + torch.accelerator.current_stream().synchronize() + else: + torch.cuda.current_stream().synchronize() + self._evict_previous_layer(cache, layer_idx) + self._ensure_layer_on_device(cache, layer_idx) + + # Prefetch the next layer + self._prefetch_layer(cache, (layer_idx + 1) % len(cache)) + return key_states, value_states - # Create offloaded CPU tensors. - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] + def _prefetch_layer(self, cache: "Cache", layer_idx: int): + """Starts prefetching the next layer cache.""" + if layer_idx < len(cache): + with ( + self.prefetch_stream + if is_torch_greater_or_equal("2.7", accept_dev=True) + else torch.cuda.stream(self.prefetch_stream) + ): + # Prefetch next layer tensors to GPU + device = self.original_device[layer_idx] + cache.key_cache[layer_idx] = cache.key_cache[layer_idx].to(device, non_blocking=True) + cache.value_cache[layer_idx] = cache.value_cache[layer_idx].to(device, non_blocking=True) - for i in range(config.num_hidden_layers): - # First layer is always on-device. - device = self.device if i == 0 else self.offload_device + def _evict_previous_layer(self, cache: "Cache", layer_idx: int): + """Moves the previous layer cache to the CPU.""" + if len(cache) >= 2: # Layer 0 stays on device to be on-device after all layers are created + # We do it on the default stream so it occurs after all earlier computations on these tensors are done + prev_layer_idx = (layer_idx - 1) % len(cache) + cache.key_cache[prev_layer_idx] = cache.key_cache[prev_layer_idx].to( + self.offload_device, non_blocking=True + ) + cache.value_cache[prev_layer_idx] = cache.value_cache[prev_layer_idx].to( + self.offload_device, non_blocking=True + ) - key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device) + def _ensure_layer_on_device(self, cache: "Cache", layer_idx: int): + """Ensures the current layer is on the original device.""" + if layer_idx < len(cache): + # Wait for the previous prefetch to be done + self.prefetch_stream.synchronize() - self.key_cache.append(key_cache) - self.value_cache.append(value_cache) + # Handle delayed beam search operations + if self.beam_idx is not None: + self.beam_idx = self.beam_idx.to(self.original_device[layer_idx]) + cache.key_cache[layer_idx] = cache.key_cache[layer_idx].index_select(0, self.beam_idx) + cache.value_cache[layer_idx] = cache.value_cache[layer_idx].index_select(0, self.beam_idx) - # Create device tensors. - self._device_key_cache: list[torch.Tensor] = [] - self._device_value_cache: list[torch.Tensor] = [] - for i in range(2): - key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device) +class QuantizedCacheProcessor(CacheProcessor): + """ + A cache processor that applies quantization to cache tensors to reduce memory usage. - self._device_key_cache.append(key_cache) - self._device_value_cache.append(value_cache) + This processor quantizes cache tensors after they are stored, maintaining a residual + length in original precision and quantizing older tokens. + """ - # For backwards compatibility. - # TODO(gante): Remove this. + def __init__(self, cache_config: QuantizedCacheConfig): + self.config = cache_config + self._quantized_key_cache: list[torch.Tensor] = [] + self._quantized_value_cache: list[torch.Tensor] = [] self._seen_tokens = 0 - # Create new CUDA stream for parallel prefetching. - self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None + def init(self, cache: "Cache", **kwargs) -> None: + """Initialize the quantized processor and validate configuration.""" + self.config.validate() - def update( + # Only compatible with DynamicCache + if not isinstance(cache, DynamicCache): + raise ValueError("QuantizedCacheProcessor is only compatible with DynamicCache") + + def post_update( self, - key_states: torch.Tensor, - value_states: torch.Tensor, + cache: "Cache", + key_tensors: torch.Tensor, + value_tensors: torch.Tensor, layer_idx: int, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + """Apply quantization after cache update.""" + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_tensors.shape[-2] - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, *optional*): - Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the - `cache_position` input to know how where to write in the cache. + # Extend quantized cache if needed + while len(self._quantized_key_cache) <= layer_idx: + self._quantized_key_cache.append(torch.empty(0)) + self._quantized_value_cache.append(torch.empty(0)) - Return: - A tuple containing the updated key and value states. - """ + # Check if we need to quantize + if layer_idx < len(cache.key_cache): + current_key = cache.key_cache[layer_idx] + current_value = cache.value_cache[layer_idx] - key_states = key_states.to(self.key_cache[layer_idx].dtype) - value_states = value_states.to(self.value_cache[layer_idx].dtype) + if ( + current_key.dim() == 4 + and current_key.shape[-2] >= self.config.residual_length + and current_key.shape[-2] > self._get_quantized_length(layer_idx) + ): + # Quantize the older part, keep recent tokens in original precision + split_idx = current_key.shape[-2] - self.config.residual_length + + # Get the part to quantize + key_to_quantize = current_key[:, :, :split_idx, :].contiguous() + value_to_quantize = current_value[:, :, :split_idx, :].contiguous() + + # Quantize and store + self._quantized_key_cache[layer_idx] = self._quantize(key_to_quantize, axis=self.config.axis_key) + self._quantized_value_cache[layer_idx] = self._quantize(value_to_quantize, axis=self.config.axis_value) + + # Keep only the recent tokens in original precision + cache.key_cache[layer_idx] = current_key[:, :, split_idx:, :] + cache.value_cache[layer_idx] = current_value[:, :, split_idx:, :] + + # Return the full tensors for this update + if self._quantized_key_cache[layer_idx].numel() > 0: + dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) + dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + full_key = torch.cat([dequant_key, cache.key_cache[layer_idx]], dim=-2) + full_value = torch.cat([dequant_value, cache.value_cache[layer_idx]], dim=-2) + return full_key, full_value + + return key_tensors, value_tensors + + def _get_quantized_length(self, layer_idx: int) -> int: + """Get the length of quantized cache for a layer.""" + if layer_idx < len(self._quantized_key_cache) and self._quantized_key_cache[layer_idx].numel() > 0: + # This would depend on the specific quantization implementation + return ( + self._quantized_key_cache[layer_idx].shape[-2] + if hasattr(self._quantized_key_cache[layer_idx], "shape") + else 0 + ) + return 0 - if layer_idx == 0: - # Update seen tokens. - # TODO(gante): Remove this. - self._seen_tokens += key_states.shape[-2] + def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: + """Quantize a tensor - to be implemented by specific quantization backends.""" + raise NotImplementedError("Quantization backend must implement _quantize method") - # Always there. - k_out = self.key_cache[0] - v_out = self.value_cache[0] - else: - # Wait for prefetch stream. - if self._prefetch_stream is not None: - torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream) + def _dequantize(self, tensor: torch.Tensor) -> torch.Tensor: + """Dequantize a tensor - to be implemented by specific quantization backends.""" + raise NotImplementedError("Quantization backend must implement _dequantize method") - k_out = self._device_key_cache[layer_idx & 1] - v_out = self._device_value_cache[layer_idx & 1] - self._prefetch_layer(layer_idx + 1) +class QuantoQuantizedCacheProcessor(QuantizedCacheProcessor): + """ + Quantized cache processor that uses `quanto` as a backend to perform quantization. + Current implementation supports `int2` and `int4` dtypes only. + """ - cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) + def init(self, cache: "Cache", **kwargs) -> None: + """Initialize the quanto quantization processor.""" + super().init(cache, **kwargs) - # Copy the values to the offloaded device as well. - if layer_idx == 0: - self.key_cache[layer_idx].copy_(key_states.to(self.offload_device)) - self.value_cache[layer_idx].copy_(value_states.to(self.offload_device)) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does - # explicitly an in-place operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS - # device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # Copy the values to the offloaded device as well. - if layer_idx != 0: - cache_position = cache_position.to(self.offload_device) - key_states = key_states.to(self.offload_device) - value_states = value_states.to(self.offload_device) - - try: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS - # device. - self.key_cache[layer_idx][:, :, cache_position] = key_states - self.value_cache[layer_idx][:, :, cache_position] = value_states + if is_optimum_quanto_available(): + optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) + if optimum_quanto_version <= version.parse("0.2.5"): + raise ImportError( + f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCacheProcessor`. Detected version {optimum_quanto_version}." + ) + from optimum.quanto import MaxOptimizer, qint2, qint4 - return k_out, v_out + if self.config.nbits not in [2, 4]: + raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.config.nbits}") - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" + if self.config.axis_key not in [0, -1]: + raise ValueError( + f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.config.axis_key}" + ) - # TODO(gante): Remove this. - return self._seen_tokens + if self.config.axis_value not in [0, -1]: + raise ValueError( + f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.config.axis_value}" + ) - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states.""" + self.qtype = qint4 if self.config.nbits == 4 else qint2 + self.optimizer = MaxOptimizer() - return self.max_cache_len + def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: + """Quantize tensor using quanto backend.""" + if is_optimum_quanto_available(): + from optimum.quanto import quantize_weight - def reset(self) -> None: - """Resets the cache values while preserving the objects.""" + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.config.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.config.q_group_size) + return qtensor - # For backwards compatibility. - # TODO(gante): Remove this. - self._seen_tokens = 0 + def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor: + """Dequantize tensor using quanto backend.""" + return qtensor.dequantize() - # Zero out cache. - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address. - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - @property - def seen_tokens(self) -> int: - # For backwards compatibility. - # TODO(gante): Remove this. - return self._seen_tokens +class HQQQuantizedCacheProcessor(QuantizedCacheProcessor): + """ + Quantized cache processor that uses `HQQ` as a backend to perform quantization. + Current implementation supports `int2`, `int4`, `int8` dtypes. + """ - def _create_key_value_cache_tensors( - self, shape: tuple[int, ...], device: torch.device - ) -> tuple[torch.Tensor, torch.Tensor]: - """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static - addresses for non-CPU tensors. + def init(self, cache: "Cache", **kwargs) -> None: + """Initialize the HQQ quantization processor.""" + super().init(cache, **kwargs) - Args: - shape (`tuple[int, ...]`): Shape. - device (`torch.device`): Device. + if self.config.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.config.nbits}" + ) - Returns: - Key and value cache tensors as a tuple. - """ + if self.config.axis_key not in [0, 1]: + raise ValueError( + f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.config.axis_key}" + ) - is_cpu_device = device == torch.device("cpu") + if self.config.axis_value not in [0, 1]: + raise ValueError( + f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.config.axis_value}" + ) - key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) - value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) + self.quantizer = HQQQuantizer - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(key_cache) - torch._dynamo.mark_static_address(value_cache) + def _quantize(self, tensor: torch.Tensor, axis: int) -> tuple[torch.Tensor, dict]: + """Quantize tensor using HQQ backend.""" + qtensor, meta = self.quantizer.quantize( + tensor, + axis=axis, + device=self.config.device, + compute_dtype=self.config.compute_dtype, + nbits=self.config.nbits, + group_size=self.config.q_group_size, + ) + meta["compute_dtype"] = self.config.compute_dtype + self.quantizer.cuda(qtensor, meta=meta, device=self.config.device) # Move to device and cast to dtype + meta["scale"] = meta["scale"].to(qtensor.device) + meta["zero"] = meta["zero"].to(qtensor.device) + return qtensor, meta - return key_cache, value_cache + def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tensor: + """Dequantize tensor using HQQ backend.""" + quant_tensor, meta = qtensor_and_meta + tensor = self.quantizer.dequantize(quant_tensor, meta) + return tensor - def _prefetch_layer(self, layer_idx: int) -> None: - """Prefetch a layer to the device. Needs to be called in order of layer indices.""" - # Don't fetch layers that do not exist. - if layer_idx >= len(self.key_cache): - return +class QuantizedCache(DynamicCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - # Alternate between two on-device caches. - if self._prefetch_stream is not None: - with torch.cuda.stream(self._prefetch_stream): - self._prefetch_layer_in_context(layer_idx) - else: - self._prefetch_layer_in_context(layer_idx) + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - def _prefetch_layer_in_context(self, layer_idx: int) -> None: - """Performs the actual copy of the layer to device cache.""" + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + """ + + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) + super().__init__(processors=processors) + + +class SinkCache(Cache): + """ + Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache. + See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for + general `custom_generate`usage. + """ + + # TODO (joao, manuel): Remove this class in v4.59.0 + def __init__(self, **kwargs) -> None: + raise NotImplementedError( + "`SinkCache` has been moved as a `custom_generate` repository on the Hub: " + "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples." + ) - self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) - self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True) + +# TODO (manuel, joao): remove this class, it is here only for backwards compatibility +# PEP 562: Lazy loading for deprecated location of MambaCache +def __getattr__(name: str) -> Any: + if name == "MambaCache": + warnings.warn( + ( + "Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed " + "in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead." + ), + FutureWarning, + stacklevel=2, + ) + + class MambaCache: + """ + Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed + in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead. + + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + MambaCache() + ``` + """ + + is_compileable = True + + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Union[torch.device, str, None] = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. Mamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() + + return MambaCache + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 76587659371b..baae6690a94d 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -54,7 +54,6 @@ HQQQuantizedCache, HybridCache, HybridChunkedCache, - MambaCache, OffloadedHybridCache, OffloadedStaticCache, QuantizedCacheConfig, @@ -67,6 +66,8 @@ CACHE_CONFIG_MAPPING["quantized"] = QuantizedCacheConfig CACHE_CONFIG_MAPPING["static"] = StaticCacheConfig + CACHE_CONFIG_MAPPING["sliding_window"] = StaticCacheConfig + CACHE_CONFIG_MAPPING["hybrid"] = StaticCacheConfig NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache, "offloaded_static": OffloadedStaticCache, @@ -75,12 +76,9 @@ "hybrid_chunked": HybridChunkedCache, "offloaded_hybrid": OffloadedHybridCache, "offloaded_hybrid_chunked": OffloadedHybridCache, - "mamba": MambaCache, } QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} - ALL_CACHE_IMPLEMENTATIONS = ( - list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded", "dynamic"] - ) + ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + ["offloaded", "dynamic", "quantized"] class GenerationMode(ExplicitEnum): @@ -186,7 +184,6 @@ class GenerationConfig(PushToHubMixin): - `"offloaded_static"`: [`OffloadedStaticCache`] - `"sliding_window"`: [`SlidingWindowCache`] - `"hybrid"`: [`HybridCache`] - - `"mamba"`: [`MambaCache`] - `"quantized"`: [`QuantizedCache`] If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index bb15454c7f5b..74cdbbc9d982 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1929,9 +1929,8 @@ def _get_cache( or isinstance( cache_to_check, (HybridChunkedCache, OffloadedHybridCache) ) # due to internal slicing, we always re-init + or cache_to_check.max_cache_len < max_cache_len ) - if cache_implementation != "mamba": - need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len if requires_cross_attention_cache and hasattr(self, "_cache"): need_new_cache = ( @@ -1966,13 +1965,14 @@ def _get_cache( def _supports_default_dynamic_cache(self) -> bool: """ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. - This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which - uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in + This is mostly the same as `_supports_cache_class` attribute, but add exception for `Mamba` models which + use their own caches and do not need to initialize the Cache in advance in order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed - for `HybridMambaAttentionDynamicCache`). + for mamba-based models). """ return ( self._supports_cache_class + and "mamba" not in self.__class__.__name__.lower() and "jamba" not in self.__class__.__name__.lower() and "zamba" not in self.__class__.__name__.lower() and "bamba" not in self.__class__.__name__.lower() @@ -2023,7 +2023,7 @@ def _prepare_cache_for_generation( if generation_config.use_cache is False: return - # Quick escape route 3: model that only supports legacy caches = nothing to prepare + # Quick escape route 3: model that only supports legacy caches or models that supply it in `prepare_inputs_for_generation` (mamba, zamba, ...) if not self._supports_default_dynamic_cache(): if generation_config.cache_implementation is not None: warnings.warn( @@ -5199,106 +5199,6 @@ def _ranking_fast( return selected_idx -def _split(data, full_batch_size: int, split_size: int): - """ - Takes care of three cases: - 1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim - 2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and - return a list of tuples - 3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and - return a list of tuples of tuples - (see documentation of ModelOutput) - """ - if data is None: - return [None] * (full_batch_size // split_size) - if isinstance(data, torch.Tensor): - return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] - # New cache format - elif isinstance(data, DynamicCache) or ( - isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache) - ): - return data.batch_split(full_batch_size, split_size) - elif isinstance(data, tuple): - # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) - if isinstance(data[0], tuple): - return [ - tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data) - for i in range(0, full_batch_size, split_size) - ] - - else: - return [ - tuple(sub_tensor[i : i + split_size] for sub_tensor in data) - for i in range(0, full_batch_size, split_size) - ] - else: - raise TypeError(f"Unexpected attribute type: {type(data)}") - - -def _split_model_inputs( - model_input: Union[ModelOutput, dict], split_size: int, full_batch_size: int, config: PretrainedConfig -) -> list[Union[ModelOutput, dict]]: - """ - Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split - size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from - previous forward pass. - """ - # Edge case: if model_input is None, return a list of Nones - # this happens with Whisper where encoder_outputs is None - if model_input is None: - return [model_input] * (full_batch_size // split_size) - # Infer the class from the object - model_output_cls = type(model_input) - if (full_batch_size % split_size) != 0: - raise ValueError("`full_batch_size` must be divisible by `split_size`") - - if split_size > full_batch_size: - raise ValueError("`split_size` must be smaller or equal to `full_batch_size`") - - # Helper function to split tensors or tuples of tensors - - # Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them - keys = ( - model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys() - ) - # We only keep keys that are in the model_input - keys = [k for k in keys if k in model_input] - # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a - # ModelOutput object. - # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] - keys_to_ignore = ["cache_position", "encoder_outputs", "logits_to_keep"] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] - - # we split the tensors and tuples of tensors - data_split_list = [ - {k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys} - for i in range(full_batch_size // split_size) - ] - # bool values are the same and replicated for each split - bool_data = {k: model_input[k] for k in bool_keys} - # encoder_outputs is a ModelOutput object and should be split by its own - if "encoder_outputs" in model_input: - encoder_outputs_split = _split_model_inputs( - model_input["encoder_outputs"], split_size, full_batch_size, config.get_text_config() - ) - data_split_list = [ - {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) - ] - # logits_to_keep should be replicated for each split, similar to bool values - if "logits_to_keep" in model_input: - data_split_list = [ - {**data_split, "logits_to_keep": model_input["logits_to_keep"]} for data_split in data_split_list - ] - - # Convert each dictionary in the list to an object of the inferred class - split_model_inputs: list[Union[ModelOutput, dict]] = [ - model_output_cls(**data_split, **bool_data) for data_split in data_split_list - ] - - return split_model_inputs - - def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput: """ Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the @@ -5323,11 +5223,6 @@ def _concat(data): return None if isinstance(data[0], torch.Tensor): return torch.cat(data, dim=0) - # New cache format - elif isinstance(data[0], DynamicCache): - return DynamicCache.from_batch_splits(data) - elif isinstance(data[0], EncoderDecoderCache): - return EncoderDecoderCache.from_batch_splits(data) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index eb75d8f2b80a..fe066898d5ce 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -65,7 +65,7 @@ logger = logging.get_logger(__name__) -class FalconHybridMambaAttentionDynamicCache(DynamicCache): +class FalconHybridMambaAttentionDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -79,6 +79,10 @@ class FalconHybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__( self, config: FalconH1Config, diff --git a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py index 4099920f4028..f7e7719b37f9 100644 --- a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py @@ -1,5 +1,11 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_falcon_mamba.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,15 +18,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""FALCONMAMBA configuration""" import math from ...configuration_utils import PretrainedConfig -from ...utils import logging - - -logger = logging.get_logger(__name__) class FalconMambaConfig(PretrainedConfig): @@ -28,7 +29,7 @@ class FalconMambaConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the FALCON_MAMBA - [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture. + [state-spaces/falcon_mamba-2.8b](https://huggingface.co/state-spaces/falcon_mamba-2.8b) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -79,10 +80,12 @@ class FalconMambaConfig(PretrainedConfig): Whether or not to rescale `out_proj` weights when initializing. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. - use_mambapy (`bool`, *optional*, defaults to `False`): + use_falcon_mambapy (`bool`, *optional*, defaults to `False`): Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not available. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. mixer_rms_eps (`float`, *optional*, defaults to 1e-06): The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states. + + Example: ```python @@ -125,10 +128,11 @@ def __init__( time_step_floor=1e-4, rescale_prenorm_residual=False, use_cache=True, - use_mambapy=False, + use_falcon_mambapy=False, mixer_rms_eps=1e-6, **kwargs, ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.state_size = state_size @@ -153,10 +157,8 @@ def __init__( self.rescale_prenorm_residual = rescale_prenorm_residual self.residual_in_fp32 = residual_in_fp32 self.use_cache = use_cache - self.use_mambapy = use_mambapy + self.use_falcon_mambapy = use_falcon_mambapy self.mixer_rms_eps = mixer_rms_eps - super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) - __all__ = ["FalconMambaConfig"] diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 426e557d9d3c..07fb3c1d849c 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_falcon_mamba.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. # @@ -12,19 +18,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch FALCONMAMBA model.""" import math from dataclasses import dataclass from typing import Any, Optional, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import MambaCache +from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel @@ -35,27 +39,127 @@ logger = logging.get_logger(__name__) -if is_mambapy_available(): - from mambapy.pscan import pscan -else: - pscan = None -if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn - from mamba_ssm.ops.triton.selective_state_update import selective_state_update +class FalconMambaCache: + """ + Cache for falcon_mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache + + >>> model = FalconMambaForCausalLM.from_pretrained("state-spaces/falcon_mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/falcon_mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + FalconMambaCache() + ``` + """ - from ...kernels.falcon_mamba import mamba_inner_fn -else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + is_compileable = True -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Union[torch.device, str, None] = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. FalconMamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() + + +def is_fast_path_available(): + if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + + if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + else: + causal_conv1d_update, causal_conv1d_fn = None, None + + return ( + all((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)), + selective_state_update, + selective_scan_fn, + mamba_inner_fn, + causal_conv1d_update, + causal_conv1d_fn, + ) def rms_forward(hidden_states, variance_epsilon=1e-6): @@ -107,7 +211,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] - self.use_mambapy = config.use_mambapy + self.use_falcon_mambapy = config.use_falcon_mambapy # projection of the input hidden states self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) @@ -126,6 +230,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias + self.warn_slow_implementation() # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here self.register_buffer( "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False @@ -135,6 +240,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): ) self.rms_eps = config.mixer_rms_eps + def warn_slow_implementation(self): if not is_fast_path_available: if self.use_mambapy: if is_mambapy_available(): @@ -157,10 +263,21 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): + if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + + if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + else: + causal_conv1d_update, causal_conv1d_fn = None, None + # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -269,13 +386,17 @@ def cuda_kernels_forward( contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states - def slow_forward( - self, + # fmt: off + def slow_forward(self, input_states, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): + if is_mambapy_available(): + from mambapy.pscan import pscan + else: + pscan = None batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection @@ -344,7 +465,7 @@ def slow_forward( deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) - if self.use_mambapy and self.training and cache_params is None: + if self.use_falcon_mambapy and self.training and cache_params is None: hs = pscan( discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2) ) # [batch, seq_len, intermediate_size, ssm_state_size] @@ -371,12 +492,12 @@ def slow_forward( # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] return contextualized_states + # fmt: on - # Copied from transformers.models.mamba.modeling_mamba.MambaMixer.forward def forward( self, hidden_states, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): @@ -385,7 +506,6 @@ def forward( return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) -# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba class FalconMambaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -395,17 +515,15 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def extra_repr(self): - return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" - - # Ignore copy def forward(self, hidden_states): return self.weight.to(hidden_states.device) * rms_forward( hidden_states, variance_epsilon=self.variance_epsilon ) + def extra_repr(self): + return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" + -# Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache class FalconMambaBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() @@ -418,7 +536,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): @@ -435,7 +553,6 @@ def forward( @auto_docstring -# Copied from transformers.models.mamba.modeling_mamba.MambaPreTrainedModel with Mamba->FalconMamba class FalconMambaPreTrainedModel(PreTrainedModel): config_class = FalconMambaConfig base_model_prefix = "backbone" @@ -494,13 +611,12 @@ def _init_weights(self, module): @dataclass @auto_docstring( custom_intro=""" - Class for the FALCONMAMBA model outputs. + Class for the FALCON_MAMBA model outputs. """ ) -# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->FALCONMAMBA,Mamba->FalconMamba,FalconMambaCache->MambaCache class FalconMambaOutput(ModelOutput): r""" - cache_params (`MambaCache`): + cache_params (`FalconMambaCache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -508,7 +624,7 @@ class FalconMambaOutput(ModelOutput): """ last_hidden_state: Optional[torch.FloatTensor] = None - cache_params: Optional[MambaCache] = None + cache_params: Optional[FalconMambaCache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None @@ -518,14 +634,13 @@ class FalconMambaOutput(ModelOutput): Base class for causal language model (or autoregressive) outputs. """ ) -# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->FalconMamba,FalconMambaCache->MambaCache class FalconMambaCausalLMOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`MambaCache`): + cache_params (`FalconMambaCache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -534,7 +649,7 @@ class FalconMambaCausalLMOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None - cache_params: Optional[MambaCache] = None + cache_params: Optional[FalconMambaCache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None @@ -551,8 +666,15 @@ def __init__(self, config): self.gradient_checkpointing = False self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) self.post_init() + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + def get_input_embeddings(self): return self.embeddings @@ -564,7 +686,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -572,7 +694,7 @@ def forward( attention_mask: Optional[torch.LongTensor] = None, ) -> Union[tuple, FalconMambaOutput]: r""" - cache_params (`MambaCache`, *optional*): + cache_params (`FalconMambaCache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): @@ -595,7 +717,7 @@ def forward( if use_cache: if cache_params is None: - cache_params = MambaCache( + cache_params = FalconMambaCache( self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype ) cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) @@ -610,6 +732,7 @@ def forward( ) else: cache_params = None + hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None for mixer_block in self.layers: @@ -640,11 +763,10 @@ def forward( @auto_docstring( custom_intro=""" - The FALCONMAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input + The FALCON_MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). """ ) -# Copied from transformers.models.mamba.modeling_mamba.MambaForCausalLM with MAMBA->FALCONMAMBA,Mamba->FalconMamba,mamba->falcon_mamba,FalconMambaCache->MambaCache class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -691,14 +813,21 @@ def prepare_inputs_for_generation( input_ids, inputs_embeds=None, use_cache=None, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, **kwargs, ): # Overwritten -- uses `cache_params` as opposed to `past_key_values` + cache_params_not_initialized = cache_params is None if use_cache: + if cache_params_not_initialized: + max_batch_size = inputs_embeds.size(0) if inputs_embeds is not None else input_ids.size(0) + cache_params = FalconMambaCache( + self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype + ) + cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device) # `cache_position` should have been initialized in `generate` if cache_position is None: raise ValueError( @@ -719,7 +848,7 @@ def prepare_inputs_for_generation( # the length of `cache_params.conv_states`, which is `config.conv_kernel` cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) - if inputs_embeds is not None and cache_params is None: + if inputs_embeds is not None and cache_params_not_initialized: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} @@ -740,7 +869,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -749,7 +878,7 @@ def forward( **kwargs, # for now we need this for generation ) -> Union[tuple, FalconMambaCausalLMOutput]: r""" - cache_params (`MambaCache`, *optional*): + cache_params (`FalconMambaCache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -798,4 +927,4 @@ def forward( ) -__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel"] +__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", "FalconMambaCache"] diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py new file mode 100644 index 000000000000..7001cf982d0e --- /dev/null +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -0,0 +1,537 @@ +# coding=utf-8 +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch FALCONMAMBA model.""" + +from typing import Optional + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...utils import auto_docstring, logging +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available +from ..mamba.configuration_mamba import MambaConfig +from ..mamba.modeling_mamba import ( + MambaBlock, + MambaCache, + MambaCausalLMOutput, + MambaForCausalLM, + MambaMixer, + MambaModel, + MambaOutput, + MambaPreTrainedModel, + MambaRMSNorm, +) + + +logger = logging.get_logger(__name__) + + +def is_fast_path_available(): + if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + + if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + else: + causal_conv1d_update, causal_conv1d_fn = None, None + + return ( + all((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)), + selective_state_update, + selective_scan_fn, + mamba_inner_fn, + causal_conv1d_update, + causal_conv1d_fn, + ) + + +( + is_fast_path_available, + selective_state_update, + selective_scan_fn, + mamba_inner_fn, + causal_conv1d_update, + causal_conv1d_fn, +) = is_fast_path_available() + + +class FalconMambaConfig(MambaConfig): + """ + This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the FALCON_MAMBA + [state-spaces/falcon_mamba-2.8b](https://huggingface.co/state-spaces/falcon_mamba-2.8b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50280): + Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`FalconMambaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + state_size (`int`, *optional*, defaults to 16): shape of the state space latents. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the model. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 0): + The id of the end of sentence token in the vocabulary. + expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_scale (`float`, *optional*, defaults to 1.0): + Scale used used to scale `dt_proj.bias`. + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_init_scheme (`float`, *optional*, defaults to `"random"`): + Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]` + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the cache should be used. + use_falcon_mambapy (`bool`, *optional*, defaults to `False`): + Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not available. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. + mixer_rms_eps (`float`, *optional*, defaults to 1e-06): + The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states. + + + Example: + + ```python + >>> from transformers import FalconMambaConfig, FalconMambaModel + + >>> # Initializing a FalconMamba configuration + >>> configuration = FalconMambaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = FalconMambaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + def __init__( + self, + vocab_size=50280, + hidden_size=768, + state_size=16, + num_hidden_layers=32, + layer_norm_epsilon=1e-5, + pad_token_id=0, + bos_token_id=0, + eos_token_id=0, + expand=2, + conv_kernel=4, + use_bias=False, + use_conv_bias=True, + hidden_act="silu", + initializer_range=0.1, + residual_in_fp32=True, + time_step_rank="auto", + time_step_scale=1.0, + time_step_min=0.001, + time_step_max=0.1, + time_step_init_scheme="random", + time_step_floor=1e-4, + rescale_prenorm_residual=False, + use_cache=True, + use_falcon_mambapy=False, + mixer_rms_eps=1e-6, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + state_size=state_size, + num_hidden_layers=num_hidden_layers, + layer_norm_epsilon=layer_norm_epsilon, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + expand=expand, + conv_kernel=conv_kernel, + use_bias=use_bias, + use_conv_bias=use_conv_bias, + hidden_act=hidden_act, + initializer_range=initializer_range, + residual_in_fp32=residual_in_fp32, + time_step_rank=time_step_rank, + time_step_scale=time_step_scale, + time_step_min=time_step_min, + time_step_max=time_step_max, + time_step_init_scheme=time_step_init_scheme, + time_step_floor=time_step_floor, + rescale_prenorm_residual=rescale_prenorm_residual, + use_cache=use_cache, + use_falcon_mambapy=use_falcon_mambapy, + **kwargs, + ) + self.mixer_rms_eps = mixer_rms_eps + + +class FalconMambaCache(MambaCache): + pass + + +def rms_forward(hidden_states, variance_epsilon=1e-6): + """ + Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will + leverage this in order to multiply the final result with the RMSNorm weight + + Args: + hidden_states (`torch.Tensor`): + Hidden states to normalize + variance_epsilon (`float`): + The eps value to add in the square root scaling factor + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) + return hidden_states.to(input_dtype) + + +class FalconMambaMixer(MambaMixer): + def warn_slow_implementation(self): + if not is_fast_path_available: + if self.use_mambapy: + if is_mambapy_available(): + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + raise ImportError( + "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py." + ) + else: + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." + ) + + def __init__(self, config: FalconMambaConfig, layer_idx: int): + super().__init__(config, layer_idx) + # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here + self.register_buffer( + "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False + ) + self.register_buffer( + "dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False + ) + self.rms_eps = config.mixer_rms_eps + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + + if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + else: + causal_conv1d_update, causal_conv1d_fn = None, None + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + if self.training and cache_params is None: # Doesn't support outputting the states -> used for training + contextualized_states = mamba_inner_fn( + projected_states, + self.conv1d.weight, + self.conv1d.bias if self.use_conv_bias else None, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias.float() if self.use_bias else None, + -torch.exp(self.A_log.float()), + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + b_rms_weight=self.b_c_rms, + c_rms_weight=self.b_c_rms, + dt_rms_weight=self.dt_rms, + b_c_dt_rms_eps=self.rms_eps, + ) + + else: + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if cache_params is not None and cache_position[0] > 0: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.update_conv_state(self.layer_idx, conv_states, cache_position) + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) + + # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module + # at the price of a small overhead. + if hasattr(self.config, "_pre_quantization_dtype"): + discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2) + else: + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + if cache_params is not None and cache_position[0] > 0: + scan_outputs = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + def slow_forward( + self, + input_states, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + if is_mambapy_available(): + from mambapy.pscan import pscan + else: + pscan = None + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + # use `cache_position.shape[0]` to check whether we are in prefill + # stage, it's equivalent to check `cache_position[0] == 0`, which + # breaks dynamo fullgraph constraints + if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size: + conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + + cache_params.update_conv_state(self.layer_idx, conv_state, cache_position) + hidden_states = self.act( + self.conv1d(hidden_states)[..., :seq_len] + ) # [batch, intermediate_size, seq_len] + else: + conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position) + conv_state = conv_state.to(self.conv1d.weight.device) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = ( + self.act(hidden_states).to(dtype).unsqueeze(-1) + ) # [batch, intermediate_size, 1] : decoding + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) + + discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose( + 1, 2 + ) # [batch, intermediate_size, seq_len] + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] + discrete_A = torch.exp( + A[None, :, None, :] * discrete_time_step[:, :, :, None] + ) # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_B = ( + discrete_time_step[:, :, :, None] * B[:, None, :, :].float() + ) # [batch, intermediate_size, seq_len, ssm_state_size] + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + if self.use_falcon_mambapy and self.training and cache_params is None: + hs = pscan( + discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2) + ) # [batch, seq_len, intermediate_size, ssm_state_size] + scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] + scan_output = scan_output + hidden_states * self.D[None, :, None] + scan_output = scan_output * self.act(gate) + else: + scan_outputs = [] + for i in range(seq_len): + ssm_state = ( + discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] + ) # [batch, intermediate_size, ssm_state] + scan_output = torch.matmul( + ssm_state.to(dtype), C[:, i, :].unsqueeze(-1) + ) # [batch, intermediate_size, 1] + scan_outputs.append(scan_output[:, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = scan_output * self.act(gate) + + if cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + return contextualized_states + + +class FalconMambaRMSNorm(MambaRMSNorm): + def forward(self, hidden_states): + return self.weight.to(hidden_states.device) * rms_forward( + hidden_states, variance_epsilon=self.variance_epsilon + ) + + +class FalconMambaBlock(MambaBlock): + pass + + +@auto_docstring +class FalconMambaPreTrainedModel(MambaPreTrainedModel): + pass + + +class FalconMambaOutput(MambaOutput): + pass + + +class FalconMambaCausalLMOutput(MambaCausalLMOutput): + pass + + +class FalconMambaModel(MambaModel): + pass + + +class FalconMambaForCausalLM(MambaForCausalLM): + pass + + +__all__ = [ + "FalconMambaForCausalLM", + "FalconMambaModel", + "FalconMambaPreTrainedModel", + "FalconMambaCache", + "FalconMambaConfig", +] diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 0b93d4484c9f..44e29c4f0f91 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -182,7 +182,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class HybridMambaAttentionDynamicCache(DynamicCache): +class HybridMambaAttentionDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -196,6 +196,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__(self, config, batch_size, dtype=torch.float16, device=None): super().__init__() self.dtype = dtype diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index f2347833db6c..5d63c166372b 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -24,7 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import MambaCache +from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel @@ -39,25 +39,137 @@ logger = logging.get_logger(__name__) -if is_mambapy_available(): - from mambapy.pscan import pscan -else: - pscan = None -if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None +def is_fast_path_available(): + if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None + if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + else: + causal_conv1d_update, causal_conv1d_fn = None, None -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) + return ( + all((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)), + selective_state_update, + selective_scan_fn, + mamba_inner_fn, + causal_conv1d_update, + causal_conv1d_fn, + ) + + +( + is_fast_path_available, + selective_state_update, + selective_scan_fn, + mamba_inner_fn, + causal_conv1d_update, + causal_conv1d_fn, +) = is_fast_path_available() + + +class MambaCache: + """ + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + MambaCache() + ``` + """ + + is_compileable = True + + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Union[torch.device, str, None] = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. Mamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() class MambaMixer(nn.Module): @@ -109,6 +221,9 @@ def __init__(self, config: MambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias + self.warn_slow_implementation() + + def warn_slow_implementation(self): if not is_fast_path_available: if self.use_mambapy: if is_mambapy_available(): @@ -231,6 +346,10 @@ def cuda_kernels_forward( # fmt: off def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None): + if is_mambapy_available(): + from mambapy.pscan import pscan + else: + pscan = None batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection @@ -638,7 +757,12 @@ def prepare_inputs_for_generation( ): # Overwritten -- uses `cache_params` as opposed to `past_key_values` + cache_params_not_initialized = cache_params is None if use_cache: + if cache_params_not_initialized: + max_batch_size = inputs_embeds.size(0) if inputs_embeds is not None else input_ids.size(0) + cache_params = MambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype) + cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device) # `cache_position` should have been initialized in `generate` if cache_position is None: raise ValueError( @@ -659,7 +783,7 @@ def prepare_inputs_for_generation( # the length of `cache_params.conv_states`, which is `config.conv_kernel` cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) - if inputs_embeds is not None and cache_params is None: + if inputs_embeds is not None and cache_params_not_initialized: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} @@ -738,4 +862,4 @@ def forward( ) -__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel"] +__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel", "MambaCache"] diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 1f663462d5e1..69737bb7c4b5 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -959,7 +959,12 @@ def prepare_inputs_for_generation( ): # Overwritten -- uses `cache_params` as opposed to `past_key_values` + cache_params_not_initialized = cache_params is None if use_cache: + if cache_params_not_initialized: + max_batch_size = inputs_embeds.size(0) if inputs_embeds is not None else input_ids.size(0) + cache_params = Mamba2Cache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype) + cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device) # `cache_position` should have been initialized in `generate` if cache_position is None: raise ValueError( @@ -979,7 +984,7 @@ def prepare_inputs_for_generation( # the length of `cache_params.conv_states`, which is `config.conv_kernel` cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) - if inputs_embeds is not None and cache_params is None: + if inputs_embeds is not None and cache_params_not_initialized: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 6d77801d4ef6..b64d0107395e 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -93,7 +93,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class ZambaHybridDynamicCache(DynamicCache): +class ZambaHybridDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -107,8 +107,13 @@ class ZambaHybridDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.dtype = dtype + self.is_compileable = False self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba self.intermediate_size = config.mamba_expand * config.hidden_size diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index ecd0abcb0263..1c7a489784f3 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -97,7 +97,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Zamba2HybridDynamicCache(DynamicCache): +class Zamba2HybridDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -111,6 +111,10 @@ class Zamba2HybridDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__( self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None ): @@ -1387,7 +1391,7 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if past_key_values and not past_key_values.has_previous_state: + if past_key_values is not None and not past_key_values.has_previous_state: past_key_values.has_previous_state = True output = BaseModelOutputWithPast( diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index a89ab2729f21..990d725cd0f0 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1142,7 +1142,7 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if past_key_values and not past_key_values.has_previous_state: + if past_key_values is not None and not past_key_values.has_previous_state: past_key_values.has_previous_state = True output = BaseModelOutputWithPast( diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e865237485a6..7026bf1697c8 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -44,13 +44,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class MambaCache(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - class OffloadedCache(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index e59787fb8c65..ad63326ac0d4 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -26,7 +26,6 @@ require_torch_accelerator, require_torch_large_accelerator, require_torch_multi_accelerator, - require_torch_multi_gpu, slow, torch_device, ) @@ -41,10 +40,10 @@ import torch from transformers import ( + FalconMambaCache, FalconMambaForCausalLM, FalconMambaModel, ) - from transformers.cache_utils import MambaCache # Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba @@ -312,31 +311,6 @@ def assertInterval(self, member, container, msg=None): def test_config(self): self.config_tester.run_common_tests() - @require_torch_multi_gpu - def test_multi_gpu_data_parallel_forward(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - # some params shouldn't be scattered by nn.DataParallel - # so just remove them if they are present. - blacklist_non_batched_params = ["cache_params"] - for k in blacklist_non_batched_params: - inputs_dict.pop(k, None) - - # move input tensors to cuda:O - for k, v in inputs_dict.items(): - if torch.is_tensor(v): - inputs_dict[k] = v.to(0) - - for model_class in self.all_model_classes: - model = model_class(config=config) - model.to(0) - model.eval() - - # Wrap model in nn.DataParallel - model = torch.nn.DataParallel(model) - with torch.no_grad(): - _ = model(**self._prepare_for_class(inputs_dict, model_class)) - def test_falcon_mamba_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_falcon_mamba_model(*config_and_inputs) @@ -396,7 +370,7 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, MambaCache): # MODIFIED PART START + if isinstance(tuple_object, FalconMambaCache): # MODIFIED PART START recursive_check(tuple_object.conv_states, dict_object.conv_states) recursive_check(tuple_object.ssm_states, dict_object.ssm_states) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END @@ -443,6 +417,10 @@ def recursive_check(tuple_object, dict_object): dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + @unittest.skip("Mamba models do not support DDP.") + def test_multi_gpu_data_parallel_forward(self): + pass + @require_torch @require_torch_accelerator @@ -482,7 +460,9 @@ def test_generation_fp16(self): @require_bitsandbytes def test_generation_4bit(self): quantization_config = BitsAndBytesConfig(load_in_4bit=True) - model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config) + model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config).to( + torch_device + ) inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device) out = model.generate(**inputs, max_new_tokens=20, do_sample=False) @@ -498,6 +478,7 @@ def test_generation_torch_compile(self): inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device) out = model.generate(**inputs, max_new_tokens=20, do_sample=False) + print(self.tokenizer.batch_decode(out, skip_special_tokens=False)[0]) self.assertEqual( self.tokenizer.batch_decode(out, skip_special_tokens=False)[0], @@ -528,7 +509,7 @@ def test_batched_generation(self): inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device) model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.float16) - out = model.generate(**inputs, max_new_tokens=20) + out = model.generate(**inputs, max_new_tokens=20, do_sample=False) out = tok.batch_decode(out, skip_special_tokens=True) self.assertListEqual(out, EXPECTED_OUTPUT) @@ -538,7 +519,7 @@ def test_batched_generation(self): inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids")) inputs["inputs_embeds"] = inputs_embeds - out = model.generate(**inputs, max_new_tokens=20) + out = model.generate(**inputs, max_new_tokens=20, do_sample=False) out = tok.batch_decode(out, skip_special_tokens=True) EXPECTED_OUTPUTS = Expectations( diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 840493648ffc..3f87481b0703 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -20,7 +20,7 @@ from parameterized import parameterized from transformers import AutoTokenizer, MambaConfig, is_torch_available -from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device +from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -32,10 +32,10 @@ import torch from transformers import ( + MambaCache, MambaForCausalLM, MambaModel, ) - from transformers.models.mamba.modeling_mamba import MambaCache class MambaModelTester: @@ -279,31 +279,6 @@ def assertInterval(self, member, container, msg=None): def test_config(self): self.config_tester.run_common_tests() - @require_torch_multi_gpu - def test_multi_gpu_data_parallel_forward(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - # some params shouldn't be scattered by nn.DataParallel - # so just remove them if they are present. - blacklist_non_batched_params = ["cache_params"] - for k in blacklist_non_batched_params: - inputs_dict.pop(k, None) - - # move input tensors to cuda:O - for k, v in inputs_dict.items(): - if torch.is_tensor(v): - inputs_dict[k] = v.to(0) - - for model_class in self.all_model_classes: - model = model_class(config=config) - model.to(0) - model.eval() - - # Wrap model in nn.DataParallel - model = torch.nn.DataParallel(model) - with torch.no_grad(): - _ = model(**self._prepare_for_class(inputs_dict, model_class)) - def test_mamba_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_mamba_model(*config_and_inputs) @@ -437,6 +412,10 @@ def test_dtype_mismatch_handled_in_cache(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size), ) + @unittest.skip("Mamba models do not support DDP.") + def test_multi_gpu_data_parallel_forward(self): + pass + @require_torch class MambaIntegrationTests(unittest.TestCase): @@ -532,11 +511,11 @@ def test_compile_mamba_cache(self): torch_device ) - output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba") + output = model.generate(input_ids, max_new_tokens=20) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") - output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba") + output = model.generate(input_ids, max_new_tokens=20) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 8c864f9b64f1..06bf5a381a40 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -63,8 +63,6 @@ TEST_CACHE_IMPLEMENTATIONS = [ cache_name for cache_name in ALL_CACHE_IMPLEMENTATIONS - # TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`? - if cache_name != "mamba" # TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them if cache_name != "offloaded_hybrid" ] @@ -1119,3 +1117,44 @@ def test_hybrid_cache_sliding_mode(self): [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 4 failed", ) + + def test_dynamic_cache(self): + """Test DynamicCache with manually prefilled states and hardcoded assertions. + Scenario 1: prefill and update for one layer + prefill: [1.0, 2.0] + update pos 2: [1.0, 2.0, 3.0] + Scenario 2: prefill and update for two layers independently + """ + prefill = torch.tensor([1.0, 2.0])[None, None, :, None] + update3 = torch.tensor(3.0)[None, None, None, None] + update4 = torch.tensor(4.0)[None, None, None, None] + + # Scenario 1: prefill and update for one layer + cache = DynamicCache() + cache.update(prefill, prefill, 0) + cache.update(update3, update3, 0) + self.assertEqual(cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed") + cache.update(update4, update4, 0) + self.assertEqual( + cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed" + ) + + # Scenario 2: prefill and update for two layers independently + prefill1 = torch.tensor([10.0, 20.0])[None, None, :, None] + update3_1 = torch.tensor(30.0)[None, None, None, None] + update4_1 = torch.tensor(40.0)[None, None, None, None] + + cache = DynamicCache() + cache.update(prefill, prefill, 0) + cache.update(prefill1, prefill1, 1) + + cache.update(update3, update3, 0) + cache.update(update3_1, update3_1, 1) + cache.update(update4, update4, 0) + cache.update(update4_1, update4_1, 1) + self.assertEqual( + cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed" + ) + self.assertEqual( + cache.key_cache[1][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0], "DynamicCache Scenario 2 layer 1 failed" + ) From 04d7a0b12a597bb0a4d9a80c23193b5ae8357404 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 30 Jun 2025 15:51:22 +0200 Subject: [PATCH 2/4] fix quantized, add tests --- src/transformers/cache_utils.py | 302 +++++++++++++++------------ src/transformers/generation/utils.py | 7 +- tests/utils/test_cache_utils.py | 59 +++++- 3 files changed, 232 insertions(+), 136 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d05220cb62a5..97c2c765dc13 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -214,7 +214,9 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None: self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) def __repr__(self): - return f"{self.__class__.__name__}(K={self.key_cache}, V={self.value_cache})" + key_repr = "None" if self.key_cache is None else f"t({tuple(self.key_cache.shape)})" + value_repr = "None" if self.value_cache is None else f"t({tuple(self.value_cache.shape)})" + return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})" class Cache: @@ -860,7 +862,7 @@ def update( def get_seq_length(self, cache_position: Optional[torch.LongTensor] = None) -> int: """Returns the sequence length of the cached states.""" # TODO: deprecate this function in favor of `cache_position` - if self is None or self.key_cache is None: + if self is None or self.key_cache is None or self.key_cache.numel() == 0: return 0 return self.key_cache.shape[-2] @@ -1017,94 +1019,6 @@ def __init__(self, config: Optional[CacheConfig] = None) -> None: super().__init__(processors=processors, config=config) -class QuantoQuantizedCache(DynamicCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - - Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. - - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. - - Example: - - ```python - >>> # Run pip install quanto first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4) - >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - QuantoQuantizedCache() - ``` - """ - - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) - super(DynamicCache, self).__init__(processors=processors) - - -class HQQQuantizedCache(DynamicCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - - Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. - - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. - - Example: - - ```python - >>> # Run pip install hqq first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) - >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HQQQuantizedCache() - ``` - """ - - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)]) - super(DynamicCache, self).__init__(processors=processors) - - class StaticLayer(CacheLayer): is_compileable = True @@ -2120,56 +2034,60 @@ def post_update( if layer_idx == 0: self._seen_tokens += key_tensors.shape[-2] - # Extend quantized cache if needed - while len(self._quantized_key_cache) <= layer_idx: - self._quantized_key_cache.append(torch.empty(0)) - self._quantized_value_cache.append(torch.empty(0)) + if len(cache.key_cache) < layer_idx: + raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") - # Check if we need to quantize - if layer_idx < len(cache.key_cache): - current_key = cache.key_cache[layer_idx] - current_value = cache.value_cache[layer_idx] + # `key_tensors` is the content of the residual cache, after having been updated by DynamicLayer + # On the first forward pass, we quantize the whole prompt. + # On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full. + is_prefill = self._get_quantized_length(layer_idx) == 0 - if ( - current_key.dim() == 4 - and current_key.shape[-2] >= self.config.residual_length - and current_key.shape[-2] > self._get_quantized_length(layer_idx) - ): - # Quantize the older part, keep recent tokens in original precision - split_idx = current_key.shape[-2] - self.config.residual_length + if is_prefill: + self._quantized_key_cache.append(self._quantize(key_tensors.contiguous(), axis=self.config.axis_key)) + self._quantized_value_cache.append(self._quantize(value_tensors.contiguous(), axis=self.config.axis_value)) - # Get the part to quantize - key_to_quantize = current_key[:, :, :split_idx, :].contiguous() - value_to_quantize = current_value[:, :, :split_idx, :].contiguous() + # Clear the residual cache + cache.key_cache[layer_idx] = torch.zeros( + 0, + dtype=key_tensors.dtype, + device=key_tensors.device, + ) + cache.value_cache[layer_idx] = torch.zeros( + 0, + dtype=value_tensors.dtype, + device=value_tensors.device, + ) + # On prefill, we return the original prompt + keys_to_return, values_to_return = key_tensors, value_tensors + else: + # Prepend the previously quantized cache + dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) + dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2) + values_to_return = torch.cat([dequant_value, value_tensors], dim=-2) + if key_tensors.shape[-2] >= self.config.residual_length: # Quantize and store - self._quantized_key_cache[layer_idx] = self._quantize(key_to_quantize, axis=self.config.axis_key) - self._quantized_value_cache[layer_idx] = self._quantize(value_to_quantize, axis=self.config.axis_value) - - # Keep only the recent tokens in original precision - cache.key_cache[layer_idx] = current_key[:, :, split_idx:, :] - cache.value_cache[layer_idx] = current_value[:, :, split_idx:, :] - - # Return the full tensors for this update - if self._quantized_key_cache[layer_idx].numel() > 0: - dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) - dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) - full_key = torch.cat([dequant_key, cache.key_cache[layer_idx]], dim=-2) - full_value = torch.cat([dequant_value, cache.value_cache[layer_idx]], dim=-2) - return full_key, full_value + self._quantized_key_cache[layer_idx] = self._quantize( + keys_to_return.contiguous(), axis=self.config.axis_key + ) + self._quantized_value_cache[layer_idx] = self._quantize( + values_to_return.contiguous(), axis=self.config.axis_value + ) - return key_tensors, value_tensors + # Clear the residual cache + cache.key_cache[layer_idx] = torch.zeros( + 0, + dtype=key_tensors.dtype, + device=key_tensors.device, + ) + cache.value_cache[layer_idx] = torch.zeros( + 0, + dtype=value_tensors.dtype, + device=value_tensors.device, + ) - def _get_quantized_length(self, layer_idx: int) -> int: - """Get the length of quantized cache for a layer.""" - if layer_idx < len(self._quantized_key_cache) and self._quantized_key_cache[layer_idx].numel() > 0: - # This would depend on the specific quantization implementation - return ( - self._quantized_key_cache[layer_idx].shape[-2] - if hasattr(self._quantized_key_cache[layer_idx], "shape") - else 0 - ) - return 0 + return keys_to_return, values_to_return def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: """Quantize a tensor - to be implemented by specific quantization backends.""" @@ -2227,6 +2145,12 @@ def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor: """Dequantize tensor using quanto backend.""" return qtensor.dequantize() + def _get_quantized_length(self, layer_idx: int) -> int: + """Get the length of quantized cache for a layer.""" + if layer_idx < len(self._quantized_key_cache): + return self._quantized_key_cache[layer_idx].shape[-2] + return 0 + class HQQQuantizedCacheProcessor(QuantizedCacheProcessor): """ @@ -2277,6 +2201,12 @@ def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tens tensor = self.quantizer.dequantize(quant_tensor, meta) return tensor + def _get_quantized_length(self, layer_idx: int) -> int: + """Get the length of quantized cache for a layer.""" + if layer_idx < len(self._quantized_key_cache): + return self._quantized_key_cache[layer_idx][0].shape[-2] + return 0 + class QuantizedCache(DynamicCache): """ @@ -2293,9 +2223,113 @@ class QuantizedCache(DynamicCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: - processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) + if cache_config.backend == "quanto": + processor = QuantoQuantizedCacheProcessor(cache_config) + elif cache_config.backend == "hqq": + processor = HQQQuantizedCacheProcessor(cache_config) + else: + raise ValueError(f"Unknown quantization backend `{cache_config.backend}`") + + processors = CacheProcessorList([processor]) super().__init__(processors=processors) + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is + # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx + # this part of code otherwise fails when used to verify attn_weight shape in some models + return self.processors[0]._seen_tokens if layer_idx == 0 else self.processors[0]._seen_tokens - 1 + + +class QuantoQuantizedCache(QuantizedCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + + Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + + Parameters: + cache_config (`QuantizedCacheConfig`): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install quanto first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4) + >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + QuantoQuantizedCache() + ``` + """ + + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) + Cache.__init__(self, processors=processors) + + +class HQQQuantizedCache(QuantizedCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + + Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + + Parameters: + cache_config (`QuantizedCacheConfig`): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install hqq first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) + >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HQQQuantizedCache() + ``` + """ + + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)]) + Cache.__init__(self, processors=processors) + class SinkCache(Cache): """ diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 74cdbbc9d982..242b16195460 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -574,7 +574,12 @@ def prepare_inputs_for_generation( # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly # (this alternative is not as robust as calling `generate` and letting it create `cache_position`) elif cache_position is None: - past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_length = 0 + if past_key_values is not None: + if not isinstance(past_key_values, Cache): + past_length = past_key_values[0][0].shape[2] + elif hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() is not None: + past_length = past_key_values.get_seq_length() cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) # 2. Generic cache-dependent input preparation diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 06bf5a381a40..de110b872a7e 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -36,7 +36,7 @@ slow, torch_device, ) -from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal +from transformers.utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal if is_torch_available(): @@ -52,11 +52,13 @@ GenerationConfig, HybridCache, LlamaConfig, + QuantizedCache, SlidingWindowCache, StaticCache, convert_and_export_with_cache, pipeline, ) + from transformers.cache_utils import HQQQuantizedCacheProcessor, QuantoQuantizedCacheProcessor from transformers.integrations.executorch import export_with_dynamic_cache @@ -283,6 +285,61 @@ def test_cache_beam_search(self, cache_implementation): decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) self.assertListEqual(decoded, EXPECTED_GENERATION) + @parameterized.expand([("quanto"), ("HQQ")]) + def test_quantized_cache_generation(self, backend): + """Tests that QuantizedCache works as expected for both `quanto` and `hqq` backends.""" + if backend == "quanto": + if not is_optimum_quanto_available(): + self.skipTest("Quanto is not available") + axis_key, axis_value = 0, 0 + # This output is taken from a run with the same parameters, and is known to be correct + expected_generation = ["The cat's whiskers are also a sign of anxiety."] + elif backend == "HQQ": + if not is_hqq_available(): + self.skipTest("HQQ is not available") + axis_key, axis_value = 1, 1 + # HQQ has slightly different numerics + expected_generation = ["The cat's whiskers are also a sign of anxiety."] + else: + return + + inputs = self.tokenizer(["The cat"], return_tensors="pt").to(self.model.device) + + gen_out = self.model.generate( + **inputs, + do_sample=False, + max_new_tokens=10, + return_dict_in_generate=True, + cache_implementation="quantized", + cache_config={ + "backend": backend, + "nbits": 4, + "q_group_size": 16, + "residual_length": 4, + "axis_key": axis_key, + "axis_value": axis_value, + }, + disable_compile=True, + ) + + self.assertIsInstance(gen_out.past_key_values, QuantizedCache) + processor = gen_out.past_key_values.processors[0] + if backend == "quanto": + self.assertIsInstance(processor, QuantoQuantizedCacheProcessor) + elif backend == "hqq": + self.assertIsInstance(processor, HQQQuantizedCacheProcessor) + + decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) + self.assertListEqual(decoded, expected_generation) + + self.assertTrue(len(processor._quantized_key_cache) > 0) + + # Check that something is actually quantized + has_been_quantized = any( + (q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_key_cache + ) + self.assertTrue(has_been_quantized) + @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) def test_cache_extra_left_padding(self, cache_implementation): """Tests that adding extra left-padding does not affect the generation with the cache""" From 26c28af61632dca73ce0f784e7596928e773768a Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Mon, 30 Jun 2025 18:24:18 +0200 Subject: [PATCH 3/4] remove CacheProcessorList --- src/transformers/cache_utils.py | 103 ++++++++++---------------------- tests/utils/test_cache_utils.py | 2 +- 2 files changed, 32 insertions(+), 73 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 97c2c765dc13..233b72bcbfff 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -86,46 +86,7 @@ def post_update( return key_tensors, value_tensors -class CacheProcessorList(list): - """ - list of cache processors that can be applied to a cache. - """ - - def init(self, cache: "Cache", **kwargs) -> None: - """Initialize all processors in the list.""" - for processor in self: - processor.init(cache, **kwargs) - - def pre_update( - self, - cache: "Cache", - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Apply pre_update hook for all processors.""" - for processor in self: - key_states, value_states = processor.pre_update(cache, key_states, value_states, layer_idx, cache_kwargs) - return key_states, value_states - - def post_update( - self, - cache: "Cache", - key_tensors: torch.Tensor, - value_tensors: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Apply post_update hook for all processors.""" - for processor in self: - key_tensors, value_tensors = processor.post_update( - cache, key_tensors, value_tensors, layer_idx, cache_kwargs - ) - return key_tensors, value_tensors - - -class KVList: +class KVProxy: """Efficiently simulates layer-indexed key or value lists from a layered cache. This allows for BC access, e.g., cache.key_cache[idx] or cache.value_cache[idx].""" @@ -228,8 +189,8 @@ class Cache: config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): Model configuration for shape/device info, or DDP-distributed cache data for compatibility. If DDP-distributed cache data, must be an iterable of (key_states, value_states) tuples for each layer. - processors (`CacheProcessorList`, *optional*): - List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + processor (`CacheProcessor`, *optional*): + Cache processor to apply (e.g., quantization, offloading). pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): Pattern of cache layer types to use. Defaults to `(DynamicLayer,)`. Must be a tuple whose length divides the total number of layers. The pattern repeats to fill all layers. Examples: `(StaticLayer,)` for a @@ -258,13 +219,13 @@ def __init__( config_or_ddp_cache_data: Optional[ Union[PretrainedConfig, Iterable[tuple[torch.Tensor, torch.Tensor]]] ] = None, - processors: Optional[CacheProcessorList] = None, + processor: Optional[CacheProcessor] = None, pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, *args, **kwargs, ): self.layers: list[CacheLayer] = [] - self.processors = processors if processors is not None else CacheProcessorList() + self.processor = processor pattern_block = pattern_block or self.pattern_block or (DynamicLayer,) if isinstance(config_or_ddp_cache_data, PretrainedConfig): @@ -280,7 +241,8 @@ def __init__( assert pattern_block == (DynamicLayer,), "torch DDP is only supported for DynamicCache" for key_states, value_states in _distributed_cache_data: self.layers.append(DynamicLayer.from_kv(key_states, value_states)) - self.processors.init(self, **kwargs) + if self.processor is not None: + self.processor.init(self, **kwargs) return else: model_config = kwargs.pop("config", None) @@ -292,7 +254,8 @@ def __init__( layer = layer_type(self.config.to_layer(idx)) self.layers.append(layer) - self.processors.init(self, **kwargs) + if self.processor is not None: + self.processor.init(self, **kwargs) def grow_layers_to(self, layer_idx): while len(self.layers) <= layer_idx: @@ -302,14 +265,14 @@ def grow_layers_to(self, layer_idx): self.layers.append(next_layer_type()) @property - def key_cache(self) -> KVList: + def key_cache(self) -> KVProxy: """Returns a list-like object of key cache tensors indexed by layer.""" - return KVList(self.layers, "key") + return KVProxy(self.layers, "key") @property - def value_cache(self) -> KVList: + def value_cache(self) -> KVProxy: """Returns a list-like object of value cache tensors indexed by layer.""" - return KVList(self.layers, "value") + return KVProxy(self.layers, "value") def update( self, @@ -335,12 +298,16 @@ def update( Return: A tuple containing the updated key and value states. """ - key_states, value_states = self.processors.pre_update(self, key_states, value_states, layer_idx, cache_kwargs) + if self.processor is not None: + key_states, value_states = self.processor.pre_update( + self, key_states, value_states, layer_idx, cache_kwargs + ) self.grow_layers_to(layer_idx) key_tensors, value_tensors = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) - key_tensors, value_tensors = self.processors.post_update( - self, key_tensors, value_tensors, layer_idx, cache_kwargs - ) + if self.processor is not None: + key_tensors, value_tensors = self.processor.post_update( + self, key_tensors, value_tensors, layer_idx, cache_kwargs + ) return key_tensors, value_tensors def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: @@ -1015,8 +982,7 @@ class OffloadedCache(DynamicCache): def __init__(self, config: Optional[CacheConfig] = None) -> None: # Create the underlying cache with offload processor - processors = CacheProcessorList([OffloadedCacheProcessor()]) - super().__init__(processors=processors, config=config) + super().__init__(processor=OffloadedCacheProcessor(), config=config) class StaticLayer(CacheLayer): @@ -1115,7 +1081,7 @@ class StaticCache(Cache): Parameters: config_or_ddp_cache_data (`Union`, *optional*): Model configuration for shape/device info, or DDP-distributed cache data for compatibility. - processors (`Optional`, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + processor (`Optional`, *optional*): Cache processor to apply (e.g., quantization, offloading). pattern_block (`Optional`, *optional*): Pattern of cache layer types to use. Defaults to `(StaticLayer,)` for backward compatibility. @@ -1429,7 +1395,7 @@ class HybridCache(Cache): Parameters: config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): Model configuration for shape/device info. No DDP-distributed cache data is supported. - processors (`CacheProcessorList`, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list. + processor (`CacheProcessor`, *optional*): Cache processor to apply (e.g., quantization, offloading). pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): Pattern of cache layer types to use. Defaults to `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)` for backward compatibility. Example: @@ -1455,7 +1421,7 @@ class HybridCache(Cache): def __init__( self, config_or_ddp_cache_data=None, - processors: Optional[CacheProcessorList] = None, + processor: Optional[CacheProcessor] = None, pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, *args, **kwargs, @@ -1469,7 +1435,7 @@ def __init__( self.is_sliding = [False] * model_config.num_hidden_layers pattern_block = tuple(SlidingWindowLayer if sl else StaticLayer for sl in self.is_sliding) - super().__init__(config_or_ddp_cache_data, processors, pattern_block, *args, **kwargs) + super().__init__(config_or_ddp_cache_data, processor, pattern_block, *args, **kwargs) class HybridChunkedCache(Cache): @@ -1878,10 +1844,6 @@ def __init__( offload_device: Union[str, torch.device] = "cpu", layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, ) -> None: - # Create offload processor - processors = CacheProcessorList([OffloadedCacheProcessor(offload_device)]) - - # Initialize the base StaticCache with the processor super().__init__( config=config, max_batch_size=max_batch_size, @@ -1889,7 +1851,7 @@ def __init__( device=device, dtype=dtype, layer_device_map=layer_device_map, - processors=processors, + processor=OffloadedCacheProcessor(offload_device), ) @@ -2230,8 +2192,7 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None: else: raise ValueError(f"Unknown quantization backend `{cache_config.backend}`") - processors = CacheProcessorList([processor]) - super().__init__(processors=processors) + super().__init__(processor=processor) def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" @@ -2240,7 +2201,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx # this part of code otherwise fails when used to verify attn_weight shape in some models - return self.processors[0]._seen_tokens if layer_idx == 0 else self.processors[0]._seen_tokens - 1 + return self.processor._seen_tokens if layer_idx == 0 else self.processor._seen_tokens - 1 class QuantoQuantizedCache(QuantizedCache): @@ -2283,8 +2244,7 @@ class QuantoQuantizedCache(QuantizedCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: - processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)]) - Cache.__init__(self, processors=processors) + Cache.__init__(self, processor=QuantoQuantizedCacheProcessor(cache_config)) class HQQQuantizedCache(QuantizedCache): @@ -2327,8 +2287,7 @@ class HQQQuantizedCache(QuantizedCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: - processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)]) - Cache.__init__(self, processors=processors) + Cache.__init__(self, processor=HQQQuantizedCacheProcessor(cache_config)) class SinkCache(Cache): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index de110b872a7e..8580316d26bd 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -323,7 +323,7 @@ def test_quantized_cache_generation(self, backend): ) self.assertIsInstance(gen_out.past_key_values, QuantizedCache) - processor = gen_out.past_key_values.processors[0] + processor = gen_out.past_key_values.processor if backend == "quanto": self.assertIsInstance(processor, QuantoQuantizedCacheProcessor) elif backend == "hqq": From 16a662408777d3bb66124192d92a53a5464cb871 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 2 Jul 2025 17:15:16 +0200 Subject: [PATCH 4/4] raushan review, arthur review --- docs/source/en/cache_explanation.md | 20 +- src/transformers/cache_utils.py | 710 ++++++++---------- src/transformers/generation/utils.py | 2 +- src/transformers/integrations/executorch.py | 24 +- src/transformers/masking_utils.py | 24 +- src/transformers/models/bart/modeling_bart.py | 4 +- .../modeling_bigbird_pegasus.py | 4 +- .../models/biogpt/modeling_biogpt.py | 4 +- .../models/blenderbot/modeling_blenderbot.py | 4 +- .../modeling_blenderbot_small.py | 4 +- src/transformers/models/dia/modeling_dia.py | 4 +- src/transformers/models/dia/modular_dia.py | 4 +- .../models/gemma3n/modeling_gemma3n.py | 4 +- .../models/gemma3n/modular_gemma3n.py | 4 +- src/transformers/models/gptj/modeling_gptj.py | 5 +- .../models/informer/modeling_informer.py | 8 +- .../models/informer/modular_informer.py | 4 +- .../models/longt5/modeling_longt5.py | 4 +- .../models/m2m_100/modeling_m2m_100.py | 4 +- .../models/marian/modeling_marian.py | 4 +- .../models/mbart/modeling_mbart.py | 4 +- .../models/minimax/modeling_minimax.py | 6 +- .../models/minimax/modular_minimax.py | 6 +- .../models/mllama/modeling_mllama.py | 4 +- .../models/moonshine/modeling_moonshine.py | 4 +- .../models/moonshine/modular_moonshine.py | 4 +- src/transformers/models/mt5/modeling_mt5.py | 4 +- .../models/pegasus/modeling_pegasus.py | 4 +- .../models/pegasus_x/modeling_pegasus_x.py | 4 +- .../models/pix2struct/modeling_pix2struct.py | 4 +- .../models/plbart/modeling_plbart.py | 4 +- .../models/pop2piano/modeling_pop2piano.py | 4 +- .../modeling_switch_transformers.py | 4 +- src/transformers/models/t5/modeling_t5.py | 4 +- .../models/t5gemma/modeling_t5gemma.py | 4 +- .../models/t5gemma/modular_t5gemma.py | 4 +- .../modeling_time_series_transformer.py | 4 +- src/transformers/models/udop/modeling_udop.py | 4 +- src/transformers/models/umt5/modeling_umt5.py | 4 +- .../models/whisper/generation_whisper.py | 4 +- .../models/whisper/modeling_whisper.py | 4 +- tests/generation/test_utils.py | 82 +- .../deepseek_v3/test_modeling_deepseek_v3.py | 8 +- tests/models/dia/test_modeling_dia.py | 8 +- .../models/gpt_neox/test_modeling_gpt_neox.py | 4 +- tests/models/t5gemma/test_modeling_t5gemma.py | 44 +- tests/utils/test_cache_utils.py | 102 +-- utils/check_docstrings.py | 5 +- 48 files changed, 535 insertions(+), 651 deletions(-) diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 4190cefdb8a1..2adcc0c78012 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -82,24 +82,18 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s ## Cache storage implementation -The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`]. +Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape `[batch_size, num_heads, seq_len, head_dim]`. +Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWindowLayer`), which mostly changes how sequence length is handled and how the cache is updated. -In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`. -- `key_cache`: A list of tensors, one for each layer. -- `value_cache`: A list of tensors, one for each layer. +The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token: -When new tokens are processed: - -1. For each layer, the new key and value states are concatenated with the existing cache. ```py -self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) -self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) +cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2) +cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2) ``` -2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token. - -3. The cache maintains a count of seen tokens through `self._seen_tokens`. This is updated when the first layer processes a new token. +Other layers like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added. The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token. @@ -143,7 +137,7 @@ The legacy format is essentially the same data structure but organized different - The tensors have the same shape `[batch_size, num_heads, seq_len, head_dim]`. - The format is less flexible and doesn't support features like quantization or offloading. -If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~DynamicCache.from_legacy_cache`] and [`DynamicCache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format. +If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~Cache.from_legacy_cache`] and [`Cache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format. ```py import torch diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 233b72bcbfff..9219365e76ad 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -48,7 +48,7 @@ def pre_update( cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Hook called before the cache update. Can modify the key/value states. + Function called before the cache update. Can modify the key/value states. Args: cache (`Cache`): The cache instance. @@ -71,7 +71,7 @@ def post_update( cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Hook called after the cache update. Can process the cached data. + Function called after the cache update. Can process the cached data. Args: cache (`Cache`): The cache instance. @@ -86,37 +86,6 @@ def post_update( return key_tensors, value_tensors -class KVProxy: - """Efficiently simulates layer-indexed key or value lists from a layered cache. - This allows for BC access, e.g., cache.key_cache[idx] or cache.value_cache[idx].""" - - def __init__(self, layers, cache_type="key"): - self.layers = layers - self.cache_type = cache_type - - def __getitem__(self, idx): - if isinstance(idx, slice): - return [getattr(layer, f"{self.cache_type}_cache") for layer in self.layers[idx]] - return getattr(self.layers[idx], f"{self.cache_type}_cache") - - def __setitem__(self, idx, value): - if isinstance(idx, slice): - for layer, val in zip(self.layers[idx], value): - setattr(layer, f"{self.cache_type}_cache", val) - else: - setattr(self.layers[idx], f"{self.cache_type}_cache", value) - - def __len__(self): - return len(self.layers) - - def __iter__(self): - for layer in self.layers: - yield getattr(layer, f"{self.cache_type}_cache") - - def __bool__(self): - return bool(self.layers) - - class CacheLayer: """Base, abstract class for a single layer's cache.""" @@ -126,14 +95,14 @@ def __init__( self, config: Optional["CacheConfig"] = None, ): - self.key_cache = None - self.value_cache = None + self.keys = None + self.values = None @classmethod - def from_kv(cls, key_cache: torch.Tensor, value_cache: torch.Tensor) -> None: + def from_kv(cls, keys: torch.Tensor, values: torch.Tensor) -> None: cache = cls() - cache.key_cache = key_cache - cache.value_cache = value_cache + cache.keys = keys + cache.values = values return cache def update( @@ -167,44 +136,44 @@ def reset(self) -> None: def reorder_cache(self, beam_idx: torch.LongTensor) -> None: """Reorders this layer's cache for beam search.""" - if self.key_cache.numel(): - device = self.key_cache.device - self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) - if self.value_cache.numel(): - device = self.value_cache.device - self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) + if self.keys.numel(): + device = self.keys.device + self.keys = self.keys.index_select(0, beam_idx.to(device)) + if self.values.numel(): + device = self.values.device + self.values = self.values.index_select(0, beam_idx.to(device)) def __repr__(self): - key_repr = "None" if self.key_cache is None else f"t({tuple(self.key_cache.shape)})" - value_repr = "None" if self.value_cache is None else f"t({tuple(self.value_cache.shape)})" + key_repr = "None" if self.keys is None else f"t({tuple(self.keys.shape)})" + value_repr = "None" if self.values is None else f"t({tuple(self.values.shape)})" return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})" -class Cache: +class CacheBase: + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.cache_processor is not None: + key_states, value_states = self.cache_processor.pre_update( + self, key_states, value_states, layer_idx, cache_kwargs + ) + key_tensors, value_tensors = self._update(key_states, value_states, layer_idx, cache_kwargs) + if self.cache_processor is not None: + key_tensors, value_tensors = self.cache_processor.post_update( + self, key_tensors, value_tensors, layer_idx, cache_kwargs + ) + return key_tensors, value_tensors + + +class Cache(CacheBase): """ Base, abstract class for all caches. The actual data structure is specific to the layers. This class handles propagation of operations across layers. - Parameters: - config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): - Model configuration for shape/device info, or DDP-distributed cache data for compatibility. - If DDP-distributed cache data, must be an iterable of (key_states, value_states) tuples for each layer. - processor (`CacheProcessor`, *optional*): - Cache processor to apply (e.g., quantization, offloading). - pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): - Pattern of cache layer types to use. Defaults to `(DynamicLayer,)`. Must be a tuple whose length divides - the total number of layers. The pattern repeats to fill all layers. Examples: `(StaticLayer,)` for a - uniform cache, `(StaticLayer, StaticLayer, SlidingWindowLayer)` for a hybrid cache with repeating pattern, - or specify the full structure like `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)`. - Additional arguments for cache configuration: - - `max_batch_size`/`batch_size` (`int`): Maximum batch size for static caches - - `max_cache_len` (`int`): Maximum sequence length. For hybrid caches: - * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` - * StaticLayers: uses full `max_cache_len` - - `device` (`torch.device`): Device for cache tensors - - `dtype` (`torch.dtype`): Data type for cache tensors - - `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping - Note for hybrid caches (blocks of (StaticLayer, ..., SlidingWindowLayer) repeated across layers): - Requires `model_config.sliding_window` to be set - Uses `sliding_window_pattern` (default: 2) to determine layer alternation if pattern not specified @@ -212,69 +181,61 @@ class Cache: """ layers = [] - pattern_block = () # Subclasses can define their layer pattern statically def __init__( self, - config_or_ddp_cache_data: Optional[ - Union[PretrainedConfig, Iterable[tuple[torch.Tensor, torch.Tensor]]] - ] = None, - processor: Optional[CacheProcessor] = None, - pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, + model_config: Optional[PretrainedConfig] = None, + cache_processor: Optional[CacheProcessor] = None, + layer_classes: Optional[list[type[CacheLayer]]] = None, *args, **kwargs, ): + """ + Parameters: + model_config (`PretrainedConfig`): + Model configuration for shape/device info. + cache_processor (`CacheProcessor`, *optional*): + Cache processor to apply (e.g., quantization, offloading). + layer_classes (`list[type[CacheLayer]]`, *optional*): + List of layer classes to use for the cache. + Additional arguments for cache configuration: + - `max_batch_size`/`batch_size` (`int`): Maximum batch size for static caches + - `max_cache_len` (`int`): Maximum sequence length. For hybrid caches: + * SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)` + * StaticLayers: uses full `max_cache_len` + - `device` (`torch.device`): Device for cache tensors + - `dtype` (`torch.dtype`): Data type for cache tensors + - `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping + """ self.layers: list[CacheLayer] = [] - self.processor = processor - pattern_block = pattern_block or self.pattern_block or (DynamicLayer,) - - if isinstance(config_or_ddp_cache_data, PretrainedConfig): - model_config = config_or_ddp_cache_data - elif isinstance(config_or_ddp_cache_data, Iterable): - _distributed_cache_data = config_or_ddp_cache_data - # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 - # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the - # iterable contains the key and value states for a layer gathered across replicas by torch.distributed - # (shape=[global batch size, num_heads, seq_len, head_dim]). - # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break - # compatibility. The name of the argument doesn't matter. - assert pattern_block == (DynamicLayer,), "torch DDP is only supported for DynamicCache" - for key_states, value_states in _distributed_cache_data: - self.layers.append(DynamicLayer.from_kv(key_states, value_states)) - if self.processor is not None: - self.processor.init(self, **kwargs) - return - else: - model_config = kwargs.pop("config", None) + self.cache_processor = cache_processor - self.config, self.pattern_block = CacheConfig.from_model_config(model_config, pattern_block, *args, **kwargs) - self.layer_types = [self.pattern_block[i % len(self.pattern_block)] for i in range(self.config.num_layers)] + if ( + layer_classes is None # setting layer_classes takes precedence + and model_config is not None + and getattr(model_config, "layer_types", None) is not None + ): + layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in model_config.layer_types] + self.layer_classes = layer_classes or [DynamicLayer] - for idx, layer_type in enumerate(self.layer_types): - layer = layer_type(self.config.to_layer(idx)) - self.layers.append(layer) + self.config = CacheConfig.from_model_config(model_config, *args, **kwargs) - if self.processor is not None: - self.processor.init(self, **kwargs) + self.append_new_layers(self.config.num_layers - 1) - def grow_layers_to(self, layer_idx): - while len(self.layers) <= layer_idx: - next_type_idx = len(self.layer_types) % len(self.pattern_block) - next_layer_type = self.pattern_block[next_type_idx] - self.layer_types.append(next_layer_type) - self.layers.append(next_layer_type()) + if self.cache_processor is not None: + self.cache_processor.init(self, **kwargs) - @property - def key_cache(self) -> KVProxy: - """Returns a list-like object of key cache tensors indexed by layer.""" - return KVProxy(self.layers, "key") - - @property - def value_cache(self) -> KVProxy: - """Returns a list-like object of value cache tensors indexed by layer.""" - return KVProxy(self.layers, "value") + def append_new_layers(self, layer_idx): + """ + Appends layers to the cache until the layer `layer_idx` is reached. + Used in prefill and for skipped layers. + """ + while len(self.layers) <= layer_idx: + self.layers.append( + self.layer_classes[layer_idx % len(self.layer_classes)](self.config.for_layer(layer_idx)) + ) - def update( + def _update( self, key_states: torch.Tensor, value_states: torch.Tensor, @@ -298,17 +259,8 @@ def update( Return: A tuple containing the updated key and value states. """ - if self.processor is not None: - key_states, value_states = self.processor.pre_update( - self, key_states, value_states, layer_idx, cache_kwargs - ) - self.grow_layers_to(layer_idx) - key_tensors, value_tensors = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) - if self.processor is not None: - key_tensors, value_tensors = self.processor.post_update( - self, key_tensors, value_tensors, layer_idx, cache_kwargs - ) - return key_tensors, value_tensors + self.append_new_layers(layer_idx) + return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -316,7 +268,7 @@ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: sequence length. """ if layer_idx < len(self.layers): - return self.layers[layer_idx].key_cache, self.layers[layer_idx].value_cache + return self.layers[layer_idx].keys, self.layers[layer_idx].values else: raise KeyError( f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" @@ -328,20 +280,20 @@ def __iter__(self): keys and values """ for layer_idx in range(len(self)): - yield (self.layers[layer_idx].key_cache, self.layers[layer_idx].value_cache) + yield (self.layers[layer_idx].keys, self.layers[layer_idx].values) def __len__(self): """ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds to the number of layers in the model. """ - # Best effort BC support for subclasses + # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked. if self.layers is None: if getattr(self, "key_cache", None) is not None: return len(self.key_cache) return 0 dynamic_empty = ( - len(self.layers) == 1 and isinstance(self.layers[0], DynamicLayer) and self.layers[0].key_cache is None + len(self.layers) == 1 and isinstance(self.layers[0], DynamicLayer) and self.layers[0].keys is None ) return len(self.layers) if not dynamic_empty else 0 @@ -355,45 +307,42 @@ def __getattr__(self, name): if name in ("__getstate__", "__setstate__"): raise AttributeError(name) - # Check if the attribute/method exists and gather values if it is an attribute - attribute_values = [] - for i, layer in enumerate(self.layers[: len(self.pattern_block)]): - if not hasattr(layer, name): - raise AttributeError( - f"Layer {i} ({layer.__class__.__name__}) of {self.__class__.__name__} does not support `{name}`" - ) - if not callable(getattr(layer, name)): - attribute_values.append(getattr(layer, name)) - - if attribute_values: - assert len(attribute_values) == len(self.pattern_block), ( - f"Cache {self.__class__.__name__} gathered {len(attribute_values)} values for {name}, but there are {len(self.pattern_block)} layers." - ) - values = set(attribute_values) - if len(values) == 1: - return values.pop() + is_attribute = ( + all(hasattr(layer, name) and not callable(getattr(layer, name)) for layer in self.layers) + and len(self.layers) > 0 + ) + is_method = all(callable(getattr(layer, name)) for layer in self.layers) and len(self.layers) > 0 + + if is_attribute: + attribute_values = [getattr(layer, name) for layer in self.layers] + if all(isinstance(value, bool) for value in attribute_values): + return all(attribute_values) + elif len(set(attribute_values)) == 1: + return attribute_values[0] else: - if all(isinstance(value, bool) for value in values): - return all(values) - else: - raise ValueError( - f"Cache {self.__class__.__name__}:{self.pattern_block} has multiple values for {name}: {attribute_values}. This is not supported." - ) + raise ValueError( + f"{self.__class__.__name__}: layers have multiple values for layer.{name}: {attribute_values}. This is not supported." + ) + elif is_method: - # If the attribute is a method, we propagate it to all layers - def propagate_to_layers(*args, **kwargs): - for layer in self.layers: - return_value = getattr(layer, name)(*args, **kwargs) - if return_value is not None: - break - return return_value + def propagate_to_layers(*args, **kwargs): + for layer in self.layers: + return_value = getattr(layer, name)(*args, **kwargs) + if return_value is not None: + break + return return_value - return propagate_to_layers + return propagate_to_layers + else: + raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") def get_seq_length(self, layer_idx: int = 0) -> int: """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" if layer_idx >= len(self.layers): return 0 + # Hack since QuantizedCache messes with keys shape as it becomes the residual cache + if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): + return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length() return self.layers[layer_idx].get_seq_length() def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: @@ -403,24 +352,7 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), for each layer. """ - if isinstance(self.layers[layer_idx], SlidingWindowLayer): - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] - - local_mask_kv_offset = torch.clamp(first_cache_position - self.config.sliding_window + 1, min=0) - # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns - local_mask_kv_length = max(query_length, self.config.sliding_window) - return local_mask_kv_length, local_mask_kv_offset - - full_mask_kv_offset = 0 - if isinstance(self.layers[layer_idx], StaticLayer): - full_mask_kv_length = self.get_max_cache_shape() - return full_mask_kv_length, full_mask_kv_offset - else: - query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens - return kv_length, full_mask_kv_offset + return self.layers[layer_idx].get_mask_sizes(cache_position) def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: """Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for @@ -428,7 +360,7 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: legacy_cache = () for layer in self.layers: if layer is not None: - legacy_cache += ((layer.key_cache, layer.value_cache),) + legacy_cache += ((layer.keys, layer.values),) return legacy_cache @classmethod @@ -461,8 +393,7 @@ def __init__(self, num_layers: Optional[int] = None, cache_implementation: Optio @classmethod def from_model_config( cls, - config: Optional[PretrainedConfig], - pattern_block: tuple[type["CacheLayer"], ...], + model_config: Optional[PretrainedConfig], batch_size: Optional[int] = None, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, @@ -470,39 +401,41 @@ def from_model_config( layer_device_map=None, max_batch_size: Optional[int] = None, ) -> "CacheConfig": - num_layers = getattr(config, "num_hidden_layers", len(pattern_block)) # No model config -> must be a dynamic cache, return bare CacheConfig - if config is None: - return cls(num_layers=num_layers), pattern_block - # Build a StaticCacheConfig for any kind of static: hybrid, sliding or static + if model_config is None: + return cls(num_layers=getattr(model_config, "num_hidden_layers", 1)) + # Build a StaticCacheConfig for hybrid, sliding or static else: - # Rename max_batch_size to batch_size - if max_batch_size is not None: - batch_size = max_batch_size # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) - if StaticLayer in pattern_block and SlidingWindowLayer in pattern_block: - if getattr(config, "sliding_window", None) is None: + if ( + getattr(model_config, "layer_types", None) is not None + and "sliding_attention" in model_config.layer_types + and "full_attention" in model_config.layer_types + ): + if getattr(model_config, "sliding_window", None) is None: raise ValueError( "Setting up a hybrid or sliding window KVCache requires the model config supporting " "sliding window attention, please check if there is a `sliding_window` field in the model " "config and it's not set to None." ) # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) - max_cache_len = max_cache_len or config.max_position_embeddings - sliding_window_len = min(getattr(config, "sliding_window", max_cache_len) or max_cache_len, max_cache_len) + max_cache_len = max_cache_len or model_config.max_position_embeddings + sliding_window_len = min( + getattr(model_config, "sliding_window", max_cache_len) or max_cache_len, max_cache_len + ) # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: head_dim = ( - config.head_dim - if getattr(config, "head_dim", None) is not None - else config.hidden_size // config.num_attention_heads + model_config.head_dim + if getattr(model_config, "head_dim", None) is not None + else model_config.hidden_size // model_config.num_attention_heads ) num_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads + model_config.num_attention_heads + if getattr(model_config, "num_key_value_heads", None) is None + else model_config.num_key_value_heads ) cache_config = StaticCacheConfig( - batch_size=batch_size, + batch_size=max_batch_size if max_batch_size is not None else batch_size, max_cache_len=max_cache_len, device=torch.device(device) if device is not None else None, dtype=dtype, @@ -510,9 +443,9 @@ def from_model_config( head_dim=head_dim, num_heads=num_heads, sliding_window=sliding_window_len, - num_layers=num_layers, + num_layers=model_config.num_hidden_layers, ) - return cache_config, pattern_block + return cache_config @classmethod def from_dict(cls, config_dict, **kwargs): @@ -535,7 +468,7 @@ def from_dict(cls, config_dict, **kwargs): kwargs.pop(key, None) return config - def to_layer(self, layer_idx: int) -> "CacheLayer": + def for_layer(self, layer_idx: int) -> "CacheConfig": return self # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file @@ -817,21 +750,20 @@ def update( Return: A tuple containing the updated key and value states. """ - if self.key_cache is None: - self.key_cache = key_states - self.value_cache = value_states + if self.keys is None: + self.keys = key_states + self.values = value_states else: - self.key_cache = torch.cat([self.key_cache, key_states], dim=-2) - self.value_cache = torch.cat([self.value_cache, value_states], dim=-2) - - return self.key_cache, self.value_cache + self.keys = torch.cat([self.keys, key_states], dim=-2) + self.values = torch.cat([self.values, value_states], dim=-2) + return self.keys, self.values def get_seq_length(self, cache_position: Optional[torch.LongTensor] = None) -> int: """Returns the sequence length of the cached states.""" # TODO: deprecate this function in favor of `cache_position` - if self is None or self.key_cache is None or self.key_cache.numel() == 0: + if self is None or self.keys is None or self.keys.numel() == 0: return 0 - return self.key_cache.shape[-2] + return self.keys.shape[-2] def get_max_cache_shape(self) -> int: """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" @@ -839,15 +771,15 @@ def get_max_cache_shape(self) -> int: def reset(self) -> None: """Resets the cache values while preserving the objects""" - self.key_cache = torch.tensor([], dtype=self.key_cache.dtype, device=self.key_cache.device) - self.value_cache = torch.tensor([], dtype=self.value_cache.dtype, device=self.value_cache.device) + self.keys = torch.tensor([], dtype=self.keys.dtype, device=self.keys.device) + self.values = torch.tensor([], dtype=self.values.dtype, device=self.values.device) def reorder_cache(self, beam_idx: torch.LongTensor) -> None: """Reorders the cache for beam search, given the selected beam indices.""" - if self.key_cache is not None and self.key_cache.numel(): - self.key_cache = self.key_cache.index_select(0, beam_idx.to(self.key_cache.device)) - if self.value_cache is not None and self.value_cache.numel(): - self.value_cache = self.value_cache.index_select(0, beam_idx.to(self.value_cache.device)) + if self.keys is not None and self.keys.numel(): + self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device)) + if self.values is not None and self.values.numel(): + self.values = self.values.index_select(0, beam_idx.to(self.values.device)) def crop(self, max_length: int) -> None: """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be @@ -858,21 +790,28 @@ def crop(self, max_length: int) -> None: if self.get_seq_length() <= max_length: return - if self.key_cache.numel(): - self.key_cache = self.key_cache[..., :max_length, :] - self.value_cache = self.value_cache[..., :max_length, :] + if self.keys is not None and self.keys.numel(): + self.keys = self.keys[..., :max_length, :] + self.values = self.values[..., :max_length, :] def batch_repeat_interleave(self, repeats: int) -> None: """Repeat the cache `repeats` times in the batch dimension.""" - if self.key_cache.numel(): - self.key_cache = self.key_cache.repeat_interleave(repeats, dim=0) - self.value_cache = self.value_cache.repeat_interleave(repeats, dim=0) + if self.keys.numel(): + self.keys = self.keys.repeat_interleave(repeats, dim=0) + self.values = self.values.repeat_interleave(repeats, dim=0) def batch_select_indices(self, indices: torch.Tensor) -> None: """Only keep the `indices` in the batch dimension of the cache.""" - if self.key_cache.numel(): - self.key_cache = self.key_cache[indices, ...] - self.value_cache = self.value_cache[indices, ...] + if self.keys.numel(): + self.keys = self.keys[indices, ...] + self.values = self.values[indices, ...] + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + full_mask_kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length(cache_position) + kv_length = query_length + past_seen_tokens + return kv_length, full_mask_kv_offset class DynamicCache(Cache): @@ -900,7 +839,18 @@ class DynamicCache(Cache): ``` """ - pattern_block = (DynamicLayer,) + # Specialized constructor for DDP cache data, needed for BC + def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 + # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the + # iterable contains the key and value states for a layer gathered across replicas by torch.distributed + # (shape=[global batch size, num_heads, seq_len, head_dim]). + # WARNING: `ddp_cache_data` must be the first argument in `__init__`, otherwise we'll break + # compatibility. The name of the argument doesn't matter. + if ddp_cache_data is not None: + for key_states, value_states in ddp_cache_data: + self.layers.append(DynamicLayer.from_kv(key_states, value_states)) + super().__init__(*args, **kwargs) # Utilities for `DynamicCache` <> torch.export support @@ -917,16 +867,16 @@ def _flatten_dynamic_cache( ) dictionary = { - "key_cache": [layer.key_cache for layer in dynamic_cache.layers if layer.key_cache is not None], - "value_cache": [layer.value_cache for layer in dynamic_cache.layers if layer.value_cache is not None], + "key_cache": [layer.keys for layer in dynamic_cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in dynamic_cache.layers if layer.values is not None], } return torch.utils._pytree._dict_flatten(dictionary) def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): dictionary = { - "key_cache": [layer.key_cache for layer in dynamic_cache.layers if layer.key_cache is not None], - "value_cache": [layer.value_cache for layer in dynamic_cache.layers if layer.value_cache is not None], + "key_cache": [layer.keys for layer in dynamic_cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in dynamic_cache.layers if layer.values is not None], } return torch.utils._pytree._dict_flatten_with_keys(dictionary) @@ -937,7 +887,7 @@ def _unflatten_dynamic_cache( ): dictionary = torch.utils._pytree._dict_unflatten(values, context) cache = DynamicCache() - # Reconstruct layers from key_cache and value_cache lists + # Reconstruct layers from keys and values lists key_list = dictionary.get("key_cache", []) value_list = dictionary.get("value_cache", []) for idx in range(max(len(key_list), len(value_list))): @@ -949,8 +899,8 @@ def _unflatten_dynamic_cache( def _flatten_dynamic_cache_for_fx(cache, spec): dictionary = { - "key_cache": [layer.key_cache for layer in cache.layers if layer.key_cache is not None], - "value_cache": [layer.value_cache for layer in cache.layers if layer.value_cache is not None], + "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in cache.layers if layer.values is not None], } return torch.fx._pytree._dict_flatten_spec(dictionary, spec) @@ -980,13 +930,14 @@ class OffloadedCache(DynamicCache): ensure the eviction is scheduled after all computations on that cache are finished. """ - def __init__(self, config: Optional[CacheConfig] = None) -> None: + def __init__(self, model_config: Optional[PretrainedConfig] = None) -> None: # Create the underlying cache with offload processor - super().__init__(processor=OffloadedCacheProcessor(), config=config) + super().__init__(cache_processor=OffloadedCacheProcessor(), model_config=model_config) class StaticLayer(CacheLayer): is_compileable = True + is_sliding = False def __init__( self, @@ -996,20 +947,20 @@ def __init__( self.max_cache_len = max_len or config.max_cache_len self.max_batch_size = config.batch_size # Note: There will be significant perf decrease if switching to use 5D tensors instead. - self.key_cache = torch.zeros( + self.keys = torch.zeros( (config.batch_size, config.num_heads, self.max_cache_len, config.head_dim), dtype=config.dtype, device=config.device, ) - self.value_cache = torch.zeros( + self.values = torch.zeros( (config.batch_size, config.num_heads, self.max_cache_len, config.head_dim), dtype=config.dtype, device=config.device, ) # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(self.key_cache) - torch._dynamo.mark_static_address(self.value_cache) + torch._dynamo.mark_static_address(self.keys) + torch._dynamo.mark_static_address(self.values) def get_max_cache_shape(self) -> int: return self.max_cache_len @@ -1037,24 +988,24 @@ def _static_update( """ if cache_position is None: # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. - self.key_cache.copy_(key_states) - self.value_cache.copy_(value_states) + self.keys.copy_(key_states) + self.values.copy_(value_states) else: # Generation phase. Update specific positions. # Use index_copy_ for in-place update (compile-friendly). try: - self.key_cache.index_copy_(2, cache_position, key_states) - self.value_cache.index_copy_(2, cache_position, value_states) + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) except NotImplementedError: # Fallback for devices like MPS where index_copy_ might not be supported. - self.key_cache[:, :, cache_position] = key_states - self.value_cache[:, :, cache_position] = value_states - return self.key_cache, self.value_cache + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + return self.keys, self.values def update(self, key_states, value_states, cache_kwargs=None) -> tuple[torch.Tensor, torch.Tensor]: cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None - key_states = key_states.to(self.key_cache.dtype) - value_states = value_states.to(self.value_cache.dtype) + key_states = key_states.to(self.keys.dtype) + value_states = value_states.to(self.values.dtype) return self._static_update(key_states, value_states, cache_position) def get_seq_length(self, cache_position=None) -> int: @@ -1062,29 +1013,28 @@ def get_seq_length(self, cache_position=None) -> int: return int(cache_position[-1] + 1) # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. - return (self.key_cache[0, 0].any(dim=-1)).sum() + return (self.keys[0, 0].any(dim=-1)).sum() def reset(self): - self.key_cache.zero_() - self.value_cache.zero_() + self.keys.zero_() + self.values.zero_() def reorder_cache(self, beam_idx): - dev = self.key_cache.device + dev = self.keys.device beam_idx_dev = beam_idx.to(dev) - self.key_cache = self.key_cache.index_select(0, beam_idx_dev) - self.value_cache = self.value_cache.index_select(0, beam_idx_dev) + self.keys = self.keys.index_select(0, beam_idx_dev) + self.values = self.values.index_select(0, beam_idx_dev) + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + full_mask_kv_offset = 0 + full_mask_kv_length = self.max_cache_len + return full_mask_kv_length, full_mask_kv_offset class StaticCache(Cache): """ Static Cache class to be used with `torch.compile(model)` and `torch.export()`. - Parameters: - config_or_ddp_cache_data (`Union`, *optional*): Model configuration for shape/device info, or DDP-distributed cache data for compatibility. - processor (`Optional`, *optional*): Cache processor to apply (e.g., quantization, offloading). - pattern_block (`Optional`, *optional*): Pattern of cache layer types to use. Defaults to `(StaticLayer,)` for backward compatibility. - - Example: ```python @@ -1105,7 +1055,8 @@ class StaticCache(Cache): ``` """ - pattern_block = (StaticLayer,) + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=[StaticLayer], *args, **kwargs) class SlidingWindowLayer(StaticLayer): @@ -1147,9 +1098,9 @@ def _static_update( if cache_position.shape[0] > self.max_cache_len: new_k = key_states[:, :, -self.max_cache_len :, :] new_v = value_states[:, :, -self.max_cache_len :, :] - self.key_cache.copy_(new_k) - self.value_cache.copy_(new_v) - return self.key_cache, self.value_cache + self.keys.copy_(new_k) + self.values.copy_(new_v) + return self.keys, self.values # Sliding window logic for generation phase or prefill < window slicing = torch.arange(self.max_cache_len, device=value_states.device) @@ -1157,8 +1108,8 @@ def _static_update( to_shift = current_seq_len > self.max_cache_len indices = (slicing + to_shift.sum()) % self.max_cache_len - k_out_shifted = self.key_cache[:, :, indices] - v_out_shifted = self.value_cache[:, :, indices] + k_out_shifted = self.keys[:, :, indices] + v_out_shifted = self.values[:, :, indices] # Clamp cache_position to determine the *target index* within the shifted cache view update_position = cache_position.clamp(min=0, max=self.max_cache_len - 1) @@ -1173,9 +1124,18 @@ def _static_update( k_out_updated[:, :, update_position] = key_states v_out_updated[:, :, update_position] = value_states - self.key_cache.copy_(k_out_updated) - self.value_cache.copy_(v_out_updated) - return self.key_cache, self.value_cache + self.keys.copy_(k_out_updated) + self.values.copy_(v_out_updated) + return self.keys, self.values + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + + local_mask_kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0) + # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns + local_mask_kv_length = max(query_length, self.max_cache_len) + return local_mask_kv_length, local_mask_kv_offset class SlidingWindowCache(Cache): @@ -1214,7 +1174,8 @@ class SlidingWindowCache(Cache): ``` """ - pattern_block = (SlidingWindowLayer,) + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs) class EncoderDecoderCache(Cache): @@ -1250,7 +1211,7 @@ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) self.is_updated = {} - for layer_idx in range(len(cross_attention_cache.key_cache)): + for layer_idx in range(len(cross_attention_cache)): self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -1260,10 +1221,10 @@ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch """ if layer_idx < len(self): return ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], + self.self_attention_cache.layers[layer_idx].keys, + self.self_attention_cache.layers[layer_idx].values, + self.cross_attention_cache.layers[layer_idx].keys, + self.cross_attention_cache.layers[layer_idx].values, ) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") @@ -1377,12 +1338,6 @@ def get_max_cache_shape(self) -> int: return self.self_attention_cache.get_max_cache_shape() def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) @@ -1393,11 +1348,6 @@ class HybridCache(Cache): Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] for global attention.For more information, see the documentation of each subcomponent cache class. - Parameters: - config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): Model configuration for shape/device info. No DDP-distributed cache data is supported. - processor (`CacheProcessor`, *optional*): Cache processor to apply (e.g., quantization, offloading). - pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): Pattern of cache layer types to use. Defaults to `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)` - for backward compatibility. Example: ```python @@ -1418,24 +1368,10 @@ class HybridCache(Cache): ``` """ - def __init__( - self, - config_or_ddp_cache_data=None, - processor: Optional[CacheProcessor] = None, - pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None, - *args, - **kwargs, - ): - model_config = config_or_ddp_cache_data or kwargs.get("config", None) - assert model_config is not None, "HybridCache requires a model config" - # If the attribute does not exist in the config, fallback to a simple StaticCache - if hasattr(model_config, "layer_types"): - self.is_sliding = [layer_type != "full_attention" for layer_type in model_config.layer_types] - else: - self.is_sliding = [False] * model_config.num_hidden_layers - - pattern_block = tuple(SlidingWindowLayer if sl else StaticLayer for sl in self.is_sliding) - super().__init__(config_or_ddp_cache_data, processor, pattern_block, *args, **kwargs) + def __init__(self, model_config: PretrainedConfig, *args, **kwargs): + # Ugly but needed for BC + layer_classes = [StaticLayer] if not hasattr(model_config, "layer_types") else None + super().__init__(model_config=model_config, layer_classes=layer_classes, *args, **kwargs) class HybridChunkedCache(Cache): @@ -1447,7 +1383,7 @@ class HybridChunkedCache(Cache): for global attention. For more information, see the documentation of each subcomponent cache class. Parameters: - config (`PretrainedConfig): + model_config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a @@ -1485,40 +1421,39 @@ class HybridChunkedCache(Cache): """ is_compileable = True - # Override @property since HybridChunked does its own thing + # Override @property since HybridChunked does not conform to layered caches yet key_cache = None value_cache = None def __init__( self, - config: PretrainedConfig, + model_config: PretrainedConfig, max_batch_size: int, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.bfloat16, layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, ) -> None: - super().__init__() - if not hasattr(config, "sliding_window") or config.sliding_window is None: - self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192) + if not hasattr(model_config, "sliding_window") or model_config.sliding_window is None: + self.sliding_window = getattr(model_config.get_text_config(), "attention_chunk_size", 8192) else: - self.sliding_window = config.sliding_window + self.sliding_window = model_config.sliding_window self.max_cache_len = max_cache_len # Sliding layers can't be larger than the overall max cache len self.sliding_window = min(self.sliding_window, self.max_cache_len) self.max_batch_size = max_batch_size - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.head_dim = getattr(model_config, "head_dim", model_config.hidden_size // model_config.num_attention_heads) self._dtype = dtype # If the attribute does not exist in the config, fallback to a simple StaticCache - if hasattr(config, "layer_types"): - self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types] + if hasattr(model_config, "layer_types"): + self.is_sliding = [layer_type != "full_attention" for layer_type in model_config.layer_types] else: - self.is_sliding = [False] * config.num_hidden_layers + self.is_sliding = [False] * model_config.num_hidden_layers self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] - self.cumulative_length = [0 for _ in range(config.num_hidden_layers)] + self.cumulative_length = [0 for _ in range(model_config.num_hidden_layers)] def initialise_cache_layer(self, layer_idx, key_states): if len(self.key_cache) > layer_idx: @@ -1648,12 +1583,6 @@ def reorder_cache(self, beam_idx: torch.LongTensor): self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ if self.is_sliding[layer_idx]: query_length = cache_position.shape[0] first_cache_position = cache_position[0] @@ -1682,7 +1611,7 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ class OffloadedHybridCache(HybridChunkedCache): def __init__( self, - config: PretrainedConfig, + model_config: PretrainedConfig, max_batch_size: int, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, @@ -1690,7 +1619,7 @@ def __init__( offload_device: Union[str, torch.device] = torch.device("cpu"), layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, ): - super().__init__(config, max_batch_size, max_cache_len, device, dtype, layer_device_map) + super().__init__(model_config, max_batch_size, max_cache_len, device, dtype, layer_device_map) # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps # track of the original device of each layer @@ -1801,15 +1730,6 @@ class OffloadedStaticCache(StaticCache): This cache maintains the compilation-friendly properties of StaticCache while enabling much longer sequences by offloading inactive layers to CPU memory. - Parameters: - config (`PretrainedConfig`): Model configuration for shape/device info. - max_batch_size (`int`): Maximum batch size for static caches. - max_cache_len (`int`, *optional*): Maximum sequence length. - device (`torch.device` or `str`, *optional*): Device for cache tensors. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type for cache tensors. - offload_device (`Union`, *optional*, defaults to `"cpu"`): Device to offload cache tensors to. - layer_device_map (`dict[int, Union[str, torch.device, int]]`, *optional*): Per-layer device mapping. - Example: ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache @@ -1834,25 +1754,8 @@ class OffloadedStaticCache(StaticCache): ``` """ - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: Optional[torch.dtype] = None, - offload_device: Union[str, torch.device] = "cpu", - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - super().__init__( - config=config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=device, - dtype=dtype, - layer_device_map=layer_device_map, - processor=OffloadedCacheProcessor(offload_device), - ) + def __init__(self, *args, offload_device: Union[str, torch.device] = "cpu", **kwargs) -> None: + super().__init__(*args, cache_processor=OffloadedCacheProcessor(offload_device), **kwargs) class OffloadedCacheProcessor(CacheProcessor): @@ -1885,8 +1788,8 @@ def init(self, cache: "Cache", **kwargs) -> None: if self.is_static: for i, layer in enumerate(cache.layers): device = cache.config.device if i == 0 else self.offload_device - layer.key_cache = layer.key_cache.to(device) - layer.value_cache = layer.value_cache.to(device) + layer.keys = layer.keys.to(device) + layer.values = layer.values.to(device) self.original_device.append(cache.config.device) if len(cache) != cache.config.num_layers: raise ValueError("If static layers are used, all cache layers must be initialized") @@ -1933,18 +1836,18 @@ def _prefetch_layer(self, cache: "Cache", layer_idx: int): ): # Prefetch next layer tensors to GPU device = self.original_device[layer_idx] - cache.key_cache[layer_idx] = cache.key_cache[layer_idx].to(device, non_blocking=True) - cache.value_cache[layer_idx] = cache.value_cache[layer_idx].to(device, non_blocking=True) + cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.to(device, non_blocking=True) + cache.layers[layer_idx].values = cache.layers[layer_idx].values.to(device, non_blocking=True) def _evict_previous_layer(self, cache: "Cache", layer_idx: int): """Moves the previous layer cache to the CPU.""" if len(cache) >= 2: # Layer 0 stays on device to be on-device after all layers are created # We do it on the default stream so it occurs after all earlier computations on these tensors are done prev_layer_idx = (layer_idx - 1) % len(cache) - cache.key_cache[prev_layer_idx] = cache.key_cache[prev_layer_idx].to( + cache.layers[prev_layer_idx].keys = cache.layers[prev_layer_idx].keys.to( self.offload_device, non_blocking=True ) - cache.value_cache[prev_layer_idx] = cache.value_cache[prev_layer_idx].to( + cache.layers[prev_layer_idx].values = cache.layers[prev_layer_idx].values.to( self.offload_device, non_blocking=True ) @@ -1957,8 +1860,8 @@ def _ensure_layer_on_device(self, cache: "Cache", layer_idx: int): # Handle delayed beam search operations if self.beam_idx is not None: self.beam_idx = self.beam_idx.to(self.original_device[layer_idx]) - cache.key_cache[layer_idx] = cache.key_cache[layer_idx].index_select(0, self.beam_idx) - cache.value_cache[layer_idx] = cache.value_cache[layer_idx].index_select(0, self.beam_idx) + cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.index_select(0, self.beam_idx) + cache.layers[layer_idx].values = cache.layers[layer_idx].values.index_select(0, self.beam_idx) class QuantizedCacheProcessor(CacheProcessor): @@ -1971,16 +1874,16 @@ class QuantizedCacheProcessor(CacheProcessor): def __init__(self, cache_config: QuantizedCacheConfig): self.config = cache_config - self._quantized_key_cache: list[torch.Tensor] = [] - self._quantized_value_cache: list[torch.Tensor] = [] - self._seen_tokens = 0 + self._quantized_keys: list[torch.Tensor] = [] + self._quantized_values: list[torch.Tensor] = [] def init(self, cache: "Cache", **kwargs) -> None: """Initialize the quantized processor and validate configuration.""" self.config.validate() + self.erased_length = 0 # Only compatible with DynamicCache - if not isinstance(cache, DynamicCache): + if not isinstance(cache.layers[0], DynamicLayer): raise ValueError("QuantizedCacheProcessor is only compatible with DynamicCache") def post_update( @@ -1992,29 +1895,25 @@ def post_update( cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Apply quantization after cache update.""" - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_tensors.shape[-2] - if len(cache.key_cache) < layer_idx: + if len(cache) < layer_idx: raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") # `key_tensors` is the content of the residual cache, after having been updated by DynamicLayer - # On the first forward pass, we quantize the whole prompt. + # On the first forward pass, we quantize the whole prompt (prefill, quantize_length=0) # On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full. - is_prefill = self._get_quantized_length(layer_idx) == 0 - - if is_prefill: - self._quantized_key_cache.append(self._quantize(key_tensors.contiguous(), axis=self.config.axis_key)) - self._quantized_value_cache.append(self._quantize(value_tensors.contiguous(), axis=self.config.axis_value)) + if self._is_quantized_length_zero(layer_idx): + self._quantized_keys.append(self._quantize(key_tensors.contiguous(), axis=self.config.axis_key)) + self._quantized_values.append(self._quantize(value_tensors.contiguous(), axis=self.config.axis_value)) # Clear the residual cache - cache.key_cache[layer_idx] = torch.zeros( + self.erased_length = key_tensors.shape[-2] + cache.layers[layer_idx].keys = torch.zeros( 0, dtype=key_tensors.dtype, device=key_tensors.device, ) - cache.value_cache[layer_idx] = torch.zeros( + cache.layers[layer_idx].values = torch.zeros( 0, dtype=value_tensors.dtype, device=value_tensors.device, @@ -2024,26 +1923,27 @@ def post_update( else: # Prepend the previously quantized cache - dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) - dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + dequant_key = self._dequantize(self._quantized_keys[layer_idx]) + dequant_value = self._dequantize(self._quantized_values[layer_idx]) keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2) values_to_return = torch.cat([dequant_value, value_tensors], dim=-2) if key_tensors.shape[-2] >= self.config.residual_length: # Quantize and store - self._quantized_key_cache[layer_idx] = self._quantize( + self._quantized_keys[layer_idx] = self._quantize( keys_to_return.contiguous(), axis=self.config.axis_key ) - self._quantized_value_cache[layer_idx] = self._quantize( + self._quantized_values[layer_idx] = self._quantize( values_to_return.contiguous(), axis=self.config.axis_value ) # Clear the residual cache - cache.key_cache[layer_idx] = torch.zeros( + self.erased_length += key_tensors.shape[-2] + cache.layers[layer_idx].keys = torch.zeros( 0, dtype=key_tensors.dtype, device=key_tensors.device, ) - cache.value_cache[layer_idx] = torch.zeros( + cache.layers[layer_idx].values = torch.zeros( 0, dtype=value_tensors.dtype, device=value_tensors.device, @@ -2059,6 +1959,10 @@ def _dequantize(self, tensor: torch.Tensor) -> torch.Tensor: """Dequantize a tensor - to be implemented by specific quantization backends.""" raise NotImplementedError("Quantization backend must implement _dequantize method") + def _is_quantized_length_zero(self, layer_idx: int) -> bool: + """Check if quantized cache is empty for layer. Note: shape[-2] is unreliable since quantized tensors are bit-packed and flattened.""" + return layer_idx >= len(self._quantized_keys) + class QuantoQuantizedCacheProcessor(QuantizedCacheProcessor): """ @@ -2107,12 +2011,6 @@ def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor: """Dequantize tensor using quanto backend.""" return qtensor.dequantize() - def _get_quantized_length(self, layer_idx: int) -> int: - """Get the length of quantized cache for a layer.""" - if layer_idx < len(self._quantized_key_cache): - return self._quantized_key_cache[layer_idx].shape[-2] - return 0 - class HQQQuantizedCacheProcessor(QuantizedCacheProcessor): """ @@ -2163,12 +2061,6 @@ def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tens tensor = self.quantizer.dequantize(quant_tensor, meta) return tensor - def _get_quantized_length(self, layer_idx: int) -> int: - """Get the length of quantized cache for a layer.""" - if layer_idx < len(self._quantized_key_cache): - return self._quantized_key_cache[layer_idx][0].shape[-2] - return 0 - class QuantizedCache(DynamicCache): """ @@ -2192,16 +2084,7 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None: else: raise ValueError(f"Unknown quantization backend `{cache_config.backend}`") - super().__init__(processor=processor) - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is - # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to verify attn_weight shape in some models - return self.processor._seen_tokens if layer_idx == 0 else self.processor._seen_tokens - 1 + super().__init__(cache_processor=processor) class QuantoQuantizedCache(QuantizedCache): @@ -2244,7 +2127,7 @@ class QuantoQuantizedCache(QuantizedCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: - Cache.__init__(self, processor=QuantoQuantizedCacheProcessor(cache_config)) + Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor(cache_config)) class HQQQuantizedCache(QuantizedCache): @@ -2287,7 +2170,7 @@ class HQQQuantizedCache(QuantizedCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: - Cache.__init__(self, processor=HQQQuantizedCacheProcessor(cache_config)) + Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor(cache_config)) class SinkCache(Cache): @@ -2422,3 +2305,10 @@ def reset(self): return MambaCache raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +LAYER_CLASS_MAP = { + "full_attention": StaticLayer, + "sliding_attention": SlidingWindowLayer, + # "chunked_attention": ChunkedLayer, +} diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 242b16195460..53095c19121a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1951,7 +1951,7 @@ def _get_cache( layer_device_map = self._get_layer_device_map_for_cache_init() cache_kwargs = { - "config": self.config.get_text_config(), + "model_config": self.config.get_text_config(), "max_batch_size": batch_size, "max_cache_len": max_cache_len, "dtype": cache_dtype, diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 0df283c83b71..36aab8699a81 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -275,15 +275,15 @@ def __init__(self, model: PreTrainedModel): self.model = model self.static_cache = StaticCache( - config=self.model.config, + model_config=self.model.config, max_batch_size=self.model.generation_config.cache_config.batch_size, max_cache_len=self.model.generation_config.cache_config.max_cache_len, device=self.model.generation_config.cache_config.device, dtype=self.model.dtype, ) - for i in range(len(self.static_cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + for i in range(len(self.static_cache)): + self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): """ @@ -404,7 +404,7 @@ def __init__( # Initialize the HybridCache self.cache = HybridCache( - config=self.model.config, + model_config=self.model.config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=self.model.device, @@ -412,9 +412,9 @@ def __init__( ) # Register all key and value cache tensors as buffers - for i in range(len(self.cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False) + for i in range(len(self.cache)): + self.register_buffer(f"key_cache_{i}", self.cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.cache.layers[i].values, persistent=False) def forward( self, @@ -550,7 +550,7 @@ def __init__(self, model, max_static_cache_length, batch_size): # Initialize static cache self.static_cache = StaticCache( - config=self.config, + model_config=self.config, max_batch_size=batch_size, max_cache_len=max_static_cache_length, device="cpu", @@ -558,9 +558,9 @@ def __init__(self, model, max_static_cache_length, batch_size): ) # Register cache buffers to make them exportable - for i in range(len(self.static_cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + for i in range(len(self.static_cache)): + self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): # Get outputs from decoder diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 128abd56ffac..e06056d7c0be 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -692,10 +692,10 @@ def create_causal_mask( useful to easily overlay another mask on top of the causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the full layers - if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding: - layer_idx = past_key_values.is_sliding.index(False) - else: - layer_idx = 0 + is_sliding = [] + if past_key_values is not None: + is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] + layer_idx = is_sliding.index(True) if True in is_sliding else 0 early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx @@ -774,10 +774,10 @@ def create_sliding_window_causal_mask( useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the sliding layers - if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding: - layer_idx = past_key_values.is_sliding.index(True) - else: - layer_idx = 0 + is_sliding = [] + if past_key_values is not None: + is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] + layer_idx = is_sliding.index(True) if True in is_sliding else 0 early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx @@ -861,10 +861,10 @@ def create_chunked_causal_mask( useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the sliding layers - if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding: - layer_idx = past_key_values.is_sliding.index(True) - else: - layer_idx = 0 + is_sliding = [] + if past_key_values is not None: + is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers] + layer_idx = is_sliding.index(True) if True in is_sliding else 0 early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 994bf9d85dca..57a1eff22c65 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -230,8 +230,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 465b94e13bee..bb4660ae0998 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1293,8 +1293,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 8a0c43eafd3f..87cd60b2ca66 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -207,8 +207,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 7821e1c7b4fb..974a3fb7e9ec 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -229,8 +229,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 550e51221929..3f08d6804f9c 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -213,8 +213,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 19cac3e8c3ac..12677705002c 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -356,8 +356,8 @@ def forward( is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False if past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] - value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] + key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys + value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values else: key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index fe437fde84ed..dfe345968da6 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -182,8 +182,8 @@ def forward( is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False if past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] - value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] + key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys + value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values else: key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 0817e16451ac..57c8c4900e07 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1332,8 +1332,8 @@ def forward( else: indices = cache_position - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] + key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices] + value_states = past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices] else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index a3ffa710d842..7b8bcc6d37ec 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1774,8 +1774,8 @@ def forward( else: indices = cache_position - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] + key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices] + value_states = past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices] else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index a8504db42c82..0526a067b020 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -728,8 +728,9 @@ def forward( # Ensure layer_past is on same device as hidden_states (might not be correct) if past_key_values is not None: - past_key_values.key_cache = past_key_values.key_cache.to(hidden_states.device) - past_key_values.value_cache = past_key_values.value_cache.to(hidden_states.device) + for layer in past_key_values.layers: + layer.keys = layer.keys.to(hidden_states.device) + layer.values = layer.values.to(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states if causal_mask is not None: diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 9718e8fb736e..c988a874c20f 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -484,8 +484,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) @@ -601,8 +601,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 3d46275bdc81..606627823514 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -290,8 +290,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 081869ec8fc5..1a15f11b3850 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -478,8 +478,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 7d5a73667ee4..05368453e9e3 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -294,8 +294,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 7319671b485e..f199b3edfe19 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -229,8 +229,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 2585d91a3e3e..26f7b53caa67 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -239,8 +239,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 66ed4adcea4c..e3e438c0714b 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -105,16 +105,14 @@ def batch_repeat_interleave(self, repeats: int): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) else: - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.layers[layer_idx].batch_repeat_interleave(repeats) def batch_select_indices(self, indices: torch.Tensor): for layer_idx in range(len(self)): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] else: - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.layers[layer_idx].batch_select_indices(indices) def crop(self, max_length: int): raise RuntimeError("MiniMaxCache doesnot support `crop` method") diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 9b6fc12ae3de..0477a942a695 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -215,16 +215,14 @@ def batch_repeat_interleave(self, repeats: int): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) else: - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.layers[layer_idx].batch_repeat_interleave(repeats) def batch_select_indices(self, indices: torch.Tensor): for layer_idx in range(len(self)): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] else: - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.layers[layer_idx].batch_select_indices(indices) def crop(self, max_length: int): raise RuntimeError("MiniMaxCache doesnot support `crop` method") diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index d33edcb3dd00..985a35448e99 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -496,8 +496,8 @@ def forward( ) elif cache_position[0] != 0: key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], + past_key_value.layers[self.layer_idx].keys, + past_key_value.layers[self.layer_idx].values, ) else: raise ValueError( diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 2909fb386fb5..307bbe9ae90e 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -235,8 +235,8 @@ def forward( # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = ( self.k_proj(current_states) diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 500231f3b48b..df864b0c1ff2 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -331,8 +331,8 @@ def forward( # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = ( self.k_proj(current_states) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 5584b2ee8255..f467467e7fba 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -376,8 +376,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 2ffb53ee9e01..5ade9ee41d8a 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -228,8 +228,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 13f0ea27a6e4..396b49dfb6e4 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -249,8 +249,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 6b90ae80d7c7..a01f88b5443b 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -770,8 +770,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.key(current_states) value_states = self.value(current_states) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 327b70b5ec73..04384b8265a6 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -425,8 +425,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 5c4285afe728..13741a20ac18 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -320,8 +320,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index b0273c8a4a33..e5bdc624feb9 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -513,8 +513,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 2a1a84b81523..de7dbfa3e740 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -501,8 +501,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index feccf6d7d9fd..acf3bac94bcb 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -352,8 +352,8 @@ def forward( past_key_value.is_updated[self.layer_idx] = True # cross-attention: reuse cached states else: - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index ae69ae991009..522b60ebc83b 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -308,8 +308,8 @@ def forward( past_key_value.is_updated[self.layer_idx] = True # cross-attention: reuse cached states else: - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 778a0485b4e8..4c37cd42ef63 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -394,8 +394,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 8d4e368e945b..74bbce0259d8 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -599,8 +599,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 2b1f650c6789..bd2f3dd10ea6 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -285,8 +285,8 @@ def forward( current_states = encoder_hidden_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 248d17cac404..c5ce00016e6a 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1140,8 +1140,8 @@ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None for layer_idx in range(self.config.decoder_layers): layer_past_key_values = [] for cache_cls in [values.self_attention_cache, values.cross_attention_cache]: - for v in [cache_cls.key_cache, cache_cls.value_cache]: - layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu()) + for v in [cache_cls.layers[layer_idx].keys, cache_cls.layers[layer_idx].values]: + layer_past_key_values.append(v[batch_idx][None].cpu()) all_past_key_values.append(tuple(layer_past_key_values)) return tuple(all_past_key_values) else: diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d3e9c8e03a2b..e5dc5d59e7f3 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -329,8 +329,8 @@ def forward( current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 746fd2179d20..533c72be199f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1595,9 +1595,7 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): # 3. Check cache shapes # 3.1. Encoder-Decoder checks if config.is_encoder_decoder: - num_cache_decoder_layers = ( - len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache) - ) + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1605,30 +1603,30 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] + self_attention_layer_keys = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys ) - self_attention_layer_value_cache = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] + self_attention_layer_values = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values ) - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_key_cache = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] + cross_attention_layer_keys = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys ) - cross_attention_layer_value_cache = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] + cross_attention_layer_values = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values ) - cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] - cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] - self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) + cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] + cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] + self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) # 3.2. Decoder-only checks else: - num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache) + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1636,10 +1634,10 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] - self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys + self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) @pytest.mark.generate @parameterized.expand([("greedy", 1), ("beam search", 2)]) @@ -1798,8 +1796,8 @@ def test_generate_from_inputs_embeds_with_static_cache(self): max_length = max_new_tokens + inputs_embeds.shape[1] - 1 cache_shape = [batch_size, num_key_value_heads, max_length, head_dim] self.assertIsInstance(outputs.past_key_values, StaticCache) - self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers) - self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape) + self.assertEqual(len(outputs.past_key_values), num_hidden_layers) + self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape) @pytest.mark.generate def test_generate_continue_from_past_key_values(self): @@ -2029,8 +2027,8 @@ def test_generate_with_static_cache(self): num_hidden_layers = text_config.num_hidden_layers cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache)) - self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers) - self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape) + self.assertTrue(len(static_cache_generation.past_key_values) == num_hidden_layers) + self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape) # Check 2: The outputs must be similar to the case with dynamic cache dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) @@ -2612,12 +2610,12 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value if isinstance(decoder_past_key_values, Cache): self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_shape] * len(decoder_past_key_values.key_cache), + [layer.keys.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) self.assertListEqual( - [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], - [expected_shape] * len(decoder_past_key_values.value_cache), + [layer.values.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) # Legacy cache format checks. This branch should be removed when all models use `Cache` by default @@ -3976,13 +3974,13 @@ def test_generate_with_static_cache_multi_accelerator(self): self.assertTrue(isinstance(results.past_key_values, StaticCache)) # check device of each layer - key_cache_0 = results.past_key_values.key_cache[0] - value_cache_0 = results.past_key_values.value_cache[0] - self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + keys_0 = results.past_key_values.layers[0].keys + values_0 = results.past_key_values.layers[0].values + self.assertTrue(keys_0.device == values_0.device == torch.device(0)) - key_cache_1 = results.past_key_values.key_cache[1] - value_cache_1 = results.past_key_values.value_cache[1] - self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + keys_1 = results.past_key_values.layers[1].keys + values_1 = results.past_key_values.layers[1].values + self.assertTrue(keys_1.device == values_1.device == torch.device(1)) @pytest.mark.generate @require_torch_multi_accelerator @@ -4054,13 +4052,13 @@ def test_init_static_cache_multi_accelerator(self): results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) # check device of each layer - key_cache_0 = results.past_key_values.key_cache[0] - value_cache_0 = results.past_key_values.value_cache[0] - self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + keys_0 = results.past_key_values.layers[0].keys + values_0 = results.past_key_values.layers[0].values + self.assertTrue(keys_0.device == values_0.device == torch.device(0)) - key_cache_1 = results.past_key_values.key_cache[1] - value_cache_1 = results.past_key_values.value_cache[1] - self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + keys_1 = results.past_key_values.layers[1].keys + values_1 = results.past_key_values.layers[1].values + self.assertTrue(keys_1.device == values_1.device == torch.device(1)) @slow def test_padding_input_contrastive_search_gpt2(self): diff --git a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py index 6c0c3a19d067..87f7b2abb0e9 100644 --- a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py +++ b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -440,13 +440,11 @@ def test_past_key_values_format(self): # difference: last dim k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim v_embed_dim = config.v_head_dim - self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) - self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) + self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) + self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) # build the full cache shapes num_hidden_layers = config.num_hidden_layers - all_cache_shapes = [ - [self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers) - ] + all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)] super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes) @require_torch_large_accelerator diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py index f9427160c254..34a7a9884728 100644 --- a/tests/models/dia/test_modeling_dia.py +++ b/tests/models/dia/test_modeling_dia.py @@ -399,12 +399,12 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value if isinstance(decoder_past_key_values, Cache): self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_shape] * len(decoder_past_key_values.key_cache), + [layer.keys.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) self.assertListEqual( - [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], - [expected_shape] * len(decoder_past_key_values.value_cache), + [layer.values.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) def _check_scores(self, batch_size, scores, generated_length, config): diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index b0a0a6a3ccb4..ecd2af9fdc6c 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -235,8 +235,8 @@ def copy_cache(cache: DynamicCache): """Deep copy a DynamicCache to reuse the same one multiple times.""" new_cache = cache for i in range(len(cache)): - new_cache.key_cache[i] = cache.key_cache[i].clone() - new_cache.value_cache[i] = cache.value_cache[i].clone() + new_cache.layers[i].keys = cache.layers[i].keys.clone() + new_cache.layers[i].values = cache.layers[i].values.clone() # Cached forward once with the attention mask provided and the other time without it (which should assume full attention) # We need to run both on a copy of the cache, otherwise it is modified in-place diff --git a/tests/models/t5gemma/test_modeling_t5gemma.py b/tests/models/t5gemma/test_modeling_t5gemma.py index fd61e5e5c5db..269fee53d165 100644 --- a/tests/models/t5gemma/test_modeling_t5gemma.py +++ b/tests/models/t5gemma/test_modeling_t5gemma.py @@ -272,7 +272,7 @@ def create_and_check_model( self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertIsNotNone(decoder_past) self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers) - self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers) + self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers) def check_prepare_lm_labels_via_shift_left( self, @@ -1069,9 +1069,7 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): # 3. Check cache shapes # 3.1. Encoder-Decoder checks if config.is_encoder_decoder: - num_cache_decoder_layers = ( - len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache) - ) + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1079,30 +1077,30 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] + self_attention_layer_keys = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys ) - self_attention_layer_value_cache = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] + self_attention_layer_values = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values ) - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_key_cache = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] + cross_attention_layer_keys = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys ) - cross_attention_layer_value_cache = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] + cross_attention_layer_values = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values ) - cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] - cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] - self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) + cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] + cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] + self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) # 3.2. Decoder-only checks else: - num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache) + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1110,10 +1108,10 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] - self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys + self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) @unittest.skip("Mismatch issue doesn't exist in T5Gemma.") def test_load_with_mismatched_shapes(self): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 8580316d26bd..d5c1463cf618 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -171,7 +171,9 @@ def _random_kvs(config): return random_keys, random_values mha_config = LlamaConfig(num_attention_heads=32) - mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) + mha_static_cache = StaticCache( + model_config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device + ) cached_keys, cached_values = mha_static_cache.update( *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -179,7 +181,9 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 32, 10, 128)) gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) - gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + gqa_static_cache = StaticCache( + model_config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device + ) cached_keys, cached_values = gqa_static_cache.update( *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -187,7 +191,9 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 4, 10, 128)) mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) - mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + mqa_static_cache = StaticCache( + model_config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device + ) cached_keys, cached_values = mqa_static_cache.update( *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -323,7 +329,7 @@ def test_quantized_cache_generation(self, backend): ) self.assertIsInstance(gen_out.past_key_values, QuantizedCache) - processor = gen_out.past_key_values.processor + processor = gen_out.past_key_values.cache_processor if backend == "quanto": self.assertIsInstance(processor, QuantoQuantizedCacheProcessor) elif backend == "hqq": @@ -332,12 +338,10 @@ def test_quantized_cache_generation(self, backend): decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) self.assertListEqual(decoded, expected_generation) - self.assertTrue(len(processor._quantized_key_cache) > 0) + self.assertTrue(len(processor._quantized_keys) > 0) # Check that something is actually quantized - has_been_quantized = any( - (q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_key_cache - ) + has_been_quantized = any((q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_keys) self.assertTrue(has_been_quantized) @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) @@ -654,7 +658,7 @@ def test_dynamic_cache_exportability(self): past_key_values=DynamicCache(), use_cache=True, ) - self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers) + self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers) self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs)) self.assertEqual( 3, @@ -675,11 +679,9 @@ def test_dynamic_cache_exportability(self): use_cache=True, ) self.assertTrue(torch.allclose(res.logits, res_eager.logits)) - for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache): - self.assertTrue(torch.allclose(k1, k2)) - - for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache): - self.assertTrue(torch.allclose(v1, v2)) + for l1, l2 in zip(res.past_key_values.layers, res_eager.past_key_values.layers): + self.assertTrue(torch.allclose(l1.keys, l2.keys)) + self.assertTrue(torch.allclose(l1.values, l2.values)) def test_dynamic_cache_exportability_multiple_run(self): # When exporting with DynamicCache, you should export two graphs: @@ -703,7 +705,7 @@ def test_dynamic_cache_exportability_multiple_run(self): past_key_values=DynamicCache(), use_cache=True, ) - self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers) + self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers) self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs)) self.assertEqual( 3, @@ -728,9 +730,9 @@ def test_dynamic_cache_exportability_multiple_run(self): shapes = torch.export.ShapesCollection() dyn = torch.export.Dim("seq", max=512) - for ix in range(len(past_key_values.key_cache)): - shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None) - shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None) + for ix in range(len(past_key_values)): + shapes[past_key_values.layers[ix].keys] = (None, None, dyn, None) + shapes[past_key_values.layers[ix].values] = (None, None, dyn, None) ep_second = torch.export.export( model, @@ -771,11 +773,9 @@ def test_dynamic_cache_exportability_multiple_run(self): use_cache=True, ) - for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache): - self.assertTrue(torch.allclose(k1, k2)) - - for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache): - self.assertTrue(torch.allclose(v1, v2)) + for l1, l2 in zip(res_export_2.past_key_values.layers, res_eager_2.past_key_values.layers): + self.assertTrue(torch.allclose(l1.keys, l2.keys)) + self.assertTrue(torch.allclose(l1.values, l2.values)) def test_static_cache_exportability(self): """ @@ -922,7 +922,7 @@ def setUp(self): def test_static_cache_out_of_bounds(self): """Test StaticCache raises IndexError for out-of-bounds positions.""" - static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + static_cache = StaticCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len with self.assertRaises(IndexError): @@ -944,7 +944,7 @@ def test_static_cache(self): update pos 3: [1.0, 2.0, 3.0, 4.0] """ # Scenario 1: Fill up to near capacity - static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + static_cache = StaticCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None) static_cache.update( @@ -954,7 +954,7 @@ def test_static_cache(self): cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" + static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" ) # Scenario 2: Fill to capacity @@ -965,7 +965,7 @@ def test_static_cache(self): cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" + static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" ) def test_sliding_window_cache(self): @@ -984,7 +984,9 @@ def test_sliding_window_cache(self): result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ # Scenario 1: Update within window, no slide yet - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache( + model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len + ) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] sliding_cache.update( key_states=prefill, @@ -999,13 +1001,15 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "SlidingWindowCache Scenario 1 failed", ) # Scenario 2: Update causing slide - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache( + model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len + ) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] sliding_cache.update( key_states=prefill, @@ -1020,13 +1024,15 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "SlidingWindowCache Scenario 2 failed", ) # Scenario 3: Long prompt handling - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache( + model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len + ) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] sliding_cache.update( key_states=long_prefill, @@ -1035,7 +1041,7 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "SlidingWindowCache Scenario 3 failed", ) @@ -1054,7 +1060,7 @@ def test_hybrid_cache_static_mode(self): config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0) # Scenario 1 - hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache_static_mode = HybridCache(model_config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] hybrid_cache_static_mode.update( key_states=prefill, @@ -1069,7 +1075,7 @@ def test_hybrid_cache_static_mode(self): cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Static Scenario 1 failed", ) @@ -1082,7 +1088,7 @@ def test_hybrid_cache_static_mode(self): cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "HybridCache Static Scenario 2 failed", ) @@ -1106,7 +1112,7 @@ def test_hybrid_cache_sliding_mode(self): result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ # Scenario 1: Update within window, no slide yet - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, @@ -1121,13 +1127,13 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Sliding Scenario 1 failed", ) # Scenario 2: Update causing first slide - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, @@ -1142,7 +1148,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "HybridCache Sliding Scenario 2 failed", ) @@ -1155,13 +1161,13 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 3 failed", ) # Scenario 4: Long prompt handling - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] hybrid_cache.update( key_states=long_prefill, @@ -1170,7 +1176,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 4 failed", ) @@ -1190,10 +1196,10 @@ def test_dynamic_cache(self): cache = DynamicCache() cache.update(prefill, prefill, 0) cache.update(update3, update3, 0) - self.assertEqual(cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed") + self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed") cache.update(update4, update4, 0) self.assertEqual( - cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed" + cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed" ) # Scenario 2: prefill and update for two layers independently @@ -1210,8 +1216,10 @@ def test_dynamic_cache(self): cache.update(update4, update4, 0) cache.update(update4_1, update4_1, 1) self.assertEqual( - cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed" + cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed" ) self.assertEqual( - cache.key_cache[1][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0], "DynamicCache Scenario 2 layer 1 failed" + cache.layers[1].keys[0, 0, :, 0].tolist(), + [10.0, 20.0, 30.0, 40.0], + "DynamicCache Scenario 2 layer 1 failed", ) diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index bc247b2b6011..eb101ab566aa 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -956,8 +956,9 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str): idx += 1 if "".join(source[start_idx:idx])[:-1] != old_doc_args: - # Args are not fully defined in the docstring of this object - return + raise ValueError( + f"Expected\n{old_doc_args}\nbut got\n{''.join(source[start_idx:idx])[:-1]}\n in {find_source_file(obj)}: {obj.__name__}" + ) obj_file = find_source_file(obj) with open(obj_file, "r", encoding="utf-8") as f: