From d8944058ac3e29926c014d278643ae72d9a8ed4d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 17:19:21 +0200 Subject: [PATCH 01/48] Add new dynamic cache --- src/transformers/cache_utils.py | 79 +++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 4e4a1ee26c12..37cc8f24b221 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -545,6 +545,85 @@ def batch_select_indices(self, indices: torch.Tensor): self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] +class DynamicSlidingWindowCache(DynamicCache): + """ + A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. + This is the default for generative models with sliding window attention. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicSlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicSlidingWindowCache(model.config.sliding_window) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicSlidingWindowCache() + ``` + """ + + def __init__(self, sliding_window: int): + super().__init__() + self.sliding_window = sliding_window + self.slicing_ + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous + tokens according to the sliding window if needed. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + # Add only up to sliding window size if larger + self.key_cache.append(key_states[:, :, -self.sliding_window:, :]) + self.value_cache.append(value_states[:, :, -self.sliding_window:, :]) + else: + new_seq_len = key_states.shape[-2] + current_seq_len = self.get_seq_length(layer_idx) + if new_seq_len + current_seq_len > self.sliding_window: + # We need to slice + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) + else: + # Similar to DynamicCache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + class OffloadedCache(DynamicCache): """ A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. From 3b0984b99e73bca779b6288b953f488479323acd Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 17:46:05 +0200 Subject: [PATCH 02/48] Add cache by default in generate for models supporting it --- src/transformers/generation/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5da4878513eb..e49e9d9636bb 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -28,6 +28,7 @@ from ..cache_utils import ( Cache, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, OffloadedCache, QuantizedCacheConfig, From 345e695d73b6828886d99bbb1ecfcce952e9d235 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 18:02:54 +0200 Subject: [PATCH 03/48] Add to __init__ and correct typo --- src/transformers/__init__.py | 1 + src/transformers/cache_utils.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ab829c6894c0..9bae0dff2d09 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1268,6 +1268,7 @@ "Cache", "CacheConfig", "DynamicCache", + "DynamicSlidingWindowCache", "EncoderDecoderCache", "HQQQuantizedCache", "HybridCache", diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 37cc8f24b221..99d2d418bd0f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -574,7 +574,6 @@ class DynamicSlidingWindowCache(DynamicCache): def __init__(self, sliding_window: int): super().__init__() self.sliding_window = sliding_window - self.slicing_ def update( self, From 38e82b5432189a95f265054e9e0f7898ea290625 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 19:41:44 +0200 Subject: [PATCH 04/48] Correct output if prefill larger than sliding window + compatibility --- src/transformers/cache_utils.py | 35 ++++++++++++++++++++++++---- src/transformers/generation/utils.py | 6 ++--- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 99d2d418bd0f..a517d9ef90b5 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -606,15 +606,17 @@ def update( # Update the cache if len(self.key_cache) <= layer_idx: # Add only up to sliding window size if larger - self.key_cache.append(key_states[:, :, -self.sliding_window:, :]) - self.value_cache.append(value_states[:, :, -self.sliding_window:, :]) + self.key_cache.append(key_states[..., -self.sliding_window:, :]) + self.value_cache.append(value_states[..., -self.sliding_window:, :]) + # We should return full states during prefill even though we only save up to sliding window + return key_states, value_states else: new_seq_len = key_states.shape[-2] current_seq_len = self.get_seq_length(layer_idx) if new_seq_len + current_seq_len > self.sliding_window: # We need to slice - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) else: # Similar to DynamicCache self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) @@ -622,6 +624,31 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSlidingWindowCache"]: + """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicSlidingWindowCache(self.sliding_window) + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": + """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + cache = cls(splits[0].sliding_window) + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + # Legacy format does not really make sense here even though it is a DynamicCache -> we set methods to None + from_legacy_cache = None + to_legacy_cache = None + class OffloadedCache(DynamicCache): """ diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e49e9d9636bb..0f61cc84db88 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4604,10 +4604,8 @@ def _concat(data): if isinstance(data[0], torch.Tensor): return torch.cat(data, dim=0) # New cache format - elif isinstance(data[0], DynamicCache): - return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) - elif isinstance(data[0], EncoderDecoderCache): - return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) + elif isinstance(data[0], (DynamicCache, EncoderDecoderCache)): + return data[0].__class__.from_batch_splits(data) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): From c46a92a23b4c25fea9bf9fc9b9df9d8daba98dde Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 22:08:36 +0200 Subject: [PATCH 05/48] Add legacy format handling --- src/transformers/cache_utils.py | 21 +++++++++++++++------ src/transformers/generation/utils.py | 2 +- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index a517d9ef90b5..df34b8de6919 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -623,9 +623,21 @@ def update( self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] + + @classmethod + def from_legacy_cache(cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicSlidingWindowCache`. Used for + backward compatibility.""" + cache = cls(sliding_window) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSlidingWindowCache"]: - """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" out = [] for i in range(0, full_batch_size, split_size): current_split = DynamicSlidingWindowCache(self.sliding_window) @@ -637,17 +649,14 @@ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSli @classmethod def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": - """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" cache = cls(splits[0].sliding_window) for idx in range(len(splits[0])): layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) cache.update(layer_keys, layer_values, idx) return cache - - # Legacy format does not really make sense here even though it is a DynamicCache -> we set methods to None - from_legacy_cache = None - to_legacy_cache = None class OffloadedCache(DynamicCache): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0f61cc84db88..919a0bcc6b79 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2354,7 +2354,7 @@ def typeerror(): should_convert_cache = generation_config.return_legacy_cache is_user_defined_cache = user_defined_cache is not None is_default_cache_type = ( - type(result.past_key_values) == DynamicCache # noqa E721 + type(result.past_key_values) in (DynamicCache, DynamicSlidingWindowCache) # noqa E721 or ( isinstance(result.past_key_values, EncoderDecoderCache) and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 From 02b8506eb542545011c72c6f2977ec4e96e8d323 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 13:01:41 +0200 Subject: [PATCH 06/48] style --- src/transformers/cache_utils.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index df34b8de6919..08eb329cab83 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -606,8 +606,8 @@ def update( # Update the cache if len(self.key_cache) <= layer_idx: # Add only up to sliding window size if larger - self.key_cache.append(key_states[..., -self.sliding_window:, :]) - self.value_cache.append(value_states[..., -self.sliding_window:, :]) + self.key_cache.append(key_states[..., -self.sliding_window :, :]) + self.value_cache.append(value_states[..., -self.sliding_window :, :]) # We should return full states during prefill even though we only save up to sliding window return key_states, value_states else: @@ -615,17 +615,23 @@ def update( current_seq_len = self.get_seq_length(layer_idx) if new_seq_len + current_seq_len > self.sliding_window: # We need to slice - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], value_states], dim=-2 + ) else: # Similar to DynamicCache self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] - + @classmethod - def from_legacy_cache(cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + def from_legacy_cache( + cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "DynamicCache": """Converts a cache in the legacy cache format into an equivalent `DynamicSlidingWindowCache`. Used for backward compatibility.""" cache = cls(sliding_window) @@ -646,7 +652,7 @@ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSli current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] out.append(current_split) return out - + @classmethod def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in From 7a98aac880bf130669b295694997270e00965ae7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 14:29:44 +0200 Subject: [PATCH 07/48] add docs --- docs/source/en/internal/generation_utils.md | 7 +++++++ src/transformers/cache_utils.py | 5 +++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index a81d202c6634..3d4dfef70267 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -367,6 +367,13 @@ A [`Constraint`] can be used to force the generation to include specific tokens - to_legacy_cache - from_legacy_cache +[[autodoc]] DynamicSlidingWindowCache + - update + - get_seq_length + - reorder_cache + - to_legacy_cache + - from_legacy_cache + [[autodoc]] QuantizedCache - update - get_seq_length diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 08eb329cab83..256126ff9b5e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -350,7 +350,8 @@ def validate(self): class DynamicCache(Cache): """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. + A cache that grows dynamically as more tokens are generated. This is the default for generative models without sliding window attention + (see `DynamicSlidingWindowCache` in this case). It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. @@ -548,7 +549,7 @@ def batch_select_indices(self, indices: torch.Tensor): class DynamicSlidingWindowCache(DynamicCache): """ A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. - This is the default for generative models with sliding window attention. + This is the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used). It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window. From ebe6dc91b171793c42eef2ed4d893e340b4238e9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 14:41:33 +0200 Subject: [PATCH 08/48] fix import --- src/transformers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9bae0dff2d09..af795e60ba47 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -6157,6 +6157,7 @@ Cache, CacheConfig, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, HQQQuantizedCache, HybridCache, From af95f2ad467b64ee8105733526b749225f1c50fb Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 14:49:50 +0200 Subject: [PATCH 09/48] Update dummy_pt_objects.py --- src/transformers/utils/dummy_pt_objects.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 048de1cc8ae7..514277b10766 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -37,6 +37,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class DynamicSlidingWindowCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class EncoderDecoderCache(metaclass=DummyObject): _backends = ["torch"] From 08d1a9f09b15a0da94ccd2d2223ef76ec213d034 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 16:59:30 +0200 Subject: [PATCH 10/48] Update test --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1727aed1117b..19920dab7d03 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,7 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import DynamicCache, DynamicSlidingWindowCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, From b73655ae7d7a0482aec056183f750a9c2169fcae Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 17:01:44 +0200 Subject: [PATCH 11/48] style --- tests/generation/test_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 19920dab7d03..0b2f47a95265 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,13 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, DynamicSlidingWindowCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import ( + DynamicCache, + DynamicSlidingWindowCache, + EncoderDecoderCache, + QuantoQuantizedCache, + StaticCache, + ) from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, From ff16af0165678c3942cc9009e9cde2240f4bd6c2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 18:25:25 +0200 Subject: [PATCH 12/48] update cache conversion in test --- tests/generation/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0b2f47a95265..691ab457c488 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1816,7 +1816,10 @@ def test_new_cache_format(self, num_beams, do_sample): ) new_cache = new_results.past_key_values - legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) + if cache_cls == DynamicSlidingWindowCache: + legacy_cache_converted = cache_cls.from_legacy_cache(config.sliding_window, legacy_results.past_key_values) + else: + legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) for layer_idx in range(len(new_cache)): for kv_idx in range(len(new_cache[layer_idx])): # TODO: @raushan, please look into this for new cache format From 5e3fef01ef257e12ef80c6ac5c297da3197e8fe2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Sep 2024 16:34:41 +0200 Subject: [PATCH 13/48] style --- tests/generation/test_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 691ab457c488..c4177b79c35d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1817,7 +1817,9 @@ def test_new_cache_format(self, num_beams, do_sample): new_cache = new_results.past_key_values if cache_cls == DynamicSlidingWindowCache: - legacy_cache_converted = cache_cls.from_legacy_cache(config.sliding_window, legacy_results.past_key_values) + legacy_cache_converted = cache_cls.from_legacy_cache( + config.sliding_window, legacy_results.past_key_values + ) else: legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) for layer_idx in range(len(new_cache)): From 3d1bfd0b52f052d2918fcc863f43ac43d5538c00 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Sep 2024 10:54:45 +0200 Subject: [PATCH 14/48] Allow the cache to support new states of more than 1 token, even after prefill stage --- src/transformers/cache_utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 256126ff9b5e..62e0ca7c2fc5 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -587,6 +587,10 @@ def update( Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous tokens according to the sliding window if needed. + Note: we always keep `sliding_window` tokens in the cache, instead of the `sliding_window - 1` tokens that + are strictly necesary. This allows to roll back one token in the past with `cache.crop(-1)` in contrastive search. + Assisted decoding would need to roll back additional tokens, and is therefore not supported with this Cache class. + Parameters: key_states (`torch.Tensor`): The new key states to cache. @@ -615,13 +619,17 @@ def update( new_seq_len = key_states.shape[-2] current_seq_len = self.get_seq_length(layer_idx) if new_seq_len + current_seq_len > self.sliding_window: - # We need to slice - self.key_cache[layer_idx] = torch.cat( - [self.key_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], key_states], dim=-2 + # We may need to return longer states (e.g. to continue generation with previous cache, with added tokens), but we only keep + # the last `sliding_window` states in the cache for next forward + full_key_states = torch.cat( + [self.key_cache[layer_idx][..., -(self.sliding_window - 1) :, :], key_states], dim=-2 ) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], value_states], dim=-2 + full_value_states = torch.cat( + [self.value_cache[layer_idx][..., -(self.sliding_window - 1) :, :], value_states], dim=-2 ) + self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window :, :] + self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window :, :] + return full_key_states, full_value_states else: # Similar to DynamicCache self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) From 6a02bdca13cf19e0234cd027575058e28bac2a6f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Sep 2024 10:56:32 +0200 Subject: [PATCH 15/48] Update cache_utils.py --- src/transformers/cache_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 62e0ca7c2fc5..1df6f6e36ef0 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -587,10 +587,10 @@ def update( Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous tokens according to the sliding window if needed. - Note: we always keep `sliding_window` tokens in the cache, instead of the `sliding_window - 1` tokens that + Note: we always keep `sliding_window` tokens in the cache if it is full, instead of the `sliding_window - 1` tokens that are strictly necesary. This allows to roll back one token in the past with `cache.crop(-1)` in contrastive search. Assisted decoding would need to roll back additional tokens, and is therefore not supported with this Cache class. - + Parameters: key_states (`torch.Tensor`): The new key states to cache. From 838712ded0619b5da09f05199f7dfe7a276ef5cc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Sep 2024 13:37:39 +0200 Subject: [PATCH 16/48] maybe change test --- tests/generation/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c4177b79c35d..dec7e46f6a5e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1726,7 +1726,10 @@ def test_generate_continue_from_past_key_values(self): outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) # Continue from the tokens generated above, preparing the inputs accordingly - inputs["past_key_values"] = outputs_cached.past_key_values + if getattr(config, "sliding_window", None) is not None: + inputs["past_key_values"] = DynamicSlidingWindowCache(config.sliding_window, outputs_cached.past_key_values) + else: + inputs["past_key_values"] = outputs_cached.past_key_values new_attention_len = outputs_cached.sequences.shape[-1] if config.is_encoder_decoder: inputs["decoder_input_ids"] = outputs_cached.sequences From 6afd20d3d7f06eb44547fa9ff600b35d2cfe9332 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 15:27:28 +0200 Subject: [PATCH 17/48] revert tests diffs --- tests/generation/test_utils.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index dec7e46f6a5e..1727aed1117b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,13 +62,7 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import ( - DynamicCache, - DynamicSlidingWindowCache, - EncoderDecoderCache, - QuantoQuantizedCache, - StaticCache, - ) + from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -1726,10 +1720,7 @@ def test_generate_continue_from_past_key_values(self): outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) # Continue from the tokens generated above, preparing the inputs accordingly - if getattr(config, "sliding_window", None) is not None: - inputs["past_key_values"] = DynamicSlidingWindowCache(config.sliding_window, outputs_cached.past_key_values) - else: - inputs["past_key_values"] = outputs_cached.past_key_values + inputs["past_key_values"] = outputs_cached.past_key_values new_attention_len = outputs_cached.sequences.shape[-1] if config.is_encoder_decoder: inputs["decoder_input_ids"] = outputs_cached.sequences @@ -1819,12 +1810,7 @@ def test_new_cache_format(self, num_beams, do_sample): ) new_cache = new_results.past_key_values - if cache_cls == DynamicSlidingWindowCache: - legacy_cache_converted = cache_cls.from_legacy_cache( - config.sliding_window, legacy_results.past_key_values - ) - else: - legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) + legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) for layer_idx in range(len(new_cache)): for kv_idx in range(len(new_cache[layer_idx])): # TODO: @raushan, please look into this for new cache format From 217e803405f6f3f20c2cdd80709aeda6aedc4f53 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 16:32:53 +0200 Subject: [PATCH 18/48] define get_seen_tokens --- src/transformers/cache_utils.py | 74 +++++++++++++++++----------- src/transformers/generation/utils.py | 2 +- 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 1df6f6e36ef0..08653aaf54a7 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -63,6 +63,12 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + def get_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: + """Returns the number of already processed tokens. For all Cache classes except SlidingWindow caches, this is the same as + `get_seq_length()`. However, with sliding window we can process more tokens than the cache size. A layer index can be optionally passed. + """ + return self.get_seq_length(layer_idx) # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles @@ -350,8 +356,7 @@ def validate(self): class DynamicCache(Cache): """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models without sliding window attention - (see `DynamicSlidingWindowCache` in this case). + A cache that grows dynamically as more tokens are generated. This is the default for generative models. It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. @@ -549,7 +554,7 @@ def batch_select_indices(self, indices: torch.Tensor): class DynamicSlidingWindowCache(DynamicCache): """ A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. - This is the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used). + This will be the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used). It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window. @@ -572,9 +577,18 @@ class DynamicSlidingWindowCache(DynamicCache): ``` """ - def __init__(self, sliding_window: int): - super().__init__() + def __init__(self, sliding_window: int, num_hidden_layers: Optional[int] = None) -> None: + super().__init__(num_hidden_layers) self.sliding_window = sliding_window + # We overwrite the field and maintain a list of size `num_hidden_layers` to accurately reflect the seen tokens at each layer during `update` + self._seen_tokens = [0]*num_hidden_layers if num_hidden_layers is not None else [] + + def get_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: + """This needs to be overriden because the number of processed tokens may be larger than the cache length.""" + if len(self._seen_tokens) <= layer_idx: + return 0 + else: + return self._seen_tokens[layer_idx] def update( self, @@ -604,18 +618,23 @@ def update( Return: A tuple containing the updated key and value states. """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the cache if len(self.key_cache) <= layer_idx: + # Update the number of seen tokens + self._seen_tokens.append(key_states.shape[-2]) # Add only up to sliding window size if larger self.key_cache.append(key_states[..., -self.sliding_window :, :]) self.value_cache.append(value_states[..., -self.sliding_window :, :]) # We should return full states during prefill even though we only save up to sliding window return key_states, value_states + # In case we initialized empty lists + elif self.key_cache[layer_idx] == []: + self._seen_tokens[layer_idx] += key_states.shape[-2] + self.key_cache[layer_idx] = key_states[..., -self.sliding_window :, :] + self.value_cache[layer_idx] = value_states[..., -self.sliding_window :, :] + # We should return full states during prefill even though we only save up to sliding window + return key_states, value_states else: + self._seen_tokens[layer_idx] += key_states.shape[-2] new_seq_len = key_states.shape[-2] current_seq_len = self.get_seq_length(layer_idx) if new_seq_len + current_seq_len > self.sliding_window: @@ -637,25 +656,12 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] - @classmethod - def from_legacy_cache( - cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicSlidingWindowCache`. Used for - backward compatibility.""" - cache = cls(sliding_window) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSlidingWindowCache"]: + def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]: """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by `_split_model_inputs()` in `generation.utils`""" out = [] for i in range(0, full_batch_size, split_size): - current_split = DynamicSlidingWindowCache(self.sliding_window) + current_split = DynamicSlidingWindowCache(self.sliding_window, num_hidden_layers) current_split._seen_tokens = self._seen_tokens current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] @@ -663,16 +669,24 @@ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSli return out @classmethod - def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": + def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"], num_hidden_layers: int) -> "DynamicSlidingWindowCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" - cache = cls(splits[0].sliding_window) + cache = cls(splits[0].sliding_window, num_hidden_layers) for idx in range(len(splits[0])): - layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) - layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) - cache.update(layer_keys, layer_values, idx) + key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] + value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] + if key_cache != []: + layer_keys = torch.cat(key_cache, dim=0) + layer_values = torch.cat(value_cache, dim=0) + cache.update(layer_keys, layer_values, idx) + + # We need this because _seen_tokens may be bigger than what will be automatically set with `update` (if cache > sliding_window) + cache._seen_tokens = splits[0]._seen_tokens return cache + from_legacy_cache = None + to_legacy_cache = None class OffloadedCache(DynamicCache): """ diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 919a0bcc6b79..0f61cc84db88 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2354,7 +2354,7 @@ def typeerror(): should_convert_cache = generation_config.return_legacy_cache is_user_defined_cache = user_defined_cache is not None is_default_cache_type = ( - type(result.past_key_values) in (DynamicCache, DynamicSlidingWindowCache) # noqa E721 + type(result.past_key_values) == DynamicCache # noqa E721 or ( isinstance(result.past_key_values, EncoderDecoderCache) and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 From 582301cb03aef93b2a337d3cc03a6483111ea2ae Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 17:09:08 +0200 Subject: [PATCH 19/48] Modify all current .get_seq_length names --- docs/source/en/internal/generation_utils.md | 16 ++++++++-------- examples/modular-transformers/modeling_dummy.py | 4 ++-- .../modeling_my_new_model2.py | 4 ++-- examples/modular-transformers/modeling_super.py | 2 +- src/transformers/cache_utils.py | 4 ++-- src/transformers/generation/utils.py | 8 ++++---- src/transformers/models/bloom/modeling_bloom.py | 4 ++-- .../models/chameleon/modeling_chameleon.py | 4 ++-- .../models/codegen/modeling_codegen.py | 4 ++-- .../models/cohere/modeling_cohere.py | 4 ++-- src/transformers/models/dbrx/modeling_dbrx.py | 4 ++-- .../models/falcon/modeling_falcon.py | 4 ++-- src/transformers/models/gemma/modeling_gemma.py | 4 ++-- src/transformers/models/gemma/modular_gemma.py | 2 +- .../models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/gemma2/modular_gemma2.py | 2 +- src/transformers/models/git/modeling_git.py | 4 ++-- .../models/gpt_neo/modeling_gpt_neo.py | 4 ++-- .../models/gpt_neox/modeling_gpt_neox.py | 4 ++-- .../modeling_gpt_neox_japanese.py | 4 ++-- src/transformers/models/gptj/modeling_gptj.py | 4 ++-- .../models/granite/modeling_granite.py | 4 ++-- .../models/granitemoe/modeling_granitemoe.py | 4 ++-- .../models/idefics/modeling_idefics.py | 4 ++-- .../models/idefics2/modeling_idefics2.py | 7 ++++++- .../models/idefics3/modeling_idefics3.py | 7 ++++++- .../models/jetmoe/modeling_jetmoe.py | 4 ++-- src/transformers/models/llama/modeling_llama.py | 4 ++-- src/transformers/models/mimi/modeling_mimi.py | 4 ++-- .../models/mistral/modeling_mistral.py | 4 ++-- .../models/mixtral/modeling_mixtral.py | 4 ++-- .../models/mllama/modeling_mllama.py | 4 ++-- .../models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 4 ++-- src/transformers/models/olmoe/modeling_olmoe.py | 4 ++-- .../models/paligemma/modeling_paligemma.py | 2 +- .../models/persimmon/modeling_persimmon.py | 4 ++-- src/transformers/models/phi/modeling_phi.py | 4 ++-- src/transformers/models/phi3/modeling_phi3.py | 4 ++-- src/transformers/models/qwen2/modeling_qwen2.py | 4 ++-- .../models/qwen2_audio/modeling_qwen2_audio.py | 2 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 4 ++-- .../models/qwen2_vl/modeling_qwen2_vl.py | 4 ++-- .../models/stablelm/modeling_stablelm.py | 4 ++-- .../models/starcoder2/modeling_starcoder2.py | 4 ++-- .../models/whisper/modeling_whisper.py | 8 ++++---- 46 files changed, 103 insertions(+), 93 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 3d4dfef70267..3004e89300fc 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -362,21 +362,21 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] DynamicCache - update - - get_seq_length + - get_past_seen_tokens - reorder_cache - to_legacy_cache - from_legacy_cache [[autodoc]] DynamicSlidingWindowCache - update - - get_seq_length + - get_past_seen_tokens - reorder_cache - to_legacy_cache - from_legacy_cache [[autodoc]] QuantizedCache - update - - get_seq_length + - get_past_seen_tokens [[autodoc]] QuantoQuantizedCache @@ -384,7 +384,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] SinkCache - update - - get_seq_length + - get_past_seen_tokens - reorder_cache [[autodoc]] OffloadedCache @@ -394,17 +394,17 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] StaticCache - update - - get_seq_length + - get_past_seen_tokens - reset [[autodoc]] OffloadedStaticCache - update - - get_seq_length + - get_past_seen_tokens - reset [[autodoc]] HybridCache - update - - get_seq_length + - get_past_seen_tokens - reset [[autodoc]] SlidingWindowCache @@ -412,7 +412,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens - reset [[autodoc]] EncoderDecoderCache - - get_seq_length + - get_past_seen_tokens - to_legacy_cache - from_legacy_cache - reset diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index b5b1fc6aec85..420fe6d6d2c1 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -906,7 +906,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -997,7 +997,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 49cdd2741620..8b20b43c20b3 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -784,7 +784,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -879,7 +879,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index d91bdb1820c2..71b14bb8051a 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -902,7 +902,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 08653aaf54a7..a7b1c1f4a29a 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -64,7 +64,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - def get_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: + def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: """Returns the number of already processed tokens. For all Cache classes except SlidingWindow caches, this is the same as `get_seq_length()`. However, with sliding window we can process more tokens than the cache size. A layer index can be optionally passed. """ @@ -583,7 +583,7 @@ def __init__(self, sliding_window: int, num_hidden_layers: Optional[int] = None) # We overwrite the field and maintain a list of size `num_hidden_layers` to accurately reflect the seen tokens at each layer during `update` self._seen_tokens = [0]*num_hidden_layers if num_hidden_layers is not None else [] - def get_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: + def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: """This needs to be overriden because the number of processed tokens may be larger than the cache length.""" if len(self._seen_tokens) <= layer_idx: return 0 diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0f61cc84db88..9a1e7db6373b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1533,8 +1533,8 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): past_length = 0 if not isinstance(cache, Cache): past_length = cache[0][0].shape[2] - elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: - past_length = cache.get_seq_length() + elif hasattr(cache, "get_past_seen_tokens") and cache.get_past_seen_tokens() is not None: + past_length = cache.get_past_seen_tokens() # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, # end-to-end compilation will yield bad results because `cache_position` will be incorrect. @@ -2765,7 +2765,7 @@ def _contrastive_search( # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step if model_kwargs.get("past_key_values") is None or ( isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache)) - and model_kwargs["past_key_values"].get_seq_length() == 0 + and model_kwargs["past_key_values"].get_past_seen_tokens() == 0 ): # prepare inputs model_kwargs["use_cache"] = True @@ -4167,7 +4167,7 @@ def _assisted_decoding( isinstance(past_key_values, EncoderDecoderCache) and isinstance(past_key_values.self_attention_cache, DynamicCache) ): - if past_key_values.get_seq_length() == 0: + if past_key_values.get_past_seen_tokens() == 0: start_from_empty_dynamic_cache = True this_peer_finished = False diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 75f8e5830f44..b0bf53901f47 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -646,7 +646,7 @@ def forward( ) batch_size, seq_length, _ = inputs_embeds.shape - past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 seq_length_with_past = seq_length + past_length if cache_position is None: cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) @@ -747,7 +747,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index fd76c0b11522..f9cf103e6f50 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1295,7 +1295,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1386,7 +1386,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 478745b2c59e..2b51bad8ab7b 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -487,7 +487,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -590,7 +590,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index a5d3721f5bdb..69505b581470 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -868,7 +868,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -959,7 +959,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index ef81e43d0294..d07bc864839d 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1019,7 +1019,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1120,7 +1120,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index f48accab44bf..94c3930f2de7 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -992,7 +992,7 @@ def forward( # Compute alibi tensor: check build_alibi_tensor documentation alibi = None - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 batch_size, seq_length, _ = inputs_embeds.shape if self.use_alibi: mask = ( @@ -1114,7 +1114,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ff206a470bc3..c7cb92fde0ff 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -780,7 +780,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -875,7 +875,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 7130a30dc9be..ce39025cfbdf 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -860,7 +860,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 0b99aa59c65b..be4e5b8e0c97 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -790,7 +790,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index c0f76dbe5bfc..7ea5b3bdd23e 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -627,7 +627,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index c7f9ceafe194..168935b467ab 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1283,7 +1283,7 @@ def forward( past_key_values_length = ( past_key_values[0][0].shape[2] if not isinstance(past_key_values, Cache) - else past_key_values.get_seq_length() + else past_key_values.get_past_seen_tokens() ) # Prepare head mask if needed @@ -1611,7 +1611,7 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values.get_seq_length() + past_length = past_key_values.get_past_seen_tokens() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 7bba7608e6c1..0099341d5262 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -702,7 +702,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -804,7 +804,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index f4636db0a97b..f0adac7642a7 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -904,7 +904,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -1001,7 +1001,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index b618f531e52f..7215e8e05076 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -624,7 +624,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -705,7 +705,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 5c80485823c1..8145d9b250ab 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -774,7 +774,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) @@ -899,7 +899,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 0eb27d452f08..420d971b2ac4 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -794,7 +794,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -892,7 +892,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index ebdea826fa04..7ac0829509a4 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1020,7 +1020,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1125,7 +1125,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 02de8d61ae20..e757a043e68d 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1153,7 +1153,7 @@ def forward( ) batch_size, seq_length, _ = inputs_embeds.shape - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 seq_length_with_past = seq_length + past_key_values_length if cache_position is None: @@ -1384,7 +1384,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index b53d0722587d..9ca404d3cb45 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1359,7 +1359,7 @@ def forward( "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" ) - past_seen_tokens = past_key_values.get_seq_length() + past_seen_tokens = past_key_values.get_past_seen_tokens() if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") @@ -1664,8 +1664,13 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore +<<<<<<< HEAD past_length = past_key_values.get_seq_length() max_cache_length = past_key_values.get_max_cache_shape() +======= + past_length = past_key_values.get_past_seen_tokens() + max_cache_length = past_key_values.get_max_length() +>>>>>>> 0c098e35c (Modify all current .get_seq_length names) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 757391175ea6..f30c83e7ade3 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -953,7 +953,7 @@ def forward( past_seen_tokens = 0 if use_cache: - past_seen_tokens = past_key_values.get_seq_length() + past_seen_tokens = past_key_values.get_past_seen_tokens() if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") @@ -1252,8 +1252,13 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore +<<<<<<< HEAD past_length = past_key_values.get_seq_length() max_cache_length = past_key_values.get_max_cache_shape() +======= + past_length = past_key_values.get_past_seen_tokens() + max_cache_length = past_key_values.get_max_length() +>>>>>>> 0c098e35c (Modify all current .get_seq_length names) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index bbc70b26d1f8..ca2651871cbe 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -998,7 +998,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1101,7 +1101,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index dde017bbb927..bbcbe5ef96bb 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -904,7 +904,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -995,7 +995,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 514f9de706ec..49d83cf800e0 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -958,7 +958,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device ) @@ -1044,7 +1044,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b0ffe3e56e59..21053e361310 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -775,7 +775,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -872,7 +872,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 9c7fadbb8f88..2619d5792f0c 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -986,7 +986,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1085,7 +1085,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 0bc77eaeec33..89b9813aa4a8 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1620,7 +1620,7 @@ def forward( hidden_states = inputs_embeds if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1726,7 +1726,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 7d0390adc3c0..54e5ff3698ab 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -872,7 +872,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 7ab54146c974..3b1de2709f9a 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -825,7 +825,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -914,7 +914,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 8c29f89ff3e7..47167a62677b 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -973,7 +973,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1073,7 +1073,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index d75a05bda0e1..e19deac906ea 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -472,7 +472,7 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 7ae3469a4c93..88c2941ed503 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -648,7 +648,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -741,7 +741,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 3f770c9ec00b..e68fb3d07de3 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -938,7 +938,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1032,7 +1032,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 0380c6cd49d6..225e590c58a1 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -965,7 +965,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1052,7 +1052,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 50f273ba766c..6f4b4fd61524 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -878,7 +878,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -971,7 +971,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 6422baac5feb..928433c58d53 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -1266,7 +1266,7 @@ def prepare_inputs_for_generation( if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens + past_length = past_key_values.get_past_seen_tokens() else: cache_length = past_length = past_key_values[0][0].shape[2] diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 2ab13b7227ad..4123f833090c 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1048,7 +1048,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1152,7 +1152,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 283e38d3a7d5..52379787ab30 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1139,7 +1139,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1234,7 +1234,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index fe3ad6498172..779034f51b82 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -923,7 +923,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1016,7 +1016,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index e0fdbef1a3ba..9688acc27497 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -851,7 +851,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -945,7 +945,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 079965fc174a..e860d099f9c2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1248,7 +1248,7 @@ def forward( if cache_position is not None: past_key_values_length = cache_position[0] elif past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() + past_key_values_length = past_key_values.get_past_seen_tokens() if cache_position is None: cache_position = torch.arange( @@ -1383,7 +1383,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward @@ -1824,7 +1824,7 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() else: past_length = past_key_values[0][0].shape[2] @@ -2105,7 +2105,7 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, (Cache, EncoderDecoderCache)): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() else: past_length = past_key_values[0][0].shape[2] From b239a5782fc824bebb6ba8c333d8ed1951ee6ba9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 17:13:38 +0200 Subject: [PATCH 20/48] style --- src/transformers/cache_utils.py | 9 ++++++--- src/transformers/generation/utils.py | 3 +-- src/transformers/models/whisper/modeling_whisper.py | 8 ++++++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index a7b1c1f4a29a..c6ba9d1d55e0 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -63,7 +63,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - + def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: """Returns the number of already processed tokens. For all Cache classes except SlidingWindow caches, this is the same as `get_seq_length()`. However, with sliding window we can process more tokens than the cache size. A layer index can be optionally passed. @@ -581,7 +581,7 @@ def __init__(self, sliding_window: int, num_hidden_layers: Optional[int] = None) super().__init__(num_hidden_layers) self.sliding_window = sliding_window # We overwrite the field and maintain a list of size `num_hidden_layers` to accurately reflect the seen tokens at each layer during `update` - self._seen_tokens = [0]*num_hidden_layers if num_hidden_layers is not None else [] + self._seen_tokens = [0] * num_hidden_layers if num_hidden_layers is not None else [] def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: """This needs to be overriden because the number of processed tokens may be larger than the cache length.""" @@ -669,7 +669,9 @@ def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: return out @classmethod - def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"], num_hidden_layers: int) -> "DynamicSlidingWindowCache": + def from_batch_splits( + cls, splits: List["DynamicSlidingWindowCache"], num_hidden_layers: int + ) -> "DynamicSlidingWindowCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" cache = cls(splits[0].sliding_window, num_hidden_layers) @@ -688,6 +690,7 @@ def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"], num_hidden from_legacy_cache = None to_legacy_cache = None + class OffloadedCache(DynamicCache): """ A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9a1e7db6373b..3bb8615cef4c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -28,7 +28,6 @@ from ..cache_utils import ( Cache, DynamicCache, - DynamicSlidingWindowCache, EncoderDecoderCache, OffloadedCache, QuantizedCacheConfig, @@ -4605,7 +4604,7 @@ def _concat(data): return torch.cat(data, dim=0) # New cache format elif isinstance(data[0], (DynamicCache, EncoderDecoderCache)): - return data[0].__class__.from_batch_splits(data) + return data[0].__class__.from_batch_splits(data, num_hidden_layers=num_hidden_layers) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index e860d099f9c2..408ff54f5c55 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1824,7 +1824,9 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() + past_length = ( + cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() + ) else: past_length = past_key_values[0][0].shape[2] @@ -2105,7 +2107,9 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, (Cache, EncoderDecoderCache)): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() + past_length = ( + cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() + ) else: past_length = past_key_values[0][0].shape[2] From ee30eb91db21360dea7c66b693f23cf73c420941 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 17:44:25 +0200 Subject: [PATCH 21/48] trigger CIs From f3af18023df43ff4d407aec58e0c7b66dc66ff9b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 18:41:14 +0200 Subject: [PATCH 22/48] Add tests --- tests/generation/test_utils.py | 78 +++++++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1727aed1117b..ce3a46dae3a1 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,7 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import DynamicCache, DynamicSlidingWindowCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -2024,6 +2024,82 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) + @parameterized.expand([{"do_sample": False}, {'do_sample': False, 'top_k': 2, 'penalty_alpha': 0.5}]) + @pytest.mark.generate + def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dict): + """ + Tests if DynamicSlidingWindowCache works the same as DynamicCache for models that support it. The first expand + is for greedy, and the other is for contrasting search, as contrastive search needs to correctly roll back 1 token + of the cache even with DynamicSlidingWindowCache. + """ + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + if getattr(config, "sliding_window", None) is None: + self.skipTest(reason="This model does not support sliding window.") + + # Make sure we will go beyond the sliding window + config.sliding_window = 3 + model = model_class(config).to(torch_device).eval() + all_generation_kwargs = { + **generation_kwargs, + "max_new_tokens": 20, + "min_new_tokens": 20, + "use_cache": True, + } + + dynamic_cache = DynamicCache() + dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) + + results_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) + results_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) + + self.assertListEqual(results_dynamic, results_sliding_dynamic) + + + @parameterized.expand([False, True]) + @pytest.mark.generate + def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_tokens_than_window: bool): + """ + Tests if we can correctly continue generation with DynamicSlidingWindowCache, even after the cache is "full" (bigger than sliding + window), and we provide more than 1 new token to add to the cache. + """ + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + if getattr(config, "sliding_window", None) is None: + self.skipTest(reason="This model does not support sliding window.") + + # Make sure we will go beyond the sliding window + config.sliding_window = 3 + model = model_class(config).to(torch_device).eval() + all_generation_kwargs = { + "do_sample": False, + "max_new_tokens": 5, + "min_new_tokens": 5, + "use_cache": True, + "return_dict_in_generate": True, + } + + dynamic_cache = DynamicCache() + dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) + + out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) + out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) + + results_dynamic, dynamic_cache = out_dynamic.sequences, out_dynamic.past_key_values + results_sliding_dynamic, dynamic_sliding_cache = out_sliding_dynamic.sequences, out_sliding_dynamic.past_key_values + + self.assertListEqual(results_dynamic, results_sliding_dynamic) + + bs = results_dynamic.shape[0] + num_added_tokens = 2 if not add_more_tokens_than_window else 4 + added_tokens = ids_tensor((bs, num_added_tokens), vocab_size=config.vocab_size) + input_ids = torch.cat([results_dynamic, added_tokens], dim=-1) + + out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) + out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) + + self.assertListEqual(out_dynamic.sequences, out_sliding_dynamic.sequences) + def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1): batch_size = main_input.shape[0] seq_length = main_input.shape[-1] From 25cd9c071afd6956c5c75feba817dc5231830e5d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 18:43:47 +0200 Subject: [PATCH 23/48] Update test_utils.py --- tests/generation/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ce3a46dae3a1..e728483739b0 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2053,7 +2053,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic results_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) results_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) - self.assertListEqual(results_dynamic, results_sliding_dynamic) + self.assertTrue((results_dynamic ==results_sliding_dynamic).all()) @parameterized.expand([False, True]) @@ -2088,7 +2088,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke results_dynamic, dynamic_cache = out_dynamic.sequences, out_dynamic.past_key_values results_sliding_dynamic, dynamic_sliding_cache = out_sliding_dynamic.sequences, out_sliding_dynamic.past_key_values - self.assertListEqual(results_dynamic, results_sliding_dynamic) + self.assertTrue((results_dynamic ==results_sliding_dynamic).all()) bs = results_dynamic.shape[0] num_added_tokens = 2 if not add_more_tokens_than_window else 4 @@ -2098,7 +2098,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) - self.assertListEqual(out_dynamic.sequences, out_sliding_dynamic.sequences) + self.assertTrue((out_dynamic.sequences == out_sliding_dynamic.sequences).all()) def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1): batch_size = main_input.shape[0] From b2f7dee6996f652044f3870837d5f6587951d75f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 18:45:23 +0200 Subject: [PATCH 24/48] Update test_utils.py --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e728483739b0..6d3d3b63ffc2 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2056,7 +2056,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic self.assertTrue((results_dynamic ==results_sliding_dynamic).all()) - @parameterized.expand([False, True]) + @parameterized.expand([(False,), (True,)]) @pytest.mark.generate def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_tokens_than_window: bool): """ From b5492900832785c96da390c2699eaab7119b776f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 18:49:30 +0200 Subject: [PATCH 25/48] Update test_utils.py --- tests/generation/test_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6d3d3b63ffc2..e10a741a0363 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2024,7 +2024,8 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) - @parameterized.expand([{"do_sample": False}, {'do_sample': False, 'top_k': 2, 'penalty_alpha': 0.5}]) + + @parameterized.expand([({"do_sample": False},), ({'do_sample': False, 'top_k': 2, 'penalty_alpha': 0.5},)]) @pytest.mark.generate def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dict): """ @@ -2053,7 +2054,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic results_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) results_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) - self.assertTrue((results_dynamic ==results_sliding_dynamic).all()) + self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) @parameterized.expand([(False,), (True,)]) @@ -2088,7 +2089,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke results_dynamic, dynamic_cache = out_dynamic.sequences, out_dynamic.past_key_values results_sliding_dynamic, dynamic_sliding_cache = out_sliding_dynamic.sequences, out_sliding_dynamic.past_key_values - self.assertTrue((results_dynamic ==results_sliding_dynamic).all()) + self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) bs = results_dynamic.shape[0] num_added_tokens = 2 if not add_more_tokens_than_window else 4 @@ -2098,7 +2099,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) - self.assertTrue((out_dynamic.sequences == out_sliding_dynamic.sequences).all()) + self.assertListEqual(out_dynamic.sequences.tolist(), out_sliding_dynamic.sequences.tolist()) def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1): batch_size = main_input.shape[0] From f052bede920398f947caec34f743c810c267f2e1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 3 Oct 2024 15:55:31 +0200 Subject: [PATCH 26/48] Update causal mask generation in case of DynamicSlidingCache (only Mistral) --- .../models/mistral/modeling_mistral.py | 11 ++++- tests/generation/test_utils.py | 49 ++++++++++++++----- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 21053e361310..5d596eaaf725 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -918,6 +918,15 @@ def _update_causal_mask( past_key_values=past_key_values, ) + if isinstance(past_key_values, DynamicSlidingWindowCache): + current_cache_length = past_key_values.get_seq_length() + if sequence_length + current_cache_length > self.config.sliding_window: + target_length = sequence_length + self.config.sliding_window - 1 + else: + target_length = current_cache_length + sequence_length + # Slice the causal mask to get only relevant part of the same shape as the keys/values + causal_mask = causal_mask[:, :, :, -target_length:] + if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e10a741a0363..53ac08e4a1be 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,13 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, DynamicSlidingWindowCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import ( + DynamicCache, + DynamicSlidingWindowCache, + EncoderDecoderCache, + QuantoQuantizedCache, + StaticCache, + ) from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -2024,8 +2030,7 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) - - @parameterized.expand([({"do_sample": False},), ({'do_sample': False, 'top_k': 2, 'penalty_alpha': 0.5},)]) + @parameterized.expand([({"do_sample": False},), ({"do_sample": False, "top_k": 2, "penalty_alpha": 0.5},)]) @pytest.mark.generate def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dict): """ @@ -2050,12 +2055,18 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic dynamic_cache = DynamicCache() dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) - - results_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) - results_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) - self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) + results_dynamic = model.generate( + input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache + ) + results_sliding_dynamic = model.generate( + input_ids, + attention_mask=attention_mask, + **all_generation_kwargs, + past_key_values=dynamic_sliding_cache, + ) + self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) @parameterized.expand([(False,), (True,)]) @pytest.mark.generate @@ -2082,12 +2093,22 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke dynamic_cache = DynamicCache() dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) - - out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) - out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) + + out_dynamic = model.generate( + input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache + ) + out_sliding_dynamic = model.generate( + input_ids, + attention_mask=attention_mask, + **all_generation_kwargs, + past_key_values=dynamic_sliding_cache, + ) results_dynamic, dynamic_cache = out_dynamic.sequences, out_dynamic.past_key_values - results_sliding_dynamic, dynamic_sliding_cache = out_sliding_dynamic.sequences, out_sliding_dynamic.past_key_values + results_sliding_dynamic, dynamic_sliding_cache = ( + out_sliding_dynamic.sequences, + out_sliding_dynamic.past_key_values, + ) self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) @@ -2096,8 +2117,10 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke added_tokens = ids_tensor((bs, num_added_tokens), vocab_size=config.vocab_size) input_ids = torch.cat([results_dynamic, added_tokens], dim=-1) - out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) - out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) + out_dynamic = model.generate(input_ids, **all_generation_kwargs, past_key_values=dynamic_cache) + out_sliding_dynamic = model.generate( + input_ids, **all_generation_kwargs, past_key_values=dynamic_sliding_cache + ) self.assertListEqual(out_dynamic.sequences.tolist(), out_sliding_dynamic.sequences.tolist()) From e091f4de58306d9d88d20b1a1e13976f47ef1ca0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 3 Oct 2024 16:46:00 +0200 Subject: [PATCH 27/48] Improve tests --- src/transformers/cache_utils.py | 1 + tests/generation/test_utils.py | 33 +++++++++++++++++---------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c6ba9d1d55e0..2f0becf33cb9 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -551,6 +551,7 @@ def batch_select_indices(self, indices: torch.Tensor): self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] +# TODO: (cyril) Make this the default for models with sliding window once `generate` no longer returns Cache as tuples class DynamicSlidingWindowCache(DynamicCache): """ A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 53ac08e4a1be..5810b0d29c44 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2030,7 +2030,9 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) - @parameterized.expand([({"do_sample": False},), ({"do_sample": False, "top_k": 2, "penalty_alpha": 0.5},)]) + @parameterized.expand( + [({"do_sample": False},), ({"do_sample": False, "top_k": 2, "penalty_alpha": 0.5, "low_memory": True},)] + ) @pytest.mark.generate def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dict): """ @@ -2068,20 +2070,25 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) - @parameterized.expand([(False,), (True,)]) + @parameterized.expand([(3, 1), (3, 4), (14, 5)]) @pytest.mark.generate - def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_tokens_than_window: bool): + def test_generate_continue_from_dynamic_sliding_window_cache(self, sliding_window: int, additional_tokens: int): """ - Tests if we can correctly continue generation with DynamicSlidingWindowCache, even after the cache is "full" (bigger than sliding - window), and we provide more than 1 new token to add to the cache. + Tests if we can correctly continue generation with DynamicSlidingWindowCache. + - First case tests that we can continue if the cache is already full, and we add less tokens than the sliding window + - Second case tests that we can continue if the cache is already full, and we add more tokens that the sliding window + - Third case tests that we can continue if the cache is not full, and we add tokens so that the new input is bigger than the sliding window """ for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, _, _, _ = self._get_input_ids_and_config() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") + # We need to be sure to always have shape (2, 7) for the different test assumptions to hold + input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) + # Make sure we will go beyond the sliding window - config.sliding_window = 3 + config.sliding_window = sliding_window model = model_class(config).to(torch_device).eval() all_generation_kwargs = { "do_sample": False, @@ -2094,14 +2101,9 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke dynamic_cache = DynamicCache() dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) - out_dynamic = model.generate( - input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache - ) + out_dynamic = model.generate(input_ids, **all_generation_kwargs, past_key_values=dynamic_cache) out_sliding_dynamic = model.generate( - input_ids, - attention_mask=attention_mask, - **all_generation_kwargs, - past_key_values=dynamic_sliding_cache, + input_ids, **all_generation_kwargs, past_key_values=dynamic_sliding_cache ) results_dynamic, dynamic_cache = out_dynamic.sequences, out_dynamic.past_key_values @@ -2113,8 +2115,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) bs = results_dynamic.shape[0] - num_added_tokens = 2 if not add_more_tokens_than_window else 4 - added_tokens = ids_tensor((bs, num_added_tokens), vocab_size=config.vocab_size) + added_tokens = ids_tensor((bs, additional_tokens), vocab_size=config.vocab_size) input_ids = torch.cat([results_dynamic, added_tokens], dim=-1) out_dynamic = model.generate(input_ids, **all_generation_kwargs, past_key_values=dynamic_cache) From 9a30ad414ad3c13a0a9ac06b56fd1d671f4aaaa7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 18:31:44 +0200 Subject: [PATCH 28/48] improve cache --- src/transformers/cache_utils.py | 74 +++++++++++++-------------------- tests/generation/test_utils.py | 11 ++--- 2 files changed, 33 insertions(+), 52 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2f0becf33cb9..bc6f1e5ad8b6 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -558,7 +558,11 @@ class DynamicSlidingWindowCache(DynamicCache): This will be the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used). It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window. + `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window-1, head_dim]` if seq_len >= sliding_window-1. + + Note: Since we only keep maximum `sliding_window-1` tokens in the cache, once this value is reached the cache can no + longer be roll-backed to previous states without losing information. For this reason, it should not be used with assisted decoding + (or contrastive search when using `low_memory=True`). Example: @@ -578,11 +582,11 @@ class DynamicSlidingWindowCache(DynamicCache): ``` """ - def __init__(self, sliding_window: int, num_hidden_layers: Optional[int] = None) -> None: - super().__init__(num_hidden_layers) + def __init__(self, sliding_window: int) -> None: + super().__init__() self.sliding_window = sliding_window # We overwrite the field and maintain a list of size `num_hidden_layers` to accurately reflect the seen tokens at each layer during `update` - self._seen_tokens = [0] * num_hidden_layers if num_hidden_layers is not None else [] + self._seen_tokens = [] def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: """This needs to be overriden because the number of processed tokens may be larger than the cache length.""" @@ -602,10 +606,6 @@ def update( Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous tokens according to the sliding window if needed. - Note: we always keep `sliding_window` tokens in the cache if it is full, instead of the `sliding_window - 1` tokens that - are strictly necesary. This allows to roll back one token in the past with `cache.crop(-1)` in contrastive search. - Assisted decoding would need to roll back additional tokens, and is therefore not supported with this Cache class. - Parameters: key_states (`torch.Tensor`): The new key states to cache. @@ -623,46 +623,26 @@ def update( # Update the number of seen tokens self._seen_tokens.append(key_states.shape[-2]) # Add only up to sliding window size if larger - self.key_cache.append(key_states[..., -self.sliding_window :, :]) - self.value_cache.append(value_states[..., -self.sliding_window :, :]) - # We should return full states during prefill even though we only save up to sliding window - return key_states, value_states - # In case we initialized empty lists - elif self.key_cache[layer_idx] == []: - self._seen_tokens[layer_idx] += key_states.shape[-2] - self.key_cache[layer_idx] = key_states[..., -self.sliding_window :, :] - self.value_cache[layer_idx] = value_states[..., -self.sliding_window :, :] - # We should return full states during prefill even though we only save up to sliding window + self.key_cache.append(key_states[..., -self.sliding_window+1 :, :]) + self.value_cache.append(value_states[..., -self.sliding_window+1 :, :]) + # We should return full states during prefill even though we only save up to sliding window-1 return key_states, value_states else: self._seen_tokens[layer_idx] += key_states.shape[-2] - new_seq_len = key_states.shape[-2] - current_seq_len = self.get_seq_length(layer_idx) - if new_seq_len + current_seq_len > self.sliding_window: - # We may need to return longer states (e.g. to continue generation with previous cache, with added tokens), but we only keep - # the last `sliding_window` states in the cache for next forward - full_key_states = torch.cat( - [self.key_cache[layer_idx][..., -(self.sliding_window - 1) :, :], key_states], dim=-2 - ) - full_value_states = torch.cat( - [self.value_cache[layer_idx][..., -(self.sliding_window - 1) :, :], value_states], dim=-2 - ) - self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window :, :] - self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window :, :] - return full_key_states, full_value_states - else: - # Similar to DynamicCache - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]: + # We may need to return longer states (e.g. to continue generation with previous cache, with added tokens), but we only keep + # the last `sliding_window-1` states in the cache for next forward + full_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + full_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window+1 :, :] + self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window+1 :, :] + return full_key_states, full_value_states + + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by `_split_model_inputs()` in `generation.utils`""" out = [] for i in range(0, full_batch_size, split_size): - current_split = DynamicSlidingWindowCache(self.sliding_window, num_hidden_layers) + current_split = DynamicSlidingWindowCache(self.sliding_window) current_split._seen_tokens = self._seen_tokens current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] @@ -671,11 +651,10 @@ def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: @classmethod def from_batch_splits( - cls, splits: List["DynamicSlidingWindowCache"], num_hidden_layers: int - ) -> "DynamicSlidingWindowCache": + cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" - cache = cls(splits[0].sliding_window, num_hidden_layers) + cache = cls(splits[0].sliding_window) for idx in range(len(splits[0])): key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] @@ -687,6 +666,13 @@ def from_batch_splits( # We need this because _seen_tokens may be bigger than what will be automatically set with `update` (if cache > sliding_window) cache._seen_tokens = splits[0]._seen_tokens return cache + + def crop(self, max_length: int): + + if self.get_past_seen_tokens() >= self.sliding_window - 1: + raise RuntimeError(f"The current DynamicSlidingWindowCache is full. It cannot be cropped as this would mean losing past states.") + else: + super().crop(max_length) from_legacy_cache = None to_legacy_cache = None diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5810b0d29c44..41e6f914a95e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2030,15 +2030,10 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) - @parameterized.expand( - [({"do_sample": False},), ({"do_sample": False, "top_k": 2, "penalty_alpha": 0.5, "low_memory": True},)] - ) @pytest.mark.generate - def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dict): + def test_generate_with_dynamic_sliding_window_cache(self): """ - Tests if DynamicSlidingWindowCache works the same as DynamicCache for models that support it. The first expand - is for greedy, and the other is for contrasting search, as contrastive search needs to correctly roll back 1 token - of the cache even with DynamicSlidingWindowCache. + Tests if DynamicSlidingWindowCache works the same as DynamicCache for models that support it. """ for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() @@ -2049,7 +2044,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic config.sliding_window = 3 model = model_class(config).to(torch_device).eval() all_generation_kwargs = { - **generation_kwargs, + "do_sample": False, "max_new_tokens": 20, "min_new_tokens": 20, "use_cache": True, From 8202a19feb68b68809e4446d192af73cee682bdb Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 18:43:24 +0200 Subject: [PATCH 29/48] add exceptions --- src/transformers/generation/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3bb8615cef4c..75f56bebdaca 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -28,6 +28,7 @@ from ..cache_utils import ( Cache, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, OffloadedCache, QuantizedCacheConfig, @@ -2130,6 +2131,8 @@ def generate( raise ValueError( f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" ) + if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache): + raise ValueError("DynamicSlidingWindowCache cannot be used in assisted generation.") # 11. Get the candidate generator, given the parameterization candidate_generator = self._get_candidate_generator( @@ -2179,6 +2182,8 @@ def generate( raise ValueError( f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" ) + if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr(model_kwargs, "low_memory", False): + raise ValueError("DynamicSlidingWindowCache cannot be used in contrastive generation with `low_memory=True`.") result = self._contrastive_search( input_ids, From 55a39a6546b75a63b353382ba61854db065b41f6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 18:47:47 +0200 Subject: [PATCH 30/48] Update utils.py --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 75f56bebdaca..c4ba43fa8929 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2182,7 +2182,7 @@ def generate( raise ValueError( f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" ) - if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr(model_kwargs, "low_memory", False): + if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr(generation_config, "low_memory", False): raise ValueError("DynamicSlidingWindowCache cannot be used in contrastive generation with `low_memory=True`.") result = self._contrastive_search( From 9caf947f7e8531e22a3839a7975359cbecc8ee22 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 18:54:24 +0200 Subject: [PATCH 31/48] Update test_utils.py --- tests/generation/test_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 41e6f914a95e..4070cce7d65d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2030,16 +2030,26 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) + @parameterized.expand([(False,), (True,)]) @pytest.mark.generate - def test_generate_with_dynamic_sliding_window_cache(self): + def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): """ Tests if DynamicSlidingWindowCache works the same as DynamicCache for models that support it. """ for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, _, _, inputs_dict = self._get_input_ids_and_config() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") + input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) + if left_padding: + attention_mask = torch.tensor([ + [0,0,0,0,1,1,1], + [1,1,1,1,1,1,1], + ], device=input_ids.device, dtype=int) + else: + attention_mask = torch.ones_like(input_ids) + # Make sure we will go beyond the sliding window config.sliding_window = 3 model = model_class(config).to(torch_device).eval() From 1404cece8b15d232e7836b27c99e4e4ac597f622 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 18:59:13 +0200 Subject: [PATCH 32/48] Update test_utils.py --- tests/generation/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 4070cce7d65d..d1c7f66b5ebd 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2037,7 +2037,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): Tests if DynamicSlidingWindowCache works the same as DynamicCache for models that support it. """ for model_class in self.all_generative_model_classes: - config, _, _, inputs_dict = self._get_input_ids_and_config() + config, _ = self.prepare_config_and_inputs_for_generate() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") @@ -2085,7 +2085,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, sliding_windo - Third case tests that we can continue if the cache is not full, and we add tokens so that the new input is bigger than the sliding window """ for model_class in self.all_generative_model_classes: - config, _, _, _ = self._get_input_ids_and_config() + config, _ = self.prepare_config_and_inputs_for_generate() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") From 4f3ba863d8904852763b58f024e4d3b24dd7e9dc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 19:00:27 +0200 Subject: [PATCH 33/48] Update test_utils.py --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d1c7f66b5ebd..6bae77d175be 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2044,7 +2044,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: attention_mask = torch.tensor([ - [0,0,0,0,1,1,1], + [0,0,0,0,0,1,1], [1,1,1,1,1,1,1], ], device=input_ids.device, dtype=int) else: From 44331f107f8c3587c069b868df43a511ffb6b754 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 19:03:04 +0200 Subject: [PATCH 34/48] Update test_utils.py --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6bae77d175be..756429070a82 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2044,7 +2044,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: attention_mask = torch.tensor([ - [0,0,0,0,0,1,1], + [0,0,0,1,1,1,1], [1,1,1,1,1,1,1], ], device=input_ids.device, dtype=int) else: From b5ebae2aadd06ae85ef733061bae8470538476dc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 9 Oct 2024 11:31:29 +0200 Subject: [PATCH 35/48] Update test_utils.py --- tests/generation/test_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 756429070a82..fa196b128242 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2040,6 +2040,8 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): config, _ = self.prepare_config_and_inputs_for_generate() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") + if "qwen2" in str(model_class).lower(): + self.skipTest(reason="Sliding window attention is not implemented for sdpa in Qwen2 models.") input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: @@ -2088,6 +2090,8 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, sliding_windo config, _ = self.prepare_config_and_inputs_for_generate() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") + if "qwen2" in str(model_class).lower(): + self.skipTest(reason="Sliding window attention is not implemented for sdpa in Qwen2 models.") # We need to be sure to always have shape (2, 7) for the different test assumptions to hold input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) From 7e78258d5fc1fbe59001de7e33e63ad9c8dbd8f6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 15:41:24 +0200 Subject: [PATCH 36/48] Update 4d mask creation in Mistral --- src/transformers/cache_utils.py | 18 +++++----- src/transformers/generation/utils.py | 8 +++-- .../models/mistral/modeling_mistral.py | 34 ++++++++++++------- tests/generation/test_utils.py | 12 ++++--- 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index bc6f1e5ad8b6..ddc9db0d02d0 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -623,8 +623,8 @@ def update( # Update the number of seen tokens self._seen_tokens.append(key_states.shape[-2]) # Add only up to sliding window size if larger - self.key_cache.append(key_states[..., -self.sliding_window+1 :, :]) - self.value_cache.append(value_states[..., -self.sliding_window+1 :, :]) + self.key_cache.append(key_states[..., -self.sliding_window + 1 :, :]) + self.value_cache.append(value_states[..., -self.sliding_window + 1 :, :]) # We should return full states during prefill even though we only save up to sliding window-1 return key_states, value_states else: @@ -633,8 +633,8 @@ def update( # the last `sliding_window-1` states in the cache for next forward full_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) full_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window+1 :, :] - self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window+1 :, :] + self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window + 1 :, :] + self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window + 1 :, :] return full_key_states, full_value_states def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: @@ -650,8 +650,7 @@ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCac return out @classmethod - def from_batch_splits( - cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": + def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" cache = cls(splits[0].sliding_window) @@ -666,11 +665,12 @@ def from_batch_splits( # We need this because _seen_tokens may be bigger than what will be automatically set with `update` (if cache > sliding_window) cache._seen_tokens = splits[0]._seen_tokens return cache - - def crop(self, max_length: int): + def crop(self, max_length: int): if self.get_past_seen_tokens() >= self.sliding_window - 1: - raise RuntimeError(f"The current DynamicSlidingWindowCache is full. It cannot be cropped as this would mean losing past states.") + raise RuntimeError( + "The current DynamicSlidingWindowCache is full. It cannot be cropped as this would mean losing past states." + ) else: super().crop(max_length) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c4ba43fa8929..83a21e590df9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2182,8 +2182,12 @@ def generate( raise ValueError( f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" ) - if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr(generation_config, "low_memory", False): - raise ValueError("DynamicSlidingWindowCache cannot be used in contrastive generation with `low_memory=True`.") + if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr( + generation_config, "low_memory", False + ): + raise ValueError( + "DynamicSlidingWindowCache cannot be used in contrastive generation with `low_memory=True`." + ) result = self._contrastive_search( input_ids, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 5d596eaaf725..4cffd6baabb5 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -904,12 +904,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -918,15 +924,6 @@ def _update_causal_mask( past_key_values=past_key_values, ) - if isinstance(past_key_values, DynamicSlidingWindowCache): - current_cache_length = past_key_values.get_seq_length() - if sequence_length + current_cache_length > self.config.sliding_window: - target_length = sequence_length + self.config.sliding_window - 1 - else: - target_length = current_cache_length + sequence_length - # Slice the causal mask to get only relevant part of the same shape as the keys/values - causal_mask = causal_mask[:, :, :, -target_length:] - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None @@ -945,6 +942,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -963,6 +961,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -982,14 +983,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1000,7 +1006,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index fa196b128242..704b4a44b8f0 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2045,10 +2045,14 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: - attention_mask = torch.tensor([ - [0,0,0,1,1,1,1], - [1,1,1,1,1,1,1], - ], device=input_ids.device, dtype=int) + attention_mask = torch.tensor( + [ + [0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + ], + device=input_ids.device, + dtype=int, + ) else: attention_mask = torch.ones_like(input_ids) From 301f7f2d7c7ab64dc47d9b3b7b73a19e824b591d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 15:48:47 +0200 Subject: [PATCH 37/48] fix missed conflict --- src/transformers/models/idefics2/modeling_idefics2.py | 7 +------ src/transformers/models/idefics3/modeling_idefics3.py | 7 +------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 9ca404d3cb45..6f3c2ffbbb5a 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1664,13 +1664,8 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore -<<<<<<< HEAD - past_length = past_key_values.get_seq_length() - max_cache_length = past_key_values.get_max_cache_shape() -======= past_length = past_key_values.get_past_seen_tokens() - max_cache_length = past_key_values.get_max_length() ->>>>>>> 0c098e35c (Modify all current .get_seq_length names) + max_cache_length = past_key_values.get_max_cache_shape() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index f30c83e7ade3..9c65b2e01b76 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -1252,13 +1252,8 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore -<<<<<<< HEAD - past_length = past_key_values.get_seq_length() - max_cache_length = past_key_values.get_max_cache_shape() -======= past_length = past_key_values.get_past_seen_tokens() - max_cache_length = past_key_values.get_max_length() ->>>>>>> 0c098e35c (Modify all current .get_seq_length names) + max_cache_length = past_key_values.get_max_cache_shape() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where From be18801446b5d70083c9a9c2069c3463e3c25603 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 15:57:55 +0200 Subject: [PATCH 38/48] Apply to other models --- src/transformers/models/mimi/modeling_mimi.py | 27 +++++++++++++---- .../models/mixtral/modeling_mixtral.py | 27 +++++++++++++---- src/transformers/models/phi3/modeling_phi3.py | 27 +++++++++++++---- .../models/phimoe/modeling_phimoe.py | 29 +++++++++++++++---- .../models/qwen2/modeling_qwen2.py | 27 +++++++++++++---- .../models/qwen2_moe/modeling_qwen2_moe.py | 27 +++++++++++++---- .../models/qwen2_vl/modeling_qwen2_vl.py | 27 +++++++++++++---- .../models/starcoder2/modeling_starcoder2.py | 27 +++++++++++++---- 8 files changed, 177 insertions(+), 41 deletions(-) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 49d83cf800e0..985ca1fe275a 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -23,7 +23,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel @@ -1076,12 +1076,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1109,6 +1115,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1127,6 +1134,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1146,14 +1156,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1164,7 +1179,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 2619d5792f0c..8af557153dce 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1117,12 +1117,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1150,6 +1156,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1168,6 +1175,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1187,14 +1197,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1205,7 +1220,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 225e590c58a1..aa7aff4c31f1 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1084,12 +1084,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1117,6 +1123,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1135,6 +1142,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1154,14 +1164,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1172,7 +1187,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index d1705f04ddb7..fca68092b0af 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1220,7 +1220,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1252,12 +1252,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1285,6 +1291,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1303,6 +1310,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1322,14 +1332,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1340,7 +1355,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 6f4b4fd61524..78dda718d8cf 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1003,12 +1003,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1036,6 +1042,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1054,6 +1061,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1073,14 +1083,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1091,7 +1106,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 4123f833090c..e512f2beeab5 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1184,12 +1184,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1217,6 +1223,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1235,6 +1242,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1254,14 +1264,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1272,7 +1287,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 52379787ab30..242eaa01f3f1 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -30,7 +30,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN -from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -1266,12 +1266,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1299,6 +1305,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1317,6 +1324,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1336,14 +1346,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1354,7 +1369,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 9688acc27497..42821a3a0ea9 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -977,12 +977,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1010,6 +1016,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1028,6 +1035,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1047,14 +1057,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1065,7 +1080,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype From 734e3fe75eb0072496a93ec2e336ae07e0045b4d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 16:09:06 +0200 Subject: [PATCH 39/48] Add required arg in prepare_inoput --- src/transformers/models/mistral/modeling_mistral.py | 1 + src/transformers/models/mixtral/modeling_mixtral.py | 1 + src/transformers/models/phi3/modeling_phi3.py | 1 + src/transformers/models/phimoe/modeling_phimoe.py | 1 + src/transformers/models/qwen2/modeling_qwen2.py | 1 + src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 1 + src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 1 + src/transformers/models/starcoder2/modeling_starcoder2.py | 1 + 8 files changed, 8 insertions(+) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 4cffd6baabb5..5f1cadca2716 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1195,6 +1195,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 8af557153dce..a78286ea570c 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1430,6 +1430,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index aa7aff4c31f1..140fd90031e0 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1406,6 +1406,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index fca68092b0af..33dd603fdb5e 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1596,6 +1596,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 78dda718d8cf..fe8cd476965f 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1296,6 +1296,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index e512f2beeab5..d361c421fed2 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1500,6 +1500,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 242eaa01f3f1..dc0d3ae9e91e 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1855,6 +1855,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 42821a3a0ea9..180f560856fe 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1272,6 +1272,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, From 106c4100b151300097f131288082cb824bf83e65 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 16:30:04 +0200 Subject: [PATCH 40/48] Update test_utils.py --- tests/generation/test_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 704b4a44b8f0..ddf40acfe8fd 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2038,10 +2038,10 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): """ for model_class in self.all_generative_model_classes: config, _ = self.prepare_config_and_inputs_for_generate() - if getattr(config, "sliding_window", None) is None: + if not hasattr(config, "sliding_window"): self.skipTest(reason="This model does not support sliding window.") - if "qwen2" in str(model_class).lower(): - self.skipTest(reason="Sliding window attention is not implemented for sdpa in Qwen2 models.") + if hasattr(config, "cache_implementation"): + self.skipTest(reason="This model uses a specific cache format.") input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: @@ -2092,10 +2092,10 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, sliding_windo """ for model_class in self.all_generative_model_classes: config, _ = self.prepare_config_and_inputs_for_generate() - if getattr(config, "sliding_window", None) is None: + if not hasattr(config, "sliding_window"): self.skipTest(reason="This model does not support sliding window.") - if "qwen2" in str(model_class).lower(): - self.skipTest(reason="Sliding window attention is not implemented for sdpa in Qwen2 models.") + if hasattr(config, "cache_implementation"): + self.skipTest(reason="This model uses a specific cache format.") # We need to be sure to always have shape (2, 7) for the different test assumptions to hold input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) From 0d8e9acf4505fa6bcf20838fd2737939fa9d8dce Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 16:53:39 +0200 Subject: [PATCH 41/48] Update test_utils.py --- tests/generation/test_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ddf40acfe8fd..2aa8d154cd0e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2042,6 +2042,8 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): self.skipTest(reason="This model does not support sliding window.") if hasattr(config, "cache_implementation"): self.skipTest(reason="This model uses a specific cache format.") + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support Cache classes") input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: @@ -2096,6 +2098,8 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, sliding_windo self.skipTest(reason="This model does not support sliding window.") if hasattr(config, "cache_implementation"): self.skipTest(reason="This model uses a specific cache format.") + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support Cache classes") # We need to be sure to always have shape (2, 7) for the different test assumptions to hold input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) From 85090534a6d5f70745fdd8b11c356f8f2110b4a0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 22:47:57 +0200 Subject: [PATCH 42/48] Fix kv_seq_length and rotary_seq_length --- .../models/mistral/modeling_mistral.py | 2 +- .../models/mixtral/modeling_mixtral.py | 29 ++++++++++++------- src/transformers/models/phi3/modeling_phi3.py | 26 ++--------------- .../models/qwen2/modeling_qwen2.py | 3 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 3 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 7 ----- .../models/starcoder2/modeling_starcoder2.py | 3 +- 7 files changed, 29 insertions(+), 44 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 5f1cadca2716..5a3dd65809f8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -313,7 +313,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += cache_position[0] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index a78286ea570c..e9e795566f1f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -329,6 +329,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_length = kv_seq_len if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -337,7 +338,12 @@ def forward( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -412,6 +418,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_length = kv_seq_len if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -420,13 +427,12 @@ def forward( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = ( - max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len - ) - - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -558,10 +564,13 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] + rotary_seq_length = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 140fd90031e0..58f4aa46f76f 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -377,17 +377,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -482,13 +472,7 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = ( - max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len - ) - - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids) - + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -625,11 +609,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index fe8cd476965f..8c2ee4ec8828 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -395,7 +395,8 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - kv_seq_len = key_states.shape[-2] + cache_position[0] + kv_seq_len = key_states.shape[-2] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index d361c421fed2..a4a4d3491331 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -483,7 +483,8 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - kv_seq_len = key_states.shape[-2] + cache_position[0] + kv_seq_len = key_states.shape[-2] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index dc0d3ae9e91e..2346ede2c4c8 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -549,10 +549,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += cache_position[0] + 1 - if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -809,9 +805,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 180f560856fe..b7d548029bb9 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -375,7 +375,8 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - kv_seq_len = key_states.shape[-2] + cache_position[0] + kv_seq_len = key_states.shape[-2] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window From 2ae645fba95861449d7162419602c35ad144b405 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 00:44:33 +0200 Subject: [PATCH 43/48] up --- .../models/mixtral/modeling_mixtral.py | 22 +++---------------- src/transformers/models/phi3/modeling_phi3.py | 6 ++--- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index e9e795566f1f..54ecec1a78ab 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -329,7 +329,6 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] - rotary_seq_length = kv_seq_len if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -338,12 +337,8 @@ def forward( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() - else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) + cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -418,7 +413,6 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] - rotary_seq_length = kv_seq_len if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -427,12 +421,8 @@ def forward( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() - else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) + cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -564,13 +554,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - rotary_seq_length = key_states.shape[-2] - if past_key_value is not None: - if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() - else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) + cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 58f4aa46f76f..707bd1e8621b 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -377,7 +377,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -472,7 +472,7 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids=position_ids) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -609,7 +609,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: From 8d539e6a87eff3804c475c44cf28d78b53e0de1d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 01:17:55 +0200 Subject: [PATCH 44/48] up --- .../models/mixtral/modeling_mixtral.py | 15 ++++++++++++--- src/transformers/models/phi3/modeling_phi3.py | 15 ++++++++++++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 54ecec1a78ab..415fdcac8505 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -338,7 +338,10 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -422,7 +425,10 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -554,7 +560,10 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 707bd1e8621b..7746f418f8be 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -377,7 +377,10 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -472,7 +475,10 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -609,7 +615,10 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: From e808fa53e2ba2fd17e5b3fe47b75357815cddd46 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 01:31:31 +0200 Subject: [PATCH 45/48] up --- .../models/mixtral/modeling_mixtral.py | 36 ++++++++----------- src/transformers/models/phi3/modeling_phi3.py | 26 ++++++++------ 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 415fdcac8505..1b76aca214ed 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -329,18 +329,14 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - rotary_seq_len = key_states.shape[-2] - if past_key_value is not None: - rotary_seq_len += cache_position[0] cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -416,20 +412,15 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - rotary_seq_len = key_states.shape[-2] - if past_key_value is not None: - rotary_seq_len += cache_position[0] cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -562,9 +553,12 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - rotary_seq_len += cache_position[0] - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 7746f418f8be..be6192cd8aae 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -379,8 +379,12 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - rotary_seq_len += cache_position[0] + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -466,18 +470,14 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - rotary_seq_len = key_states.shape[-2] - if past_key_value is not None: - rotary_seq_len += cache_position[0] cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -617,8 +617,12 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - rotary_seq_len += cache_position[0] + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: From 8499f942158dc6cdb600b37217571b88bc5018c0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 01:35:08 +0200 Subject: [PATCH 46/48] up --- .../models/mixtral/modeling_mixtral.py | 12 ++++++------ src/transformers/models/phi3/modeling_phi3.py | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1b76aca214ed..5a1112a60214 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -333,9 +333,9 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -416,9 +416,9 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -554,9 +554,9 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index be6192cd8aae..e79e91b059ee 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -380,9 +380,9 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -474,9 +474,9 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -618,11 +618,11 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: From 687986696a258551add56deab329b0d89201fcef Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 10:32:04 +0200 Subject: [PATCH 47/48] CIs From fe8a625a1c0545772fd19b193e6042d2414b9444 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 11:51:45 +0200 Subject: [PATCH 48/48] improve sdpa is_causal escape --- src/transformers/modeling_attn_mask_utils.py | 6 +++--- src/transformers/models/mimi/modeling_mimi.py | 3 ++- src/transformers/models/mistral/modeling_mistral.py | 3 ++- src/transformers/models/mixtral/modeling_mixtral.py | 3 ++- src/transformers/models/phi3/modeling_phi3.py | 3 ++- src/transformers/models/phimoe/modeling_phimoe.py | 3 ++- src/transformers/models/qwen2/modeling_qwen2.py | 3 ++- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 3 ++- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 3 ++- src/transformers/models/starcoder2/modeling_starcoder2.py | 3 ++- 10 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 4319c021cb2b..64b64e889962 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -275,13 +275,13 @@ def _ignore_causal_mask_sdpa( if ( (is_training or not is_tracing) and (query_length == 1 or key_value_length == query_length) - and (sliding_window is None or key_value_length < sliding_window) + and (sliding_window is None or key_value_length <= sliding_window) ): ignore_causal_mask = True - elif sliding_window is None or key_value_length < sliding_window: + elif sliding_window is None or key_value_length <= sliding_window: if len(attention_mask.shape) == 4: return False - elif not is_tracing and torch.all(attention_mask == 1): + elif not is_tracing and torch.all(attention_mask[:, -key_value_length:] == 1): if query_length == 1 or key_value_length == query_length: # For query_length == 1, causal attention and bi-directional attention are the same. ignore_causal_mask = True diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 985ca1fe275a..6120e2576d68 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1044,6 +1044,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1057,7 +1058,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 5a3dd65809f8..464172172ae8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -872,6 +872,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -885,7 +886,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 5a1112a60214..05db01c83f3b 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1081,6 +1081,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1094,7 +1095,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index e79e91b059ee..3dd799f59177 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1045,6 +1045,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1058,7 +1059,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 33dd603fdb5e..298777ecbcca 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1220,6 +1220,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1233,7 +1234,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8c2ee4ec8828..505a35b8b249 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -972,6 +972,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -985,7 +986,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index a4a4d3491331..2b5aff242375 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1153,6 +1153,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1166,7 +1167,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 2346ede2c4c8..86e6bfbb5472 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1227,6 +1227,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1240,7 +1241,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index b7d548029bb9..3e6fba51861d 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -946,6 +946,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -959,7 +960,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ):