Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
acb901e
squash rebase
manueldeprada Apr 24, 2025
4eacd7d
ruff
manueldeprada May 6, 2025
05d2ce6
Merge branch 'main' into cache-fix2
manueldeprada May 6, 2025
6b765bd
ruff
manueldeprada May 6, 2025
32cd5f6
fix hybrid cache in torch compile
manueldeprada May 6, 2025
4ddd8d6
Merge branch 'main' into cache-fix2
manueldeprada May 6, 2025
9bfdcbc
Merge branch 'main' of github.com:huggingface/transformers into cache…
manueldeprada May 6, 2025
ec26e69
joaos suggestions
manueldeprada May 6, 2025
9858f2c
Merge branch 'main' into cache-fix2
manueldeprada May 6, 2025
95805f3
ruff
manueldeprada May 6, 2025
f08ea20
Trigger Build
manueldeprada May 6, 2025
b3b0133
ruff
manueldeprada May 6, 2025
016d9db
Merge branch 'main' into cache-fix2
manueldeprada May 7, 2025
3de7505
Update src/transformers/cache_utils.py
manueldeprada May 8, 2025
214e517
suggestions
manueldeprada May 9, 2025
468d887
Merge branch 'main' of github.com:huggingface/transformers into cache…
manueldeprada May 9, 2025
36e07a2
ruff
manueldeprada May 9, 2025
deacc67
revert naming change
manueldeprada May 10, 2025
8548b8f
Merge branch 'main' of github.com:huggingface/transformers into cache…
manueldeprada May 10, 2025
326d2b2
Merge remote-tracking branch 'upstream/main' into cache-fix2
manueldeprada May 11, 2025
772b0a0
added new test and fixes for gptj
manueldeprada May 13, 2025
e67049f
ruff
manueldeprada May 13, 2025
df95e1f
reinit instead of resetting stateful caches
manueldeprada May 13, 2025
86d3f21
Merge remote-tracking branch 'upstream/main' into cache-fix2
manueldeprada May 13, 2025
2ce64c5
ruff
manueldeprada May 13, 2025
901c2a4
optimize short seqs
manueldeprada May 13, 2025
8a9b0e2
Revert "optimize short seqs"
manueldeprada May 14, 2025
bd0a245
apply suggestions
manueldeprada May 14, 2025
c16653b
Merge branch 'main' into cache-fix2
manueldeprada May 19, 2025
d719ef4
Merge branch 'main' into cache-fix2
manueldeprada May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 155 additions & 131 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,104 @@
logger = logging.get_logger(__name__)


# Utility functions for static/sliding cache update logic
def _static_cache_update(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_position: Optional[torch.LongTensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the static cache tensors in place.

Args:
k_cache (`torch.Tensor`): The key cache tensor to update.
v_cache (`torch.Tensor`): The value cache tensor to update.
key_states (`torch.Tensor`): The new key states to add.
value_states (`torch.Tensor`): The new value states to add.
cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted.
If None, the entire cache is overwritten (prefill).

Returns:
Tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place).
"""
if cache_position is None:
# Prefill phase where seq_len potentially equals max_cache_len. Directly copy.
k_cache.copy_(key_states)
v_cache.copy_(value_states)
else:
# Generation phase. Update specific positions.
# Use index_copy_ for in-place update (compile-friendly).
try:
k_cache.index_copy_(2, cache_position, key_states)
v_cache.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# Fallback for devices like MPS where index_copy_ might not be supported.
k_cache[:, :, cache_position] = key_states
v_cache[:, :, cache_position] = value_states
Comment thread
manueldeprada marked this conversation as resolved.
return k_cache, v_cache


def _sliding_cache_update(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_position: torch.LongTensor,
max_cache_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the sliding window cache tensors, returning the potentially modified tensors.

Args:
k_cache (`torch.Tensor`): The key cache tensor to update.
v_cache (`torch.Tensor`): The value cache tensor to update.
key_states (`torch.Tensor`): The new key states to add.
value_states (`torch.Tensor`): The new value states to add.
cache_position (`torch.LongTensor`): The position indices where the new states should be inserted.
max_cache_len (`int`): The maximum length of the sliding window cache.

Returns:
Tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update.
For prefill > window, these are the full input states.
Otherwise, they are the updated cache tensors.
"""
# Handle prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > max_cache_len:
Comment thread
manueldeprada marked this conversation as resolved.
new_k = key_states[:, :, -max_cache_len:, :]
new_v = value_states[:, :, -max_cache_len:, :]
k_cache.copy_(new_k)
v_cache.copy_(new_v)
return key_states, value_states

# Sliding window logic for generation phase or prefill < window
slicing = torch.arange(max_cache_len, device=value_states.device)
current_seq_len = cache_position[-1] + 1 # Use last position to determine current length
to_shift = current_seq_len > max_cache_len
indices = (slicing + to_shift.sum()) % max_cache_len

k_out_shifted = k_cache[:, :, indices]
v_out_shifted = v_cache[:, :, indices]

# Clamp cache_position to determine the *target index* within the shifted cache view
update_position = cache_position.clamp(min=0, max=max_cache_len - 1)
Comment thread
manueldeprada marked this conversation as resolved.

try:
k_out_updated = k_out_shifted.index_copy(2, update_position, key_states)
v_out_updated = v_out_shifted.index_copy(2, update_position, value_states)
except NotImplementedError:
# Fallback for MPS: clone and modify the clone
k_out_updated = k_out_shifted.clone()
v_out_updated = v_out_shifted.clone()
Comment thread
manueldeprada marked this conversation as resolved.
k_out_updated[:, :, update_position] = key_states
v_out_updated[:, :, update_position] = value_states

k_cache.copy_(k_out_updated)
v_cache.copy_(v_out_updated)
Comment thread
manueldeprada marked this conversation as resolved.
return k_out_updated, v_out_updated


class Cache:
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
Expand Down Expand Up @@ -1264,28 +1362,16 @@ def update(
"""
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)

if cache_position is None:
k_out.copy_(key_states)
v_out.copy_(value_states)
else:
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
# operation, that avoids copies and uses less memory.
try:
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

return k_out, v_out
key_states = key_states.to(self.key_cache[layer_idx].dtype)
value_states = value_states.to(self.value_cache[layer_idx].dtype)
Comment on lines +1366 to +1367
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype should already be correct here no?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to remove it, passes tests. There are similar checks or casts that could be removed too, but I kept them in case existing code relies on them.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the cast, yes. Looking at the original PR, the source of the lines was to handle the case where we don't cast RoPE-based KVs in the model forward pass (RoPE is by default FP32, regardless of the model dtype)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the pointer @gante, I should have traced that down. We can't remove it: the PR's sample code fails after removing the casts. I am restoring the cast and adding a test...

I agree though with @Cyrilvallez that it is a bad solution to cast everything instead of doing something specific for RoPE. Since this PR has limited scope and is ground work for #38077, I will try to solve it more elegantly there.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[perhaps for a subsequent PR, to avoid bloating/delaying this one:]

It would be more transparent and precise if casting is done in the model architecture, rather than in the cache. In the specific case of GPT-J loaded in FP16, it seems like without cache KV is kept in FP32, and with cache KV is casted to FP16 in the cache class -> cache introduces performance degradation.

As such, I think it would be positive to remove the cast, and delegate control to the model architectures.

Copy link
Copy Markdown
Contributor Author

@manueldeprada manueldeprada May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, I will check if its only a gptJ thing in a subsequent PR. In the meantime, I added the test. It unveiled 3 fixes like this that were applied to StaticCache but not to Hybrid and Offloaded. Please have a quick look at 772b0a0 before I merge.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially if it's only there for a given old model -> much better to fix the model rather than general cache logic!

return _static_cache_update(
self.key_cache[layer_idx],
self.value_cache[layer_idx],
key_states,
value_states,
cache_kwargs.get("cache_position"),
)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
Expand Down Expand Up @@ -1314,7 +1400,7 @@ class SlidingWindowCache(StaticCache):

The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:

indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
Expand Down Expand Up @@ -1398,46 +1484,21 @@ def update(
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)

# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
if cache_position.shape[0] >= self.max_cache_len:
k_out = key_states[:, :, -self.max_cache_len :, :]
v_out = value_states[:, :, -self.max_cache_len :, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states

slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
to_shift = cache_position > self.max_cache_len - 1
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len

k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]

try:
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
if cache_position is None:
raise ValueError("`cache_position` must be provided for SlidingWindowCache.")

self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
key_states = key_states.to(self.key_cache[layer_idx].dtype)
value_states = value_states.to(self.value_cache[layer_idx].dtype)

return k_out, v_out
return _sliding_cache_update(
self.key_cache[layer_idx],
self.value_cache[layer_idx],
key_states,
value_states,
cache_position,
self.max_cache_len,
)

def get_max_cache_shape(self) -> Optional[int]:
return self.max_cache_len
Expand Down Expand Up @@ -1680,12 +1741,13 @@ def __init__(
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"Setting `cache_implementation` to 'hybrid' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self._sliding_window_max_len = min(config.sliding_window, max_cache_len)
self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings
# Sliding layers can't be larger than the overall max cache len
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
Expand All @@ -1694,22 +1756,17 @@ def __init__(

self._dtype = dtype
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)

layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
Comment thread
manueldeprada marked this conversation as resolved.
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
sliding_cache_shape = (
self.max_batch_size,
self.num_key_value_heads,
self._sliding_window_max_len,
self.head_dim,
)
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim)
device = torch.device(device) if device is not None else None
for i in range(config.num_hidden_layers):
if layer_device_map is not None:
Expand All @@ -1718,50 +1775,14 @@ def __init__(
layer_device = device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
if cache_position.shape[0] >= max_cache_len:
k_out = key_states[:, :, -max_cache_len:, :]
v_out = value_states[:, :, -max_cache_len:, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states

slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
to_shift = cache_position > max_cache_len - 1
cache_position = cache_position.clamp(0, max_cache_len - 1)
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]

k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out

def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out

def update(
self,
key_states: torch.Tensor,
Expand All @@ -1772,7 +1793,10 @@ def update(
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
if cache_position is None:
raise ValueError("`cache_position` must be provided for HybridCache.")

is_sliding_layer = self.is_sliding_list[layer_idx]
Comment thread
manueldeprada marked this conversation as resolved.

# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
# when the cache is initialized in the forward pass (e.g. Gemma2)
Expand All @@ -1781,25 +1805,22 @@ def update(
if self.value_cache[layer_idx].device != value_states.device:
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)

k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)

if sliding_window:
update_fn = self._sliding_update
k_cache = self.key_cache[layer_idx]
v_cache = self.value_cache[layer_idx]
key_states = key_states.to(k_cache.dtype)
value_states = value_states.to(v_cache.dtype)

if is_sliding_layer:
Comment thread
manueldeprada marked this conversation as resolved.
return _sliding_cache_update(
k_cache,
v_cache,
key_states,
value_states,
cache_position,
k_cache.shape[2], # Use actual cache dim as max cache len
)
else:
update_fn = self._static_update

return update_fn(
cache_position,
layer_idx,
key_states,
value_states,
k_out,
v_out,
k_out.shape[2],
)
return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position)

def get_max_cache_shape(self) -> Optional[int]:
return self.max_cache_len
Expand Down Expand Up @@ -2033,7 +2054,7 @@ def __init__(

# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
# track of the original device of each layer
unique_devices = set(layer_device_map.values())
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
if len(unique_devices) > 1:
raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}")

Expand Down Expand Up @@ -2292,7 +2313,7 @@ def __init__(

# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
# track of the original device of each layer
unique_devices = set(layer_device_map.values())
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
if len(unique_devices) > 1:
raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}")

Expand Down Expand Up @@ -2369,6 +2390,9 @@ def update(
A tuple containing the updated key and value states.
"""

key_states = key_states.to(self.key_cache[layer_idx].dtype)
value_states = value_states.to(self.value_cache[layer_idx].dtype)

if layer_idx == 0:
# Update seen tokens.
# TODO(gante): Remove this.
Expand Down
Loading