From cfa4e58c208bf986b5fbad0646fe249306af22d4 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Mon, 9 Dec 2024 12:00:56 +0100 Subject: [PATCH 01/29] fix generate_answer for quantized cache --- kvpress/__init__.py | 2 +- kvpress/pipeline.py | 23 ++++++++++++++--------- kvpress/presses/base_press.py | 2 +- kvpress/presses/tova_press.py | 2 +- tests/presses/test_presses.py | 3 +-- 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 5f7fc8cb..4f99c63a 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -11,8 +11,8 @@ from kvpress.presses.random_press import RandomPress from kvpress.presses.snapkv_press import SnapKVPress from kvpress.presses.streaming_llm_press import StreamingLLMPress -from kvpress.presses.tova_press import TOVAPress from kvpress.presses.think_press import ThinKPress +from kvpress.presses.tova_press import TOVAPress __all__ = [ "BasePress", diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 1ae45aef..6a533619 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -7,7 +7,7 @@ from typing import Optional import torch -from transformers import AutoModelForCausalLM, Cache, DynamicCache, QuantizedCache, Pipeline +from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline from transformers.pipelines import PIPELINE_REGISTRY from transformers.pipelines.base import GenericTensor @@ -215,7 +215,11 @@ def generate_answer( The generated answer. """ - cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] + if hasattr(cache, "_quantized_key_cache"): + cache_seq_lengths = [cache._quantized_key_cache[layer_idx].shape[-2] for layer_idx in range(len(cache))] + else: + cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] + position_ids = torch.arange( context_length, context_length + question_ids.shape[1], device=self.model.device ).unsqueeze(0) @@ -248,13 +252,14 @@ def generate_answer( answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True) # Remove the generated tokens from the cache - if isinstance(cache, QuantizedCache): - key_attr, value_attr = "_quantized_key_cache", "_quantized_value_cache" - else: - key_attr, value_attr = "key_cache", "value_cache" - - setattr(cache, key_attr, [key[:, :, :c] for key, c in zip(getattr(cache, key_attr), cache_seq_lengths)]) - setattr(cache, value_attr, [value[:, :, :c] for value, c in zip(getattr(cache, value_attr), cache_seq_lengths)]) + cache.key_cache = [ + cache.key_cache[layer_idx][:, :, :sequence_length] + for layer_idx, sequence_length in enumerate(cache_seq_lengths) + ] + cache.key_cache = [ + cache.key_cache[layer_idx][:, :, :sequence_length] + for layer_idx, sequence_length in enumerate(cache_seq_lengths) + ] return answer diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 553187aa..28bb856f 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -13,8 +13,8 @@ MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, - Qwen2ForCausalLM, QuantizedCache, + Qwen2ForCausalLM, ) logger = logging.getLogger(__name__) diff --git a/kvpress/presses/tova_press.py b/kvpress/presses/tova_press.py index 0addcd95..28eb96f6 100644 --- a/kvpress/presses/tova_press.py +++ b/kvpress/presses/tova_press.py @@ -4,8 +4,8 @@ from dataclasses import dataclass import torch -from torch import nn import torch.nn.functional as F +from torch import nn from kvpress.presses.snapkv_press import SnapKVPress diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 4c2d361e..dbb6d805 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -14,10 +14,9 @@ RandomPress, SnapKVPress, StreamingLLMPress, - TOVAPress, ThinKPress, + TOVAPress, ) - from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401 From 190b3c40fc4932de5589f022e9c28d0f130c13eb Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Mon, 9 Dec 2024 13:10:32 +0100 Subject: [PATCH 02/29] fix value chache pruning --- kvpress/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 6a533619..9efc563a 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -256,8 +256,8 @@ def generate_answer( cache.key_cache[layer_idx][:, :, :sequence_length] for layer_idx, sequence_length in enumerate(cache_seq_lengths) ] - cache.key_cache = [ - cache.key_cache[layer_idx][:, :, :sequence_length] + cache.value_cache = [ + cache.value_cache[layer_idx][:, :, :sequence_length] for layer_idx, sequence_length in enumerate(cache_seq_lengths) ] From 4255ba1e1fd629428e3e1485c908e96101deb414 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Mon, 9 Dec 2024 13:38:56 +0100 Subject: [PATCH 03/29] improve test --- tests/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 07a95f2b..f82ea7ac 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -99,7 +99,7 @@ def generate_answer(model): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model.to(device) context = "This is a test article. It was written on 2022-01-01." - questions = ["When was this article written?"] + questions = ["When was this article written?", "When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)( From d766ce9ce2ebb2f676cee11e457e511ff2f24497 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Mon, 9 Dec 2024 13:48:15 +0100 Subject: [PATCH 04/29] improve test --- tests/test_pipeline.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f82ea7ac..f46f12f4 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -27,6 +27,17 @@ def test_pipeline(kv_press_pipeline, caplog): # noqa: F811 assert "Compressed Context Length: 13" in messages, messages +def test_pipeline_With_cache(kv_press_pipeline, caplog): # noqa: F811 + context = "This is a test article. It was written on 2022-01-01." + questions = ["When was this article written?"] + press = ExpectedAttentionPress(compression_ratio=0.4) + cache = DynamicCache() + answers = kv_press_pipeline(context, questions=questions, press=press, cache=cache)["answers"] + + assert len(answers) == 1 + assert isinstance(answers[0], str) + + @pytest.mark.parametrize("question", ["When was this article written?", ""]) def test_pipeline_single_or_no_question(kv_press_pipeline, question, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): From 9c30df16d333e88d22d82566ec6f24770c9425ff Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 09:22:37 +0100 Subject: [PATCH 05/29] add integration tests --- tests/fixtures.py | 17 +++++- tests/integration/test_quantized_cache.py | 59 +++++++++++++++++++++ tests/test_generate.py | 10 ++-- tests/test_per_layer_compression_wrapper.py | 6 +-- tests/test_pipeline.py | 18 +++---- 5 files changed, 92 insertions(+), 18 deletions(-) create mode 100644 tests/integration/test_quantized_cache.py diff --git a/tests/fixtures.py b/tests/fixtures.py index 856404f7..68942e0f 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -25,9 +25,24 @@ def danube_500m_model(): @pytest.fixture(scope="session") -def kv_press_pipeline(): +def kv_press_unit_test_pipeline(): return pipeline( "kv-press-text-generation", model="maxjeblick/llama2-0b-unit-test", device=0 if torch.cuda.is_available() else -1, ) + + +@pytest.fixture(scope="session") +def kv_press_llama3_1_flash_attn_pipeline(): + device = "cuda:0" + ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct" + attn_implementation = "flash_attention_2" + pipe = pipeline( + "kv-press-new-text-generation", + model=ckpt, + device=device, + torch_dtype="auto", + model_kwargs={"attn_implementation": attn_implementation}, + ) + return pipe diff --git a/tests/integration/test_quantized_cache.py b/tests/integration/test_quantized_cache.py new file mode 100644 index 00000000..e8f48736 --- /dev/null +++ b/tests/integration/test_quantized_cache.py @@ -0,0 +1,59 @@ +import datasets +import pytest +import torch +from transformers import QuantizedCacheConfig, QuantoQuantizedCache + +from kvpress import ExpectedAttentionPress +from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401 + +gpu_available = torch.cuda.is_available() +try: + import flash_attn + + flash_attn_installed = True +except: + flash_attn_installed = False + +dynamic_cache_available = False +try: + config = QuantizedCacheConfig(nbits=4) + cache = QuantoQuantizedCache(config) + dynamic_cache_available = True +except: + dynamic_cache_available = False + + +@pytest.mark.skipif(not gpu_available, reason="GPU is not available") +@pytest.mark.skipif(not flash_attn_installed, reason="flash_attn is not installed") +def kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 + df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() + df = df.loc[df["task"] == "niah_single_3"].reset_index(drop=True) + press = ExpectedAttentionPress(0.3) + + idx = 0 + context = df.iloc[idx]["context"] + question = df.iloc[idx]["question"] + true_answer = df.iloc[idx]["answer"][0] + + pred_answer = kv_press_llama3_1_flash_attn_pipeline(context, question=question, press=press)["answer"] + assert true_answer in pred_answer + + +@pytest.mark.skipif(not gpu_available, reason="GPU is not available") +@pytest.mark.skipif(not flash_attn_installed, reason="flash_attn is not installed") +@pytest.mark.skipif(not dynamic_cache_available, reason="QuantizedCache is not available") +def kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 + df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() + df = df.loc[df["task"] == "niah_single_3"].reset_index(drop=True) + press = ExpectedAttentionPress(0.15) + + idx = 0 + context = df.iloc[idx]["context"] + question = df.iloc[idx]["question"] + true_answer = df.iloc[idx]["answer"][0] + + config = QuantizedCacheConfig(nbits=4) + cache = QuantoQuantizedCache(config) + + pred_answer = kv_press_llama3_1_flash_attn_pipeline(context, question=question, press=press, cache=cache)["answer"] + assert true_answer in pred_answer diff --git a/tests/test_generate.py b/tests/test_generate.py index 7bad5974..c5315312 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -3,20 +3,20 @@ from kvpress import KnormPress -from tests.fixtures import kv_press_pipeline # noqa: F401 +from tests.fixtures import kv_press_unit_test_pipeline # noqa: F401 -def test_generate(kv_press_pipeline): # noqa: F811 +def test_generate(kv_press_unit_test_pipeline): # noqa: F811 context = "This is a test article. It was written on 2022-01-01." press = KnormPress(compression_ratio=0.4) # Answer with pipeline - pipe_answer = kv_press_pipeline(context, press=press, max_new_tokens=10)["answer"] + pipe_answer = kv_press_unit_test_pipeline(context, press=press, max_new_tokens=10)["answer"] # Answer with model.generate context += "\n" # kv press pipeline automatically adds a newline if no chat template - model = kv_press_pipeline.model - tokenizer = kv_press_pipeline.tokenizer + model = kv_press_unit_test_pipeline.model + tokenizer = kv_press_unit_test_pipeline.tokenizer with press(model): inputs = tokenizer(context, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_new_tokens=10, do_sample=False) diff --git a/tests/test_per_layer_compression_wrapper.py b/tests/test_per_layer_compression_wrapper.py index f68808eb..1df9b0da 100644 --- a/tests/test_per_layer_compression_wrapper.py +++ b/tests/test_per_layer_compression_wrapper.py @@ -9,7 +9,7 @@ from transformers import DynamicCache from kvpress import KnormPress, apply_per_layer_compression -from tests.fixtures import kv_press_pipeline, unit_test_model # noqa: F401 +from tests.fixtures import kv_press_unit_test_pipeline, unit_test_model # noqa: F401 @dataclass @@ -22,7 +22,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic return super().forward_hook(module, input, kwargs, output) -def test_presses_run(kv_press_pipeline): # noqa: F811 +def test_presses_run(kv_press_unit_test_pipeline): # noqa: F811 press = RecordCompressionKnormPress(compression_ratio=0) compression_ratios = [0.1, 0.2] wrapped_press = apply_per_layer_compression(press, compression_ratios) @@ -30,7 +30,7 @@ def test_presses_run(kv_press_pipeline): # noqa: F811 context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?"] - answers = kv_press_pipeline(context, questions=questions, press=wrapped_press)["answers"] + answers = kv_press_unit_test_pipeline(context, questions=questions, press=wrapped_press)["answers"] assert len(answers) == 1 assert isinstance(answers[0], str) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f46f12f4..fdff0794 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -9,15 +9,15 @@ from transformers import AutoTokenizer, DynamicCache from kvpress import ExpectedAttentionPress, KVPressTextGenerationPipeline -from tests.fixtures import danube_500m_model, kv_press_pipeline, unit_test_model # noqa: F401 +from tests.fixtures import danube_500m_model, kv_press_unit_test_pipeline, unit_test_model # noqa: F401 -def test_pipeline(kv_press_pipeline, caplog): # noqa: F811 +def test_pipeline(kv_press_unit_test_pipeline, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) - answers = kv_press_pipeline(context, questions=questions, press=press)["answers"] + answers = kv_press_unit_test_pipeline(context, questions=questions, press=press)["answers"] assert len(answers) == 1 assert isinstance(answers[0], str) @@ -27,23 +27,23 @@ def test_pipeline(kv_press_pipeline, caplog): # noqa: F811 assert "Compressed Context Length: 13" in messages, messages -def test_pipeline_With_cache(kv_press_pipeline, caplog): # noqa: F811 +def test_pipeline_With_cache(kv_press_unit_test_pipeline, caplog): # noqa: F811 context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) cache = DynamicCache() - answers = kv_press_pipeline(context, questions=questions, press=press, cache=cache)["answers"] + answers = kv_press_unit_test_pipeline(context, questions=questions, press=press, cache=cache)["answers"] assert len(answers) == 1 assert isinstance(answers[0], str) @pytest.mark.parametrize("question", ["When was this article written?", ""]) -def test_pipeline_single_or_no_question(kv_press_pipeline, question, caplog): # noqa: F811 +def test_pipeline_single_or_no_question(kv_press_unit_test_pipeline, question, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): context = "This is a test article. It was written on 2022-01-01." press = ExpectedAttentionPress(compression_ratio=0.4) - answer = kv_press_pipeline(context, question=question, press=press)["answer"] + answer = kv_press_unit_test_pipeline(context, question=question, press=press)["answer"] assert isinstance(answer, str) @@ -52,10 +52,10 @@ def test_pipeline_single_or_no_question(kv_press_pipeline, question, caplog): # assert "Compressed Context Length: 13" in messages, messages -def test_pipeline_no_press_works(kv_press_pipeline, caplog): # noqa: F811 +def test_pipeline_no_press_works(kv_press_unit_test_pipeline, caplog): # noqa: F811 context = "This is a test article. It was written on 2022-01-01." question = "When was this article written?" - kv_press_pipeline(context, question=question) + kv_press_unit_test_pipeline(context, question=question) @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") From f753eae07a0d5d8e375074b5acb1669700bb7c2b Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 15:01:49 +0100 Subject: [PATCH 06/29] get correct context length --- kvpress/pipeline.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 39dc52c1..b0581cfd 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -172,7 +172,10 @@ def _forward( ) logger.debug(f"Context Length: {context_length}") - logger.debug(f"Compressed Context Length: {cache.get_seq_length()}") + compressed_context_length = ( + cache._quantized_key_cache[0].shape[-2] if isinstance(cache, QuantizedCache) else cache.get_seq_length() + ) + logger.debug(f"Compressed Context Length: {compressed_context_length}") # Greedy decoding for each question answers = [] From 7649ab915231e55130eae91aa40b9cb503bc9f65 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 16:11:05 +0100 Subject: [PATCH 07/29] fix qunatized key cache --- kvpress/pipeline.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index b0581cfd..136444f2 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -263,6 +263,15 @@ def generate_answer( cache.value_cache[layer_idx][:, :, :sequence_length] for layer_idx, sequence_length in enumerate(cache_seq_lengths) ] + if hasattr(cache, "_quantized_key_cache"): + cache._quantized_key_cache = [ + cache._quantized_key_cache[layer_idx][:, :, :sequence_length] + for layer_idx, sequence_length in enumerate(cache_seq_lengths) + ] + cache._quantized_value_cache = [ + cache._quantized_value_cache[layer_idx][:, :, :sequence_length] + for layer_idx, sequence_length in enumerate(cache_seq_lengths) + ] return answer From 11a4c45b91a27c602e22612820085b2132f87bcf Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 16:22:08 +0100 Subject: [PATCH 08/29] fix qunatized key cache --- README.md | 2 +- kvpress/__init__.py | 2 -- kvpress/pipeline.py | 5 +---- kvpress/presses/scorer_press.py | 3 +++ tests/integration/test_quantized_cache.py | 15 ++++----------- tests/test_pipeline.py | 19 +++++++++++++++---- 6 files changed, 24 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 7e4acd4d..139de4e3 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,7 @@ pipe(..., cache=cache) By default, the `DynamicCache` is used (no quantization). > [!IMPORTANT] -> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto==0.2.4`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)). +> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)). ## FAQ diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 2314d754..3a379d50 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -29,5 +29,3 @@ "KVPressTextGenerationPipeline", "PerLayerCompressionPress", ] - -from kvpress.presses.tova_press import TOVAPress diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 136444f2..9f6847e8 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -218,10 +218,7 @@ def generate_answer( The generated answer. """ - if hasattr(cache, "_quantized_key_cache"): - cache_seq_lengths = [cache._quantized_key_cache[layer_idx].shape[-2] for layer_idx in range(len(cache))] - else: - cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] + cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] position_ids = torch.arange( context_length, context_length + question_ids.shape[1], device=self.model.device diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py index a696cc7e..8bcc5a0d 100644 --- a/kvpress/presses/scorer_press.py +++ b/kvpress/presses/scorer_press.py @@ -115,6 +115,9 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic if isinstance(cache, QuantizedCache): cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) + cache.key_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) + cache.value_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) + cache._seen_tokens = n_kept else: cache.key_cache[module.layer_idx] = keys cache.value_cache[module.layer_idx] = values diff --git a/tests/integration/test_quantized_cache.py b/tests/integration/test_quantized_cache.py index e8f48736..4aa51fd2 100644 --- a/tests/integration/test_quantized_cache.py +++ b/tests/integration/test_quantized_cache.py @@ -2,26 +2,19 @@ import pytest import torch from transformers import QuantizedCacheConfig, QuantoQuantizedCache +from transformers.utils import is_optimum_quanto_available from kvpress import ExpectedAttentionPress from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401 gpu_available = torch.cuda.is_available() try: - import flash_attn + import flash_attn # noqa: F401 flash_attn_installed = True -except: +except: # noqa: E722 flash_attn_installed = False -dynamic_cache_available = False -try: - config = QuantizedCacheConfig(nbits=4) - cache = QuantoQuantizedCache(config) - dynamic_cache_available = True -except: - dynamic_cache_available = False - @pytest.mark.skipif(not gpu_available, reason="GPU is not available") @pytest.mark.skipif(not flash_attn_installed, reason="flash_attn is not installed") @@ -41,7 +34,7 @@ def kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline) @pytest.mark.skipif(not gpu_available, reason="GPU is not available") @pytest.mark.skipif(not flash_attn_installed, reason="flash_attn is not installed") -@pytest.mark.skipif(not dynamic_cache_available, reason="QuantizedCache is not available") +@pytest.mark.skipif(not is_optimum_quanto_available(), reason="QuantizedCache is not available") def kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() df = df.loc[df["task"] == "niah_single_3"].reset_index(drop=True) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index ebfb16d5..77cb8d55 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -6,19 +6,30 @@ import pytest import torch -from transformers import AutoTokenizer, DynamicCache +from transformers import AutoTokenizer, DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache +from transformers.utils import is_optimum_quanto_available from kvpress import ExpectedAttentionPress from kvpress.pipeline import KVPressTextGenerationPipeline -from tests.fixtures import danube_500m_model, kv_press_pipeline, unit_test_model # noqa: F401 +from tests.fixtures import danube_500m_model, kv_press_unit_test_pipeline, unit_test_model # noqa: F401 -def test_pipeline(kv_press_unit_test_pipeline, caplog): # noqa: F811 +@pytest.mark.parametrize("cache", ["dynamic", "quantized"]) +def test_pipeline(kv_press_unit_test_pipeline, caplog, cache): # noqa: F811 with caplog.at_level(logging.DEBUG): context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) - answers = kv_press_unit_test_pipeline(context, questions=questions, press=press)["answers"] + if cache == "dynamic": + cache = DynamicCache() + elif cache == "quantized" and is_optimum_quanto_available(): + config = QuantizedCacheConfig(nbits=4) + cache = QuantoQuantizedCache(config) + elif cache == "quantized": + pytest.skip("Optimum Quanto is not available") + else: + raise ValueError(f"Unknown cache type: {cache}") + answers = kv_press_unit_test_pipeline(context, questions=questions, press=press, cache=cache)["answers"] assert len(answers) == 1 assert isinstance(answers[0], str) From cfec9b85c0dd9ccef9b97004dc99519ab50d8e77 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 16:31:32 +0100 Subject: [PATCH 09/29] fix test --- tests/test_per_layer_compression_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_per_layer_compression_wrapper.py b/tests/test_per_layer_compression_wrapper.py index b0bf4155..6db01eb2 100644 --- a/tests/test_per_layer_compression_wrapper.py +++ b/tests/test_per_layer_compression_wrapper.py @@ -7,7 +7,7 @@ from kvpress.presses.knorm_press import KnormPress from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress -from tests.fixtures import kv_press_pipeline, unit_test_model # noqa: F401 +from tests.fixtures import unit_test_model # noqa: F401 def test_per_layer_compression_press(unit_test_model): # noqa: F811 From 1bfe667073eb0b3851d30a87ac24d61ce3e02730 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 16:44:01 +0100 Subject: [PATCH 10/29] fix test --- tests/fixtures.py | 9 +++++++++ tests/test_pipeline.py | 36 +++++++++++++++++++++++------------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/tests/fixtures.py b/tests/fixtures.py index 68942e0f..baac63f3 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -33,6 +33,15 @@ def kv_press_unit_test_pipeline(): ) +@pytest.fixture(scope="session") +def kv_press_danube_pipeline(): + return pipeline( + "kv-press-text-generation", + model="h2oai/h2o-danube3-500m-chat", + device=0 if torch.cuda.is_available() else -1, + ) + + @pytest.fixture(scope="session") def kv_press_llama3_1_flash_attn_pipeline(): device = "cuda:0" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 77cb8d55..64d07600 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -11,25 +11,35 @@ from kvpress import ExpectedAttentionPress from kvpress.pipeline import KVPressTextGenerationPipeline -from tests.fixtures import danube_500m_model, kv_press_unit_test_pipeline, unit_test_model # noqa: F401 +from tests.fixtures import danube_500m_model # noqa: F401 +from tests.fixtures import kv_press_danube_pipeline # noqa: F401 +from tests.fixtures import kv_press_unit_test_pipeline # noqa: F401 +from tests.fixtures import unit_test_model # noqa: F401 -@pytest.mark.parametrize("cache", ["dynamic", "quantized"]) -def test_pipeline(kv_press_unit_test_pipeline, caplog, cache): # noqa: F811 +def test_pipeline(kv_press_unit_test_pipeline, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) - if cache == "dynamic": - cache = DynamicCache() - elif cache == "quantized" and is_optimum_quanto_available(): - config = QuantizedCacheConfig(nbits=4) - cache = QuantoQuantizedCache(config) - elif cache == "quantized": - pytest.skip("Optimum Quanto is not available") - else: - raise ValueError(f"Unknown cache type: {cache}") - answers = kv_press_unit_test_pipeline(context, questions=questions, press=press, cache=cache)["answers"] + answers = kv_press_unit_test_pipeline(context, questions=questions, press=press)["answers"] + + assert len(answers) == 1 + assert isinstance(answers[0], str) + + messages = [record.message for record in caplog.records] + assert "Context Length: 23" in messages, messages + assert "Compressed Context Length: 13" in messages, messages + + +@pytest.mark.skipif(not is_optimum_quanto_available(), reason="Optimum Quanto is not available") +def test_pipeline_qunatized(kv_press_danube_pipeline, caplog, cache): # noqa: F811 + with caplog.at_level(logging.DEBUG): + context = "This is a test article. It was written on 2022-01-01." + questions = ["When was this article written?"] + press = ExpectedAttentionPress(compression_ratio=0.4) + config = QuantizedCacheConfig(nbits=4) + answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"] assert len(answers) == 1 assert isinstance(answers[0], str) From d46cc55e7b64b0d6e8a8ce9b8ca37b2813fe6bc2 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 16:53:30 +0100 Subject: [PATCH 11/29] fix test --- tests/test_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 64d07600..9a97ec23 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -33,12 +33,13 @@ def test_pipeline(kv_press_unit_test_pipeline, caplog): # noqa: F811 @pytest.mark.skipif(not is_optimum_quanto_available(), reason="Optimum Quanto is not available") -def test_pipeline_qunatized(kv_press_danube_pipeline, caplog, cache): # noqa: F811 +def test_pipeline_quantized(kv_press_danube_pipeline, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) config = QuantizedCacheConfig(nbits=4) + cache = QuantoQuantizedCache(config) answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"] assert len(answers) == 1 From 524fe471193ecdfb2f8e06859e95c9e540c4d4bd Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 16:56:35 +0100 Subject: [PATCH 12/29] add more asserts --- tests/test_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 9a97ec23..29728f7f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -44,6 +44,7 @@ def test_pipeline_quantized(kv_press_danube_pipeline, caplog): # noqa: F811 assert len(answers) == 1 assert isinstance(answers[0], str) + assert cache.get_seq_length() == 13 messages = [record.message for record in caplog.records] assert "Context Length: 23" in messages, messages From c07cc4dbf5796bd961d5c9ccccb0769bde3f6fd6 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 16:59:22 +0100 Subject: [PATCH 13/29] fix test --- tests/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 29728f7f..77f063d0 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -99,7 +99,7 @@ def test_pipeline_compresses_context(unit_test_model, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): answers = generate_answer(unit_test_model) - assert len(answers) == 1 + assert len(answers) == 2 assert isinstance(answers[0], str) messages = [record.message for record in caplog.records] From 6ff4bc963b851d92b2bc9226c419fe314e77baee Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 17:25:14 +0100 Subject: [PATCH 14/29] fix test --- tests/test_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 77f063d0..23dcb966 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -44,7 +44,6 @@ def test_pipeline_quantized(kv_press_danube_pipeline, caplog): # noqa: F811 assert len(answers) == 1 assert isinstance(answers[0], str) - assert cache.get_seq_length() == 13 messages = [record.message for record in caplog.records] assert "Context Length: 23" in messages, messages From 960e05e4626c0340da707fe2b47e102cf6109b00 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 17:28:26 +0100 Subject: [PATCH 15/29] fix test --- tests/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 23dcb966..99c9dc5b 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -112,7 +112,7 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 questions = ["When was this article written?"] tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) - compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer) + compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=torch.device("cpu")) input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"] seq_len = 256 From 58bc8ae2acca57abebe8ddd7bb9c8fa9a4bbf95b Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 17:32:24 +0100 Subject: [PATCH 16/29] fix merge conflicts --- kvpress/__init__.py | 2 +- kvpress/presses/base_press.py | 5 ++++- kvpress/presses/composed_press.py | 1 + kvpress/presses/scorer_press.py | 10 ---------- 4 files changed, 6 insertions(+), 12 deletions(-) diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 95ae161e..18a287c3 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -4,6 +4,7 @@ from kvpress.pipeline import KVPressTextGenerationPipeline from kvpress.presses.base_press import BasePress +from kvpress.presses.composed_press import ComposedPress from kvpress.presses.expected_attention_press import ExpectedAttentionPress from kvpress.presses.knorm_press import KnormPress from kvpress.presses.observed_attention_press import ObservedAttentionPress @@ -13,7 +14,6 @@ from kvpress.presses.snapkv_press import SnapKVPress from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.think_press import ThinKPress -from kvpress.presses.composed_press import ComposedPress from kvpress.presses.tova_press import TOVAPress __all__ = [ diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 129d1769..8a1bea2f 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -14,8 +14,8 @@ MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, - Qwen2ForCausalLM, QuantizedCache, + Qwen2ForCausalLM, ) logger = logging.getLogger(__name__) @@ -111,6 +111,9 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic if isinstance(cache, QuantizedCache): cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) + cache.key_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) + cache.value_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) + cache._seen_tokens = keys.shape[2] else: cache.key_cache[module.layer_idx] = keys cache.value_cache[module.layer_idx] = values diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py index 3ebc5c16..ba3b65e7 100644 --- a/kvpress/presses/composed_press.py +++ b/kvpress/presses/composed_press.py @@ -1,4 +1,5 @@ from dataclasses import dataclass + from kvpress.presses.base_press import BasePress diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py index 747e1877..79d496f3 100644 --- a/kvpress/presses/scorer_press.py +++ b/kvpress/presses/scorer_press.py @@ -67,14 +67,4 @@ def compress( # Prune keys and values keys = keys.gather(2, indices).contiguous() values = values.gather(2, indices).contiguous() - if isinstance(cache, QuantizedCache): - cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) - cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) - cache.key_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) - cache.value_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) - cache._seen_tokens = n_kept - else: - cache.key_cache[module.layer_idx] = keys - cache.value_cache[module.layer_idx] = values - return keys, values From 16e8671fdecf4de8403f9df535db4dab79e06365 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 17:34:42 +0100 Subject: [PATCH 17/29] fix failing tests --- kvpress/pipeline.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 9f6847e8..0a980b8c 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -172,9 +172,7 @@ def _forward( ) logger.debug(f"Context Length: {context_length}") - compressed_context_length = ( - cache._quantized_key_cache[0].shape[-2] if isinstance(cache, QuantizedCache) else cache.get_seq_length() - ) + compressed_context_length = cache.get_seq_length() logger.debug(f"Compressed Context Length: {compressed_context_length}") # Greedy decoding for each question From d9887ea6a5fd78c6658241ecbbcab3c3c86f34cc Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 17:37:57 +0100 Subject: [PATCH 18/29] import flash attn skip --- tests/integration/test_quantized_cache.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/integration/test_quantized_cache.py b/tests/integration/test_quantized_cache.py index 4aa51fd2..57f2cdbb 100644 --- a/tests/integration/test_quantized_cache.py +++ b/tests/integration/test_quantized_cache.py @@ -2,22 +2,14 @@ import pytest import torch from transformers import QuantizedCacheConfig, QuantoQuantizedCache -from transformers.utils import is_optimum_quanto_available +from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available from kvpress import ExpectedAttentionPress from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401 -gpu_available = torch.cuda.is_available() -try: - import flash_attn # noqa: F401 - flash_attn_installed = True -except: # noqa: E722 - flash_attn_installed = False - - -@pytest.mark.skipif(not gpu_available, reason="GPU is not available") -@pytest.mark.skipif(not flash_attn_installed, reason="flash_attn is not installed") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") +@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") def kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() df = df.loc[df["task"] == "niah_single_3"].reset_index(drop=True) @@ -32,8 +24,8 @@ def kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline) assert true_answer in pred_answer -@pytest.mark.skipif(not gpu_available, reason="GPU is not available") -@pytest.mark.skipif(not flash_attn_installed, reason="flash_attn is not installed") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") +@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") @pytest.mark.skipif(not is_optimum_quanto_available(), reason="QuantizedCache is not available") def kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() From f5d9d59653442f3467df9127cff2b7ac489c4610 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 17:41:39 +0100 Subject: [PATCH 19/29] fix test --- kvpress/pipeline.py | 2 +- tests/test_pipeline.py | 41 ++++++++++++++++++++++------------------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 0a980b8c..f341a39d 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -7,7 +7,7 @@ from typing import Optional import torch -from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline, QuantizedCache +from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline from transformers.pipelines import PIPELINE_REGISTRY from transformers.pipelines.base import GenericTensor diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 99c9dc5b..96993492 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -32,25 +32,7 @@ def test_pipeline(kv_press_unit_test_pipeline, caplog): # noqa: F811 assert "Compressed Context Length: 13" in messages, messages -@pytest.mark.skipif(not is_optimum_quanto_available(), reason="Optimum Quanto is not available") -def test_pipeline_quantized(kv_press_danube_pipeline, caplog): # noqa: F811 - with caplog.at_level(logging.DEBUG): - context = "This is a test article. It was written on 2022-01-01." - questions = ["When was this article written?"] - press = ExpectedAttentionPress(compression_ratio=0.4) - config = QuantizedCacheConfig(nbits=4) - cache = QuantoQuantizedCache(config) - answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"] - - assert len(answers) == 1 - assert isinstance(answers[0], str) - - messages = [record.message for record in caplog.records] - assert "Context Length: 23" in messages, messages - assert "Compressed Context Length: 13" in messages, messages - - -def test_pipeline_With_cache(kv_press_unit_test_pipeline, caplog): # noqa: F811 +def test_pipeline_with_cache(kv_press_unit_test_pipeline, caplog): # noqa: F811 context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) @@ -94,6 +76,27 @@ def test_pipeline_answer_is_correct(danube_500m_model, caplog): # noqa: F811 assert "Compressed Context Length: 16" in messages +@pytest.mark.skipif(not is_optimum_quanto_available(), reason="Optimum Quanto is not available") +def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog): # noqa: F811 + with caplog.at_level(logging.DEBUG): + context = "This is a test article. It was written on 2022-01-01." + questions = ["When was this article written?"] + press = ExpectedAttentionPress(compression_ratio=0.4) + config = QuantizedCacheConfig(nbits=4) + cache = QuantoQuantizedCache(config) + answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"] + + assert len(answers) == 1 + assert isinstance(answers[0], str) + + for answer in answers: + assert answer == "This article was written on January 1, 2022." + + messages = [record.message for record in caplog.records] + assert "Context Length: 28" in messages + assert "Compressed Context Length: 16" in messages + + def test_pipeline_compresses_context(unit_test_model, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): answers = generate_answer(unit_test_model) From c620cb0be6a3f59b651edfc0f6524f4889a22424 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 17:47:10 +0100 Subject: [PATCH 20/29] add integration tests --- tests/integration/__init__.py | 0 tests/integration/test_quantized_cache.py | 4 ++-- tests/presses/__init__.py | 0 3 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 tests/integration/__init__.py create mode 100644 tests/presses/__init__.py diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_quantized_cache.py b/tests/integration/test_quantized_cache.py index 57f2cdbb..f0f64fc6 100644 --- a/tests/integration/test_quantized_cache.py +++ b/tests/integration/test_quantized_cache.py @@ -10,7 +10,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") -def kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 +def test_kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() df = df.loc[df["task"] == "niah_single_3"].reset_index(drop=True) press = ExpectedAttentionPress(0.3) @@ -27,7 +27,7 @@ def kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline) @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") @pytest.mark.skipif(not is_optimum_quanto_available(), reason="QuantizedCache is not available") -def kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 +def test_kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() df = df.loc[df["task"] == "niah_single_3"].reset_index(drop=True) press = ExpectedAttentionPress(0.15) diff --git a/tests/presses/__init__.py b/tests/presses/__init__.py new file mode 100644 index 00000000..e69de29b From ca4f0ba8aa03eb5342ec280249988d71abdc4e7a Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 17:51:55 +0100 Subject: [PATCH 21/29] add integration tests --- tests/fixtures.py | 2 +- tests/integration/{test_quantized_cache.py => test_ruler.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/integration/{test_quantized_cache.py => test_ruler.py} (100%) diff --git a/tests/fixtures.py b/tests/fixtures.py index baac63f3..912c90fe 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -48,7 +48,7 @@ def kv_press_llama3_1_flash_attn_pipeline(): ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct" attn_implementation = "flash_attention_2" pipe = pipeline( - "kv-press-new-text-generation", + "kv-press-text-generation", model=ckpt, device=device, torch_dtype="auto", diff --git a/tests/integration/test_quantized_cache.py b/tests/integration/test_ruler.py similarity index 100% rename from tests/integration/test_quantized_cache.py rename to tests/integration/test_ruler.py From 702888aedcb585ad4b99162ad0b16163223564da Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 17:55:53 +0100 Subject: [PATCH 22/29] add integration tests --- tests/integration/test_ruler.py | 47 ++++++++++++++++----------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index f0f64fc6..6d1fb93f 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -1,44 +1,43 @@ import datasets import pytest import torch -from transformers import QuantizedCacheConfig, QuantoQuantizedCache +from transformers import DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available -from kvpress import ExpectedAttentionPress +from kvpress import ExpectedAttentionPress, KnormPress, SnapKVPress, StreamingLLMPress, ThinKPress, TOVAPress from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401 @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") -def test_kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 +@pytest.mark.parametrize( + "cls", [KnormPress, ExpectedAttentionPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress] +) +@pytest.mark.parametrize("compression_ratio", [0.1, 0.2]) +@pytest.mark.parametrize("cache", ["dynamic", "quantized"]) +def test_kv_press_llama3_1_flash_attn_pipeline( + kv_press_llama3_1_flash_attn_pipeline, cls, compression_ratio, cache # noqa: F811 +): df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() df = df.loc[df["task"] == "niah_single_3"].reset_index(drop=True) - press = ExpectedAttentionPress(0.3) + if cls == ThinKPress: + press = cls(key_channel_compression_ratio=compression_ratio, window_size=2) + else: + press = cls(compression_ratio=compression_ratio) + if cache == "dynamic": + cache = DynamicCache() + elif cache == "quantized" and is_optimum_quanto_available(): + config = QuantizedCacheConfig(nbits=4) + cache = QuantoQuantizedCache(config) + elif cache == "quantized" and not is_optimum_quanto_available(): + pytest.skip("Quanto is not installed") + else: + raise ValueError(f"Unknown cache type: {cache}") idx = 0 context = df.iloc[idx]["context"] question = df.iloc[idx]["question"] true_answer = df.iloc[idx]["answer"][0] - pred_answer = kv_press_llama3_1_flash_attn_pipeline(context, question=question, press=press)["answer"] - assert true_answer in pred_answer - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") -@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") -@pytest.mark.skipif(not is_optimum_quanto_available(), reason="QuantizedCache is not available") -def test_kv_press_llama3_1_flash_attn_pipeline(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 - df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() - df = df.loc[df["task"] == "niah_single_3"].reset_index(drop=True) - press = ExpectedAttentionPress(0.15) - - idx = 0 - context = df.iloc[idx]["context"] - question = df.iloc[idx]["question"] - true_answer = df.iloc[idx]["answer"][0] - - config = QuantizedCacheConfig(nbits=4) - cache = QuantoQuantizedCache(config) - pred_answer = kv_press_llama3_1_flash_attn_pipeline(context, question=question, press=press, cache=cache)["answer"] assert true_answer in pred_answer From 4369c026390a0358d82bb2d56450acfcf889f17a Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 17:58:04 +0100 Subject: [PATCH 23/29] add fixture --- tests/integration/test_ruler.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index 6d1fb93f..fc81a408 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -8,6 +8,13 @@ from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401 +@pytest.fixture(scope="session") +def df_ruler(): + df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() + df = df.loc[df["task"] == "niah_single_3"].reset_index(drop=True) + return df + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") @pytest.mark.parametrize( @@ -15,11 +22,7 @@ ) @pytest.mark.parametrize("compression_ratio", [0.1, 0.2]) @pytest.mark.parametrize("cache", ["dynamic", "quantized"]) -def test_kv_press_llama3_1_flash_attn_pipeline( - kv_press_llama3_1_flash_attn_pipeline, cls, compression_ratio, cache # noqa: F811 -): - df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() - df = df.loc[df["task"] == "niah_single_3"].reset_index(drop=True) +def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, cls, compression_ratio, cache): # noqa: F811 if cls == ThinKPress: press = cls(key_channel_compression_ratio=compression_ratio, window_size=2) else: @@ -35,9 +38,9 @@ def test_kv_press_llama3_1_flash_attn_pipeline( raise ValueError(f"Unknown cache type: {cache}") idx = 0 - context = df.iloc[idx]["context"] - question = df.iloc[idx]["question"] - true_answer = df.iloc[idx]["answer"][0] + context = df_ruler.iloc[idx]["context"] + question = df_ruler.iloc[idx]["question"] + true_answer = df_ruler.iloc[idx]["answer"][0] pred_answer = kv_press_llama3_1_flash_attn_pipeline(context, question=question, press=press, cache=cache)["answer"] assert true_answer in pred_answer From 22c9549d2313fbb6b8b21b9752e69fcf9ebd07ae Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 18:02:22 +0100 Subject: [PATCH 24/29] easen up test --- tests/integration/test_ruler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index fc81a408..e2a45089 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -11,7 +11,7 @@ @pytest.fixture(scope="session") def df_ruler(): df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() - df = df.loc[df["task"] == "niah_single_3"].reset_index(drop=True) + df = df.loc[df["task"] == "niah_multikey_1"].reset_index(drop=True) return df From 4b66477107499cbb3b60b5963024eab0149f7697 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 18:04:29 +0100 Subject: [PATCH 25/29] undo vvariable extraction --- kvpress/pipeline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index f341a39d..93235216 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -172,8 +172,7 @@ def _forward( ) logger.debug(f"Context Length: {context_length}") - compressed_context_length = cache.get_seq_length() - logger.debug(f"Compressed Context Length: {compressed_context_length}") + logger.debug(f"Compressed Context Length: {cache.get_seq_length()}") # Greedy decoding for each question answers = [] From 8006a75efa036ff33666fc87a36cc60d969693df Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Tue, 10 Dec 2024 18:05:57 +0100 Subject: [PATCH 26/29] undo newlines --- kvpress/pipeline.py | 1 - kvpress/presses/composed_press.py | 1 - kvpress/presses/scorer_press.py | 1 + 3 files changed, 1 insertion(+), 2 deletions(-) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 93235216..b9159d1e 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -216,7 +216,6 @@ def generate_answer( """ cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] - position_ids = torch.arange( context_length, context_length + question_ids.shape[1], device=self.model.device ).unsqueeze(0) diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py index ba3b65e7..3ebc5c16 100644 --- a/kvpress/presses/composed_press.py +++ b/kvpress/presses/composed_press.py @@ -1,5 +1,4 @@ from dataclasses import dataclass - from kvpress.presses.base_press import BasePress diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py index 79d496f3..ea97eab4 100644 --- a/kvpress/presses/scorer_press.py +++ b/kvpress/presses/scorer_press.py @@ -67,4 +67,5 @@ def compress( # Prune keys and values keys = keys.gather(2, indices).contiguous() values = values.gather(2, indices).contiguous() + return keys, values From 519222b4685ab5928b3d0a9ea5bfcd3f08e92b91 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 15:35:41 +0100 Subject: [PATCH 27/29] address pr feedback --- kvpress/presses/composed_press.py | 1 + tests/integration/test_ruler.py | 12 ++++++++++-- tests/presses/test_presses.py | 2 +- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py index 3ebc5c16..ba3b65e7 100644 --- a/kvpress/presses/composed_press.py +++ b/kvpress/presses/composed_press.py @@ -1,4 +1,5 @@ from dataclasses import dataclass + from kvpress.presses.base_press import BasePress diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index e2a45089..460a63bb 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -4,7 +4,15 @@ from transformers import DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available -from kvpress import ExpectedAttentionPress, KnormPress, SnapKVPress, StreamingLLMPress, ThinKPress, TOVAPress +from kvpress import ( + ExpectedAttentionPress, + KnormPress, + SimLayerKVPress, + SnapKVPress, + StreamingLLMPress, + ThinKPress, + TOVAPress, +) from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401 @@ -18,7 +26,7 @@ def df_ruler(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") @pytest.mark.parametrize( - "cls", [KnormPress, ExpectedAttentionPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress] + "cls", [KnormPress, ExpectedAttentionPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress, SimLayerKVPress] ) @pytest.mark.parametrize("compression_ratio", [0.1, 0.2]) @pytest.mark.parametrize("cache", ["dynamic", "quantized"]) diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 86dd35ac..29ae9a40 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -12,8 +12,8 @@ KnormPress, ObservedAttentionPress, RandomPress, - SnapKVPress, SimLayerKVPress, + SnapKVPress, StreamingLLMPress, TOVAPress, ) From f75cb61733d83d92ec87820ddd0713c635a25b37 Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 16:44:27 +0100 Subject: [PATCH 28/29] fix broken test --- tests/integration/test_ruler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index 460a63bb..5a9cc5bc 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -33,6 +33,8 @@ def df_ruler(): def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, cls, compression_ratio, cache): # noqa: F811 if cls == ThinKPress: press = cls(key_channel_compression_ratio=compression_ratio, window_size=2) + elif cls == SimLayerKVPress: + press = cls(lazy_threshold=compression_ratio) else: press = cls(compression_ratio=compression_ratio) if cache == "dynamic": From 261bceacd7e8c8b076ddb492253747a1c3c033dd Mon Sep 17 00:00:00 2001 From: maxjeblick Date: Wed, 11 Dec 2024 17:08:28 +0100 Subject: [PATCH 29/29] fix broken test --- tests/integration/test_ruler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index 5a9cc5bc..460203a0 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -34,7 +34,7 @@ def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, cls, if cls == ThinKPress: press = cls(key_channel_compression_ratio=compression_ratio, window_size=2) elif cls == SimLayerKVPress: - press = cls(lazy_threshold=compression_ratio) + press = cls(lazy_threshold=1 - compression_ratio) else: press = cls(compression_ratio=compression_ratio) if cache == "dynamic":