Skip to content

[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes)#35993

Merged
gante merged 26 commits intohuggingface:mainfrom
gante:static_cache_checks
Feb 10, 2025
Merged

[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes)#35993
gante merged 26 commits intohuggingface:mainfrom
gante:static_cache_checks

Conversation

@gante
Copy link
Copy Markdown
Contributor

@gante gante commented Jan 31, 2025

What does this PR do?

Fixes a few test cases exposed by #33212 [e.g. some models are failing cache shape checks, if their default cache is not a dynamic cache]. Fix added in a separate PR to avoid making #33212 a huge PR :)

This PR:

  1. Updates generate output checks to also work with fixed-length caches;
  2. Tests with output_attentions=True now explicitly use eager attention (sdpa doesn't return the attentions, eager was being used implicitly with warnings);
  3. (enabled by 1. and 2.) Adds a test regarding generate + extra outputs + compile, which was not being tested;
  4. Standardizes the inconsistent method to pull the cache from the model outputs in the generation loop;
  5. (enabled by 1.) Deletes unnecessary overwrites.
  6. [EDIT, added after PR reviews] updates variable names used in common tests, so we can easily understand what's going on.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante changed the title [generate] shape checks in tests compatible with static cache [generate] shape checks in tests compatible with fixed-length caches Feb 3, 2025
@gante gante changed the title [generate] shape checks in tests compatible with fixed-length caches [generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) Feb 3, 2025
# - different models have a different cache name expected by the model (default = "past_key_values")
# - `max_length`, prepared above, is used to determine the maximum cache length
max_cache_length = generation_config.max_length
max_cache_length = generation_config.max_length - 1
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were creating caches that were larger than what was set by our max length flags: the last token is not present in the cache.

@gante gante requested review from ydshieh and zucchini-nlp February 5, 2025 19:08
@gante gante marked this pull request as ready for review February 5, 2025 19:20
Comment thread src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated
Comment thread src/transformers/generation/utils.py Outdated
@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Feb 6, 2025

Before diving into more details, I think it's best to separate the fix of copies in a separate PR.

#36063

we can revert the changes (in this PR) to those files once the above PR is merged.

Comment thread tests/generation/test_utils.py
Comment thread tests/generation/test_utils.py Outdated
Comment thread tests/generation/test_utils.py
Comment thread tests/generation/test_utils.py Outdated
Comment thread tests/generation/test_utils.py Outdated
@gante
Copy link
Copy Markdown
Contributor Author

gante commented Feb 6, 2025

@ydshieh should be ready for a re-review.

Two notes:

  1. _check_outputs and especially its inner functions should be much easier to follow. LMK if any part is unclear—documenting these general functions is important to future-proof our test base!
  2. As I rewrote the test functions, I noticed some overwrites were not necessary (or at least could be simplified) ✂️

Copy link
Copy Markdown
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left 2 nit comments + I don't know well the removal of _update_model_kwargs_for_generation in the 2 modeling files (for which I will leave to another reviewer 🙏 )

Otherwise all good to me 💯 thank you

Comment thread tests/generation/test_utils.py Outdated
Comment thread tests/generation/test_utils.py
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning up! Left one suggestion for static cache tests :)


inputs_embeds = model.get_input_embeddings()(input_ids)
max_cache_len += inputs_embeds.shape[1]
max_cache_len += inputs_embeds.shape[1] - 1 # the last generated token has no cache
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe not super related, imo using max_new_tokens instead of max_cache_len is better. We cant know if the inputs length is already longer than preset max_cache_len

Already had it fixed somewhere, but it was a huge PR so it's lost hehe

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(leaving this one for a separate PR, moving conversation to slack)

Comment on lines +2415 to +2418
batch_size=internal_batch_size,
attentions=output.decoder_attentions,
prompt_length=1, # the BOS token
output_length=output.sequences.shape[-1],
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love the naming!

generated_length = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length
)
decoder_past_key_values = getattr(output, "past_key_values", None)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that we have no attr past_key_values?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When use_cache=False :P (or models with different cache names, like RWKV)

self.assertEqual(len(attentions), (output_length - prompt_length))

use_cache = decoder_past_key_values is not None
has_static_cache = isinstance(decoder_past_key_values, (StaticCache, HybridCache))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible edge case, HybridCache has non-uniform max length for each layer and we might have sliding layer lengths which are different from non-sliding layers. Do you think we need to account for that?

Copy link
Copy Markdown
Contributor Author

@gante gante Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a note here, so we can easily know what to do when the test breaks.

(but for now will leave as is, to contain test complexity)

@gante gante merged commit be2ac09 into huggingface:main Feb 10, 2025
@gante gante deleted the static_cache_checks branch February 10, 2025 17:50
sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Feb 16, 2025
…(+ some minor fixes) (huggingface#35993)

* shape checks compatible with static cache

* add test

* tmp

* manually turn on eager attn when we want to output attn

* typo

* generalize to encoder-decoder models

* force compilation on cpu

* tmp commit

* fix static cache shape checks

* models with odd caches

* fix copies

* shorter cache search loop

* use decoder_past_key_values everywhere

* better test variable names and comments

* signature

* rename _check_outputs into _check_generate_outputs

* add comments

* HybridCache future test note
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants