[tests] Test all cache implementations#37873
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
|
||
| class SinkCache(Cache): | ||
| """ | ||
| Deprecated. |
There was a problem hiding this comment.
SinkCache has been broken on some edge cases for over a year, the issues are non-trivial to fix, and it is no longer relevant -- we can achieve a similar effect with a few other flags. See deprecation warning below.
| slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) | ||
| cache_position = cache_position.clamp(0, self.max_cache_len - 1) | ||
| to_shift = cache_position >= self.max_cache_len - 1 | ||
| to_shift = cache_position > self.max_cache_len - 1 |
There was a problem hiding this comment.
Off by one: we were applying the shifting update one token too early. This applies on the last token when we initialize the sliding window cache with the exact size of the generation (e.g. with model.generate(..., cache_implementation="sliding_window")).
This effectively means our models were micro-underperforming with sliding window caches, more specifically on the last generated token :D One of the new tests caught this issue.
There was a problem hiding this comment.
On first glance, this likely fixes the issue(s) raised in #37574 👀
There was a problem hiding this comment.
This change is wrong in general and leads to garbage generation on sequence > sliding window! I am opening a PR to revert with examples 😉 What you observed is the fact that prefill and later stages should be treated separately in terms of the states they return
There was a problem hiding this comment.
@Cyrilvallez You should give #37972 a look before :D
| "config and it's not set to None." | ||
| ) | ||
| self.max_cache_len = max_cache_len | ||
| self._sliding_window_max_len = min(config.sliding_window, max_cache_len) |
There was a problem hiding this comment.
HybridCache had the right pattern, but some of the other hybrid caches did not: generation was crashing if we tried to generate a max length < sliding window length. Caught by one of the new tests.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| @require_torch_accelerator | ||
| class CacheIntegrationTest(unittest.TestCase): | ||
| """Cache tests that require loading models""" | ||
| """Fast cache integration tests that share the same small model""" |
There was a problem hiding this comment.
Separated into two classes, to make best use of setUpClass. Loading the model is the most costly part of these tests, and we only do it once.
|
|
||
| # DynamicCache and the legacy cache format should be equivalent | ||
| set_seed(0) | ||
| gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256) |
There was a problem hiding this comment.
the default is now DynamicCache(), the two generate calls in this test were the same
| self.assertEqual(decoded[0], expected_text) | ||
|
|
||
| @slow | ||
| def test_dynamic_cache_batched(self): |
There was a problem hiding this comment.
adapted into CacheIntegrationTest
| self.assertListEqual(decoded, expected_text) | ||
|
|
||
| @slow | ||
| def test_dynamic_cache_beam_search(self): |
There was a problem hiding this comment.
adapted into CacheIntegrationTest
| self.assertListEqual(decoded, expected_text) | ||
|
|
||
| @slow | ||
| def test_hybrid_cache_n_sequences(self): |
There was a problem hiding this comment.
redundant with the tests in CacheIntegrationTest (more specifically, test_cache_batched and test_cache_beam_search)
| @require_non_xpu | ||
| @require_gptq | ||
| @slow | ||
| def test_sink_cache_hard(self): |
There was a problem hiding this comment.
test was broken and SinkCache is being deprecated
| self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) | ||
|
|
||
| @slow | ||
| def test_sink_cache_iterative_prompts(self): |
There was a problem hiding this comment.
test was broken and SinkCache is being deprecated
| self.assertListEqual(decoded, EXPECTED_GENERATION) | ||
|
|
||
| @slow | ||
| def test_dynamic_cache_extra_left_padding(self): |
There was a problem hiding this comment.
adapted into CacheIntegrationTest
| self.assertListEqual(decoded, EXPECTED_GENERATION) | ||
|
|
||
| @slow | ||
| def test_static_cache_extra_left_padding(self): |
There was a problem hiding this comment.
adapted into CacheIntegrationTest
|
|
||
| @require_torch_accelerator | ||
| @slow | ||
| def test_offloaded_cache_equivalent_to_dynamic_cache(self): |
There was a problem hiding this comment.
we implicitly test this in CacheIntegrationTest
| responses.append(response) | ||
|
|
||
| EXPECTED_DECODED_TEXT = [ | ||
| "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week", |
There was a problem hiding this comment.
if we checkout to the commit that added this test, we get a different output 👀 possibly due to different hardware/software? (anyway, I don't think it's worth to pin the exact cause)
| # on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped. | ||
| with CaptureStderr() as cap: | ||
| model.generate(**inputs, max_new_tokens=2, cache_implementation="static") | ||
| self.assertEqual(cap.err, "") |
There was a problem hiding this comment.
failing on main if we have kernels installed, this change makes the test green regardless of the installed packages
| self.skipTest("Quanto is not available") | ||
|
|
||
| if cache_implementation == "offloaded_hybrid_chunked": | ||
| # TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the |
There was a problem hiding this comment.
I don't think offloaded_hybrid_chunked + beam_search is worth the dive for now 🤔
There was a problem hiding this comment.
nope agree with you!
|
|
||
| from ...activations import ACT2FN | ||
| from ...cache_utils import Cache, DynamicCache, StaticCache | ||
| from ...cache_utils import Cache, DynamicCache |
There was a problem hiding this comment.
(same diff on all models)
ArthurZucker
left a comment
There was a problem hiding this comment.
Very nice! Thanks 🤗
Would be nice to have a fast test for the HybridChunked to make sure compile is fine using a dummy gemma2 model maybe?
TP is also an option to test 👀 but more of a TODO later!
|
|
||
| from ...activations import ACT2FN | ||
| from ...cache_utils import Cache, DynamicCache, StaticCache | ||
| from ...cache_utils import Cache, DynamicCache |
| self.skipTest("Quanto is not available") | ||
|
|
||
| if cache_implementation == "offloaded_hybrid_chunked": | ||
| # TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the |
There was a problem hiding this comment.
nope agree with you!
|
@ArthurZucker yeah, generalist cache + compile tests will be up next! :D |
What does this PR do?
The main purpose of this PR is to convert a few slow tests targeted at one cache implementation into fast tests that run on ALL cache implementations.
Secondarily, makes
RUN_SLOW=1 py.test tests/utils/test_cache_utils.pygreen 🟢 These tests also become much, much faster (3 mins -> 1 min, on my machine), despite covering a larger number of features.This is a follow up to #37684, which paved the way for this PR. After this PR is merged, I can go back to #37394 and properly test things!
👉 torch.compile was benchmarked with gemma2/hybrid and qwen3/static, no speed regressions.
👉 no regressions in
RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py