From 09fbcbb127e8c407be8b7bb361b25673a3b23bbb Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Wed, 23 Jul 2025 11:06:04 +0200 Subject: [PATCH 1/3] Update base_press.py --- kvpress/presses/base_press.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index c5c8511b..10e2df03 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -131,23 +131,23 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic return output if isinstance(cache, QuantizedCache): - keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) - values = cache._dequantize(cache._quantized_value_cache[module.layer_idx]) + keys = cache.cache_processor._dequantize(cache.cache_processor._quantized_keys[module.layer_idx]) + values = cache.cache_processor._dequantize(cache.cache_processor._quantized_values[module.layer_idx]) else: - keys = cache.key_cache[module.layer_idx] - values = cache.value_cache[module.layer_idx] + keys = cache.layers[module.layer_idx].keys + values = cache.layers[module.layer_idx].values keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs) if isinstance(cache, QuantizedCache): - cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) - cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) - cache.key_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) - cache.value_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) - cache._seen_tokens = keys.shape[2] + cache.cache_processor._quantized_keys[module.layer_idx] = cache.cache_processor._quantize(keys, axis=cache.cache_processor.axis_key) + cache.cache_processor._quantized_values[module.layer_idx] = cache.cache_processor._quantize(values, axis=cache.cache_processor.axis_value) + cache.layers[module.layer_idx].keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) + cache.layers[module.layer_idx].values = torch.zeros(0, dtype=keys.dtype, device=keys.device) + cache.cache_processor.erased_length = keys.shape[2] else: - cache.key_cache[module.layer_idx] = keys - cache.value_cache[module.layer_idx] = values + cache.layers[module.layer_idx].keys = keys + cache.layers[module.layer_idx].values = values return output From 1738a027232f9d641b9f81a3cbc68abfa29df96b Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Wed, 23 Jul 2025 11:13:24 +0200 Subject: [PATCH 2/3] Update test_pipeline.py --- tests/test_pipeline.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 14ff3156..c111fc39 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -6,7 +6,7 @@ import pytest import torch -from transformers import AutoTokenizer, DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache +from transformers import AutoTokenizer, DynamicCache, QuantoQuantizedCache from transformers.utils import is_optimum_quanto_available from kvpress import ExpectedAttentionPress @@ -82,8 +82,7 @@ def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog): # noq context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) - config = QuantizedCacheConfig(nbits=4) - cache = QuantoQuantizedCache(config) + cache = QuantoQuantizedCache(nbits=4) answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"] assert len(answers) == 1 From 9986c31a92ba2fc7ea214550027a11bba93dc6a3 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Thu, 7 Aug 2025 20:14:47 +0200 Subject: [PATCH 3/3] Update attention_patch.py --- kvpress/attention_patch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index 2b47a233..1e4975f5 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -78,6 +78,7 @@ def wrapper(module, query, key, value, attention_mask, dropout, **kwargs): batch_indices, head_indices, seq_indices = module.masked_key_indices key[batch_indices, head_indices, seq_indices] = k[batch_indices, head_indices] + kwargs["cu_seq_lens_k"][-1] = key.shape[-2] return func(module, query, key, value, attention_mask, dropout, **kwargs) return wrapper