diff --git a/README.md b/README.md index b72ec054..4ebf50b0 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ All current presses are training free. We provide the following presses associat - `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb)) - `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453)) - `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104)) +- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 2f1e0409..5f7fc8cb 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -12,6 +12,7 @@ from kvpress.presses.snapkv_press import SnapKVPress from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.tova_press import TOVAPress +from kvpress.presses.think_press import ThinKPress __all__ = [ "BasePress", @@ -21,6 +22,7 @@ "RandomPress", "SnapKVPress", "StreamingLLMPress", + "ThinKPress", "TOVAPress", "KVPressTextGenerationPipeline", "apply_per_layer_compression", diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 7d450475..05753a14 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import inspect import math from dataclasses import dataclass @@ -45,13 +44,9 @@ def compute_window_attention(self, module, hidden_states, keys): query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2) # Apply RoPE - if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: - position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device) - cos, sin = module.rotary_emb(query_states, position_ids) - else: - cos, sin = module.rotary_emb(query_states, q_len) - cos, sin = cos[-self.window_size :].unsqueeze(0), sin[-self.window_size :].unsqueeze(0) - query_states = (query_states * cos) + (rotate_half(query_states) * sin) + position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device) + cos, sin = module.rotary_emb(query_states, position_ids) + query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1)) # Compute attention for first q_len - window_size tokens key_states = repeat_kv(keys, module.num_key_value_groups) diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py new file mode 100644 index 00000000..7aae5410 --- /dev/null +++ b/kvpress/presses/think_press.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn +from transformers.cache_utils import QuantizedCache +from transformers.models.llama.modeling_llama import rotate_half + +from kvpress.presses.base_press import BasePress + + +@dataclass +class ThinKPress(BasePress): + """ + ThinK (https://arxiv.org/pdf/2407.21018) compresses the dimensions of the keys, and not the sequence length. + Hence it can be combined with any other press that compresses the sequence length, e.g. + press = ThinKPress(compression_ratio=0.5, inner_press=SnapKVPress(compression_ratio=0.5)) + + Here, we zero out the pruned dimensions resulting in no memory gain (the shape of the keys remains the same). + To achieve memory savings, several options can be considered (see https://github.com/NVIDIA/kvpress/pull/18/), + we might implement them in the future, especially if other similar presses are requested. + + This press has been reviewed by Yuhui Xu, first author of the ThinK paper. + """ + + compression_ratio: float = 0.0 + inner_press: Optional[BasePress] = None + window_size: int = 32 + + def compute_window_queries(self, module, hidden_states): + """ + Re-compute the last window_size query states + """ + + bsz, q_len, _ = hidden_states.shape + + # Get last window_size queries + if hasattr(module, "q_proj"): + query_states = module.q_proj(hidden_states[:, -self.window_size :]) + elif hasattr(module, "qkv_proj"): + qkv = module.qkv_proj(hidden_states[:, -self.window_size :]) + query_states = qkv[..., : module.num_heads * module.head_dim] + else: + raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") + + query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2) + + # Apply RoPE + position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device) + cos, sin = module.rotary_emb(query_states, position_ids) + query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1)) + + return query_states + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """ + We first apply the inner press, then we prune the key dimensions. If other similar presses are requested, + we will create a dedicated DimensionBasePress class to avoid code duplication. + """ + + # Apply the forward hook of the inner press + if self.inner_press is not None: + output = self.inner_press.forward_hook(module, input, kwargs, output) + + # Don't compress if the compression ratio is 0 or this is not pre-filling + cache = output[-1] + hidden_states = kwargs["hidden_states"] + q_len = hidden_states.shape[1] + assert q_len > self.window_size, "Query length should be greater than the window size" + + if (self.compression_ratio == 0) or (cache.seen_tokens > q_len): + return output + + # Get keys + if isinstance(cache, QuantizedCache): + keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) + else: + keys = cache.key_cache[module.layer_idx] + bsz, num_key_value_heads, q_len, head_dim = keys.shape + + # ThinK specific code + queries = self.compute_window_queries(module, kwargs["hidden_states"]) + + # Compute scores per dimension + queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim) + queries_norm = queries_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, module.head_dim).mean(2) + keys_norm = torch.pow(keys, 2).mean(dim=2) + key_scores = queries_norm * keys_norm # (bsz, num_key_value_heads, head_dim) + + # Prune dimensions with the lowest scores by setting them to 0 + n_pruned = int(head_dim * self.compression_ratio) + indices = key_scores.topk(n_pruned, dim=-1, largest=False).indices + indices = indices.unsqueeze(2).expand(-1, -1, q_len, -1) + keys = keys.scatter_(-1, indices, 0) + + # Update cache + if isinstance(cache, QuantizedCache): + cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) + else: + cache.key_cache[module.layer_idx] = keys + + return output diff --git a/notebooks/per_layer_compression_demo.ipynb b/notebooks/per_layer_compression_demo.ipynb index fc46c9de..9aa9bd88 100644 --- a/notebooks/per_layer_compression_demo.ipynb +++ b/notebooks/per_layer_compression_demo.ipynb @@ -216,9 +216,9 @@ ], "metadata": { "kernelspec": { - "display_name": "kvpress_2", + "display_name": ".venv", "language": "python", - "name": "kvpress_2" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -230,7 +230,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index e3589ef7..024ca89e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "kvpress" authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"] description = "Efficiently compress the KV cache of any pretrained transformer" -version = "0.0.3" +version = "0.0.4" readme = "README.md" [tool.poetry.dependencies] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index ba450160..4c2d361e 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -15,15 +15,24 @@ SnapKVPress, StreamingLLMPress, TOVAPress, + ThinKPress, ) + from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401 +def test_think_inner_press(unit_test_model): # noqa: F811 + press = ThinKPress(compression_ratio=0.5, window_size=2, inner_press=KnormPress(0.5)) + with press(unit_test_model): + input_ids = unit_test_model.dummy_inputs["input_ids"] + unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values + + def test_presses_run(unit_test_model): # noqa: F811 - for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress]: + for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress]: for compression_ratio in [0.2, 0.4, 0.6, 0.8]: press = cls(compression_ratio=compression_ratio) - if cls == SnapKVPress: + if cls in [SnapKVPress, ThinKPress]: press.window_size = 2 with press(unit_test_model): input_ids = unit_test_model.dummy_inputs["input_ids"]