diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 540b6136bb63..563615d217b9 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -268,53 +268,62 @@ def forward( **kwargs, ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]: is_cross_attention = encoder_hidden_states is not None + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + if is_cross_attention: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`." ) - query_states = self.q_attn(hidden_states) - key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask + + # Try to get key/value states from cache if possible + if past_key_value is not None and is_updated: + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) else: query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) shape_q = (*query_states.shape[:-1], -1, self.head_dim) - shape_kv = (*key_states.shape[:-1], -1, self.head_dim) - query_states = query_states.view(shape_q).transpose(1, 2) - key_states = key_states.view(shape_kv).transpose(1, 2) - value_states = value_states.view(shape_kv).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - if is_cross_attention: - past_key_value = past_key_value.cross_attention_cache - else: - past_key_value = past_key_value.self_attention_cache - cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs + if (past_key_value is not None and not is_cross_attention) or ( + past_key_value is not None and is_cross_attention and not is_updated + ): + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention using_eager = self.config._attn_implementation == "eager" attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): - using_eager = True - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - # Attention functions are consistent with previous equivalent attention classes, however they do not support some options - # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but - # not necessarily to eager (if mentioned options are provided). - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] if using_eager and self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn( diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index c3155b17eae3..5974b8024e41 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -27,9 +27,10 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa +from ...masking_utils import create_causal_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -278,53 +279,62 @@ def forward( **kwargs, ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]: is_cross_attention = encoder_hidden_states is not None + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + if is_cross_attention: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." ) - query_states = self.q_attn(hidden_states) - key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask + + # Try to get key/value states from cache if possible + if past_key_value is not None and is_updated: + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) else: query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) shape_q = (*query_states.shape[:-1], -1, self.head_dim) - shape_kv = (*key_states.shape[:-1], -1, self.head_dim) - query_states = query_states.view(shape_q).transpose(1, 2) - key_states = key_states.view(shape_kv).transpose(1, 2) - value_states = value_states.view(shape_kv).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - if is_cross_attention: - past_key_value = past_key_value.cross_attention_cache - else: - past_key_value = past_key_value.self_attention_cache - cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs + if (past_key_value is not None and not is_cross_attention) or ( + past_key_value is not None and is_cross_attention and not is_updated + ): + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention using_eager = self.config._attn_implementation == "eager" attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): - using_eager = True - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - # Attention functions are consistent with previous equivalent attention classes, however they do not support some options - # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but - # not necessarily to eager (if mentioned options are provided). - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] if using_eager and self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn( @@ -861,8 +871,14 @@ def forward( # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel if attention_mask is not None and attention_mask.ndim < 4: attention_mask = attention_mask.view(batch_size, -1) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, ) # If a 2D or 3D attention mask is provided for the cross-attention @@ -903,9 +919,6 @@ def forward( # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) if isinstance(head_mask, torch.Tensor): head_mask = head_mask.to(hidden_states.device) if output_hidden_states: @@ -966,123 +979,6 @@ def forward( cross_attentions=all_cross_attentions, ) - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 96d8228ea53d..ab2bbf8d3063 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -449,6 +449,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" @@ -561,6 +562,7 @@ def forward( use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, + cache_position=cache_position, **kwargs_decoder, ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9454f4b4e522..2b4b93e6476f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1775,6 +1775,7 @@ def test_head_pruning(self): model = model_class(config=config) model.to(torch_device) model.eval() + model.set_attn_implementation("eager") heads_to_prune = { 0: list(range(1, self.model_tester.num_attention_heads)), -1: [0], @@ -1808,6 +1809,7 @@ def test_head_pruning_save_load_from_pretrained(self): model = model_class(config=config) model.to(torch_device) model.eval() + model.set_attn_implementation("eager") heads_to_prune = { 0: list(range(1, self.model_tester.num_attention_heads)), -1: [0], @@ -1816,7 +1818,7 @@ def test_head_pruning_save_load_from_pretrained(self): with tempfile.TemporaryDirectory() as temp_dir_name: model.save_pretrained(temp_dir_name) - model = model_class.from_pretrained(temp_dir_name) + model = model_class.from_pretrained(temp_dir_name, attn_implementation="eager") model.to(torch_device) with torch.no_grad(): @@ -1852,6 +1854,7 @@ def test_head_pruning_save_load_from_config_init(self): model = model_class(config=config) model.to(torch_device) model.eval() + model.set_attn_implementation("eager") with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) @@ -1884,6 +1887,7 @@ def test_head_pruning_integration(self): model = model_class(config=config) model.to(torch_device) model.eval() + model.set_attn_implementation("eager") with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) @@ -1894,7 +1898,7 @@ def test_head_pruning_integration(self): with tempfile.TemporaryDirectory() as temp_dir_name: model.save_pretrained(temp_dir_name) - model = model_class.from_pretrained(temp_dir_name) + model = model_class.from_pretrained(temp_dir_name, attn_implementation="eager") model.to(torch_device) with torch.no_grad():