diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index 2b47a233..1e4975f5 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -78,6 +78,7 @@ def wrapper(module, query, key, value, attention_mask, dropout, **kwargs): batch_indices, head_indices, seq_indices = module.masked_key_indices key[batch_indices, head_indices, seq_indices] = k[batch_indices, head_indices] + kwargs["cu_seq_lens_k"][-1] = key.shape[-2] return func(module, query, key, value, attention_mask, dropout, **kwargs) return wrapper diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index c5c8511b..10e2df03 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -131,23 +131,23 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic return output if isinstance(cache, QuantizedCache): - keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) - values = cache._dequantize(cache._quantized_value_cache[module.layer_idx]) + keys = cache.cache_processor._dequantize(cache.cache_processor._quantized_keys[module.layer_idx]) + values = cache.cache_processor._dequantize(cache.cache_processor._quantized_values[module.layer_idx]) else: - keys = cache.key_cache[module.layer_idx] - values = cache.value_cache[module.layer_idx] + keys = cache.layers[module.layer_idx].keys + values = cache.layers[module.layer_idx].values keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs) if isinstance(cache, QuantizedCache): - cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) - cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) - cache.key_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) - cache.value_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) - cache._seen_tokens = keys.shape[2] + cache.cache_processor._quantized_keys[module.layer_idx] = cache.cache_processor._quantize(keys, axis=cache.cache_processor.axis_key) + cache.cache_processor._quantized_values[module.layer_idx] = cache.cache_processor._quantize(values, axis=cache.cache_processor.axis_value) + cache.layers[module.layer_idx].keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) + cache.layers[module.layer_idx].values = torch.zeros(0, dtype=keys.dtype, device=keys.device) + cache.cache_processor.erased_length = keys.shape[2] else: - cache.key_cache[module.layer_idx] = keys - cache.value_cache[module.layer_idx] = values + cache.layers[module.layer_idx].keys = keys + cache.layers[module.layer_idx].values = values return output diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 14ff3156..c111fc39 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -6,7 +6,7 @@ import pytest import torch -from transformers import AutoTokenizer, DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache +from transformers import AutoTokenizer, DynamicCache, QuantoQuantizedCache from transformers.utils import is_optimum_quanto_available from kvpress import ExpectedAttentionPress @@ -82,8 +82,7 @@ def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog): # noq context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) - config = QuantizedCacheConfig(nbits=4) - cache = QuantoQuantizedCache(config) + cache = QuantoQuantizedCache(nbits=4) answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"] assert len(answers) == 1