From c4d27aeeeee6bbac21057688cdd47bcabc4eeef4 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 24 Jul 2025 15:40:39 +0000 Subject: [PATCH 1/9] mllama outputs refactor --- .../models/mllama/modeling_mllama.py | 318 ++++-------------- 1 file changed, 57 insertions(+), 261 deletions(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 266a916cef63..589897e0715f 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -27,12 +27,15 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig +from ...utils.generic import OutputRecorder, check_model_inputs + if is_torch_flex_attn_available(): @@ -235,7 +238,6 @@ def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: query = self.q_proj(hidden_state) @@ -252,13 +254,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -274,9 +270,6 @@ def forward( attn_output = attn_output.reshape(batch_size, q_seq_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights @@ -303,7 +296,6 @@ def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, ): # Self Attention residual = hidden_state @@ -321,12 +313,7 @@ def forward( hidden_state = self.gate_ffn.tanh() * hidden_state hidden_state = residual + hidden_state - outputs = (hidden_state,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return (hidden_state,) class MllamaVisionEncoder(nn.Module): @@ -349,10 +336,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, BaseModelOutput]: + ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -366,54 +350,16 @@ def forward( - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + """ for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_state=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - + layer_outputs = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + ) hidden_states = layer_outputs[0] - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) + return BaseModelOutput(last_hidden_state=hidden_states) # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText @@ -470,7 +416,6 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -507,13 +452,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -529,9 +468,6 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights @@ -595,7 +531,6 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, - output_attentions: bool = False, use_cache: bool = False, past_key_value=None, cache_position=None, @@ -619,16 +554,11 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -644,9 +574,6 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights @@ -669,7 +596,7 @@ def forward(self, x): # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer -class MllamaSelfAttentionDecoderLayer(nn.Module): +class MllamaSelfAttentionDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MllamaTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -691,7 +618,6 @@ def forward( full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -703,9 +629,7 @@ def forward( attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. + use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -729,7 +653,6 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -743,15 +666,10 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) + return (hidden_states,) - return outputs - -class MllamaCrossAttentionDecoderLayer(torch.nn.Module): +class MllamaCrossAttentionDecoderLayer(GradientCheckpointingLayer): """Cross-attention transformer block with tanh-gated attention and feedforward.""" def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: @@ -775,7 +693,6 @@ def forward( full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.Tensor] = None, @@ -789,7 +706,6 @@ def forward( attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, past_key_value=past_key_value, - output_attentions=output_attentions, cache_position=cache_position, **kwargs, ) @@ -802,12 +718,7 @@ def forward( hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return (hidden_states,) class MllamaRotaryEmbedding(nn.Module): @@ -849,12 +760,15 @@ class MllamaPreTrainedModel(PreTrainedModel): "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer", ] - _can_compile_fullgraph = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": MllamaSelfAttentionDecoderLayer, + "attentions": MllamaTextSelfAttention, + } def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -894,7 +808,6 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): @@ -912,7 +825,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -946,7 +859,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -1071,16 +983,14 @@ def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = torch.cat([class_embedding, hidden_state], dim=1) return hidden_state + @check_model_inputs @auto_docstring def forward( self, pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, aspect_ratio_mask: torch.Tensor, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[BaseModelOutput, tuple[torch.Tensor, ...]]: + ) -> BaseModelOutput: r""" aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*): Aspect ratio ids used to select the appropriate precomputed tile embeddings based on the aspect ratio of each input image. @@ -1121,12 +1031,6 @@ def forward( torch.Size([1, 1, 4, 1025, 7680]) ``` """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width) @@ -1176,10 +1080,8 @@ def forward( output = self.transformer( hidden_state, attention_mask=attention_mask, - output_hidden_states=True, - output_attentions=output_attentions, ) - hidden_state = output[0] + hidden_state = output.last_hidden_state hidden_state = self.layernorm_post(hidden_state) @@ -1194,10 +1096,8 @@ def forward( global_output = self.global_transformer( hidden_state, attention_mask=attention_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, ) - hidden_state = global_output[0] + hidden_state = global_output.last_hidden_state # Remove padding form hidden state hidden_state = hidden_state.reshape( @@ -1207,7 +1107,7 @@ def forward( hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) # Collect intermediate layer outputs from encoder output - all_intermediate_hidden_states = [output[1][i] for i in self.intermediate_layers_indices] + all_intermediate_hidden_states = [output.last_hidden_state for _ in self.intermediate_layers_indices] intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) # Remove padding from intermediate hidden states @@ -1222,26 +1122,7 @@ def forward( # Concatenate final hidden state and intermediate hidden states hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) - if output_hidden_states: - hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1]) - else: - hidden_states = None - - if output_attentions: - # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range - global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) - attentions = tuple(output[2]) + global_attn - else: - attentions = None - - if not return_dict: - return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) - - return BaseModelOutput( - last_hidden_state=hidden_state, - hidden_states=hidden_states, - attentions=attentions, - ) + return BaseModelOutput(last_hidden_state=hidden_state) @auto_docstring( @@ -1273,6 +1154,8 @@ def __init__(self, config: MllamaTextConfig): self.gradient_checkpointing = False self.post_init() + @check_model_inputs + @can_return_tuple @auto_docstring def forward( self, @@ -1285,12 +1168,9 @@ def forward( past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: r""" cross_attention_states (`torch.FloatTensor`, *optional*): Output of the vision model, used for cross-attention. This tensor contains the processed image features that @@ -1330,22 +1210,11 @@ def forward( torch.Size([1, 13, 4096]) ``` """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1363,19 +1232,14 @@ def forward( position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + attention_mask, inputs_embeds, cache_position, past_key_values ) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) # For text-only path we should skip cross attention layers. # Let's check if the layer is cross attention layer and if we have cross attention states @@ -1388,57 +1252,29 @@ def forward( if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - cross_attention_states, - cross_attention_mask, - causal_mask, - full_text_row_masked_out_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - attention_mask=causal_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + attention_mask=causal_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None - ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, + hidden_states=None, + attentions=None, ) @@ -1468,6 +1304,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @can_return_tuple @auto_docstring def forward( self, @@ -1481,9 +1318,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], @@ -1532,12 +1366,6 @@ def forward( I love the idea of snowflakes gently falling, each one ``` """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -1549,14 +1377,11 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]).float() @@ -1564,10 +1389,6 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1614,6 +1435,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model + @check_model_inputs @can_return_tuple @auto_docstring def forward( @@ -1629,12 +1451,9 @@ def forward( past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, CausalLMOutputWithPast]: + ) -> BaseModelOutputWithPast: r""" aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*): Mask to avoid performing attention on padding tiles. Mask values selected in `[0, 1]`: @@ -1664,12 +1483,6 @@ def forward( Output of the vision model, used for cross-attention. This tensor contains the processed image features that the language model will attend to. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1684,11 +1497,8 @@ def forward( pixel_values=pixel_values, aspect_ratio_ids=aspect_ratio_ids, aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, ) - cross_attention_states = vision_outputs[0] + cross_attention_states = vision_outputs.last_hidden_state cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( -1, cross_attention_states.shape[-2], self.hidden_size ) @@ -1716,9 +1526,6 @@ def forward( past_key_values=past_key_values, use_cache=use_cache, inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=True, cache_position=cache_position, **kwargs, ) @@ -1772,6 +1579,7 @@ def language_model(self): def vision_model(self): return self.model.vision_model + @check_model_inputs @can_return_tuple @auto_docstring def forward( @@ -1788,9 +1596,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], @@ -1855,12 +1660,6 @@ def forward( [', it would be:.\\nA stop sign in Chinatown.\\n'] ``` """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, @@ -1873,14 +1672,11 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) From ed1c799eded6271e9109a142df703b31de85f70b Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 25 Jul 2025 13:46:20 +0000 Subject: [PATCH 2/9] forgot kwargs --- src/transformers/models/mllama/modeling_mllama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 589897e0715f..6c5ad146e56a 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -554,7 +554,6 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -767,7 +766,8 @@ class MllamaPreTrainedModel(PreTrainedModel): _supports_attention_backend = True _can_record_outputs = { "hidden_states": MllamaSelfAttentionDecoderLayer, - "attentions": MllamaTextSelfAttention, + "attentions": OutputRecorder(MllamaTextSelfAttention, index=1), + "cross_attentions": OutputRecorder(MllamaTextCrossAttention, index=1), } def _init_weights(self, module): @@ -990,6 +990,7 @@ def forward( pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, aspect_ratio_mask: torch.Tensor, + **kwargs ) -> BaseModelOutput: r""" aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*): From 767bce114021bf64c4022991f790e3222c871a45 Mon Sep 17 00:00:00 2001 From: itazap Date: Fri, 25 Jul 2025 17:20:37 +0200 Subject: [PATCH 3/9] fix output --- src/transformers/models/mllama/modeling_mllama.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 6c5ad146e56a..778c28efda49 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -665,7 +665,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return (hidden_states,) + return (hidden_states, self_attn_weights) class MllamaCrossAttentionDecoderLayer(GradientCheckpointingLayer): @@ -717,7 +717,7 @@ def forward( hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - return (hidden_states,) + return (hidden_states, attn_weights) class MllamaRotaryEmbedding(nn.Module): @@ -765,9 +765,9 @@ class MllamaPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True _can_record_outputs = { - "hidden_states": MllamaSelfAttentionDecoderLayer, - "attentions": OutputRecorder(MllamaTextSelfAttention, index=1), - "cross_attentions": OutputRecorder(MllamaTextCrossAttention, index=1), + "hidden_states": [OutputRecorder(MllamaSelfAttentionDecoderLayer, index=0), OutputRecorder(MllamaCrossAttentionDecoderLayer, index=0)], + "attentions": [MllamaTextSelfAttention, MllamaTextCrossAttention], + # "cross_attentions": OutputRecorder(MllamaCrossAttentionDecoderLayer, index=1), } def _init_weights(self, module): From 1fc04b129ea6e44d0611595ae9a83258503f57d8 Mon Sep 17 00:00:00 2001 From: itazap Date: Sun, 27 Jul 2025 23:48:31 +0200 Subject: [PATCH 4/9] add can_record_outputs --- .../models/mllama/modeling_mllama.py | 356 +++++++++--------- 1 file changed, 179 insertions(+), 177 deletions(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 778c28efda49..d7e61d6cc8b6 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -36,21 +36,18 @@ from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig from ...utils.generic import OutputRecorder, check_model_inputs - - if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask from ...integrations.flex_attention import make_flex_block_causal_mask - logger = logging.get_logger(__name__) def _prepare_cross_attention_mask( - cross_attention_mask: torch.Tensor, - num_vision_tokens: int, - dtype: str, + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, ) -> tuple[torch.Tensor, torch.Tensor]: # reshape so it can be used by attn module batch_size, text_total_length, *_ = cross_attention_mask.shape @@ -76,10 +73,10 @@ def _prepare_cross_attention_mask( def _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask: torch.Tensor, - num_patches: int, - target_length: int, - dtype: torch.dtype, + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, ) -> torch.Tensor: # Expand aspect ratio mask to target_length batch_size, max_num_tiles = aspect_ratio_mask.shape @@ -132,7 +129,7 @@ def __init__(self, config: MllamaVisionConfig): self.max_aspect_ratio_id = config.max_aspect_ratio_id self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 self.hidden_size = config.hidden_size - self.scale = config.hidden_size**-0.5 + self.scale = config.hidden_size ** -0.5 self.gate = nn.Parameter(torch.zeros(1)) @@ -193,14 +190,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: # Copied from transformers.models.llama.modeling_llama.eager_attention_forward def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -226,7 +223,7 @@ def __init__(self, config: MllamaVisionConfig): self.embed_dim = config.hidden_size self.num_heads = config.attention_heads self.head_dim = config.hidden_size // config.attention_heads - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.num_key_value_groups = 1 self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) @@ -235,10 +232,10 @@ def __init__(self, config: MllamaVisionConfig): self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=False) def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: query = self.q_proj(hidden_state) key = self.k_proj(hidden_state) @@ -293,9 +290,9 @@ def __init__(self, config: MllamaVisionConfig, is_gated: bool = False): self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, ): # Self Attention residual = hidden_state @@ -333,9 +330,9 @@ def __init__(self, config: MllamaVisionConfig, num_layers=32, is_gated=False): self.config = config def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, ) -> BaseModelOutput: r""" Args: @@ -387,9 +384,9 @@ class MllamaTextCrossAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( - self, - config: Optional[MllamaTextConfig] = None, - layer_idx: Optional[int] = None, + self, + config: Optional[MllamaTextConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.config = config @@ -400,7 +397,7 @@ def __init__( self.head_dim = config.hidden_size // self.num_heads self.layer_idx = layer_idx self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -411,14 +408,14 @@ def __init__( self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() @@ -475,7 +472,7 @@ def forward( def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] + x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) @@ -517,7 +514,7 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int): self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.rope_theta = config.rope_theta self.layer_idx = layer_idx @@ -527,14 +524,14 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int): self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - use_cache: bool = False, - past_key_value=None, - cache_position=None, - **kwargs, + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, ): bsz, q_len, _ = hidden_states.size() @@ -609,18 +606,18 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int): self.layer_idx = layer_idx def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -684,18 +681,18 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: torch.Tensor, - cross_attention_mask: torch.Tensor, - attention_mask: torch.Tensor, - full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor], - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[torch.Tensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + attention_mask: torch.Tensor, + full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -765,9 +762,14 @@ class MllamaPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True _can_record_outputs = { - "hidden_states": [OutputRecorder(MllamaSelfAttentionDecoderLayer, index=0), OutputRecorder(MllamaCrossAttentionDecoderLayer, index=0)], - "attentions": [MllamaTextSelfAttention, MllamaTextCrossAttention], - # "cross_attentions": OutputRecorder(MllamaCrossAttentionDecoderLayer, index=1), + "hidden_states": [ + OutputRecorder(MllamaTextSelfAttention, index=0), + OutputRecorder(MllamaTextCrossAttention, index=0)], + + "attentions": [ + OutputRecorder(MllamaTextSelfAttention, index=1, layer_name="self_attn"), + OutputRecorder(MllamaTextSelfAttention, index=1, layer_name="cross_attn"), + OutputRecorder(MllamaTextCrossAttention, index=1, layer_name="cross_attn")] } def _init_weights(self, module): @@ -803,11 +805,11 @@ def _init_weights(self, module): # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): @@ -827,10 +829,10 @@ def _update_causal_mask( # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, ): return None @@ -856,9 +858,9 @@ def _update_causal_mask( ) if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -871,13 +873,13 @@ def _update_causal_mask( @staticmethod # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -944,7 +946,7 @@ def __init__(self, config: MllamaVisionConfig): self.intermediate_layers_indices = config.intermediate_layers_indices self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 - self.scale = config.hidden_size**-0.5 + self.scale = config.hidden_size ** -0.5 self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, @@ -986,11 +988,11 @@ def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: @check_model_inputs @auto_docstring def forward( - self, - pixel_values: torch.Tensor, - aspect_ratio_ids: torch.Tensor, - aspect_ratio_mask: torch.Tensor, - **kwargs + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + **kwargs ) -> BaseModelOutput: r""" aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*): @@ -1159,18 +1161,18 @@ def __init__(self, config: MllamaTextConfig): @can_return_tuple @auto_docstring def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.FloatTensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.FloatTensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: r""" cross_attention_states (`torch.FloatTensor`, *optional*): @@ -1247,7 +1249,7 @@ def forward( # or cached cross attention states. is_cross_attention_layer = idx in self.cross_attention_layers is_cross_attention_cache_empty = past_key_values is None or ( - past_key_values is not None and past_key_values.get_seq_length(idx) == 0 + past_key_values is not None and past_key_values.get_seq_length(idx) == 0 ) if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: @@ -1274,8 +1276,8 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, - hidden_states=None, - attentions=None, + # hidden_states=None, + # attentions=None, ) @@ -1308,20 +1310,20 @@ def get_decoder(self): @can_return_tuple @auto_docstring def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[TransformersKwargs], + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" cross_attention_states (`torch.FloatTensor`, *optional*): @@ -1440,20 +1442,20 @@ def get_decoder(self): @can_return_tuple @auto_docstring def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: r""" aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*): @@ -1584,22 +1586,22 @@ def vision_model(self): @can_return_tuple @auto_docstring def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[TransformersKwargs], + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*): @@ -1695,20 +1697,20 @@ def forward( ) def prepare_inputs_for_generation( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - position_ids=None, - pixel_values=None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=None, - past_key_values=None, - use_cache=False, - cache_position=None, - logits_to_keep=None, - **kwargs, + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + pixel_values=None, + aspect_ratio_ids=None, + aspect_ratio_mask=None, + cross_attention_mask=None, + past_key_values=None, + use_cache=False, + cache_position=None, + logits_to_keep=None, + **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model From be90868808628f524b810987901b00a8e32ac439 Mon Sep 17 00:00:00 2001 From: itazap Date: Mon, 28 Jul 2025 09:49:45 +0200 Subject: [PATCH 5/9] correct @check_model_inputs placement --- src/transformers/models/mllama/modeling_mllama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index d7e61d6cc8b6..1cd2a646ed37 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1582,7 +1582,6 @@ def language_model(self): def vision_model(self): return self.model.vision_model - @check_model_inputs @can_return_tuple @auto_docstring def forward( From 70c3986c5c8d5d3dcf971fa62e955475111cf94c Mon Sep 17 00:00:00 2001 From: itazap Date: Mon, 28 Jul 2025 09:52:21 +0200 Subject: [PATCH 6/9] ruff and copies --- .../models/mllama/modeling_mllama.py | 353 +++++++++--------- 1 file changed, 175 insertions(+), 178 deletions(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 1cd2a646ed37..cfd1ebfd87f7 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -33,8 +33,9 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging -from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig + if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask @@ -45,9 +46,9 @@ def _prepare_cross_attention_mask( - cross_attention_mask: torch.Tensor, - num_vision_tokens: int, - dtype: str, + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, ) -> tuple[torch.Tensor, torch.Tensor]: # reshape so it can be used by attn module batch_size, text_total_length, *_ = cross_attention_mask.shape @@ -73,10 +74,10 @@ def _prepare_cross_attention_mask( def _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask: torch.Tensor, - num_patches: int, - target_length: int, - dtype: torch.dtype, + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, ) -> torch.Tensor: # Expand aspect ratio mask to target_length batch_size, max_num_tiles = aspect_ratio_mask.shape @@ -129,7 +130,7 @@ def __init__(self, config: MllamaVisionConfig): self.max_aspect_ratio_id = config.max_aspect_ratio_id self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 self.hidden_size = config.hidden_size - self.scale = config.hidden_size ** -0.5 + self.scale = config.hidden_size**-0.5 self.gate = nn.Parameter(torch.zeros(1)) @@ -190,14 +191,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: # Copied from transformers.models.llama.modeling_llama.eager_attention_forward def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -223,7 +224,7 @@ def __init__(self, config: MllamaVisionConfig): self.embed_dim = config.hidden_size self.num_heads = config.attention_heads self.head_dim = config.hidden_size // config.attention_heads - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.num_key_value_groups = 1 self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) @@ -232,10 +233,10 @@ def __init__(self, config: MllamaVisionConfig): self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=False) def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: query = self.q_proj(hidden_state) key = self.k_proj(hidden_state) @@ -290,9 +291,9 @@ def __init__(self, config: MllamaVisionConfig, is_gated: bool = False): self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, ): # Self Attention residual = hidden_state @@ -330,9 +331,9 @@ def __init__(self, config: MllamaVisionConfig, num_layers=32, is_gated=False): self.config = config def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, ) -> BaseModelOutput: r""" Args: @@ -384,9 +385,9 @@ class MllamaTextCrossAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( - self, - config: Optional[MllamaTextConfig] = None, - layer_idx: Optional[int] = None, + self, + config: Optional[MllamaTextConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.config = config @@ -397,7 +398,7 @@ def __init__( self.head_dim = config.hidden_size // self.num_heads self.layer_idx = layer_idx self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -408,14 +409,14 @@ def __init__( self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() @@ -472,7 +473,7 @@ def forward( def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -514,7 +515,7 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int): self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.rope_theta = config.rope_theta self.layer_idx = layer_idx @@ -524,14 +525,14 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int): self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - use_cache: bool = False, - past_key_value=None, - cache_position=None, - **kwargs, + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, ): bsz, q_len, _ = hidden_states.size() @@ -606,18 +607,18 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int): self.layer_idx = layer_idx def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -681,18 +682,18 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: torch.Tensor, - cross_attention_mask: torch.Tensor, - attention_mask: torch.Tensor, - full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor], - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[torch.Tensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + attention_mask: torch.Tensor, + full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -764,12 +765,13 @@ class MllamaPreTrainedModel(PreTrainedModel): _can_record_outputs = { "hidden_states": [ OutputRecorder(MllamaTextSelfAttention, index=0), - OutputRecorder(MllamaTextCrossAttention, index=0)], - + OutputRecorder(MllamaTextCrossAttention, index=0), + ], "attentions": [ OutputRecorder(MllamaTextSelfAttention, index=1, layer_name="self_attn"), OutputRecorder(MllamaTextSelfAttention, index=1, layer_name="cross_attn"), - OutputRecorder(MllamaTextCrossAttention, index=1, layer_name="cross_attn")] + OutputRecorder(MllamaTextCrossAttention, index=1, layer_name="cross_attn"), + ], } def _init_weights(self, module): @@ -805,11 +807,12 @@ def _init_weights(self, module): # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): @@ -827,12 +830,12 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, ): return None @@ -858,9 +861,10 @@ def _update_causal_mask( ) if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -873,13 +877,13 @@ def _update_causal_mask( @staticmethod # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -946,7 +950,7 @@ def __init__(self, config: MllamaVisionConfig): self.intermediate_layers_indices = config.intermediate_layers_indices self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 - self.scale = config.hidden_size ** -0.5 + self.scale = config.hidden_size**-0.5 self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, @@ -988,11 +992,7 @@ def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: @check_model_inputs @auto_docstring def forward( - self, - pixel_values: torch.Tensor, - aspect_ratio_ids: torch.Tensor, - aspect_ratio_mask: torch.Tensor, - **kwargs + self, pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, aspect_ratio_mask: torch.Tensor, **kwargs ) -> BaseModelOutput: r""" aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*): @@ -1161,18 +1161,18 @@ def __init__(self, config: MllamaTextConfig): @can_return_tuple @auto_docstring def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.FloatTensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.FloatTensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: r""" cross_attention_states (`torch.FloatTensor`, *optional*): @@ -1234,22 +1234,19 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values - ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers for idx, decoder_layer in enumerate(self.layers): - # For text-only path we should skip cross attention layers. # Let's check if the layer is cross attention layer and if we have cross attention states # or cached cross attention states. is_cross_attention_layer = idx in self.cross_attention_layers is_cross_attention_cache_empty = past_key_values is None or ( - past_key_values is not None and past_key_values.get_seq_length(idx) == 0 + past_key_values is not None and past_key_values.get_seq_length(idx) == 0 ) if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: @@ -1310,20 +1307,20 @@ def get_decoder(self): @can_return_tuple @auto_docstring def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[TransformersKwargs], + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" cross_attention_states (`torch.FloatTensor`, *optional*): @@ -1442,20 +1439,20 @@ def get_decoder(self): @can_return_tuple @auto_docstring def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: r""" aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*): @@ -1585,22 +1582,22 @@ def vision_model(self): @can_return_tuple @auto_docstring def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[TransformersKwargs], + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*): @@ -1696,20 +1693,20 @@ def forward( ) def prepare_inputs_for_generation( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - position_ids=None, - pixel_values=None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=None, - past_key_values=None, - use_cache=False, - cache_position=None, - logits_to_keep=None, - **kwargs, + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + pixel_values=None, + aspect_ratio_ids=None, + aspect_ratio_mask=None, + cross_attention_mask=None, + past_key_values=None, + use_cache=False, + cache_position=None, + logits_to_keep=None, + **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model From f13fa4c6873e04fc2a53bc6f363c85c42e299b70 Mon Sep 17 00:00:00 2001 From: itazap Date: Mon, 28 Jul 2025 13:46:34 +0200 Subject: [PATCH 7/9] rebase --- src/transformers/models/mllama/modeling_mllama.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index cfd1ebfd87f7..7bf651552621 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -764,8 +764,8 @@ class MllamaPreTrainedModel(PreTrainedModel): _supports_attention_backend = True _can_record_outputs = { "hidden_states": [ - OutputRecorder(MllamaTextSelfAttention, index=0), - OutputRecorder(MllamaTextCrossAttention, index=0), + OutputRecorder(MllamaSelfAttentionDecoderLayer, index=0), + OutputRecorder(MllamaCrossAttentionDecoderLayer, index=0), ], "attentions": [ OutputRecorder(MllamaTextSelfAttention, index=1, layer_name="self_attn"), @@ -1273,8 +1273,6 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, - # hidden_states=None, - # attentions=None, ) From eda7d1994ffad6e7e8cb511c4b157b9369791eae Mon Sep 17 00:00:00 2001 From: itazap Date: Mon, 28 Jul 2025 14:28:47 +0200 Subject: [PATCH 8/9] feedback --- src/transformers/models/mllama/modeling_mllama.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 7bf651552621..5a158a592aba 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -311,7 +311,7 @@ def forward( hidden_state = self.gate_ffn.tanh() * hidden_state hidden_state = residual + hidden_state - return (hidden_state,) + return hidden_state class MllamaVisionEncoder(nn.Module): @@ -351,11 +351,10 @@ def forward( """ for encoder_layer in self.layers: - layer_outputs = encoder_layer( + hidden_states = encoder_layer( hidden_state=hidden_states, attention_mask=attention_mask, ) - hidden_states = layer_outputs[0] return BaseModelOutput(last_hidden_state=hidden_states) @@ -663,7 +662,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return (hidden_states, self_attn_weights) + return hidden_states, self_attn_weights class MllamaCrossAttentionDecoderLayer(GradientCheckpointingLayer): @@ -715,7 +714,7 @@ def forward( hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - return (hidden_states, attn_weights) + return hidden_states, attn_weights class MllamaRotaryEmbedding(nn.Module): @@ -763,10 +762,7 @@ class MllamaPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True _can_record_outputs = { - "hidden_states": [ - OutputRecorder(MllamaSelfAttentionDecoderLayer, index=0), - OutputRecorder(MllamaCrossAttentionDecoderLayer, index=0), - ], + "hidden_states": [MllamaSelfAttentionDecoderLayer, MllamaCrossAttentionDecoderLayer], "attentions": [ OutputRecorder(MllamaTextSelfAttention, index=1, layer_name="self_attn"), OutputRecorder(MllamaTextSelfAttention, index=1, layer_name="cross_attn"), From fa870ff00080aa732ee07201375c8a55078d7120 Mon Sep 17 00:00:00 2001 From: itazap Date: Mon, 28 Jul 2025 15:28:52 +0200 Subject: [PATCH 9/9] only return hidden_states --- src/transformers/models/mllama/modeling_mllama.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 5a158a592aba..5a0bcb55d42a 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -662,7 +662,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return hidden_states, self_attn_weights + return hidden_states class MllamaCrossAttentionDecoderLayer(GradientCheckpointingLayer): @@ -714,7 +714,7 @@ def forward( hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - return hidden_states, attn_weights + return hidden_states class MllamaRotaryEmbedding(nn.Module): @@ -1248,7 +1248,7 @@ def forward( if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: continue - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, @@ -1262,8 +1262,6 @@ def forward( **kwargs, ) - hidden_states = layer_outputs[0] - hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast(