Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
self.value_cache: List[torch.Tensor] = []
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_cache = {}
self.cos_sin_rerotation_cache = {}
self._cos_cache = None
self._sin_cache = None
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

@staticmethod
Expand All @@ -225,7 +227,7 @@ def _apply_key_rotary_pos_emb(
def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_cache:
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
# Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
Expand All @@ -238,11 +240,11 @@ def _get_rerotation_cos_sin(
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin

self.cos_sin_cache[key_states.shape[-2]] = (
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0),
)
return self.cos_sin_cache[key_states.shape[-2]]
return self.cos_sin_rerotation_cache[key_states.shape[-2]]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
Expand Down Expand Up @@ -292,6 +294,21 @@ def update(
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]

# Update the sin/cos cache, which holds sin/cos values for all possible positions
if using_rope and layer_idx == 0:
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
# after all RoPE models have a llama-like cache utilization.
if cos.dim() == 2:
self._cos_cache = cos
self._sin_cache = sin
else:
if self._cos_cache is None:
self._cos_cache = cos[0, ...]
self._sin_cache = sin[0, ...]
elif self._cos_cache.shape[0] < self.window_length:
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
Comment on lines +308 to +310
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my own understanding of how the cache is meant to work, I have two Qs:

  1. Values passed in on update call
    if we call update with sin and cos passed in, is the cache keeping old values + new values i.e. self._cos_cache[:self._cos_cache_prev.shape[0]] are the old values and self._cos_cache[self._cos_cache_prev.shape[0]:] is the new values, or the passed in cos is just the new values to be appended?

  2. Window length
    Is the assumption here that the window length is constant once the cache is created?

Copy link
Copy Markdown
Contributor Author

@gante gante May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts

  1. The values passed in cos are new values to be appended. In RoPE models, sin and cos are a constant with shape [config.max_position_embeddings, rope_embedding_dims, config.hidden_size // config.num_attention_heads]. However, with the compile-optimized modeling code, we only materialize the needed parts of these matrices, with shape[0] = input_ids.shape[1] = input sequence length. Since SinkCache needs access to all sin and cos values up to shape[0] = self.window_length when going beyond the window length, this cache was created.

Alternatively, we could pass the the model config to compute the full sin and cos, but that would be (IMO) an ugly interface (we would have to use the model config to instantiate a RoPE layer inside the cache, to then compute these values and discard the layer).


  1. Yes. SinkCache is a fixed-length cache -- its purpose is to be used with self.window_length < config.max_position_embeddings, while enabling coherent outputs beyond full sequence length = self.window_length. In other words, coherent long outputs with a relatively short cache :) Its limitation is that it can only recall content back up to the size of the window length, it quickly forgets things.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it - thanks for taking the time to write this up and explain!


# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
# Empty cache
Expand All @@ -312,7 +329,7 @@ def update(
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
key_states, cos[: self.window_length], sin[: self.window_length]
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
)
if partial_rotation_size is not None:
keys_to_keep, keys_pass = (
Expand Down