diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 266a916cef63..5a0bcb55d42a 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -27,11 +27,13 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig @@ -40,7 +42,6 @@ from ...integrations.flex_attention import make_flex_block_causal_mask - logger = logging.get_logger(__name__) @@ -235,7 +236,6 @@ def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: query = self.q_proj(hidden_state) @@ -252,13 +252,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -274,9 +268,6 @@ def forward( attn_output = attn_output.reshape(batch_size, q_seq_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights @@ -303,7 +294,6 @@ def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, ): # Self Attention residual = hidden_state @@ -321,12 +311,7 @@ def forward( hidden_state = self.gate_ffn.tanh() * hidden_state hidden_state = residual + hidden_state - outputs = (hidden_state,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_state class MllamaVisionEncoder(nn.Module): @@ -349,10 +334,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, BaseModelOutput]: + ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -366,54 +348,15 @@ def forward( - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + """ for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_state=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - hidden_states = layer_outputs[0] + hidden_states = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + ) - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) + return BaseModelOutput(last_hidden_state=hidden_states) # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText @@ -470,7 +413,6 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -507,13 +449,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -529,9 +465,6 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights @@ -595,7 +528,6 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, - output_attentions: bool = False, use_cache: bool = False, past_key_value=None, cache_position=None, @@ -622,13 +554,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -644,9 +570,6 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights @@ -669,7 +592,7 @@ def forward(self, x): # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer -class MllamaSelfAttentionDecoderLayer(nn.Module): +class MllamaSelfAttentionDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MllamaTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -691,7 +614,6 @@ def forward( full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -703,9 +625,7 @@ def forward( attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. + use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -729,7 +649,6 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -743,15 +662,10 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - return outputs + return hidden_states -class MllamaCrossAttentionDecoderLayer(torch.nn.Module): +class MllamaCrossAttentionDecoderLayer(GradientCheckpointingLayer): """Cross-attention transformer block with tanh-gated attention and feedforward.""" def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: @@ -775,7 +689,6 @@ def forward( full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.Tensor] = None, @@ -789,7 +702,6 @@ def forward( attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, past_key_value=past_key_value, - output_attentions=output_attentions, cache_position=cache_position, **kwargs, ) @@ -802,12 +714,7 @@ def forward( hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states class MllamaRotaryEmbedding(nn.Module): @@ -849,12 +756,19 @@ class MllamaPreTrainedModel(PreTrainedModel): "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer", ] - _can_compile_fullgraph = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": [MllamaSelfAttentionDecoderLayer, MllamaCrossAttentionDecoderLayer], + "attentions": [ + OutputRecorder(MllamaTextSelfAttention, index=1, layer_name="self_attn"), + OutputRecorder(MllamaTextSelfAttention, index=1, layer_name="cross_attn"), + OutputRecorder(MllamaTextCrossAttention, index=1, layer_name="cross_attn"), + ], + } def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -1071,16 +985,11 @@ def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = torch.cat([class_embedding, hidden_state], dim=1) return hidden_state + @check_model_inputs @auto_docstring def forward( - self, - pixel_values: torch.Tensor, - aspect_ratio_ids: torch.Tensor, - aspect_ratio_mask: torch.Tensor, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[BaseModelOutput, tuple[torch.Tensor, ...]]: + self, pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, aspect_ratio_mask: torch.Tensor, **kwargs + ) -> BaseModelOutput: r""" aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*): Aspect ratio ids used to select the appropriate precomputed tile embeddings based on the aspect ratio of each input image. @@ -1121,12 +1030,6 @@ def forward( torch.Size([1, 1, 4, 1025, 7680]) ``` """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width) @@ -1176,10 +1079,8 @@ def forward( output = self.transformer( hidden_state, attention_mask=attention_mask, - output_hidden_states=True, - output_attentions=output_attentions, ) - hidden_state = output[0] + hidden_state = output.last_hidden_state hidden_state = self.layernorm_post(hidden_state) @@ -1194,10 +1095,8 @@ def forward( global_output = self.global_transformer( hidden_state, attention_mask=attention_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, ) - hidden_state = global_output[0] + hidden_state = global_output.last_hidden_state # Remove padding form hidden state hidden_state = hidden_state.reshape( @@ -1207,7 +1106,7 @@ def forward( hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) # Collect intermediate layer outputs from encoder output - all_intermediate_hidden_states = [output[1][i] for i in self.intermediate_layers_indices] + all_intermediate_hidden_states = [output.last_hidden_state for _ in self.intermediate_layers_indices] intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) # Remove padding from intermediate hidden states @@ -1222,26 +1121,7 @@ def forward( # Concatenate final hidden state and intermediate hidden states hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) - if output_hidden_states: - hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1]) - else: - hidden_states = None - - if output_attentions: - # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range - global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) - attentions = tuple(output[2]) + global_attn - else: - attentions = None - - if not return_dict: - return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) - - return BaseModelOutput( - last_hidden_state=hidden_state, - hidden_states=hidden_states, - attentions=attentions, - ) + return BaseModelOutput(last_hidden_state=hidden_state) @auto_docstring( @@ -1273,6 +1153,8 @@ def __init__(self, config: MllamaTextConfig): self.gradient_checkpointing = False self.post_init() + @check_model_inputs + @can_return_tuple @auto_docstring def forward( self, @@ -1285,12 +1167,9 @@ def forward( past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: r""" cross_attention_states (`torch.FloatTensor`, *optional*): Output of the vision model, used for cross-attention. This tensor contains the processed image features that @@ -1330,22 +1209,11 @@ def forward( torch.Size([1, 13, 4096]) ``` """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1362,21 +1230,13 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - # For text-only path we should skip cross attention layers. # Let's check if the layer is cross attention layer and if we have cross attention states # or cached cross attention states. @@ -1388,57 +1248,25 @@ def forward( if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - cross_attention_states, - cross_attention_mask, - causal_mask, - full_text_row_masked_out_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - attention_mask=causal_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) + hidden_states = decoder_layer( + hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + attention_mask=causal_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = self.norm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None - ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, ) @@ -1468,6 +1296,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @can_return_tuple @auto_docstring def forward( self, @@ -1481,9 +1310,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], @@ -1532,12 +1358,6 @@ def forward( I love the idea of snowflakes gently falling, each one ``` """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -1549,14 +1369,11 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]).float() @@ -1564,10 +1381,6 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1614,6 +1427,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model + @check_model_inputs @can_return_tuple @auto_docstring def forward( @@ -1629,12 +1443,9 @@ def forward( past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, CausalLMOutputWithPast]: + ) -> BaseModelOutputWithPast: r""" aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*): Mask to avoid performing attention on padding tiles. Mask values selected in `[0, 1]`: @@ -1664,12 +1475,6 @@ def forward( Output of the vision model, used for cross-attention. This tensor contains the processed image features that the language model will attend to. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1684,11 +1489,8 @@ def forward( pixel_values=pixel_values, aspect_ratio_ids=aspect_ratio_ids, aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, ) - cross_attention_states = vision_outputs[0] + cross_attention_states = vision_outputs.last_hidden_state cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( -1, cross_attention_states.shape[-2], self.hidden_size ) @@ -1716,9 +1518,6 @@ def forward( past_key_values=past_key_values, use_cache=use_cache, inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=True, cache_position=cache_position, **kwargs, ) @@ -1788,9 +1587,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], @@ -1855,12 +1651,6 @@ def forward( [', it would be:.\\nA stop sign in Chinatown.\\n'] ``` """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, @@ -1873,14 +1663,11 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :])