From 5e402264b450f8e81f1d7cb6b625ef33bb8f3648 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 12 Aug 2025 11:52:43 +0200 Subject: [PATCH 1/2] switch order for BC and future logic --- src/transformers/cache_utils.py | 28 ++++++++++----------- src/transformers/integrations/executorch.py | 4 +-- tests/models/llama/test_modeling_llama.py | 4 +-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 4887c88826e5..8e3aff3c8bd3 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1148,15 +1148,15 @@ class StaticCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = StaticCache(max_cache_len=max_generated_length, config=model.config) + >>> past_key_values = StaticCache(config=model.config, max_cache_len=max_generated_length) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation StaticCache() ``` """ - # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before) - def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs): + # Pass-in args and kwargs as well to avoid crashing for BC (it used more arguments before) + def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs): layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)] super().__init__(layers=layers) @@ -1183,15 +1183,15 @@ class OffloadedStaticCache(Cache): >>> # Prepare a cache class with offloading >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = OffloadedStaticCache(max_cache_len=max_generated_length, config=model.config) + >>> past_key_values = OffloadedStaticCache(config=model.config, max_cache_len=max_generated_length) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache with offloaded layers OffloadedStaticCache() ``` """ - # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before) - def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs): + # Pass-in args and kwargs as well to avoid crashing for BC (it used more arguments before) + def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs): layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)] super().__init__(layers=layers, offloading=True) @@ -1214,15 +1214,15 @@ class SlidingWindowCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(max_cache_len=max_generated_length, config=model.config) + >>> past_key_values = SlidingWindowCache(config=model.config, max_cache_len=max_generated_length) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation SlidingWindowCache() ``` """ - # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before) - def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs): + # Pass-in args and kwargs as well to avoid crashing for BC (it used more arguments before) + def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs): layers = [SlidingWindowLayer(max_cache_len, config.sliding_window) for _ in range(config.num_hidden_layers)] super().__init__(layers=layers) @@ -1249,15 +1249,15 @@ class HybridCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(max_cache_len=max_generated_length, config=model.config) + >>> past_key_values = HybridCache(config=model.config, max_cache_len=max_generated_length) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation HybridCache() ``` """ - # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before) - def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs): + # Pass-in args and kwargs as well to avoid crashing for BC (it used more arguments before) + def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs): if hasattr(config, "layer_types"): layers = [] for layer_type in config.layer_types: @@ -1288,8 +1288,8 @@ class OffloadedHybridCache(Cache): See `Cache` for details on common methods that are implemented by all cache classes. """ - # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before) - def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs): + # Pass-in args and kwargs as well to avoid crashing for BC (it used more arguments before) + def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs): if hasattr(config, "layer_types"): layers = [] for layer_type in config.layer_types: diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 0afe61ca78f5..a56a48ab844d 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -670,7 +670,7 @@ def __init__( raise AssertionError("Model must have caching enabled.") # Initialize the HybridCache - self.cache = HybridCache(max_cache_len=generation_config.cache_config.get("max_cache_len"), config=config) + self.cache = HybridCache(config=config, max_cache_len=generation_config.cache_config.get("max_cache_len")) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) max_batch_size = generation_config.cache_config.get("batch_size") @@ -818,7 +818,7 @@ def __init__(self, model, max_static_cache_length, batch_size): self.config = model.config # Initialize static cache for decoder and DynamicCache for encoder - self.static_cache = StaticCache(max_cache_len=max_static_cache_length, config=self.config) + self.static_cache = StaticCache(config=self.config, max_cache_len=max_static_cache_length) head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, "cpu") diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index d58837cc0fbd..26be82b9da82 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -504,7 +504,7 @@ def test_stacked_causal_mask_static_cache(self): # upgrade the model with StaticCache max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache(max_cache_len=max_cache_len, config=self.model.config) + past_key_values = StaticCache(config=self.model.config, max_cache_len=max_cache_len) padded_attention_mask = torch.nn.functional.pad( input=mask_shared_prefix, @@ -546,7 +546,7 @@ def test_partial_stacked_causal_mask_static_cache(self): # upgrade the model with StaticCache max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache(max_cache_len=max_cache_len, config=self.model.config) + past_key_values = StaticCache(config=self.model.config, max_cache_len=max_cache_len) # forward run for the first part of input part_a = 3 # split point From b14bd0ad526b9869409b11fb8d9025528671e246 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 12 Aug 2025 11:58:05 +0200 Subject: [PATCH 2/2] in generate as well --- src/transformers/generation/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 177cfeeda6a6..eabc6f2926d3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1859,10 +1859,7 @@ def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len: ) if need_new_cache: - cache_kwargs = { - "max_cache_len": max_cache_len, - "config": self.config.get_text_config(), - } + cache_kwargs = {"config": self.config.get_text_config(), "max_cache_len": max_cache_len} self._cache = cache_cls(**cache_kwargs) if requires_cross_attention_cache: encoder_kwargs = cache_kwargs.copy()