diff --git a/README.md b/README.md index 95f0e57b..2b9e59cf 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ![kvpress](kvpress.jpg) -Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. This repository implements multiple KV cache pruning methods and benchmarks using [🤗 transformers](https://huggingface.co/docs/transformers/en/index), aiming to simplify the development of new methods for researchers and developers in this field. +Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. This repository implements multiple KV cache compression methods and benchmarks using [🤗 transformers](https://huggingface.co/docs/transformers/en/index), aiming to simplify the development of new methods for researchers and developers in this field. ## Installation @@ -60,7 +60,7 @@ All current presses are training free. Several of them inherit from `ScorerPress - `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104)) - `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048)) -Some presses relying on a different logic: +Some presses rely on a different logic: - `ThinKPress`: compress the dimensions of the keys based on the channel attention score on the last queries ([paper](https://arxiv.org/pdf/2407.21018)) - `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)) @@ -101,7 +101,7 @@ pipe(..., cache=cache) By default, the `DynamicCache` is used (no quantization). > [!IMPORTANT] -> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)). +> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`). ## FAQ diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index b82018b2..3b2b9b23 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -23,6 +23,8 @@ RandomPress, SnapKVPress, StreamingLLMPress, + ThinKPress, + TOVAPress, ) logger = logging.getLogger(__name__) @@ -48,6 +50,8 @@ "random": RandomPress(), "snapkv": SnapKVPress(), "streaming_llm": StreamingLLMPress(), + "think": ThinKPress(), + "tova": TOVAPress(), } @@ -110,6 +114,12 @@ def evaluate( df = load_dataset(DATASET_DICT[dataset], data_dir=data_dir, split="test").to_pandas() if fraction < 1.0: df = df.sample(frac=fraction, random_state=42) + save_filename = save_filename.with_name(save_filename.stem + f"__fraction{fraction:.2f}" + save_filename.suffix) + + if max_context_length is not None: + save_filename = save_filename.with_name( + save_filename.stem + f"__max_context{max_context_length}" + save_filename.suffix + ) if compress_questions: df["context"] = df["context"] + df["question"] @@ -119,27 +129,24 @@ def evaluate( # Load press assert press_name in PRESS_DICT press = PRESS_DICT[press_name] - press.compression_ratio = compression_ratio + press.compression_ratio = compression_ratio # type:ignore[attr-defined] # Initialize pipeline with the correct attention implementation + model_kwargs = {"torch_dtype": "auto"} if isinstance(press, ObservedAttentionPress): - model_kwargs = {"attn_implementation": "eager"} + model_kwargs["attn_implementation"] = "eager" else: try: import flash_attn # noqa: F401 - model_kwargs = {"attn_implementation": "flash_attention_2"} + model_kwargs["attn_implementation"] = "flash_attention_2" except ImportError: - model_kwargs = {} + pass if device == "auto": - pipe = pipeline( - "kv-press-text-generation", model=model, device_map="auto", torch_dtype="auto", model_kwargs=model_kwargs - ) + pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs) else: - pipe = pipeline( - "kv-press-text-generation", model=model, device=device, torch_dtype="auto", model_kwargs=model_kwargs - ) + pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) # Run pipeline on each context df["predicted_answer"] = None diff --git a/kvpress/__init__.py b/kvpress/__init__.py index e2693aaf..a285240e 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -18,6 +18,7 @@ from kvpress.presses.think_press import ThinKPress from kvpress.presses.tova_press import TOVAPress + __all__ = [ "BasePress", "ComposedPress", diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 8a1bea2f..88a2d025 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -86,17 +86,13 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic Modified output of the forward pass of the layer. """ - # See e.g. LlamaDecoderLayer.forward for the output structure - if len(output) == 3: - _, attentions, cache = output - else: - attentions, cache = None, output[-1] hidden_states = kwargs["hidden_states"] + cache = kwargs["past_key_value"] q_len = hidden_states.shape[1] # Don't compress after pre-filling - if cache.seen_tokens > q_len: + if kwargs["cache_position"][-1] > q_len: return output if isinstance(cache, QuantizedCache): @@ -106,7 +102,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic keys = cache.key_cache[module.layer_idx] values = cache.value_cache[module.layer_idx] - keys, values = self.compress(module, hidden_states, keys, values, attentions, kwargs) + keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs) if isinstance(cache, QuantizedCache): cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) @@ -138,8 +134,8 @@ def __call__(self, model: PreTrainedModel) -> Generator: hooks = [] try: for layer in model.model.layers: + layer.self_attn.rotary_emb = model.model.rotary_emb hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) - yield finally: for forward_hook in hooks: diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py index 7fdf8fba..1c2b2bdc 100644 --- a/kvpress/presses/composed_press.py +++ b/kvpress/presses/composed_press.py @@ -15,8 +15,8 @@ class ComposedPress(BasePress): def __post_init__(self): self.compression_ratio = None assert not any( - isinstance(press, ObservedAttentionPress) for press in self.presses - ), "ComposedPress cannot contains ObservedAttentionPress because attentions pruning is not handled" + isinstance(press, (ObservedAttentionPress)) for press in self.presses + ), "ComposedPress cannot contains ObservedAttentionPress" def forward_hook(self, module, input, kwargs, output): self.compression_ratio = 1.0 diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index 747a9597..3b1695e7 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import inspect import math from dataclasses import dataclass @@ -39,7 +38,7 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): """ bsz, q_len, _ = hidden_states.shape - n, d = module.num_heads, module.head_dim + n, d = module.config.num_attention_heads, module.head_dim # Remove first hidden_states that likely contain outliers h = hidden_states[:, self.n_sink :] @@ -66,13 +65,9 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): cov = cov.permute(0, 3, 1, 2) # RoPE rotation matrix on next n_future_positions - if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: - position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device) - cos, sin = module.rotary_emb(mu, position_ids) - cos, sin = cos[0], sin[0] - else: - cos, sin = module.rotary_emb(mu, q_len + self.n_future_positions) - cos, sin = cos[q_len:], sin[q_len:] + position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device) + cos, sin = module.rotary_emb(mu, position_ids) + cos, sin = cos[0], sin[0] Id = torch.eye(d, device=cos.device, dtype=cos.dtype) P = torch.zeros((d, d), device=cos.device, dtype=cos.dtype) @@ -117,14 +112,16 @@ def score( # Compute scores bsz, num_key_value_heads, q_len, d = keys.shape - keys = repeat_kv(keys, module.num_key_value_groups).transpose(2, 3) + num_key_value_groups = module.config.num_attention_heads // num_key_value_heads + + keys = repeat_kv(keys, num_key_value_groups).transpose(2, 3) scores = torch.matmul(mean_query.unsqueeze(2), keys).squeeze(2) / math.sqrt(d) if self.use_covariance: scores += torch.einsum("bhin, bhij, bhjn->bhn", keys, cov_query, keys) / d / 2 scores = F.softmax(scores, dim=-1) # Average scores across groups - scores = scores.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len) + scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len) scores = scores.mean(dim=2) # Rescale scores by the norm of the values diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 514178a5..0675fbcb 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import inspect from dataclasses import dataclass import torch @@ -28,6 +27,9 @@ class KeyRerotationPress(BasePress): press: ScorerPress + def __post_init__(self): + assert isinstance(self.press, ScorerPress) + def compress( self, module: nn.Module, @@ -49,7 +51,7 @@ def compress( indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) - cos, sin = get_rope_embeddings(module, keys) + cos, sin = kwargs["position_embeddings"] # Rerotate as follows # 1. keys = RoPE(W_k * hidden_states) # 2. keys_unrotated = RoPE^-1(keys) @@ -61,19 +63,8 @@ def compress( # 3. Prune keys keys = keys.gather(2, indices).contiguous() # 4. Apply RoPE - cos, sin = get_rope_embeddings(module, keys) + cos, sin = cos[:, :n_kept], sin[:, :n_kept] keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) values = values.gather(2, indices).contiguous() return keys, values - - -def get_rope_embeddings(module, x): - length = x.shape[2] - # rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape - if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: - position_ids = torch.arange(length).unsqueeze(0).to(x.device) - cos, sin = module.rotary_emb(x, position_ids) - else: - cos, sin = module.rotary_emb(x, length) - return cos, sin diff --git a/kvpress/presses/observed_attention_press.py b/kvpress/presses/observed_attention_press.py index d11d2f01..4e9e78c9 100644 --- a/kvpress/presses/observed_attention_press.py +++ b/kvpress/presses/observed_attention_press.py @@ -54,8 +54,6 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic # attentions are needed as input for the hook, but unless the user wants to return them in the output, # we can remove them to save memory if not self.output_attentions: - output = list(output) - output[-2] = None - output = tuple(output) + output = (output[0], None) return output diff --git a/kvpress/presses/per_layer_compression_press.py b/kvpress/presses/per_layer_compression_press.py index 0e497375..80c6db77 100644 --- a/kvpress/presses/per_layer_compression_press.py +++ b/kvpress/presses/per_layer_compression_press.py @@ -32,6 +32,7 @@ def __post_init__(self): self.press.__init__ # type:ignore[misc] ).parameters ), f"compression_ratio can't be set in the provided press: {self.press.__class__}" + assert isinstance(self.press, ScorerPress), "PerLayerCompressionPress requires a ScorerPress as input" def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): original_compression_ratio = self.press.compression_ratio # type:ignore[attr-defined] diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 8693015c..6113899e 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -43,13 +43,16 @@ def is_lazy( module: nn.Module, hidden_states: torch.Tensor, keys: torch.Tensor, + position_embeddings: torch.Tensor, ) -> bool: """ Compute the attention weights of the last tokens over the initial and recent tokens. The layer is considered lazy if the sum of these attention weights is above the lazy_threshold. """ - attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, self.n_last) + attn_weights = SnapKVPress.compute_window_attention( + module, hidden_states, keys, self.n_last, position_embeddings + ) attn_weights = attn_weights.mean((0, 1, 2)) # mean over bsz, heads and window size score = attn_weights[: self.n_initial].sum() + attn_weights[-self.n_recent :].sum() return score.item() > self.lazy_threshold @@ -91,7 +94,7 @@ def compress( return keys, values # Compression - if self.is_lazy(module, hidden_states, keys): + if self.is_lazy(module, hidden_states, keys, kwargs["position_embeddings"]): # If layer is lazy, only keep the initial and recent KV pairs keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2) values = torch.cat([values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2) diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 9265862e..371e9368 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -26,34 +26,35 @@ class SnapKVPress(ScorerPress): kernel_size: int = 5 @staticmethod - def compute_window_attention( - module: nn.Module, hidden_states: torch.Tensor, keys: torch.Tensor, window_size: int - ) -> torch.Tensor: + def compute_window_attention(module, hidden_states, keys, window_size, position_embeddings): """ Compute the last window_size queries and associated attention weights for the first q_len - window_size keys. """ bsz, q_len, _ = hidden_states.shape + num_heads = module.config.num_attention_heads + head_dim = module.head_dim + num_key_value_groups = num_heads // module.config.num_key_value_heads # Get last window_size queries if hasattr(module, "q_proj"): query_states = module.q_proj(hidden_states[:, -window_size:]) elif hasattr(module, "qkv_proj"): qkv = module.qkv_proj(hidden_states[:, -window_size:]) - query_states = qkv[..., : module.num_heads * module.head_dim] + query_states = qkv[..., : num_heads * head_dim] else: raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") - query_states = query_states.view(bsz, window_size, module.num_heads, module.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2) # Apply RoPE - position_ids = torch.arange(q_len - window_size, q_len).unsqueeze(0).to(query_states.device) - cos, sin = module.rotary_emb(query_states, position_ids) + cos, sin = position_embeddings + cos, sin = cos[:, -window_size:], sin[:, -window_size:] query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1)) # Compute attention for first q_len - window_size tokens - key_states = repeat_kv(keys, module.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + key_states = repeat_kv(keys, num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim) attention_mask = torch.ones_like(attn_weights) * float("-inf") attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1) attn_weights += attention_mask @@ -73,19 +74,22 @@ def score( ) -> torch.Tensor: bsz, num_key_value_heads, q_len, _ = keys.shape + num_key_value_groups = module.config.num_attention_heads // num_key_value_heads assert q_len > self.window_size, "Query length should be greater than the window size" if attentions is not None: attn_weights = attentions[..., -self.window_size :, : -self.window_size] else: - attn_weights = self.compute_window_attention(module, hidden_states, keys, self.window_size) + attn_weights = self.compute_window_attention( + module, hidden_states, keys, self.window_size, kwargs["position_embeddings"] + ) scores = attn_weights.mean(dim=-2) scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1) # Average per grioup (https://github.com/FasterDecoding/SnapKV/issues/22) - scores = scores.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len - self.window_size) + scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.window_size) scores = scores.mean(2) # Add back the observation window. Use max score to make sure the window is not pruned. diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index 2f882e56..6f1b829b 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -28,26 +28,28 @@ class ThinKPress(BasePress): key_channel_compression_ratio: float = 0.0 window_size: int = 32 - def compute_window_queries(self, module, hidden_states): + def compute_window_queries(self, module, hidden_states, position_embeddings): """ Re-compute the last window_size query states """ bsz, q_len, _ = hidden_states.shape + num_heads = module.config.num_attention_heads + head_dim = module.head_dim # Get last window_size queries if hasattr(module, "q_proj"): query_states = module.q_proj(hidden_states[:, -self.window_size :]) elif hasattr(module, "qkv_proj"): qkv = module.qkv_proj(hidden_states[:, -self.window_size :]) - query_states = qkv[..., : module.num_heads * module.head_dim] + query_states = qkv[..., : num_heads * head_dim] else: raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") - query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, self.window_size, num_heads, head_dim).transpose(1, 2) # Apply RoPE - position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device) - cos, sin = module.rotary_emb(query_states, position_ids) + cos, sin = position_embeddings + cos, sin = cos[:, -self.window_size:], sin[:, -self.window_size:] query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1)) return query_states @@ -71,9 +73,11 @@ def compress( # Compute scores per dimension bsz, num_key_value_heads, q_len, head_dim = keys.shape - queries = self.compute_window_queries(module, kwargs["hidden_states"]) + num_key_value_groups = module.config.num_attention_heads // num_key_value_heads + + queries = self.compute_window_queries(module, kwargs["hidden_states"], kwargs["position_embeddings"]) queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim) - queries_norm = queries_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, module.head_dim).mean(2) + queries_norm = queries_norm.view(bsz, num_key_value_heads, num_key_value_groups, module.head_dim).mean(2) keys_norm = torch.pow(keys, 2).mean(dim=2) key_scores = queries_norm * keys_norm # (bsz, num_key_value_heads, head_dim) diff --git a/kvpress/presses/tova_press.py b/kvpress/presses/tova_press.py index 6a9eb8e0..5ca96cd1 100644 --- a/kvpress/presses/tova_press.py +++ b/kvpress/presses/tova_press.py @@ -37,7 +37,9 @@ def score( if attentions is not None: attn_weights = attentions[..., -1:, :-1] else: - attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, 1) + attn_weights = SnapKVPress.compute_window_attention( + module, hidden_states, keys, 1, kwargs["position_embeddings"] + ) # Average across heads and repeat num_key_value_head times scores = attn_weights.mean(1) diff --git a/notebooks/new_press.ipynb b/notebooks/new_press.ipynb index c9ed6769..3ffb2799 100644 --- a/notebooks/new_press.ipynb +++ b/notebooks/new_press.ipynb @@ -157,7 +157,7 @@ " # For demonstration, we show some details on the shape for the first layer\n", " if module.layer_idx == 0:\n", " print(f\"module: {module}\")\n", - " print(f\"Number of key value heads: {module.num_key_value_heads}\")\n", + " print(f\"Number of key value heads: {module.config.num_key_value_heads}\")\n", " print(f\"Sequence length: {hidden_states.shape[1]}\")\n", " print()\n", " print(f\"hidden_states shape: {hidden_states.shape}\")\n", diff --git a/pyproject.toml b/pyproject.toml index 256f0bb0..4ded782b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "kvpress" authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"] description = "Efficiently compress the KV cache of any pretrained transformer" -version = "0.1.1" +version = "0.2.0" readme = "README.md" [tool.poetry.dependencies] @@ -14,7 +14,7 @@ scipy = "^1.13.1" matplotlib = "^3.9.0" bs4 = "^0.0.2" torch = "^2.3.1" -transformers = ">=4.45.1 <4.48" +transformers = ">=4.48.0" nvitop = "^1.3.2" sentencepiece = "^0.2.0" protobuf = "^5.27.2" diff --git a/tests/presses/test_key_rerotation_press_rope.py b/tests/presses/test_key_rerotation_press_rope.py index b612a6b8..f890dc6f 100644 --- a/tests/presses/test_key_rerotation_press_rope.py +++ b/tests/presses/test_key_rerotation_press_rope.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 +import inspect from dataclasses import dataclass import pytest @@ -10,7 +11,6 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, rotate_half from kvpress import KeyRerotationPress, ScorerPress -from kvpress.presses.key_rerotation_press import get_rope_embeddings from tests.fixtures import unit_test_model # noqa: F401 @@ -31,21 +31,27 @@ def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: Llam original_press = RandomPressStoreIndices(compression_ratio=0.5) key_rerotation_press = KeyRerotationPress(press=original_press) - module = unit_test_model.model.layers[0].self_attn - hidden_states = torch.randn( - 8, 64, module.config.hidden_size, device=unit_test_model.device, dtype=unit_test_model.dtype - ) + with key_rerotation_press(unit_test_model): + module = unit_test_model.model.layers[0].self_attn + hidden_states = torch.randn( + 8, 64, module.config.hidden_size, device=unit_test_model.device, dtype=unit_test_model.dtype + ) - keys = get_keys_with_rope(module, hidden_states) + keys = get_keys_with_rope(module, hidden_states) - values = torch.randn_like(keys) - # Press result - keys_compressed, _ = key_rerotation_press.compress( - module, hidden_states, keys, values, attentions=None, kwargs=dict() - ) + values = torch.randn_like(keys) + # Press result + keys_compressed, _ = key_rerotation_press.compress( + module, + hidden_states, + keys, + values, + attentions=None, + kwargs={"position_embeddings": get_rope_embeddings(module, keys)}, + ) - indices = original_press.indices - keys_compressed_ref = compute_rerotated_keys_comparison_implementation(module, hidden_states, indices) + indices = original_press.indices + keys_compressed_ref = compute_rerotated_keys_comparison_implementation(module, hidden_states, indices) assert torch.allclose(keys_compressed, keys_compressed_ref, atol=1e-6 if precision == "full" else 1e-3) @@ -108,6 +114,17 @@ def compute_rerotated_keys_comparison_implementation(module: LlamaAttention, hid def get_keys_without_pos_embedding(module, hidden_states): key_states = module.k_proj(hidden_states) key_states = key_states.view( - key_states.shape[0], key_states.shape[1], module.num_key_value_heads, module.head_dim + key_states.shape[0], key_states.shape[1], module.config.num_key_value_heads, module.head_dim ).transpose(1, 2) return key_states + + +def get_rope_embeddings(module, x): + length = x.shape[2] + # rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape + if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: + position_ids = torch.arange(length).unsqueeze(0).to(x.device) + cos, sin = module.rotary_emb(x, position_ids) + else: + cos, sin = module.rotary_emb(x, length) + return cos, sin diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 50056b9a..2a96e5c8 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -7,9 +7,14 @@ from torch import nn from transformers import DynamicCache -from kvpress import ComposedPress, KeyRerotationPress, KnormPress, ObservedAttentionPress -from kvpress.presses.scorer_press import ScorerPress -from kvpress.presses.think_press import ThinKPress +from kvpress import ( + ComposedPress, + KeyRerotationPress, + KnormPress, + ObservedAttentionPress, + ThinKPress, + ScorerPress, +) from tests.default_presses import default_presses from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 96993492..5e371db5 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -114,13 +114,14 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 model = unit_test_model questions = ["When was this article written?"] tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=torch.device("cpu")) - input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"] + compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device) + input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"].to(device) seq_len = 256 past_key_values: DynamicCache = model( - input_ids=torch.randint(0, 1000, (1, seq_len)), past_key_values=DynamicCache() + input_ids=torch.randint(0, 1000, (1, seq_len), device=device), past_key_values=DynamicCache() ).past_key_values assert past_key_values.get_seq_length() == seq_len