diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c0da2530fe2c..a5aff5e0dc7d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -266,7 +266,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -281,6 +281,13 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if q_len > 1: + # prefill + cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device) + else: + # decoding + cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device) + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} @@ -340,7 +347,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -366,6 +373,13 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if q_len > 1: + # prefill + cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device) + else: + # decoding + cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device) + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} @@ -531,7 +545,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -546,7 +560,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, + cache_length=cache_length, ) bsz, q_len, _ = hidden_states.size() @@ -562,6 +576,13 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if q_len > 1: + # prefill + cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device) + else: + # decoding + cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device) + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} @@ -585,6 +606,11 @@ def forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False + if cache_length > 0: + key_states = key_states[:, :, :cache_length, :] + value_states = value_states[:, :, :cache_length, :] + causal_mask = causal_mask[:, :, :, :cache_length] if causal_mask is not None else causal_mask + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, @@ -628,7 +654,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -662,7 +688,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, + cache_length=cache_length, ) hidden_states = residual + hidden_states @@ -850,7 +876,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -878,17 +904,18 @@ def forward( return_legacy_cache = True # noqa: F841 past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_length is None: + cache_length = past_seen_tokens + inputs_embeds.shape[1] if position_ids is None: + cache_position = torch.arange( + past_seen_tokens, cache_length, device=inputs_embeds.device + ) position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + attention_mask, inputs_embeds, cache_length, past_key_values, output_attentions ) # embed positions @@ -925,7 +952,7 @@ def forward( past_key_values, output_attentions, use_cache, - cache_position, + cache_length, ) else: layer_outputs = decoder_layer( @@ -935,7 +962,7 @@ def forward( past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, + cache_length=cache_length, ) hidden_states = layer_outputs[0] @@ -969,7 +996,7 @@ def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, - cache_position: torch.Tensor, + cache_length: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): @@ -1017,12 +1044,12 @@ def _update_causal_mask( raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + else: + # This computation is only required when `sequence_length = 1` in the case of static cache. + causal_mask *= torch.arange(target_length, device=device) > cache_length causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit @@ -1090,7 +1117,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_length: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1134,7 +1161,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, + cache_length=cache_length, ) hidden_states = outputs[0] @@ -1171,14 +1198,14 @@ def prepare_inputs_for_generation( past_key_values=None, attention_mask=None, inputs_embeds=None, - cache_position=None, + cached_length=None, use_cache=True, **kwargs, ): past_length = 0 if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) if past_key_values.get_max_length() is not None @@ -1223,15 +1250,14 @@ def prepare_inputs_for_generation( model_inputs = {"input_ids": input_ids.contiguous()} input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + if cached_length is None: + # It must be a python int + cached_length = int(past_length + input_length) model_inputs.update( { "position_ids": position_ids, - "cache_position": cache_position, + "cache_length": cached_length, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask,