From 95edf6c6f2c4b8669fce03c1a8baf0d2764c8f7b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 9 Apr 2025 09:57:37 +0000 Subject: [PATCH] cache batch size arg --- benchmark/llama.py | 14 +-- docs/source/en/kv_cache.md | 6 +- docs/source/en/llm_optims.md | 4 +- docs/source/en/model_doc/gemma2.md | 8 +- docs/source/ko/llm_optims.md | 4 +- src/transformers/cache_utils.py | 93 ++++++++++--------- src/transformers/generation/utils.py | 4 +- src/transformers/integrations/executorch.py | 4 +- .../models/cohere2/modeling_cohere2.py | 2 +- .../models/cohere2/modular_cohere2.py | 2 +- .../models/gemma2/modeling_gemma2.py | 2 +- .../models/gemma2/modular_gemma2.py | 2 +- .../models/gemma3/modeling_gemma3.py | 2 +- .../models/gemma3/modular_gemma3.py | 2 +- tests/generation/test_utils.py | 6 +- .../diffllama/test_modeling_diffllama.py | 4 +- tests/models/llama/test_modeling_llama.py | 4 +- tests/models/mamba/test_modeling_mamba.py | 2 +- tests/models/phi3/test_modeling_phi3.py | 2 +- tests/models/phimoe/test_modeling_phimoe.py | 2 +- tests/utils/test_cache_utils.py | 8 +- 21 files changed, 90 insertions(+), 87 deletions(-) diff --git a/benchmark/llama.py b/benchmark/llama.py index 1857dee3d66b..bc91b29b581d 100644 --- a/benchmark/llama.py +++ b/benchmark/llama.py @@ -118,7 +118,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values): with torch.no_grad(): past_key_values = StaticCache( model.config, - max_batch_size=batch_size, + batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + num_tokens_to_generate, @@ -144,7 +144,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values): past_key_values = StaticCache( model.config, - max_batch_size=batch_size, + batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + num_tokens_to_generate, @@ -187,7 +187,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values): # TODO use decode_one_token(model, input_id.clone(), cache_position) for verification past_key_values = StaticCache( model.config, - max_batch_size=batch_size, + batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + num_tokens_to_generate + 10, @@ -254,7 +254,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values): past_key_values = StaticCache( model.config, - max_batch_size=batch_size, + batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + 128, @@ -271,7 +271,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values): past_key_values = StaticCache( model.config, - max_batch_size=batch_size, + batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + 128, @@ -287,7 +287,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values): past_key_values = StaticCache( model.config, - max_batch_size=batch_size, + batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + 128, @@ -303,7 +303,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values): past_key_values = StaticCache( model.config, - max_batch_size=batch_size, + batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + 128, diff --git a/docs/source/en/kv_cache.md b/docs/source/en/kv_cache.md index 36f82fb3dc9a..999d644a892b 100644 --- a/docs/source/en/kv_cache.md +++ b/docs/source/en/kv_cache.md @@ -336,9 +336,9 @@ model_id = "meta-llama/Llama-2-7b-chat-hf" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda") tokenizer = AutoTokenizer.from_pretrained(model_id) -# Init StaticCache with big enough max-length (1024 tokens for the below example) +# Init StaticCache with big enough max-length (1024 tokens for the below example) # You can also init a DynamicCache, if that suits you better -prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16) +prompt_cache = StaticCache(config=model.config, batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16) INITIAL_PROMPT = "You are a helpful assistant. " inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda") @@ -351,7 +351,7 @@ responses = [] for prompt in prompts: new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda") past_key_values = copy.deepcopy(prompt_cache) - outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20) + outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20) response = tokenizer.batch_decode(outputs)[0] responses.append(response) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index e8e20dab5db6..7c9bc154ab6e 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -93,7 +93,7 @@ model.generation_config.max_new_tokens = 16 past_key_values = StaticCache( config=model.config, - max_batch_size=1, + batch_size=1, # If you plan to reuse the cache, make sure the cache length is large enough for all cases max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), device=model.device, @@ -159,7 +159,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel batch_size, seq_length = inputs["input_ids"].shape with torch.no_grad(): past_key_values = StaticCache( - config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype + config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype ) cache_position = torch.arange(seq_length, device=torch_device) generated_ids = torch.zeros( diff --git a/docs/source/en/model_doc/gemma2.md b/docs/source/en/model_doc/gemma2.md index 4380cae26903..e63e0193f5d5 100644 --- a/docs/source/en/model_doc/gemma2.md +++ b/docs/source/en/model_doc/gemma2.md @@ -58,7 +58,7 @@ pipe("Explain quantum computing simply. ", max_new_tokens=50) - + ```python import torch from transformers import AutoTokenizer, AutoModelForCausalLM @@ -89,7 +89,7 @@ echo -e "Explain quantum computing simply." | transformers-cli run --task text-g Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends. - + The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4. ```python @@ -118,7 +118,7 @@ Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/bl ```python from transformers.utils.attention_visualizer import AttentionMaskVisualizer visualizer = AttentionMaskVisualizer("google/gemma-2b") -visualizer("You are an assistant. Make sure you print me") +visualizer("You are an assistant. Make sure you print me") ```
@@ -137,7 +137,7 @@ visualizer("You are an assistant. Make sure you print me") inputs = tokenizer(text="My name is Gemma", return_tensors="pt") max_generated_length = inputs.input_ids.shape[1] + 10 - past_key_values = HybridCache(config=model.config, max_batch_size=1, + past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) ``` diff --git a/docs/source/ko/llm_optims.md b/docs/source/ko/llm_optims.md index f6eaa58c0004..17495cbd0fe2 100644 --- a/docs/source/ko/llm_optims.md +++ b/docs/source/ko/llm_optims.md @@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16 past_key_values = StaticCache( config=model.config, - max_batch_size=1, + batch_size=1, # 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다 max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), device=model.device, @@ -161,7 +161,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu batch_size, seq_length = inputs["input_ids"].shape with torch.no_grad(): past_key_values = StaticCache( - config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype + config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype ) cache_position = torch.arange(seq_length, device=torch_device) generated_ids = torch.zeros( diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 12673a8d41a5..08e6fc0698a9 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -12,6 +12,7 @@ from .configuration_utils import PretrainedConfig from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging +from .utils.deprecation import deprecate_kwarg if is_hqq_available(): @@ -1142,10 +1143,9 @@ class StaticCache(Cache): Parameters: config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. If you are manually setting the batch size, make sure to take into account the - number of beams if you are running beam search + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. max_cache_len (`int`, *optional*): The maximum sequence length with which the model will be used. device (`torch.device` or `str`, *optional*): @@ -1172,7 +1172,7 @@ class StaticCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation StaticCache() @@ -1181,17 +1181,18 @@ class StaticCache(Cache): is_compileable = True + @deprecate_kwarg("max_batch_size", version="4.53", new_name="batch_size") def __init__( self, config: PretrainedConfig, - max_batch_size: int, + batch_size: int, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.float32, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() - self.max_batch_size = max_batch_size + self.batch_size = batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads @@ -1209,7 +1210,7 @@ def __init__( self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] # Note: There will be significant perf decrease if switching to use 5D tensors instead. - cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) + cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) device = torch.device(device) if device is not None else None for idx in range(config.num_hidden_layers): if layer_device_map is not None: @@ -1313,8 +1314,8 @@ class SlidingWindowCache(StaticCache): Parameters: config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. max_cache_len (`int`, *optional*): The maximum sequence length with which the model will be used. @@ -1341,7 +1342,7 @@ class SlidingWindowCache(StaticCache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation SlidingWindowCache() @@ -1351,10 +1352,11 @@ class SlidingWindowCache(StaticCache): is_sliding = True is_compileable = True + @deprecate_kwarg("max_batch_size", version="4.53", new_name="batch_size") def __init__( self, config: PretrainedConfig, - max_batch_size: int, + batch_size: int, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.float32, @@ -1369,7 +1371,7 @@ def __init__( max_cache_len = min(config.sliding_window, max_cache_len) super().__init__( config=config, - max_batch_size=max_batch_size, + batch_size=batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype, @@ -1619,8 +1621,8 @@ class HybridCache(Cache): Parameters: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. max_cache_len (`int`, *optional*): The maximum sequence length with which the model will be used. @@ -1647,7 +1649,7 @@ class HybridCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation HybridCache() @@ -1658,10 +1660,11 @@ class HybridCache(Cache): # ALL changes from the PR that commented the line below when reactivating it. # is_compileable = True + @deprecate_kwarg("max_batch_size", version="4.53", new_name="batch_size") def __init__( self, config: PretrainedConfig, - max_batch_size: int, + batch_size: int, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.float32, @@ -1675,7 +1678,7 @@ def __init__( "config and it's not set to None." ) self.max_cache_len = max_cache_len - self.max_batch_size = max_batch_size + self.batch_size = batch_size # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads @@ -1692,9 +1695,9 @@ def __init__( ) self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) + global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) sliding_cache_shape = ( - self.max_batch_size, + self.batch_size, self.num_key_value_heads, min(config.sliding_window, max_cache_len), self.head_dim, @@ -1823,8 +1826,8 @@ class HybridChunkedCache(Cache): Parameters: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. max_cache_len (`int`, *optional*): The maximum sequence length with which the model will be used. @@ -1851,7 +1854,7 @@ class HybridChunkedCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation HybridCache() @@ -1862,10 +1865,11 @@ class HybridChunkedCache(Cache): # ALL changes from the PR that commented the line below when reactivating it. is_compileable = True + @deprecate_kwarg("max_batch_size", version="4.53", new_name="batch_size") def __init__( self, config: PretrainedConfig, - max_batch_size: int, + batch_size: int, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.bfloat16, @@ -1877,7 +1881,7 @@ def __init__( else: self.sliding_window = config.sliding_window self.max_cache_len = max_cache_len - self.max_batch_size = max_batch_size + self.batch_size = batch_size self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self._dtype = dtype @@ -1897,13 +1901,8 @@ def initialise_cache_layer(self, layer_idx, key_states): num_key_value_heads = key_states.shape[1] device = key_states.device - global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = ( - self.max_batch_size, - num_key_value_heads, - self.sliding_window, - self.head_dim, - ) + global_cache_shape = (self.batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) + sliding_cache_shape = (self.batch_size, num_key_value_heads, self.sliding_window, self.head_dim) # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape @@ -2019,8 +2018,9 @@ class MambaCache: Arguments: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): The default `dtype` to use when initializing the layer. device (`torch.device` or `str`, *optional*): @@ -2037,7 +2037,7 @@ class MambaCache: >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values MambaCache() @@ -2047,14 +2047,15 @@ class MambaCache: is_compileable = True # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + @deprecate_kwarg("max_batch_size", version="4.53", new_name="batch_size") def __init__( self, config: PretrainedConfig, - max_batch_size: int, + batch_size: int, dtype: torch.dtype = torch.float16, device: Union[torch.device, str, None] = None, ): - self.max_batch_size = max_batch_size + self.batch_size = batch_size self._dtype = dtype self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size @@ -2065,14 +2066,14 @@ def __init__( device = torch.device(device) if device is not None else None for _ in range(config.num_hidden_layers): conv_state: torch.Tensor = torch.zeros( - self.max_batch_size, + self.batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=self._dtype, ) ssm_state: torch.Tensor = torch.zeros( - self.max_batch_size, + self.batch_size, self.intermediate_size, self.ssm_state_size, device=device, @@ -2121,8 +2122,9 @@ class OffloadedStaticCache(StaticCache): config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. max_cache_len (`int`): The maximum sequence length with which the model will be used. device (`Union[str, torch.device]`): @@ -2150,7 +2152,7 @@ class OffloadedStaticCache(StaticCache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = OffloadedStaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation ``` @@ -2158,10 +2160,11 @@ class OffloadedStaticCache(StaticCache): is_compileable = True + @deprecate_kwarg("max_batch_size", version="4.53", new_name="batch_size") def __init__( self, config: PretrainedConfig, - max_batch_size: int, + batch_size: int, max_cache_len: Optional[int], device: Union[str, torch.device], dtype: Optional[torch.dtype] = None, @@ -2169,7 +2172,7 @@ def __init__( layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super(Cache, self).__init__() - self.max_batch_size = max_batch_size + self.batch_size = batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) self.offload_device = torch.device(offload_device) @@ -2184,7 +2187,7 @@ def __init__( else config.num_key_value_heads ) - cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) + cache_shape = (batch_size, num_key_value_heads, self.max_cache_len, head_dim) # Create offloaded CPU tensors. self.key_cache: List[torch.Tensor] = [] diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4ea1f88136d3..cb98849dbcf7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1832,7 +1832,7 @@ def _get_cache( need_new_cache = ( not hasattr(self, "_cache") or (not isinstance(cache_to_check, cache_cls)) - or cache_to_check.max_batch_size != batch_size + or cache_to_check.batch_size != batch_size ) if cache_implementation != "mamba": need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len @@ -1852,7 +1852,7 @@ def _get_cache( layer_device_map = self._get_layer_device_map_for_cache_init() cache_kwargs = { "config": self.config.get_text_config(), - "max_batch_size": batch_size, + "batch_size": batch_size, "max_cache_len": max_cache_len, "dtype": cache_dtype, "device": device, diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 591c556e59f0..ed3322093a30 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -71,7 +71,7 @@ def __init__(self, model: PreTrainedModel): self.model = model self.static_cache = StaticCache( config=self.model.config, - max_batch_size=self.model.generation_config.cache_config.batch_size, + batch_size=self.model.generation_config.cache_config.batch_size, max_cache_len=self.model.generation_config.cache_config.max_cache_len, device=self.model.generation_config.cache_config.device, dtype=self.model.dtype, @@ -263,7 +263,7 @@ def __init__(self, model, max_static_cache_length, batch_size): # Initialize static cache self.static_cache = StaticCache( config=self.config, - max_batch_size=batch_size, + batch_size=batch_size, max_cache_len=max_static_cache_length, device="cpu", dtype=torch.float32, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 18a3a50ac157..f5af771b427a 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -565,7 +565,7 @@ def forward( # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` past_key_values = HybridCache( self.config, - max_batch_size=batch_size, + batch_size=batch_size, max_cache_len=seq_len, dtype=inputs_embeds.dtype, device=self.device, diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 3d1bdaeca944..f773e56fdb3a 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -489,7 +489,7 @@ def forward( # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` past_key_values = HybridCache( self.config, - max_batch_size=batch_size, + batch_size=batch_size, max_cache_len=seq_len, dtype=inputs_embeds.dtype, device=self.device, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 144a94ef33e9..97e3469eaefa 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -568,7 +568,7 @@ def forward( # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` past_key_values = HybridCache( self.config, - max_batch_size=batch_size, + batch_size=batch_size, max_cache_len=seq_len, dtype=inputs_embeds.dtype, device=self.device, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 384f3e08023d..ab2912ab66e5 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -439,7 +439,7 @@ def forward( # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` past_key_values = HybridCache( self.config, - max_batch_size=batch_size, + batch_size=batch_size, max_cache_len=seq_len, dtype=inputs_embeds.dtype, device=self.device, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 0988e2692aa4..226f267c5b79 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -655,7 +655,7 @@ def forward( batch_size, seq_len, _ = inputs_embeds.shape past_key_values = HybridCache( self.config, - max_batch_size=batch_size, + batch_size=batch_size, max_cache_len=seq_len, dtype=inputs_embeds.dtype, ) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 3f7292f13a07..6da37502d75c 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -620,7 +620,7 @@ def forward( batch_size, seq_len, _ = inputs_embeds.shape past_key_values = HybridCache( self.config, - max_batch_size=batch_size, + batch_size=batch_size, max_cache_len=seq_len, dtype=inputs_embeds.dtype, ) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e16d477335d5..021492281d7c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3979,7 +3979,7 @@ def test_init_static_cache_multi_gpu(self): # TODO: We need to raise a warning in case the cache is not set correctly # with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"): # past_key_values = StaticCache( - # config=model.config, max_batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype + # config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype # ) # results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) @@ -3987,7 +3987,7 @@ def test_init_static_cache_multi_gpu(self): layer_device_map = {0: 0, 1: 1} past_key_values = StaticCache( config=model.config, - max_batch_size=1, + batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype, @@ -4189,7 +4189,7 @@ def test_prepare_inputs_for_generation_decoder_llm(self): query_length = input_ids.shape[-1] - init_input_ids.shape[-1] static_cache = StaticCache( config=config, - max_batch_size=batch_size, + batch_size=batch_size, max_cache_len=max_cache_len, device=torch_device, dtype=torch.float32, diff --git a/tests/models/diffllama/test_modeling_diffllama.py b/tests/models/diffllama/test_modeling_diffllama.py index c738fbf76d1a..f8b9612faf7c 100644 --- a/tests/models/diffllama/test_modeling_diffllama.py +++ b/tests/models/diffllama/test_modeling_diffllama.py @@ -778,7 +778,7 @@ def test_stacked_causal_mask_static_cache(self): max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] past_key_values = StaticCache( config=self.model.config, - max_batch_size=1, + batch_size=1, max_cache_len=max_cache_len, device=torch_device, dtype=self.model.dtype, @@ -826,7 +826,7 @@ def test_partial_stacked_causal_mask_static_cache(self): max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] past_key_values = StaticCache( config=self.model.config, - max_batch_size=1, + batch_size=1, max_cache_len=max_cache_len, device=torch_device, dtype=self.model.dtype, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index fe843938cf54..4d1fd7ad22f2 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -836,7 +836,7 @@ def test_stacked_causal_mask_static_cache(self): max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] past_key_values = StaticCache( config=self.model.config, - max_batch_size=1, + batch_size=1, max_cache_len=max_cache_len, device=torch_device, dtype=self.model.dtype, @@ -884,7 +884,7 @@ def test_partial_stacked_causal_mask_static_cache(self): max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] past_key_values = StaticCache( config=self.model.config, - max_batch_size=1, + batch_size=1, max_cache_len=max_cache_len, device=torch_device, dtype=self.model.dtype, diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 840493648ffc..ac27584b805e 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -420,7 +420,7 @@ def test_dtype_mismatch_handled_in_cache(self): model.eval() # Create cache with float32 dtype - cache_params = MambaCache(config, max_batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device) + cache_params = MambaCache(config, batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device) # If code is correct, no error occurs and test passes outputs = model( diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index c2af64ffd8ab..b76888b8f3a1 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -51,7 +51,7 @@ def __init__(self, model: Phi3ForCausalLM, batch_size: int, max_seq_len: int): self.model = model self.cache = StaticCache( config=model.config, - max_batch_size=batch_size, + batch_size=batch_size, max_cache_len=max_seq_len, device=self.model.device, dtype=self.model.dtype, diff --git a/tests/models/phimoe/test_modeling_phimoe.py b/tests/models/phimoe/test_modeling_phimoe.py index 7f548bd2dc0b..bacda1bf3595 100644 --- a/tests/models/phimoe/test_modeling_phimoe.py +++ b/tests/models/phimoe/test_modeling_phimoe.py @@ -50,7 +50,7 @@ def __init__(self, model: PhimoeForCausalLM, batch_size: int, max_seq_len: int): self.model = model self.cache = StaticCache( config=model.config, - max_batch_size=batch_size, + batch_size=batch_size, max_cache_len=max_seq_len, device=self.model.device, dtype=self.model.dtype, diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index bbf0268c8b48..b45ebf7e8ae8 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -150,7 +150,7 @@ def _random_kvs(config): return random_keys, random_values mha_config = LlamaConfig(num_attention_heads=32) - mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) + mha_static_cache = StaticCache(config=mha_config, batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mha_static_cache.update( *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -158,7 +158,7 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 32, 10, 128)) gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) - gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + gqa_static_cache = StaticCache(config=gqa_config, batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = gqa_static_cache.update( *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -166,7 +166,7 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 4, 10, 128)) mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) - mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + mqa_static_cache = StaticCache(config=mqa_config, batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mqa_static_cache.update( *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -670,7 +670,7 @@ def test_cache_copy(self): model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16) prompt_cache = StaticCache( - config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16 + config=model.config, batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16 ) INITIAL_PROMPT = "You are a helpful assistant. "