FEAT [Generation]: Introduce a centralized API to switch between cache implementations#29030
FEAT [Generation]: Introduce a centralized API to switch between cache implementations#29030younesbelkada wants to merge 11 commits intohuggingface:mainfrom
Generation]: Introduce a centralized API to switch between cache implementations#29030Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`." | ||
| ) | ||
|
|
||
| def switch_cache_implementation(self, cache_implementation: Union[CacheImplementation, str], **kwargs): |
There was a problem hiding this comment.
In practice, we are obviously always switching from one cache implementation to another, but for the users it's likely more intuitive to simply set a new cache implementation, so set_cache_implementation might be a better method.
There was a problem hiding this comment.
I agree! thanks for ponting that out !
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
tomaarsen
left a comment
There was a problem hiding this comment.
Some more small nits regarding the phrasing now that we're using set_...
| "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`." | ||
| ) | ||
|
|
||
| def set_cache_implementation(self, cache_implementation: Union[CacheImplementation, str], **kwargs): |
There was a problem hiding this comment.
I'd rather have this function as set_cache, with past_key_values as an optional parameter. But this is a tougher discussion: I've opened the discussion on our internal slack here 🤗
amyeroberts
left a comment
There was a problem hiding this comment.
Looks great! Thanks for adding
Agreed set_ is preferable to switch_.
Only request is for tests!
| self.generation_config.sink_window_length = kwargs.get("window_length") | ||
| self.generation_config.num_sink_tokens = kwargs.get("num_sink_tokens") |
There was a problem hiding this comment.
Just for my own understanding, if the currently set cache implementation is sink, with e.g. window_length=window_length=508, num_sink_tokens=4, should I be allowed to call set_cache_implementation with updated parameters i.e.
config.set_cache_implementation(CacheImplementation.SINK, window_length=1_016, num_sink_tokens=8)
?
There was a problem hiding this comment.
yes I think so !
Per my understanding the way the API is designed currently for sink cache is to pass a SinkCache() object to generate() through past_key_values. If past_key_values is already passed in generate(), the way I designed things in the PR, we only use the cache from set_cache_implementation in case past_key_values is not passed to generate.
TLDR: If one calls model.set_cache_implementation(CacheImplementation.SINK), they shouldn't call model.generate(xxx, past_key_values=SinkCache()) --> maybe we should raise a warning saying that they already called model.set_cache_impl(), wdyt? | Commit is here: 2810ffa
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
ArthurZucker
left a comment
There was a problem hiding this comment.
Down for that, let's not add enums where we don't need, let's make it easy to use custom CacheClasses and init them with a generation config no?
| class CacheImplementation(str, Enum): | ||
| DYNAMIC = "dynamic" | ||
| STATIC = "static" | ||
| SINK = "sink" |
There was a problem hiding this comment.
this prevents anyone from adding / using a custom implementation why not just use strings?
There was a problem hiding this comment.
I think in general it is a good practice to use enums to avoid silent behaviours / errors.
E.g. if one passes
model.generation_config.cache_implementation = "Static"The code will silently work out of the box as it will use dynamic cache, and can potentitally lead to silent errors.
When using enums,
model.set_cache_implementation("Static")Will also work out of the box but this time will correctly use the static cache implementation and not dynamic cache as opposed to the snippet above.
We can also use a mapping with hardcoded strings but I found it clearer to have enums
cc @amyeroberts @gante @tomaarsen wdyt?
There was a problem hiding this comment.
good practice does not mean it should always be used.
Here it's:
- useless: we need a mapping with keys from
"string_cls":clsanyways. - cumbersome: anywhere you compare the generation_config.cache_cls (a string) you need to use the enum. Why?
- not restrictive, and let me be clear here I do NOT want users to pass anything and expect it to work. If the class is not in the Mapping error out and we are done with it. We should raise and error
TLDR; why would we use then when we don't need it and it only adds additional calls to CacheImplementation.DYNAMIC vs "dynamic" is something I don't understand.
There was a problem hiding this comment.
@ArthurZucker I don't understand your objections here. For one, best practices does mean it should be done wherever possible.
I don't believe using enums is useless or cumbersome (aside from these not being well defined)? As @younesbelkada highlights, they provide better guarantees of selection of valid types. The explicit enum means the user can pass a string, but the code uses stricter checks that string matching, which are error prone.
not restrictive, and let me be clear here I do NOT want users to pass anything and expect it to work. If the class is not in the Mapping error out and we are done with it. We should raise and error
Enums are restrictive? In fact, they're far more restrictive than doing string checks and have stronger guarantee's than checking a mapping which is mutable.
There was a problem hiding this comment.
But that is exactly what we want: the mapping should be mutable, because we want it to be easily adapted for a custom code on the hub.
Down for stricter checks, but in this specific case IMO it is cumbersome and useless as we don't have the notion of "safety" and we basically use mapping for model_type, config_type etc etc.
There was a problem hiding this comment.
Anyway it's not that important
There was a problem hiding this comment.
Ah, OK, I understand better now. I'm not sure how we can both let users modify a dict and have guarantees about not working if users can pass anything. If I've understood correctly, is the check you're wanting on the membership within the dictionary rather than e.g. how it's handled within the setting logic? In this case, I agree dictionaries are probably the simplest solution.
| if getattr(generation_config, "cache_implementation", "") == CacheImplementation.SINK: | ||
| if "past_key_values" not in kwargs: | ||
| kwargs["past_key_values"] = SinkCache( | ||
| window_length=generation_config.sink_window_length, | ||
| num_sink_tokens=generation_config.num_sink_tokens, | ||
| ) |
There was a problem hiding this comment.
why not pass the class and do the same init scheme with generation config passed and you take these from the generation config to allow custom classes?
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
I would like to introduce a new API before the release to centralize switching between cache implementations !
Right now to load SInkCache one needs to do:
For static cache:
With this PR:
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ") model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto") - cache = SinkCache(window_length=508, num_sink_tokens=4) + model.set_cache_implementation("sink", sink_window_length=508, num_sink_tokens=4) inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device) - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=300, past_key_value=cache) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=300) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ") model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto") - model.generation_config.cache_implementation = "static" + model.set_cache_implementation("static") inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device) gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=300, past_key_values=cache) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)What do you think @gante @tomaarsen @ArthurZucker @amyeroberts ?
If you are happy with the design and idea I can move forward with adding tests and docs !