Skip to content

New cache tests and modular Hybrid Cache#37972

Merged
manueldeprada merged 30 commits intohuggingface:mainfrom
manueldeprada:cache-fix2
May 20, 2025
Merged

New cache tests and modular Hybrid Cache#37972
manueldeprada merged 30 commits intohuggingface:mainfrom
manueldeprada:cache-fix2

Conversation

@manueldeprada
Copy link
Copy Markdown
Contributor

@manueldeprada manueldeprada commented May 6, 2025

What does this PR do?

  1. Refactor out the cache update logic for static and sliding window attention mechanisms.
    Extracts the core update logic into two new utility functions: _static_cache_update_logic and _sliding_cache_update_logic. This way, there is only one implementation for StaticCache, SlidingWindowCache, and HybridCache.
    @ArthurZucker @gante this is a first step towards per-layer modular cache definitions.
  2. Added new synthetic tests for caches. Fixes Wrong KV cache update for sliding-window attention (SWA) layers when total sequence length reaches window size #37574 and should catch similar bugs.

This is preliminary work for #38077

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general looks good to me, especially the added tests. 👍

Comment thread src/transformers/cache_utils.py
Comment thread src/transformers/cache_utils.py Outdated
Comment thread src/transformers/cache_utils.py Outdated
Comment thread src/transformers/cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
@manueldeprada manueldeprada requested a review from gante May 6, 2025 17:19
@manueldeprada manueldeprada marked this pull request as ready for review May 7, 2025 07:50
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

Comment thread src/transformers/cache_utils.py Outdated
Comment thread src/transformers/cache_utils.py
Comment thread tests/utils/test_cache_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the hardcoded tests better (easier to follow) 👍 added a few comments to align code style with other tests

Comment thread tests/utils/test_cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
@manueldeprada
Copy link
Copy Markdown
Contributor Author

All suggestions applied and all the tests moved into clear "hardcoded" ones!! thanks a lot for the feedback :)

@manueldeprada
Copy link
Copy Markdown
Contributor Author

Thanks @Cyrilvallez for quickly merging the fix into main(#38046)! In hindsight, I could’ve split the PR into the fix and the refactor + tests.

That said, I think this version is better long-term: having the sliding logic in one place with clear names and comments:

current_seq_len = cache_position[-1] + 1 # Use last position to determine current length
to_shift = current_seq_len > max_cache_len
indices = (slicing + to_shift.sum()) % max_cache_len
k_out_shifted = k_cache[:, :, indices]
v_out_shifted = v_cache[:, :, indices]
# Clamp cache_position to determine the *target index* within the shifted cache view
update_position = cache_position.clamp(min=0, max=max_cache_len - 1)

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for working on this! Just a few thoughts/performance tips! 🤗

Comment thread src/transformers/cache_utils.py
Comment thread src/transformers/cache_utils.py
Comment thread src/transformers/cache_utils.py
Comment thread src/transformers/cache_utils.py
Comment on lines +1364 to +1365
key_states = key_states.to(self.key_cache[layer_idx].dtype)
value_states = value_states.to(self.value_cache[layer_idx].dtype)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype should already be correct here no?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to remove it, passes tests. There are similar checks or casts that could be removed too, but I kept them in case existing code relies on them.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the cast, yes. Looking at the original PR, the source of the lines was to handle the case where we don't cast RoPE-based KVs in the model forward pass (RoPE is by default FP32, regardless of the model dtype)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the pointer @gante, I should have traced that down. We can't remove it: the PR's sample code fails after removing the casts. I am restoring the cast and adding a test...

I agree though with @Cyrilvallez that it is a bad solution to cast everything instead of doing something specific for RoPE. Since this PR has limited scope and is ground work for #38077, I will try to solve it more elegantly there.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[perhaps for a subsequent PR, to avoid bloating/delaying this one:]

It would be more transparent and precise if casting is done in the model architecture, rather than in the cache. In the specific case of GPT-J loaded in FP16, it seems like without cache KV is kept in FP32, and with cache KV is casted to FP16 in the cache class -> cache introduces performance degradation.

As such, I think it would be positive to remove the cast, and delegate control to the model architectures.

Copy link
Copy Markdown
Contributor Author

@manueldeprada manueldeprada May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, I will check if its only a gptJ thing in a subsequent PR. In the meantime, I added the test. It unveiled 3 fixes like this that were applied to StaticCache but not to Hybrid and Offloaded. Please have a quick look at 772b0a0 before I merge.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially if it's only there for a given old model -> much better to fix the model rather than general cache logic!

Comment thread src/transformers/cache_utils.py
Comment thread src/transformers/cache_utils.py
Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating 🤗

Comment thread tests/utils/test_cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
Comment thread tests/utils/test_cache_utils.py Outdated
@manueldeprada manueldeprada merged commit d34e21e into huggingface:main May 20, 2025
20 checks passed
faaany pushed a commit to faaany/transformers that referenced this pull request May 21, 2025
xvyv99 pushed a commit to xvyv99/transformers that referenced this pull request May 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Wrong KV cache update for sliding-window attention (SWA) layers when total sequence length reaches window size

5 participants