Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,15 +1148,15 @@ class StaticCache(Cache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = StaticCache(max_cache_len=max_generated_length, config=model.config)
>>> past_key_values = StaticCache(config=model.config, max_cache_len=max_generated_length)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
StaticCache()
```
"""

# Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs):
# Pass-in args and kwargs as well to avoid crashing for BC (it used more arguments before)
def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]
super().__init__(layers=layers)

Expand All @@ -1183,15 +1183,15 @@ class OffloadedStaticCache(Cache):

>>> # Prepare a cache class with offloading
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = OffloadedStaticCache(max_cache_len=max_generated_length, config=model.config)
>>> past_key_values = OffloadedStaticCache(config=model.config, max_cache_len=max_generated_length)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache with offloaded layers
OffloadedStaticCache()
```
"""

# Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs):
# Pass-in args and kwargs as well to avoid crashing for BC (it used more arguments before)
def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]
super().__init__(layers=layers, offloading=True)

Expand All @@ -1214,15 +1214,15 @@ class SlidingWindowCache(Cache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = SlidingWindowCache(max_cache_len=max_generated_length, config=model.config)
>>> past_key_values = SlidingWindowCache(config=model.config, max_cache_len=max_generated_length)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
SlidingWindowCache()
```
"""

# Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs):
# Pass-in args and kwargs as well to avoid crashing for BC (it used more arguments before)
def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
layers = [SlidingWindowLayer(max_cache_len, config.sliding_window) for _ in range(config.num_hidden_layers)]
super().__init__(layers=layers)

Expand All @@ -1249,15 +1249,15 @@ class HybridCache(Cache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(max_cache_len=max_generated_length, config=model.config)
>>> past_key_values = HybridCache(config=model.config, max_cache_len=max_generated_length)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
HybridCache()
```
"""

# Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs):
# Pass-in args and kwargs as well to avoid crashing for BC (it used more arguments before)
def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
if hasattr(config, "layer_types"):
layers = []
for layer_type in config.layer_types:
Expand Down Expand Up @@ -1288,8 +1288,8 @@ class OffloadedHybridCache(Cache):
See `Cache` for details on common methods that are implemented by all cache classes.
"""

# Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs):
# Pass-in args and kwargs as well to avoid crashing for BC (it used more arguments before)
def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
if hasattr(config, "layer_types"):
layers = []
for layer_type in config.layer_types:
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,10 +1859,7 @@ def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len:
)

if need_new_cache:
cache_kwargs = {
"max_cache_len": max_cache_len,
"config": self.config.get_text_config(),
}
cache_kwargs = {"config": self.config.get_text_config(), "max_cache_len": max_cache_len}
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def __init__(
raise AssertionError("Model must have caching enabled.")

# Initialize the HybridCache
self.cache = HybridCache(max_cache_len=generation_config.cache_config.get("max_cache_len"), config=config)
self.cache = HybridCache(config=config, max_cache_len=generation_config.cache_config.get("max_cache_len"))
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
max_batch_size = generation_config.cache_config.get("batch_size")
Expand Down Expand Up @@ -818,7 +818,7 @@ def __init__(self, model, max_static_cache_length, batch_size):
self.config = model.config

# Initialize static cache for decoder and DynamicCache for encoder
self.static_cache = StaticCache(max_cache_len=max_static_cache_length, config=self.config)
self.static_cache = StaticCache(config=self.config, max_cache_len=max_static_cache_length)
head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, "cpu")
Expand Down
4 changes: 2 additions & 2 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def test_stacked_causal_mask_static_cache(self):

# upgrade the model with StaticCache
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(max_cache_len=max_cache_len, config=self.model.config)
past_key_values = StaticCache(config=self.model.config, max_cache_len=max_cache_len)

padded_attention_mask = torch.nn.functional.pad(
input=mask_shared_prefix,
Expand Down Expand Up @@ -546,7 +546,7 @@ def test_partial_stacked_causal_mask_static_cache(self):

# upgrade the model with StaticCache
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(max_cache_len=max_cache_len, config=self.model.config)
past_key_values = StaticCache(config=self.model.config, max_cache_len=max_cache_len)

# forward run for the first part of input
part_a = 3 # split point
Expand Down