Cache: don't throw warnings on gemma2 when instantiating a new cache#33595
Cache: don't throw warnings on gemma2 when instantiating a new cache#33595gante merged 3 commits intohuggingface:mainfrom
gemma2 when instantiating a new cache#33595Conversation
gemma2 when instantiating a new cache
| def get_seq_length(self, layer_idx: Optional[int] = 0): | ||
| return None | ||
| # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | ||
| # limit the check to the first batch member and head dimension. | ||
| # TODO: deprecate this function in favor of `cache_position` | ||
| return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
HybridCache is a StaticCache with alternating sliding window layers. The method to retrieve the cache length is copy/paste from StaticCache
We will want to use another method in the future, but let's leave this as a copy of StaticCache for now. This method is needed in the updated gemma 2.
| raise ValueError("When `past_key_values` is passed, `cache_position` must be too") | ||
|
|
||
| # Probably a forward call with caching, so we set up cache for one call only | ||
| if use_cache and past_key_values is None and not self.training: |
There was a problem hiding this comment.
Two changes here, both to be consistent with other models:
self.trainingshould not control whether we instantiate a cache- If a user respects the types in the docs,
past_key_valuesis either aCacheor we instantiate a new one for the user without warnings
| dtype=inputs_embeds.dtype, | ||
| ) | ||
|
|
||
| if cache_position is None: |
There was a problem hiding this comment.
copy/paste from llama (and other Cache-supporting models)
There was a problem hiding this comment.
okey, this should always work actually since the seq length gets layer_idx=0. Just one question, isn't it a bit misleading if some layers will have get_seq_length() number of tokens while others no more than sliding window length?
There was a problem hiding this comment.
@zucchini-nlp yes, if get_seq_length gets called on the wrong layer we will have problems! I'm going to add an exception if it gets called on layer_idx != 0 (I doubt we need it).
There was a problem hiding this comment.
okey sounds good, as long as the function of get_seq_length is transparent for users, to reduce number of cache-related question we get 😄
|
|
||
| if use_cache and past_key_values is None and not self.training: | ||
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
| if use_cache and not isinstance(past_key_values, Cache): |
There was a problem hiding this comment.
copy/paste from llama (and other Cache-supporting models)
| def test_model_outputs_equivalence(self, **kwargs): | ||
| pass | ||
|
|
||
| @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) |
There was a problem hiding this comment.
without this parameterized, the intended overwriting was not happening
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
LysandreJik
left a comment
There was a problem hiding this comment.
Thank you! Please merge once @zucchini-nlp has approved as she knows this code more than I.
cc @BenjaminBossan as well
zucchini-nlp
left a comment
There was a problem hiding this comment.
LGTM, thanks for cleaning up warnings! Left one question about HybridCache, since I was reluctant to add seq-length for that cache type where lengths are not consistent over layers
| dtype=inputs_embeds.dtype, | ||
| ) | ||
|
|
||
| if cache_position is None: |
There was a problem hiding this comment.
okey, this should always work actually since the seq length gets layer_idx=0. Just one question, isn't it a bit misleading if some layers will have get_seq_length() number of tokens while others no more than sliding window length?
|
I'm not qualified to review this but thanks for addressing this so quickly. |
What does this PR do?
Related to #33541
The warning in question should only be thrown in the case we are converting from a legacy cache, which will be deprecated soon. Gemma 2 doesn't support the legacy cache format, so no warning should ever be thrown :)
In the process, updates a few related inconsistencies.
✅ slow
gemma2tests ran locally. There are a few failures (also present on main). Some failures were fixed in this PR.