From 9ab68d0d9f8db06e4bb7e709f7499549955d52a2 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 10:13:00 +0200 Subject: [PATCH 01/18] Use cache_info --- src/transformers/cache_utils.py | 13 ++++- .../models/gemma/modeling_gemma.py | 58 +++++++++++-------- 2 files changed, 43 insertions(+), 28 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 1f5a164815aa..c4f7578be4d0 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -23,6 +23,13 @@ logger = logging.get_logger(__name__) +class CacheInfo: + + def __init__(self, position, length): + self.position = position + self._length = length + + @dataclass class Cache: """ @@ -854,7 +861,7 @@ def update( Return: A tuple containing the updated key and value states. """ - cache_position = cache_kwargs.get("cache_position") + cache_info = cache_kwargs.get("cache_info") k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] @@ -862,8 +869,8 @@ def update( k_out.copy_(key_states) v_out.copy_(value_states) else: - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states + k_out[:, :, cache_info.position] = key_states + v_out[:, :, cache_info.position] = value_states return k_out, v_out diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c0da2530fe2c..21b3ebc9d94a 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, CacheInfo, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -266,7 +266,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -283,7 +283,7 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -340,7 +340,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -368,7 +368,7 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -531,7 +531,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -546,7 +546,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, + cache_info=cache_info, ) bsz, q_len, _ = hidden_states.size() @@ -564,7 +564,7 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -585,6 +585,11 @@ def forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False + if cache_info._length > 0: + key_states = key_states[:, :, :cache_info._length, :] + value_states = value_states[:, :, :cache_info._length, :] + causal_mask = causal_mask[:, :, :, :cache_info._length] if causal_mask is not None else causal_mask + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, @@ -628,7 +633,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -662,7 +667,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, + cache_info=cache_info, ) hidden_states = residual + hidden_states @@ -850,7 +855,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -878,17 +883,18 @@ def forward( return_legacy_cache = True # noqa: F841 past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if cache_position is None: + if cache_info is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = cache_info.position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + attention_mask, inputs_embeds, cache_info, past_key_values, output_attentions ) # embed positions @@ -925,7 +931,7 @@ def forward( past_key_values, output_attentions, use_cache, - cache_position, + cache_info, ) else: layer_outputs = decoder_layer( @@ -935,7 +941,7 @@ def forward( past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, + cache_info=cache_info, ) hidden_states = layer_outputs[0] @@ -969,7 +975,7 @@ def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, - cache_position: torch.Tensor, + cache_info: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): @@ -1022,7 +1028,7 @@ def _update_causal_mask( ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=device) > cache_info.position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit @@ -1090,7 +1096,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1134,7 +1140,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, + cache_info=cache_info, ) hidden_states = outputs[0] @@ -1171,14 +1177,14 @@ def prepare_inputs_for_generation( past_key_values=None, attention_mask=None, inputs_embeds=None, - cache_position=None, + cache_info=None, use_cache=True, **kwargs, ): past_length = 0 if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_info.position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) if past_key_values.get_max_length() is not None @@ -1223,15 +1229,17 @@ def prepare_inputs_for_generation( model_inputs = {"input_ids": input_ids.contiguous()} input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: + if cache_info is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1) elif use_cache: - cache_position = cache_position[-input_length:] + cache_position = cache_info.position[-input_length:] + cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1) model_inputs.update( { "position_ids": position_ids, - "cache_position": cache_position, + "cache_info": cache_info, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, From 81e795a600d41323a38393ab9c6f7829c1627cb7 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 10:13:26 +0200 Subject: [PATCH 02/18] [run-slow] gemma From c1098f976667c440b50ab17eb9e1f0a543d7ab1e Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 10:40:20 +0200 Subject: [PATCH 03/18] [run-slow] gemma --- src/transformers/models/gemma/modeling_gemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 21b3ebc9d94a..0dbe4d9f5721 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1184,7 +1184,7 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_info.position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_info.position[0] if cache_info is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) if past_key_values.get_max_length() is not None From b7eaf50cd5e4c021afe099b5966bbc45ac1ee958 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 13:57:33 +0200 Subject: [PATCH 04/18] [run-slow] gemma --- src/transformers/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c4f7578be4d0..dc095c8b080d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -865,7 +865,7 @@ def update( k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] - if cache_position is None: + if cache_info is None: k_out.copy_(key_states) v_out.copy_(value_states) else: From 888a2c00075afab3ae90a61f1fe0dcb588173b9e Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 14:18:17 +0200 Subject: [PATCH 05/18] reconstruct cache_position from _length: 0001 --- src/transformers/cache_utils.py | 8 ++++---- src/transformers/models/gemma/modeling_gemma.py | 9 ++++++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index dc095c8b080d..f400739f1e89 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -861,16 +861,16 @@ def update( Return: A tuple containing the updated key and value states. """ - cache_info = cache_kwargs.get("cache_info") + cache_position = cache_kwargs.get("cache_position") k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] - if cache_info is None: + if cache_position is None: k_out.copy_(key_states) v_out.copy_(value_states) else: - k_out[:, :, cache_info.position] = key_states - v_out[:, :, cache_info.position] = value_states + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states return k_out, v_out diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 0dbe4d9f5721..339b381b174e 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -281,9 +281,16 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if q_len > 1: + # prefill + cache_position = torch.arange(cache_info._length, dtype=torch.int32, device=hidden_states.device) + else: + # decoding + cache_position = torch.tensor([cache_info._length - 1], dtype=torch.int32, device=hidden_states.device) + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info} + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) From 01cd35fb9751be94c4707d7405059e9170ad5705 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 15:32:59 +0200 Subject: [PATCH 06/18] reconstruct cache_position from _length: 0002 --- src/transformers/models/gemma/modeling_gemma.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 339b381b174e..05db3470e56e 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -569,9 +569,16 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if q_len > 1: + # prefill + cache_position = torch.arange(cache_info._length, dtype=torch.int32, device=hidden_states.device) + else: + # decoding + cache_position = torch.tensor([cache_info._length - 1], dtype=torch.int32, device=hidden_states.device) + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info} + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) From f93b239ed7ccac3633ddbad9a946e11ae183dbdd Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 16:05:33 +0200 Subject: [PATCH 07/18] reconstruct cache_position from _length: 0003 --- .../models/gemma/modeling_gemma.py | 75 ++++++++++--------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 05db3470e56e..34deb932e87d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -266,7 +266,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_info: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -283,10 +283,10 @@ def forward( if q_len > 1: # prefill - cache_position = torch.arange(cache_info._length, dtype=torch.int32, device=hidden_states.device) + cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device) else: # decoding - cache_position = torch.tensor([cache_info._length - 1], dtype=torch.int32, device=hidden_states.device) + cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -347,7 +347,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_info: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -373,9 +373,16 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if q_len > 1: + # prefill + cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device) + else: + # decoding + cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device) + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info} + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -538,7 +545,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_info: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -553,7 +560,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_info=cache_info, + cache_length=cache_length, ) bsz, q_len, _ = hidden_states.size() @@ -571,10 +578,10 @@ def forward( if q_len > 1: # prefill - cache_position = torch.arange(cache_info._length, dtype=torch.int32, device=hidden_states.device) + cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device) else: # decoding - cache_position = torch.tensor([cache_info._length - 1], dtype=torch.int32, device=hidden_states.device) + cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -599,10 +606,10 @@ def forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False - if cache_info._length > 0: - key_states = key_states[:, :, :cache_info._length, :] - value_states = value_states[:, :, :cache_info._length, :] - causal_mask = causal_mask[:, :, :, :cache_info._length] if causal_mask is not None else causal_mask + if cache_length > 0: + key_states = key_states[:, :, :cache_length, :] + value_states = value_states[:, :, :cache_length, :] + causal_mask = causal_mask[:, :, :, :cache_length] if causal_mask is not None else causal_mask attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, @@ -647,7 +654,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - cache_info: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -681,7 +688,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_info=cache_info, + cache_length=cache_length, ) hidden_states = residual + hidden_states @@ -869,7 +876,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_info: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -897,18 +904,18 @@ def forward( return_legacy_cache = True # noqa: F841 past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if cache_info is None: + if cache_length is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_length = past_seen_tokens + inputs_embeds.shape[1] cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, cache_length, device=inputs_embeds.device ) - cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1) if position_ids is None: - position_ids = cache_info.position.unsqueeze(0) + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_info, past_key_values, output_attentions + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) # embed positions @@ -945,7 +952,7 @@ def forward( past_key_values, output_attentions, use_cache, - cache_info, + cache_length, ) else: layer_outputs = decoder_layer( @@ -955,7 +962,7 @@ def forward( past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, - cache_info=cache_info, + cache_length=cache_length, ) hidden_states = layer_outputs[0] @@ -989,7 +996,7 @@ def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, - cache_info: torch.Tensor, + cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): @@ -1042,7 +1049,7 @@ def _update_causal_mask( ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_info.position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit @@ -1110,7 +1117,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_info: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1154,7 +1161,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_info=cache_info, + cache_length=cache_length, ) hidden_states = outputs[0] @@ -1191,14 +1198,14 @@ def prepare_inputs_for_generation( past_key_values=None, attention_mask=None, inputs_embeds=None, - cache_info=None, + cache_length=None, use_cache=True, **kwargs, ): past_length = 0 if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_info.position[0] if cache_info is not None else past_key_values.get_seq_length() + past_length = past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) if past_key_values.get_max_length() is not None @@ -1243,17 +1250,13 @@ def prepare_inputs_for_generation( model_inputs = {"input_ids": input_ids.contiguous()} input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_info is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1) - elif use_cache: - cache_position = cache_info.position[-input_length:] - cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1) + if cache_length is None: + cache_length = past_length + input_length model_inputs.update( { "position_ids": position_ids, - "cache_info": cache_info, + "cache_length": cache_length, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, From 8029b7fc2ab5b84445629762c52c414235f0859b Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 17:05:33 +0200 Subject: [PATCH 08/18] [run-slow] gemma From b6f30f537655e5f32fd752eb995b68b2abe537ee Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 17:27:38 +0200 Subject: [PATCH 09/18] fix --- src/transformers/models/gemma/modeling_gemma.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 34deb932e87d..a0e8614344a4 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -907,9 +907,10 @@ def forward( if cache_length is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_length = past_seen_tokens + inputs_embeds.shape[1] - cache_position = torch.arange( - past_seen_tokens, cache_length, device=inputs_embeds.device - ) + + cache_position = torch.arange( + past_seen_tokens, cache_length, device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) From 33ef0b14ead8b2de6ffe2cd1136d45572e8b1c63 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 17:27:56 +0200 Subject: [PATCH 10/18] [run-slow] gemma From 3ca52cf3ab1a0a3cb2b045bffc20247d3c424cba Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 17:31:24 +0200 Subject: [PATCH 11/18] remove cache info --- src/transformers/cache_utils.py | 7 ------- src/transformers/models/gemma/modeling_gemma.py | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index f400739f1e89..1f5a164815aa 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -23,13 +23,6 @@ logger = logging.get_logger(__name__) -class CacheInfo: - - def __init__(self, position, length): - self.position = position - self._length = length - - @dataclass class Cache: """ diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index a0e8614344a4..277b0dbc811d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, CacheInfo, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, From dcd292bab6f1ab1bb232518bb72121d77c7e24bd Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 17:34:49 +0200 Subject: [PATCH 12/18] fix --- src/transformers/models/gemma/modeling_gemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 277b0dbc811d..0b4bce41aa5d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -904,8 +904,8 @@ def forward( return_legacy_cache = True # noqa: F841 past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_length is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_length = past_seen_tokens + inputs_embeds.shape[1] cache_position = torch.arange( From 843876bdb7c1728d08fdf7f781291a289ac4a691 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 17:34:53 +0200 Subject: [PATCH 13/18] [run-slow] gemma From f47f4a8998a64c1866ce7a6a10306277f72e28b6 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 17:49:38 +0200 Subject: [PATCH 14/18] fix --- src/transformers/models/gemma/modeling_gemma.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 0b4bce41aa5d..070f824b22b9 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1199,7 +1199,7 @@ def prepare_inputs_for_generation( past_key_values=None, attention_mask=None, inputs_embeds=None, - cache_length=None, + cached_length=None, use_cache=True, **kwargs, ): @@ -1251,13 +1251,13 @@ def prepare_inputs_for_generation( model_inputs = {"input_ids": input_ids.contiguous()} input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_length is None: - cache_length = past_length + input_length + if cached_length is None: + cached_length = past_length + input_length model_inputs.update( { "position_ids": position_ids, - "cache_length": cache_length, + "cache_length": cached_length, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, From 8d4e17b311f6e473ad7a685b72e0a113ea15281f Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 17:49:43 +0200 Subject: [PATCH 15/18] [run-slow] gemma From 84e694d4ec7f08559aff294a2b4dd2d18c1d7670 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 17:54:37 +0200 Subject: [PATCH 16/18] fix --- src/transformers/models/gemma/modeling_gemma.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 070f824b22b9..1c6b885a3575 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1252,7 +1252,8 @@ def prepare_inputs_for_generation( input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cached_length is None: - cached_length = past_length + input_length + # It must be a python int + cached_length = int(past_length + input_length) model_inputs.update( { From 4df64e62e915930a3f20926061645651af61fe50 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 4 Jul 2024 17:54:42 +0200 Subject: [PATCH 17/18] [run-slow] gemma From 450b1d26e42d4232f463a91b6add16050e987db8 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Fri, 5 Jul 2024 17:54:05 +0200 Subject: [PATCH 18/18] fix --- .../models/gemma/modeling_gemma.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 1c6b885a3575..a5aff5e0dc7d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -908,15 +908,14 @@ def forward( if cache_length is None: cache_length = past_seen_tokens + inputs_embeds.shape[1] - cache_position = torch.arange( - past_seen_tokens, cache_length, device=inputs_embeds.device - ) - if position_ids is None: + cache_position = torch.arange( + past_seen_tokens, cache_length, device=inputs_embeds.device + ) position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + attention_mask, inputs_embeds, cache_length, past_key_values, output_attentions ) # embed positions @@ -997,7 +996,7 @@ def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, - cache_position: torch.Tensor, + cache_length: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): @@ -1045,12 +1044,12 @@ def _update_causal_mask( raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + else: + # This computation is only required when `sequence_length = 1` in the case of static cache. + causal_mask *= torch.arange(target_length, device=device) > cache_length causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit