diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 00000000..d3f0e0f0
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,10 @@
+## PR description
+
+Description of your PR. Fixes # (issue) (if applicable)
+
+## New press checklist (if applicable)
+
+- [ ] I added `mypress_press.py` in the `presses` directory
+- [ ] I added `MyPress` in `__init__.py`
+- [ ] I updated the `README.md` with a 1 liner about my new press in the Available presses section
+- [ ] I added my press in the `default_presses` list in `tests/default_presses.py`
diff --git a/README.md b/README.md
index 6e01a5ca..95f0e57b 100644
--- a/README.md
+++ b/README.md
@@ -19,9 +19,7 @@ pip install flash-attn --no-build-isolation
## Usage
-This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:
-
-
+This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:
```python
from kvpress import ExpectedAttentionPress
@@ -48,27 +46,28 @@ In the snippet above, the compression is only applied on the context tokens so t
## Contributing with a new press
-We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [FAQ](#faq) for more information on how presses work and how to create new ones or check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide.
+We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide to understand how presses work and what should be done to create a new one.
## Available presses
-All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score used to prune the KV pairs with lowest importance:
+All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score to prune the KV pairs with lowest importance:
- `RandomPress`: random score
- `KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430))
-- `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469))
+- `SnapKVPress`: average attention weight of the last queries ([paper](https://arxiv.org/abs/2404.14469))
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
- `StreamingLLMPress`: keep only the initial and recent tokens ([paper](https://arxiv.org/abs/2309.17453))
-- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)). The input of this press is the lazy threshold, not the compression ratio.
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))
-We also provide presses relying on a different logic:
-- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018))
+Some presses relying on a different logic:
+- `ThinKPress`: compress the dimensions of the keys based on the channel attention score on the last queries ([paper](https://arxiv.org/pdf/2407.21018))
+- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846))
-Finally we provide two special presses:
-- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental)
-- `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks
+Finally we provide special presses:
+- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio
+- `ComposedPress`: compose multiple presses together by chaining their forward hooks
+- `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`.
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
@@ -129,9 +128,7 @@ Memory usage should be reduced by around `compression_ratio * kv_cache_size`. As
### How does a press work ?
-A press registers a forward hook to each attention layer during the pre-filling phase:
-1. Immediately after the forward pass, the hook is called, and it computes a score for each key-value pair using the `press.score` method
-2. The key-value pairs with the lowest scores are then removed based on the `compression_ratio` parameter
+A press registers a forward hook (`press.forward_hook` method) to each attention layer during the pre-filling phase. Registration can be applied using the press as a context manager (`press.__call__` method):
```python
import torch
@@ -170,29 +167,3 @@ with press(model):
However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once.
-
-
-
-### How to create a new press ?
-
-
-All presses are stored in the `presses` directory. The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair (see `knorm_press.py` for a simple example). Check the notebook [new_press.ipynb](notebooks/new_press.ipynb) for a step-by-step guide.
-
-Before opening a pull request with a new press, make sure to register it in the `__init__.py` file of repository and to add it in [test_presses.py](tests/presses/test_presses.py).
-
-
-
-
-
-### Can I change the compression ratio from one layer to another ?
-
-
-We provide an experimental feature, which only works with flash attention:
-```python
-from kvpress import PerLayerCompressionPress
-# compression_ratios should have the same length as the number of layers
-press = PerLayerCompressionPress(press, compression_ratios=[...])
-```
-
-Check the [demo notebook](notebooks/per_layer_compression_demo.ipynb) for more details.
-
diff --git a/kvpress/__init__.py b/kvpress/__init__.py
index 913a0828..e2693aaf 100644
--- a/kvpress/__init__.py
+++ b/kvpress/__init__.py
@@ -6,6 +6,7 @@
from kvpress.presses.base_press import BasePress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
+from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
@@ -32,4 +33,5 @@
"TOVAPress",
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
+ "KeyRerotationPress",
]
diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py
index b9159d1e..456f7e4a 100644
--- a/kvpress/pipeline.py
+++ b/kvpress/pipeline.py
@@ -12,7 +12,10 @@
from transformers.pipelines.base import GenericTensor
from kvpress.presses.base_press import BasePress
+from kvpress.presses.composed_press import ComposedPress
+from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
+from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
logger = logging.getLogger(__name__)
@@ -167,7 +170,7 @@ def _forward(
self.model(
input_ids=context_ids,
past_key_values=cache,
- output_attentions=isinstance(press, ObservedAttentionPress),
+ output_attentions=self.output_attentions,
num_logits_to_keep=1,
)
@@ -180,13 +183,26 @@ def _forward(
answer = self.generate_answer(
question_ids=question_ids.to(self.model.device),
cache=cache,
- context_length=context_length,
+ context_length=(cache.get_seq_length() if isinstance(press, KeyRerotationPress) else context_length),
max_new_tokens=max_new_tokens,
)
answers.append(answer)
return answers
+ def output_attentions(self, press: BasePress):
+ if isinstance(press, ObservedAttentionPress):
+ return True
+ if isinstance(press, (KeyRerotationPress, PerLayerCompressionPress)) and isinstance(
+ press.press, ObservedAttentionPress
+ ):
+ return True
+ if isinstance(press, ComposedPress) and any(
+ isinstance(sub_press, ObservedAttentionPress) for sub_press in press.presses
+ ):
+ return True
+ return False
+
def postprocess(self, model_outputs, single_question):
if single_question:
return {"answer": model_outputs[0]}
diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py
new file mode 100644
index 00000000..514178a5
--- /dev/null
+++ b/kvpress/presses/key_rerotation_press.py
@@ -0,0 +1,79 @@
+# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+import inspect
+from dataclasses import dataclass
+
+import torch
+from torch import nn
+from transformers.models.llama.modeling_llama import rotate_half
+
+from kvpress.presses.base_press import BasePress
+from kvpress.presses.scorer_press import ScorerPress
+
+
+@dataclass
+class KeyRerotationPress(BasePress):
+ """
+ Rerotate keys to have a uniform RoPE representation of keys after pruning.
+ This method is used in several key-value cache compression methods, such as
+ - SinkCache implementation in Hugging Face's transformers library
+ - FINCH: Prompt-guided Key-Value Cache Compression for Large Language Models
+ Parameters
+ ----------
+ press : ScorerPress
+ The press object to apply per-layer compression to.
+ """
+
+ press: ScorerPress
+
+ def compress(
+ self,
+ module: nn.Module,
+ hidden_states: torch.Tensor,
+ keys: torch.Tensor,
+ values: torch.Tensor,
+ attentions: torch.Tensor,
+ kwargs: dict,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if self.press.compression_ratio == 0:
+ return keys, values
+
+ # Compute scores from base press
+ scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
+
+ # Get indices of KV pairs with the lowest scores
+ q_len = hidden_states.shape[1]
+ n_kept = int(q_len * (1 - self.press.compression_ratio))
+ indices = scores.topk(n_kept, dim=-1).indices
+ indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
+
+ cos, sin = get_rope_embeddings(module, keys)
+ # Rerotate as follows
+ # 1. keys = RoPE(W_k * hidden_states)
+ # 2. keys_unrotated = RoPE^-1(keys)
+ # 3. keys_pruned = prune(keys_unrotated)
+ # 4. keys = RoPE(keys_pruned)
+
+ # 2. Inverse of rotation matrix is equivalent to setting sin -> -sin in the equation below
+ keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
+ # 3. Prune keys
+ keys = keys.gather(2, indices).contiguous()
+ # 4. Apply RoPE
+ cos, sin = get_rope_embeddings(module, keys)
+ keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
+
+ values = values.gather(2, indices).contiguous()
+ return keys, values
+
+
+def get_rope_embeddings(module, x):
+ length = x.shape[2]
+ # rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape
+ if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
+ position_ids = torch.arange(length).unsqueeze(0).to(x.device)
+ cos, sin = module.rotary_emb(x, position_ids)
+ else:
+ cos, sin = module.rotary_emb(x, length)
+ return cos, sin
diff --git a/kvpress/presses/random_press.py b/kvpress/presses/random_press.py
index 7c013a18..063e8082 100644
--- a/kvpress/presses/random_press.py
+++ b/kvpress/presses/random_press.py
@@ -3,6 +3,7 @@
from dataclasses import dataclass
+from typing import Optional
import torch
from torch import nn
@@ -14,6 +15,9 @@
class RandomPress(ScorerPress):
"""Randomly prune KV pairs"""
+ compression_ratio: float = 0.0
+ seed: Optional[int] = None
+
def score(
self,
module: nn.Module,
@@ -23,4 +27,6 @@ def score(
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
+ if self.seed is not None:
+ torch.manual_seed(self.seed)
return torch.rand(*keys.shape[:-1]).to(keys.device, keys.dtype)
diff --git a/kvpress/presses/streaming_llm_press.py b/kvpress/presses/streaming_llm_press.py
index b0560541..c63a85b8 100644
--- a/kvpress/presses/streaming_llm_press.py
+++ b/kvpress/presses/streaming_llm_press.py
@@ -16,6 +16,10 @@ class StreamingLLMPress(ScorerPress):
Prune a fixed number of KV pairs at the beginning and end of the sequence (https://arxiv.org/abs/2309.17453)
We keep the first n_sink tokens and the last n_local tokens.
n_local is computed using the compression ratio.
+
+ Note that the original implementation https://github.com/mit-han-lab/streaming-llm additionally rerotates keys.
+ This can be achieved by using
+ press = KeyRerotationPress(press=StreamingLLMPress(compression_ratio, n_sink))
"""
compression_ratio: float = 0.0
diff --git a/notebooks/new_press.ipynb b/notebooks/new_press.ipynb
index 47bba3c2..c9ed6769 100644
--- a/notebooks/new_press.ipynb
+++ b/notebooks/new_press.ipynb
@@ -242,7 +242,7 @@
"source": [
"All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to \n",
"- register it in the `__init__.py` file of repository\n",
- "- add a test [test_presses.py](tests/presses/test_presses.py)\n",
+ "- register the press in [default_presses.py](tests/default_presses.py)\n",
"- update the README"
]
}
diff --git a/pyproject.toml b/pyproject.toml
index 024ca89e..cb0a686f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -2,7 +2,7 @@
name = "kvpress"
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
description = "Efficiently compress the KV cache of any pretrained transformer"
-version = "0.0.4"
+version = "0.1.0"
readme = "README.md"
[tool.poetry.dependencies]
diff --git a/tests/default_presses.py b/tests/default_presses.py
new file mode 100644
index 00000000..08d2fb13
--- /dev/null
+++ b/tests/default_presses.py
@@ -0,0 +1,42 @@
+# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+from kvpress import (
+ ExpectedAttentionPress,
+ KnormPress,
+ RandomPress,
+ SimLayerKVPress,
+ SnapKVPress,
+ StreamingLLMPress,
+ ThinKPress,
+ TOVAPress,
+)
+
+# contains all presses to be tested
+# kwargs should be ordered easy to hard compression
+default_presses = [
+ {"cls": KnormPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
+ {"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
+ {"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
+ {"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
+ {
+ "cls": SnapKVPress,
+ "kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}],
+ },
+ {"cls": TOVAPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
+ {
+ "cls": ThinKPress,
+ "kwargs": [
+ {"key_channel_compression_ratio": 0.2, "window_size": 2},
+ {"key_channel_compression_ratio": 0.8, "window_size": 2},
+ ],
+ },
+ {
+ "cls": SimLayerKVPress,
+ "kwargs": [
+ {"lazy_threshold": 0.8, "n_initial": 1, "n_recent": 1, "n_last": 1},
+ {"lazy_threshold": 0.2, "n_initial": 1, "n_recent": 1, "n_last": 1},
+ ],
+ },
+]
diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py
index 460203a0..41a5a0a7 100644
--- a/tests/integration/test_ruler.py
+++ b/tests/integration/test_ruler.py
@@ -4,15 +4,7 @@
from transformers import DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available
-from kvpress import (
- ExpectedAttentionPress,
- KnormPress,
- SimLayerKVPress,
- SnapKVPress,
- StreamingLLMPress,
- ThinKPress,
- TOVAPress,
-)
+from tests.default_presses import default_presses
from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401
@@ -25,18 +17,13 @@ def df_ruler():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
-@pytest.mark.parametrize(
- "cls", [KnormPress, ExpectedAttentionPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress, SimLayerKVPress]
-)
-@pytest.mark.parametrize("compression_ratio", [0.1, 0.2])
+@pytest.mark.parametrize("press_dict", default_presses)
@pytest.mark.parametrize("cache", ["dynamic", "quantized"])
-def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, cls, compression_ratio, cache): # noqa: F811
- if cls == ThinKPress:
- press = cls(key_channel_compression_ratio=compression_ratio, window_size=2)
- elif cls == SimLayerKVPress:
- press = cls(lazy_threshold=1 - compression_ratio)
- else:
- press = cls(compression_ratio=compression_ratio)
+def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, press_dict, cache): # noqa: F811
+ cls = press_dict["cls"]
+ kwargs = press_dict["kwargs"][0]
+ press = cls(**kwargs)
+
if cache == "dynamic":
cache = DynamicCache()
elif cache == "quantized" and is_optimum_quanto_available():
diff --git a/tests/presses/test_key_rerotation_press_rope.py b/tests/presses/test_key_rerotation_press_rope.py
new file mode 100644
index 00000000..b612a6b8
--- /dev/null
+++ b/tests/presses/test_key_rerotation_press_rope.py
@@ -0,0 +1,113 @@
+# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+from dataclasses import dataclass
+
+import pytest
+import torch
+from torch import nn
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, rotate_half
+
+from kvpress import KeyRerotationPress, ScorerPress
+from kvpress.presses.key_rerotation_press import get_rope_embeddings
+from tests.fixtures import unit_test_model # noqa: F401
+
+
+@pytest.mark.parametrize("precision", ["full", "half"])
+def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: LlamaForCausalLM, precision): # noqa: F811
+ """
+ Compare KeyRerotationPress' rerotation of keys with the reference implementation.
+ In the reference implementation, we are computing
+ 1. keys = W_k * hidden_states
+ 2. keys_pruned = prune(keys)
+ 3. keys = RoPE(keys_pruned)
+ """
+ if precision == "half" and torch.cuda.is_available():
+ unit_test_model = unit_test_model.cuda().half()
+ elif precision == "half" and not torch.cuda.is_available():
+ pytest.skip("Half precision test is skipped because CUDA is not available.")
+
+ original_press = RandomPressStoreIndices(compression_ratio=0.5)
+ key_rerotation_press = KeyRerotationPress(press=original_press)
+
+ module = unit_test_model.model.layers[0].self_attn
+ hidden_states = torch.randn(
+ 8, 64, module.config.hidden_size, device=unit_test_model.device, dtype=unit_test_model.dtype
+ )
+
+ keys = get_keys_with_rope(module, hidden_states)
+
+ values = torch.randn_like(keys)
+ # Press result
+ keys_compressed, _ = key_rerotation_press.compress(
+ module, hidden_states, keys, values, attentions=None, kwargs=dict()
+ )
+
+ indices = original_press.indices
+ keys_compressed_ref = compute_rerotated_keys_comparison_implementation(module, hidden_states, indices)
+
+ assert torch.allclose(keys_compressed, keys_compressed_ref, atol=1e-6 if precision == "full" else 1e-3)
+
+
+def get_keys_with_rope(module, hidden_states):
+ # Compute keys with RoPE
+ keys = get_keys_without_pos_embedding(module, hidden_states)
+ cos, sin = get_rope_embeddings(module, keys)
+ keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
+ return keys
+
+
+@dataclass
+class RandomPressStoreIndices(ScorerPress):
+ compression_ratio: float = 0.0
+ seed: int = 0
+
+ def __post_init__(self):
+ self.indices = None
+ super().__post_init__()
+
+ def score(
+ self,
+ module: nn.Module,
+ hidden_states: torch.Tensor,
+ keys: torch.Tensor,
+ values: torch.Tensor,
+ attentions: torch.Tensor,
+ kwargs,
+ ) -> torch.Tensor:
+ torch.manual_seed(self.seed)
+ scores = torch.rand(*keys.shape[:-1]).to(keys.device, keys.dtype)
+ # Get indices of KV pairs with the lowest scores
+ q_len = hidden_states.shape[1]
+ n_kept = int(q_len * (1 - self.compression_ratio))
+ indices = scores.topk(n_kept, dim=-1).indices
+ indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
+ self.indices = indices
+
+ return scores
+
+
+def compute_rerotated_keys_comparison_implementation(module: LlamaAttention, hidden_states, indices):
+ """
+ Computes the rerotated keys for the given indices.
+ 1. keys = W_k * hidden_states
+ 2. keys_pruned = prune(keys)
+ 3. keys = RoPE(keys_pruned)
+ """
+ # 1.
+ keys = get_keys_without_pos_embedding(module, hidden_states)
+ # 2.
+ keys = keys.gather(2, indices).contiguous()
+ # 3.
+ cos, sin = get_rope_embeddings(module, keys)
+ keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
+ return keys
+
+
+def get_keys_without_pos_embedding(module, hidden_states):
+ key_states = module.k_proj(hidden_states)
+ key_states = key_states.view(
+ key_states.shape[0], key_states.shape[1], module.num_key_value_heads, module.head_dim
+ ).transpose(1, 2)
+ return key_states
diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py
index 29ae9a40..50056b9a 100644
--- a/tests/presses/test_presses.py
+++ b/tests/presses/test_presses.py
@@ -2,23 +2,15 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
+import pytest
import torch
from torch import nn
from transformers import DynamicCache
-from kvpress import (
- ComposedPress,
- ExpectedAttentionPress,
- KnormPress,
- ObservedAttentionPress,
- RandomPress,
- SimLayerKVPress,
- SnapKVPress,
- StreamingLLMPress,
- TOVAPress,
-)
+from kvpress import ComposedPress, KeyRerotationPress, KnormPress, ObservedAttentionPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.think_press import ThinKPress
+from tests.default_presses import default_presses
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401
@@ -31,40 +23,27 @@ def test_composed_press(unit_test_model): # noqa: F811
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
-def test_presses_run(unit_test_model): # noqa: F811
- for cls in [
- KnormPress,
- ExpectedAttentionPress,
- RandomPress,
- StreamingLLMPress,
- SimLayerKVPress,
- SnapKVPress,
- TOVAPress,
- ThinKPress,
- ]:
- for value in [0.2, 0.4, 0.6, 0.8]:
-
- # Load the press
- if cls == ThinKPress:
- press = cls(key_channel_compression_ratio=value, window_size=2)
- elif cls == SimLayerKVPress:
- press = cls(lazy_threshold=value, n_initial=1, n_recent=1, n_last=1)
- else:
- press = cls(compression_ratio=value)
- if cls == SnapKVPress:
- press.window_size = 2
-
- # Run the press
- with press(unit_test_model):
- input_ids = unit_test_model.dummy_inputs["input_ids"]
- unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
- # Check that the press has a compression_ratio attribute
- assert hasattr(press, "compression_ratio")
+@pytest.mark.parametrize("press_dict", default_presses)
+@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress])
+def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
+ cls = press_dict["cls"]
+ for kwargs in press_dict["kwargs"]:
+ press = cls(**kwargs)
+ if isinstance(wrapper_press, ComposedPress):
+ press = ComposedPress(presses=[press])
+ if isinstance(wrapper_press, KeyRerotationPress):
+ press = KeyRerotationPress(press=press)
+
+ with press(unit_test_model):
+ input_ids = unit_test_model.dummy_inputs["input_ids"]
+ unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
+ # Check that the press has a compression_ratio attribute
+ assert hasattr(press, "compression_ratio")
def test_presses_run_observed_attention(unit_test_model_output_attention): # noqa: F811
for cls in [ObservedAttentionPress]:
- for compresion_ratio in [0.2, 0.4, 0.6, 0.8]:
+ for compresion_ratio in [0.2, 0.8]:
press = cls(compression_ratio=compresion_ratio)
with press(unit_test_model_output_attention):
input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"]