New cache tests and modular Hybrid Cache#37972
New cache tests and modular Hybrid Cache#37972manueldeprada merged 30 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
gante
left a comment
There was a problem hiding this comment.
In general looks good to me, especially the added tests. 👍
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for working on this!
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
gante
left a comment
There was a problem hiding this comment.
I like the hardcoded tests better (easier to follow) 👍 added a few comments to align code style with other tests
|
All suggestions applied and all the tests moved into clear "hardcoded" ones!! thanks a lot for the feedback :) |
|
Thanks @Cyrilvallez for quickly merging the fix into main(#38046)! In hindsight, I could’ve split the PR into the fix and the refactor + tests. That said, I think this version is better long-term: having the sliding logic in one place with clear names and comments: transformers/src/transformers/cache_utils.py Lines 97 to 105 in 8548b8f |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Hey! Thanks for working on this! Just a few thoughts/performance tips! 🤗
| key_states = key_states.to(self.key_cache[layer_idx].dtype) | ||
| value_states = value_states.to(self.value_cache[layer_idx].dtype) |
There was a problem hiding this comment.
The dtype should already be correct here no?
There was a problem hiding this comment.
I am happy to remove it, passes tests. There are similar checks or casts that could be removed too, but I kept them in case existing code relies on them.
There was a problem hiding this comment.
I think we can remove the cast, yes. Looking at the original PR, the source of the lines was to handle the case where we don't cast RoPE-based KVs in the model forward pass (RoPE is by default FP32, regardless of the model dtype)
There was a problem hiding this comment.
thanks for the pointer @gante, I should have traced that down. We can't remove it: the PR's sample code fails after removing the casts. I am restoring the cast and adding a test...
I agree though with @Cyrilvallez that it is a bad solution to cast everything instead of doing something specific for RoPE. Since this PR has limited scope and is ground work for #38077, I will try to solve it more elegantly there.
There was a problem hiding this comment.
[perhaps for a subsequent PR, to avoid bloating/delaying this one:]
It would be more transparent and precise if casting is done in the model architecture, rather than in the cache. In the specific case of GPT-J loaded in FP16, it seems like without cache KV is kept in FP32, and with cache KV is casted to FP16 in the cache class -> cache introduces performance degradation.
As such, I think it would be positive to remove the cast, and delegate control to the model architectures.
There was a problem hiding this comment.
Good, I will check if its only a gptJ thing in a subsequent PR. In the meantime, I added the test. It unveiled 3 fixes like this that were applied to StaticCache but not to Hybrid and Offloaded. Please have a quick look at 772b0a0 before I merge.
There was a problem hiding this comment.
Especially if it's only there for a given old model -> much better to fix the model rather than general cache logic!
What does this PR do?
Extracts the core update logic into two new utility functions:
_static_cache_update_logicand_sliding_cache_update_logic. This way, there is only one implementation for StaticCache, SlidingWindowCache, and HybridCache.@ArthurZucker @gante this is a first step towards per-layer modular cache definitions.
This is preliminary work for #38077