From acb901e0324b06a83b7a5364c7556a8b4ecead20 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Thu, 24 Apr 2025 11:52:39 +0200 Subject: [PATCH 01/19] squash rebase --- src/transformers/cache_utils.py | 276 +++++++++++++++++--------------- tests/utils/test_cache_utils.py | 201 +++++++++++++++++++++++ 2 files changed, 352 insertions(+), 125 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 85a09f03de22..ea049fd2291e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -21,6 +21,104 @@ logger = logging.get_logger(__name__) +# Utility functions for static/sliding cache update logic +def _static_cache_update_logic( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_position: Optional[torch.LongTensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the static cache tensors in place. + + Args: + k_cache (`torch.Tensor`): The key cache tensor to update. + v_cache (`torch.Tensor`): The value cache tensor to update. + key_states (`torch.Tensor`): The new key states to add. + value_states (`torch.Tensor`): The new value states to add. + cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted. + If None, the entire cache is overwritten (prefill). + + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place). + """ + if cache_position is None: + # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. + k_cache.copy_(key_states) + v_cache.copy_(value_states) + else: + # Generation phase. Update specific positions. + # Use index_copy_ for in-place update (compile-friendly). + try: + k_cache.index_copy_(2, cache_position, key_states) + v_cache.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # Fallback for devices like MPS where index_copy_ might not be supported. + k_cache[:, :, cache_position] = key_states + v_cache[:, :, cache_position] = value_states + return k_cache, v_cache + + +def _sliding_cache_update_logic( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_position: torch.LongTensor, + max_cache_len: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the sliding window cache tensors, returning the potentially modified tensors. + + Args: + k_cache (`torch.Tensor`): The key cache tensor to update. + v_cache (`torch.Tensor`): The value cache tensor to update. + key_states (`torch.Tensor`): The new key states to add. + value_states (`torch.Tensor`): The new value states to add. + cache_position (`torch.LongTensor`): The position indices where the new states should be inserted. + max_cache_len (`int`): The maximum length of the sliding window cache. + + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update. + For prefill > window, these are the full input states. + Otherwise, they are the updated cache tensors. + """ + # Handle prefill phase when prompt length > sliding_window_size + if cache_position.shape[0] > max_cache_len: + new_k = key_states[:, :, -max_cache_len:, :] + new_v = value_states[:, :, -max_cache_len:, :] + k_cache.copy_(new_k) + v_cache.copy_(new_v) + return key_states, value_states + + # Sliding window logic for generation phase or prefill < window + slicing = torch.arange(max_cache_len, device=value_states.device) + current_seq_len = cache_position[-1] + 1 # Use last position to determine current length + should_shift = current_seq_len > max_cache_len + indices = (slicing + should_shift.int()) % max_cache_len + + k_out_shifted = k_cache[:, :, indices] + v_out_shifted = v_cache[:, :, indices] + + # Clamp cache_position to determine the *target index* within the shifted cache view + update_position = cache_position.clamp(min=0, max=max_cache_len - 1) + + try: + k_out_updated = k_out_shifted.index_copy(2, update_position, key_states) + v_out_updated = v_out_shifted.index_copy(2, update_position, value_states) + except NotImplementedError: + # Fallback for MPS: clone and modify the clone + k_out_updated = k_out_shifted.clone() + v_out_updated = v_out_shifted.clone() + k_out_updated[:, :, update_position] = key_states + v_out_updated[:, :, update_position] = value_states + + k_cache.copy_(k_out_updated) + v_cache.copy_(v_out_updated) + return k_out_updated, v_out_updated + + class Cache: """ Base, abstract class for all caches. The actual data structure is specific to each subclass. @@ -1262,28 +1360,16 @@ def update( """ if cache_kwargs is None: cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place - # operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - return k_out, v_out + key_states = key_states.to(self.key_cache[layer_idx].dtype) + value_states = value_states.to(self.value_cache[layer_idx].dtype) + return _static_cache_update_logic( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + key_states, + value_states, + cache_kwargs.get("cache_position"), + ) def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model.""" @@ -1386,6 +1472,18 @@ def __init__( layer_device_map=layer_device_map, ) + def _sliding_update_logic(self, layer_idx, cache_position, key_states, value_states): + """Performs the actual cache update for sliding window cache.""" + # Call the standalone utility function + return _sliding_cache_update_logic( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + key_states, + value_states, + cache_position, + self.max_cache_len, # Pass the window size + ) + def update( self, key_states: torch.Tensor, @@ -1396,46 +1494,16 @@ def update( if cache_kwargs is None: cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) - if cache_position.shape[0] > self.max_cache_len: - k_out = key_states[:, :, -self.max_cache_len :, :] - v_out = value_states[:, :, -self.max_cache_len :, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states - - slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) - cache_position = cache_position.clamp(0, self.max_cache_len - 1) - to_shift = cache_position > self.max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len - - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + if cache_position is None: + raise ValueError("`cache_position` must be provided for SlidingWindowCache.") - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out + # Ensure correct dtype + key_states = key_states.to(self.key_cache[layer_idx].dtype) + value_states = value_states.to(self.value_cache[layer_idx].dtype) - return k_out, v_out + # Call the extracted logic + return self._sliding_update_logic(layer_idx, cache_position, key_states, value_states) def get_max_cache_shape(self) -> Optional[int]: return self.max_cache_len @@ -1678,12 +1746,13 @@ def __init__( super().__init__() if not hasattr(config, "sliding_window") or config.sliding_window is None: raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "Setting `cache_implementation` to 'hybrid' requires the model config supporting " "sliding window attention, please check if there is a `sliding_window` field in the model " "config and it's not set to None." ) - self.max_cache_len = max_cache_len - self._sliding_window_max_len = min(config.sliding_window, max_cache_len) + self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings + # Sliding layers can't be larger than the overall max cache len + self.sliding_window_len = min(config.sliding_window, self.max_cache_len) self.max_batch_size = max_batch_size # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads self.head_dim = ( @@ -1701,14 +1770,9 @@ def __init__( ) self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) - sliding_cache_shape = ( - self.max_batch_size, - self.num_key_value_heads, - self._sliding_window_max_len, - self.head_dim, - ) - device = torch.device(device) if device is not None else None + global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) + sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim) + device = torch.device(device) if device is not None and isinstance(device, str) else None for i in range(config.num_hidden_layers): if layer_device_map is not None: layer_device = layer_device_map[i] @@ -1716,7 +1780,7 @@ def __init__( layer_device = device # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. - cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape + cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) torch._dynamo.mark_static_address(new_layer_key_cache) @@ -1724,42 +1788,6 @@ def __init__( self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) - def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - if cache_position.shape[0] > max_cache_len: - k_out = key_states[:, :, -max_cache_len:, :] - v_out = value_states[:, :, -max_cache_len:, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states - - slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) - cache_position = cache_position.clamp(0, max_cache_len - 1) - to_shift = cache_position > max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % max_cache_len - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - return k_out, v_out - - def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out - return k_out, v_out - def update( self, key_states: torch.Tensor, @@ -1770,34 +1798,32 @@ def update( if cache_kwargs is None: cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") - sliding_window = cache_kwargs.get("sliding_window") + if cache_position is None: + raise ValueError("`cache_position` must be provided for HybridCache.") + + is_sliding_layer = self.is_sliding[layer_idx].item() - # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used + # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used # when the cache is initialized in the forward pass (e.g. Gemma2) if self.key_cache[layer_idx].device != key_states.device: self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) if self.value_cache[layer_idx].device != value_states.device: self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) + k_cache = self.key_cache[layer_idx] + v_cache = self.value_cache[layer_idx] + key_states = key_states.to(k_cache.dtype) + value_states = value_states.to(v_cache.dtype) - if sliding_window: - update_fn = self._sliding_update + if is_sliding_layer: + return _sliding_cache_update_logic( + k_cache, v_cache, key_states, value_states, cache_position, + k_cache.shape[2] # Use actual cache dim as max cache len + ) else: - update_fn = self._static_update - - return update_fn( - cache_position, - layer_idx, - key_states, - value_states, - k_out, - v_out, - k_out.shape[2], - ) + return _static_cache_update_logic( + k_cache, v_cache, key_states, value_states, cache_position + ) def get_max_cache_shape(self) -> Optional[int]: return self.max_cache_len diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 980f57aa342e..4f97b27028ec 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -16,6 +16,7 @@ import unittest from parameterized import parameterized +from types import SimpleNamespace from transformers import set_seed from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS @@ -46,6 +47,8 @@ GenerationConfig, LlamaConfig, StaticCache, + SlidingWindowCache, + HybridCache, convert_and_export_with_cache, ) @@ -695,3 +698,201 @@ def test_hybrid_cache_exportability(self): dynamic_shapes=dynamic_shapes, strict=False, ) + +class SyntheticCacheTest(unittest.TestCase): + """ + Synthetic tests for StaticCache, SlidingWindowCache, and HybridCache. + Uses window_size=4, max_cache_len=4 for all tests. + """ + + def setUp(self): + """Set up common configuration for all tests.""" + self.window_size = 4 + self.max_cache_len = 4 + self.config = SimpleNamespace( + num_hidden_layers=1, + num_key_value_heads=1, + num_attention_heads=1, + head_dim=1, + hidden_size=1, + sliding_window=self.window_size, + sliding_window_pattern=2 # Example pattern + ) + + def _extract_both_caches(self, cache): + """Helper to extract flattened key and value states from the cache.""" + # Assumes layer 0, batch 0, head 0 + k = cache.key_cache[0][0, 0, :, 0].tolist() + v = cache.value_cache[0][0, 0, :, 0].tolist() + return k, v + + def test_static_cache_within_bounds(self): + """ + Test StaticCache preserves values at written positions within bounds. + Example (max_cache_len=4): + step 1 (pos 0): [1.0, 0.0, 0.0, 0.0] + step 2 (pos 1): [1.0, 2.0, 0.0, 0.0] + step 3 (pos 2): [1.0, 2.0, 3.0, 0.0] + step 4 (pos 3): [1.0, 2.0, 3.0, 4.0] + """ + static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + expected_state = [0.0] * self.max_cache_len + + for step in range(1, self.max_cache_len + 1): # Test up to cache capacity + pos_idx = step - 1 + value = float(step) + + static_cache.update( + key_states=torch.tensor([[[[value]]]]), + value_states=torch.tensor([[[[value]]]]), # Use same value for simplicity + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([pos_idx])} + ) + + expected_state[pos_idx] = value + stored_keys, stored_values = self._extract_both_caches(static_cache) + + self.assertEqual(stored_keys, expected_state, f"Static Key cache failed at step {step}") + self.assertEqual(stored_values, expected_state, f"Static Value cache failed at step {step}") + + def test_static_cache_out_of_bounds(self): + """Test StaticCache raises IndexError for out-of-bounds positions.""" + static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len + + with self.assertRaises(IndexError): + static_cache.update( + key_states=torch.tensor([[[[1.0]]]]), + value_states=torch.tensor([[[[1.0]]]]), + layer_idx=0, + cache_kwargs={"cache_position": pos_out_of_bounds} + ) + + def test_sliding_window_cache(self): + """ + Test SlidingWindowCache accumulates then slides. + Example (window_size=4): + step 4 (pos 3): [1.0, 2.0, 3.0, 4.0] + step 5 (pos 4): [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1) + step 6 (pos 5): [3.0, 4.0, 5.0, 6.0] + """ + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + expected_state = [0.0] * self.window_size + + for step in range(1, 7): # Test beyond window size + pos_idx = step - 1 + value = float(step) + + sliding_cache.update( + key_states=torch.tensor([[[[value]]]]), + value_states=torch.tensor([[[[value]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([pos_idx]), "sliding_window": self.window_size} + ) + + # Calculate expected state based on corrected sliding logic + clamped_pos = min(pos_idx, self.window_size - 1) + to_shift = pos_idx > self.window_size - 1 + + if to_shift: + expected_state = expected_state[1:] + [0.0] # Shift left, add placeholder + + expected_state[clamped_pos] = value # Insert new value + + # Only verify key cache for simplicity, assuming value is symmetrical + stored_keys, _ = self._extract_both_caches(sliding_cache) + self.assertEqual(stored_keys, expected_state, f"SlidingWindowCache failed at step {step}") + + def test_hybrid_cache_static_mode(self): + """ + Test HybridCache acts like StaticCache when 'sliding_window' is absent. + Example (max_cache_len=4): + step 4 (pos 3): [1.0, 2.0, 3.0, 4.0] + step 5 (pos 3): [1.0, 2.0, 3.0, 5.0] (pos clamped, overwrites last) + step 6 (pos 3): [1.0, 2.0, 3.0, 6.0] (pos clamped, overwrites last) + """ + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + expected_state = [0.0] * self.max_cache_len + + for step in range(1, 7): # Test beyond cache size + pos_idx_clamped = min(step - 1, self.max_cache_len - 1) + value = float(step) + + hybrid_cache.update( + key_states=torch.tensor([[[[value]]]]), + value_states=torch.tensor([[[[value]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([pos_idx_clamped])} # Use clamped pos + ) + + expected_state[pos_idx_clamped] = value + stored_keys, _ = self._extract_both_caches(hybrid_cache) + self.assertEqual(stored_keys, expected_state, f"HybridCache (static) failed at step {step}") + + def test_hybrid_cache_sliding_mode(self): + """ + Test HybridCache acts like SlidingWindowCache when 'sliding_window' is present. + Example (window_size=4): + step 4 (pos 3): [1.0, 2.0, 3.0, 4.0] + step 5 (pos 4): [2.0, 3.0, 4.0, 5.0] + step 6 (pos 5): [3.0, 4.0, 5.0, 6.0] + """ + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + expected_state = [0.0] * self.window_size + + for step in range(1, 7): # Test beyond window size + pos_idx = step - 1 + value = float(step) + + hybrid_cache.update( + key_states=torch.tensor([[[[value]]]]), + value_states=torch.tensor([[[[value]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([pos_idx]), "sliding_window": self.window_size} + ) + + clamped_pos = min(pos_idx, self.window_size - 1) + to_shift = pos_idx > self.window_size - 1 + if to_shift: + expected_state = expected_state[1:] + [0.0] + expected_state[clamped_pos] = value + + stored_keys, _ = self._extract_both_caches(hybrid_cache) + self.assertEqual(stored_keys, expected_state, f"HybridCache (sliding) failed at step {step}") + + def test_sliding_window_cache_prompt_longer_than_max_cache_len(self): + """Test SlidingWindowCache when prompt length > max_cache_len (should keep only last max_cache_len tokens).""" + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prompt_len = self.max_cache_len + 2 # e.g., 6 if max_cache_len=4 + values = [float(i + 1) for i in range(prompt_len)] + key_states = torch.tensor([[[[v] for v in values]]]) # shape (1,1,6,1) + value_states = torch.tensor([[[[v] for v in values]]]) # shape (1,1,6,1) + cache_position = torch.arange(prompt_len) + sliding_cache.update( + key_states=key_states, + value_states=value_states, + layer_idx=0, + cache_kwargs={"cache_position": cache_position, "sliding_window": self.window_size} + ) + stored_keys, stored_values = self._extract_both_caches(sliding_cache) + self.assertEqual(stored_keys, values[-self.window_size:], "SlidingWindowCache did not keep last window tokens") + self.assertEqual(stored_values, values[-self.window_size:], "SlidingWindowCache did not keep last window tokens") + + def test_hybrid_cache_prompt_longer_than_max_cache_len(self): + """Test HybridCache when prompt length > max_cache_len (should keep only last max_cache_len tokens in sliding mode).""" + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prompt_len = self.max_cache_len + 2 # e.g., 6 if max_cache_len=4 + values = [float(i + 1) for i in range(prompt_len)] + key_states = torch.tensor([[[[v] for v in values]]]) # shape (1,1,6,1) + value_states = torch.tensor([[[[v] for v in values]]]) # shape (1,1,6,1) + cache_position = torch.arange(prompt_len) + # Use sliding mode + hybrid_cache.update( + key_states=key_states, + value_states=value_states, + layer_idx=0, + cache_kwargs={"cache_position": cache_position, "sliding_window": self.window_size} + ) + stored_keys, stored_values = self._extract_both_caches(hybrid_cache) + self.assertEqual(stored_keys, values[-self.window_size:], "HybridCache (sliding) did not keep last window tokens") + self.assertEqual(stored_values, values[-self.window_size:], "HybridCache (sliding) did not keep last window tokens") \ No newline at end of file From 4eacd7dad7f5c6fe07377dffd2a517363d7a5408 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 6 May 2025 12:05:46 +0200 Subject: [PATCH 02/19] ruff --- src/transformers/cache_utils.py | 20 ++++---- tests/utils/test_cache_utils.py | 87 ++++++++++++++++++--------------- 2 files changed, 59 insertions(+), 48 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ea049fd2291e..6affdd7a06f0 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -94,7 +94,7 @@ def _sliding_cache_update_logic( # Sliding window logic for generation phase or prefill < window slicing = torch.arange(max_cache_len, device=value_states.device) - current_seq_len = cache_position[-1] + 1 # Use last position to determine current length + current_seq_len = cache_position[-1] + 1 # Use last position to determine current length should_shift = current_seq_len > max_cache_len indices = (slicing + should_shift.int()) % max_cache_len @@ -1481,7 +1481,7 @@ def _sliding_update_logic(self, layer_idx, cache_position, key_states, value_sta key_states, value_states, cache_position, - self.max_cache_len, # Pass the window size + self.max_cache_len, ) def update( @@ -1496,7 +1496,7 @@ def update( cache_position = cache_kwargs.get("cache_position") if cache_position is None: - raise ValueError("`cache_position` must be provided for SlidingWindowCache.") + raise ValueError("`cache_position` must be provided for SlidingWindowCache.") # Ensure correct dtype key_states = key_states.to(self.key_cache[layer_idx].dtype) @@ -1803,7 +1803,7 @@ def update( is_sliding_layer = self.is_sliding[layer_idx].item() - # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used + # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used # when the cache is initialized in the forward pass (e.g. Gemma2) if self.key_cache[layer_idx].device != key_states.device: self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) @@ -1817,13 +1817,15 @@ def update( if is_sliding_layer: return _sliding_cache_update_logic( - k_cache, v_cache, key_states, value_states, cache_position, - k_cache.shape[2] # Use actual cache dim as max cache len + k_cache, + v_cache, + key_states, + value_states, + cache_position, + k_cache.shape[2], # Use actual cache dim as max cache len ) else: - return _static_cache_update_logic( - k_cache, v_cache, key_states, value_states, cache_position - ) + return _static_cache_update_logic(k_cache, v_cache, key_states, value_states, cache_position) def get_max_cache_shape(self) -> Optional[int]: return self.max_cache_len diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 4f97b27028ec..d20a30835172 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -699,12 +699,13 @@ def test_hybrid_cache_exportability(self): strict=False, ) + class SyntheticCacheTest(unittest.TestCase): """ Synthetic tests for StaticCache, SlidingWindowCache, and HybridCache. Uses window_size=4, max_cache_len=4 for all tests. """ - + def setUp(self): """Set up common configuration for all tests.""" self.window_size = 4 @@ -716,7 +717,7 @@ def setUp(self): head_dim=1, hidden_size=1, sliding_window=self.window_size, - sliding_window_pattern=2 # Example pattern + sliding_window_pattern=2, # Example pattern ) def _extract_both_caches(self, cache): @@ -737,35 +738,35 @@ def test_static_cache_within_bounds(self): """ static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) expected_state = [0.0] * self.max_cache_len - - for step in range(1, self.max_cache_len + 1): # Test up to cache capacity + + for step in range(1, self.max_cache_len + 1): # Test up to cache capacity pos_idx = step - 1 value = float(step) - + static_cache.update( key_states=torch.tensor([[[[value]]]]), - value_states=torch.tensor([[[[value]]]]), # Use same value for simplicity + value_states=torch.tensor([[[[value]]]]), # Use same value for simplicity layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([pos_idx])} + cache_kwargs={"cache_position": torch.tensor([pos_idx])}, ) - + expected_state[pos_idx] = value stored_keys, stored_values = self._extract_both_caches(static_cache) - + self.assertEqual(stored_keys, expected_state, f"Static Key cache failed at step {step}") self.assertEqual(stored_values, expected_state, f"Static Value cache failed at step {step}") def test_static_cache_out_of_bounds(self): """Test StaticCache raises IndexError for out-of-bounds positions.""" - static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len - + static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len + with self.assertRaises(IndexError): static_cache.update( key_states=torch.tensor([[[[1.0]]]]), value_states=torch.tensor([[[[1.0]]]]), layer_idx=0, - cache_kwargs={"cache_position": pos_out_of_bounds} + cache_kwargs={"cache_position": pos_out_of_bounds}, ) def test_sliding_window_cache(self): @@ -777,30 +778,30 @@ def test_sliding_window_cache(self): step 6 (pos 5): [3.0, 4.0, 5.0, 6.0] """ sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - expected_state = [0.0] * self.window_size - - for step in range(1, 7): # Test beyond window size + expected_state = [0.0] * self.window_size + + for step in range(1, 7): # Test beyond window size pos_idx = step - 1 value = float(step) - + sliding_cache.update( key_states=torch.tensor([[[[value]]]]), value_states=torch.tensor([[[[value]]]]), layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([pos_idx]), "sliding_window": self.window_size} + cache_kwargs={"cache_position": torch.tensor([pos_idx]), "sliding_window": self.window_size}, ) - + # Calculate expected state based on corrected sliding logic clamped_pos = min(pos_idx, self.window_size - 1) - to_shift = pos_idx > self.window_size - 1 - + to_shift = pos_idx > self.window_size - 1 + if to_shift: - expected_state = expected_state[1:] + [0.0] # Shift left, add placeholder - - expected_state[clamped_pos] = value # Insert new value + expected_state = expected_state[1:] + [0.0] # Shift left, add placeholder + + expected_state[clamped_pos] = value # Insert new value # Only verify key cache for simplicity, assuming value is symmetrical - stored_keys, _ = self._extract_both_caches(sliding_cache) + stored_keys, _ = self._extract_both_caches(sliding_cache) self.assertEqual(stored_keys, expected_state, f"SlidingWindowCache failed at step {step}") def test_hybrid_cache_static_mode(self): @@ -813,18 +814,18 @@ def test_hybrid_cache_static_mode(self): """ hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) expected_state = [0.0] * self.max_cache_len - - for step in range(1, 7): # Test beyond cache size + + for step in range(1, 7): # Test beyond cache size pos_idx_clamped = min(step - 1, self.max_cache_len - 1) value = float(step) - + hybrid_cache.update( key_states=torch.tensor([[[[value]]]]), value_states=torch.tensor([[[[value]]]]), layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([pos_idx_clamped])} # Use clamped pos + cache_kwargs={"cache_position": torch.tensor([pos_idx_clamped])}, # Use clamped pos ) - + expected_state[pos_idx_clamped] = value stored_keys, _ = self._extract_both_caches(hybrid_cache) self.assertEqual(stored_keys, expected_state, f"HybridCache (static) failed at step {step}") @@ -839,8 +840,8 @@ def test_hybrid_cache_sliding_mode(self): """ hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) expected_state = [0.0] * self.window_size - - for step in range(1, 7): # Test beyond window size + + for step in range(1, 7): # Test beyond window size pos_idx = step - 1 value = float(step) @@ -848,7 +849,7 @@ def test_hybrid_cache_sliding_mode(self): key_states=torch.tensor([[[[value]]]]), value_states=torch.tensor([[[[value]]]]), layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([pos_idx]), "sliding_window": self.window_size} + cache_kwargs={"cache_position": torch.tensor([pos_idx]), "sliding_window": self.window_size}, ) clamped_pos = min(pos_idx, self.window_size - 1) @@ -872,11 +873,15 @@ def test_sliding_window_cache_prompt_longer_than_max_cache_len(self): key_states=key_states, value_states=value_states, layer_idx=0, - cache_kwargs={"cache_position": cache_position, "sliding_window": self.window_size} + cache_kwargs={"cache_position": cache_position, "sliding_window": self.window_size}, ) stored_keys, stored_values = self._extract_both_caches(sliding_cache) - self.assertEqual(stored_keys, values[-self.window_size:], "SlidingWindowCache did not keep last window tokens") - self.assertEqual(stored_values, values[-self.window_size:], "SlidingWindowCache did not keep last window tokens") + self.assertEqual( + stored_keys, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens" + ) + self.assertEqual( + stored_values, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens" + ) def test_hybrid_cache_prompt_longer_than_max_cache_len(self): """Test HybridCache when prompt length > max_cache_len (should keep only last max_cache_len tokens in sliding mode).""" @@ -891,8 +896,12 @@ def test_hybrid_cache_prompt_longer_than_max_cache_len(self): key_states=key_states, value_states=value_states, layer_idx=0, - cache_kwargs={"cache_position": cache_position, "sliding_window": self.window_size} + cache_kwargs={"cache_position": cache_position, "sliding_window": self.window_size}, ) stored_keys, stored_values = self._extract_both_caches(hybrid_cache) - self.assertEqual(stored_keys, values[-self.window_size:], "HybridCache (sliding) did not keep last window tokens") - self.assertEqual(stored_values, values[-self.window_size:], "HybridCache (sliding) did not keep last window tokens") \ No newline at end of file + self.assertEqual( + stored_keys, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens" + ) + self.assertEqual( + stored_values, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens" + ) From 6b765bd556c047aa2ecdd94118f24dc6df874ca9 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 6 May 2025 12:56:19 +0200 Subject: [PATCH 03/19] ruff --- tests/utils/test_cache_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index d20a30835172..263f1715b0cf 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -14,9 +14,9 @@ import copy import unittest +from types import SimpleNamespace from parameterized import parameterized -from types import SimpleNamespace from transformers import set_seed from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS @@ -45,10 +45,10 @@ ClvpForCausalLM, DynamicCache, GenerationConfig, + HybridCache, LlamaConfig, - StaticCache, SlidingWindowCache, - HybridCache, + StaticCache, convert_and_export_with_cache, ) From 32cd5f6b07ca4303a0eec9bef4e380b00f7b7ddd Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 6 May 2025 13:47:36 +0200 Subject: [PATCH 04/19] fix hybrid cache in torch compile --- src/transformers/cache_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 6affdd7a06f0..dbb2337e34dc 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1765,9 +1765,7 @@ def __init__( ) layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC - self.is_sliding = torch.tensor( - [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool - ) + self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)] self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) @@ -1780,7 +1778,7 @@ def __init__( layer_device = device # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape + cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) torch._dynamo.mark_static_address(new_layer_key_cache) @@ -1801,7 +1799,7 @@ def update( if cache_position is None: raise ValueError("`cache_position` must be provided for HybridCache.") - is_sliding_layer = self.is_sliding[layer_idx].item() + is_sliding_layer = self.is_sliding_list[layer_idx] # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used # when the cache is initialized in the forward pass (e.g. Gemma2) From ec26e6995b9bd45af22ea8ace81c88710f57a140 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 6 May 2025 19:18:39 +0200 Subject: [PATCH 05/19] joaos suggestions --- src/transformers/cache_utils.py | 31 +++--- tests/utils/test_cache_utils.py | 186 ++++++++++++++++++++++++-------- 2 files changed, 155 insertions(+), 62 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index dbb2337e34dc..0b603875800e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -96,7 +96,7 @@ def _sliding_cache_update_logic( slicing = torch.arange(max_cache_len, device=value_states.device) current_seq_len = cache_position[-1] + 1 # Use last position to determine current length should_shift = current_seq_len > max_cache_len - indices = (slicing + should_shift.int()) % max_cache_len + indices = (slicing + should_shift.sum()) % max_cache_len k_out_shifted = k_cache[:, :, indices] v_out_shifted = v_cache[:, :, indices] @@ -1396,9 +1396,9 @@ class SlidingWindowCache(StaticCache): if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. - The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + The `should_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: - indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window + indices = (slicing + should_shift[-1].sum()-1) % self.config.sliding_window tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, @@ -1472,18 +1472,6 @@ def __init__( layer_device_map=layer_device_map, ) - def _sliding_update_logic(self, layer_idx, cache_position, key_states, value_states): - """Performs the actual cache update for sliding window cache.""" - # Call the standalone utility function - return _sliding_cache_update_logic( - self.key_cache[layer_idx], - self.value_cache[layer_idx], - key_states, - value_states, - cache_position, - self.max_cache_len, - ) - def update( self, key_states: torch.Tensor, @@ -1498,12 +1486,17 @@ def update( if cache_position is None: raise ValueError("`cache_position` must be provided for SlidingWindowCache.") - # Ensure correct dtype key_states = key_states.to(self.key_cache[layer_idx].dtype) value_states = value_states.to(self.value_cache[layer_idx].dtype) - # Call the extracted logic - return self._sliding_update_logic(layer_idx, cache_position, key_states, value_states) + return _sliding_cache_update_logic( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + key_states, + value_states, + cache_position, + self.max_cache_len, + ) def get_max_cache_shape(self) -> Optional[int]: return self.max_cache_len @@ -1770,7 +1763,7 @@ def __init__( self.value_cache: List[torch.Tensor] = [] global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim) - device = torch.device(device) if device is not None and isinstance(device, str) else None + device = torch.device(device) if device is not None else None for i in range(config.num_hidden_layers): if layer_device_map is not None: layer_device = layer_device_map[i] diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 5923567f0997..40c1d2018474 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -14,7 +14,6 @@ import copy import unittest -from types import SimpleNamespace from parameterized import parameterized @@ -48,6 +47,7 @@ GenerationConfig, HybridCache, LlamaConfig, + Gemma2Config, SlidingWindowCache, StaticCache, convert_and_export_with_cache, @@ -719,32 +719,22 @@ def test_hybrid_cache_exportability(self): class SyntheticCacheTest(unittest.TestCase): - """ - Synthetic tests for StaticCache, SlidingWindowCache, and HybridCache. - Uses window_size=4, max_cache_len=4 for all tests. - """ + """Tests cache behavior with simple dummy data.""" def setUp(self): """Set up common configuration for all tests.""" self.window_size = 4 self.max_cache_len = 4 - self.config = SimpleNamespace( + self.config = Gemma2Config( num_hidden_layers=1, num_key_value_heads=1, num_attention_heads=1, head_dim=1, hidden_size=1, sliding_window=self.window_size, - sliding_window_pattern=2, # Example pattern + sliding_window_pattern=2, ) - def _extract_both_caches(self, cache): - """Helper to extract flattened key and value states from the cache.""" - # Assumes layer 0, batch 0, head 0 - k = cache.key_cache[0][0, 0, :, 0].tolist() - v = cache.value_cache[0][0, 0, :, 0].tolist() - return k, v - def test_static_cache_within_bounds(self): """ Test StaticCache preserves values at written positions within bounds. @@ -769,10 +759,11 @@ def test_static_cache_within_bounds(self): ) expected_state[pos_idx] = value - stored_keys, stored_values = self._extract_both_caches(static_cache) + k = static_cache.key_cache[0][0, 0, :, 0].tolist() + v = static_cache.value_cache[0][0, 0, :, 0].tolist() - self.assertEqual(stored_keys, expected_state, f"Static Key cache failed at step {step}") - self.assertEqual(stored_values, expected_state, f"Static Value cache failed at step {step}") + self.assertEqual(k, expected_state, f"Static Key cache failed at step {step}") + self.assertEqual(v, expected_state, f"Static Value cache failed at step {step}") def test_static_cache_out_of_bounds(self): """Test StaticCache raises IndexError for out-of-bounds positions.""" @@ -798,7 +789,7 @@ def test_sliding_window_cache(self): sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) expected_state = [0.0] * self.window_size - for step in range(1, 7): # Test beyond window size + for step in range(1, self.window_size * 2): # Test beyond window size pos_idx = step - 1 value = float(step) @@ -819,34 +810,31 @@ def test_sliding_window_cache(self): expected_state[clamped_pos] = value # Insert new value # Only verify key cache for simplicity, assuming value is symmetrical - stored_keys, _ = self._extract_both_caches(sliding_cache) - self.assertEqual(stored_keys, expected_state, f"SlidingWindowCache failed at step {step}") + k = sliding_cache.key_cache[0][0, 0, :, 0].tolist() + self.assertEqual(k, expected_state, f"SlidingWindowCache failed at step {step}") def test_hybrid_cache_static_mode(self): """ - Test HybridCache acts like StaticCache when 'sliding_window' is absent. + Test HybridCache acts like StaticCache in the static layers. Example (max_cache_len=4): step 4 (pos 3): [1.0, 2.0, 3.0, 4.0] - step 5 (pos 3): [1.0, 2.0, 3.0, 5.0] (pos clamped, overwrites last) - step 6 (pos 3): [1.0, 2.0, 3.0, 6.0] (pos clamped, overwrites last) """ - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + config = copy.deepcopy(self.config) + config.sliding_window_pattern = 1 # Force layer 0 to be 1 % 1 = 0 (static) + hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) expected_state = [0.0] * self.max_cache_len - for step in range(1, 7): # Test beyond cache size - pos_idx_clamped = min(step - 1, self.max_cache_len - 1) - value = float(step) - + for step in range(0, self.max_cache_len): # Test up to cache size hybrid_cache.update( - key_states=torch.tensor([[[[value]]]]), - value_states=torch.tensor([[[[value]]]]), + key_states=torch.tensor([[[[float(step + 1)]]]]), + value_states=torch.tensor([[[[float(step + 1)]]]]), layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([pos_idx_clamped])}, # Use clamped pos + cache_kwargs={"cache_position": torch.tensor([step])}, ) - expected_state[pos_idx_clamped] = value - stored_keys, _ = self._extract_both_caches(hybrid_cache) - self.assertEqual(stored_keys, expected_state, f"HybridCache (static) failed at step {step}") + expected_state[step] = float(step + 1) + k = hybrid_cache.key_cache[0][0, 0, :, 0].tolist() + self.assertEqual(k, expected_state, f"HybridCache (static) failed at step {step}") def test_hybrid_cache_sliding_mode(self): """ @@ -859,7 +847,7 @@ def test_hybrid_cache_sliding_mode(self): hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) expected_state = [0.0] * self.window_size - for step in range(1, 7): # Test beyond window size + for step in range(1, self.window_size * 2): # Test beyond window size pos_idx = step - 1 value = float(step) @@ -876,8 +864,8 @@ def test_hybrid_cache_sliding_mode(self): expected_state = expected_state[1:] + [0.0] expected_state[clamped_pos] = value - stored_keys, _ = self._extract_both_caches(hybrid_cache) - self.assertEqual(stored_keys, expected_state, f"HybridCache (sliding) failed at step {step}") + k = hybrid_cache.key_cache[0][0, 0, :, 0].tolist() + self.assertEqual(k, expected_state, f"HybridCache (sliding) failed at step {step}") def test_sliding_window_cache_prompt_longer_than_max_cache_len(self): """Test SlidingWindowCache when prompt length > max_cache_len (should keep only last max_cache_len tokens).""" @@ -893,12 +881,13 @@ def test_sliding_window_cache_prompt_longer_than_max_cache_len(self): layer_idx=0, cache_kwargs={"cache_position": cache_position, "sliding_window": self.window_size}, ) - stored_keys, stored_values = self._extract_both_caches(sliding_cache) + k = sliding_cache.key_cache[0][0, 0, :, 0].tolist() + v = sliding_cache.value_cache[0][0, 0, :, 0].tolist() self.assertEqual( - stored_keys, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens" + k, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens" ) self.assertEqual( - stored_values, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens" + v, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens" ) def test_hybrid_cache_prompt_longer_than_max_cache_len(self): @@ -916,10 +905,121 @@ def test_hybrid_cache_prompt_longer_than_max_cache_len(self): layer_idx=0, cache_kwargs={"cache_position": cache_position, "sliding_window": self.window_size}, ) - stored_keys, stored_values = self._extract_both_caches(hybrid_cache) + k = hybrid_cache.key_cache[0][0, 0, :, 0].tolist() + v = hybrid_cache.value_cache[0][0, 0, :, 0].tolist() self.assertEqual( - stored_keys, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens" + k, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens" ) self.assertEqual( - stored_values, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens" + v, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens" + ) + + def test_static_cache_hardcoded(self): + """Test StaticCache with manually prefilled states and hardcoded assertions.""" + # Scenario 1: Fill up to near capacity + static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 0.0, 0.0]]]], dtype=torch.float) + static_cache.key_cache[0].copy_(prefill) + static_cache.value_cache[0].copy_(prefill) + static_cache.update( + key_states=torch.tensor([[[[3.0]]]]), value_states=torch.tensor([[[[3.0]]]]), + layer_idx=0, cache_kwargs={"cache_position": torch.tensor([2])} + ) + self.assertEqual(static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed") + + # Scenario 2: Fill to capacity + static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 3.0, 0.0]]]], dtype=torch.float) + static_cache.key_cache[0].copy_(prefill) + static_cache.value_cache[0].copy_(prefill) + static_cache.update( + key_states=torch.tensor([[[[4.0]]]]), value_states=torch.tensor([[[[4.0]]]]), + layer_idx=0, cache_kwargs={"cache_position": torch.tensor([3])} + ) + self.assertEqual(static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed") + + def test_sliding_window_cache_hardcode(self): + """Test SlidingWindowCache with manually prefilled states and hardcoded assertions.""" + # Scenario 1: Update within window, no slide yet + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 0.0, 0.0]]]], dtype=torch.float) + sliding_cache.key_cache[0].copy_(prefill) + sliding_cache.value_cache[0].copy_(prefill) + sliding_cache.update( + key_states=torch.tensor([[[[3.0]]]]), value_states=torch.tensor([[[[3.0]]]]), layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size} + ) + self.assertEqual(sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "SlidingWindowCache Scenario 1 failed") + + # Scenario 2: Update causing slide + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float) + sliding_cache.key_cache[0].copy_(prefill) + sliding_cache.value_cache[0].copy_(prefill) + sliding_cache.update( + key_states=torch.tensor([[[[5.0]]]]), value_states=torch.tensor([[[[5.0]]]]), layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size} + ) + self.assertEqual(sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "SlidingWindowCache Scenario 2 failed") + + def test_hybrid_cache_static_mode_hardcoded(self): + """Test HybridCache in static mode with hardcoded assertions.""" + config = copy.deepcopy(self.config) + config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0) + + # Scenario 1 + hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 0.0, 0.0]]]], dtype=torch.float) + hybrid_cache.key_cache[0].copy_(prefill) + hybrid_cache.value_cache[0].copy_(prefill) + hybrid_cache.update( + key_states=torch.tensor([[[[3.0]]]]), value_states=torch.tensor([[[[3.0]]]]), + layer_idx=0, cache_kwargs={"cache_position": torch.tensor([2])} + ) + self.assertEqual(hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Static Scenario 1 failed") + + # Scenario 2 + hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 3.0, 0.0]]]], dtype=torch.float) + hybrid_cache.key_cache[0].copy_(prefill) + hybrid_cache.value_cache[0].copy_(prefill) + hybrid_cache.update( + key_states=torch.tensor([[[[4.0]]]]), value_states=torch.tensor([[[[4.0]]]]), + layer_idx=0, cache_kwargs={"cache_position": torch.tensor([3])} + ) + self.assertEqual(hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "HybridCache Static Scenario 2 failed") + + def test_hybrid_cache_sliding_mode_hardcoded(self): + """Test HybridCache in sliding mode with hardcoded assertions.""" + # Scenario 1: Update within window, no slide yet + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 0.0, 0.0]]]], dtype=torch.float) + hybrid_cache.key_cache[0].copy_(prefill) + hybrid_cache.value_cache[0].copy_(prefill) + hybrid_cache.update( + key_states=torch.tensor([[[[3.0]]]]), value_states=torch.tensor([[[[3.0]]]]), layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size} + ) + self.assertEqual(hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Sliding Scenario 1 failed") + + # Scenario 2: Update causing first slide + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float) + hybrid_cache.key_cache[0].copy_(prefill) + hybrid_cache.value_cache[0].copy_(prefill) + hybrid_cache.update( + key_states=torch.tensor([[[[5.0]]]]), value_states=torch.tensor([[[[5.0]]]]), layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size} + ) + self.assertEqual(hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "HybridCache Sliding Scenario 2 failed") + + # Scenario 3: Update causing subsequent slide + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([[[[v] for v in [2.0, 3.0, 4.0, 5.0]]]], dtype=torch.float) + hybrid_cache.key_cache[0].copy_(prefill) + hybrid_cache.value_cache[0].copy_(prefill) + hybrid_cache.update( + key_states=torch.tensor([[[[6.0]]]]), value_states=torch.tensor([[[[6.0]]]]), layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size} ) + self.assertEqual(hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 3 failed") From 95805f33a25c2c5a7edd7f6b421ecc6c54e9aec3 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 6 May 2025 19:20:27 +0200 Subject: [PATCH 06/19] ruff --- tests/utils/test_cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 40c1d2018474..badf3e1e1c5f 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -44,10 +44,10 @@ Cache, ClvpForCausalLM, DynamicCache, + Gemma2Config, GenerationConfig, HybridCache, LlamaConfig, - Gemma2Config, SlidingWindowCache, StaticCache, convert_and_export_with_cache, From f08ea2013358381c87558965b6e3f2ac49172282 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 6 May 2025 19:30:54 +0200 Subject: [PATCH 07/19] Trigger Build From b3b01330a1c7111ed7d819748653b158f21d4814 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 6 May 2025 19:33:12 +0200 Subject: [PATCH 08/19] ruff --- tests/utils/test_cache_utils.py | 120 +++++++++++++++++++++----------- 1 file changed, 81 insertions(+), 39 deletions(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index badf3e1e1c5f..73e86f03d70a 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -883,12 +883,8 @@ def test_sliding_window_cache_prompt_longer_than_max_cache_len(self): ) k = sliding_cache.key_cache[0][0, 0, :, 0].tolist() v = sliding_cache.value_cache[0][0, 0, :, 0].tolist() - self.assertEqual( - k, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens" - ) - self.assertEqual( - v, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens" - ) + self.assertEqual(k, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens") + self.assertEqual(v, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens") def test_hybrid_cache_prompt_longer_than_max_cache_len(self): """Test HybridCache when prompt length > max_cache_len (should keep only last max_cache_len tokens in sliding mode).""" @@ -907,12 +903,8 @@ def test_hybrid_cache_prompt_longer_than_max_cache_len(self): ) k = hybrid_cache.key_cache[0][0, 0, :, 0].tolist() v = hybrid_cache.value_cache[0][0, 0, :, 0].tolist() - self.assertEqual( - k, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens" - ) - self.assertEqual( - v, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens" - ) + self.assertEqual(k, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens") + self.assertEqual(v, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens") def test_static_cache_hardcoded(self): """Test StaticCache with manually prefilled states and hardcoded assertions.""" @@ -922,10 +914,14 @@ def test_static_cache_hardcoded(self): static_cache.key_cache[0].copy_(prefill) static_cache.value_cache[0].copy_(prefill) static_cache.update( - key_states=torch.tensor([[[[3.0]]]]), value_states=torch.tensor([[[[3.0]]]]), - layer_idx=0, cache_kwargs={"cache_position": torch.tensor([2])} + key_states=torch.tensor([[[[3.0]]]]), + value_states=torch.tensor([[[[3.0]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2])}, + ) + self.assertEqual( + static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" ) - self.assertEqual(static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed") # Scenario 2: Fill to capacity static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) @@ -933,10 +929,14 @@ def test_static_cache_hardcoded(self): static_cache.key_cache[0].copy_(prefill) static_cache.value_cache[0].copy_(prefill) static_cache.update( - key_states=torch.tensor([[[[4.0]]]]), value_states=torch.tensor([[[[4.0]]]]), - layer_idx=0, cache_kwargs={"cache_position": torch.tensor([3])} + key_states=torch.tensor([[[[4.0]]]]), + value_states=torch.tensor([[[[4.0]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + self.assertEqual( + static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" ) - self.assertEqual(static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed") def test_sliding_window_cache_hardcode(self): """Test SlidingWindowCache with manually prefilled states and hardcoded assertions.""" @@ -946,10 +946,16 @@ def test_sliding_window_cache_hardcode(self): sliding_cache.key_cache[0].copy_(prefill) sliding_cache.value_cache[0].copy_(prefill) sliding_cache.update( - key_states=torch.tensor([[[[3.0]]]]), value_states=torch.tensor([[[[3.0]]]]), layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size} + key_states=torch.tensor([[[[3.0]]]]), + value_states=torch.tensor([[[[3.0]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, + ) + self.assertEqual( + sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 0.0], + "SlidingWindowCache Scenario 1 failed", ) - self.assertEqual(sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "SlidingWindowCache Scenario 1 failed") # Scenario 2: Update causing slide sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) @@ -957,10 +963,16 @@ def test_sliding_window_cache_hardcode(self): sliding_cache.key_cache[0].copy_(prefill) sliding_cache.value_cache[0].copy_(prefill) sliding_cache.update( - key_states=torch.tensor([[[[5.0]]]]), value_states=torch.tensor([[[[5.0]]]]), layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size} + key_states=torch.tensor([[[[5.0]]]]), + value_states=torch.tensor([[[[5.0]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, + ) + self.assertEqual( + sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + [2.0, 3.0, 4.0, 5.0], + "SlidingWindowCache Scenario 2 failed", ) - self.assertEqual(sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "SlidingWindowCache Scenario 2 failed") def test_hybrid_cache_static_mode_hardcoded(self): """Test HybridCache in static mode with hardcoded assertions.""" @@ -973,10 +985,16 @@ def test_hybrid_cache_static_mode_hardcoded(self): hybrid_cache.key_cache[0].copy_(prefill) hybrid_cache.value_cache[0].copy_(prefill) hybrid_cache.update( - key_states=torch.tensor([[[[3.0]]]]), value_states=torch.tensor([[[[3.0]]]]), - layer_idx=0, cache_kwargs={"cache_position": torch.tensor([2])} + key_states=torch.tensor([[[[3.0]]]]), + value_states=torch.tensor([[[[3.0]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2])}, + ) + self.assertEqual( + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 0.0], + "HybridCache Static Scenario 1 failed", ) - self.assertEqual(hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Static Scenario 1 failed") # Scenario 2 hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) @@ -984,10 +1002,16 @@ def test_hybrid_cache_static_mode_hardcoded(self): hybrid_cache.key_cache[0].copy_(prefill) hybrid_cache.value_cache[0].copy_(prefill) hybrid_cache.update( - key_states=torch.tensor([[[[4.0]]]]), value_states=torch.tensor([[[[4.0]]]]), - layer_idx=0, cache_kwargs={"cache_position": torch.tensor([3])} + key_states=torch.tensor([[[[4.0]]]]), + value_states=torch.tensor([[[[4.0]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + self.assertEqual( + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 4.0], + "HybridCache Static Scenario 2 failed", ) - self.assertEqual(hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "HybridCache Static Scenario 2 failed") def test_hybrid_cache_sliding_mode_hardcoded(self): """Test HybridCache in sliding mode with hardcoded assertions.""" @@ -997,10 +1021,16 @@ def test_hybrid_cache_sliding_mode_hardcoded(self): hybrid_cache.key_cache[0].copy_(prefill) hybrid_cache.value_cache[0].copy_(prefill) hybrid_cache.update( - key_states=torch.tensor([[[[3.0]]]]), value_states=torch.tensor([[[[3.0]]]]), layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size} + key_states=torch.tensor([[[[3.0]]]]), + value_states=torch.tensor([[[[3.0]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, + ) + self.assertEqual( + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 0.0], + "HybridCache Sliding Scenario 1 failed", ) - self.assertEqual(hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Sliding Scenario 1 failed") # Scenario 2: Update causing first slide hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) @@ -1008,10 +1038,16 @@ def test_hybrid_cache_sliding_mode_hardcoded(self): hybrid_cache.key_cache[0].copy_(prefill) hybrid_cache.value_cache[0].copy_(prefill) hybrid_cache.update( - key_states=torch.tensor([[[[5.0]]]]), value_states=torch.tensor([[[[5.0]]]]), layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size} + key_states=torch.tensor([[[[5.0]]]]), + value_states=torch.tensor([[[[5.0]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, + ) + self.assertEqual( + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + [2.0, 3.0, 4.0, 5.0], + "HybridCache Sliding Scenario 2 failed", ) - self.assertEqual(hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "HybridCache Sliding Scenario 2 failed") # Scenario 3: Update causing subsequent slide hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) @@ -1019,7 +1055,13 @@ def test_hybrid_cache_sliding_mode_hardcoded(self): hybrid_cache.key_cache[0].copy_(prefill) hybrid_cache.value_cache[0].copy_(prefill) hybrid_cache.update( - key_states=torch.tensor([[[[6.0]]]]), value_states=torch.tensor([[[[6.0]]]]), layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size} + key_states=torch.tensor([[[[6.0]]]]), + value_states=torch.tensor([[[[6.0]]]]), + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size}, + ) + self.assertEqual( + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + [3.0, 4.0, 5.0, 6.0], + "HybridCache Sliding Scenario 3 failed", ) - self.assertEqual(hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 3 failed") From 3de75053b4598b82d2ed9637218be6ff3d1cda73 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Thu, 8 May 2025 11:40:25 +0200 Subject: [PATCH 09/19] Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0b603875800e..6cf78906df3e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -22,7 +22,7 @@ # Utility functions for static/sliding cache update logic -def _static_cache_update_logic( +def _static_cache_update( k_cache: torch.Tensor, v_cache: torch.Tensor, key_states: torch.Tensor, From 214e517b92398afe5edb4a906a3a1234a234180e Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Fri, 9 May 2025 15:04:43 +0200 Subject: [PATCH 10/19] suggestions --- src/transformers/cache_utils.py | 10 +- tests/utils/test_cache_utils.py | 407 +++++++++++++------------------- 2 files changed, 174 insertions(+), 243 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 6cf78906df3e..a957366e2951 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -60,7 +60,7 @@ def _static_cache_update( return k_cache, v_cache -def _sliding_cache_update_logic( +def _sliding_cache_update( k_cache: torch.Tensor, v_cache: torch.Tensor, key_states: torch.Tensor, @@ -1363,7 +1363,7 @@ def update( key_states = key_states.to(self.key_cache[layer_idx].dtype) value_states = value_states.to(self.value_cache[layer_idx].dtype) - return _static_cache_update_logic( + return _static_cache_update( self.key_cache[layer_idx], self.value_cache[layer_idx], key_states, @@ -1489,7 +1489,7 @@ def update( key_states = key_states.to(self.key_cache[layer_idx].dtype) value_states = value_states.to(self.value_cache[layer_idx].dtype) - return _sliding_cache_update_logic( + return _sliding_cache_update( self.key_cache[layer_idx], self.value_cache[layer_idx], key_states, @@ -1807,7 +1807,7 @@ def update( value_states = value_states.to(v_cache.dtype) if is_sliding_layer: - return _sliding_cache_update_logic( + return _sliding_cache_update( k_cache, v_cache, key_states, @@ -1816,7 +1816,7 @@ def update( k_cache.shape[2], # Use actual cache dim as max cache len ) else: - return _static_cache_update_logic(k_cache, v_cache, key_states, value_states, cache_position) + return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position) def get_max_cache_shape(self) -> Optional[int]: return self.max_cache_len diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 73e86f03d70a..c079452ef93f 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -722,7 +722,7 @@ class SyntheticCacheTest(unittest.TestCase): """Tests cache behavior with simple dummy data.""" def setUp(self): - """Set up common configuration for all tests.""" + """Set up common configuration and cache instances for all tests.""" self.window_size = 4 self.max_cache_len = 4 self.config = Gemma2Config( @@ -732,336 +732,267 @@ def setUp(self): head_dim=1, hidden_size=1, sliding_window=self.window_size, - sliding_window_pattern=2, + sliding_window_pattern=2, # Default pattern for hybrid sliding ) - - def test_static_cache_within_bounds(self): - """ - Test StaticCache preserves values at written positions within bounds. - Example (max_cache_len=4): - step 1 (pos 0): [1.0, 0.0, 0.0, 0.0] - step 2 (pos 1): [1.0, 2.0, 0.0, 0.0] - step 3 (pos 2): [1.0, 2.0, 3.0, 0.0] - step 4 (pos 3): [1.0, 2.0, 3.0, 4.0] - """ - static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - expected_state = [0.0] * self.max_cache_len - - for step in range(1, self.max_cache_len + 1): # Test up to cache capacity - pos_idx = step - 1 - value = float(step) - - static_cache.update( - key_states=torch.tensor([[[[value]]]]), - value_states=torch.tensor([[[[value]]]]), # Use same value for simplicity - layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([pos_idx])}, - ) - - expected_state[pos_idx] = value - k = static_cache.key_cache[0][0, 0, :, 0].tolist() - v = static_cache.value_cache[0][0, 0, :, 0].tolist() - - self.assertEqual(k, expected_state, f"Static Key cache failed at step {step}") - self.assertEqual(v, expected_state, f"Static Value cache failed at step {step}") + self.static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + self.sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + self.hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) def test_static_cache_out_of_bounds(self): """Test StaticCache raises IndexError for out-of-bounds positions.""" - static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + self.static_cache.reset() pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len with self.assertRaises(IndexError): - static_cache.update( + self.static_cache.update( key_states=torch.tensor([[[[1.0]]]]), value_states=torch.tensor([[[[1.0]]]]), layer_idx=0, cache_kwargs={"cache_position": pos_out_of_bounds}, ) - def test_sliding_window_cache(self): - """ - Test SlidingWindowCache accumulates then slides. - Example (window_size=4): - step 4 (pos 3): [1.0, 2.0, 3.0, 4.0] - step 5 (pos 4): [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1) - step 6 (pos 5): [3.0, 4.0, 5.0, 6.0] - """ - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - expected_state = [0.0] * self.window_size - - for step in range(1, self.window_size * 2): # Test beyond window size - pos_idx = step - 1 - value = float(step) - - sliding_cache.update( - key_states=torch.tensor([[[[value]]]]), - value_states=torch.tensor([[[[value]]]]), - layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([pos_idx]), "sliding_window": self.window_size}, - ) - - # Calculate expected state based on corrected sliding logic - clamped_pos = min(pos_idx, self.window_size - 1) - to_shift = pos_idx > self.window_size - 1 - - if to_shift: - expected_state = expected_state[1:] + [0.0] # Shift left, add placeholder + def test_static_cache(self): + """Test StaticCache with manually prefilled states and hardcoded assertions. - expected_state[clamped_pos] = value # Insert new value + Scenario 1: Fill up to near capacity + prefill: [1.0, 2.0, 0.0, 0.0] + update pos 2: [1.0, 2.0, 3.0, 0.0] - # Only verify key cache for simplicity, assuming value is symmetrical - k = sliding_cache.key_cache[0][0, 0, :, 0].tolist() - self.assertEqual(k, expected_state, f"SlidingWindowCache failed at step {step}") - - def test_hybrid_cache_static_mode(self): + Scenario 2: Fill to capacity + update pos 3: [1.0, 2.0, 3.0, 4.0] """ - Test HybridCache acts like StaticCache in the static layers. - Example (max_cache_len=4): - step 4 (pos 3): [1.0, 2.0, 3.0, 4.0] - """ - config = copy.deepcopy(self.config) - config.sliding_window_pattern = 1 # Force layer 0 to be 1 % 1 = 0 (static) - hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) - expected_state = [0.0] * self.max_cache_len - - for step in range(0, self.max_cache_len): # Test up to cache size - hybrid_cache.update( - key_states=torch.tensor([[[[float(step + 1)]]]]), - value_states=torch.tensor([[[[float(step + 1)]]]]), - layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([step])}, - ) - - expected_state[step] = float(step + 1) - k = hybrid_cache.key_cache[0][0, 0, :, 0].tolist() - self.assertEqual(k, expected_state, f"HybridCache (static) failed at step {step}") - - def test_hybrid_cache_sliding_mode(self): - """ - Test HybridCache acts like SlidingWindowCache when 'sliding_window' is present. - Example (window_size=4): - step 4 (pos 3): [1.0, 2.0, 3.0, 4.0] - step 5 (pos 4): [2.0, 3.0, 4.0, 5.0] - step 6 (pos 5): [3.0, 4.0, 5.0, 6.0] - """ - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - expected_state = [0.0] * self.window_size - - for step in range(1, self.window_size * 2): # Test beyond window size - pos_idx = step - 1 - value = float(step) + self.static_cache.reset() - hybrid_cache.update( - key_states=torch.tensor([[[[value]]]]), - value_states=torch.tensor([[[[value]]]]), - layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([pos_idx]), "sliding_window": self.window_size}, - ) - - clamped_pos = min(pos_idx, self.window_size - 1) - to_shift = pos_idx > self.window_size - 1 - if to_shift: - expected_state = expected_state[1:] + [0.0] - expected_state[clamped_pos] = value - - k = hybrid_cache.key_cache[0][0, 0, :, 0].tolist() - self.assertEqual(k, expected_state, f"HybridCache (sliding) failed at step {step}") - - def test_sliding_window_cache_prompt_longer_than_max_cache_len(self): - """Test SlidingWindowCache when prompt length > max_cache_len (should keep only last max_cache_len tokens).""" - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - prompt_len = self.max_cache_len + 2 # e.g., 6 if max_cache_len=4 - values = [float(i + 1) for i in range(prompt_len)] - key_states = torch.tensor([[[[v] for v in values]]]) # shape (1,1,6,1) - value_states = torch.tensor([[[[v] for v in values]]]) # shape (1,1,6,1) - cache_position = torch.arange(prompt_len) - sliding_cache.update( - key_states=key_states, - value_states=value_states, - layer_idx=0, - cache_kwargs={"cache_position": cache_position, "sliding_window": self.window_size}, - ) - k = sliding_cache.key_cache[0][0, 0, :, 0].tolist() - v = sliding_cache.value_cache[0][0, 0, :, 0].tolist() - self.assertEqual(k, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens") - self.assertEqual(v, values[-self.window_size :], "SlidingWindowCache did not keep last window tokens") - - def test_hybrid_cache_prompt_longer_than_max_cache_len(self): - """Test HybridCache when prompt length > max_cache_len (should keep only last max_cache_len tokens in sliding mode).""" - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - prompt_len = self.max_cache_len + 2 # e.g., 6 if max_cache_len=4 - values = [float(i + 1) for i in range(prompt_len)] - key_states = torch.tensor([[[[v] for v in values]]]) # shape (1,1,6,1) - value_states = torch.tensor([[[[v] for v in values]]]) # shape (1,1,6,1) - cache_position = torch.arange(prompt_len) - # Use sliding mode - hybrid_cache.update( - key_states=key_states, - value_states=value_states, - layer_idx=0, - cache_kwargs={"cache_position": cache_position, "sliding_window": self.window_size}, - ) - k = hybrid_cache.key_cache[0][0, 0, :, 0].tolist() - v = hybrid_cache.value_cache[0][0, 0, :, 0].tolist() - self.assertEqual(k, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens") - self.assertEqual(v, values[-self.window_size :], "HybridCache (sliding) did not keep last window tokens") - - def test_static_cache_hardcoded(self): - """Test StaticCache with manually prefilled states and hardcoded assertions.""" # Scenario 1: Fill up to near capacity - static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 0.0, 0.0]]]], dtype=torch.float) - static_cache.key_cache[0].copy_(prefill) - static_cache.value_cache[0].copy_(prefill) - static_cache.update( - key_states=torch.tensor([[[[3.0]]]]), - value_states=torch.tensor([[[[3.0]]]]), + prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] + self.static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None) + self.static_cache.update( + key_states=torch.tensor(3.0)[None, None, None, None], + value_states=torch.tensor(3.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" + self.static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" ) # Scenario 2: Fill to capacity - static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 3.0, 0.0]]]], dtype=torch.float) - static_cache.key_cache[0].copy_(prefill) - static_cache.value_cache[0].copy_(prefill) - static_cache.update( - key_states=torch.tensor([[[[4.0]]]]), - value_states=torch.tensor([[[[4.0]]]]), + self.static_cache.update( + key_states=torch.tensor(4.0)[None, None, None, None], + value_states=torch.tensor(4.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" + self.static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" ) - def test_sliding_window_cache_hardcode(self): - """Test SlidingWindowCache with manually prefilled states and hardcoded assertions.""" + def test_sliding_window_cache(self): + """Test SlidingWindowCache with manually prefilled states and hardcoded assertions. + + Scenario 1: Update within window, no slide yet + prefill: [1.0, 2.0, 0.0, 0.0] + update pos 2: [1.0, 2.0, 3.0, 0.0] + + Scenario 2: Update causing slide + prefill: [1.0, 2.0, 3.0, 4.0] + update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1) + + Scenario 3: Long prompt handling (prompt_len > window_size) + input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) + """ + self.sliding_cache.reset() + # Scenario 1: Update within window, no slide yet - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 0.0, 0.0]]]], dtype=torch.float) - sliding_cache.key_cache[0].copy_(prefill) - sliding_cache.value_cache[0].copy_(prefill) - sliding_cache.update( - key_states=torch.tensor([[[[3.0]]]]), - value_states=torch.tensor([[[[3.0]]]]), + prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] + self.sliding_cache.update( + key_states=prefill, + value_states=prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + ) + self.sliding_cache.update( + key_states=torch.tensor(3.0)[None, None, None, None], + value_states=torch.tensor(3.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + self.sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "SlidingWindowCache Scenario 1 failed", ) # Scenario 2: Update causing slide - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float) - sliding_cache.key_cache[0].copy_(prefill) - sliding_cache.value_cache[0].copy_(prefill) - sliding_cache.update( - key_states=torch.tensor([[[[5.0]]]]), - value_states=torch.tensor([[[[5.0]]]]), + self.sliding_cache.reset() + prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] + self.sliding_cache.update( + key_states=prefill, + value_states=prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + ) + self.sliding_cache.update( + key_states=torch.tensor(5.0)[None, None, None, None], + value_states=torch.tensor(5.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + self.sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "SlidingWindowCache Scenario 2 failed", ) - def test_hybrid_cache_static_mode_hardcoded(self): - """Test HybridCache in static mode with hardcoded assertions.""" + # Scenario 3: Long prompt handling + self.sliding_cache.reset() + long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] + self.sliding_cache.update( + key_states=long_prefill, + value_states=long_prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, + ) + self.assertEqual( + self.sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + [3.0, 4.0, 5.0, 6.0], + "SlidingWindowCache Scenario 3 failed", + ) + + def test_hybrid_cache_static_mode(self): + """Test HybridCache in static mode with hardcoded assertions. + + Scenario 1: Static layer behavior + prefill: [1.0, 2.0, 0.0, 0.0] + update pos 2: [1.0, 2.0, 3.0, 0.0] + + Scenario 2: Fill to capacity + update pos 3: [1.0, 2.0, 3.0, 4.0] + """ config = copy.deepcopy(self.config) config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0) + hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache_static_mode.reset() # Ensure it's clean even if it's new # Scenario 1 - hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 0.0, 0.0]]]], dtype=torch.float) - hybrid_cache.key_cache[0].copy_(prefill) - hybrid_cache.value_cache[0].copy_(prefill) - hybrid_cache.update( - key_states=torch.tensor([[[[3.0]]]]), - value_states=torch.tensor([[[[3.0]]]]), + prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] + hybrid_cache_static_mode.update( + key_states=prefill, + value_states=prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(4)}, + ) + hybrid_cache_static_mode.update( + key_states=torch.tensor(3.0)[None, None, None, None], + value_states=torch.tensor(3.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Static Scenario 1 failed", ) # Scenario 2 - hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 3.0, 0.0]]]], dtype=torch.float) - hybrid_cache.key_cache[0].copy_(prefill) - hybrid_cache.value_cache[0].copy_(prefill) - hybrid_cache.update( - key_states=torch.tensor([[[[4.0]]]]), - value_states=torch.tensor([[[[4.0]]]]), + hybrid_cache_static_mode.update( + key_states=torch.tensor(4.0)[None, None, None, None], + value_states=torch.tensor(4.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "HybridCache Static Scenario 2 failed", ) - def test_hybrid_cache_sliding_mode_hardcoded(self): - """Test HybridCache in sliding mode with hardcoded assertions.""" + def test_hybrid_cache_sliding_mode(self): + """Test HybridCache in sliding mode with hardcoded assertions. + + Scenario 1: Update within window, no slide yet + prefill: [1.0, 2.0, 0.0, 0.0] + update pos 2: [1.0, 2.0, 3.0, 0.0] + + Scenario 2: Update causing first slide + prefill: [1.0, 2.0, 3.0, 4.0] + update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1) + + Scenario 3: Update causing subsequent slide + update pos 5: [3.0, 4.0, 5.0, 6.0] (shift continues) + + Scenario 4: Long prompt handling (prompt_len > window_size) + input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) + """ + self.hybrid_cache.reset() + # Scenario 1: Update within window, no slide yet - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 0.0, 0.0]]]], dtype=torch.float) - hybrid_cache.key_cache[0].copy_(prefill) - hybrid_cache.value_cache[0].copy_(prefill) - hybrid_cache.update( - key_states=torch.tensor([[[[3.0]]]]), - value_states=torch.tensor([[[[3.0]]]]), + prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] + self.hybrid_cache.update( + key_states=prefill, + value_states=prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + ) + self.hybrid_cache.update( + key_states=torch.tensor(3.0)[None, None, None, None], + value_states=torch.tensor(3.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + self.hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Sliding Scenario 1 failed", ) # Scenario 2: Update causing first slide - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([[[[v] for v in [1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float) - hybrid_cache.key_cache[0].copy_(prefill) - hybrid_cache.value_cache[0].copy_(prefill) - hybrid_cache.update( - key_states=torch.tensor([[[[5.0]]]]), - value_states=torch.tensor([[[[5.0]]]]), + self.hybrid_cache.reset() + prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] + self.hybrid_cache.update( + key_states=prefill, + value_states=prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + ) + self.hybrid_cache.update( + key_states=torch.tensor(5.0)[None, None, None, None], + value_states=torch.tensor(5.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + self.hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "HybridCache Sliding Scenario 2 failed", ) # Scenario 3: Update causing subsequent slide - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([[[[v] for v in [2.0, 3.0, 4.0, 5.0]]]], dtype=torch.float) - hybrid_cache.key_cache[0].copy_(prefill) - hybrid_cache.value_cache[0].copy_(prefill) - hybrid_cache.update( - key_states=torch.tensor([[[[6.0]]]]), - value_states=torch.tensor([[[[6.0]]]]), + self.hybrid_cache.update( + key_states=torch.tensor(6.0)[None, None, None, None], + value_states=torch.tensor(6.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + self.hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 3 failed", ) + + # Scenario 4: Long prompt handling + self.hybrid_cache.reset() + long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] + self.hybrid_cache.update( + key_states=long_prefill, + value_states=long_prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, + ) + self.assertEqual( + self.hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + [3.0, 4.0, 5.0, 6.0], + "HybridCache Sliding Scenario 4 failed", + ) From 36e07a2a65101a3ee8420370fd7b091ccfc18e80 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Fri, 9 May 2025 15:06:26 +0200 Subject: [PATCH 11/19] ruff --- tests/utils/test_cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 5b8686f85d22..52d7c22eaf63 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -880,7 +880,7 @@ def test_hybrid_cache_static_mode(self): config = copy.deepcopy(self.config) config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0) hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) - hybrid_cache_static_mode.reset() # Ensure it's clean even if it's new + hybrid_cache_static_mode.reset() # Ensure it's clean even if it's new # Scenario 1 prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] From deacc6770e75a7025d8a7c3e521e2a32fda08a9f Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Sat, 10 May 2025 11:36:39 +0200 Subject: [PATCH 12/19] revert naming change --- src/transformers/cache_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 3d3f7cb63831..92ada609dc3d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -95,8 +95,8 @@ def _sliding_cache_update( # Sliding window logic for generation phase or prefill < window slicing = torch.arange(max_cache_len, device=value_states.device) current_seq_len = cache_position[-1] + 1 # Use last position to determine current length - should_shift = current_seq_len > max_cache_len - indices = (slicing + should_shift.sum()) % max_cache_len + to_shift = current_seq_len > max_cache_len + indices = (slicing + to_shift.sum()) % max_cache_len k_out_shifted = k_cache[:, :, indices] v_out_shifted = v_cache[:, :, indices] @@ -1396,9 +1396,9 @@ class SlidingWindowCache(StaticCache): if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. - The `should_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: - indices = (slicing + should_shift[-1].sum()-1) % self.config.sliding_window + indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, From 772b0a02fab4c3bbf304e959b250bb7cde88ef7f Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 13 May 2025 15:23:00 +0200 Subject: [PATCH 13/19] added new test and fixes for gptj --- src/transformers/cache_utils.py | 11 ++++++++--- tests/utils/test_cache_utils.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 92ada609dc3d..7e1f1cf49525 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1754,7 +1754,9 @@ def __init__( self._dtype = dtype self.num_key_value_heads = ( - config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads ) layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC @@ -2050,7 +2052,7 @@ def __init__( # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps # track of the original device of each layer - unique_devices = set(layer_device_map.values()) + unique_devices = set(layer_device_map.values()) if layer_device_map else set() if len(unique_devices) > 1: raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}") @@ -2309,7 +2311,7 @@ def __init__( # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps # track of the original device of each layer - unique_devices = set(layer_device_map.values()) + unique_devices = set(layer_device_map.values()) if layer_device_map else set() if len(unique_devices) > 1: raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}") @@ -2386,6 +2388,9 @@ def update( A tuple containing the updated key and value states. """ + key_states = key_states.to(self.key_cache[layer_idx].dtype) + value_states = value_states.to(self.value_cache[layer_idx].dtype) + if layer_idx == 0: # Update seen tokens. # TODO(gante): Remove this. diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 52d7c22eaf63..ad2916cdb2ac 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -306,6 +306,20 @@ def test_cache_extra_left_padding(self, cache_implementation): decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True) self.assertListEqual(decoded, EXPECTED_GENERATION) + @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) + def test_cache_pipe_rope_model(self, cache_implementation): + """Tests caches with a RoPE model""" + self._skip_on_failed_cache_prerequisites(cache_implementation) + # if cache_implementation in ["off + from transformers import pipeline + + model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM" + pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16) + pipe.model.config.sliding_window = 10 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None + out = pipe("h", cache_implementation=cache_implementation, max_new_tokens=10, do_sample=False, disable_compile=True, return_tensors=True)[0]['generated_token_ids'][-10:] + EXPECTED_OUTPUT = [914, 134, 124, 889, 48, 233, 541, 27, 380, 365] + self.assertListEqual(out, EXPECTED_OUTPUT) + @require_torch_accelerator class CacheHardIntegrationTest(unittest.TestCase): From e67049fe0aa172c1b7a1c660906ccd1bb73548f6 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 13 May 2025 15:30:31 +0200 Subject: [PATCH 14/19] ruff --- tests/utils/test_cache_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index ad2916cdb2ac..7d75b71f44f3 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -310,13 +310,13 @@ def test_cache_extra_left_padding(self, cache_implementation): def test_cache_pipe_rope_model(self, cache_implementation): """Tests caches with a RoPE model""" self._skip_on_failed_cache_prerequisites(cache_implementation) - # if cache_implementation in ["off from transformers import pipeline - + model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM" pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16) pipe.model.config.sliding_window = 10 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None - out = pipe("h", cache_implementation=cache_implementation, max_new_tokens=10, do_sample=False, disable_compile=True, return_tensors=True)[0]['generated_token_ids'][-10:] + out = pipe("h", cache_implementation=cache_implementation, max_new_tokens=10, do_sample=False, disable_compile=True, return_tensors=True) + out = out[0]['generated_token_ids'][-10:] EXPECTED_OUTPUT = [914, 134, 124, 889, 48, 233, 541, 27, 380, 365] self.assertListEqual(out, EXPECTED_OUTPUT) From df95e1fe482195b616dad495917366329765b9bb Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 13 May 2025 15:36:13 +0200 Subject: [PATCH 15/19] reinit instead of resetting stateful caches --- tests/utils/test_cache_utils.py | 73 +++++++++++++++------------------ 1 file changed, 33 insertions(+), 40 deletions(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 7d75b71f44f3..5b2852ddddda 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -755,17 +755,14 @@ def setUp(self): sliding_window=self.window_size, sliding_window_pattern=2, # Default pattern for hybrid sliding ) - self.static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - self.sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) - self.hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) def test_static_cache_out_of_bounds(self): """Test StaticCache raises IndexError for out-of-bounds positions.""" - self.static_cache.reset() + static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len with self.assertRaises(IndexError): - self.static_cache.update( + static_cache.update( key_states=torch.tensor([[[[1.0]]]]), value_states=torch.tensor([[[[1.0]]]]), layer_idx=0, @@ -782,30 +779,29 @@ def test_static_cache(self): Scenario 2: Fill to capacity update pos 3: [1.0, 2.0, 3.0, 4.0] """ - self.static_cache.reset() - # Scenario 1: Fill up to near capacity + static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] - self.static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None) - self.static_cache.update( + static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None) + static_cache.update( key_states=torch.tensor(3.0)[None, None, None, None], value_states=torch.tensor(3.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - self.static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" + static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" ) # Scenario 2: Fill to capacity - self.static_cache.update( + static_cache.update( key_states=torch.tensor(4.0)[None, None, None, None], value_states=torch.tensor(4.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - self.static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" + static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" ) def test_sliding_window_cache(self): @@ -823,60 +819,59 @@ def test_sliding_window_cache(self): input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ - self.sliding_cache.reset() - # Scenario 1: Update within window, no slide yet + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] - self.sliding_cache.update( + sliding_cache.update( key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, ) - self.sliding_cache.update( + sliding_cache.update( key_states=torch.tensor(3.0)[None, None, None, None], value_states=torch.tensor(3.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - self.sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "SlidingWindowCache Scenario 1 failed", ) # Scenario 2: Update causing slide - self.sliding_cache.reset() + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] - self.sliding_cache.update( + sliding_cache.update( key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, ) - self.sliding_cache.update( + sliding_cache.update( key_states=torch.tensor(5.0)[None, None, None, None], value_states=torch.tensor(5.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - self.sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "SlidingWindowCache Scenario 2 failed", ) # Scenario 3: Long prompt handling - self.sliding_cache.reset() + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] - self.sliding_cache.update( + sliding_cache.update( key_states=long_prefill, value_states=long_prefill, layer_idx=0, cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - self.sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.key_cache[0][0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "SlidingWindowCache Scenario 3 failed", ) @@ -893,10 +888,9 @@ def test_hybrid_cache_static_mode(self): """ config = copy.deepcopy(self.config) config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0) - hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) - hybrid_cache_static_mode.reset() # Ensure it's clean even if it's new # Scenario 1 + hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] hybrid_cache_static_mode.update( key_states=prefill, @@ -947,73 +941,72 @@ def test_hybrid_cache_sliding_mode(self): input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ - self.hybrid_cache.reset() - # Scenario 1: Update within window, no slide yet + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] - self.hybrid_cache.update( + hybrid_cache.update( key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, ) - self.hybrid_cache.update( + hybrid_cache.update( key_states=torch.tensor(3.0)[None, None, None, None], value_states=torch.tensor(3.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - self.hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Sliding Scenario 1 failed", ) # Scenario 2: Update causing first slide - self.hybrid_cache.reset() + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] - self.hybrid_cache.update( + hybrid_cache.update( key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, ) - self.hybrid_cache.update( + hybrid_cache.update( key_states=torch.tensor(5.0)[None, None, None, None], value_states=torch.tensor(5.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - self.hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "HybridCache Sliding Scenario 2 failed", ) # Scenario 3: Update causing subsequent slide - self.hybrid_cache.update( + hybrid_cache.update( key_states=torch.tensor(6.0)[None, None, None, None], value_states=torch.tensor(6.0)[None, None, None, None], layer_idx=0, cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size}, ) self.assertEqual( - self.hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 3 failed", ) # Scenario 4: Long prompt handling - self.hybrid_cache.reset() + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] - self.hybrid_cache.update( + hybrid_cache.update( key_states=long_prefill, value_states=long_prefill, layer_idx=0, cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - self.hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 4 failed", ) From 2ce64c53151a257f24bfcea8141654907f62f48e Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 13 May 2025 15:37:45 +0200 Subject: [PATCH 16/19] ruff --- tests/utils/test_cache_utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 5b2852ddddda..9ea8dc892510 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -314,9 +314,18 @@ def test_cache_pipe_rope_model(self, cache_implementation): model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM" pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16) - pipe.model.config.sliding_window = 10 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None - out = pipe("h", cache_implementation=cache_implementation, max_new_tokens=10, do_sample=False, disable_compile=True, return_tensors=True) - out = out[0]['generated_token_ids'][-10:] + pipe.model.config.sliding_window = ( + 10 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None + ) + out = pipe( + "h", + cache_implementation=cache_implementation, + max_new_tokens=10, + do_sample=False, + disable_compile=True, + return_tensors=True, + ) + out = out[0]["generated_token_ids"][-10:] EXPECTED_OUTPUT = [914, 134, 124, 889, 48, 233, 541, 27, 380, 365] self.assertListEqual(out, EXPECTED_OUTPUT) From 901c2a47155b6fdc7c87f9d5dd6ec8b937fc744e Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Tue, 13 May 2025 17:11:39 +0200 Subject: [PATCH 17/19] optimize short seqs --- src/transformers/cache_utils.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 7e1f1cf49525..8780df870e87 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -92,16 +92,28 @@ def _sliding_cache_update( v_cache.copy_(new_v) return key_states, value_states - # Sliding window logic for generation phase or prefill < window - slicing = torch.arange(max_cache_len, device=value_states.device) current_seq_len = cache_position[-1] + 1 # Use last position to determine current length + + # Use optimized path for short sequences when not compiling + if not torch.compiler.is_compiling() and current_seq_len <= max_cache_len: + try: + k_cache.index_copy_(2, cache_position, key_states) + v_cache.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # Fallback for MPS: direct assignment + k_cache[:, :, cache_position] = key_states + v_cache[:, :, cache_position] = value_states + return k_cache, v_cache + + # Sliding window logic for when we exceed max_cache_len + slicing = torch.arange(max_cache_len, device=value_states.device) to_shift = current_seq_len > max_cache_len indices = (slicing + to_shift.sum()) % max_cache_len k_out_shifted = k_cache[:, :, indices] v_out_shifted = v_cache[:, :, indices] - # Clamp cache_position to determine the *target index* within the shifted cache view + # Only need to clamp when we're doing sliding window update_position = cache_position.clamp(min=0, max=max_cache_len - 1) try: From 8a9b0e2abe48cc6ea6bd7be8a10b6c85cfc150e9 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 14 May 2025 15:56:57 +0200 Subject: [PATCH 18/19] Revert "optimize short seqs" This reverts commit 901c2a47155b6fdc7c87f9d5dd6ec8b937fc744e. --- src/transformers/cache_utils.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 8780df870e87..7e1f1cf49525 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -92,28 +92,16 @@ def _sliding_cache_update( v_cache.copy_(new_v) return key_states, value_states - current_seq_len = cache_position[-1] + 1 # Use last position to determine current length - - # Use optimized path for short sequences when not compiling - if not torch.compiler.is_compiling() and current_seq_len <= max_cache_len: - try: - k_cache.index_copy_(2, cache_position, key_states) - v_cache.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # Fallback for MPS: direct assignment - k_cache[:, :, cache_position] = key_states - v_cache[:, :, cache_position] = value_states - return k_cache, v_cache - - # Sliding window logic for when we exceed max_cache_len + # Sliding window logic for generation phase or prefill < window slicing = torch.arange(max_cache_len, device=value_states.device) + current_seq_len = cache_position[-1] + 1 # Use last position to determine current length to_shift = current_seq_len > max_cache_len indices = (slicing + to_shift.sum()) % max_cache_len k_out_shifted = k_cache[:, :, indices] v_out_shifted = v_cache[:, :, indices] - # Only need to clamp when we're doing sliding window + # Clamp cache_position to determine the *target index* within the shifted cache view update_position = cache_position.clamp(min=0, max=max_cache_len - 1) try: From bd0a245d9ed068f80cead9bd94f563ce1bf9c6dc Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral Date: Wed, 14 May 2025 16:38:05 +0200 Subject: [PATCH 19/19] apply suggestions --- tests/utils/test_cache_utils.py | 81 +++++++++++++++++---------------- 1 file changed, 41 insertions(+), 40 deletions(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 9ea8dc892510..24c5c650e97c 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -52,6 +52,7 @@ SlidingWindowCache, StaticCache, convert_and_export_with_cache, + pipeline, ) @@ -190,6 +191,21 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 1, 10, 128)) +def _skip_on_failed_cache_prerequisites(test, cache_implementation): + """Function to skip tests on failed cache prerequisites, given a cache implementation""" + # Installed dependencies + if cache_implementation == "quantized" and not is_optimum_quanto_available(): + test.skipTest("Quanto is not available") + # Devices + if "offloaded" in cache_implementation: + has_accelerator = torch_device is not None and torch_device != "cpu" + if not has_accelerator: + test.skipTest("Offloaded caches require an accelerator") + if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]: + if torch.cuda.device_count() != 1: + test.skipTest("Offloaded static caches require exactly 1 GPU") + + class CacheIntegrationTest(unittest.TestCase): """Fast cache integration tests that share the same small model""" @@ -202,24 +218,10 @@ def setUpClass(cls): ) cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows - def _skip_on_failed_cache_prerequisites(self, cache_implementation): - """Function to skip tests on failed cache prerequisites, given a cache implementation""" - # Installed dependencies - if cache_implementation == "quantized" and not is_optimum_quanto_available(): - self.skipTest("Quanto is not available") - # Devices - if "offloaded" in cache_implementation: - has_accelerator = torch_device is not None and torch_device != "cpu" - if not has_accelerator: - self.skipTest("Offloaded caches require an accelerator") - if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]: - if torch.cuda.device_count() != 1: - self.skipTest("Offloaded static caches require exactly 1 GPU") - @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) def test_cache_batched(self, cache_implementation): """Sanity check: caches' `.update` function expects batched inputs""" - self._skip_on_failed_cache_prerequisites(cache_implementation) + _skip_on_failed_cache_prerequisites(self, cache_implementation) EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] @@ -248,7 +250,7 @@ def test_cache_beam_search(self, cache_implementation): Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices (an output sequence contains multiple beam indices). """ - self._skip_on_failed_cache_prerequisites(cache_implementation) + _skip_on_failed_cache_prerequisites(self, cache_implementation) if cache_implementation == "offloaded_hybrid_chunked": # TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the # output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly @@ -282,7 +284,7 @@ def test_cache_beam_search(self, cache_implementation): @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) def test_cache_extra_left_padding(self, cache_implementation): """Tests that adding extra left-padding does not affect the generation with the cache""" - self._skip_on_failed_cache_prerequisites(cache_implementation) + _skip_on_failed_cache_prerequisites(self, cache_implementation) EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."] @@ -306,29 +308,6 @@ def test_cache_extra_left_padding(self, cache_implementation): decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True) self.assertListEqual(decoded, EXPECTED_GENERATION) - @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) - def test_cache_pipe_rope_model(self, cache_implementation): - """Tests caches with a RoPE model""" - self._skip_on_failed_cache_prerequisites(cache_implementation) - from transformers import pipeline - - model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM" - pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16) - pipe.model.config.sliding_window = ( - 10 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None - ) - out = pipe( - "h", - cache_implementation=cache_implementation, - max_new_tokens=10, - do_sample=False, - disable_compile=True, - return_tensors=True, - ) - out = out[0]["generated_token_ids"][-10:] - EXPECTED_OUTPUT = [914, 134, 124, 889, 48, 233, 541, 27, 380, 365] - self.assertListEqual(out, EXPECTED_OUTPUT) - @require_torch_accelerator class CacheHardIntegrationTest(unittest.TestCase): @@ -577,6 +556,28 @@ def test_static_cache_multi_accelerator(self): _ = model(**inputs) _ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid") + @require_torch_gpu + @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) + def test_cache_gptj_model(self, cache_implementation): + """Tests caches with GPT-J model. Regression test for https://github.com/huggingface/transformers/pull/34799""" + _skip_on_failed_cache_prerequisites(self, cache_implementation) + + model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM" + pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16) + pipe.model.config.sliding_window = ( + 256 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None + ) + out = pipe( + "hello world", + cache_implementation=cache_implementation, + max_new_tokens=10, + do_sample=False, + disable_compile=True, + return_tensors=True, + )[0]["generated_token_ids"][-10:] + EXPECTED_OUTPUT = [879, 175, 39, 141, 1000, 975, 951, 991, 683, 441] + self.assertListEqual(out, EXPECTED_OUTPUT) + @require_torch class CacheExportIntegrationTest(unittest.TestCase):