-
Notifications
You must be signed in to change notification settings - Fork 33.1k
FEAT [Generation]: Introduce a centralized API to switch between cache implementations
#29030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1430ca6
b641975
a7df058
989bd9b
16669aa
ed77747
84c7a53
2810ffa
efdcefc
3df9253
387a9e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,13 +18,14 @@ | |
| import inspect | ||
| import warnings | ||
| from dataclasses import dataclass | ||
| from enum import Enum | ||
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from torch import nn | ||
|
|
||
| from ..cache_utils import Cache, DynamicCache, StaticCache | ||
| from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache | ||
| from ..integrations.deepspeed import is_deepspeed_zero3_enabled | ||
| from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput | ||
| from ..models.auto import ( | ||
|
|
@@ -92,6 +93,13 @@ | |
| if is_accelerate_available(): | ||
| from accelerate.hooks import AlignDevicesHook, add_hook_to_module | ||
|
|
||
|
|
||
| class CacheImplementation(str, Enum): | ||
| DYNAMIC = "dynamic" | ||
| STATIC = "static" | ||
| SINK = "sink" | ||
|
|
||
|
|
||
| NEED_SETUP_CACHE_CLASSES_MAPPING = { | ||
| "static": StaticCache, | ||
| } | ||
|
|
@@ -351,6 +359,43 @@ def prepare_inputs_for_generation(self, *args, **kwargs): | |
| "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`." | ||
| ) | ||
|
|
||
| def set_cache_implementation(self, cache_implementation: Union[CacheImplementation, str], **kwargs): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather have this function as |
||
| """ | ||
| Simple API to set cache implementation in a model. Users could also do | ||
| `model.config.generation_config.cache_implementation = xxx` but they will most likely unexpected | ||
| behaviour | ||
|
|
||
| Args: | ||
| cache_implementation (`Union[CacheImplementation, str]`): | ||
| The target cache implementation, currently supported cache implementations are `["sink", "static", "dynamic"]`. | ||
| kwargs (`dict`, *optional*): | ||
| Optional key word arguments to be passed. E.g. for "sink", it is required to | ||
| pass `window_length` and `num_sink_tokens`. | ||
| """ | ||
| if cache_implementation.upper() not in CacheImplementation.__members__: | ||
| raise ValueError( | ||
| f"Unrecognized cache implementation - you passed {cache_implementation}. Supported cache implementations are" | ||
| f" {CacheImplementation.__members__}" | ||
| ) | ||
|
|
||
| if not self._supports_cache_class: | ||
| raise ValueError( | ||
| "You cannot currently update the cache implementation for this model. Please raise a feature request on GitHub for adding" | ||
| " improved cache support this model: https://github.com/huggingface/transformers" | ||
| ) | ||
|
|
||
| if cache_implementation == CacheImplementation.SINK: | ||
| if "window_length" not in kwargs or "num_sink_tokens" not in kwargs: | ||
| raise ValueError( | ||
| "You requested to use the Sink cache implementation, but you did not pass `window_length` and `num_sink_tokens` to " | ||
| "`set_cache_implementation`. Try again with passing these arguments to the method. (e.g. `model.set_cache_implementation('sink', window_length=window_length=508, num_sink_tokens=4)`" | ||
| ) | ||
| self.generation_config.sink_window_length = kwargs.get("window_length") | ||
| self.generation_config.num_sink_tokens = kwargs.get("num_sink_tokens") | ||
|
|
||
| self._reset_cache() | ||
| self.generation_config.cache_implementation = cache_implementation.lower() | ||
|
|
||
| def _prepare_model_inputs( | ||
| self, | ||
| inputs: Optional[torch.Tensor] = None, | ||
|
|
@@ -1194,6 +1239,24 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de | |
| UserWarning, | ||
| ) | ||
|
|
||
| def _sanitize_kwargs(self, generation_config: GenerationConfig, kwargs: dict): | ||
| """ | ||
| Sanitize the kwargs for generation | ||
| """ | ||
| if getattr(generation_config, "cache_implementation", "") == CacheImplementation.SINK: | ||
| if "past_key_values" not in kwargs: | ||
| kwargs["past_key_values"] = SinkCache( | ||
| window_length=generation_config.sink_window_length, | ||
| num_sink_tokens=generation_config.num_sink_tokens, | ||
| ) | ||
|
Comment on lines
+1246
to
+1251
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not pass the class and do the same init scheme with generation config passed and you take these from the generation config to allow custom classes? |
||
| else: | ||
| warnings.warn( | ||
| "You have already called `model.set_cache_implementation('sink')` and you are passing a `SinkCache()`" | ||
| " to `generate`. We assume you know what you are doing and we'll silently ignore the `SinkCache` arguments you" | ||
| " passed into `set_cache_implementation`." | ||
| ) | ||
| return kwargs | ||
|
|
||
| @torch.no_grad() | ||
| def generate( | ||
| self, | ||
|
|
@@ -1329,6 +1392,7 @@ def generate( | |
| model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs | ||
| generation_config.validate() | ||
| self._validate_model_kwargs(model_kwargs.copy()) | ||
| kwargs = self._sanitize_kwargs(generation_config, kwargs) | ||
|
|
||
| # 2. Set generation parameters if not already defined | ||
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this prevents anyone from adding / using a custom implementation why not just use strings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in general it is a good practice to use enums to avoid silent behaviours / errors.
E.g. if one passes
The code will silently work out of the box as it will use dynamic cache, and can potentitally lead to silent errors.
When using enums,
Will also work out of the box but this time will correctly use the static cache implementation and not dynamic cache as opposed to the snippet above.
We can also use a mapping with hardcoded strings but I found it clearer to have enums
cc @amyeroberts @gante @tomaarsen wdyt?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good practice does not mean it should always be used.
Here it's:
"string_cls":clsanyways.TLDR; why would we use then when we don't need it and it only adds additional calls to
CacheImplementation.DYNAMICvs"dynamic"is something I don't understand.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker I don't understand your objections here. For one, best practices does mean it should be done wherever possible.
I don't believe using enums is useless or cumbersome (aside from these not being well defined)? As @younesbelkada highlights, they provide better guarantees of selection of valid types. The explicit enum means the user can pass a string, but the code uses stricter checks that string matching, which are error prone.
Enums are restrictive? In fact, they're far more restrictive than doing string checks and have stronger guarantee's than checking a mapping which is mutable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But that is exactly what we want: the mapping should be mutable, because we want it to be easily adapted for a custom code on the hub.
Down for stricter checks, but in this specific case IMO it is cumbersome and useless as we don't have the notion of "safety" and we basically use mapping for model_type, config_type etc etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway it's not that important
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, OK, I understand better now. I'm not sure how we can both let users modify a dict and have guarantees about not working if users can pass anything. If I've understood correctly, is the check you're wanting on the membership within the dictionary rather than e.g. how it's handled within the setting logic? In this case, I agree dictionaries are probably the simplest solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly 😉