Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
d14ea54
use dependency injection for press
maxjeblick Dec 2, 2024
1054ac2
rename press-> pruner class
maxjeblick Dec 2, 2024
d05437b
docstring fixes
maxjeblick Dec 2, 2024
95275a3
fix default observed attention press
maxjeblick Dec 2, 2024
68d6071
fix test
maxjeblick Dec 2, 2024
38e6cea
fix test
maxjeblick Dec 2, 2024
362f723
fix test
maxjeblick Dec 2, 2024
4687bad
fix test
maxjeblick Dec 2, 2024
6ded859
fix typo
maxjeblick Dec 2, 2024
774e468
update notebooks
maxjeblick Dec 3, 2024
1162b4e
Merge branch 'main' into max/refactor_press
maxjeblick Dec 3, 2024
76ca9f2
fix merge conflicts
maxjeblick Dec 3, 2024
926290e
add think_press to pruners
maxjeblick Dec 3, 2024
a370fc7
create a basepruner class
maxjeblick Dec 3, 2024
a54b17c
update readme
maxjeblick Dec 3, 2024
7aaef60
update readme
maxjeblick Dec 3, 2024
0e8a194
improve example notebook
maxjeblick Dec 3, 2024
44734a4
remove default compression_ratios argument
maxjeblick Dec 3, 2024
9161e4a
remove default compression_ratios argument
maxjeblick Dec 3, 2024
21102c3
fix circular import
maxjeblick Dec 4, 2024
33ab1a6
fix circular import
maxjeblick Dec 4, 2024
2bd6d93
fix circular import
maxjeblick Dec 4, 2024
437a6ae
fix import
maxjeblick Dec 4, 2024
3ef2970
fix import
maxjeblick Dec 4, 2024
d427b74
fix test
maxjeblick Dec 4, 2024
ee3a0d0
fix test
maxjeblick Dec 4, 2024
44917a4
fix type annotation
maxjeblick Dec 4, 2024
1e4a36f
make compression for think more explicit
maxjeblick Dec 4, 2024
ca7ef7c
address pr feedback
maxjeblick Dec 10, 2024
5bff26a
address pr feedback
maxjeblick Dec 10, 2024
2646ba4
address pr feedback
maxjeblick Dec 10, 2024
bbc03f8
address pr feedback
maxjeblick Dec 10, 2024
158d5c6
fix style
maxjeblick Dec 10, 2024
afc87a3
fix tests
maxjeblick Dec 10, 2024
fa78b82
make field non mutable
maxjeblick Dec 10, 2024
fc644f4
fix field init
maxjeblick Dec 10, 2024
4476336
fix failing tests
maxjeblick Dec 10, 2024
75ecce8
address pr feedback
maxjeblick Dec 10, 2024
c81d2ce
address pr feedback
maxjeblick Dec 10, 2024
91a542f
remove scorer
maxjeblick Dec 10, 2024
d831bd8
add license
maxjeblick Dec 10, 2024
a2737c4
improve pr
maxjeblick Dec 10, 2024
cd5b643
fix readme
maxjeblick Dec 10, 2024
34968af
fix notebooks
maxjeblick Dec 10, 2024
0015a00
fix notebooks
maxjeblick Dec 10, 2024
b0b6538
add back ThinKPress
maxjeblick Dec 10, 2024
f613307
fix tests
maxjeblick Dec 10, 2024
b2b6b89
update notebook
maxjeblick Dec 10, 2024
1c617dd
add back comment
maxjeblick Dec 10, 2024
4fe50bd
fix import
maxjeblick Dec 10, 2024
c9d20c1
address pr feedback
maxjeblick Dec 10, 2024
a20f0fa
fix tests
maxjeblick Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ However, the `generate` method does not allow to exclude the question from the c
### How to create a new press ?
</summary>

All presses are stored in the `presses` directory. The easiest way to create a new press is to create a class that inherits from `BasePress` and implement a `score` method that computes the score for each key-value pair (see `knorm_press.py` for a simple example). Check the notebook [new_press.ipynb](notebooks/new_press.ipynb) for a step-by-step guide.
All presses are stored in the `presses` directory. The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair (see `knorm_press.py` for a simple example). Check the notebook [new_press.ipynb](notebooks/new_press.ipynb) for a step-by-step guide.

Before opening a pull request with a new press, make sure to register it in the `__init__.py` file of repository and to add it in [test_presses.py](tests/presses/test_presses.py).

Expand All @@ -181,9 +181,9 @@ Before opening a pull request with a new press, make sure to register it in the

We provide an experimental feature, which only works with flash attention:
```python
from kvpress import apply_per_layer_compression
from kvpress import PerLayerCompressionPress
# compression_ratios should have the same length as the number of layers
press = apply_per_layer_compression(press, compression_ratios=[...])
press = PerLayerCompressionPress(press, compression_ratios=[...])
```

Check the [demo notebook](notebooks/per_layer_compression_demo.ipynb) for more details.
Expand Down
9 changes: 6 additions & 3 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@
# SPDX-License-Identifier: Apache-2.0


from kvpress.per_layer_compression_wrapper import apply_per_layer_compression
from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.base_press import BasePress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.think_press import ThinKPress

__all__ = [
"BasePress",
"ScorerPress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
Expand All @@ -25,5 +26,7 @@
"ThinKPress",
"TOVAPress",
"KVPressTextGenerationPipeline",
"apply_per_layer_compression",
"PerLayerCompressionPress",
]

from kvpress.presses.tova_press import TOVAPress
48 changes: 0 additions & 48 deletions kvpress/per_layer_compression_wrapper.py

This file was deleted.

2 changes: 1 addition & 1 deletion kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional

import torch
from transformers import AutoModelForCausalLM, Cache, DynamicCache, QuantizedCache, Pipeline
from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline, QuantizedCache
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.pipelines.base import GenericTensor

Expand Down
103 changes: 10 additions & 93 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,27 @@

import logging
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Generator

import torch
from torch import nn
from transformers import (
LlamaForCausalLM,
MistralForCausalLM,
Phi3ForCausalLM,
PreTrainedModel,
Qwen2ForCausalLM,
QuantizedCache,
)
from transformers import LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, Qwen2ForCausalLM

logger = logging.getLogger(__name__)


@dataclass
class BasePress:
"""Base class for pruning methods.
Each pruning method should implement a `score` method that computes the scores for each KV pair in a layer.
This score is used to prune the KV pairs with the lowest scores in the `hook` method
The `hook` method is called after the forward pass of a layer and updates the cache with the pruned KV pairs.
The press can be applied to a model by calling it with the model as an argument.
"""

def __init__(self, compression_ratio: float = 0.0):
self.compression_ratio = compression_ratio
assert 0 <= compression_ratio < 1, "Compression ratio must be between 0 and 1"

def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
"""Compute the scores for each KV pair in the layer.

Parameters
----------
module :
Transformer layer, see `hook` method for more details.
hidden_states :
Hidden states of the layer.
keys :
Keys of the cache. Note keys are after RoPE.
values :
Values of the cache.
attentions :
Attention weights of the layer.
kwargs :
Keyword arguments, as given to the forward pass of the layer.

Returns
-------
Scores for each KV pair in the layer, shape keys.shape[:-1].

"""
raise NotImplementedError
Base class for all pruning methods.
The `forward_hook` method is called after the forward pass of an attention layer.
Any pruning/updating method should be implemented in the derived class.
"""

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""Cache compression hook called after the forward pass of a decoder layer.
"""Cache compression hook called after the forward pass of an attention layer.
The hook is applied only during the pre-filling phase if there is some pruning ratio.
The current implementation only allows to remove a constant number of KV pairs.

Parameters
----------
Expand All @@ -84,47 +40,8 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
Returns
-------
Modified output of the forward pass of the layer.

"""
# See e.g. LlamaDecoderLayer.forward for the output structure
if len(output) == 3:
_, attentions, cache = output
else:
attentions, cache = None, output[-1]

hidden_states = kwargs["hidden_states"]
q_len = hidden_states.shape[1]

# Don't compress if the compression ratio is 0 or this is not pre-filling
if (self.compression_ratio == 0) or (cache.seen_tokens > q_len):
return output

if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
else:
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]

with torch.no_grad():
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)

# Prune KV pairs with the lowest scores
n_kept = int(q_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

# Update cache
keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()
if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values

return output
raise NotImplementedError("forward_hook method must be implemented in the derived class")

@contextmanager
def __call__(self, model: PreTrainedModel) -> Generator:
Expand All @@ -141,8 +58,8 @@ def __call__(self, model: PreTrainedModel) -> Generator:
if not isinstance(model, (LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, Qwen2ForCausalLM)):
logger.warning(f"Model {type(model)} not tested")

hooks = []
try:
hooks = []
for layer in model.model.layers:
hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True))

Expand Down
6 changes: 3 additions & 3 deletions kvpress/presses/expected_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import repeat_kv

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress


@dataclass
class ExpectedAttentionPress(BasePress):
class ExpectedAttentionPress(ScorerPress):
"""
Compute scores based on the expected attention on next positions. To do so
1. Compute the mean and covariance matrix of the queries before RoPE.
Expand Down
7 changes: 5 additions & 2 deletions kvpress/presses/knorm_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# SPDX-License-Identifier: Apache-2.0


from dataclasses import dataclass

import torch
from torch import nn

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress


class KnormPress(BasePress):
@dataclass
class KnormPress(ScorerPress):
"""Prune KV pairs with highest L2 norm of keys (https://arxiv.org/pdf/2406.11430)"""

def score(
Expand Down
24 changes: 14 additions & 10 deletions kvpress/presses/observed_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,36 @@
# SPDX-License-Identifier: Apache-2.0


import logging
from dataclasses import dataclass

import torch
from torch import nn
from transformers.utils import logging

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress

logger = logging.get_logger(__name__)
logger = logging.getLogger(__name__)


@dataclass
class ObservedAttentionPress(BasePress):
"""The observed attention score is defined as the average attention weight over all prompt tokens
class ObservedAttentionPress(ScorerPress):
"""
The observed attention score is defined as the average attention weight over all prompt tokens
Requires output_attentions=True and attn_implementation="eager" to have access to attentions
This approach is related to H2O (https://arxiv.org/abs/2306.14048).
"""

compression_ratio: float = 0.0
output_attentions: bool = False

def __post_init__(self):
if not self.output_attentions:
logger.warning(
"Model will not return attentions in its output to save memory. Please use DefaultPruner if"
" attentions are needed in the output."
)
super().__post_init__()

def score(
self,
module: nn.Module,
Expand All @@ -42,14 +51,9 @@ def score(

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
output = super().forward_hook(module, input, kwargs, output)

# attentions are needed as input for the hook, but unless the user wants to return them in the output,
# we can remove them to save memory
if not self.output_attentions:
logger.warning_once(
"Model will not return attentions in its output to save memory. "
"Set output_attentions=True in ObservedAttentionPress to return attentions."
)
output = list(output)
output[-2] = None
output = tuple(output)
Expand Down
49 changes: 49 additions & 0 deletions kvpress/presses/per_layer_compression_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import inspect
import logging
from dataclasses import dataclass
from typing import List

import torch
from torch import nn

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress

logger = logging.getLogger(__name__)


@dataclass
class PerLayerCompressionPress(BasePress):
press: ScorerPress
compression_ratios: List[float]

def __post_init__(self):
logger.warning(
"Per layer compression wrapper is an experimental feature and only works with flash attention. "
"Please make sure that the model uses flash attention."
)
assert (
"compression_ratio"
in inspect.signature(
self.press.__init__ # type:ignore[misc]
).parameters
), f"compression_ratio can't be set in the provided press: {self.press.__class__}"

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
original_compression_ratio = self.press.compression_ratio # type:ignore[attr-defined]
self.press.compression_ratio = self.compression_ratios[module.layer_idx] # type:ignore[attr-defined]
output = self.press.forward_hook(module, input, kwargs, output)
self.press.compression_ratio = original_compression_ratio # type:ignore[attr-defined]
return output

@property
def compression_ratio(self):
return sum(self.compression_ratios) / len(self.compression_ratios)

@compression_ratio.setter
def compression_ratio(self, value):
raise AttributeError(f"compression ratio cannot be set for {type(self).__name__}")
Loading