Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
## PR description

Description of your PR. Fixes # (issue) (if applicable)

## New press checklist (if applicable)

- [ ] I added `mypress_press.py` in the `presses` directory
- [ ] I added `MyPress` in `__init__.py`
- [ ] I updated the `README.md` with a 1 liner about my new press in the Available presses section
- [ ] I added my press in the `default_presses` list in `tests/default_presses.py`
53 changes: 12 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ pip install flash-attn --no-build-isolation

## Usage

This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:


This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:

```python
from kvpress import ExpectedAttentionPress
Expand All @@ -48,27 +46,28 @@ In the snippet above, the compression is only applied on the context tokens so t

## Contributing with a new press

We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [FAQ](#faq) for more information on how presses work and how to create new ones or check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide.
We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide to understand how presses work and what should be done to create a new one.

## Available presses

All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score used to prune the KV pairs with lowest importance:
All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score to prune the KV pairs with lowest importance:

- `RandomPress`: random score
- `KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430))
- `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469))
- `SnapKVPress`: average attention weight of the last queries ([paper](https://arxiv.org/abs/2404.14469))
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
- `StreamingLLMPress`: keep only the initial and recent tokens ([paper](https://arxiv.org/abs/2309.17453))
- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)). The input of this press is the lazy threshold, not the compression ratio.
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))

We also provide presses relying on a different logic:
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018))
Some presses relying on a different logic:
- `ThinKPress`: compress the dimensions of the keys based on the channel attention score on the last queries ([paper](https://arxiv.org/pdf/2407.21018))
- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846))

Finally we provide two special presses:
- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental)
- `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks
Finally we provide special presses:
- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio
- `ComposedPress`: compose multiple presses together by chaining their forward hooks
- `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`.

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

Expand Down Expand Up @@ -129,9 +128,7 @@ Memory usage should be reduced by around `compression_ratio * kv_cache_size`. As

### How does a press work ? </summary>

A press registers a forward hook to each attention layer during the pre-filling phase:
1. Immediately after the forward pass, the hook is called, and it computes a score for each key-value pair using the `press.score` method
2. The key-value pairs with the lowest scores are then removed based on the `compression_ratio` parameter
A press registers a forward hook (`press.forward_hook` method) to each attention layer during the pre-filling phase. Registration can be applied using the press as a context manager (`press.__call__` method):

```python
import torch
Expand Down Expand Up @@ -170,29 +167,3 @@ with press(model):
However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once.

</details>

<details><summary>

### How to create a new press ?
</summary>

All presses are stored in the `presses` directory. The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair (see `knorm_press.py` for a simple example). Check the notebook [new_press.ipynb](notebooks/new_press.ipynb) for a step-by-step guide.

Before opening a pull request with a new press, make sure to register it in the `__init__.py` file of repository and to add it in [test_presses.py](tests/presses/test_presses.py).

</details>

<details><summary>

### Can I change the compression ratio from one layer to another ?
</summary>

We provide an experimental feature, which only works with flash attention:
```python
from kvpress import PerLayerCompressionPress
# compression_ratios should have the same length as the number of layers
press = PerLayerCompressionPress(press, compression_ratios=[...])
```

Check the [demo notebook](notebooks/per_layer_compression_demo.ipynb) for more details.
</details>
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from kvpress.presses.base_press import BasePress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
Expand All @@ -32,4 +33,5 @@
"TOVAPress",
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
"KeyRerotationPress",
]
20 changes: 18 additions & 2 deletions kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from transformers.pipelines.base import GenericTensor

from kvpress.presses.base_press import BasePress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -167,7 +170,7 @@ def _forward(
self.model(
input_ids=context_ids,
past_key_values=cache,
output_attentions=isinstance(press, ObservedAttentionPress),
output_attentions=self.output_attentions,
num_logits_to_keep=1,
)

Expand All @@ -180,13 +183,26 @@ def _forward(
answer = self.generate_answer(
question_ids=question_ids.to(self.model.device),
cache=cache,
context_length=context_length,
context_length=(cache.get_seq_length() if isinstance(press, KeyRerotationPress) else context_length),
max_new_tokens=max_new_tokens,
)
answers.append(answer)

return answers

def output_attentions(self, press: BasePress):
if isinstance(press, ObservedAttentionPress):
return True
if isinstance(press, (KeyRerotationPress, PerLayerCompressionPress)) and isinstance(
press.press, ObservedAttentionPress
):
return True
if isinstance(press, ComposedPress) and any(
isinstance(sub_press, ObservedAttentionPress) for sub_press in press.presses
):
return True
return False

def postprocess(self, model_outputs, single_question):
if single_question:
return {"answer": model_outputs[0]}
Expand Down
79 changes: 79 additions & 0 deletions kvpress/presses/key_rerotation_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import inspect
from dataclasses import dataclass

import torch
from torch import nn
from transformers.models.llama.modeling_llama import rotate_half

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress


@dataclass
class KeyRerotationPress(BasePress):
"""
Rerotate keys to have a uniform RoPE representation of keys after pruning.
This method is used in several key-value cache compression methods, such as
- SinkCache implementation in Hugging Face's transformers library
Comment thread
SimJeg marked this conversation as resolved.
- FINCH: Prompt-guided Key-Value Cache Compression for Large Language Models
Parameters
----------
press : ScorerPress
The press object to apply per-layer compression to.
"""

press: ScorerPress

def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.press.compression_ratio == 0:
return keys, values

# Compute scores from base press
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)

# Get indices of KV pairs with the lowest scores
q_len = hidden_states.shape[1]
n_kept = int(q_len * (1 - self.press.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

cos, sin = get_rope_embeddings(module, keys)
# Rerotate as follows
# 1. keys = RoPE(W_k * hidden_states)
# 2. keys_unrotated = RoPE^-1(keys)
# 3. keys_pruned = prune(keys_unrotated)
# 4. keys = RoPE(keys_pruned)
Comment thread
SimJeg marked this conversation as resolved.

# 2. Inverse of rotation matrix is equivalent to setting sin -> -sin in the equation below
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
Comment thread
SimJeg marked this conversation as resolved.
# 3. Prune keys
keys = keys.gather(2, indices).contiguous()
# 4. Apply RoPE
cos, sin = get_rope_embeddings(module, keys)
Comment thread
SimJeg marked this conversation as resolved.
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))

values = values.gather(2, indices).contiguous()
return keys, values


def get_rope_embeddings(module, x):
length = x.shape[2]
# rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
position_ids = torch.arange(length).unsqueeze(0).to(x.device)
cos, sin = module.rotary_emb(x, position_ids)
else:
cos, sin = module.rotary_emb(x, length)
return cos, sin
6 changes: 6 additions & 0 deletions kvpress/presses/random_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


from dataclasses import dataclass
from typing import Optional

import torch
from torch import nn
Expand All @@ -14,6 +15,9 @@
class RandomPress(ScorerPress):
"""Randomly prune KV pairs"""

compression_ratio: float = 0.0
seed: Optional[int] = None

def score(
self,
module: nn.Module,
Expand All @@ -23,4 +27,6 @@ def score(
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
if self.seed is not None:
torch.manual_seed(self.seed)
return torch.rand(*keys.shape[:-1]).to(keys.device, keys.dtype)
4 changes: 4 additions & 0 deletions kvpress/presses/streaming_llm_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class StreamingLLMPress(ScorerPress):
Prune a fixed number of KV pairs at the beginning and end of the sequence (https://arxiv.org/abs/2309.17453)
We keep the first n_sink tokens and the last n_local tokens.
n_local is computed using the compression ratio.

Note that the original implementation https://github.com/mit-han-lab/streaming-llm additionally rerotates keys.
This can be achieved by using
press = KeyRerotationPress(press=StreamingLLMPress(compression_ratio, n_sink))
"""

compression_ratio: float = 0.0
Expand Down
2 changes: 1 addition & 1 deletion notebooks/new_press.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@
"source": [
"All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to \n",
"- register it in the `__init__.py` file of repository\n",
"- add a test [test_presses.py](tests/presses/test_presses.py)\n",
"- register the press in [default_presses.py](tests/default_presses.py)\n",
"- update the README"
]
}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "kvpress"
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
description = "Efficiently compress the KV cache of any pretrained transformer"
version = "0.0.4"
version = "0.1.0"
readme = "README.md"

[tool.poetry.dependencies]
Expand Down
42 changes: 42 additions & 0 deletions tests/default_presses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from kvpress import (
ExpectedAttentionPress,
KnormPress,
RandomPress,
SimLayerKVPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
)

# contains all presses to be tested
# kwargs should be ordered easy to hard compression
default_presses = [
{"cls": KnormPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{
"cls": SnapKVPress,
"kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}],
},
{"cls": TOVAPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{
"cls": ThinKPress,
"kwargs": [
{"key_channel_compression_ratio": 0.2, "window_size": 2},
{"key_channel_compression_ratio": 0.8, "window_size": 2},
],
},
{
"cls": SimLayerKVPress,
"kwargs": [
{"lazy_threshold": 0.8, "n_initial": 1, "n_recent": 1, "n_last": 1},
{"lazy_threshold": 0.2, "n_initial": 1, "n_recent": 1, "n_last": 1},
],
},
]
27 changes: 7 additions & 20 deletions tests/integration/test_ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,7 @@
from transformers import DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available

from kvpress import (
ExpectedAttentionPress,
KnormPress,
SimLayerKVPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
)
from tests.default_presses import default_presses
from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401


Expand All @@ -25,18 +17,13 @@ def df_ruler():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
@pytest.mark.parametrize(
"cls", [KnormPress, ExpectedAttentionPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress, SimLayerKVPress]
)
@pytest.mark.parametrize("compression_ratio", [0.1, 0.2])
@pytest.mark.parametrize("press_dict", default_presses)
@pytest.mark.parametrize("cache", ["dynamic", "quantized"])
def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, cls, compression_ratio, cache): # noqa: F811
if cls == ThinKPress:
press = cls(key_channel_compression_ratio=compression_ratio, window_size=2)
elif cls == SimLayerKVPress:
press = cls(lazy_threshold=1 - compression_ratio)
else:
press = cls(compression_ratio=compression_ratio)
def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, press_dict, cache): # noqa: F811
cls = press_dict["cls"]
kwargs = press_dict["kwargs"][0]
press = cls(**kwargs)

if cache == "dynamic":
cache = DynamicCache()
elif cache == "quantized" and is_optimum_quanto_available():
Expand Down
Loading