Skip to content

Fix: StaticCache & inputs_embeds#32932

Merged
zucchini-nlp merged 4 commits intohuggingface:mainfrom
zucchini-nlp:embeds-with-static-cache
Sep 6, 2024
Merged

Fix: StaticCache & inputs_embeds#32932
zucchini-nlp merged 4 commits intohuggingface:mainfrom
zucchini-nlp:embeds-with-static-cache

Conversation

@zucchini-nlp
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp commented Aug 22, 2024

What does this PR do?

Fixes #32911. Enables generation with Static Cache and inputs embeds, previously it was failing due to incorrect calculation of max_cache_length

Added a test for that and added tests for Gemma2ForCausalLM. Some things to note:

  • Gemma2 doesn't support StaticCache. It can with some small changes but imo we shouldn't
  • Static shape cache classes have no support for contrastive search, dola, low-memory generation and assisted decoding. So these tests are all skipped in Gemma2. I think if we want to enable the, it should go on another PR for upgrading static cache classes

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, to nits but good otherwise. Do we take the max of num beams, num return sequences because they stem from beams?

Comment thread src/transformers/generation/utils.py Outdated
Comment thread src/transformers/generation/utils.py Outdated
Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for taking care of gemma 2 🤗

Comment thread src/transformers/generation/utils.py Outdated
Comment thread src/transformers/generation/utils.py Outdated
Comment thread tests/generation/test_utils.py Outdated
else ()
)
all_generative_model_classes = ()
all_generative_model_classes = (Gemma2ForCausalLM,) if is_torch_available() else ()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😱 good spot!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was removed because it was faiiling too many tests

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I skipped those that shouldn't be triggered due to model-specific cache and fixed other failing ones

def test_generate_from_inputs_embeds_with_static_cache(self):
pass

def _check_attentions_for_generate(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the reason for the overwrite at the top of the fn as a comment, here an on the other functions that need an overwrite! That way, we immediately know why the function needs to exist :)

(I see that you added a few comments below, like HybridCache has fixed length for key/values, moving it to the top suffices)

@zucchini-nlp zucchini-nlp force-pushed the embeds-with-static-cache branch from 4dd1494 to fce9e7e Compare August 30, 2024 17:33
Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating 💛

@zzxslp
Copy link
Copy Markdown

zzxslp commented Sep 6, 2024

Hi, run into similar errors as in #32911, will this PR get merged?

@zucchini-nlp
Copy link
Copy Markdown
Member Author

Yes, merging now, should be ready

@zucchini-nlp zucchini-nlp merged commit 1759bb9 into huggingface:main Sep 6, 2024
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Error in _prepare_generated_length

5 participants