-
Notifications
You must be signed in to change notification settings - Fork 33.1k
New cache tests and modular Hybrid Cache #37972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
acb901e
4eacd7d
05d2ce6
6b765bd
32cd5f6
4ddd8d6
9bfdcbc
ec26e69
9858f2c
95805f3
f08ea20
b3b0133
016d9db
3de7505
214e517
468d887
36e07a2
deacc67
8548b8f
326d2b2
772b0a0
e67049f
df95e1f
86d3f21
2ce64c5
901c2a4
8a9b0e2
bd0a245
c16653b
d719ef4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,104 @@ | |
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| # Utility functions for static/sliding cache update logic | ||
| def _static_cache_update( | ||
| k_cache: torch.Tensor, | ||
| v_cache: torch.Tensor, | ||
| key_states: torch.Tensor, | ||
| value_states: torch.Tensor, | ||
| cache_position: Optional[torch.LongTensor], | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Updates the static cache tensors in place. | ||
|
|
||
| Args: | ||
| k_cache (`torch.Tensor`): The key cache tensor to update. | ||
| v_cache (`torch.Tensor`): The value cache tensor to update. | ||
| key_states (`torch.Tensor`): The new key states to add. | ||
| value_states (`torch.Tensor`): The new value states to add. | ||
| cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted. | ||
| If None, the entire cache is overwritten (prefill). | ||
|
|
||
| Returns: | ||
| Tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place). | ||
| """ | ||
| if cache_position is None: | ||
| # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. | ||
| k_cache.copy_(key_states) | ||
| v_cache.copy_(value_states) | ||
| else: | ||
| # Generation phase. Update specific positions. | ||
| # Use index_copy_ for in-place update (compile-friendly). | ||
| try: | ||
| k_cache.index_copy_(2, cache_position, key_states) | ||
| v_cache.index_copy_(2, cache_position, value_states) | ||
| except NotImplementedError: | ||
| # Fallback for devices like MPS where index_copy_ might not be supported. | ||
| k_cache[:, :, cache_position] = key_states | ||
| v_cache[:, :, cache_position] = value_states | ||
| return k_cache, v_cache | ||
|
|
||
|
|
||
| def _sliding_cache_update( | ||
| k_cache: torch.Tensor, | ||
| v_cache: torch.Tensor, | ||
| key_states: torch.Tensor, | ||
| value_states: torch.Tensor, | ||
| cache_position: torch.LongTensor, | ||
| max_cache_len: int, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Updates the sliding window cache tensors, returning the potentially modified tensors. | ||
|
|
||
| Args: | ||
| k_cache (`torch.Tensor`): The key cache tensor to update. | ||
| v_cache (`torch.Tensor`): The value cache tensor to update. | ||
| key_states (`torch.Tensor`): The new key states to add. | ||
| value_states (`torch.Tensor`): The new value states to add. | ||
| cache_position (`torch.LongTensor`): The position indices where the new states should be inserted. | ||
| max_cache_len (`int`): The maximum length of the sliding window cache. | ||
|
|
||
| Returns: | ||
| Tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update. | ||
| For prefill > window, these are the full input states. | ||
| Otherwise, they are the updated cache tensors. | ||
| """ | ||
| # Handle prefill phase when prompt length > sliding_window_size | ||
| if cache_position.shape[0] > max_cache_len: | ||
|
manueldeprada marked this conversation as resolved.
|
||
| new_k = key_states[:, :, -max_cache_len:, :] | ||
| new_v = value_states[:, :, -max_cache_len:, :] | ||
| k_cache.copy_(new_k) | ||
| v_cache.copy_(new_v) | ||
| return key_states, value_states | ||
|
|
||
| # Sliding window logic for generation phase or prefill < window | ||
| slicing = torch.arange(max_cache_len, device=value_states.device) | ||
| current_seq_len = cache_position[-1] + 1 # Use last position to determine current length | ||
| to_shift = current_seq_len > max_cache_len | ||
| indices = (slicing + to_shift.sum()) % max_cache_len | ||
|
|
||
| k_out_shifted = k_cache[:, :, indices] | ||
| v_out_shifted = v_cache[:, :, indices] | ||
|
|
||
| # Clamp cache_position to determine the *target index* within the shifted cache view | ||
| update_position = cache_position.clamp(min=0, max=max_cache_len - 1) | ||
|
manueldeprada marked this conversation as resolved.
|
||
|
|
||
| try: | ||
| k_out_updated = k_out_shifted.index_copy(2, update_position, key_states) | ||
| v_out_updated = v_out_shifted.index_copy(2, update_position, value_states) | ||
| except NotImplementedError: | ||
| # Fallback for MPS: clone and modify the clone | ||
| k_out_updated = k_out_shifted.clone() | ||
| v_out_updated = v_out_shifted.clone() | ||
|
manueldeprada marked this conversation as resolved.
|
||
| k_out_updated[:, :, update_position] = key_states | ||
| v_out_updated[:, :, update_position] = value_states | ||
|
|
||
| k_cache.copy_(k_out_updated) | ||
| v_cache.copy_(v_out_updated) | ||
|
manueldeprada marked this conversation as resolved.
|
||
| return k_out_updated, v_out_updated | ||
|
|
||
|
|
||
| class Cache: | ||
| """ | ||
| Base, abstract class for all caches. The actual data structure is specific to each subclass. | ||
|
|
@@ -1264,28 +1362,16 @@ def update( | |
| """ | ||
| if cache_kwargs is None: | ||
| cache_kwargs = {} | ||
| cache_position = cache_kwargs.get("cache_position") | ||
| k_out = self.key_cache[layer_idx] | ||
| v_out = self.value_cache[layer_idx] | ||
| key_states = key_states.to(k_out.dtype) | ||
| value_states = value_states.to(v_out.dtype) | ||
|
|
||
| if cache_position is None: | ||
| k_out.copy_(key_states) | ||
| v_out.copy_(value_states) | ||
| else: | ||
| # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to | ||
| # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place | ||
| # operation, that avoids copies and uses less memory. | ||
| try: | ||
| k_out.index_copy_(2, cache_position, key_states) | ||
| v_out.index_copy_(2, cache_position, value_states) | ||
| except NotImplementedError: | ||
| # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. | ||
| k_out[:, :, cache_position] = key_states | ||
| v_out[:, :, cache_position] = value_states | ||
|
|
||
| return k_out, v_out | ||
| key_states = key_states.to(self.key_cache[layer_idx].dtype) | ||
| value_states = value_states.to(self.value_cache[layer_idx].dtype) | ||
|
Comment on lines
+1366
to
+1367
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The dtype should already be correct here no?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am happy to remove it, passes tests. There are similar checks or casts that could be removed too, but I kept them in case existing code relies on them.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can remove the cast, yes. Looking at the original PR, the source of the lines was to handle the case where we don't cast RoPE-based KVs in the model forward pass (RoPE is by default FP32, regardless of the model dtype)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the pointer @gante, I should have traced that down. We can't remove it: the PR's sample code fails after removing the casts. I am restoring the cast and adding a test... I agree though with @Cyrilvallez that it is a bad solution to cast everything instead of doing something specific for RoPE. Since this PR has limited scope and is ground work for #38077, I will try to solve it more elegantly there.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [perhaps for a subsequent PR, to avoid bloating/delaying this one:] It would be more transparent and precise if casting is done in the model architecture, rather than in the cache. In the specific case of GPT-J loaded in FP16, it seems like without cache KV is kept in FP32, and with cache KV is casted to FP16 in the cache class -> cache introduces performance degradation. As such, I think it would be positive to remove the cast, and delegate control to the model architectures.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good, I will check if its only a gptJ thing in a subsequent PR. In the meantime, I added the test. It unveiled 3 fixes like this that were applied to StaticCache but not to Hybrid and Offloaded. Please have a quick look at 772b0a0 before I merge.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Especially if it's only there for a given old model -> much better to fix the model rather than general cache logic! |
||
| return _static_cache_update( | ||
| self.key_cache[layer_idx], | ||
| self.value_cache[layer_idx], | ||
| key_states, | ||
| value_states, | ||
| cache_kwargs.get("cache_position"), | ||
| ) | ||
|
|
||
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
| """Returns the sequence length of the cached states that were seen by the model.""" | ||
|
|
@@ -1314,7 +1400,7 @@ class SlidingWindowCache(StaticCache): | |
|
|
||
| The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: | ||
|
|
||
| indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window | ||
| indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window | ||
| tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, | ||
| 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, | ||
| 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, | ||
|
|
@@ -1398,46 +1484,21 @@ def update( | |
| if cache_kwargs is None: | ||
| cache_kwargs = {} | ||
| cache_position = cache_kwargs.get("cache_position") | ||
| k_out = self.key_cache[layer_idx] | ||
| v_out = self.value_cache[layer_idx] | ||
| key_states = key_states.to(k_out.dtype) | ||
| value_states = value_states.to(v_out.dtype) | ||
|
|
||
| # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) | ||
| if cache_position.shape[0] >= self.max_cache_len: | ||
| k_out = key_states[:, :, -self.max_cache_len :, :] | ||
| v_out = value_states[:, :, -self.max_cache_len :, :] | ||
| # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly | ||
| self.key_cache[layer_idx] += k_out | ||
| self.value_cache[layer_idx] += v_out | ||
| # we should return the whole states instead of k_out, v_out to take the whole prompt | ||
| # into consideration when building kv cache instead of just throwing away tokens outside of the window | ||
| return key_states, value_states | ||
|
|
||
| slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) | ||
| to_shift = cache_position > self.max_cache_len - 1 | ||
| cache_position = cache_position.clamp(0, self.max_cache_len - 1) | ||
| indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len | ||
|
|
||
| k_out = k_out[:, :, indices] | ||
| v_out = v_out[:, :, indices] | ||
|
|
||
| try: | ||
| k_out.index_copy_(2, cache_position, key_states) | ||
| v_out.index_copy_(2, cache_position, value_states) | ||
| except NotImplementedError: | ||
| # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. | ||
| k_out[:, :, cache_position] = key_states | ||
| v_out[:, :, cache_position] = value_states | ||
|
|
||
| # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) | ||
| self.key_cache[layer_idx].zero_() | ||
| self.value_cache[layer_idx].zero_() | ||
| if cache_position is None: | ||
| raise ValueError("`cache_position` must be provided for SlidingWindowCache.") | ||
|
|
||
| self.key_cache[layer_idx] += k_out | ||
| self.value_cache[layer_idx] += v_out | ||
| key_states = key_states.to(self.key_cache[layer_idx].dtype) | ||
| value_states = value_states.to(self.value_cache[layer_idx].dtype) | ||
|
|
||
| return k_out, v_out | ||
| return _sliding_cache_update( | ||
| self.key_cache[layer_idx], | ||
| self.value_cache[layer_idx], | ||
| key_states, | ||
| value_states, | ||
| cache_position, | ||
| self.max_cache_len, | ||
| ) | ||
|
|
||
| def get_max_cache_shape(self) -> Optional[int]: | ||
| return self.max_cache_len | ||
|
|
@@ -1680,12 +1741,13 @@ def __init__( | |
| super().__init__() | ||
| if not hasattr(config, "sliding_window") or config.sliding_window is None: | ||
| raise ValueError( | ||
| "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " | ||
| "Setting `cache_implementation` to 'hybrid' requires the model config supporting " | ||
| "sliding window attention, please check if there is a `sliding_window` field in the model " | ||
| "config and it's not set to None." | ||
| ) | ||
| self.max_cache_len = max_cache_len | ||
| self._sliding_window_max_len = min(config.sliding_window, max_cache_len) | ||
| self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings | ||
| # Sliding layers can't be larger than the overall max cache len | ||
| self.sliding_window_len = min(config.sliding_window, self.max_cache_len) | ||
| self.max_batch_size = max_batch_size | ||
| # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads | ||
| self.head_dim = ( | ||
|
|
@@ -1694,22 +1756,17 @@ def __init__( | |
|
|
||
| self._dtype = dtype | ||
| self.num_key_value_heads = ( | ||
| config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads | ||
| config.num_attention_heads | ||
| if getattr(config, "num_key_value_heads", None) is None | ||
| else config.num_key_value_heads | ||
| ) | ||
|
|
||
| layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC | ||
| self.is_sliding = torch.tensor( | ||
| [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool | ||
| ) | ||
| self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)] | ||
|
manueldeprada marked this conversation as resolved.
|
||
| self.key_cache: List[torch.Tensor] = [] | ||
| self.value_cache: List[torch.Tensor] = [] | ||
| global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) | ||
| sliding_cache_shape = ( | ||
| self.max_batch_size, | ||
| self.num_key_value_heads, | ||
| self._sliding_window_max_len, | ||
| self.head_dim, | ||
| ) | ||
| global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) | ||
| sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim) | ||
| device = torch.device(device) if device is not None else None | ||
| for i in range(config.num_hidden_layers): | ||
| if layer_device_map is not None: | ||
|
|
@@ -1718,50 +1775,14 @@ def __init__( | |
| layer_device = device | ||
| # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph | ||
| # breaks when updating the cache. | ||
| cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape | ||
| cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape | ||
| new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) | ||
| new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) | ||
| torch._dynamo.mark_static_address(new_layer_key_cache) | ||
| torch._dynamo.mark_static_address(new_layer_value_cache) | ||
| self.key_cache.append(new_layer_key_cache) | ||
| self.value_cache.append(new_layer_value_cache) | ||
|
|
||
| def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): | ||
| if cache_position.shape[0] >= max_cache_len: | ||
| k_out = key_states[:, :, -max_cache_len:, :] | ||
| v_out = value_states[:, :, -max_cache_len:, :] | ||
| # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly | ||
| self.key_cache[layer_idx] += k_out | ||
| self.value_cache[layer_idx] += v_out | ||
| # we should return the whole states instead of k_out, v_out to take the whole prompt | ||
| # into consideration when building kv cache instead of just throwing away tokens outside of the window | ||
| return key_states, value_states | ||
|
|
||
| slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) | ||
| to_shift = cache_position > max_cache_len - 1 | ||
| cache_position = cache_position.clamp(0, max_cache_len - 1) | ||
| indices = (slicing + to_shift[-1].int() - 1) % max_cache_len | ||
| k_out = k_out[:, :, indices] | ||
| v_out = v_out[:, :, indices] | ||
|
|
||
| k_out[:, :, cache_position] = key_states | ||
| v_out[:, :, cache_position] = value_states | ||
| # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) | ||
| self.key_cache[layer_idx].zero_() | ||
| self.value_cache[layer_idx].zero_() | ||
|
|
||
| self.key_cache[layer_idx] += k_out | ||
| self.value_cache[layer_idx] += v_out | ||
| return k_out, v_out | ||
|
|
||
| def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): | ||
| k_out[:, :, cache_position] = key_states | ||
| v_out[:, :, cache_position] = value_states | ||
|
|
||
| self.key_cache[layer_idx] = k_out | ||
| self.value_cache[layer_idx] = v_out | ||
| return k_out, v_out | ||
|
|
||
| def update( | ||
| self, | ||
| key_states: torch.Tensor, | ||
|
|
@@ -1772,7 +1793,10 @@ def update( | |
| if cache_kwargs is None: | ||
| cache_kwargs = {} | ||
| cache_position = cache_kwargs.get("cache_position") | ||
| sliding_window = cache_kwargs.get("sliding_window") | ||
| if cache_position is None: | ||
| raise ValueError("`cache_position` must be provided for HybridCache.") | ||
|
|
||
| is_sliding_layer = self.is_sliding_list[layer_idx] | ||
|
manueldeprada marked this conversation as resolved.
|
||
|
|
||
| # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used | ||
| # when the cache is initialized in the forward pass (e.g. Gemma2) | ||
|
|
@@ -1781,25 +1805,22 @@ def update( | |
| if self.value_cache[layer_idx].device != value_states.device: | ||
| self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) | ||
|
|
||
| k_out = self.key_cache[layer_idx] | ||
| v_out = self.value_cache[layer_idx] | ||
| key_states = key_states.to(k_out.dtype) | ||
| value_states = value_states.to(v_out.dtype) | ||
|
|
||
| if sliding_window: | ||
| update_fn = self._sliding_update | ||
| k_cache = self.key_cache[layer_idx] | ||
| v_cache = self.value_cache[layer_idx] | ||
| key_states = key_states.to(k_cache.dtype) | ||
| value_states = value_states.to(v_cache.dtype) | ||
|
|
||
| if is_sliding_layer: | ||
|
manueldeprada marked this conversation as resolved.
|
||
| return _sliding_cache_update( | ||
| k_cache, | ||
| v_cache, | ||
| key_states, | ||
| value_states, | ||
| cache_position, | ||
| k_cache.shape[2], # Use actual cache dim as max cache len | ||
| ) | ||
| else: | ||
| update_fn = self._static_update | ||
|
|
||
| return update_fn( | ||
| cache_position, | ||
| layer_idx, | ||
| key_states, | ||
| value_states, | ||
| k_out, | ||
| v_out, | ||
| k_out.shape[2], | ||
| ) | ||
| return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position) | ||
|
|
||
| def get_max_cache_shape(self) -> Optional[int]: | ||
| return self.max_cache_len | ||
|
|
@@ -2033,7 +2054,7 @@ def __init__( | |
|
|
||
| # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps | ||
| # track of the original device of each layer | ||
| unique_devices = set(layer_device_map.values()) | ||
| unique_devices = set(layer_device_map.values()) if layer_device_map else set() | ||
| if len(unique_devices) > 1: | ||
| raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}") | ||
|
|
||
|
|
@@ -2292,7 +2313,7 @@ def __init__( | |
|
|
||
| # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps | ||
| # track of the original device of each layer | ||
| unique_devices = set(layer_device_map.values()) | ||
| unique_devices = set(layer_device_map.values()) if layer_device_map else set() | ||
| if len(unique_devices) > 1: | ||
| raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}") | ||
|
|
||
|
|
@@ -2369,6 +2390,9 @@ def update( | |
| A tuple containing the updated key and value states. | ||
| """ | ||
|
|
||
| key_states = key_states.to(self.key_cache[layer_idx].dtype) | ||
| value_states = value_states.to(self.value_cache[layer_idx].dtype) | ||
|
|
||
| if layer_idx == 0: | ||
| # Update seen tokens. | ||
| # TODO(gante): Remove this. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.