Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

![kvpress](kvpress.jpg)

Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. This repository implements multiple KV cache pruning methods and benchmarks using [🤗 transformers](https://huggingface.co/docs/transformers/en/index), aiming to simplify the development of new methods for researchers and developers in this field.
Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. This repository implements multiple KV cache compression methods and benchmarks using [🤗 transformers](https://huggingface.co/docs/transformers/en/index), aiming to simplify the development of new methods for researchers and developers in this field.

## Installation

Expand Down Expand Up @@ -60,7 +60,7 @@ All current presses are training free. Several of them inherit from `ScorerPress
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))

Some presses relying on a different logic:
Some presses rely on a different logic:
- `ThinKPress`: compress the dimensions of the keys based on the channel attention score on the last queries ([paper](https://arxiv.org/pdf/2407.21018))
- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846))

Expand Down Expand Up @@ -101,7 +101,7 @@ pipe(..., cache=cache)
By default, the `DynamicCache` is used (no quantization).

> [!IMPORTANT]
> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)).
> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`).


## FAQ
Expand Down
27 changes: 17 additions & 10 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
RandomPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
)

logger = logging.getLogger(__name__)
Expand All @@ -48,6 +50,8 @@
"random": RandomPress(),
"snapkv": SnapKVPress(),
"streaming_llm": StreamingLLMPress(),
"think": ThinKPress(),
"tova": TOVAPress(),
}


Expand Down Expand Up @@ -110,6 +114,12 @@ def evaluate(
df = load_dataset(DATASET_DICT[dataset], data_dir=data_dir, split="test").to_pandas()
if fraction < 1.0:
df = df.sample(frac=fraction, random_state=42)
save_filename = save_filename.with_name(save_filename.stem + f"__fraction{fraction:.2f}" + save_filename.suffix)
Comment thread
SimJeg marked this conversation as resolved.

if max_context_length is not None:
save_filename = save_filename.with_name(
save_filename.stem + f"__max_context{max_context_length}" + save_filename.suffix
)

if compress_questions:
df["context"] = df["context"] + df["question"]
Expand All @@ -119,27 +129,24 @@ def evaluate(
# Load press
assert press_name in PRESS_DICT
press = PRESS_DICT[press_name]
press.compression_ratio = compression_ratio
press.compression_ratio = compression_ratio # type:ignore[attr-defined]

# Initialize pipeline with the correct attention implementation
model_kwargs = {"torch_dtype": "auto"}
if isinstance(press, ObservedAttentionPress):
model_kwargs = {"attn_implementation": "eager"}
model_kwargs["attn_implementation"] = "eager"
else:
try:
import flash_attn # noqa: F401

model_kwargs = {"attn_implementation": "flash_attention_2"}
model_kwargs["attn_implementation"] = "flash_attention_2"
except ImportError:
model_kwargs = {}
pass

if device == "auto":
pipe = pipeline(
"kv-press-text-generation", model=model, device_map="auto", torch_dtype="auto", model_kwargs=model_kwargs
)
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs)
else:
pipe = pipeline(
"kv-press-text-generation", model=model, device=device, torch_dtype="auto", model_kwargs=model_kwargs
)
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)

# Run pipeline on each context
df["predicted_answer"] = None
Expand Down
1 change: 1 addition & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress


__all__ = [
"BasePress",
"ComposedPress",
Expand Down
12 changes: 4 additions & 8 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,13 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
Modified output of the forward pass of the layer.

"""
# See e.g. LlamaDecoderLayer.forward for the output structure
if len(output) == 3:
_, attentions, cache = output
else:
attentions, cache = None, output[-1]

hidden_states = kwargs["hidden_states"]
cache = kwargs["past_key_value"]
q_len = hidden_states.shape[1]

# Don't compress after pre-filling
if cache.seen_tokens > q_len:
if kwargs["cache_position"][-1] > q_len:
return output

if isinstance(cache, QuantizedCache):
Expand All @@ -106,7 +102,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]

keys, values = self.compress(module, hidden_states, keys, values, attentions, kwargs)
keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs)

if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
Expand Down Expand Up @@ -138,8 +134,8 @@ def __call__(self, model: PreTrainedModel) -> Generator:
hooks = []
try:
for layer in model.model.layers:
layer.self_attn.rotary_emb = model.model.rotary_emb
Comment thread
maxjeblick marked this conversation as resolved.
hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True))

yield
finally:
for forward_hook in hooks:
Expand Down
4 changes: 2 additions & 2 deletions kvpress/presses/composed_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class ComposedPress(BasePress):
def __post_init__(self):
self.compression_ratio = None
assert not any(
isinstance(press, ObservedAttentionPress) for press in self.presses
), "ComposedPress cannot contains ObservedAttentionPress because attentions pruning is not handled"
isinstance(press, (ObservedAttentionPress)) for press in self.presses
), "ComposedPress cannot contains ObservedAttentionPress"

def forward_hook(self, module, input, kwargs, output):
self.compression_ratio = 1.0
Expand Down
19 changes: 8 additions & 11 deletions kvpress/presses/expected_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0


import inspect
import math
from dataclasses import dataclass

Expand Down Expand Up @@ -39,7 +38,7 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor):
"""

bsz, q_len, _ = hidden_states.shape
n, d = module.num_heads, module.head_dim
n, d = module.config.num_attention_heads, module.head_dim

# Remove first hidden_states that likely contain outliers
h = hidden_states[:, self.n_sink :]
Expand All @@ -66,13 +65,9 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor):
cov = cov.permute(0, 3, 1, 2)

# RoPE rotation matrix on next n_future_positions
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device)
cos, sin = module.rotary_emb(mu, position_ids)
cos, sin = cos[0], sin[0]
else:
cos, sin = module.rotary_emb(mu, q_len + self.n_future_positions)
cos, sin = cos[q_len:], sin[q_len:]
position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device)
cos, sin = module.rotary_emb(mu, position_ids)
cos, sin = cos[0], sin[0]

Id = torch.eye(d, device=cos.device, dtype=cos.dtype)
P = torch.zeros((d, d), device=cos.device, dtype=cos.dtype)
Expand Down Expand Up @@ -117,14 +112,16 @@ def score(

# Compute scores
bsz, num_key_value_heads, q_len, d = keys.shape
keys = repeat_kv(keys, module.num_key_value_groups).transpose(2, 3)
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads

keys = repeat_kv(keys, num_key_value_groups).transpose(2, 3)
scores = torch.matmul(mean_query.unsqueeze(2), keys).squeeze(2) / math.sqrt(d)
if self.use_covariance:
scores += torch.einsum("bhin, bhij, bhjn->bhn", keys, cov_query, keys) / d / 2
scores = F.softmax(scores, dim=-1)

# Average scores across groups
scores = scores.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len)
scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len)
scores = scores.mean(dim=2)

# Rescale scores by the norm of the values
Expand Down
19 changes: 5 additions & 14 deletions kvpress/presses/key_rerotation_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0


import inspect
from dataclasses import dataclass

import torch
Expand All @@ -28,6 +27,9 @@ class KeyRerotationPress(BasePress):

press: ScorerPress

def __post_init__(self):
assert isinstance(self.press, ScorerPress)

def compress(
self,
module: nn.Module,
Expand All @@ -49,7 +51,7 @@ def compress(
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

cos, sin = get_rope_embeddings(module, keys)
cos, sin = kwargs["position_embeddings"]
# Rerotate as follows
# 1. keys = RoPE(W_k * hidden_states)
# 2. keys_unrotated = RoPE^-1(keys)
Expand All @@ -61,19 +63,8 @@ def compress(
# 3. Prune keys
keys = keys.gather(2, indices).contiguous()
# 4. Apply RoPE
cos, sin = get_rope_embeddings(module, keys)
cos, sin = cos[:, :n_kept], sin[:, :n_kept]
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))

values = values.gather(2, indices).contiguous()
return keys, values


def get_rope_embeddings(module, x):
length = x.shape[2]
# rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
position_ids = torch.arange(length).unsqueeze(0).to(x.device)
cos, sin = module.rotary_emb(x, position_ids)
else:
cos, sin = module.rotary_emb(x, length)
return cos, sin
4 changes: 1 addition & 3 deletions kvpress/presses/observed_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
# attentions are needed as input for the hook, but unless the user wants to return them in the output,
# we can remove them to save memory
if not self.output_attentions:
output = list(output)
output[-2] = None
output = tuple(output)
output = (output[0], None)

return output
1 change: 1 addition & 0 deletions kvpress/presses/per_layer_compression_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __post_init__(self):
self.press.__init__ # type:ignore[misc]
).parameters
), f"compression_ratio can't be set in the provided press: {self.press.__class__}"
assert isinstance(self.press, ScorerPress), "PerLayerCompressionPress requires a ScorerPress as input"

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
original_compression_ratio = self.press.compression_ratio # type:ignore[attr-defined]
Expand Down
7 changes: 5 additions & 2 deletions kvpress/presses/simlayerkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ def is_lazy(
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
position_embeddings: torch.Tensor,
) -> bool:
"""
Compute the attention weights of the last tokens over the initial and recent tokens.
The layer is considered lazy if the sum of these attention weights is above the lazy_threshold.
"""

attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, self.n_last)
attn_weights = SnapKVPress.compute_window_attention(
module, hidden_states, keys, self.n_last, position_embeddings
)
attn_weights = attn_weights.mean((0, 1, 2)) # mean over bsz, heads and window size
score = attn_weights[: self.n_initial].sum() + attn_weights[-self.n_recent :].sum()
return score.item() > self.lazy_threshold
Expand Down Expand Up @@ -91,7 +94,7 @@ def compress(
return keys, values

# Compression
if self.is_lazy(module, hidden_states, keys):
if self.is_lazy(module, hidden_states, keys, kwargs["position_embeddings"]):
# If layer is lazy, only keep the initial and recent KV pairs
keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2)
values = torch.cat([values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2)
Expand Down
26 changes: 15 additions & 11 deletions kvpress/presses/snapkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,35 @@ class SnapKVPress(ScorerPress):
kernel_size: int = 5

@staticmethod
def compute_window_attention(
module: nn.Module, hidden_states: torch.Tensor, keys: torch.Tensor, window_size: int
) -> torch.Tensor:
def compute_window_attention(module, hidden_states, keys, window_size, position_embeddings):
"""
Compute the last window_size queries and associated attention weights for the first q_len - window_size keys.
"""

bsz, q_len, _ = hidden_states.shape
num_heads = module.config.num_attention_heads
head_dim = module.head_dim
num_key_value_groups = num_heads // module.config.num_key_value_heads

# Get last window_size queries
if hasattr(module, "q_proj"):
query_states = module.q_proj(hidden_states[:, -window_size:])
elif hasattr(module, "qkv_proj"):
qkv = module.qkv_proj(hidden_states[:, -window_size:])
query_states = qkv[..., : module.num_heads * module.head_dim]
query_states = qkv[..., : num_heads * head_dim]
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, window_size, module.num_heads, module.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2)

# Apply RoPE
position_ids = torch.arange(q_len - window_size, q_len).unsqueeze(0).to(query_states.device)
cos, sin = module.rotary_emb(query_states, position_ids)
cos, sin = position_embeddings
cos, sin = cos[:, -window_size:], sin[:, -window_size:]
query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))

# Compute attention for first q_len - window_size tokens
key_states = repeat_kv(keys, module.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
key_states = repeat_kv(keys, num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attention_mask = torch.ones_like(attn_weights) * float("-inf")
attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1)
attn_weights += attention_mask
Expand All @@ -73,19 +74,22 @@ def score(
) -> torch.Tensor:

bsz, num_key_value_heads, q_len, _ = keys.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads

assert q_len > self.window_size, "Query length should be greater than the window size"

if attentions is not None:
attn_weights = attentions[..., -self.window_size :, : -self.window_size]
else:
attn_weights = self.compute_window_attention(module, hidden_states, keys, self.window_size)
attn_weights = self.compute_window_attention(
module, hidden_states, keys, self.window_size, kwargs["position_embeddings"]
)

scores = attn_weights.mean(dim=-2)
scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1)

# Average per grioup (https://github.com/FasterDecoding/SnapKV/issues/22)
scores = scores.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len - self.window_size)
scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.window_size)
scores = scores.mean(2)

# Add back the observation window. Use max score to make sure the window is not pruned.
Expand Down
Loading