diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..d3f0e0f0 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,10 @@ +## PR description + +Description of your PR. Fixes # (issue) (if applicable) + +## New press checklist (if applicable) + +- [ ] I added `mypress_press.py` in the `presses` directory +- [ ] I added `MyPress` in `__init__.py` +- [ ] I updated the `README.md` with a 1 liner about my new press in the Available presses section +- [ ] I added my press in the `default_presses` list in `tests/default_presses.py` diff --git a/README.md b/README.md index 6e01a5ca..95f0e57b 100644 --- a/README.md +++ b/README.md @@ -19,9 +19,7 @@ pip install flash-attn --no-build-isolation ## Usage -This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you: - - +This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you: ```python from kvpress import ExpectedAttentionPress @@ -48,27 +46,28 @@ In the snippet above, the compression is only applied on the context tokens so t ## Contributing with a new press -We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [FAQ](#faq) for more information on how presses work and how to create new ones or check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide. +We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide to understand how presses work and what should be done to create a new one. ## Available presses -All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score used to prune the KV pairs with lowest importance: +All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score to prune the KV pairs with lowest importance: - `RandomPress`: random score - `KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430)) -- `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469)) +- `SnapKVPress`: average attention weight of the last queries ([paper](https://arxiv.org/abs/2404.14469)) - `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb)) - `StreamingLLMPress`: keep only the initial and recent tokens ([paper](https://arxiv.org/abs/2309.17453)) -- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)). The input of this press is the lazy threshold, not the compression ratio. - `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104)) - `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048)) -We also provide presses relying on a different logic: -- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)) +Some presses relying on a different logic: +- `ThinKPress`: compress the dimensions of the keys based on the channel attention score on the last queries ([paper](https://arxiv.org/pdf/2407.21018)) +- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)) -Finally we provide two special presses: -- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental) -- `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks +Finally we provide special presses: +- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio +- `ComposedPress`: compose multiple presses together by chaining their forward hooks +- `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) @@ -129,9 +128,7 @@ Memory usage should be reduced by around `compression_ratio * kv_cache_size`. As ### How does a press work ? -A press registers a forward hook to each attention layer during the pre-filling phase: -1. Immediately after the forward pass, the hook is called, and it computes a score for each key-value pair using the `press.score` method -2. The key-value pairs with the lowest scores are then removed based on the `compression_ratio` parameter +A press registers a forward hook (`press.forward_hook` method) to each attention layer during the pre-filling phase. Registration can be applied using the press as a context manager (`press.__call__` method): ```python import torch @@ -170,29 +167,3 @@ with press(model): However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once. - -
- -### How to create a new press ? - - -All presses are stored in the `presses` directory. The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair (see `knorm_press.py` for a simple example). Check the notebook [new_press.ipynb](notebooks/new_press.ipynb) for a step-by-step guide. - -Before opening a pull request with a new press, make sure to register it in the `__init__.py` file of repository and to add it in [test_presses.py](tests/presses/test_presses.py). - -
- -
- -### Can I change the compression ratio from one layer to another ? - - -We provide an experimental feature, which only works with flash attention: -```python -from kvpress import PerLayerCompressionPress -# compression_ratios should have the same length as the number of layers -press = PerLayerCompressionPress(press, compression_ratios=[...]) -``` - -Check the [demo notebook](notebooks/per_layer_compression_demo.ipynb) for more details. -
diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 913a0828..e2693aaf 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -6,6 +6,7 @@ from kvpress.presses.base_press import BasePress from kvpress.presses.composed_press import ComposedPress from kvpress.presses.expected_attention_press import ExpectedAttentionPress +from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.knorm_press import KnormPress from kvpress.presses.observed_attention_press import ObservedAttentionPress from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress @@ -32,4 +33,5 @@ "TOVAPress", "KVPressTextGenerationPipeline", "PerLayerCompressionPress", + "KeyRerotationPress", ] diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index b9159d1e..456f7e4a 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -12,7 +12,10 @@ from transformers.pipelines.base import GenericTensor from kvpress.presses.base_press import BasePress +from kvpress.presses.composed_press import ComposedPress +from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.observed_attention_press import ObservedAttentionPress +from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress logger = logging.getLogger(__name__) @@ -167,7 +170,7 @@ def _forward( self.model( input_ids=context_ids, past_key_values=cache, - output_attentions=isinstance(press, ObservedAttentionPress), + output_attentions=self.output_attentions, num_logits_to_keep=1, ) @@ -180,13 +183,26 @@ def _forward( answer = self.generate_answer( question_ids=question_ids.to(self.model.device), cache=cache, - context_length=context_length, + context_length=(cache.get_seq_length() if isinstance(press, KeyRerotationPress) else context_length), max_new_tokens=max_new_tokens, ) answers.append(answer) return answers + def output_attentions(self, press: BasePress): + if isinstance(press, ObservedAttentionPress): + return True + if isinstance(press, (KeyRerotationPress, PerLayerCompressionPress)) and isinstance( + press.press, ObservedAttentionPress + ): + return True + if isinstance(press, ComposedPress) and any( + isinstance(sub_press, ObservedAttentionPress) for sub_press in press.presses + ): + return True + return False + def postprocess(self, model_outputs, single_question): if single_question: return {"answer": model_outputs[0]} diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py new file mode 100644 index 00000000..514178a5 --- /dev/null +++ b/kvpress/presses/key_rerotation_press.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import inspect +from dataclasses import dataclass + +import torch +from torch import nn +from transformers.models.llama.modeling_llama import rotate_half + +from kvpress.presses.base_press import BasePress +from kvpress.presses.scorer_press import ScorerPress + + +@dataclass +class KeyRerotationPress(BasePress): + """ + Rerotate keys to have a uniform RoPE representation of keys after pruning. + This method is used in several key-value cache compression methods, such as + - SinkCache implementation in Hugging Face's transformers library + - FINCH: Prompt-guided Key-Value Cache Compression for Large Language Models + Parameters + ---------- + press : ScorerPress + The press object to apply per-layer compression to. + """ + + press: ScorerPress + + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.press.compression_ratio == 0: + return keys, values + + # Compute scores from base press + scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) + + # Get indices of KV pairs with the lowest scores + q_len = hidden_states.shape[1] + n_kept = int(q_len * (1 - self.press.compression_ratio)) + indices = scores.topk(n_kept, dim=-1).indices + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) + + cos, sin = get_rope_embeddings(module, keys) + # Rerotate as follows + # 1. keys = RoPE(W_k * hidden_states) + # 2. keys_unrotated = RoPE^-1(keys) + # 3. keys_pruned = prune(keys_unrotated) + # 4. keys = RoPE(keys_pruned) + + # 2. Inverse of rotation matrix is equivalent to setting sin -> -sin in the equation below + keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1))) + # 3. Prune keys + keys = keys.gather(2, indices).contiguous() + # 4. Apply RoPE + cos, sin = get_rope_embeddings(module, keys) + keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) + + values = values.gather(2, indices).contiguous() + return keys, values + + +def get_rope_embeddings(module, x): + length = x.shape[2] + # rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape + if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: + position_ids = torch.arange(length).unsqueeze(0).to(x.device) + cos, sin = module.rotary_emb(x, position_ids) + else: + cos, sin = module.rotary_emb(x, length) + return cos, sin diff --git a/kvpress/presses/random_press.py b/kvpress/presses/random_press.py index 7c013a18..063e8082 100644 --- a/kvpress/presses/random_press.py +++ b/kvpress/presses/random_press.py @@ -3,6 +3,7 @@ from dataclasses import dataclass +from typing import Optional import torch from torch import nn @@ -14,6 +15,9 @@ class RandomPress(ScorerPress): """Randomly prune KV pairs""" + compression_ratio: float = 0.0 + seed: Optional[int] = None + def score( self, module: nn.Module, @@ -23,4 +27,6 @@ def score( attentions: torch.Tensor, kwargs, ) -> torch.Tensor: + if self.seed is not None: + torch.manual_seed(self.seed) return torch.rand(*keys.shape[:-1]).to(keys.device, keys.dtype) diff --git a/kvpress/presses/streaming_llm_press.py b/kvpress/presses/streaming_llm_press.py index b0560541..c63a85b8 100644 --- a/kvpress/presses/streaming_llm_press.py +++ b/kvpress/presses/streaming_llm_press.py @@ -16,6 +16,10 @@ class StreamingLLMPress(ScorerPress): Prune a fixed number of KV pairs at the beginning and end of the sequence (https://arxiv.org/abs/2309.17453) We keep the first n_sink tokens and the last n_local tokens. n_local is computed using the compression ratio. + + Note that the original implementation https://github.com/mit-han-lab/streaming-llm additionally rerotates keys. + This can be achieved by using + press = KeyRerotationPress(press=StreamingLLMPress(compression_ratio, n_sink)) """ compression_ratio: float = 0.0 diff --git a/notebooks/new_press.ipynb b/notebooks/new_press.ipynb index 47bba3c2..c9ed6769 100644 --- a/notebooks/new_press.ipynb +++ b/notebooks/new_press.ipynb @@ -242,7 +242,7 @@ "source": [ "All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to \n", "- register it in the `__init__.py` file of repository\n", - "- add a test [test_presses.py](tests/presses/test_presses.py)\n", + "- register the press in [default_presses.py](tests/default_presses.py)\n", "- update the README" ] } diff --git a/pyproject.toml b/pyproject.toml index 024ca89e..cb0a686f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "kvpress" authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"] description = "Efficiently compress the KV cache of any pretrained transformer" -version = "0.0.4" +version = "0.1.0" readme = "README.md" [tool.poetry.dependencies] diff --git a/tests/default_presses.py b/tests/default_presses.py new file mode 100644 index 00000000..08d2fb13 --- /dev/null +++ b/tests/default_presses.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from kvpress import ( + ExpectedAttentionPress, + KnormPress, + RandomPress, + SimLayerKVPress, + SnapKVPress, + StreamingLLMPress, + ThinKPress, + TOVAPress, +) + +# contains all presses to be tested +# kwargs should be ordered easy to hard compression +default_presses = [ + {"cls": KnormPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + {"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + {"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + {"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + { + "cls": SnapKVPress, + "kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}], + }, + {"cls": TOVAPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + { + "cls": ThinKPress, + "kwargs": [ + {"key_channel_compression_ratio": 0.2, "window_size": 2}, + {"key_channel_compression_ratio": 0.8, "window_size": 2}, + ], + }, + { + "cls": SimLayerKVPress, + "kwargs": [ + {"lazy_threshold": 0.8, "n_initial": 1, "n_recent": 1, "n_last": 1}, + {"lazy_threshold": 0.2, "n_initial": 1, "n_recent": 1, "n_last": 1}, + ], + }, +] diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index 460203a0..41a5a0a7 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -4,15 +4,7 @@ from transformers import DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available -from kvpress import ( - ExpectedAttentionPress, - KnormPress, - SimLayerKVPress, - SnapKVPress, - StreamingLLMPress, - ThinKPress, - TOVAPress, -) +from tests.default_presses import default_presses from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401 @@ -25,18 +17,13 @@ def df_ruler(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") -@pytest.mark.parametrize( - "cls", [KnormPress, ExpectedAttentionPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress, SimLayerKVPress] -) -@pytest.mark.parametrize("compression_ratio", [0.1, 0.2]) +@pytest.mark.parametrize("press_dict", default_presses) @pytest.mark.parametrize("cache", ["dynamic", "quantized"]) -def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, cls, compression_ratio, cache): # noqa: F811 - if cls == ThinKPress: - press = cls(key_channel_compression_ratio=compression_ratio, window_size=2) - elif cls == SimLayerKVPress: - press = cls(lazy_threshold=1 - compression_ratio) - else: - press = cls(compression_ratio=compression_ratio) +def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, press_dict, cache): # noqa: F811 + cls = press_dict["cls"] + kwargs = press_dict["kwargs"][0] + press = cls(**kwargs) + if cache == "dynamic": cache = DynamicCache() elif cache == "quantized" and is_optimum_quanto_available(): diff --git a/tests/presses/test_key_rerotation_press_rope.py b/tests/presses/test_key_rerotation_press_rope.py new file mode 100644 index 00000000..b612a6b8 --- /dev/null +++ b/tests/presses/test_key_rerotation_press_rope.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from dataclasses import dataclass + +import pytest +import torch +from torch import nn +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, rotate_half + +from kvpress import KeyRerotationPress, ScorerPress +from kvpress.presses.key_rerotation_press import get_rope_embeddings +from tests.fixtures import unit_test_model # noqa: F401 + + +@pytest.mark.parametrize("precision", ["full", "half"]) +def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: LlamaForCausalLM, precision): # noqa: F811 + """ + Compare KeyRerotationPress' rerotation of keys with the reference implementation. + In the reference implementation, we are computing + 1. keys = W_k * hidden_states + 2. keys_pruned = prune(keys) + 3. keys = RoPE(keys_pruned) + """ + if precision == "half" and torch.cuda.is_available(): + unit_test_model = unit_test_model.cuda().half() + elif precision == "half" and not torch.cuda.is_available(): + pytest.skip("Half precision test is skipped because CUDA is not available.") + + original_press = RandomPressStoreIndices(compression_ratio=0.5) + key_rerotation_press = KeyRerotationPress(press=original_press) + + module = unit_test_model.model.layers[0].self_attn + hidden_states = torch.randn( + 8, 64, module.config.hidden_size, device=unit_test_model.device, dtype=unit_test_model.dtype + ) + + keys = get_keys_with_rope(module, hidden_states) + + values = torch.randn_like(keys) + # Press result + keys_compressed, _ = key_rerotation_press.compress( + module, hidden_states, keys, values, attentions=None, kwargs=dict() + ) + + indices = original_press.indices + keys_compressed_ref = compute_rerotated_keys_comparison_implementation(module, hidden_states, indices) + + assert torch.allclose(keys_compressed, keys_compressed_ref, atol=1e-6 if precision == "full" else 1e-3) + + +def get_keys_with_rope(module, hidden_states): + # Compute keys with RoPE + keys = get_keys_without_pos_embedding(module, hidden_states) + cos, sin = get_rope_embeddings(module, keys) + keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) + return keys + + +@dataclass +class RandomPressStoreIndices(ScorerPress): + compression_ratio: float = 0.0 + seed: int = 0 + + def __post_init__(self): + self.indices = None + super().__post_init__() + + def score( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs, + ) -> torch.Tensor: + torch.manual_seed(self.seed) + scores = torch.rand(*keys.shape[:-1]).to(keys.device, keys.dtype) + # Get indices of KV pairs with the lowest scores + q_len = hidden_states.shape[1] + n_kept = int(q_len * (1 - self.compression_ratio)) + indices = scores.topk(n_kept, dim=-1).indices + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) + self.indices = indices + + return scores + + +def compute_rerotated_keys_comparison_implementation(module: LlamaAttention, hidden_states, indices): + """ + Computes the rerotated keys for the given indices. + 1. keys = W_k * hidden_states + 2. keys_pruned = prune(keys) + 3. keys = RoPE(keys_pruned) + """ + # 1. + keys = get_keys_without_pos_embedding(module, hidden_states) + # 2. + keys = keys.gather(2, indices).contiguous() + # 3. + cos, sin = get_rope_embeddings(module, keys) + keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) + return keys + + +def get_keys_without_pos_embedding(module, hidden_states): + key_states = module.k_proj(hidden_states) + key_states = key_states.view( + key_states.shape[0], key_states.shape[1], module.num_key_value_heads, module.head_dim + ).transpose(1, 2) + return key_states diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 29ae9a40..50056b9a 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -2,23 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +import pytest import torch from torch import nn from transformers import DynamicCache -from kvpress import ( - ComposedPress, - ExpectedAttentionPress, - KnormPress, - ObservedAttentionPress, - RandomPress, - SimLayerKVPress, - SnapKVPress, - StreamingLLMPress, - TOVAPress, -) +from kvpress import ComposedPress, KeyRerotationPress, KnormPress, ObservedAttentionPress from kvpress.presses.scorer_press import ScorerPress from kvpress.presses.think_press import ThinKPress +from tests.default_presses import default_presses from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401 @@ -31,40 +23,27 @@ def test_composed_press(unit_test_model): # noqa: F811 unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values -def test_presses_run(unit_test_model): # noqa: F811 - for cls in [ - KnormPress, - ExpectedAttentionPress, - RandomPress, - StreamingLLMPress, - SimLayerKVPress, - SnapKVPress, - TOVAPress, - ThinKPress, - ]: - for value in [0.2, 0.4, 0.6, 0.8]: - - # Load the press - if cls == ThinKPress: - press = cls(key_channel_compression_ratio=value, window_size=2) - elif cls == SimLayerKVPress: - press = cls(lazy_threshold=value, n_initial=1, n_recent=1, n_last=1) - else: - press = cls(compression_ratio=value) - if cls == SnapKVPress: - press.window_size = 2 - - # Run the press - with press(unit_test_model): - input_ids = unit_test_model.dummy_inputs["input_ids"] - unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values - # Check that the press has a compression_ratio attribute - assert hasattr(press, "compression_ratio") +@pytest.mark.parametrize("press_dict", default_presses) +@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress]) +def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 + cls = press_dict["cls"] + for kwargs in press_dict["kwargs"]: + press = cls(**kwargs) + if isinstance(wrapper_press, ComposedPress): + press = ComposedPress(presses=[press]) + if isinstance(wrapper_press, KeyRerotationPress): + press = KeyRerotationPress(press=press) + + with press(unit_test_model): + input_ids = unit_test_model.dummy_inputs["input_ids"] + unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values + # Check that the press has a compression_ratio attribute + assert hasattr(press, "compression_ratio") def test_presses_run_observed_attention(unit_test_model_output_attention): # noqa: F811 for cls in [ObservedAttentionPress]: - for compresion_ratio in [0.2, 0.4, 0.6, 0.8]: + for compresion_ratio in [0.2, 0.8]: press = cls(compression_ratio=compresion_ratio) with press(unit_test_model_output_attention): input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"]