Skip to content
66 changes: 65 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import inspect
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch import nn

from ..cache_utils import Cache, DynamicCache, StaticCache
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -92,6 +93,13 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module


class CacheImplementation(str, Enum):
DYNAMIC = "dynamic"
STATIC = "static"
SINK = "sink"
Comment on lines +97 to +100
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this prevents anyone from adding / using a custom implementation why not just use strings?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general it is a good practice to use enums to avoid silent behaviours / errors.
E.g. if one passes

model.generation_config.cache_implementation = "Static"

The code will silently work out of the box as it will use dynamic cache, and can potentitally lead to silent errors.

When using enums,

model.set_cache_implementation("Static")

Will also work out of the box but this time will correctly use the static cache implementation and not dynamic cache as opposed to the snippet above.

We can also use a mapping with hardcoded strings but I found it clearer to have enums

cc @amyeroberts @gante @tomaarsen wdyt?

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker Feb 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good practice does not mean it should always be used.

Here it's:

  • useless: we need a mapping with keys from "string_cls":cls anyways.
  • cumbersome: anywhere you compare the generation_config.cache_cls (a string) you need to use the enum. Why?
  • not restrictive, and let me be clear here I do NOT want users to pass anything and expect it to work. If the class is not in the Mapping error out and we are done with it. We should raise and error

TLDR; why would we use then when we don't need it and it only adds additional calls to CacheImplementation.DYNAMIC vs "dynamic" is something I don't understand.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker I don't understand your objections here. For one, best practices does mean it should be done wherever possible.

I don't believe using enums is useless or cumbersome (aside from these not being well defined)? As @younesbelkada highlights, they provide better guarantees of selection of valid types. The explicit enum means the user can pass a string, but the code uses stricter checks that string matching, which are error prone.

not restrictive, and let me be clear here I do NOT want users to pass anything and expect it to work. If the class is not in the Mapping error out and we are done with it. We should raise and error

Enums are restrictive? In fact, they're far more restrictive than doing string checks and have stronger guarantee's than checking a mapping which is mutable.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that is exactly what we want: the mapping should be mutable, because we want it to be easily adapted for a custom code on the hub.
Down for stricter checks, but in this specific case IMO it is cumbersome and useless as we don't have the notion of "safety" and we basically use mapping for model_type, config_type etc etc.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway it's not that important

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, OK, I understand better now. I'm not sure how we can both let users modify a dict and have guarantees about not working if users can pass anything. If I've understood correctly, is the check you're wanting on the membership within the dictionary rather than e.g. how it's handled within the setting logic? In this case, I agree dictionaries are probably the simplest solution.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly 😉



NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
}
Expand Down Expand Up @@ -351,6 +359,43 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
)

def set_cache_implementation(self, cache_implementation: Union[CacheImplementation, str], **kwargs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have this function as set_cache, with past_key_values as an optional parameter. But this is a tougher discussion: I've opened the discussion on our internal slack here 🤗

"""
Simple API to set cache implementation in a model. Users could also do
`model.config.generation_config.cache_implementation = xxx` but they will most likely unexpected
behaviour

Args:
cache_implementation (`Union[CacheImplementation, str]`):
The target cache implementation, currently supported cache implementations are `["sink", "static", "dynamic"]`.
kwargs (`dict`, *optional*):
Optional key word arguments to be passed. E.g. for "sink", it is required to
pass `window_length` and `num_sink_tokens`.
"""
if cache_implementation.upper() not in CacheImplementation.__members__:
raise ValueError(
f"Unrecognized cache implementation - you passed {cache_implementation}. Supported cache implementations are"
f" {CacheImplementation.__members__}"
)

if not self._supports_cache_class:
raise ValueError(
"You cannot currently update the cache implementation for this model. Please raise a feature request on GitHub for adding"
" improved cache support this model: https://github.com/huggingface/transformers"
)

if cache_implementation == CacheImplementation.SINK:
if "window_length" not in kwargs or "num_sink_tokens" not in kwargs:
raise ValueError(
"You requested to use the Sink cache implementation, but you did not pass `window_length` and `num_sink_tokens` to "
"`set_cache_implementation`. Try again with passing these arguments to the method. (e.g. `model.set_cache_implementation('sink', window_length=window_length=508, num_sink_tokens=4)`"
)
self.generation_config.sink_window_length = kwargs.get("window_length")
self.generation_config.num_sink_tokens = kwargs.get("num_sink_tokens")

self._reset_cache()
self.generation_config.cache_implementation = cache_implementation.lower()

def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1194,6 +1239,24 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de
UserWarning,
)

def _sanitize_kwargs(self, generation_config: GenerationConfig, kwargs: dict):
"""
Sanitize the kwargs for generation
"""
if getattr(generation_config, "cache_implementation", "") == CacheImplementation.SINK:
if "past_key_values" not in kwargs:
kwargs["past_key_values"] = SinkCache(
window_length=generation_config.sink_window_length,
num_sink_tokens=generation_config.num_sink_tokens,
)
Comment on lines +1246 to +1251
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not pass the class and do the same init scheme with generation config passed and you take these from the generation config to allow custom classes?

else:
warnings.warn(
"You have already called `model.set_cache_implementation('sink')` and you are passing a `SinkCache()`"
" to `generate`. We assume you know what you are doing and we'll silently ignore the `SinkCache` arguments you"
" passed into `set_cache_implementation`."
)
return kwargs

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1329,6 +1392,7 @@ def generate(
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
kwargs = self._sanitize_kwargs(generation_config, kwargs)

# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
Expand Down
91 changes: 89 additions & 2 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,55 @@ def _random_kvs(config):
@require_torch_gpu
@slow
class CacheIntegrationTest(unittest.TestCase):
def _check_static_cache_correct(self, model):
for m in model.modules():
if hasattr(m, "past_key_value") and isinstance(m.past_key_value, StaticCache):
return True
return False

def test_set_cache_api_errors(self):
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")

with self.assertRaises(ValueError):
# Sink cache requires some specific args
model.set_cache_implementation("sink")

with self.assertRaises(ValueError):
# Dummy name
model.set_cache_implementation("dummy-name")

def test_set_cache_api(self):
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")

# First, set it to static.
model.set_cache_implementation("static")

inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)

# At this point the static cache should be correctly set
self.assertTrue(self._check_static_cache_correct(model))

_ = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

model.set_cache_implementation("dynamic")
# This should return false as the cache should have been reset at this point.
self.assertFalse(self._check_static_cache_correct(model))

gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
_ = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

model.set_cache_implementation("sink", window_length=128, num_sink_tokens=4)
# This should return false as the cache should have been reset at this point.
self.assertFalse(self._check_static_cache_correct(model))

gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
_ = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

def test_dynamic_cache_hard(self):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -289,6 +338,44 @@ def test_sink_cache_iterative_prompts(self):
)
self.assertTrue(decoded[0].endswith(last_output))

def test_sink_cache_iterative_prompts_new_api(self):
"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16
)
prompt = (
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
"and must-see attractions."
)

# Prepare generation settings
model.set_cache_implementation("sink", window_length=256, num_sink_tokens=4)

input_ids = torch.tensor([], device=model.device, dtype=torch.int)
for _ in range(3):
# Tokenize the prompt with the correct chat template
chat = [{"role": "user", "content": prompt}]
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
)
input_ids = torch.cat((input_ids, tokenized_chat), dim=1)

# Perform the generation
gen_out = model.generate(input_ids, do_sample=False, max_new_tokens=100, use_cache=True)
input_ids = gen_out

# And it still produces a coherent english
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
last_output = (
"<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
"and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip "
"was visiting the historic district of Honolulu. Here,"
)
self.assertTrue(decoded[0].endswith(last_output))

@require_torch_gpu
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
Expand Down Expand Up @@ -316,7 +403,7 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
self.assertListEqual(decoded, EXPECTED_GENERATION)

set_seed(0)
model.generation_config.cache_implementation = "static"
model.set_cache_implementation("static")
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
Expand Down Expand Up @@ -356,7 +443,7 @@ def test_static_cache_greedy_sampling_pad_right(self, attn_implementation):
self.assertListEqual(decoded, EXPECTED_GENERATION)

set_seed(0)
model.generation_config.cache_implementation = "static"
model.set_cache_implementation("static")
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
Expand Down