Avoid incorrect generations for KV caches containing more than sliding_window tokens#38156
Avoid incorrect generations for KV caches containing more than sliding_window tokens#38156TimFelixBeyer wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
cc @gante as well! |
ArthurZucker
left a comment
There was a problem hiding this comment.
sounds good, we are gonna adress this with the cache refactor! #38077
|
This is being fixed more generally in the cache refactor, yes. Thanks for the pointer! I will add a test for it. @TimFelixBeyer do you have a simple test snippet at hand? I think it is better to wait for the general fix rather than adding a safeguard just for Gemma. |
|
As announced, this was fixed with cache refactors. PR can be closed! |
|
okay, closing |
What does this PR do?
Gemma3 generates incoherent output when manually calling
forwardwith an instance ofHybridCachewhich contains more thansliding_windowtokens of content.This is because the call to
past_key_values.get_seq_len()always returns the sequence length as measured by the cache of the very first layer. Because this is a local attention layer, itssequence_lengthnever extends beyondconfig.sliding_window. This leads to an incorrect computation ofcache_positionand incoherent generations down the line.To fix it you can simply provide the correct
cache_positionmanually.This behavior is impossible to fix without changing
get_seq_lenofHybridCache, so I propose to simply raise an informative error message for now.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker