From 43908688b2ceb810f96ac4a3823ea4eca2caa7c7 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 10:41:49 +0100 Subject: [PATCH 01/23] add test for rerotation --- kvpress/__init__.py | 7 +- kvpress/pipeline.py | 3 +- kvpress/presses/base_press.py | 2 +- kvpress/presses/composed_press.py | 1 + kvpress/presses/key_rerotation_press.py | 90 ++++++++++++++++++ tests/presses/test_key_rerotation_press.py | 104 +++++++++++++++++++++ 6 files changed, 202 insertions(+), 5 deletions(-) create mode 100644 kvpress/presses/key_rerotation_press.py create mode 100644 tests/presses/test_key_rerotation_press.py diff --git a/kvpress/__init__.py b/kvpress/__init__.py index fe37d310..52d3bfcf 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -4,7 +4,9 @@ from kvpress.pipeline import KVPressTextGenerationPipeline from kvpress.presses.base_press import BasePress +from kvpress.presses.composed_press import ComposedPress from kvpress.presses.expected_attention_press import ExpectedAttentionPress +from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.knorm_press import KnormPress from kvpress.presses.observed_attention_press import ObservedAttentionPress from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress @@ -13,7 +15,7 @@ from kvpress.presses.snapkv_press import SnapKVPress from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.think_press import ThinKPress -from kvpress.presses.composed_press import ComposedPress +from kvpress.presses.tova_press import TOVAPress __all__ = [ "BasePress", @@ -29,6 +31,5 @@ "TOVAPress", "KVPressTextGenerationPipeline", "PerLayerCompressionPress", + "KeyRerotationPress", ] - -from kvpress.presses.tova_press import TOVAPress diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index ca68e678..0d2cb1ae 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -12,6 +12,7 @@ from transformers.pipelines.base import GenericTensor from kvpress.presses.base_press import BasePress +from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.observed_attention_press import ObservedAttentionPress logger = logging.getLogger(__name__) @@ -180,7 +181,7 @@ def _forward( answer = self.generate_answer( question_ids=question_ids.to(self.model.device), cache=cache, - context_length=context_length, + context_length=(cache.get_seq_length() if isinstance(press, KeyRerotationPress) else context_length), max_new_tokens=max_new_tokens, ) answers.append(answer) diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 129d1769..8493395a 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -14,8 +14,8 @@ MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, - Qwen2ForCausalLM, QuantizedCache, + Qwen2ForCausalLM, ) logger = logging.getLogger(__name__) diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py index 3ebc5c16..ba3b65e7 100644 --- a/kvpress/presses/composed_press.py +++ b/kvpress/presses/composed_press.py @@ -1,4 +1,5 @@ from dataclasses import dataclass + from kvpress.presses.base_press import BasePress diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py new file mode 100644 index 00000000..9219b96f --- /dev/null +++ b/kvpress/presses/key_rerotation_press.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import inspect +from dataclasses import dataclass + +import torch +from torch import nn +from transformers.models.llama.modeling_llama import rotate_half + +from kvpress.presses.base_press import BasePress +from kvpress.presses.scorer_press import ScorerPress + + +@dataclass +class KeyRerotationPress(BasePress): + """ + Rerotate keys to have a uniform RoPE representation of keys after pruning. + This function wraps the forward hook of the press object. + See FINCH: Prompt-guided Key-Value Cache Compression for Large Language Models + https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280 + for more details on this method. + Parameters + ---------- + press : BasePress + The press object to apply per-layer compression to. + Returns + ------- + BasePress + The press object with rerotation applied. + """ + + press: ScorerPress + + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.press.compression_ratio == 0: + return keys, values + + # Compute scores from base press + scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) + + # Get indices of KV pairs with the lowest scores + q_len = hidden_states.shape[1] + n_kept = int(q_len * (1 - self.press.compression_ratio)) + indices = scores.topk(n_kept, dim=-1).indices + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) + + # Apply RoPE to the pruned keys + cos, sin = get_position_embeddings(module, keys) + keys = keys.gather(2, indices).contiguous() + + # On RoPE models, we need to recompute the Key rotation as the tokens are shifted + rerotation_cos, rerotation_sin = self.get_rerotation_cos_sin(keys, cos, sin) + keys = (keys * rerotation_cos.unsqueeze(1)) + (rotate_half(keys) * rerotation_sin.unsqueeze(1)) + + values = values.gather(2, indices).contiguous() + return keys, values + + def get_rerotation_cos_sin(self, keys, cos, sin): + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence + original_cos = cos[:, keys.shape[-2] :] + shifted_cos = cos[:, -keys.shape[-2]] + original_sin = sin[:, keys.shape[-2] :] + shifted_sin = sin[:, -keys.shape[-2]] + rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin + rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin + return rerotation_cos, rerotation_sin + + +def get_position_embeddings(module, x): + length = x.shape[2] + # rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape + if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: + position_ids = torch.arange(length).unsqueeze(0).to(x.device) + cos, sin = module.rotary_emb(x, position_ids) + else: + cos, sin = module.rotary_emb(x, length) + return cos, sin diff --git a/tests/presses/test_key_rerotation_press.py b/tests/presses/test_key_rerotation_press.py new file mode 100644 index 00000000..1524de39 --- /dev/null +++ b/tests/presses/test_key_rerotation_press.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import torch +from torch import nn +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, LlamaRotaryEmbedding, rotate_half + +from kvpress import KeyRerotationPress, ScorerPress +from tests.fixtures import unit_test_model # noqa: F401 + + +def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: LlamaForCausalLM): # noqa: F811 + """ + Compare KeyRerotationPress' rerotation of keys with the reference implementation. + In KeyRerotationPress, we are using trigonometric functions to rerotate the keys. + In the reference implementation, we are using the + """ + original_press = RandomPressWithSeed(compression_ratio=0.5) + key_rerotation_press = KeyRerotationPress(press=original_press) + + module = unit_test_model.model.layers[0].self_attn + hidden_states = torch.randn(8, 64, module.config.hidden_size) + + keys = get_keys_with_rope(module, hidden_states) + + values = torch.randn_like(keys) + keys_compressed, _ = key_rerotation_press.compress( + module, hidden_states, keys, values, attentions=None, kwargs=dict() + ) + + indices = original_press.indices + keys_compressed_ref = compute_rerotated_keys_comparison_implementation(module, hidden_states, indices) + + assert torch.allclose(keys_compressed, keys_compressed_ref, atol=1e-6) + + +def get_keys_with_rope(module, hidden_states): + # Compute keys with RoPE + keys = get_keys_without_pos_embedding(module, hidden_states) + cos, sin = get_position_embeddings(keys, module.rotary_emb) + keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) + return keys + + +@dataclass +class RandomPressWithSeed(ScorerPress): + compression_ratio: float = 0.0 + seed: int = 0 + + def __post_init__(self): + self.indices = None + super().__post_init__() + + def score( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs, + ) -> torch.Tensor: + torch.manual_seed(self.seed) + scores = torch.rand(*keys.shape[:-1]).to(keys.device, keys.dtype) + # Get indices of KV pairs with the lowest scores + q_len = hidden_states.shape[1] + n_kept = int(q_len * (1 - self.compression_ratio)) + indices = scores.topk(n_kept, dim=-1).indices + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) + self.indices = indices + + return scores + + +def compute_rerotated_keys_comparison_implementation(module: LlamaAttention, hidden_states, indices): + """ + Computes the rerotated keys for the given indices. + This is a reference implementation for the rerotation of keys. + """ + keys = get_keys_without_pos_embedding(module, hidden_states) + + keys = keys.gather(2, indices).contiguous() + # apply position embeddings on the pruned keys + cos, sin = get_position_embeddings(keys, module.rotary_emb) + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + keys = (keys * cos) + (rotate_half(keys) * sin) + return keys + + +def get_keys_without_pos_embedding(module, hidden_states): + key_states = module.k_proj(hidden_states) + key_states = key_states.view( + key_states.shape[0], key_states.shape[1], module.num_key_value_heads, module.head_dim + ).transpose(1, 2) + return key_states + + +def get_position_embeddings(keys, rotary_emb: LlamaRotaryEmbedding): + length = keys.shape[2] + position_ids = torch.arange(length).unsqueeze(0).to(keys.device) + cos, sin = rotary_emb(keys, position_ids) + return cos, sin From 42c30d7a3495f5534130af38022b476866224d19 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 11:04:13 +0100 Subject: [PATCH 02/23] move to inv rope implementation --- kvpress/presses/key_rerotation_press.py | 30 +++++++++------------- tests/presses/test_key_rerotation_press.py | 20 +++++++++------ 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 9219b96f..daf28529 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -54,30 +54,24 @@ def compress( indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) - # Apply RoPE to the pruned keys cos, sin = get_position_embeddings(module, keys) + # Rerotate as follows + # 1. keys = RoPE(W_k * hidden_states) + # 2. keys_unrotated = RoPE^-1(keys) + # 3. keys_pruned = prune(keys_unrotated) + # 4. keys = RoPE(keys_pruned) + + # 2. Inverse of rotation matrix is equivalent to setting sin -> -sin in the equation below + keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1))) + # 3. Prune keys keys = keys.gather(2, indices).contiguous() - - # On RoPE models, we need to recompute the Key rotation as the tokens are shifted - rerotation_cos, rerotation_sin = self.get_rerotation_cos_sin(keys, cos, sin) - keys = (keys * rerotation_cos.unsqueeze(1)) + (rotate_half(keys) * rerotation_sin.unsqueeze(1)) + # 4. Apply RoPE + cos, sin = get_position_embeddings(module, keys) + keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) values = values.gather(2, indices).contiguous() return keys, values - def get_rerotation_cos_sin(self, keys, cos, sin): - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - - # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence - original_cos = cos[:, keys.shape[-2] :] - shifted_cos = cos[:, -keys.shape[-2]] - original_sin = sin[:, keys.shape[-2] :] - shifted_sin = sin[:, -keys.shape[-2]] - rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin - rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin - return rerotation_cos, rerotation_sin - def get_position_embeddings(module, x): length = x.shape[2] diff --git a/tests/presses/test_key_rerotation_press.py b/tests/presses/test_key_rerotation_press.py index 1524de39..9d3ca61f 100644 --- a/tests/presses/test_key_rerotation_press.py +++ b/tests/presses/test_key_rerotation_press.py @@ -13,8 +13,10 @@ def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: LlamaForCausalLM): # noqa: F811 """ Compare KeyRerotationPress' rerotation of keys with the reference implementation. - In KeyRerotationPress, we are using trigonometric functions to rerotate the keys. - In the reference implementation, we are using the + In the reference implementation, we are computing + 1. keys = W_k * hidden_states + 2. keys_pruned = prune(keys) + 3. keys = RoPE(keys_pruned) """ original_press = RandomPressWithSeed(compression_ratio=0.5) key_rerotation_press = KeyRerotationPress(press=original_press) @@ -25,6 +27,7 @@ def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: Llam keys = get_keys_with_rope(module, hidden_states) values = torch.randn_like(keys) + # Press result keys_compressed, _ = key_rerotation_press.compress( module, hidden_states, keys, values, attentions=None, kwargs=dict() ) @@ -76,16 +79,17 @@ def score( def compute_rerotated_keys_comparison_implementation(module: LlamaAttention, hidden_states, indices): """ Computes the rerotated keys for the given indices. - This is a reference implementation for the rerotation of keys. + 1. keys = W_k * hidden_states + 2. keys_pruned = prune(keys) + 3. keys = RoPE(keys_pruned) """ + # 1. keys = get_keys_without_pos_embedding(module, hidden_states) - + # 2. keys = keys.gather(2, indices).contiguous() - # apply position embeddings on the pruned keys + # 3. cos, sin = get_position_embeddings(keys, module.rotary_emb) - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) - keys = (keys * cos) + (rotate_half(keys) * sin) + keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) return keys From 83d8813090eca6aeb705396de157d215fc823b03 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 11:17:04 +0100 Subject: [PATCH 03/23] more tests --- README.md | 5 ++- kvpress/pipeline.py | 16 +++++++- kvpress/presses/key_rerotation_press.py | 13 ++----- ...s.py => test_key_rerotation_press_rope.py} | 0 tests/presses/test_presses.py | 37 +++++++++++-------- 5 files changed, 44 insertions(+), 27 deletions(-) rename tests/presses/{test_key_rerotation_press.py => test_key_rerotation_press_rope.py} (100%) diff --git a/README.md b/README.md index 899497cf..e03dca84 100644 --- a/README.md +++ b/README.md @@ -65,9 +65,10 @@ All current presses are training free. Several of them inherit from `ScorerPress We also provide presses relying on a different logic: - `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)) -Finally we provide two special presses: -- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental) +Finally we provide special presses: +- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press. - `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks +- `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 0d2cb1ae..b342f65e 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -11,6 +11,7 @@ from transformers.pipelines import PIPELINE_REGISTRY from transformers.pipelines.base import GenericTensor +from kvpress import ComposedPress, PerLayerCompressionPress from kvpress.presses.base_press import BasePress from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.observed_attention_press import ObservedAttentionPress @@ -168,7 +169,7 @@ def _forward( self.model( input_ids=context_ids, past_key_values=cache, - output_attentions=isinstance(press, ObservedAttentionPress), + output_attentions=self.output_attentions, num_logits_to_keep=1, ) @@ -188,6 +189,19 @@ def _forward( return answers + def output_attentions(self, press: BasePress): + if isinstance(press, ObservedAttentionPress): + return True + if isinstance(press, (KeyRerotationPress, PerLayerCompressionPress)) and isinstance( + press.press, ObservedAttentionPress + ): + return True + if isinstance(press, ComposedPress) and any( + isinstance(sub_press, ObservedAttentionPress) for sub_press in press.presses + ): + return True + return False + def postprocess(self, model_outputs, single_question): if single_question: return {"answer": model_outputs[0]} diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index daf28529..3e9dfd8f 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -17,18 +17,13 @@ class KeyRerotationPress(BasePress): """ Rerotate keys to have a uniform RoPE representation of keys after pruning. - This function wraps the forward hook of the press object. - See FINCH: Prompt-guided Key-Value Cache Compression for Large Language Models - https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280 - for more details on this method. + This method is used in several key-value cache compression methods, such as + - SinkCache implementation in Hugging Face's transformers library + - FINCH: Prompt-guided Key-Value Cache Compression for Large Language Models Parameters ---------- - press : BasePress + press : ScorerPress The press object to apply per-layer compression to. - Returns - ------- - BasePress - The press object with rerotation applied. """ press: ScorerPress diff --git a/tests/presses/test_key_rerotation_press.py b/tests/presses/test_key_rerotation_press_rope.py similarity index 100% rename from tests/presses/test_key_rerotation_press.py rename to tests/presses/test_key_rerotation_press_rope.py diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 07f40085..d6886a66 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +import pytest import torch from torch import nn from transformers import DynamicCache @@ -14,7 +15,7 @@ RandomPress, SnapKVPress, StreamingLLMPress, - TOVAPress, + TOVAPress, KeyRerotationPress, ) from kvpress.presses.scorer_press import ScorerPress from kvpress.presses.think_press import ThinKPress @@ -30,20 +31,26 @@ def test_composed_press(unit_test_model): # noqa: F811 unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values -def test_presses_run(unit_test_model): # noqa: F811 - for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress]: - for compression_ratio in [0.2, 0.4, 0.6, 0.8]: - if cls == ThinKPress: - press = cls(key_channel_compression_ratio=compression_ratio, window_size=2) - else: - press = cls(compression_ratio=compression_ratio) - if cls in [SnapKVPress]: - press.window_size = 2 - with press(unit_test_model): - input_ids = unit_test_model.dummy_inputs["input_ids"] - unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values - # Check that the press has a compression_ratio attribute - assert hasattr(press, "compression_ratio") +@pytest.mark.parametrize("cls", [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress]) +@pytest.mark.parametrize("compression_ratio", [0.2, 0.4, 0.6, 0.8]) +@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress]) +def test_presses_run(unit_test_model, cls, compression_ratio, wrapper_press): # noqa: F811 + if cls == ThinKPress: + press = cls(key_channel_compression_ratio=compression_ratio, window_size=2) + else: + press = cls(compression_ratio=compression_ratio) + if cls in [SnapKVPress]: + press.window_size = 2 + if isinstance(wrapper_press, ComposedPress): + press = ComposedPress(presses=[press]) + if isinstance(wrapper_press, KeyRerotationPress): + press = KeyRerotationPress(press=press) + + with press(unit_test_model): + input_ids = unit_test_model.dummy_inputs["input_ids"] + unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values + # Check that the press has a compression_ratio attribute + assert hasattr(press, "compression_ratio") def test_presses_run_observed_attention(unit_test_model_output_attention): # noqa: F811 From a34a941416464bfdf4ee1bf45c19ac292083d2a0 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 12:03:39 +0100 Subject: [PATCH 04/23] format --- tests/presses/test_presses.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index d6886a66..c95d4305 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -10,12 +10,13 @@ from kvpress import ( ComposedPress, ExpectedAttentionPress, + KeyRerotationPress, KnormPress, ObservedAttentionPress, RandomPress, SnapKVPress, StreamingLLMPress, - TOVAPress, KeyRerotationPress, + TOVAPress, ) from kvpress.presses.scorer_press import ScorerPress from kvpress.presses.think_press import ThinKPress @@ -31,7 +32,9 @@ def test_composed_press(unit_test_model): # noqa: F811 unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values -@pytest.mark.parametrize("cls", [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress]) +@pytest.mark.parametrize( + "cls", [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress] +) @pytest.mark.parametrize("compression_ratio", [0.2, 0.4, 0.6, 0.8]) @pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress]) def test_presses_run(unit_test_model, cls, compression_ratio, wrapper_press): # noqa: F811 From 48c5d01d8b2a8ed2d0390310735750de4960a616 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 12:05:16 +0100 Subject: [PATCH 05/23] format --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e03dca84..31f2863b 100644 --- a/README.md +++ b/README.md @@ -66,9 +66,9 @@ We also provide presses relying on a different logic: - `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)) Finally we provide special presses: -- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press. -- `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks -- `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press. +- `PerLayerCompressionPress`: Compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio. +- `ComposedPress`: A press that composes multiple presses together by chaining their forward hooks. +- `KeyRerotationPress`: Rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that allows to set a compression_ratio. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) From 16092406ac884645e9e455c9a5716c682b6e897b Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 12:08:59 +0100 Subject: [PATCH 06/23] resue rope --- kvpress/pipeline.py | 3 ++- kvpress/presses/key_rerotation_press.py | 6 +++--- tests/presses/test_key_rerotation_press_rope.py | 12 +++--------- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index b342f65e..9add806c 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -11,10 +11,11 @@ from transformers.pipelines import PIPELINE_REGISTRY from transformers.pipelines.base import GenericTensor -from kvpress import ComposedPress, PerLayerCompressionPress from kvpress.presses.base_press import BasePress +from kvpress.presses.composed_press import ComposedPress from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.observed_attention_press import ObservedAttentionPress +from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress logger = logging.getLogger(__name__) diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 3e9dfd8f..514178a5 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -49,7 +49,7 @@ def compress( indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) - cos, sin = get_position_embeddings(module, keys) + cos, sin = get_rope_embeddings(module, keys) # Rerotate as follows # 1. keys = RoPE(W_k * hidden_states) # 2. keys_unrotated = RoPE^-1(keys) @@ -61,14 +61,14 @@ def compress( # 3. Prune keys keys = keys.gather(2, indices).contiguous() # 4. Apply RoPE - cos, sin = get_position_embeddings(module, keys) + cos, sin = get_rope_embeddings(module, keys) keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) values = values.gather(2, indices).contiguous() return keys, values -def get_position_embeddings(module, x): +def get_rope_embeddings(module, x): length = x.shape[2] # rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: diff --git a/tests/presses/test_key_rerotation_press_rope.py b/tests/presses/test_key_rerotation_press_rope.py index 9d3ca61f..83bf937d 100644 --- a/tests/presses/test_key_rerotation_press_rope.py +++ b/tests/presses/test_key_rerotation_press_rope.py @@ -7,6 +7,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, LlamaRotaryEmbedding, rotate_half from kvpress import KeyRerotationPress, ScorerPress +from kvpress.presses.key_rerotation_press import get_rope_embeddings from tests.fixtures import unit_test_model # noqa: F401 @@ -41,7 +42,7 @@ def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: Llam def get_keys_with_rope(module, hidden_states): # Compute keys with RoPE keys = get_keys_without_pos_embedding(module, hidden_states) - cos, sin = get_position_embeddings(keys, module.rotary_emb) + cos, sin = get_rope_embeddings(module, keys) keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) return keys @@ -88,7 +89,7 @@ def compute_rerotated_keys_comparison_implementation(module: LlamaAttention, hid # 2. keys = keys.gather(2, indices).contiguous() # 3. - cos, sin = get_position_embeddings(keys, module.rotary_emb) + cos, sin = get_rope_embeddings(module, keys) keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) return keys @@ -99,10 +100,3 @@ def get_keys_without_pos_embedding(module, hidden_states): key_states.shape[0], key_states.shape[1], module.num_key_value_heads, module.head_dim ).transpose(1, 2) return key_states - - -def get_position_embeddings(keys, rotary_emb: LlamaRotaryEmbedding): - length = keys.shape[2] - position_ids = torch.arange(length).unsqueeze(0).to(keys.device) - cos, sin = rotary_emb(keys, position_ids) - return cos, sin From 467e52c8def71ddc70e7ec6e75749d817007db2d Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 12:25:26 +0100 Subject: [PATCH 07/23] remove compression ratios n tests --- tests/presses/test_presses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index c95d4305..583626ce 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -35,7 +35,7 @@ def test_composed_press(unit_test_model): # noqa: F811 @pytest.mark.parametrize( "cls", [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress] ) -@pytest.mark.parametrize("compression_ratio", [0.2, 0.4, 0.6, 0.8]) +@pytest.mark.parametrize("compression_ratio", [0.2, 0.8]) @pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress]) def test_presses_run(unit_test_model, cls, compression_ratio, wrapper_press): # noqa: F811 if cls == ThinKPress: @@ -58,7 +58,7 @@ def test_presses_run(unit_test_model, cls, compression_ratio, wrapper_press): # def test_presses_run_observed_attention(unit_test_model_output_attention): # noqa: F811 for cls in [ObservedAttentionPress]: - for compresion_ratio in [0.2, 0.4, 0.6, 0.8]: + for compresion_ratio in [0.2, 0.8]: press = cls(compression_ratio=compresion_ratio) with press(unit_test_model_output_attention): input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"] From 696da767cee141b4e05fc84b12ceccc6587d4691 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 12:26:46 +0100 Subject: [PATCH 08/23] fix test on gpu --- tests/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1d4c324b..174a74bb 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -79,7 +79,7 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 questions = ["When was this article written?"] tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) - compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer) + compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=torch.device("cpu")) input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"] seq_len = 256 From 427559588f73ab37134c75d5b19edf7347e2de4f Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 12:35:11 +0100 Subject: [PATCH 09/23] better readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 31f2863b..fdcb84a8 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ We also provide presses relying on a different logic: Finally we provide special presses: - `PerLayerCompressionPress`: Compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio. - `ComposedPress`: A press that composes multiple presses together by chaining their forward hooks. -- `KeyRerotationPress`: Rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that allows to set a compression_ratio. +- `KeyRerotationPress`: Rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from ScorerPress. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) From 5ed620840dd1c98111690a84ae10bf8a298565f8 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 12:43:19 +0100 Subject: [PATCH 10/23] fix style --- tests/presses/test_key_rerotation_press_rope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/presses/test_key_rerotation_press_rope.py b/tests/presses/test_key_rerotation_press_rope.py index 83bf937d..fe7ed24b 100644 --- a/tests/presses/test_key_rerotation_press_rope.py +++ b/tests/presses/test_key_rerotation_press_rope.py @@ -4,7 +4,7 @@ import torch from torch import nn -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, LlamaRotaryEmbedding, rotate_half +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, rotate_half from kvpress import KeyRerotationPress, ScorerPress from kvpress.presses.key_rerotation_press import get_rope_embeddings From 1079a03bb4ffd7d16bf5dd7ed30ccbea3bff4434 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 15:12:28 +0100 Subject: [PATCH 11/23] fix merge conflicts --- tests/presses/test_presses.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index f1566ff3..5b4e3aba 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -14,8 +14,8 @@ KnormPress, ObservedAttentionPress, RandomPress, - SnapKVPress, SimLayerKVPress, + SnapKVPress, StreamingLLMPress, TOVAPress, ) @@ -34,13 +34,25 @@ def test_composed_press(unit_test_model): # noqa: F811 @pytest.mark.parametrize( - "cls", [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress] + "cls", + [ + KnormPress, + ExpectedAttentionPress, + RandomPress, + StreamingLLMPress, + SnapKVPress, + TOVAPress, + ThinKPress, + SimLayerKVPress, + ], ) @pytest.mark.parametrize("compression_ratio", [0.2, 0.8]) @pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress]) def test_presses_run(unit_test_model, cls, compression_ratio, wrapper_press): # noqa: F811 if cls == ThinKPress: press = cls(key_channel_compression_ratio=compression_ratio, window_size=2) + elif cls == SimLayerKVPress: + press = cls(lazy_threshold=compression_ratio, n_initial=1, n_recent=1, n_last=1) else: press = cls(compression_ratio=compression_ratio) if cls in [SnapKVPress]: From 4c9eb72321b2b846477539119dd128dd7fbf424f Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 17:15:45 +0100 Subject: [PATCH 12/23] update to 0.1.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 024ca89e..cb0a686f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "kvpress" authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"] description = "Efficiently compress the KV cache of any pretrained transformer" -version = "0.0.4" +version = "0.1.0" readme = "README.md" [tool.poetry.dependencies] From 1ad53cbe53e8ac66d00f589b8503d756d9fae840 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 17:26:05 +0100 Subject: [PATCH 13/23] add fp16 test --- tests/presses/test_key_rerotation_press_rope.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/presses/test_key_rerotation_press_rope.py b/tests/presses/test_key_rerotation_press_rope.py index fe7ed24b..a7ad057f 100644 --- a/tests/presses/test_key_rerotation_press_rope.py +++ b/tests/presses/test_key_rerotation_press_rope.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +import pytest import torch from torch import nn from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, rotate_half @@ -11,7 +12,8 @@ from tests.fixtures import unit_test_model # noqa: F401 -def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: LlamaForCausalLM): # noqa: F811 +@pytest.mark.parametrize("precision", ["full", "half"]) +def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: LlamaForCausalLM, precision): # noqa: F811 """ Compare KeyRerotationPress' rerotation of keys with the reference implementation. In the reference implementation, we are computing @@ -19,11 +21,18 @@ def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: Llam 2. keys_pruned = prune(keys) 3. keys = RoPE(keys_pruned) """ + if precision == "half" and torch.cuda.is_available(): + unit_test_model = unit_test_model.cuda().half() + elif precision == "half" and not torch.cuda.is_available(): + pytest.skip("Half precision test is skipped because CUDA is not available.") + original_press = RandomPressWithSeed(compression_ratio=0.5) key_rerotation_press = KeyRerotationPress(press=original_press) module = unit_test_model.model.layers[0].self_attn - hidden_states = torch.randn(8, 64, module.config.hidden_size) + hidden_states = torch.randn( + 8, 64, module.config.hidden_size, device=unit_test_model.device, dtype=unit_test_model.dtype + ) keys = get_keys_with_rope(module, hidden_states) @@ -36,7 +45,7 @@ def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: Llam indices = original_press.indices keys_compressed_ref = compute_rerotated_keys_comparison_implementation(module, hidden_states, indices) - assert torch.allclose(keys_compressed, keys_compressed_ref, atol=1e-6) + assert torch.allclose(keys_compressed, keys_compressed_ref, atol=1e-6 if precision == "full" else 1e-3) def get_keys_with_rope(module, hidden_states): From 037b5c01234449d3a7750705f81941c151cc6591 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 18:00:16 +0100 Subject: [PATCH 14/23] refactor tests --- tests/default_presses.py | 35 +++++++++++++++++++ tests/integration/test_ruler.py | 27 ++++---------- tests/presses/test_presses.py | 62 +++++++++------------------------ 3 files changed, 59 insertions(+), 65 deletions(-) create mode 100644 tests/default_presses.py diff --git a/tests/default_presses.py b/tests/default_presses.py new file mode 100644 index 00000000..b688e2e5 --- /dev/null +++ b/tests/default_presses.py @@ -0,0 +1,35 @@ +from kvpress import ( + ExpectedAttentionPress, + KnormPress, + RandomPress, + SimLayerKVPress, + SnapKVPress, + StreamingLLMPress, + ThinKPress, + TOVAPress, +) + +# contains all presses to be tested +# kwargs should be ordered easy to hard compression +default_presses = [ + {"cls": KnormPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + {"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + {"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + {"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + {"cls": SnapKVPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + {"cls": TOVAPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + { + "cls": ThinKPress, + "kwargs": [ + {"key_channel_compression_ratio": 0.2, "window_size": 2}, + {"key_channel_compression_ratio": 0.8, "window_size": 2}, + ], + }, + { + "cls": SimLayerKVPress, + "kwargs": [ + {"lazy_threshold": 0.8, "n_initial": 1, "n_recent": 1, "n_last": 1}, + {"lazy_threshold": 0.2, "n_initial": 1, "n_recent": 1, "n_last": 1}, + ], + }, +] diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index 460203a0..41a5a0a7 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -4,15 +4,7 @@ from transformers import DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available -from kvpress import ( - ExpectedAttentionPress, - KnormPress, - SimLayerKVPress, - SnapKVPress, - StreamingLLMPress, - ThinKPress, - TOVAPress, -) +from tests.default_presses import default_presses from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401 @@ -25,18 +17,13 @@ def df_ruler(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") -@pytest.mark.parametrize( - "cls", [KnormPress, ExpectedAttentionPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress, SimLayerKVPress] -) -@pytest.mark.parametrize("compression_ratio", [0.1, 0.2]) +@pytest.mark.parametrize("press_dict", default_presses) @pytest.mark.parametrize("cache", ["dynamic", "quantized"]) -def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, cls, compression_ratio, cache): # noqa: F811 - if cls == ThinKPress: - press = cls(key_channel_compression_ratio=compression_ratio, window_size=2) - elif cls == SimLayerKVPress: - press = cls(lazy_threshold=1 - compression_ratio) - else: - press = cls(compression_ratio=compression_ratio) +def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, press_dict, cache): # noqa: F811 + cls = press_dict["cls"] + kwargs = press_dict["kwargs"][0] + press = cls(**kwargs) + if cache == "dynamic": cache = DynamicCache() elif cache == "quantized" and is_optimum_quanto_available(): diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 5b4e3aba..50056b9a 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -7,20 +7,10 @@ from torch import nn from transformers import DynamicCache -from kvpress import ( - ComposedPress, - ExpectedAttentionPress, - KeyRerotationPress, - KnormPress, - ObservedAttentionPress, - RandomPress, - SimLayerKVPress, - SnapKVPress, - StreamingLLMPress, - TOVAPress, -) +from kvpress import ComposedPress, KeyRerotationPress, KnormPress, ObservedAttentionPress from kvpress.presses.scorer_press import ScorerPress from kvpress.presses.think_press import ThinKPress +from tests.default_presses import default_presses from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401 @@ -33,40 +23,22 @@ def test_composed_press(unit_test_model): # noqa: F811 unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values -@pytest.mark.parametrize( - "cls", - [ - KnormPress, - ExpectedAttentionPress, - RandomPress, - StreamingLLMPress, - SnapKVPress, - TOVAPress, - ThinKPress, - SimLayerKVPress, - ], -) -@pytest.mark.parametrize("compression_ratio", [0.2, 0.8]) +@pytest.mark.parametrize("press_dict", default_presses) @pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress]) -def test_presses_run(unit_test_model, cls, compression_ratio, wrapper_press): # noqa: F811 - if cls == ThinKPress: - press = cls(key_channel_compression_ratio=compression_ratio, window_size=2) - elif cls == SimLayerKVPress: - press = cls(lazy_threshold=compression_ratio, n_initial=1, n_recent=1, n_last=1) - else: - press = cls(compression_ratio=compression_ratio) - if cls in [SnapKVPress]: - press.window_size = 2 - if isinstance(wrapper_press, ComposedPress): - press = ComposedPress(presses=[press]) - if isinstance(wrapper_press, KeyRerotationPress): - press = KeyRerotationPress(press=press) - - with press(unit_test_model): - input_ids = unit_test_model.dummy_inputs["input_ids"] - unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values - # Check that the press has a compression_ratio attribute - assert hasattr(press, "compression_ratio") +def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 + cls = press_dict["cls"] + for kwargs in press_dict["kwargs"]: + press = cls(**kwargs) + if isinstance(wrapper_press, ComposedPress): + press = ComposedPress(presses=[press]) + if isinstance(wrapper_press, KeyRerotationPress): + press = KeyRerotationPress(press=press) + + with press(unit_test_model): + input_ids = unit_test_model.dummy_inputs["input_ids"] + unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values + # Check that the press has a compression_ratio attribute + assert hasattr(press, "compression_ratio") def test_presses_run_observed_attention(unit_test_model_output_attention): # noqa: F811 From dc96e97cd2a20a2984a926c25be890651d30cd24 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 18:41:42 +0100 Subject: [PATCH 15/23] fix broken test --- tests/default_presses.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/default_presses.py b/tests/default_presses.py index b688e2e5..a8ccda66 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -16,7 +16,10 @@ {"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, - {"cls": SnapKVPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + { + "cls": SnapKVPress, + "kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}], + }, {"cls": TOVAPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, { "cls": ThinKPress, From 9e51013ceadb74a7fbab42d4e906a5a6f5de855a Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 19:09:10 +0100 Subject: [PATCH 16/23] update readme --- README.md | 2 +- notebooks/new_press.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7b397c88..04728439 100644 --- a/README.md +++ b/README.md @@ -179,7 +179,7 @@ However, the `generate` method does not allow to exclude the question from the c All presses are stored in the `presses` directory. The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair (see `knorm_press.py` for a simple example). Check the notebook [new_press.ipynb](notebooks/new_press.ipynb) for a step-by-step guide. -Before opening a pull request with a new press, make sure to register it in the `__init__.py` file of repository and to add it in [test_presses.py](tests/presses/test_presses.py). +Before opening a pull request with a new press, make sure to register it in the `__init__.py` file of repository and to add it in [default_presses.py](tests/default_presses.py). diff --git a/notebooks/new_press.ipynb b/notebooks/new_press.ipynb index 47bba3c2..c9ed6769 100644 --- a/notebooks/new_press.ipynb +++ b/notebooks/new_press.ipynb @@ -242,7 +242,7 @@ "source": [ "All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to \n", "- register it in the `__init__.py` file of repository\n", - "- add a test [test_presses.py](tests/presses/test_presses.py)\n", + "- register the press in [default_presses.py](tests/default_presses.py)\n", "- update the README" ] } From d99f287e3b158eb2baa806f17126076dc2a6a93d Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 19:12:39 +0100 Subject: [PATCH 17/23] address pr feedback --- README.md | 2 +- kvpress/presses/streaming_llm_press.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 04728439..e5a0b76e 100644 --- a/README.md +++ b/README.md @@ -59,12 +59,12 @@ All current presses are training free. Several of them inherit from `ScorerPress - `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469)) - `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb)) - `StreamingLLMPress`: keep only the initial and recent tokens ([paper](https://arxiv.org/abs/2309.17453)) -- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)). The input of this press is the lazy threshold, not the compression ratio. - `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104)) - `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048)) We also provide presses relying on a different logic: - `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)) +- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)). The input of this press is the lazy threshold, not the compression ratio. Finally we provide special presses: - `PerLayerCompressionPress`: Compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio. diff --git a/kvpress/presses/streaming_llm_press.py b/kvpress/presses/streaming_llm_press.py index b0560541..c63a85b8 100644 --- a/kvpress/presses/streaming_llm_press.py +++ b/kvpress/presses/streaming_llm_press.py @@ -16,6 +16,10 @@ class StreamingLLMPress(ScorerPress): Prune a fixed number of KV pairs at the beginning and end of the sequence (https://arxiv.org/abs/2309.17453) We keep the first n_sink tokens and the last n_local tokens. n_local is computed using the compression ratio. + + Note that the original implementation https://github.com/mit-han-lab/streaming-llm additionally rerotates keys. + This can be achieved by using + press = KeyRerotationPress(press=StreamingLLMPress(compression_ratio, n_sink)) """ compression_ratio: float = 0.0 From aa355d571d74a31baa466c6d53d69f1523f4a022 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 19:21:44 +0100 Subject: [PATCH 18/23] address pr feedback --- tests/presses/test_key_rerotation_press_rope.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/presses/test_key_rerotation_press_rope.py b/tests/presses/test_key_rerotation_press_rope.py index a7ad057f..33752ce6 100644 --- a/tests/presses/test_key_rerotation_press_rope.py +++ b/tests/presses/test_key_rerotation_press_rope.py @@ -1,5 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 + + from dataclasses import dataclass import pytest From 575c032b5458a4bfdf4b7ca8404be4927b2c792c Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 19:22:08 +0100 Subject: [PATCH 19/23] address pr feedback --- tests/default_presses.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/default_presses.py b/tests/default_presses.py index a8ccda66..08d2fb13 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + from kvpress import ( ExpectedAttentionPress, KnormPress, From d64ad7ca72656c0f5b18cb8af547d1c7d67bc7d9 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 12 Dec 2024 08:38:38 +0000 Subject: [PATCH 20/23] Update README --- README.md | 52 +++++++++++----------------------------------------- 1 file changed, 11 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index e5a0b76e..95f0e57b 100644 --- a/README.md +++ b/README.md @@ -19,9 +19,7 @@ pip install flash-attn --no-build-isolation ## Usage -This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you: - - +This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you: ```python from kvpress import ExpectedAttentionPress @@ -48,28 +46,28 @@ In the snippet above, the compression is only applied on the context tokens so t ## Contributing with a new press -We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [FAQ](#faq) for more information on how presses work and how to create new ones or check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide. +We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide to understand how presses work and what should be done to create a new one. ## Available presses -All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score used to prune the KV pairs with lowest importance: +All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score to prune the KV pairs with lowest importance: - `RandomPress`: random score - `KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430)) -- `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469)) +- `SnapKVPress`: average attention weight of the last queries ([paper](https://arxiv.org/abs/2404.14469)) - `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb)) - `StreamingLLMPress`: keep only the initial and recent tokens ([paper](https://arxiv.org/abs/2309.17453)) - `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104)) - `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048)) -We also provide presses relying on a different logic: -- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)) -- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)). The input of this press is the lazy threshold, not the compression ratio. +Some presses relying on a different logic: +- `ThinKPress`: compress the dimensions of the keys based on the channel attention score on the last queries ([paper](https://arxiv.org/pdf/2407.21018)) +- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)) Finally we provide special presses: -- `PerLayerCompressionPress`: Compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio. -- `ComposedPress`: A press that composes multiple presses together by chaining their forward hooks. -- `KeyRerotationPress`: Rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from ScorerPress. +- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio +- `ComposedPress`: compose multiple presses together by chaining their forward hooks +- `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) @@ -130,9 +128,7 @@ Memory usage should be reduced by around `compression_ratio * kv_cache_size`. As ### How does a press work ? -A press registers a forward hook to each attention layer during the pre-filling phase: -1. Immediately after the forward pass, the hook is called, and it computes a score for each key-value pair using the `press.score` method -2. The key-value pairs with the lowest scores are then removed based on the `compression_ratio` parameter +A press registers a forward hook (`press.forward_hook` method) to each attention layer during the pre-filling phase. Registration can be applied using the press as a context manager (`press.__call__` method): ```python import torch @@ -171,29 +167,3 @@ with press(model): However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once. - -
- -### How to create a new press ? - - -All presses are stored in the `presses` directory. The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair (see `knorm_press.py` for a simple example). Check the notebook [new_press.ipynb](notebooks/new_press.ipynb) for a step-by-step guide. - -Before opening a pull request with a new press, make sure to register it in the `__init__.py` file of repository and to add it in [default_presses.py](tests/default_presses.py). - -
- -
- -### Can I change the compression ratio from one layer to another ? - - -We provide an experimental feature, which only works with flash attention: -```python -from kvpress import PerLayerCompressionPress -# compression_ratios should have the same length as the number of layers -press = PerLayerCompressionPress(press, compression_ratios=[...]) -``` - -Check the [demo notebook](notebooks/per_layer_compression_demo.ipynb) for more details. -
From f9b5347ba186235bf12fe9fa9d08c18115e544e7 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Thu, 12 Dec 2024 09:49:28 +0100 Subject: [PATCH 21/23] address pr feedback --- kvpress/presses/random_press.py | 6 ++++++ tests/presses/test_key_rerotation_press_rope.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/kvpress/presses/random_press.py b/kvpress/presses/random_press.py index 7c013a18..063e8082 100644 --- a/kvpress/presses/random_press.py +++ b/kvpress/presses/random_press.py @@ -3,6 +3,7 @@ from dataclasses import dataclass +from typing import Optional import torch from torch import nn @@ -14,6 +15,9 @@ class RandomPress(ScorerPress): """Randomly prune KV pairs""" + compression_ratio: float = 0.0 + seed: Optional[int] = None + def score( self, module: nn.Module, @@ -23,4 +27,6 @@ def score( attentions: torch.Tensor, kwargs, ) -> torch.Tensor: + if self.seed is not None: + torch.manual_seed(self.seed) return torch.rand(*keys.shape[:-1]).to(keys.device, keys.dtype) diff --git a/tests/presses/test_key_rerotation_press_rope.py b/tests/presses/test_key_rerotation_press_rope.py index 33752ce6..b612a6b8 100644 --- a/tests/presses/test_key_rerotation_press_rope.py +++ b/tests/presses/test_key_rerotation_press_rope.py @@ -28,7 +28,7 @@ def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: Llam elif precision == "half" and not torch.cuda.is_available(): pytest.skip("Half precision test is skipped because CUDA is not available.") - original_press = RandomPressWithSeed(compression_ratio=0.5) + original_press = RandomPressStoreIndices(compression_ratio=0.5) key_rerotation_press = KeyRerotationPress(press=original_press) module = unit_test_model.model.layers[0].self_attn @@ -59,7 +59,7 @@ def get_keys_with_rope(module, hidden_states): @dataclass -class RandomPressWithSeed(ScorerPress): +class RandomPressStoreIndices(ScorerPress): compression_ratio: float = 0.0 seed: int = 0 From ab514ceeddcc2bfbd26b916cec22dd5806e81537 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 12 Dec 2024 09:02:00 +0000 Subject: [PATCH 22/23] Add PR template --- .github/PULL_REQUEST_TEMPLATE.md | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..2e3a65ef --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,10 @@ +## PR description + +Description of your PR. Fixes # (issue) (if applicable) + +## New press checklist (if applicable) + +- [ ] I added `mypress_press.py` in the `presses` directory +- [ ] I added `MyPress` in `__init__.py` +- [ ] I updated the `README.md` with a 1 liner about my new press in the Available presses section +- [ ] I added my press in the `PRESS_LIST` list in `tests/presses/test_presses.py` From a20df839d71ce420aefe5389f9807be94c65e6ce Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Thu, 12 Dec 2024 10:14:51 +0100 Subject: [PATCH 23/23] update PR template --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 2e3a65ef..d3f0e0f0 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -7,4 +7,4 @@ Description of your PR. Fixes # (issue) (if applicable) - [ ] I added `mypress_press.py` in the `presses` directory - [ ] I added `MyPress` in `__init__.py` - [ ] I updated the `README.md` with a 1 liner about my new press in the Available presses section -- [ ] I added my press in the `PRESS_LIST` list in `tests/presses/test_presses.py` +- [ ] I added my press in the `default_presses` list in `tests/default_presses.py`