diff --git a/README.md b/README.md index 1cd960d6..6e01a5ca 100644 --- a/README.md +++ b/README.md @@ -102,7 +102,7 @@ pipe(..., cache=cache) By default, the `DynamicCache` is used (no quantization). > [!IMPORTANT] -> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto==0.2.4`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)). +> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)). ## FAQ diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 1d5c4c60..913a0828 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -4,6 +4,7 @@ from kvpress.pipeline import KVPressTextGenerationPipeline from kvpress.presses.base_press import BasePress +from kvpress.presses.composed_press import ComposedPress from kvpress.presses.expected_attention_press import ExpectedAttentionPress from kvpress.presses.knorm_press import KnormPress from kvpress.presses.observed_attention_press import ObservedAttentionPress @@ -14,7 +15,7 @@ from kvpress.presses.snapkv_press import SnapKVPress from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.think_press import ThinKPress -from kvpress.presses.composed_press import ComposedPress +from kvpress.presses.tova_press import TOVAPress __all__ = [ "BasePress", @@ -32,5 +33,3 @@ "KVPressTextGenerationPipeline", "PerLayerCompressionPress", ] - -from kvpress.presses.tova_press import TOVAPress diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index ca68e678..b9159d1e 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -7,7 +7,7 @@ from typing import Optional import torch -from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline, QuantizedCache +from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline from transformers.pipelines import PIPELINE_REGISTRY from transformers.pipelines.base import GenericTensor @@ -248,13 +248,23 @@ def generate_answer( answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True) # Remove the generated tokens from the cache - if isinstance(cache, QuantizedCache): - key_attr, value_attr = "_quantized_key_cache", "_quantized_value_cache" - else: - key_attr, value_attr = "key_cache", "value_cache" - - setattr(cache, key_attr, [key[:, :, :c] for key, c in zip(getattr(cache, key_attr), cache_seq_lengths)]) - setattr(cache, value_attr, [value[:, :, :c] for value, c in zip(getattr(cache, value_attr), cache_seq_lengths)]) + cache.key_cache = [ + cache.key_cache[layer_idx][:, :, :sequence_length] + for layer_idx, sequence_length in enumerate(cache_seq_lengths) + ] + cache.value_cache = [ + cache.value_cache[layer_idx][:, :, :sequence_length] + for layer_idx, sequence_length in enumerate(cache_seq_lengths) + ] + if hasattr(cache, "_quantized_key_cache"): + cache._quantized_key_cache = [ + cache._quantized_key_cache[layer_idx][:, :, :sequence_length] + for layer_idx, sequence_length in enumerate(cache_seq_lengths) + ] + cache._quantized_value_cache = [ + cache._quantized_value_cache[layer_idx][:, :, :sequence_length] + for layer_idx, sequence_length in enumerate(cache_seq_lengths) + ] return answer diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 129d1769..8a1bea2f 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -14,8 +14,8 @@ MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, - Qwen2ForCausalLM, QuantizedCache, + Qwen2ForCausalLM, ) logger = logging.getLogger(__name__) @@ -111,6 +111,9 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic if isinstance(cache, QuantizedCache): cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) + cache.key_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) + cache.value_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) + cache._seen_tokens = keys.shape[2] else: cache.key_cache[module.layer_idx] = keys cache.value_cache[module.layer_idx] = values diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py index 3ebc5c16..ba3b65e7 100644 --- a/kvpress/presses/composed_press.py +++ b/kvpress/presses/composed_press.py @@ -1,4 +1,5 @@ from dataclasses import dataclass + from kvpress.presses.base_press import BasePress diff --git a/tests/fixtures.py b/tests/fixtures.py index 856404f7..912c90fe 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -25,9 +25,33 @@ def danube_500m_model(): @pytest.fixture(scope="session") -def kv_press_pipeline(): +def kv_press_unit_test_pipeline(): return pipeline( "kv-press-text-generation", model="maxjeblick/llama2-0b-unit-test", device=0 if torch.cuda.is_available() else -1, ) + + +@pytest.fixture(scope="session") +def kv_press_danube_pipeline(): + return pipeline( + "kv-press-text-generation", + model="h2oai/h2o-danube3-500m-chat", + device=0 if torch.cuda.is_available() else -1, + ) + + +@pytest.fixture(scope="session") +def kv_press_llama3_1_flash_attn_pipeline(): + device = "cuda:0" + ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct" + attn_implementation = "flash_attention_2" + pipe = pipeline( + "kv-press-text-generation", + model=ckpt, + device=device, + torch_dtype="auto", + model_kwargs={"attn_implementation": attn_implementation}, + ) + return pipe diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py new file mode 100644 index 00000000..460203a0 --- /dev/null +++ b/tests/integration/test_ruler.py @@ -0,0 +1,56 @@ +import datasets +import pytest +import torch +from transformers import DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache +from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available + +from kvpress import ( + ExpectedAttentionPress, + KnormPress, + SimLayerKVPress, + SnapKVPress, + StreamingLLMPress, + ThinKPress, + TOVAPress, +) +from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401 + + +@pytest.fixture(scope="session") +def df_ruler(): + df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas() + df = df.loc[df["task"] == "niah_multikey_1"].reset_index(drop=True) + return df + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") +@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") +@pytest.mark.parametrize( + "cls", [KnormPress, ExpectedAttentionPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress, SimLayerKVPress] +) +@pytest.mark.parametrize("compression_ratio", [0.1, 0.2]) +@pytest.mark.parametrize("cache", ["dynamic", "quantized"]) +def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, cls, compression_ratio, cache): # noqa: F811 + if cls == ThinKPress: + press = cls(key_channel_compression_ratio=compression_ratio, window_size=2) + elif cls == SimLayerKVPress: + press = cls(lazy_threshold=1 - compression_ratio) + else: + press = cls(compression_ratio=compression_ratio) + if cache == "dynamic": + cache = DynamicCache() + elif cache == "quantized" and is_optimum_quanto_available(): + config = QuantizedCacheConfig(nbits=4) + cache = QuantoQuantizedCache(config) + elif cache == "quantized" and not is_optimum_quanto_available(): + pytest.skip("Quanto is not installed") + else: + raise ValueError(f"Unknown cache type: {cache}") + + idx = 0 + context = df_ruler.iloc[idx]["context"] + question = df_ruler.iloc[idx]["question"] + true_answer = df_ruler.iloc[idx]["answer"][0] + + pred_answer = kv_press_llama3_1_flash_attn_pipeline(context, question=question, press=press, cache=cache)["answer"] + assert true_answer in pred_answer diff --git a/tests/presses/__init__.py b/tests/presses/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 86dd35ac..29ae9a40 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -12,8 +12,8 @@ KnormPress, ObservedAttentionPress, RandomPress, - SnapKVPress, SimLayerKVPress, + SnapKVPress, StreamingLLMPress, TOVAPress, ) diff --git a/tests/test_generate.py b/tests/test_generate.py index 7bad5974..c5315312 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -3,20 +3,20 @@ from kvpress import KnormPress -from tests.fixtures import kv_press_pipeline # noqa: F401 +from tests.fixtures import kv_press_unit_test_pipeline # noqa: F401 -def test_generate(kv_press_pipeline): # noqa: F811 +def test_generate(kv_press_unit_test_pipeline): # noqa: F811 context = "This is a test article. It was written on 2022-01-01." press = KnormPress(compression_ratio=0.4) # Answer with pipeline - pipe_answer = kv_press_pipeline(context, press=press, max_new_tokens=10)["answer"] + pipe_answer = kv_press_unit_test_pipeline(context, press=press, max_new_tokens=10)["answer"] # Answer with model.generate context += "\n" # kv press pipeline automatically adds a newline if no chat template - model = kv_press_pipeline.model - tokenizer = kv_press_pipeline.tokenizer + model = kv_press_unit_test_pipeline.model + tokenizer = kv_press_unit_test_pipeline.tokenizer with press(model): inputs = tokenizer(context, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_new_tokens=10, do_sample=False) diff --git a/tests/test_per_layer_compression_press.py b/tests/test_per_layer_compression_press.py index b0bf4155..6db01eb2 100644 --- a/tests/test_per_layer_compression_press.py +++ b/tests/test_per_layer_compression_press.py @@ -7,7 +7,7 @@ from kvpress.presses.knorm_press import KnormPress from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress -from tests.fixtures import kv_press_pipeline, unit_test_model # noqa: F401 +from tests.fixtures import unit_test_model # noqa: F401 def test_per_layer_compression_press(unit_test_model): # noqa: F811 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1d4c324b..96993492 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -6,19 +6,23 @@ import pytest import torch -from transformers import AutoTokenizer, DynamicCache +from transformers import AutoTokenizer, DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache +from transformers.utils import is_optimum_quanto_available from kvpress import ExpectedAttentionPress from kvpress.pipeline import KVPressTextGenerationPipeline -from tests.fixtures import danube_500m_model, kv_press_pipeline, unit_test_model # noqa: F401 +from tests.fixtures import danube_500m_model # noqa: F401 +from tests.fixtures import kv_press_danube_pipeline # noqa: F401 +from tests.fixtures import kv_press_unit_test_pipeline # noqa: F401 +from tests.fixtures import unit_test_model # noqa: F401 -def test_pipeline(kv_press_pipeline, caplog): # noqa: F811 +def test_pipeline(kv_press_unit_test_pipeline, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) - answers = kv_press_pipeline(context, questions=questions, press=press)["answers"] + answers = kv_press_unit_test_pipeline(context, questions=questions, press=press)["answers"] assert len(answers) == 1 assert isinstance(answers[0], str) @@ -28,12 +32,23 @@ def test_pipeline(kv_press_pipeline, caplog): # noqa: F811 assert "Compressed Context Length: 13" in messages, messages +def test_pipeline_with_cache(kv_press_unit_test_pipeline, caplog): # noqa: F811 + context = "This is a test article. It was written on 2022-01-01." + questions = ["When was this article written?"] + press = ExpectedAttentionPress(compression_ratio=0.4) + cache = DynamicCache() + answers = kv_press_unit_test_pipeline(context, questions=questions, press=press, cache=cache)["answers"] + + assert len(answers) == 1 + assert isinstance(answers[0], str) + + @pytest.mark.parametrize("question", ["When was this article written?", ""]) -def test_pipeline_single_or_no_question(kv_press_pipeline, question, caplog): # noqa: F811 +def test_pipeline_single_or_no_question(kv_press_unit_test_pipeline, question, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): context = "This is a test article. It was written on 2022-01-01." press = ExpectedAttentionPress(compression_ratio=0.4) - answer = kv_press_pipeline(context, question=question, press=press)["answer"] + answer = kv_press_unit_test_pipeline(context, question=question, press=press)["answer"] assert isinstance(answer, str) @@ -42,10 +57,10 @@ def test_pipeline_single_or_no_question(kv_press_pipeline, question, caplog): # assert "Compressed Context Length: 13" in messages, messages -def test_pipeline_no_press_works(kv_press_pipeline, caplog): # noqa: F811 +def test_pipeline_no_press_works(kv_press_unit_test_pipeline, caplog): # noqa: F811 context = "This is a test article. It was written on 2022-01-01." question = "When was this article written?" - kv_press_pipeline(context, question=question) + kv_press_unit_test_pipeline(context, question=question) @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") @@ -61,11 +76,32 @@ def test_pipeline_answer_is_correct(danube_500m_model, caplog): # noqa: F811 assert "Compressed Context Length: 16" in messages +@pytest.mark.skipif(not is_optimum_quanto_available(), reason="Optimum Quanto is not available") +def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog): # noqa: F811 + with caplog.at_level(logging.DEBUG): + context = "This is a test article. It was written on 2022-01-01." + questions = ["When was this article written?"] + press = ExpectedAttentionPress(compression_ratio=0.4) + config = QuantizedCacheConfig(nbits=4) + cache = QuantoQuantizedCache(config) + answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"] + + assert len(answers) == 1 + assert isinstance(answers[0], str) + + for answer in answers: + assert answer == "This article was written on January 1, 2022." + + messages = [record.message for record in caplog.records] + assert "Context Length: 28" in messages + assert "Compressed Context Length: 16" in messages + + def test_pipeline_compresses_context(unit_test_model, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): answers = generate_answer(unit_test_model) - assert len(answers) == 1 + assert len(answers) == 2 assert isinstance(answers[0], str) messages = [record.message for record in caplog.records] @@ -79,7 +115,7 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 questions = ["When was this article written?"] tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) - compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer) + compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=torch.device("cpu")) input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"] seq_len = 256 @@ -100,7 +136,7 @@ def generate_answer(model): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model.to(device) context = "This is a test article. It was written on 2022-01-01." - questions = ["When was this article written?"] + questions = ["When was this article written?", "When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)(