[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes)#35993
[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes)#35993gante merged 26 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| # - different models have a different cache name expected by the model (default = "past_key_values") | ||
| # - `max_length`, prepared above, is used to determine the maximum cache length | ||
| max_cache_length = generation_config.max_length | ||
| max_cache_length = generation_config.max_length - 1 |
There was a problem hiding this comment.
We were creating caches that were larger than what was set by our max length flags: the last token is not present in the cache.
|
Before diving into more details, I think it's best to separate the fix of copies in a separate PR. we can revert the changes (in this PR) to those files once the above PR is merged. |
|
@ydshieh should be ready for a re-review. Two notes:
|
ydshieh
left a comment
There was a problem hiding this comment.
left 2 nit comments + I don't know well the removal of _update_model_kwargs_for_generation in the 2 modeling files (for which I will leave to another reviewer 🙏 )
Otherwise all good to me 💯 thank you
zucchini-nlp
left a comment
There was a problem hiding this comment.
Thanks for cleaning up! Left one suggestion for static cache tests :)
|
|
||
| inputs_embeds = model.get_input_embeddings()(input_ids) | ||
| max_cache_len += inputs_embeds.shape[1] | ||
| max_cache_len += inputs_embeds.shape[1] - 1 # the last generated token has no cache |
There was a problem hiding this comment.
maybe not super related, imo using max_new_tokens instead of max_cache_len is better. We cant know if the inputs length is already longer than preset max_cache_len
Already had it fixed somewhere, but it was a huge PR so it's lost hehe
There was a problem hiding this comment.
(leaving this one for a separate PR, moving conversation to slack)
| batch_size=internal_batch_size, | ||
| attentions=output.decoder_attentions, | ||
| prompt_length=1, # the BOS token | ||
| output_length=output.sequences.shape[-1], |
| generated_length = ( | ||
| output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length | ||
| ) | ||
| decoder_past_key_values = getattr(output, "past_key_values", None) |
There was a problem hiding this comment.
Is it possible that we have no attr past_key_values?
There was a problem hiding this comment.
When use_cache=False :P (or models with different cache names, like RWKV)
| self.assertEqual(len(attentions), (output_length - prompt_length)) | ||
|
|
||
| use_cache = decoder_past_key_values is not None | ||
| has_static_cache = isinstance(decoder_past_key_values, (StaticCache, HybridCache)) |
There was a problem hiding this comment.
possible edge case, HybridCache has non-uniform max length for each layer and we might have sliding layer lengths which are different from non-sliding layers. Do you think we need to account for that?
There was a problem hiding this comment.
I'll add a note here, so we can easily know what to do when the test breaks.
(but for now will leave as is, to contain test complexity)
…(+ some minor fixes) (huggingface#35993) * shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
What does this PR do?
Fixes a few test cases exposed by #33212 [e.g. some models are failing cache shape checks, if their default cache is not a dynamic cache]. Fix added in a separate PR to avoid making #33212 a huge PR :)
This PR:
output_attentions=Truenow explicitly useeagerattention (sdpadoesn't return the attentions,eagerwas being used implicitly with warnings);