diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0c6740b32388..bac6c7eab829 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -18,13 +18,14 @@ import inspect import warnings from dataclasses import dataclass +from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist from torch import nn -from ..cache_utils import Cache, DynamicCache, StaticCache +from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( @@ -92,6 +93,13 @@ if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, add_hook_to_module + +class CacheImplementation(str, Enum): + DYNAMIC = "dynamic" + STATIC = "static" + SINK = "sink" + + NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache, } @@ -351,6 +359,43 @@ def prepare_inputs_for_generation(self, *args, **kwargs): "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`." ) + def set_cache_implementation(self, cache_implementation: Union[CacheImplementation, str], **kwargs): + """ + Simple API to set cache implementation in a model. Users could also do + `model.config.generation_config.cache_implementation = xxx` but they will most likely unexpected + behaviour + + Args: + cache_implementation (`Union[CacheImplementation, str]`): + The target cache implementation, currently supported cache implementations are `["sink", "static", "dynamic"]`. + kwargs (`dict`, *optional*): + Optional key word arguments to be passed. E.g. for "sink", it is required to + pass `window_length` and `num_sink_tokens`. + """ + if cache_implementation.upper() not in CacheImplementation.__members__: + raise ValueError( + f"Unrecognized cache implementation - you passed {cache_implementation}. Supported cache implementations are" + f" {CacheImplementation.__members__}" + ) + + if not self._supports_cache_class: + raise ValueError( + "You cannot currently update the cache implementation for this model. Please raise a feature request on GitHub for adding" + " improved cache support this model: https://github.com/huggingface/transformers" + ) + + if cache_implementation == CacheImplementation.SINK: + if "window_length" not in kwargs or "num_sink_tokens" not in kwargs: + raise ValueError( + "You requested to use the Sink cache implementation, but you did not pass `window_length` and `num_sink_tokens` to " + "`set_cache_implementation`. Try again with passing these arguments to the method. (e.g. `model.set_cache_implementation('sink', window_length=window_length=508, num_sink_tokens=4)`" + ) + self.generation_config.sink_window_length = kwargs.get("window_length") + self.generation_config.num_sink_tokens = kwargs.get("num_sink_tokens") + + self._reset_cache() + self.generation_config.cache_implementation = cache_implementation.lower() + def _prepare_model_inputs( self, inputs: Optional[torch.Tensor] = None, @@ -1194,6 +1239,24 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) + def _sanitize_kwargs(self, generation_config: GenerationConfig, kwargs: dict): + """ + Sanitize the kwargs for generation + """ + if getattr(generation_config, "cache_implementation", "") == CacheImplementation.SINK: + if "past_key_values" not in kwargs: + kwargs["past_key_values"] = SinkCache( + window_length=generation_config.sink_window_length, + num_sink_tokens=generation_config.num_sink_tokens, + ) + else: + warnings.warn( + "You have already called `model.set_cache_implementation('sink')` and you are passing a `SinkCache()`" + " to `generate`. We assume you know what you are doing and we'll silently ignore the `SinkCache` arguments you" + " passed into `set_cache_implementation`." + ) + return kwargs + @torch.no_grad() def generate( self, @@ -1329,6 +1392,7 @@ def generate( model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) + kwargs = self._sanitize_kwargs(generation_config, kwargs) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 5f3af2acf572..bf4b6bdb02c2 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -168,6 +168,55 @@ def _random_kvs(config): @require_torch_gpu @slow class CacheIntegrationTest(unittest.TestCase): + def _check_static_cache_correct(self, model): + for m in model.modules(): + if hasattr(m, "past_key_value") and isinstance(m.past_key_value, StaticCache): + return True + return False + + def test_set_cache_api_errors(self): + model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + + with self.assertRaises(ValueError): + # Sink cache requires some specific args + model.set_cache_implementation("sink") + + with self.assertRaises(ValueError): + # Dummy name + model.set_cache_implementation("dummy-name") + + def test_set_cache_api(self): + model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM" + + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + + # First, set it to static. + model.set_cache_implementation("static") + + inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + + # At this point the static cache should be correctly set + self.assertTrue(self._check_static_cache_correct(model)) + + _ = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + + model.set_cache_implementation("dynamic") + # This should return false as the cache should have been reset at this point. + self.assertFalse(self._check_static_cache_correct(model)) + + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + _ = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + + model.set_cache_implementation("sink", window_length=128, num_sink_tokens=4) + # This should return false as the cache should have been reset at this point. + self.assertFalse(self._check_static_cache_correct(model)) + + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + _ = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + def test_dynamic_cache_hard(self): tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left") model = AutoModelForCausalLM.from_pretrained( @@ -289,6 +338,44 @@ def test_sink_cache_iterative_prompts(self): ) self.assertTrue(decoded[0].endswith(last_output)) + def test_sink_cache_iterative_prompts_new_api(self): + """Tests that SinkCache supports more than one new token at once, when shifting the cache""" + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") + model = AutoModelForCausalLM.from_pretrained( + "HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16 + ) + prompt = ( + "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences " + "and must-see attractions." + ) + + # Prepare generation settings + model.set_cache_implementation("sink", window_length=256, num_sink_tokens=4) + + input_ids = torch.tensor([], device=model.device, dtype=torch.int) + for _ in range(3): + # Tokenize the prompt with the correct chat template + chat = [{"role": "user", "content": prompt}] + tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( + model.device + ) + input_ids = torch.cat((input_ids, tokenized_chat), dim=1) + + # Perform the generation + gen_out = model.generate(input_ids, do_sample=False, max_new_tokens=100, use_cache=True) + input_ids = gen_out + + # And it still produces a coherent english + decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) + last_output = ( + "<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of " + "Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the " + "beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences " + "and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip " + "was visiting the historic district of Honolulu. Here," + ) + self.assertTrue(decoded[0].endswith(last_output)) + @require_torch_gpu @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): @@ -316,7 +403,7 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): self.assertListEqual(decoded, EXPECTED_GENERATION) set_seed(0) - model.generation_config.cache_implementation = "static" + model.set_cache_implementation("static") gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) with self.subTest(f"{attn_implementation}, static, eager"): @@ -356,7 +443,7 @@ def test_static_cache_greedy_sampling_pad_right(self, attn_implementation): self.assertListEqual(decoded, EXPECTED_GENERATION) set_seed(0) - model.generation_config.cache_implementation = "static" + model.set_cache_implementation("static") gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) with self.subTest(f"{attn_implementation}, static, eager"):