Fix: StaticCache & inputs_embeds#32932
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks, to nits but good otherwise. Do we take the max of num beams, num return sequences because they stem from beams?
gante
left a comment
There was a problem hiding this comment.
Thank you for taking care of gemma 2 🤗
| else () | ||
| ) | ||
| all_generative_model_classes = () | ||
| all_generative_model_classes = (Gemma2ForCausalLM,) if is_torch_available() else () |
There was a problem hiding this comment.
This was removed because it was faiiling too many tests
There was a problem hiding this comment.
yes, I skipped those that shouldn't be triggered due to model-specific cache and fixed other failing ones
| def test_generate_from_inputs_embeds_with_static_cache(self): | ||
| pass | ||
|
|
||
| def _check_attentions_for_generate( |
There was a problem hiding this comment.
Let's add the reason for the overwrite at the top of the fn as a comment, here an on the other functions that need an overwrite! That way, we immediately know why the function needs to exist :)
(I see that you added a few comments below, like HybridCache has fixed length for key/values, moving it to the top suffices)
4dd1494 to
fce9e7e
Compare
|
Hi, run into similar errors as in #32911, will this PR get merged? |
|
Yes, merging now, should be ready |
squash commit
What does this PR do?
Fixes #32911. Enables generation with Static Cache and inputs embeds, previously it was failing due to incorrect calculation of
max_cache_lengthAdded a test for that and added tests for
Gemma2ForCausalLM. Some things to note:StaticCache. It can with some small changes but imo we shouldn't