From 879d8c733e2c2ca73beaab41f84ff83b6d4ac558 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 30 Apr 2024 17:28:42 +0000 Subject: [PATCH 1/2] tmp commit --- src/transformers/cache_utils.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ceca9d3eeb35..5a3a314aba2b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -207,7 +207,9 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None: self.value_cache: List[torch.Tensor] = [] self.window_length = window_length self.num_sink_tokens = num_sink_tokens - self.cos_sin_cache = {} + self.cos_sin_rerotation_cache = {} + self._cos_cache = None + self._sin_cache = None self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen @staticmethod @@ -225,7 +227,7 @@ def _apply_key_rotary_pos_emb( def _get_rerotation_cos_sin( self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - if key_states.shape[-2] not in self.cos_sin_cache: + if key_states.shape[-2] not in self.cos_sin_rerotation_cache: # Upcast to float32 temporarily for better accuracy cos = cos.to(torch.float32) sin = sin.to(torch.float32) @@ -238,11 +240,11 @@ def _get_rerotation_cos_sin( rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin - self.cos_sin_cache[key_states.shape[-2]] = ( + self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( rerotation_cos.to(key_states.dtype).unsqueeze(0), rerotation_sin.to(key_states.dtype).unsqueeze(0), ) - return self.cos_sin_cache[key_states.shape[-2]] + return self.cos_sin_rerotation_cache[key_states.shape[-2]] def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" @@ -292,6 +294,19 @@ def update( if layer_idx == 0: self._seen_tokens += key_states.shape[-2] + # Update the sin/cos cache, which holds sin/cos values for all possible positions + if using_rope and layer_idx==0: + # BC: some models still pass `sin`/`cos` with 2 dims. The expected is 3 dims, and no `if` to be needed. + if cos.dim() > 2: + cos = cos[..., :, :] + sin = sin[..., :, :] + if self._cos_cache is None: + self._cos_cache = cos + self._sin_cache = sin + elif self._cos_cache.shape[0] < self.window_length: + self._cos_cache = torch.cat([self._cos_cache, cos], dim=0) + self._sin_cache = torch.cat([self._sin_cache, sin], dim=0) + # [bsz, num_heads, seq_len, head_dim] if len(self.key_cache) <= layer_idx: # Empty cache @@ -312,7 +327,7 @@ def update( # On RoPE models, we need to recompute the Key rotation as the tokens are shifted if using_rope: rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( - key_states, cos[: self.window_length], sin[: self.window_length] + key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] ) if partial_rotation_size is not None: keys_to_keep, keys_pass = ( From d19bc053f2a2c4325c2607048dcafa7768b8ab58 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 30 Apr 2024 18:06:13 +0000 Subject: [PATCH 2/2] passing sink cache tests --- src/transformers/cache_utils.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 5a3a314aba2b..2e29e19ade46 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -295,17 +295,19 @@ def update( self._seen_tokens += key_states.shape[-2] # Update the sin/cos cache, which holds sin/cos values for all possible positions - if using_rope and layer_idx==0: - # BC: some models still pass `sin`/`cos` with 2 dims. The expected is 3 dims, and no `if` to be needed. - if cos.dim() > 2: - cos = cos[..., :, :] - sin = sin[..., :, :] - if self._cos_cache is None: + if using_rope and layer_idx == 0: + # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove + # after all RoPE models have a llama-like cache utilization. + if cos.dim() == 2: self._cos_cache = cos self._sin_cache = sin - elif self._cos_cache.shape[0] < self.window_length: - self._cos_cache = torch.cat([self._cos_cache, cos], dim=0) - self._sin_cache = torch.cat([self._sin_cache, sin], dim=0) + else: + if self._cos_cache is None: + self._cos_cache = cos[0, ...] + self._sin_cache = sin[0, ...] + elif self._cos_cache.shape[0] < self.window_length: + self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) + self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) # [bsz, num_heads, seq_len, head_dim] if len(self.key_cache) <= layer_idx: