Skip to content

Adjust Positional Embedding For Compression #16

@giulio98

Description

@giulio98

Feature

Adjust the positional encoding when compressing the cache to improve output quality for long sequences. Based on our experiments, this adjustment significantly reduces gibberish output and enhances performance for extended contexts. The proposed feature integrates seamlessly with existing caching mechanisms and dynamically adjusts cosine and sine positional embeddings during cache compression.

Motivation

Long sequences often result in incoherent outputs when the cache is compressed without considering the positional encodings. This issue arises due to misalignment of the positional embeddings, which leads to inaccurate attention calculations. By adjusting the positional embeddings during compression, we observed a dramatic improvement in model performance and output quality.

This proposal aligns with our experimental findings and can be implemented by extending the DynamicCache class from transformers. Below is an example implementation called FinchCache, which adjusts the positional encodings during cache compression:

from transformers import DynamicCache
import torch

class FinchCache(DynamicCache):
    def __init__(self) -> None:
        super().__init__()
        self._cos_cache = None
        self._sin_cache = None
        self.key_cache = []
        self.value_cache = []

    @staticmethod
    def _rotate_half(x):
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat((-x2, x1), dim=-1)

    def _apply_key_rotary_pos_emb(self, key_states, cos, sin):
        rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
        return rotated_key_states

    @staticmethod
    def _rerotate_cos_sin(x, inv_freq, important_pos_batch):
        batch_size, seq_len = important_pos_batch.shape
        idx = torch.arange(0, seq_len, dtype=torch.long, device=important_pos_batch.device).unsqueeze(0).expand(batch_size, -1)
        delta_pos = idx - important_pos_batch
        inv_freq_expanded = inv_freq[None, :, None].float().expand(batch_size, -1, 1)
        delta_pos_expanded = delta_pos[:, None, :].float()
        freqs = (inv_freq_expanded.float() @ delta_pos_expanded.float()).transpose(1, 2)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos()
        sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

    @staticmethod
    def gather_important_tokens(states, indices):
        return torch.gather(states, 2, indices.unsqueeze(1).unsqueeze(-1).expand(-1, states.size(1), -1, states.size(3)))

    def compress_cache(self, important_indices, inv_freq):
        new_length = important_indices[0].size(1)
        for layer_idx in range(len(self.key_cache)):
            important_pos = important_indices[layer_idx]
            new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_idx], inv_freq, important_pos)
            gathered_keys = self.gather_important_tokens(self.key_cache[layer_idx], important_pos)
            gathered_values = self.gather_important_tokens(self.value_cache[layer_idx], important_pos)
            self.key_cache[layer_idx] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin)
            self.value_cache[layer_idx] = gathered_values
        self._cos_cache = self._cos_cache[:new_length, :]
        self._sin_cache = self._sin_cache[:new_length, :]
        self._seen_tokens = new_length

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `Cache`.

        Return:
            A tuple containing the updated key and value states.
        """
        sin = cache_kwargs.get("sin")
        cos = cache_kwargs.get("cos")
        using_rope = cos is not None and sin is not None
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]
        
        if using_rope and layer_idx == 0:
            # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
            # after all RoPE models have a llama-like cache utilization.
            if cos.dim() == 2:
                self._cos_cache = cos
                self._sin_cache = sin
            else:
                if self._cos_cache is None:
                    self._cos_cache = cos[0, ...]
                    self._sin_cache = sin[0, ...]
                else:
                    self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
                    self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
        
        
        if len(self.key_cache) <= layer_idx:
            # There may be skipped layers, fill them with empty lists
            for _ in range(len(self.key_cache), layer_idx):
                self.key_cache.append([])
                self.value_cache.append([])
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        elif len(self.key_cache[layer_idx]) == 0:  # fills previously skipped layers; checking for tensor causes errors
            self.key_cache[layer_idx] = key_states
            self.value_cache[layer_idx] = value_states
        else:
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

The compress_cache function has to be called after the topk indices selection and can be integrated with any presses: performs cache adjustment by taking the indices (from a selection process) and the inverse frequencies of the original position embeddings (model.model.layers[0].self_attn.rotary_emb.inv_freq).

This proposal addresses a significant limitation in handling long sequences and can benefit all users working with compressed cache scenarios.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions