From 72ba76bdfe0009426dcd10a1cef7f34be4af0350 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 9 Apr 2025 14:02:56 +0000 Subject: [PATCH 1/6] support smaller bs --- src/transformers/cache_utils.py | 77 ++++++++++++++++------------ src/transformers/generation/utils.py | 26 +++------- 2 files changed, 51 insertions(+), 52 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 29a30f3ab70e..46311d25687c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1139,15 +1139,14 @@ def update( class StaticCache(Cache): """ - Static Cache class to be used with `torch.compile(model)` and `torch.export()`. + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. At initialization, the cache is + preallocated to its maximum possible shape, but can it be used with any shape that fits in it. Parameters: config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. If you are manually setting the batch size, make sure to take into account the - number of beams if you are running beam search + The maximum batch size with which the model will be used. max_cache_len (`int`, *optional*): The maximum sequence length with which the model will be used. device (`torch.device` or `str`, *optional*): @@ -1253,8 +1252,10 @@ def update( if cache_kwargs is None: cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] + batch_size = key_states.shape[0] + + k_out = self.key_cache[layer_idx][:batch_size, ...] + v_out = self.value_cache[layer_idx][:batch_size, ...] key_states = key_states.to(k_out.dtype) value_states = value_states.to(v_out.dtype) @@ -1295,10 +1296,14 @@ def reset(self): class SlidingWindowCache(StaticCache): """ - Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. - Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, - if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), - we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window + attention. As in `StaticCache`, the cache is preallocated to its maximum possible shape, but can it be used with + any shape that fits in it. + + Every time when we try to update the cache, we compute the `indices` based on + `cache_position >= self.config.sliding_window - 1`, if true(which means the cache can not hold all the old key + value states and new states together because of the sliding window constraint), we need to do a cycle shift based + on `indices` to replace the oldest states by the new key value states passed in. The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: @@ -1314,8 +1319,7 @@ class SlidingWindowCache(StaticCache): config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. + The maximum batch size with which the model will be used. max_cache_len (`int`, *optional*): The maximum sequence length with which the model will be used. device (`torch.device` or `str`, *optional*): @@ -1386,8 +1390,10 @@ def update( if cache_kwargs is None: cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] + batch_size = key_states.shape[0] + + k_out = self.key_cache[layer_idx][:batch_size, ...] + v_out = self.value_cache[layer_idx][:batch_size, ...] key_states = key_states.to(k_out.dtype) value_states = value_states.to(v_out.dtype) @@ -1614,14 +1620,13 @@ class HybridCache(Cache): Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window attention and global attention in every other layer (originally implemented for Gemma2). Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention.For more information, see the documentation of each subcomponent cache class. + for global attention. For more information, see the documentation of each subcomponent cache class. Parameters: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. + The maximum batch size with which the model will be used. max_cache_len (`int`, *optional*): The maximum sequence length with which the model will be used. device (`torch.device` or `str`, *optional*): @@ -1762,6 +1767,7 @@ def update( cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") sliding_window = cache_kwargs.get("sliding_window") + batch_size = key_states.shape[0] # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used # when the cache is initialized in the forward pass (e.g. Gemma2) @@ -1770,8 +1776,8 @@ def update( if self.value_cache[layer_idx].device != value_states.device: self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] + k_out = self.key_cache[layer_idx][:batch_size, ...] + v_out = self.value_cache[layer_idx][:batch_size, ...] key_states = key_states.to(k_out.dtype) value_states = value_states.to(v_out.dtype) @@ -1824,8 +1830,7 @@ class HybridChunkedCache(Cache): config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. + The maximum batch size with which the model will be used. max_cache_len (`int`, *optional*): The maximum sequence length with which the model will be used. device (`torch.device` or `str`, *optional*): @@ -1898,12 +1903,7 @@ def initialise_cache_layer(self, layer_idx, key_states): num_key_value_heads = key_states.shape[1] device = key_states.device global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = ( - self.max_batch_size, - num_key_value_heads, - self.sliding_window, - self.head_dim, - ) + sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape @@ -1966,10 +1966,11 @@ def update( if cache_kwargs is None: cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") + batch_size = key_states.shape[0] self.initialise_cache_layer(layer_idx, key_states) - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] + k_out = self.key_cache[layer_idx][:batch_size, ...] + v_out = self.value_cache[layer_idx][:batch_size, ...] key_states = key_states.to(k_out.dtype) value_states = value_states.to(v_out.dtype) @@ -2127,13 +2128,14 @@ def _prefetch_layer_in_context(self, layer_idx: int) -> None: class MambaCache: """ - Cache for mamba model which does not have attention mechanism and key value states. + Cache for mamba model which does not have attention mechanism and key value states. At initialization, the cache + is preallocated to its maximum possible shape, but can it be used with any shape that fits in it. Arguments: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + The maximum batch size with which the model will be used. dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): The default `dtype` to use when initializing the layer. device (`torch.device` or `str`, *optional*): @@ -2205,7 +2207,9 @@ def update_conv_state( if self.conv_states[layer_idx].device != new_conv_state.device: self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) - conv_state = self.conv_states[layer_idx] + batch_size = new_conv_state.shape[0] + + conv_state = self.conv_states[layer_idx][:batch_size, ...] cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) conv_state = conv_state.roll(shifts=-1, dims=-1) @@ -2227,8 +2231,9 @@ def reset(self): class OffloadedStaticCache(StaticCache): """ - Static cache class to be used with `torch.compile(model)` that offloads to the CPU or - another device. + Static cache class to be used with `torch.compile(model)` that offloads to the CPU or another device. As in + `StaticCache`, the cache is preallocated to its maximum possible shape, but can it be used with any shape that + fits in it. Args: config (`PretrainedConfig): @@ -2374,6 +2379,10 @@ def update( self._prefetch_layer(layer_idx + 1) cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None + batch_size = key_states.shape[0] + k_out = k_out[:batch_size, ...] + v_out = v_out[:batch_size, ...] + if cache_position is None: k_out.copy_(key_states) v_out.copy_(value_states) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ee5b79e6d3ac..04d07fd05c1b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1855,7 +1855,7 @@ def _get_cache( cache_dtype = self.dtype layer_device_map = self._get_layer_device_map_for_cache_init() - cache_kwargs = { + all_possible_cache_kwargs = { "config": self.config.get_text_config(), "max_batch_size": batch_size, "max_cache_len": max_cache_len, @@ -1863,6 +1863,8 @@ def _get_cache( "device": device, "layer_device_map": layer_device_map, } + cache_signature = inspect.signature(cache_cls.__init__) + cache_kwargs = {k: v for k, v in all_possible_cache_kwargs.items() if k in cache_signature.parameters} self._cache = cache_cls(**cache_kwargs) if requires_cross_attention_cache: encoder_kwargs = cache_kwargs.copy() @@ -1872,21 +1874,6 @@ def _get_cache( self._cache.reset() return self._cache - def _supports_default_dynamic_cache(self) -> bool: - """ - Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. - This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which - uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in - order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed - for `HybridMambaAttentionDynamicCache`). - """ - return ( - self._supports_cache_class - and "jamba" not in self.__class__.__name__.lower() - and "zamba" not in self.__class__.__name__.lower() - and "bamba" not in self.__class__.__name__.lower() - ) - def _prepare_cache_for_generation( self, generation_config: GenerationConfig, @@ -1916,7 +1903,7 @@ def _prepare_cache_for_generation( f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " "Cache object) is unsupported. Please use only one of the two." ) - if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache(): + if isinstance(user_defined_cache, tuple) and self._supports_cache_class: model_kwargs[cache_name] = ( DynamicCache.from_legacy_cache(user_defined_cache) if not requires_cross_attention_cache @@ -1930,7 +1917,10 @@ def _prepare_cache_for_generation( return # Quick escape route 3: model that only supports legacy caches = nothing to prepare - if not self._supports_default_dynamic_cache(): + is_model_with_custom_cache = any( + model_name in self.__class__.__name__.lower() for model_name in ["mamba", "jamba", "zamba", "bamba"] + ) + if not self._supports_cache_class and not is_model_with_custom_cache: if generation_config.cache_implementation is not None: warnings.warn( "This model does not support `Cache` instances, it only supports the legacy cache format (tuple " From 10304941562e8ae65db370cf2b9ddb9a8c69442d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 9 Apr 2025 16:32:14 +0000 Subject: [PATCH 2/6] add tests --- src/transformers/cache_utils.py | 18 +++---- src/transformers/generation/utils.py | 2 +- tests/generation/test_utils.py | 20 +++++++ tests/utils/test_cache_utils.py | 80 ++++++++++++++-------------- 4 files changed, 70 insertions(+), 50 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 46311d25687c..33f820ed0762 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -2129,10 +2129,11 @@ def _prefetch_layer_in_context(self, layer_idx: int) -> None: class MambaCache: """ Cache for mamba model which does not have attention mechanism and key value states. At initialization, the cache - is preallocated to its maximum possible shape, but can it be used with any shape that fits in it. + is preallocated to its maximum possible shape. Contrarily to other caches, `max_batch_size` must match the + batch size used at inference time. Arguments: - config (`PretrainedConfig): + config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): The maximum batch size with which the model will be used. @@ -2206,10 +2207,7 @@ def update_conv_state( # when the cache is initialized in the forward pass (e.g. Mamba) if self.conv_states[layer_idx].device != new_conv_state.device: self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) - - batch_size = new_conv_state.shape[0] - - conv_state = self.conv_states[layer_idx][:batch_size, ...] + conv_state = self.conv_states[layer_idx] cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) conv_state = conv_state.roll(shifts=-1, dims=-1) @@ -2411,13 +2409,13 @@ def update( value_states = value_states.to(self.offload_device) try: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + self.key_cache[layer_idx][:batch_size, ...].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx][:batch_size, ...].index_copy_(2, cache_position, value_states) except NotImplementedError: # The operator 'aten::index_copy.out' is not currently implemented for the MPS # device. - self.key_cache[layer_idx][:, :, cache_position] = key_states - self.value_cache[layer_idx][:, :, cache_position] = value_states + self.key_cache[layer_idx][:batch_size, :, cache_position] = key_states + self.value_cache[layer_idx][:batch_size, :, cache_position] = value_states return k_out, v_out diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 04d07fd05c1b..4610a3ab2d96 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1834,7 +1834,7 @@ def _get_cache( need_new_cache = ( not hasattr(self, "_cache") or (not isinstance(cache_to_check, cache_cls)) - or cache_to_check.max_batch_size != batch_size + or cache_to_check.max_batch_size < batch_size or isinstance( cache_to_check, (HybridChunkedCache, OffloadedHybridCache) ) # due to internal slicing, we always re-init diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0672589769ad..81889fa82fe7 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -4855,6 +4855,26 @@ def test_cache_device_map_with_vision_layer_device_map(self): # If the generate doesn't infer the DECODER device map correctly, this will fail _ = model.generate(**inputs, max_new_tokens=2, do_sample=False) + @slow + def test_large_cache_is_reused_with_smaller_batch_size(self): + """ + Test that a large compilable cache is reused with a smaller batch size. + """ + model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(model_repo) + + inputs_2 = tokenizer(["foo bar"] * 2, return_tensors="pt").to(model.device) + inputs_1 = tokenizer(["foo bar"], return_tensors="pt").to(model.device) + + # Generate with a large batch size, then with a smaller one + _ = model.generate(**inputs_2, max_new_tokens=3, do_sample=False, cache_implementation="static") + _ = model.generate(**inputs_1, max_new_tokens=3, do_sample=False, cache_implementation="static") + + # What is expected: the cache is reused, i.e. it retains the batch size of the first generation call + self.assertIsInstance(model._cache, StaticCache) + self.assertEqual(model._cache.max_batch_size, 2) + @require_torch class TokenHealingTestCase(unittest.TestCase): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index bbf0268c8b48..903407c75459 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -48,6 +48,7 @@ StaticCache, convert_and_export_with_cache, ) + from transformers.generation.configuration_utils import NEED_SETUP_CACHE_CLASSES_MAPPING from transformers.utils import is_torch_greater_or_equal @@ -285,8 +286,8 @@ def test_static_cache_exportability(self): @require_torch_accelerator -@slow class CacheIntegrationTest(unittest.TestCase): + @slow def test_dynamic_cache_hard(self): tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left") model = AutoModelForCausalLM.from_pretrained( @@ -316,6 +317,7 @@ def test_dynamic_cache_hard(self): ) self.assertEqual(decoded[0], expected_text) + @slow def test_dynamic_cache_batched(self): tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left") tokenizer.pad_token = tokenizer.eos_token @@ -352,6 +354,7 @@ def test_dynamic_cache_beam_search(self): ] self.assertListEqual(decoded, expected_text) + @slow def test_hybrid_cache_n_sequences(self): tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") model = AutoModelForCausalLM.from_pretrained( @@ -379,6 +382,7 @@ def test_hybrid_cache_n_sequences(self): @require_non_xpu @require_gptq + @slow def test_sink_cache_hard(self): tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ") model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto") @@ -392,6 +396,7 @@ def test_sink_cache_hard(self): decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) + @slow def test_sink_cache_iterative_prompts(self): """Tests that SinkCache supports more than one new token at once, when shifting the cache""" tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") @@ -441,6 +446,7 @@ def test_sink_cache_iterative_prompts(self): ("sdpa", "static"), ] ) + @slow def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation): EXPECTED_GENERATION = [ "The best color is the one that complements the skin tone of the", @@ -479,44 +485,7 @@ def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_ with self.subTest(f"{attn_implementation}, static, compiled"): self.assertListEqual(decoded, EXPECTED_GENERATION) - @require_torch_gpu - @parameterized.expand( - [ - ("eager", "static"), - ("sdpa", "static"), - ] - ) - def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation): - EXPECTED_GENERATION = [ - "The best color isЋ the one that complements the skin tone of", - "We should not undermind the issues at hand.\nWe should not undermind the issues", - ] - - tokenizer = AutoTokenizer.from_pretrained( - "NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="" - ) - model = AutoModelForCausalLM.from_pretrained( - "NousResearch/Llama-2-7b-chat-hf", - torch_dtype=torch.bfloat16, - attn_implementation=attn_implementation, - ).to(torch_device) - inputs = tokenizer( - ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" - ).to(model.device) - - set_seed(0) - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - with self.subTest(f"{attn_implementation}, dynamic"): - self.assertListEqual(decoded, EXPECTED_GENERATION) - - set_seed(0) - model.generation_config.cache_implementation = cache_implementation - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - with self.subTest(f"{attn_implementation}, static, eager"): - self.assertListEqual(decoded, EXPECTED_GENERATION) - + @slow def test_dynamic_cache_extra_left_padding(self): """Tests that adding extra left-padding does not affect the generation with the dynamic cache""" EXPECTED_GENERATION = [ @@ -556,6 +525,7 @@ def test_dynamic_cache_extra_left_padding(self): "static", ] ) + @slow def test_static_cache_extra_left_padding(self, cache_implementation): """Tests that adding extra left-padding does not affect the generation with the static cache""" EXPECTED_GENERATION = [ @@ -597,6 +567,7 @@ def test_static_cache_beam_search(self): pass @require_torch_accelerator + @slow def test_offloaded_cache_equivalent_to_dynamic_cache(self): """Tests that OffloadedCache produces the same result as the default DynamicCache""" model_name = "microsoft/Phi-3-mini-4k-instruct" @@ -625,6 +596,7 @@ def test_offloaded_cache_equivalent_to_dynamic_cache(self): assert torch.all(original_output == offloaded_output).item() @require_torch_accelerator + @slow def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self): """Tests that OffloadedCache uses less memory than the default DynamicCache""" model_name = "microsoft/Phi-3-mini-4k-instruct" @@ -664,6 +636,7 @@ def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self): assert offloaded_peak_memory < original_peak_memory @require_torch_gpu + @slow def test_cache_copy(self): model_name = "microsoft/Phi-3-mini-4k-instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -743,8 +716,10 @@ def test_static_cache_no_cuda_graph_skips(self): with CaptureStderr() as cap: model.generate(**inputs, max_new_tokens=2, cache_implementation="static") self.assertEqual(cap.err, "") + self.assertTrue(hasattr(model, "_compiled_call")) # Our auto compile should have been called @require_torch_multi_gpu + @slow def test_static_cache_multi_gpu(self): """Regression test for #35164: static cache with multi-gpu""" @@ -764,3 +739,30 @@ def test_static_cache_multi_gpu(self): inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0) _ = model(**inputs) _ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid") + + @parameterized.expand( + (name, cache_cls) + for name, cache_cls in NEED_SETUP_CACHE_CLASSES_MAPPING.items() + if name != "mamba" # `MambaCache` doesn't support the feature tested here + ) + def test_compilable_cache_smaller_batch_size(self, name, cache_cls): + """ + Tests that compilable caches, whose shape needs to be set in advance, can be used with smaller batch sizes. + """ + model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM" # has sliding window + model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(model_repo) + inputs_ids = tokenizer(["foo bar"], return_tensors="pt").input_ids.to(torch_device) + cache_position = torch.arange(inputs_ids.shape[1]).to(torch_device) + + # cache with a large batch size, >> input batch size (1) + cache = cache_cls( + config=model.config, + max_batch_size=16, + max_cache_len=20, + device=torch_device, + dtype=model.dtype, + ) + + # the forward pass should work with this cache, even though the input batch size is smaller than the cache's + _ = model(inputs_ids, cache_position=cache_position, past_key_values=cache) From ff61203c9cace2118b6ca6a2cb627b1a47d22fa1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 9 Apr 2025 16:46:36 +0000 Subject: [PATCH 3/6] revert change --- src/transformers/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 33f820ed0762..ca7d3d43099d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -2136,7 +2136,7 @@ class MambaCache: config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): - The maximum batch size with which the model will be used. + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): The default `dtype` to use when initializing the layer. device (`torch.device` or `str`, *optional*): From b2ba3252981bf763b1e8f5b104fafe532f44f649 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 9 Apr 2025 16:46:58 +0000 Subject: [PATCH 4/6] revert change --- src/transformers/cache_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ca7d3d43099d..0c69248d2de3 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -2207,6 +2207,7 @@ def update_conv_state( # when the cache is initialized in the forward pass (e.g. Mamba) if self.conv_states[layer_idx].device != new_conv_state.device: self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + conv_state = self.conv_states[layer_idx] cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) From 83c82a1ba9145f18763c4a1d564978473506cc7f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 9 Apr 2025 17:03:13 +0000 Subject: [PATCH 5/6] offloaded cache doesn't work on all models --- src/transformers/generation/utils.py | 2 +- tests/generation/test_utils.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4610a3ab2d96..dfba063c61cf 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1991,7 +1991,7 @@ def _prepare_cache_for_generation( # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory - else: + elif not is_model_with_custom_cache: model_kwargs[cache_name] = ( DynamicCache() if not requires_cross_attention_cache diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 81889fa82fe7..fed32a04ab49 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1955,6 +1955,8 @@ def test_offloaded_cache_implementation(self, cache_implementation): for model_class in self.all_generative_model_classes: if not model_class._supports_cache_class: self.skipTest(reason="This model does not support the new cache format") + if any(model_name in model_class.__name__.lower() for model_name in ["mamba", "jamba", "zamba", "bamba"]): + self.skipTest(reason="This model does not support offloaded cache") config, inputs_dict = self.prepare_config_and_inputs_for_generate() From 883ac393c3efa6b00d3bbc08c3ec2dee17043a61 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 17 Apr 2025 13:15:10 +0000 Subject: [PATCH 6/6] harder test --- src/transformers/cache_utils.py | 61 +++++++++++++++++---------------- tests/utils/test_cache_utils.py | 13 +++++-- 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0c69248d2de3..da92c0afac6e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1682,9 +1682,7 @@ def __init__( self.max_cache_len = max_cache_len self.max_batch_size = max_batch_size # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads self._dtype = dtype self.num_key_value_heads = ( @@ -1721,12 +1719,13 @@ def __init__( self.value_cache.append(new_layer_value_cache) def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): + batch_size = key_states.shape[0] if cache_position.shape[0] > max_cache_len: k_out = key_states[:, :, -max_cache_len:, :] v_out = value_states[:, :, -max_cache_len:, :] # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out + self.key_cache[layer_idx][:batch_size, ...] += k_out + self.value_cache[layer_idx][:batch_size, ...] += v_out # we should return the whole states instead of k_out, v_out to take the whole prompt # into consideration when building kv cache instead of just throwing away tokens outside of the window return key_states, value_states @@ -1741,19 +1740,20 @@ def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k k_out[:, :, cache_position] = key_states v_out[:, :, cache_position] = value_states # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + self.key_cache[layer_idx][:batch_size, ...].zero_() + self.value_cache[layer_idx][:batch_size, ...].zero_() - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out + self.key_cache[layer_idx][:batch_size, ...] += k_out + self.value_cache[layer_idx][:batch_size, ...] += v_out return k_out, v_out def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): k_out[:, :, cache_position] = key_states v_out[:, :, cache_position] = value_states - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out + batch_size = key_states.shape[0] + self.key_cache[layer_idx][:batch_size, ...] = k_out + self.value_cache[layer_idx][:batch_size, ...] = v_out return k_out, v_out def update( @@ -1883,7 +1883,7 @@ def __init__( self.sliding_window = config.sliding_window self.max_cache_len = max_cache_len self.max_batch_size = max_batch_size - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads self._dtype = dtype if hasattr(config.get_text_config(), "no_rope_layers"): @@ -1916,6 +1916,7 @@ def initialise_cache_layer(self, layer_idx, key_states): def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): cumulative_length = self.cumulative_length[layer_idx] + batch_size = key_states.shape[0] # Update it now that we saved the value above self.cumulative_length[layer_idx] += key_states.shape[-2] is_full = cumulative_length >= max_cache_len @@ -1926,9 +1927,9 @@ def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed adress # in memory (the values are the same as the full states, but not the address!!) if key_states.shape[-2] == 1: - self.key_cache[layer_idx].copy_(full_key_states) - self.value_cache[layer_idx].copy_(full_value_states) - return self.key_cache[layer_idx], self.value_cache[layer_idx] + self.key_cache[layer_idx][:batch_size, ...].copy_(full_key_states) + self.value_cache[layer_idx][:batch_size, ...].copy_(full_value_states) + return self.key_cache[layer_idx][:batch_size, ...], self.value_cache[layer_idx][:batch_size, ...] elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: # Fast prefill path, no need to cat() in this case (which creates a copy even if cating from 0 dim) if cumulative_length == 0: @@ -1938,12 +1939,12 @@ def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2) full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2) else: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - return self.key_cache[layer_idx], self.value_cache[layer_idx] + self.key_cache[layer_idx][:batch_size, ...].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx][:batch_size, ...].index_copy_(2, cache_position, value_states) + return self.key_cache[layer_idx][:batch_size, ...], self.value_cache[layer_idx][:batch_size, ...] - self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) - self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) + self.key_cache[layer_idx][:batch_size, ...].copy_(full_key_states[:, :, -max_cache_len:, :]) + self.value_cache[layer_idx][:batch_size, ...].copy_(full_value_states[:, :, -max_cache_len:, :]) # we should return the whole states instead of k_out, v_out to take the whole prompt # into consideration when building kv cache instead of just throwing away tokens outside of the window return full_key_states, full_value_states @@ -1952,8 +1953,9 @@ def _static_update(self, cache_position, layer_idx, key_states, value_states, k_ k_out[:, :, cache_position] = key_states v_out[:, :, cache_position] = value_states - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out + batch_size = key_states.shape[0] + self.key_cache[layer_idx][:batch_size, ...] = k_out + self.value_cache[layer_idx][:batch_size, ...] = v_out return k_out, v_out def update( @@ -2074,13 +2076,14 @@ def initialise_cache_layer(self, layer_idx, key_states): self.device_value_cache.append(device_layer_value_cache) def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): + batch_size = key_states.shape[0] # Wait for prefetch stream if needed if self._prefetch_stream is not None: torch.cuda.default_stream(key_states.device).wait_stream(self._prefetch_stream) # Get correct on-device layer - k_out = self.device_key_cache[self.active_device_layer] - v_out = self.device_value_cache[self.active_device_layer] + k_out = self.device_key_cache[self.active_device_layer][:batch_size, ...] + v_out = self.device_value_cache[self.active_device_layer][:batch_size, ...] # Let's prefetch the next layer as soon as possible self._prefetch_next_layer(layer_idx) @@ -2090,8 +2093,8 @@ def _static_update(self, cache_position, layer_idx, key_states, value_states, k_ v_out[:, :, cache_position] = value_states # Copy to offloaded device - self.key_cache[layer_idx][:, :, cache_position] = key_states.to(self.offload_device) - self.value_cache[layer_idx][:, :, cache_position] = value_states.to(self.offload_device) + self.key_cache[layer_idx][:batch_size, :, cache_position] = key_states.to(self.offload_device) + self.value_cache[layer_idx][:batch_size, :, cache_position] = value_states.to(self.offload_device) return k_out, v_out @@ -2293,7 +2296,7 @@ def __init__( self._dtype = dtype if dtype is not None else torch.float32 # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads num_key_value_heads = ( config.num_attention_heads @@ -2388,8 +2391,8 @@ def update( # Copy the values to the offloaded device as well. if layer_idx == 0: - self.key_cache[layer_idx].copy_(key_states.to(self.offload_device)) - self.value_cache[layer_idx].copy_(value_states.to(self.offload_device)) + self.key_cache[layer_idx][:batch_size, ...].copy_(key_states.to(self.offload_device)) + self.value_cache[layer_idx][:batch_size, ...].copy_(value_states.to(self.offload_device)) else: # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 903407c75459..91333a35db7c 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -747,18 +747,20 @@ def test_static_cache_multi_gpu(self): ) def test_compilable_cache_smaller_batch_size(self, name, cache_cls): """ - Tests that compilable caches, whose shape needs to be set in advance, can be used with smaller batch sizes. + Tests that compilable caches, whose shape need to be set in advance, can be used with smaller batch sizes. """ - model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM" # has sliding window + # Mistral has sliding window, can test related caches + model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM" model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device) tokenizer = AutoTokenizer.from_pretrained(model_repo) inputs_ids = tokenizer(["foo bar"], return_tensors="pt").input_ids.to(torch_device) cache_position = torch.arange(inputs_ids.shape[1]).to(torch_device) # cache with a large batch size, >> input batch size (1) + batch_size = 16 cache = cache_cls( config=model.config, - max_batch_size=16, + max_batch_size=batch_size, max_cache_len=20, device=torch_device, dtype=model.dtype, @@ -766,3 +768,8 @@ def test_compilable_cache_smaller_batch_size(self, name, cache_cls): # the forward pass should work with this cache, even though the input batch size is smaller than the cache's _ = model(inputs_ids, cache_position=cache_position, past_key_values=cache) + + # if we expand the input batch size to the cache's batch size, the same cache can be reused + cache.reset() + inputs_ids = torch.cat([inputs_ids] * batch_size, dim=0) + _ = model(inputs_ids, cache_position=cache_position, past_key_values=cache)