From a97ca9f6f10a369e74c174b87ed593220c77c7f4 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 12:49:28 +0000 Subject: [PATCH 001/146] first batch (4) --- .../modeling_audio_spectrogram_transformer.py | 14 +-- .../models/autoformer/modeling_autoformer.py | 75 +++++-------- .../models/clipseg/modeling_clipseg.py | 25 ++--- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 104 ++++++------------ 4 files changed, 70 insertions(+), 148 deletions(-) diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index d3ccf24153b7..602de3ff72b5 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -22,6 +22,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -282,7 +283,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST -class ASTLayer(nn.Module): +class ASTLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ASTConfig) -> None: @@ -349,16 +350,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) - + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 6db63b4945f7..46dde3f146c5 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -30,6 +30,7 @@ _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPredictionOutput, Seq2SeqTSPredictionOutput from ...modeling_utils import PreTrainedModel from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput @@ -670,7 +671,7 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -class AutoformerEncoderLayer(nn.Module): +class AutoformerEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: AutoformerConfig): super().__init__() self.embed_dim = config.d_model @@ -744,7 +745,7 @@ def forward( return outputs -class AutoformerDecoderLayer(nn.Module): +class AutoformerDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: AutoformerConfig): super().__init__() self.embed_dim = config.d_model @@ -1042,21 +1043,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1186,6 +1178,12 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + input_shape = inputs_embeds.size()[:-1] # expand encoder attention mask @@ -1228,38 +1226,17 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) (hidden_states, residual_trend) = layer_outputs[0] trend = trend + residual_trend diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index c68404cb66c8..cff0471c8137 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -374,7 +375,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg -class CLIPSegEncoderLayer(nn.Module): +class CLIPSegEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: CLIPSegConfig): super().__init__() self.embed_dim = config.hidden_size @@ -539,22 +540,12 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) - + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index c4e151e9ce16..c01ae5d83be9 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -727,7 +727,7 @@ def forward( } -class Qwen2_5OmniAudioEncoderLayer(nn.Module): +class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): super().__init__() self.embed_dim = config.d_model @@ -890,17 +890,10 @@ def forward( ).to(torch.int32) for idx, encoder_layer in enumerate(self.layers): - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - cu_seqlens, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - cu_seqlens, - ) + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + ) hidden_states = layer_outputs[0] @@ -1107,7 +1100,7 @@ def forward(self, hidden_state): } -class Qwen2_5OmniVisionBlock(nn.Module): +class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer): def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None: super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) @@ -1324,16 +1317,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb - ) - else: - hidden_states = blk( - hidden_states, - cu_seqlens=cu_seqlens_now, - rotary_pos_emb=rotary_pos_emb, - ) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] @@ -1760,30 +1748,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] @@ -2325,30 +2300,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] From 2627646147e1decb8cf0a8769f6a8783756a1174 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 12:57:56 +0000 Subject: [PATCH 002/146] align --- .../models/align/modeling_align.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 952fe0bdc9e4..e10ea94e8f84 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -29,6 +29,7 @@ BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndNoAttention, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging @@ -827,7 +828,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText -class AlignTextLayer(nn.Module): +class AlignTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -953,27 +954,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: From c8926f7cff257882a8ddedc62d7c4c919cd783f9 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 12:59:39 +0000 Subject: [PATCH 003/146] altclip --- .../models/align/modeling_align.py | 2 +- .../models/altclip/modeling_altclip.py | 56 ++++++------------- 2 files changed, 19 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index e10ea94e8f84..55ce72472264 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -23,13 +23,13 @@ from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithNoAttention, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndNoAttention, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 3e917940809c..76fdf62a3517 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -418,7 +419,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta -class AltRobertaLayer(nn.Module): +class AltRobertaLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -544,27 +545,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: @@ -732,7 +721,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class AltCLIPEncoderLayer(nn.Module): +class AltCLIPEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: AltCLIPConfig): super().__init__() self.embed_dim = config.hidden_size @@ -848,21 +837,12 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From ae1b29a7ff51faa8f428cf8297e4fbcc70a9ba7f Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:00:29 +0000 Subject: [PATCH 004/146] beit --- src/transformers/models/beit/modeling_beit.py | 30 +++++++------------ 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 086e62561fd0..78b815d98a2e 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -34,6 +34,7 @@ MaskedLMOutput, SemanticSegmenterOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int @@ -497,7 +498,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class BeitLayer(nn.Module): +class BeitLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None: @@ -695,25 +696,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - relative_position_bias, - interpolate_pos_encoding, - resolution, - ) - else: - layer_outputs = layer_module( - hidden_states, - layer_head_mask, - output_attentions, - relative_position_bias, - interpolate_pos_encoding, - resolution, - ) + layer_outputs = layer_module( + hidden_states, + layer_head_mask, + output_attentions, + relative_position_bias, + interpolate_pos_encoding, + resolution, + ) hidden_states = layer_outputs[0] From 4dff0763d5c5a49487f6f544b47d919b9df38846 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:01:39 +0000 Subject: [PATCH 005/146] bert --- src/transformers/models/beit/modeling_beit.py | 2 +- src/transformers/models/bert/modeling_bert.py | 33 +++++++------------ 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 78b815d98a2e..e830e66285a5 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BackboneOutput, BaseModelOutput, @@ -34,7 +35,6 @@ MaskedLMOutput, SemanticSegmenterOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 12080dfff6ff..743fba8a0902 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -522,7 +523,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class BertLayer(nn.Module): +class BertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -647,27 +648,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: From 0d0d8c7e9e02349324fb2fdd84a068d9ae3ee379 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:02:51 +0000 Subject: [PATCH 006/146] yolos --- src/transformers/models/yolos/modeling_yolos.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index f1a7f4fab822..01115516683f 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging @@ -403,7 +404,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS -class YolosLayer(nn.Module): +class YolosLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: YolosConfig) -> None: @@ -492,15 +493,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] From 8b664282fb0034d20e0b4e5113bde180725fde85 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:08:22 +0000 Subject: [PATCH 007/146] dino, pvt_v2 --- src/transformers/models/dinov2/modeling_dinov2.py | 13 +++---------- src/transformers/models/pvt_v2/modeling_pvt_v2.py | 8 +++----- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index bd35abb941ac..7c0cbd6b28cb 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int @@ -382,7 +383,7 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return self.weights_out(hidden) -class Dinov2Layer(nn.Module): +class Dinov2Layer(GradientCheckpointingLayer): """This corresponds to the Block class in the original implementation.""" def __init__(self, config: Dinov2Config) -> None: @@ -458,15 +459,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/pvt_v2/modeling_pvt_v2.py b/src/transformers/models/pvt_v2/modeling_pvt_v2.py index 7c2f48bd5806..fd5e5d89bc42 100644 --- a/src/transformers/models/pvt_v2/modeling_pvt_v2.py +++ b/src/transformers/models/pvt_v2/modeling_pvt_v2.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput, BaseModelOutput, ImageClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging @@ -300,7 +301,7 @@ def forward(self, hidden_states: torch.Tensor, height: int, width: int, output_a return outputs -class PvtV2EncoderLayer(nn.Module): +class PvtV2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: PvtV2Config, layer_idx: int): super().__init__() self.patch_embedding = PvtV2OverlapPatchEmbeddings( @@ -367,10 +368,7 @@ def forward( batch_size = pixel_values.shape[0] hidden_states = pixel_values for idx, layer in enumerate(self.layers): - if self.gradient_checkpointing and self.training: - layer_output = self._gradient_checkpointing_func(layer.__call__, hidden_states, output_attentions) - else: - layer_output = layer(hidden_states, output_attentions) + layer_output = layer(hidden_states, output_attentions) outputs, height, width = layer_output hidden_states = outputs[0] if output_attentions: From 0d387eba70b3f5574a672101574d56d43d2cebc6 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:10:16 +0000 Subject: [PATCH 008/146] bark, bart, bert_generation --- src/transformers/models/bark/modeling_bark.py | 30 +++----- src/transformers/models/bart/modeling_bart.py | 69 +++++++------------ .../modeling_bert_generation.py | 33 +++------ 3 files changed, 44 insertions(+), 88 deletions(-) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 4ee608d9aec6..4598552615e5 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel, get_parameter_device from ...utils import ( auto_docstring, @@ -309,7 +310,7 @@ def forward(self, hidden_states): return hidden_states -class BarkBlock(nn.Module): +class BarkBlock(GradientCheckpointingLayer): def __init__(self, config, is_causal=False): super().__init__() @@ -606,25 +607,14 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - attention_mask, - head_mask[i], - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - past_key_values=past_layer_key_values, - attention_mask=attention_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - ) + outputs = block( + hidden_states, + past_key_values=past_layer_key_values, + attention_mask=attention_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = outputs[0] diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index f0adc76924fe..72b705ef5e34 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -42,6 +42,7 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -270,7 +271,7 @@ def forward( return attn_output, attn_weights, past_key_value -class BartEncoderLayer(nn.Module): +class BartEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -341,7 +342,7 @@ def forward( return outputs -class BartDecoderLayer(nn.Module): +class BartDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -875,21 +876,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1137,35 +1129,20 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 959a3cce077d..7bb73bf33320 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -275,7 +276,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->BertGeneration -class BertGenerationLayer(nn.Module): +class BertGenerationLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -401,27 +402,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: From 6faee3fbe4e03f40f059ba8246ef27a2e6592d07 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:13:19 +0000 Subject: [PATCH 009/146] big_bird, biogpt --- .../models/big_bird/modeling_big_bird.py | 45 ++++------ .../modeling_bigbird_pegasus.py | 84 +++++++------------ .../models/biogpt/modeling_biogpt.py | 38 +++------ 3 files changed, 56 insertions(+), 111 deletions(-) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index e06c8f87d5f8..5f2c83acb045 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -36,6 +36,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import ModelOutput, auto_docstring, logging @@ -1419,7 +1420,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class BigBirdLayer(nn.Module): +class BigBirdLayer(GradientCheckpointingLayer): def __init__(self, config, seed=None): super().__init__() self.config = config @@ -1593,35 +1594,19 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - band_mask, - from_mask, - to_mask, - blocked_encoder_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - band_mask, - from_mask, - to_mask, - blocked_encoder_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + band_mask, + from_mask, + to_mask, + blocked_encoder_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index bc72d16bf547..2e392a2842cd 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -41,6 +41,7 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging @@ -1333,7 +1334,7 @@ def forward( return attn_output, attn_weights, past_key_value -class BigBirdPegasusEncoderLayer(nn.Module): +class BigBirdPegasusEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BigBirdPegasusConfig, seed=None): super().__init__() self.attention_type = config.attention_type @@ -1420,7 +1421,7 @@ def set_attention_type(self, value: str): self.self_attn.set_attention_type(value) -class BigBirdPegasusDecoderLayer(nn.Module): +class BigBirdPegasusDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -1947,31 +1948,17 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - band_mask, - from_mask, - to_mask, - blocked_encoder_mask, - blocked_encoder_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - band_mask=band_mask, - from_mask=from_mask, - to_mask=to_mask, - from_blocked_mask=blocked_encoder_mask, - to_blocked_mask=blocked_encoder_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + band_mask=band_mask, + from_mask=from_mask, + to_mask=to_mask, + from_blocked_mask=blocked_encoder_mask, + to_blocked_mask=blocked_encoder_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -2297,35 +2284,20 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index a1fba008841f..4387a2cd6af5 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -38,6 +38,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, is_torch_flex_attn_available, logging @@ -248,7 +249,7 @@ def forward( return attn_output, attn_weights, past_key_value -class BioGptDecoderLayer(nn.Module): +class BioGptDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.hidden_size @@ -646,30 +647,17 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - head_mask[idx] if head_mask is not None else None, - None, - output_attentions, - use_cache, - position_ids, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - position_ids=position_ids, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_ids=position_ids, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] From 3f346063a11b1b316ef96ca1daf53a72645b65d1 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:14:52 +0000 Subject: [PATCH 010/146] blnderbot, bloom --- .../models/blenderbot/modeling_blenderbot.py | 69 +++++++------------ .../modeling_blenderbot_small.py | 69 +++++++------------ .../models/bloom/modeling_bloom.py | 36 ++++------ 3 files changed, 58 insertions(+), 116 deletions(-) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 0d699d01b59c..9ccf140c31c2 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -41,6 +41,7 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -270,7 +271,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT -class BlenderbotEncoderLayer(nn.Module): +class BlenderbotEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model @@ -339,7 +340,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT -class BlenderbotDecoderLayer(nn.Module): +class BlenderbotDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -825,21 +826,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1090,35 +1082,20 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 99666356d091..1ca3353e5627 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -39,6 +39,7 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -254,7 +255,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL -class BlenderbotSmallEncoderLayer(nn.Module): +class BlenderbotSmallEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -326,7 +327,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL -class BlenderbotSmallDecoderLayer(nn.Module): +class BlenderbotSmallDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -812,21 +813,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1073,35 +1065,20 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 66dfc0c1fa8c..4ae1668c011a 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -34,6 +34,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, @@ -366,7 +367,7 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch. return output -class BloomBlock(nn.Module): +class BloomBlock(GradientCheckpointingLayer): def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None): super().__init__() hidden_size = config.hidden_size @@ -605,29 +606,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - alibi, - causal_mask, - past_key_values, - head_mask[i], - use_cache, - output_attentions, - cache_position, - ) - else: - outputs = block( - hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - cache_position=cache_position, - ) + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + cache_position=cache_position, + ) hidden_states = outputs[0] if use_cache: From 3bb70d911f75fe8a569d39c70d30f08a2a544383 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:15:44 +0000 Subject: [PATCH 011/146] bridgetower --- .../bridgetower/modeling_bridgetower.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 10db0bbb62fb..a0a811edfae5 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -32,6 +32,7 @@ ModelOutput, SequenceClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int @@ -662,7 +663,7 @@ def feed_forward_chunk(self, attention_output): return layer_output -class BridgeTowerTextLayer(nn.Module): +class BridgeTowerTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -788,27 +789,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: From 5757f3eede074c0b294dcf41109daf06f8476dfd Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:20:19 +0000 Subject: [PATCH 012/146] camambert, canine, chameleon --- .../models/camembert/modeling_camembert.py | 33 ++++++----------- .../models/canine/modeling_canine.py | 14 ++----- .../models/chameleon/modeling_chameleon.py | 37 +++++++------------ 3 files changed, 27 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 1b4a52295f28..d4ce40d2d5ed 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -37,6 +37,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging @@ -478,7 +479,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->Camembert -class CamembertLayer(nn.Module): +class CamembertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -604,27 +605,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index d55c600d05de..75435799ab7b 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -34,6 +34,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging @@ -672,7 +673,7 @@ def forward(self, hidden_states: tuple[torch.FloatTensor], input_tensor: torch.F return hidden_states -class CanineLayer(nn.Module): +class CanineLayer(GradientCheckpointingLayer): def __init__( self, config, @@ -779,16 +780,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index d6575a8751d0..784d97adf9aa 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -28,6 +28,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -387,7 +388,7 @@ def forward( # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON -class ChameleonDecoderLayer(nn.Module): +class ChameleonDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ChameleonConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -462,7 +463,7 @@ def forward( return outputs -class ChameleonSwinDecoderLayer(nn.Module): +class ChameleonSwinDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ChameleonConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -1015,28 +1016,16 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) hidden_states = layer_outputs[0] From c59a7d56dd34fe659cbb090145ec48dd42883e71 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:22:41 +0000 Subject: [PATCH 013/146] chinese clip, clap, clip --- .../chinese_clip/modeling_chinese_clip.py | 50 ++++++------------- src/transformers/models/clap/modeling_clap.py | 46 ++++++----------- src/transformers/models/clip/modeling_clip.py | 24 +++------ 3 files changed, 39 insertions(+), 81 deletions(-) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 5de98397cad0..540ba8a4e596 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -29,6 +29,7 @@ BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -577,7 +578,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText -class ChineseCLIPTextLayer(nn.Module): +class ChineseCLIPTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -663,7 +664,7 @@ def feed_forward_chunk(self, attention_output): return layer_output -class ChineseCLIPVisionLayer(nn.Module): +class ChineseCLIPVisionLayer(GradientCheckpointingLayer): def __init__(self, config: ChineseCLIPConfig): super().__init__() self.embed_dim = config.hidden_size @@ -816,27 +817,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: @@ -920,17 +909,10 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 6a44e36ade43..4142d75b1a8e 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -29,6 +29,7 @@ BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -691,7 +692,7 @@ def forward( # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->ClapAudio -class ClapAudioStage(nn.Module): +class ClapAudioStage(GradientCheckpointingLayer): def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): super().__init__() self.config = config @@ -928,14 +929,9 @@ def forward( input_dimensions = self.input_resolutions[i] - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions - ) - else: - layer_outputs = layer_module( - hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition - ) + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) hidden_states = layer_outputs[0] @@ -1355,7 +1351,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText -class ClapTextLayer(nn.Module): +class ClapTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -1481,27 +1477,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index b93f63bcea97..e3c0b36710c1 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig @@ -393,7 +394,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class CLIPEncoderLayer(nn.Module): +class CLIPEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]): super().__init__() self.embed_dim = config.hidden_size @@ -575,21 +576,12 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From d7cb7954a2e68fb7c53ec406459928da858aed44 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:24:05 +0000 Subject: [PATCH 014/146] codegen, conditional detr, convbert --- .../models/codegen/modeling_codegen.py | 36 ++++++----------- .../modeling_conditional_detr.py | 39 +++++++------------ .../models/convbert/modeling_convbert.py | 30 +++++--------- 3 files changed, 35 insertions(+), 70 deletions(-) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 6a99d0fa390a..ad8f6221ddcc 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -25,6 +25,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, @@ -245,7 +246,7 @@ def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTens # Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen -class CodeGenBlock(nn.Module): +class CodeGenBlock(GradientCheckpointingLayer): # Ignore copy def __init__(self, config, layer_idx=None): super().__init__() @@ -437,29 +438,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - causal_mask, - position_ids, - head_mask[i], - use_cache, - output_attentions, - cache_position, - ) - else: - outputs = block( - hidden_states=hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + outputs = block( + hidden_states=hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 19b1439302eb..547c96b57c2e 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_timm_available, logging, requires_backends from ...utils.backbone_utils import load_backbone @@ -827,7 +828,7 @@ def forward( return outputs -class ConditionalDetrDecoderLayer(nn.Module): +class ConditionalDetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ConditionalDetrConfig): super().__init__() self.embed_dim = config.d_model @@ -1297,31 +1298,17 @@ def forward( pos_transformation = self.query_scale(hidden_states) # apply transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - None, - object_queries, - query_position_embeddings, - query_sine_embed, - encoder_hidden_states, - encoder_attention_mask, - None, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=None, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - query_sine_embed=query_sine_embed, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - is_first=(idx == 0), - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=None, + object_queries=object_queries, + query_position_embeddings=query_position_embeddings, + query_sine_embed=query_sine_embed, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + is_first=(idx == 0), + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 1a443c575ab0..d13bbfa14e27 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -33,6 +33,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -532,7 +533,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class ConvBertLayer(nn.Module): +class ConvBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -620,25 +621,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) From 39784f72af078b97ec63355eb1fd674aceab5ce7 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:26:20 +0000 Subject: [PATCH 015/146] dab_detr, data2vec --- .../models/dab_detr/modeling_dab_detr.py | 59 ++++++------------- .../data2vec/modeling_data2vec_audio.py | 27 +++------ .../models/data2vec/modeling_data2vec_text.py | 33 ++++------- .../data2vec/modeling_data2vec_vision.py | 30 ++++------ 4 files changed, 47 insertions(+), 102 deletions(-) diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index c977f4b923b8..36f86ca3ba3e 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -702,7 +703,7 @@ def forward(self, hidden_states: torch.Tensor): # Modified from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->DabDetrEncoderLayer,DetrConfig->DabDetrConfig -class DabDetrEncoderLayer(nn.Module): +class DabDetrEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: DabDetrConfig): super().__init__() self.hidden_size = config.hidden_size @@ -764,7 +765,7 @@ def forward( # Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderLayer with ConditionalDetr->DabDetr -class DabDetrDecoderLayer(nn.Module): +class DabDetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DabDetrConfig, is_first: bool = False): super().__init__() self.self_attn = DabDetrDecoderLayerSelfAttention(config) @@ -976,21 +977,12 @@ def forward( # we add object_queries * pos_scaler as extra input to the encoder_layer scaled_object_queries = object_queries * pos_scales - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - scaled_object_queries, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - object_queries=scaled_object_queries, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + object_queries=scaled_object_queries, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1138,29 +1130,16 @@ def forward( reference_anchor_size[..., 1] / obj_center[..., 3] ).unsqueeze(-1) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - None, - object_queries, - query_pos, - query_sine_embed, - encoder_hidden_states, - memory_key_padding_mask, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=None, - object_queries=object_queries, - query_position_embeddings=query_pos, - query_sine_embed=query_sine_embed, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=memory_key_padding_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=None, + object_queries=object_queries, + query_position_embeddings=query_pos, + query_sine_embed=query_sine_embed, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=memory_key_padding_mask, + output_attentions=output_attentions, + ) # iter update hidden_states = layer_outputs[0] diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 940af308cd78..1e7183a05ae9 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -41,6 +41,7 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available @@ -51,7 +52,7 @@ from ...integrations.flex_attention import make_flex_block_causal_mask -class Data2VecAudioConvLayer(nn.Module): +class Data2VecAudioConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -155,13 +156,7 @@ def forward(self, input_values): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -357,7 +352,7 @@ def forward(self, hidden_states): return hidden_states -class Data2VecAudioEncoderLayer(nn.Module): +class Data2VecAudioEncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = Data2VecAudioAttention( @@ -441,17 +436,9 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 0d7e85134907..b42f9f56d5ba 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -34,6 +34,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging @@ -375,7 +376,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText -class Data2VecTextLayer(nn.Module): +class Data2VecTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -501,27 +502,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index c48782d24771..0769b455ddf3 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -32,6 +32,7 @@ ImageClassifierOutput, SemanticSegmenterOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int @@ -497,7 +498,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision -class Data2VecVisionLayer(nn.Module): +class Data2VecVisionLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__( @@ -699,25 +700,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - relative_position_bias, - interpolate_pos_encoding, - resolution, - ) - else: - layer_outputs = layer_module( - hidden_states, - layer_head_mask, - output_attentions, - relative_position_bias, - interpolate_pos_encoding, - resolution, - ) + layer_outputs = layer_module( + hidden_states, + layer_head_mask, + output_attentions, + relative_position_bias, + interpolate_pos_encoding, + resolution, + ) hidden_states = layer_outputs[0] From 203348db52f53b6fb8e283f4f001501eee2d1ebc Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 13:26:47 +0000 Subject: [PATCH 016/146] dbrx, deberta --- src/transformers/models/dbrx/modeling_dbrx.py | 36 +++++++------------ .../models/deberta/modeling_deberta.py | 30 ++++++---------- 2 files changed, 22 insertions(+), 44 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 097370a46cd5..12b7b00bc93e 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -27,6 +27,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_dbrx import DbrxConfig @@ -719,7 +720,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return out, weights -class DbrxBlock(nn.Module): +class DbrxBlock(GradientCheckpointingLayer): def __init__(self, config: DbrxConfig, block_idx: int): super().__init__() self.hidden_size = config.d_model @@ -942,29 +943,16 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - block_outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - ) - else: - block_outputs = block( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - ) + block_outputs = block( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = block_outputs[0] diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index eef11f7ec34b..6cf829714e01 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -29,6 +29,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_deberta import DebertaConfig @@ -492,7 +493,7 @@ def forward(self, hidden_states, input_tensor): return hidden_states -class DebertaLayer(nn.Module): +class DebertaLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = DebertaAttention(config) @@ -580,25 +581,14 @@ def forward( rel_embeddings = self.get_rel_embedding() for i, layer_module in enumerate(self.layer): - if self.gradient_checkpointing and self.training: - hidden_states, att_m = self._gradient_checkpointing_func( - layer_module.__call__, - next_kv, - attention_mask, - query_states, - relative_pos, - rel_embeddings, - output_attentions, - ) - else: - hidden_states, att_m = layer_module( - next_kv, - attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - ) + hidden_states, att_m = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) From b2719f338a1ce0434116ceb3128dc1771090f32a Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:03:53 +0000 Subject: [PATCH 017/146] deberta, decicion_tranformer, deformable_detr --- .../models/deberta_v2/modeling_deberta_v2.py | 30 +++----- .../modeling_decision_transformer.py | 39 ++++------ .../modeling_deformable_detr.py | 74 ++++++------------- 3 files changed, 47 insertions(+), 96 deletions(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 5073d1de5263..bd341eb213fd 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -31,6 +31,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_deberta_v2 import DebertaV2Config @@ -418,7 +419,7 @@ def forward(self, hidden_states, input_tensor): # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 -class DebertaV2Layer(nn.Module): +class DebertaV2Layer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = DebertaV2Attention(config) @@ -655,25 +656,14 @@ def forward( next_kv = hidden_states rel_embeddings = self.get_rel_embedding() for i, layer_module in enumerate(self.layer): - if self.gradient_checkpointing and self.training: - output_states, attn_weights = self._gradient_checkpointing_func( - layer_module.__call__, - next_kv, - attention_mask, - query_states, - relative_pos, - rel_embeddings, - output_attentions, - ) - else: - output_states, attn_weights = layer_module( - next_kv, - attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - ) + output_states, attn_weights = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) if output_attentions: all_attentions = all_attentions + (attn_weights,) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index d34e989986d1..aceef4ae41fb 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( @@ -360,7 +361,7 @@ def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.Fl # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2 -class DecisionTransformerGPT2Block(nn.Module): +class DecisionTransformerGPT2Block(GradientCheckpointingLayer): # Ignore copy def __init__(self, config, layer_idx=None): super().__init__() @@ -654,31 +655,17 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - past_key_value=past_key_values, - cache_position=cache_position, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) + outputs = block( + hidden_states, + past_key_value=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = outputs[0] diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index e36da6da89bc..43908c7548d4 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -28,6 +28,7 @@ from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import meshgrid from ...utils import ( @@ -759,7 +760,7 @@ def forward( return attn_output, attn_weights_reshaped -class DeformableDetrEncoderLayer(nn.Module): +class DeformableDetrEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: DeformableDetrConfig): super().__init__() self.embed_dim = config.d_model @@ -848,7 +849,7 @@ def forward( return outputs -class DeformableDetrDecoderLayer(nn.Module): +class DeformableDetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DeformableDetrConfig): super().__init__() self.embed_dim = config.d_model @@ -1126,29 +1127,16 @@ def forward( for i, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - position_embeddings, - reference_points, - spatial_shapes, - spatial_shapes_list, - level_start_index, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - position_embeddings=position_embeddings, - reference_points=reference_points, - spatial_shapes=spatial_shapes, - spatial_shapes_list=spatial_shapes_list, - level_start_index=level_start_index, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1273,31 +1261,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - position_embeddings, - reference_points_input, - spatial_shapes, - spatial_shapes_list, - level_start_index, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - encoder_hidden_states=encoder_hidden_states, - reference_points=reference_points_input, - spatial_shapes=spatial_shapes, - spatial_shapes_list=spatial_shapes_list, - level_start_index=level_start_index, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points_input, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From 2ed2c5bb213ee60c4864cda6955fa7f85f8b14ba Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:05:57 +0000 Subject: [PATCH 018/146] deit, deta, mctct --- src/transformers/models/deit/modeling_deit.py | 13 ++----- .../models/deprecated/deta/modeling_deta.py | 36 +++++++------------ .../models/deprecated/mctct/modeling_mctct.py | 22 ++++-------- 3 files changed, 22 insertions(+), 49 deletions(-) diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 6b6284995fc7..53f870b3d008 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -30,6 +30,7 @@ ImageClassifierOutput, MaskedImageModelingOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -347,7 +348,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT -class DeiTLayer(nn.Module): +class DeiTLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: DeiTConfig) -> None: @@ -414,15 +415,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deprecated/deta/modeling_deta.py b/src/transformers/models/deprecated/deta/modeling_deta.py index ae44b5e1b12f..214a41204a01 100644 --- a/src/transformers/models/deprecated/deta/modeling_deta.py +++ b/src/transformers/models/deprecated/deta/modeling_deta.py @@ -41,6 +41,7 @@ from ....modeling_attn_mask_utils import _prepare_4d_attention_mask from ....modeling_outputs import BaseModelOutput from ....modeling_utils import PreTrainedModel +from ....modeling_layers import GradientCheckpointingLayer from ....pytorch_utils import meshgrid from ....utils import is_accelerate_available, is_ninja_available, is_torchvision_available, logging, requires_backends from ....utils.backbone_utils import load_backbone @@ -909,7 +910,7 @@ def forward( return outputs -class DetaDecoderLayer(nn.Module): +class DetaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DetaConfig): super().__init__() self.embed_dim = config.d_model @@ -1341,29 +1342,16 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - position_embeddings, - reference_points_input, - spatial_shapes, - level_start_index, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - encoder_hidden_states=encoder_hidden_states, - reference_points=reference_points_input, - spatial_shapes=spatial_shapes, - level_start_index=level_start_index, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points_input, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index e4852ff78f87..139326d44c9d 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -27,6 +27,7 @@ from ....integrations.fsdp import is_fsdp_managed_module from ....modeling_attn_mask_utils import _prepare_4d_attention_mask from ....modeling_outputs import BaseModelOutput, CausalLMOutput +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import ( PreTrainedModel, apply_chunking_to_forward, @@ -377,7 +378,7 @@ def forward(self, hidden_states, input_tensor): return hidden_states -class MCTCTLayer(nn.Module): +class MCTCTLayer(GradientCheckpointingLayer): def __init__(self, config: MCTCTConfig): super().__init__() @@ -591,20 +592,11 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From 87704a7f080f9feb35119e398fe314bff5c6dc2b Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:09:15 +0000 Subject: [PATCH 019/146] detr, dinov2, distilbert --- src/transformers/models/detr/modeling_detr.py | 31 +++++++------------ .../modeling_dinov2_with_registers.py | 13 ++------ .../models/distilbert/modeling_distilbert.py | 24 +++++--------- 3 files changed, 22 insertions(+), 46 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 9f8ea167ab7b..21e84354d08b 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -677,7 +678,7 @@ def forward( return outputs -class DetrDecoderLayer(nn.Module): +class DetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DetrConfig): super().__init__() self.embed_dim = config.d_model @@ -1045,25 +1046,15 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - combined_attention_mask, - encoder_hidden_states, - encoder_attention_mask, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + object_queries=object_queries, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index c2eeb197021b..a4b844665e12 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int @@ -399,7 +400,7 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return self.weights_out(hidden) -class Dinov2WithRegistersLayer(nn.Module): +class Dinov2WithRegistersLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the original implementation.""" def __init__(self, config: Dinov2WithRegistersConfig) -> None: @@ -476,15 +477,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 1a84544bee3f..bcaeaa5d7372 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -39,6 +39,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ( apply_chunking_to_forward, @@ -436,7 +437,7 @@ def ff_chunk(self, input: torch.Tensor) -> torch.Tensor: } -class TransformerBlock(nn.Module): +class TransformerBlock(GradientCheckpointingLayer): def __init__(self, config: PretrainedConfig): super().__init__() @@ -532,21 +533,12 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_state, - attn_mask, - head_mask[i], - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_state, - attn_mask, - head_mask[i], - output_attentions, - ) + layer_outputs = layer_module( + hidden_state, + attn_mask, + head_mask[i], + output_attentions, + ) hidden_state = layer_outputs[-1] From cd69033250f17ddaeb88181199bc08ae76b9bcac Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:09:54 +0000 Subject: [PATCH 020/146] donut, dpt, electra --- .../models/donut/modeling_donut_swin.py | 19 +++-------- src/transformers/models/dpt/modeling_dpt.py | 13 ++------ .../models/electra/modeling_electra.py | 33 +++++++------------ 3 files changed, 19 insertions(+), 46 deletions(-) diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index a63b0d3f0f57..603acec77829 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -27,6 +27,7 @@ from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -706,7 +707,7 @@ def forward( # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin -class DonutSwinStage(nn.Module): +class DonutSwinStage(GradientCheckpointingLayer): def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): super().__init__() self.config = config @@ -816,19 +817,9 @@ def forward( for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - input_dimensions, - layer_head_mask, - output_attentions, - always_partition, - ) - else: - layer_outputs = layer_module( - hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition - ) + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) hidden_states = layer_outputs[0] hidden_states_before_downsampling = layer_outputs[1] diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index d3c1703b9ee7..5c61e911b281 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -469,7 +470,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput -class DPTViTLayer(nn.Module): +class DPTViTLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: DPTConfig) -> None: @@ -536,15 +537,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index dfe5849a5a7d..dd23f766d029 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -36,6 +36,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -436,7 +437,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra -class ElectraLayer(nn.Module): +class ElectraLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -562,27 +563,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: From 9a54ad1c118627ff665662151488b1b199f139d9 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:10:34 +0000 Subject: [PATCH 021/146] ernie, esm, falcon --- .../models/ernie/modeling_ernie.py | 33 +++++---------- src/transformers/models/esm/modeling_esm.py | 33 +++++---------- .../models/falcon/modeling_falcon.py | 42 +++++++------------ 3 files changed, 36 insertions(+), 72 deletions(-) diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 21a79a3fabc0..f364a2c20632 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -37,6 +37,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging @@ -361,7 +362,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Ernie -class ErnieLayer(nn.Module): +class ErnieLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -487,27 +488,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index dbf260fb2160..ea9458e560da 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -31,6 +31,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging from .configuration_esm import EsmConfig @@ -594,7 +595,7 @@ def forward(self, hidden_states, input_tensor): return hidden_states -class EsmLayer(nn.Module): +class EsmLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -720,27 +721,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 455afbb21577..c2d82af752ff 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -38,6 +38,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, @@ -551,7 +552,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: } -class FalconDecoderLayer(nn.Module): +class FalconDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: FalconConfig, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -831,33 +832,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - alibi, - causal_mask, - position_ids, - head_mask[i], - past_key_values, - use_cache, - output_attentions, - cache_position, - position_embeddings, - ) - else: - outputs = block( - hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = outputs[0] if use_cache is True: From 6855515bf86ac8315f2a425510f0eb4243cc109d Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:14:07 +0000 Subject: [PATCH 022/146] flava, fnet, falcon_mamba --- .../falcon_mamba/modeling_falcon_mamba.py | 20 ++++++++----------- .../models/flava/modeling_flava.py | 14 +++---------- src/transformers/models/fnet/modeling_fnet.py | 8 +++----- 3 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 257f2c50cd93..2df27e390ea6 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import MambaCache from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available @@ -405,7 +406,7 @@ def forward(self, hidden_states): # Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache -class FalconMambaBlock(nn.Module): +class FalconMambaBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.config = config @@ -620,17 +621,12 @@ def forward( hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None for mixer_block in self.layers: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask - ) - else: - hidden_states = mixer_block( - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - ) + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 3bd7b45d0dc8..b0e5f89193b5 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int from .configuration_flava import ( @@ -577,7 +578,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class FlavaLayer(nn.Module): +class FlavaLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: FlavaPossibleConfigs) -> None: @@ -648,16 +649,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 619d6c9c5add..858a0d8474c5 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -42,6 +42,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import logging @@ -235,7 +236,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class FNetLayer(nn.Module): +class FNetLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -276,10 +277,7 @@ def forward(self, hidden_states, output_hidden_states=False, return_dict=True): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func(layer_module.__call__, hidden_states) - else: - layer_outputs = layer_module(hidden_states) + layer_outputs = layer_module(hidden_states) hidden_states = layer_outputs[0] From f4f8319c39db8fd494c8ef726b84c5e65b1d15e1 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:16:36 +0000 Subject: [PATCH 023/146] focalnet, git, gpt2 --- .../models/focalnet/modeling_focalnet.py | 12 ++--- src/transformers/models/git/modeling_git.py | 52 ++++++------------- src/transformers/models/gpt2/modeling_gpt2.py | 41 +++++---------- 3 files changed, 34 insertions(+), 71 deletions(-) diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 232f1e6ed1fa..6d2fbc6069c3 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from ...utils.backbone_utils import BackboneMixin @@ -455,7 +456,7 @@ def forward(self, hidden_state, input_dimensions): return hidden_state -class FocalNetStage(nn.Module): +class FocalNetStage(GradientCheckpointingLayer): def __init__(self, config, index, input_resolution): super().__init__() @@ -560,14 +561,7 @@ def forward( all_reshaped_hidden_states += (reshaped_hidden_state,) for i, stage_module in enumerate(self.stages): - if self.gradient_checkpointing and self.training: - stage_outputs = self._gradient_checkpointing_func( - stage_module.__call__, - hidden_states, - input_dimensions, - ) - else: - stage_outputs = stage_module(hidden_states, input_dimensions) + stage_outputs = stage_module(hidden_states, input_dimensions) hidden_states = stage_outputs[0] hidden_states_before_downsampling = stage_outputs[1] diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 6068ce169da3..17fa701a2127 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -33,6 +33,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithPast, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -343,7 +344,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class GitLayer(nn.Module): +class GitLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -441,24 +442,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - past_key_values, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - past_key_values, - output_attentions, - pixel_values_present, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + past_key_values, + output_attentions, + pixel_values_present, + ) hidden_states = layer_outputs[0] if use_cache: @@ -723,7 +714,7 @@ def forward( # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision -class GitVisionEncoderLayer(nn.Module): +class GitVisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: GitVisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -840,21 +831,12 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 584d21c41087..537554c1b18a 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -37,6 +37,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( @@ -368,7 +369,7 @@ def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states -class GPT2Block(nn.Module): +class GPT2Block(GradientCheckpointingLayer): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -922,32 +923,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - past_key_values, - cache_position, - causal_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - past_key_value=past_key_values, - cache_position=cache_position, - attention_mask=causal_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs, - ) + outputs = block( + hidden_states, + past_key_value=past_key_values, + cache_position=cache_position, + attention_mask=causal_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs, + ) hidden_states = outputs[0] From b8f4ecf3983d208428aa16ce30fe5863320696dc Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:17:12 +0000 Subject: [PATCH 024/146] gpt - bigcode, neo, neox --- .../gpt_bigcode/modeling_gpt_bigcode.py | 36 ++++++---------- .../models/gpt_neo/modeling_gpt_neo.py | 33 +++++---------- .../models/gpt_neox/modeling_gpt_neox.py | 41 +++++++------------ 3 files changed, 37 insertions(+), 73 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 297c30cb06a1..063ba82233e3 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -31,6 +31,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import is_torch_greater_or_equal_than_2_2 from ...utils import ( @@ -553,7 +554,7 @@ def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.Fl } -class GPTBigCodeBlock(nn.Module): +class GPTBigCodeBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -886,29 +887,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = outputs[0] if use_cache: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 417eb11ab0ff..bbf493bc80a4 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -36,6 +36,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, @@ -426,7 +427,7 @@ def forward(self, hidden_states): return hidden_states -class GPTNeoBlock(nn.Module): +class GPTNeoBlock(GradientCheckpointingLayer): def __init__(self, config, layer_id=None): super().__init__() hidden_size = config.hidden_size @@ -630,27 +631,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - causal_mask, - head_mask[i], - use_cache, - output_attentions, - cache_position, - ) - else: - outputs = block( - hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) hidden_states = outputs[0] if use_cache: diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 6b08c27a3061..369e65793e86 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -22,6 +22,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging @@ -190,7 +191,7 @@ def forward( return attn_output, attn_weights -class GPTNeoXLayer(nn.Module): +class GPTNeoXLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.use_parallel_residual = config.use_parallel_residual @@ -415,32 +416,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - causal_mask, - position_ids, - head_mask[i], - use_cache, - past_key_values, - output_attentions, - cache_position, - position_embeddings, - ) - else: - outputs = layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - layer_past=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + outputs = layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + layer_past=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = outputs[0] if output_attentions: From d844b1258e441a7dfea389b4a2e83344ac09f169 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:18:11 +0000 Subject: [PATCH 025/146] gptj, groupvit --- src/transformers/models/gptj/modeling_gptj.py | 36 +++++++------------ .../models/groupvit/modeling_groupvit.py | 24 +++++-------- 2 files changed, 20 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 1e9c0ef5332f..2b501c9b54b3 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -34,6 +34,7 @@ QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -429,7 +430,7 @@ def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTens return hidden_states -class GPTJBlock(nn.Module): +class GPTJBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx=None): super().__init__() inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd @@ -733,29 +734,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - causal_mask, - position_ids, - head_mask[i], - use_cache, - output_attentions, - cache_position, - ) - else: - outputs = block( - hidden_states=hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + outputs = block( + hidden_states=hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 100fd2dd85f5..7f39f56ded61 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging, torch_int from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig @@ -692,7 +693,7 @@ def forward( # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GroupViT -class GroupViTEncoderLayer(nn.Module): +class GroupViTEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: GroupViTConfig): super().__init__() self.embed_dim = config.hidden_size @@ -906,21 +907,12 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From 700d20d63b1e8c48c127c82301631feefec14782 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:22:59 +0000 Subject: [PATCH 026/146] idefics2, idefics3 --- .../models/idefics2/modeling_idefics2.py | 21 +++++++------------ .../models/idefics3/modeling_idefics3.py | 21 +++++++------------ 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index d9b5d5e68336..704fa2785b0c 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -27,6 +27,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging @@ -339,7 +340,7 @@ def forward(self, hidden_state): return hidden_state[:, 0] -class Idefics2EncoderLayer(nn.Module): +class Idefics2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Idefics2VisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -448,19 +449,11 @@ def forward( for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 53b3cc2e304e..64f82e2d4e59 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -27,6 +27,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging @@ -300,7 +301,7 @@ def forward(self, x): # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3 -class Idefics3EncoderLayer(nn.Module): +class Idefics3EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Idefics3VisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -409,19 +410,11 @@ def forward( for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From 0b3ffba5d94201b6e4eeb154f01fdaf8f177ca3c Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:23:51 +0000 Subject: [PATCH 027/146] ijepa, imagegpt, internvl --- .../models/ijepa/modeling_ijepa.py | 13 ++----- .../models/imagegpt/modeling_imagegpt.py | 36 +++++++------------ .../models/internvl/modeling_internvl.py | 10 ++---- 3 files changed, 18 insertions(+), 41 deletions(-) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index c7ce7c29b4cf..44fe5c7b083f 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -13,6 +13,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int @@ -357,7 +358,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class IJepaLayer(nn.Module): +class IJepaLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: IJepaConfig) -> None: @@ -423,15 +424,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index db5ae763aadd..58f43ede5225 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -32,6 +32,7 @@ CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( @@ -402,7 +403,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class ImageGPTBlock(nn.Module): +class ImageGPTBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -720,29 +721,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 485adea83630..ea94d8b917c9 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -32,6 +32,7 @@ from ...integrations import use_kernel_forward_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -383,7 +384,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm} -class InternVLVisionLayer(nn.Module): +class InternVLVisionLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: InternVLVisionConfig) -> None: @@ -452,12 +453,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, output_attentions - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions) hidden_states = layer_outputs[0] From 9ed27ef25c4e44f296ba900ec27f63f567c93080 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:26:35 +0000 Subject: [PATCH 028/146] jetmoe, kosmos2, layoutlm --- .../models/jetmoe/modeling_jetmoe.py | 34 +++------- .../models/kosmos2/modeling_kosmos2.py | 68 +++++++------------ .../models/layoutlm/modeling_layoutlm.py | 33 +++------ 3 files changed, 45 insertions(+), 90 deletions(-) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index ab843099c54b..42480dbe4bfd 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -29,6 +29,7 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_jetmoe import JetMoeConfig @@ -758,7 +759,7 @@ def forward( } -class JetMoeBlock(nn.Module): +class JetMoeBlock(GradientCheckpointingLayer): def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None): """ Initialize the JetMoeBlock module. @@ -962,28 +963,15 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - position_ids, - past_key_values, - causal_mask, - output_attentions, - output_router_logits, - use_cache, - use_reentrant=False, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 34dc0848b79c..a3144c0ed67b 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -31,6 +31,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithCrossAttentions, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, logging, torch_int @@ -404,7 +405,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision -class Kosmos2VisionEncoderLayer(nn.Module): +class Kosmos2VisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Kosmos2VisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -521,21 +522,12 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -840,7 +832,7 @@ def forward(self, hidden_states): return hidden_states -class Kosmos2TextBlock(nn.Module): +class Kosmos2TextBlock(GradientCheckpointingLayer): def __init__(self, config: Kosmos2TextConfig): super().__init__() self.embed_dim = config.embed_dim @@ -1138,34 +1130,20 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 6ce2e7e2dcf2..f3bf7092c290 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -31,6 +31,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging @@ -358,7 +359,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM -class LayoutLMLayer(nn.Module): +class LayoutLMLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -484,27 +485,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: From 6d3ecbca40f7ab0e4cc9a908b6d6f832f281301f Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:27:24 +0000 Subject: [PATCH 029/146] layoutlm2-3, led --- .../models/layoutlmv2/modeling_layoutlmv2.py | 30 +++----- .../models/layoutlmv3/modeling_layoutlmv3.py | 30 +++----- src/transformers/models/led/modeling_led.py | 75 +++++++------------ 3 files changed, 45 insertions(+), 90 deletions(-) diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index fa89c6c45b41..f9ec7df94f8f 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -30,6 +30,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import auto_docstring, is_detectron2_available, logging, requires_backends @@ -261,7 +262,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class LayoutLMv2Layer(nn.Module): +class LayoutLMv2Layer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -436,25 +437,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - rel_pos=rel_pos, - rel_2d_pos=rel_2d_pos, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - rel_pos=rel_pos, - rel_2d_pos=rel_2d_pos, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + rel_pos=rel_pos, + rel_2d_pos=rel_2d_pos, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 83f87ec5281b..0261e14c77ff 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -31,6 +31,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( @@ -358,7 +359,7 @@ def forward( # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3 -class LayoutLMv3Layer(nn.Module): +class LayoutLMv3Layer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -514,25 +515,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - rel_pos, - rel_2d_pos, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - rel_pos=rel_pos, - rel_2d_pos=rel_2d_pos, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + rel_pos=rel_pos, + rel_2d_pos=rel_2d_pos, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index bc5738613218..7141d9a02d61 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -28,6 +28,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from .configuration_led import LEDConfig @@ -900,7 +901,7 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -class LEDEncoderLayer(nn.Module): +class LEDEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: LEDConfig, layer_id: int): super().__init__() self.embed_dim = config.d_model @@ -962,7 +963,7 @@ def forward( return (hidden_states,) + attn_outputs[1:] -class LEDDecoderLayer(nn.Module): +class LEDDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: LEDConfig): super().__init__() self.embed_dim = config.d_model @@ -1680,27 +1681,15 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - is_index_masked, - is_index_global_attn, - is_global_attn, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - is_index_masked=is_index_masked, - is_index_global_attn=is_index_global_attn, - is_global_attn=is_global_attn, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: @@ -1943,33 +1932,19 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - combined_attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] From e398d8e87344c22c342e164910b1dfdf1e1a1564 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:28:35 +0000 Subject: [PATCH 030/146] lilt, longformer, longt5, luke --- src/transformers/models/lilt/modeling_lilt.py | 27 ++++------ .../models/longformer/modeling_longformer.py | 33 ++++-------- .../models/longt5/modeling_longt5.py | 51 +++++++------------ src/transformers/models/luke/modeling_luke.py | 27 ++++------ 4 files changed, 46 insertions(+), 92 deletions(-) diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 0c76a25a6e3d..1d27e9a0d292 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -30,6 +30,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging @@ -419,7 +420,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class LiltLayer(nn.Module): +class LiltLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -506,23 +507,13 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layout_inputs, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - layout_inputs, - attention_mask, - layer_head_mask, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + layout_inputs, + attention_mask, + layer_head_mask, + output_attentions, + ) hidden_states = layer_outputs[0][0] layout_inputs = layer_outputs[0][1] diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index a40d5bb0e2c9..c6b16492c8d1 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging @@ -1205,7 +1206,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class LongformerLayer(nn.Module): +class LongformerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.attention = LongformerAttention(config, layer_id) @@ -1284,27 +1285,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - is_index_masked, - is_index_global_attn, - is_global_attn, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - is_index_masked=is_index_masked, - is_index_global_attn=is_index_global_attn, - is_global_attn=is_global_attn, - output_attentions=output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 2ddcc4f3d4d1..d740b35d774c 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -33,6 +33,7 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -1145,7 +1146,7 @@ def forward( return outputs -class LongT5Block(nn.Module): +class LongT5Block(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -1503,39 +1504,21 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - return_dict, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index c01bb7453f9a..1268048c0872 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN, gelu from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import ModelOutput, auto_docstring, logging @@ -695,7 +696,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class LukeLayer(nn.Module): +class LukeLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -774,23 +775,13 @@ def forward( all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - word_hidden_states, - entity_hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module( - word_hidden_states, - entity_hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) + layer_outputs = layer_module( + word_hidden_states, + entity_hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + ) word_hidden_states = layer_outputs[0] From 43631566cae37ea4637c8a8e96f12e823aac8bc2 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:29:41 +0000 Subject: [PATCH 031/146] m2m, mamba1-2 --- .../models/m2m_100/modeling_m2m_100.py | 69 +++++++------------ .../models/mamba/modeling_mamba.py | 20 +++--- .../models/mamba2/modeling_mamba2.py | 20 +++--- 3 files changed, 39 insertions(+), 70 deletions(-) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 25ed975e4d52..c617d5f47607 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -40,6 +40,7 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging @@ -335,7 +336,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100 -class M2M100EncoderLayer(nn.Module): +class M2M100EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model @@ -404,7 +405,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 -class M2M100DecoderLayer(nn.Module): +class M2M100DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: M2M100Config, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -883,21 +884,12 @@ def forward( if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1142,35 +1134,20 @@ def forward( if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 0f6dfab81124..d771494486f2 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import MambaCache from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -343,7 +344,7 @@ def extra_repr(self): return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" -class MambaBlock(nn.Module): +class MambaBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.config = config @@ -561,17 +562,12 @@ def forward( hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None for mixer_block in self.layers: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask - ) - else: - hidden_states = mixer_block( - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - ) + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 17925c5acc03..7dd6ecc92d42 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -682,7 +683,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -class Mamba2Block(nn.Module): +class Mamba2Block(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.config = config @@ -901,17 +902,12 @@ def forward( hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None for mixer_block in self.layers: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask - ) - else: - hidden_states = mixer_block( - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - ) + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) From dde58de6d1c54aa08cd3109cdeaa8cd6bf87cf49 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:31:25 +0000 Subject: [PATCH 032/146] marian, markuplm, mask2former --- .../models/marian/modeling_marian.py | 69 +++++++------------ .../models/markuplm/modeling_markuplm.py | 33 +++------ .../mask2former/modeling_mask2former.py | 61 +++++++--------- 3 files changed, 59 insertions(+), 104 deletions(-) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 5630f916ee38..5b5beed6c880 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -40,6 +40,7 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -270,7 +271,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN -class MarianEncoderLayer(nn.Module): +class MarianEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -342,7 +343,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN -class MarianDecoderLayer(nn.Module): +class MarianDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -831,21 +832,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1087,35 +1079,20 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 8ce6b5ed5ec9..9b07b3339f7d 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -32,6 +32,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ( PreTrainedModel, apply_chunking_to_forward, @@ -518,7 +519,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->MarkupLM -class MarkupLMLayer(nn.Module): +class MarkupLMLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -644,27 +645,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 3eb559dfcdb4..50d8496df74d 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...file_utils import ModelOutput, is_scipy_available, requires_backends from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_accelerate_available, logging from ...utils.backbone_utils import load_backbone @@ -1535,7 +1536,7 @@ def forward( return attn_output, attn_weights_reshaped -class Mask2FormerMaskedAttentionDecoderLayer(nn.Module): +class Mask2FormerMaskedAttentionDecoderLayer(GradientCheckpointingLayer): """ The Mask2FormerMaskedAttentionDecoderLayer is made up of self-attention, cross (masked) attention as well as FFN blocks. The cross attention block used as part of `Mask2FormerMaskedAttentionDecoderLayer` is actually a `masked @@ -1858,46 +1859,34 @@ def forward( if self.training and (dropout_probability < self.layerdrop): continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - None, - None, - output_attentions, - ) + level_index = idx % self.num_feature_levels - else: - level_index = idx % self.num_feature_levels - - where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) - # Multiply the attention mask instead of indexing to avoid issue in torch.export. - attention_mask = attention_mask * where.unsqueeze(-1) - - layer_outputs = decoder_layer( - hidden_states, - level_index=level_index, - position_embeddings=multi_stage_positional_embeddings, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - ) + where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) + # Multiply the attention mask instead of indexing to avoid issue in torch.export. + attention_mask = attention_mask * where.unsqueeze(-1) + + layer_outputs = decoder_layer( + hidden_states, + level_index=level_index, + position_embeddings=multi_stage_positional_embeddings, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + ) - intermediate_hidden_states = self.layernorm(layer_outputs[0]) + intermediate_hidden_states = self.layernorm(layer_outputs[0]) - predicted_mask, attention_mask = self.mask_predictor( - intermediate_hidden_states, - pixel_embeddings, - feature_size_list[(idx + 1) % self.num_feature_levels], - ) + predicted_mask, attention_mask = self.mask_predictor( + intermediate_hidden_states, + pixel_embeddings, + feature_size_list[(idx + 1) % self.num_feature_levels], + ) - intermediate_mask_predictions += (predicted_mask,) + intermediate_mask_predictions += (predicted_mask,) - # add intermediate hidden states with layer norm applied which will be used for predicting class logits - intermediate += (intermediate_hidden_states,) + # add intermediate hidden states with layer norm applied which will be used for predicting class logits + intermediate += (intermediate_hidden_states,) hidden_states = layer_outputs[0] From 69b2cf894b61ad661c707d601801e610bc625c49 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:33:09 +0000 Subject: [PATCH 033/146] maskformer --- .../models/maskformer/modeling_maskformer.py | 32 +++++++------------ .../maskformer/modeling_maskformer_swin.py | 25 ++++++--------- 2 files changed, 20 insertions(+), 37 deletions(-) diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 02f9848a8fba..f13c00b045e6 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutputWithCrossAttentions +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -529,7 +530,7 @@ def forward( # Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer -class DetrDecoderLayer(nn.Module): +class DetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DetrConfig): super().__init__() self.embed_dim = config.d_model @@ -742,26 +743,15 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - None, - encoder_hidden_states, - encoder_attention_mask, - None, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=None, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=None, + object_queries=object_queries, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index b7505aa6748e..47f80eddd0d9 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...file_utils import ModelOutput from ...modeling_outputs import BackboneOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import torch_int @@ -629,7 +630,7 @@ def forward(self, hidden_states, input_dimensions, head_mask=None, output_attent return outputs -class MaskFormerSwinStage(nn.Module): +class MaskFormerSwinStage(GradientCheckpointingLayer): # Copied from transformers.models.swin.modeling_swin.SwinStage.__init__ with Swin->MaskFormerSwin def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): super().__init__() @@ -729,21 +730,13 @@ def forward( for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_hidden_states, output_dimensions, layer_all_hidden_states = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( - hidden_states, - input_dimensions, - layer_head_mask, - output_attentions, - output_hidden_states, - ) + layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + output_hidden_states, + ) input_dimensions = (output_dimensions[-2], output_dimensions[-1]) all_input_dimensions += (input_dimensions,) From d4ccb793034ae63508cec16c2f113261ef62ebbd Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:34:26 +0000 Subject: [PATCH 034/146] mbart, megatron_bert, mimi --- .../models/mbart/modeling_mbart.py | 69 +++++++------------ .../megatron_bert/modeling_megatron_bert.py | 33 +++------ src/transformers/models/mimi/modeling_mimi.py | 33 +++------ 3 files changed, 45 insertions(+), 90 deletions(-) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 18ad34026f41..5fdc21c3126b 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -43,6 +43,7 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -279,7 +280,7 @@ def forward( return attn_output, attn_weights, past_key_value -class MBartEncoderLayer(nn.Module): +class MBartEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model @@ -347,7 +348,7 @@ def forward( return outputs -class MBartDecoderLayer(nn.Module): +class MBartDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MBartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -866,21 +867,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1130,35 +1122,20 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index d22b1536081b..97ee1f05c99e 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -39,6 +39,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging @@ -405,7 +406,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Based on transformers.models.bert.modeling_bert.BertLayer. Added LayerNorm. -class MegatronBertLayer(nn.Module): +class MegatronBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -535,27 +536,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) # Because we moved the layer-norm at the end of the hidden layer, we have non-normali- # zed data here. If that's really needed, we must apply LN to match Transformer's BERT. diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 9023c93433ab..f200ecb9b868 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -28,6 +28,7 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_torch_flex_attn_available, logging from .configuration_mimi import MimiConfig @@ -799,7 +800,7 @@ def forward( } -class MimiTransformerLayer(nn.Module): +class MimiTransformerLayer(GradientCheckpointingLayer): def __init__(self, config: MimiConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -1014,27 +1015,15 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] From ab213dad76d0409dbf811cd0ef53dc7ccb0e1dcd Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:34:54 +0000 Subject: [PATCH 035/146] mixtral, mlcd --- .../models/mixtral/modeling_mixtral.py | 41 +++++++------------ src/transformers/models/mlcd/modeling_mlcd.py | 24 ++++------- 2 files changed, 22 insertions(+), 43 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 26007b7b18ac..a3a9aa1c6502 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -46,6 +46,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging @@ -295,7 +296,7 @@ def forward( return attn_output, attn_weights -class MixtralDecoderLayer(nn.Module): +class MixtralDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -535,32 +536,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index b20e9f107e09..e1f23ca84b9d 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, torch_int @@ -299,7 +300,7 @@ def forward( return attn_output, attn_weights -class MLCDEncoderLayer(nn.Module): +class MLCDEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: MLCDVisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -416,21 +417,12 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - position_embeddings, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From cb9091621930f8acd5173437e270385a3560c1c4 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:37:41 +0000 Subject: [PATCH 036/146] mobilevit1-2, modernbert --- .../models/mobilevit/modeling_mobilevit.py | 11 ++----- .../mobilevitv2/modeling_mobilevitv2.py | 11 ++----- .../models/modernbert/modeling_modernbert.py | 33 +++++++------------ 3 files changed, 17 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 149eb9261ef1..eb16584579a3 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -31,6 +31,7 @@ ImageClassifierOutputWithNoAttention, SemanticSegmenterOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int @@ -350,7 +351,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class MobileViTLayer(nn.Module): +class MobileViTLayer(GradientCheckpointingLayer): """ MobileViT block: https://huggingface.co/papers/2110.02178 """ @@ -603,13 +604,7 @@ def forward( all_hidden_states = () if output_hidden_states else None for i, layer_module in enumerate(self.layer): - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - hidden_states = layer_module(hidden_states) + hidden_states = layer_module(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index 868c595dbace..fdf0a3261bec 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -30,6 +30,7 @@ ImageClassifierOutputWithNoAttention, SemanticSegmenterOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_mobilevitv2 import MobileViTV2Config @@ -351,7 +352,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class MobileViTV2Layer(nn.Module): +class MobileViTV2Layer(GradientCheckpointingLayer): """ MobileViTV2 layer: https://huggingface.co/papers/2206.02680 """ @@ -556,13 +557,7 @@ def forward( all_hidden_states = () if output_hidden_states else None for i, layer_module in enumerate(self.layer): - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - hidden_states = layer_module(hidden_states) + hidden_states = layer_module(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index d984c523d0c8..f2aae73bde4e 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -38,6 +38,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_flash_attn_2_available, logging from ...utils.import_utils import is_triton_available @@ -508,7 +509,7 @@ def forward( return (hidden_states,) + attn_outputs[1:] # add attentions if outputted -class ModernBertEncoderLayer(nn.Module): +class ModernBertEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config @@ -864,27 +865,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - sliding_window_mask, - position_ids, - cu_seqlens, - max_seqlen, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - sliding_window_mask=sliding_window_mask, - position_ids=position_ids, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions and len(layer_outputs) > 1: all_self_attentions = all_self_attentions + (layer_outputs[1],) From c2d3cbc2f24cbcf2c1dda4045c939f7374269125 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:39:15 +0000 Subject: [PATCH 037/146] moshi, mpt, mra --- .../models/moshi/modeling_moshi.py | 63 ++++++------------- src/transformers/models/mpt/modeling_mpt.py | 30 +++------ src/transformers/models/mra/modeling_mra.py | 12 +--- 3 files changed, 33 insertions(+), 72 deletions(-) diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index c0ef9c001471..cd80766ed23e 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -30,6 +30,7 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput, Seq2SeqLMOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging @@ -737,7 +738,7 @@ def forward( } -class MoshiDecoderLayer(nn.Module): +class MoshiDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: bool, use_rope=True): super().__init__() self.hidden_size = config.hidden_size @@ -1025,27 +1026,15 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] @@ -1342,27 +1331,15 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 79ec42e2b8d8..6fa1de4e9a76 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -32,6 +32,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_mpt import MptConfig @@ -160,7 +161,7 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch. return output -class MptBlock(nn.Module): +class MptBlock(GradientCheckpointingLayer): def __init__(self, config: MptConfig): super().__init__() hidden_size = config.hidden_size @@ -388,25 +389,14 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - alibi, - causal_mask, - layer_past, - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=causal_mask, - use_cache=use_cache, - output_attentions=output_attentions, - position_bias=alibi, - ) + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + use_cache=use_cache, + output_attentions=output_attentions, + position_bias=alibi, + ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 7501c4d83062..c5d63ba7705a 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -33,6 +33,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, is_cuda_platform, is_ninja_available, is_torch_cuda_available, logging @@ -688,7 +689,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class MraLayer(nn.Module): +class MraLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -738,14 +739,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask) + layer_outputs = layer_module(hidden_states, attention_mask) hidden_states = layer_outputs[0] From 80bcd7c71e885adeda4d0973066ae93e8531f220 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:41:29 +0000 Subject: [PATCH 038/146] mt5, musicgen --- src/transformers/models/mt5/modeling_mt5.py | 51 +++++++------------ .../models/musicgen/modeling_musicgen.py | 43 ++++++---------- .../modeling_musicgen_melody.py | 30 ++++------- 3 files changed, 42 insertions(+), 82 deletions(-) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 8596fbeb4f9d..4eb6fa1ec7fc 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -37,6 +37,7 @@ Seq2SeqSequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -523,7 +524,7 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->MT5 -class MT5Block(nn.Module): +class MT5Block(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -1088,39 +1089,21 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - return_dict, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index a7ead0a51461..beff1b6560f5 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -50,6 +50,7 @@ ModelOutput, Seq2SeqLMOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging @@ -304,7 +305,7 @@ def forward( return attn_output, attn_weights, past_key_value -class MusicgenDecoderLayer(nn.Module): +class MusicgenDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MusicgenDecoderConfig): super().__init__() self.embed_dim = config.hidden_size @@ -619,33 +620,19 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.forward, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index a57955a7a70b..ea9621b459b4 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -42,6 +42,7 @@ FlashAttentionKwargs, ) from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging @@ -320,7 +321,7 @@ def forward( return attn_output, attn_weights, past_key_value -class MusicgenMelodyDecoderLayer(nn.Module): +class MusicgenMelodyDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MusicgenMelodyDecoderConfig): super().__init__() self.embed_dim = config.hidden_size @@ -596,25 +597,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.forward, - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: From 825e2b120637afa2cc733c69c08a3d27858d5394 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:42:12 +0000 Subject: [PATCH 039/146] mvp, nemotron --- src/transformers/models/mvp/modeling_mvp.py | 75 +++++++------------ .../models/nemotron/modeling_nemotron.py | 36 +++------ 2 files changed, 37 insertions(+), 74 deletions(-) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 739ecd8f015e..3410acd2a7d4 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -38,6 +38,7 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_mvp import MvpConfig @@ -244,7 +245,7 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -class MvpEncoderLayer(nn.Module): +class MvpEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: MvpConfig): super().__init__() self.embed_dim = config.d_model @@ -316,7 +317,7 @@ def forward( return outputs -class MvpDecoderLayer(nn.Module): +class MvpDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MvpConfig): super().__init__() self.embed_dim = config.d_model @@ -682,23 +683,13 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - (self_attn_prompt[idx] if self.use_prompt else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -935,37 +926,21 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - self_attn_prompt[idx] if self.use_prompt else None, - cross_attn_prompt[idx] if self.use_prompt else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), - cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), + cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 6b5cf370c9da..bac137fd38b4 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -36,6 +36,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging @@ -487,7 +488,7 @@ def forward( # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron # no longer copied after attention refactors -class NemotronDecoderLayer(nn.Module): +class NemotronDecoderLayer(GradientCheckpointingLayer): # Ignore copy def __init__(self, config: NemotronConfig, layer_idx: int): super().__init__() @@ -692,29 +693,16 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] From 8f6a8fb0811b45489900d4a99533c936b3241476 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:43:39 +0000 Subject: [PATCH 040/146] nllb_moe --- .../models/nllb_moe/modeling_nllb_moe.py | 71 ++++++------------- 1 file changed, 22 insertions(+), 49 deletions(-) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 15c87871649b..3920ed34b029 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -38,6 +38,7 @@ Seq2SeqMoEModelOutput, Seq2SeqMoEOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging @@ -625,7 +626,7 @@ def forward( return attn_output, attn_weights, past_key_value -class NllbMoeEncoderLayer(nn.Module): +class NllbMoeEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): super().__init__() self.embed_dim = config.d_model @@ -707,7 +708,7 @@ def forward( return outputs -class NllbMoeDecoderLayer(nn.Module): +class NllbMoeDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): super().__init__() self.embed_dim = config.d_model @@ -1018,22 +1019,13 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - output_router_logits=output_router_logits, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + output_router_logits=output_router_logits, + ) hidden_states = layer_outputs[0] @@ -1296,37 +1288,18 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.forward, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + ) hidden_states = layer_outputs[0] From 6253d785253dd6b8380a87eb6ee57070b9f60a56 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:45:06 +0000 Subject: [PATCH 041/146] nystromformer, omdet_turbo --- .../nystromformer/modeling_nystromformer.py | 13 ++--- .../omdet_turbo/modeling_omdet_turbo.py | 47 ++++++------------- 2 files changed, 18 insertions(+), 42 deletions(-) diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 17a3319de3ac..d000d94f1c34 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -31,6 +31,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -311,7 +312,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class NystromformerLayer(nn.Module): +class NystromformerLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -363,15 +364,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index 3380118dd9ef..1007be135ed7 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -31,6 +31,7 @@ ) from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from ...utils.backbone_utils import load_backbone @@ -879,7 +880,7 @@ def forward(self, x): return x -class OmDetTurboDeformableTransformerDecoderLayer(nn.Module): +class OmDetTurboDeformableTransformerDecoderLayer(GradientCheckpointingLayer): """ A single layer of the Deformable Transformer Decoder. """ @@ -1376,37 +1377,19 @@ def forward( last_refined_bbox = None reference_points = reference_points.sigmoid() for i, layer in enumerate(self.layers): - if self.gradient_checkpointing and self.training: - predicted_class_features, task_features, self_attention, cross_attention = ( - self._gradient_checkpointing_func( - layer.__call__, - predicted_class_features, - task_features, - reference_points, - vision_features, - vision_shapes, - vision_shapes_list, - level_start_index=level_start_index, - attention_mask=attention_mask, - query_position=self.query_position_head(reference_points), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - ) - else: - predicted_class_features, task_features, self_attention, cross_attention = layer( - predicted_class_features, - task_features, - reference_points, - vision_features, - vision_shapes, - vision_shapes_list, - level_start_index=level_start_index, - attention_mask=attention_mask, - query_position=self.query_position_head(reference_points), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + predicted_class_features, task_features, self_attention, cross_attention = layer( + predicted_class_features, + task_features, + reference_points, + vision_features, + vision_shapes, + vision_shapes_list, + level_start_index=level_start_index, + attention_mask=attention_mask, + query_position=self.query_position_head(reference_points), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) if output_attentions: all_self_attns = all_self_attns + (self_attention,) all_cross_attns = all_cross_attns + (cross_attention,) From ab136ef193ac4d4d4b53a1b478853867d2da337e Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:46:13 +0000 Subject: [PATCH 042/146] opt, owlvit, owlv2 --- src/transformers/models/opt/modeling_opt.py | 38 +++++++------------ .../models/owlv2/modeling_owlv2.py | 24 ++++-------- .../models/owlvit/modeling_owlvit.py | 24 ++++-------- 3 files changed, 29 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index fd22722b69cf..39f460432f3f 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -32,6 +32,7 @@ QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging @@ -204,7 +205,7 @@ def forward( return attn_output, attn_weights, past_key_value -class OPTDecoderLayer(nn.Module): +class OPTDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: OPTConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.hidden_size @@ -672,30 +673,17 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - head_mask[idx] if head_mask is not None else None, - None, - output_attentions, - use_cache, - position_ids, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 4e838c90a9f8..9932c0e3dca7 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_vision_available, logging, torch_int from .configuration_owlv2 import Owlv2Config, Owlv2TextConfig, Owlv2VisionConfig @@ -497,7 +498,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Owlv2 -class Owlv2EncoderLayer(nn.Module): +class Owlv2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Owlv2Config): super().__init__() self.embed_dim = config.hidden_size @@ -655,21 +656,12 @@ def forward( for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index d487706611d1..b4aa07c2cdd5 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_vision_available, logging, torch_int from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig @@ -485,7 +486,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->OwlViT -class OwlViTEncoderLayer(nn.Module): +class OwlViTEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: OwlViTConfig): super().__init__() self.embed_dim = config.hidden_size @@ -641,21 +642,12 @@ def forward( for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From 3fb64a9f0bfd674be4e38aef0709e65feaee6312 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:47:18 +0000 Subject: [PATCH 043/146] pegasus, pegasus_x, presimmon --- .../models/pegasus/modeling_pegasus.py | 69 +++++++------------ .../models/pegasus_x/modeling_pegasus_x.py | 59 +++++----------- .../models/persimmon/modeling_persimmon.py | 38 ++++------ 3 files changed, 55 insertions(+), 111 deletions(-) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 6922d3c815bf..63c513112487 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -40,6 +40,7 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -269,7 +270,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS -class PegasusEncoderLayer(nn.Module): +class PegasusEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model @@ -338,7 +339,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS -class PegasusDecoderLayer(nn.Module): +class PegasusDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PegasusConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -845,21 +846,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1135,35 +1127,20 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 0c48aa614125..b929c74ec7ca 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -39,6 +39,7 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -528,7 +529,7 @@ def compute_local_attention_representations( return attn_output, attn_probs -class PegasusXEncoderLayer(nn.Module): +class PegasusXEncoderLayer(GradientCheckpointingLayer): def __init__(self, stagger_blocks_this_layer: bool, config: PegasusXConfig): super().__init__() self.embed_dim = config.d_model @@ -643,7 +644,7 @@ def unpad_local_tokens(cls, padded_hidden_states, block_size): return padded_hidden_states[:, pad_size:-pad_size, :] -class PegasusXDecoderLayer(nn.Module): +class PegasusXDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PegasusXConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -1148,21 +1149,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - global_hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - global_hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + global_hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] global_hidden_states = layer_outputs[1] @@ -1388,29 +1380,16 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index ce142c8d6d22..08423e838b81 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -37,6 +37,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging @@ -299,7 +300,7 @@ def forward( return attn_output, attn_weights, past_key_value -class PersimmonDecoderLayer(nn.Module): +class PersimmonDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PersimmonConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -517,30 +518,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] From 32b2876745bcfd6f003b418ceac88bba78247bed Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:52:32 +0000 Subject: [PATCH 044/146] phimoe, pix2struct, pixtral --- .../models/phimoe/modeling_phimoe.py | 39 ++++------- .../models/pix2struct/modeling_pix2struct.py | 70 ++++++------------- .../models/pixtral/modeling_pixtral.py | 26 +++---- 3 files changed, 44 insertions(+), 91 deletions(-) diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 2a9240783399..785e5deb477b 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -29,6 +29,7 @@ from ...modeling_flash_attention_utils import is_flash_attn_available from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_phimoe import PhimoeConfig @@ -790,7 +791,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -class PhimoeDecoderLayer(nn.Module): +class PhimoeDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PhimoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -1013,31 +1014,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 54c956601a51..f3c85e7ad538 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -32,6 +32,7 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( @@ -262,7 +263,7 @@ def forward(self, hidden_states): return hidden_states -class Pix2StructVisionLayer(nn.Module): +class Pix2StructVisionLayer(GradientCheckpointingLayer): def __init__(self, config: Pix2StructConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -330,16 +331,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] @@ -928,7 +920,7 @@ def forward( return outputs -class Pix2StructTextBlock(nn.Module): +class Pix2StructTextBlock(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() @@ -1151,6 +1143,10 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training and use_cache: + logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: @@ -1244,42 +1240,20 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 998124a8da66..5fb4ff26885e 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -25,6 +25,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput from ...modeling_rope_utils import dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, logging @@ -272,7 +273,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class PixtralAttentionLayer(nn.Module): +class PixtralAttentionLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) @@ -374,22 +375,13 @@ def forward( for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - position_embeddings, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - position_embeddings=position_embeddings, - output_attentions=output_attentions, - **kwargs, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + **kwargs, + ) hidden_states = layer_outputs[0] From 942f7a496b4d5793f2044352482466002971df37 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 14:57:15 +0000 Subject: [PATCH 045/146] plbart, pop2piano, prophetnet --- .../models/plbart/modeling_plbart.py | 69 ++++++---------- .../models/pop2piano/modeling_pop2piano.py | 48 ++++-------- .../models/prophetnet/modeling_prophetnet.py | 78 +++++++------------ 3 files changed, 65 insertions(+), 130 deletions(-) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index a192fe70e238..b34acbbca588 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -44,6 +44,7 @@ Seq2SeqModelOutput, Seq2SeqSequenceClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging @@ -465,7 +466,7 @@ def forward( return attn_output, attn_weights, past_key_value -class PLBartEncoderLayer(nn.Module): +class PLBartEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -683,21 +684,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -714,7 +706,7 @@ def forward( ) -class PLBartDecoderLayer(nn.Module): +class PLBartDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -1064,35 +1056,20 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index b5e52e7c3971..d20194b8bbff 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -29,6 +29,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging @@ -469,7 +470,7 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano -class Pop2PianoBlock(nn.Module): +class Pop2PianoBlock(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -816,37 +817,20 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index d7783c48e0aa..e0ec38f0d1ba 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from .configuration_prophetnet import ProphetNetConfig @@ -956,7 +957,7 @@ def get_predict_relative_pos_embeddings( return predict_relative_pos_embeddings -class ProphetNetEncoderLayer(nn.Module): +class ProphetNetEncoderLayer(GradientCheckpointingLayer): """ Encoder block for Prophetnet """ @@ -999,7 +1000,7 @@ def forward( return outputs -class ProphetNetDecoderLayer(nn.Module): +class ProphetNetDecoderLayer(GradientCheckpointingLayer): """ Decoder block for Prophetnet """ @@ -1183,21 +1184,12 @@ def forward( if output_hidden_states: encoder_hidden_states = encoder_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - extended_attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=extended_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1395,41 +1387,23 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - extended_attention_mask, - encoder_hidden_states, - extended_encoder_attention_mask, - (head_mask[idx] if head_mask is not None else None), - (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - extended_predict_attention_mask, - main_relative_position_buckets, - predict_relative_position_buckets, - position_ids, - None, - use_cache, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=extended_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attn_mask=extended_encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - extended_predict_attention_mask=extended_predict_attention_mask, - main_relative_position_buckets=main_relative_position_buckets, - predict_relative_position_buckets=predict_relative_position_buckets, - position_ids=position_ids, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attn_mask=extended_encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From b083c86011f91091e99deaa300f106b1da2dec86 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:00:50 +0000 Subject: [PATCH 046/146] qwen2* --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 45 ++++++------------- .../qwen2_audio/modeling_qwen2_audio.py | 24 ++++------ .../models/qwen2_moe/modeling_qwen2_moe.py | 39 ++++++---------- 3 files changed, 35 insertions(+), 73 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 0001236c343a..acdf046c0423 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -40,6 +40,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging @@ -314,7 +315,7 @@ def forward( } -class Qwen2_5_VLVisionBlock(nn.Module): +class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) @@ -516,12 +517,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings - ) - else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) @@ -991,30 +987,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 6569d78674e9..b36cc2eff726 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import auto_docstring, logging from ..auto import AutoModel, AutoModelForCausalLM @@ -200,7 +201,7 @@ def forward( # Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO -class Qwen2AudioEncoderLayer(nn.Module): +class Qwen2AudioEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2AudioConfig): super().__init__() self.embed_dim = config.d_model @@ -436,21 +437,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 9882ca447d7e..da4af7c89ecb 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -40,6 +40,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_qwen2_moe import Qwen2MoeConfig @@ -634,7 +635,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -class Qwen2MoeDecoderLayer(nn.Module): +class Qwen2MoeDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2MoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -860,31 +861,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] From 429ba1171259e5d14499c6626c09f7a6704fcc69 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:01:49 +0000 Subject: [PATCH 047/146] qwen2, qwen3 moe, rec gemma --- .../models/qwen2_vl/modeling_qwen2_vl.py | 45 ++++++------------- .../models/qwen3_moe/modeling_qwen3_moe.py | 41 ++++++----------- .../modeling_recurrent_gemma.py | 10 ++--- 3 files changed, 31 insertions(+), 65 deletions(-) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 95687334eb0f..0e4a521169b8 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -37,6 +37,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -439,7 +440,7 @@ def forward( } -class Qwen2VLVisionBlock(nn.Module): +class Qwen2VLVisionBlock(GradientCheckpointingLayer): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = LayerNorm(config.embed_dim, eps=1e-6) @@ -837,12 +838,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for blk in self.blocks: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens, None, position_embeddings - ) - else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) return self.merger(hidden_states) @@ -959,30 +955,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 67f21d1b8368..f332b16a7e99 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -41,6 +41,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging @@ -286,7 +287,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Qwen3MoeDecoderLayer(nn.Module): +class Qwen3MoeDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen3MoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -541,32 +542,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 47e79f34870a..8b6e37462d75 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -26,6 +26,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import auto_docstring, logging @@ -471,7 +472,7 @@ def forward(self, hidden_states): return self.down_proj(gate * self.up_proj(hidden_states)) -class RecurrentGemmaDecoderLayer(nn.Module): +class RecurrentGemmaDecoderLayer(GradientCheckpointingLayer): """Griffin and Hawk's residual block.""" def __init__(self, config, layer_idx): @@ -648,12 +649,7 @@ def forward( for i, residual_block in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - residual_block.__call__, hidden_states, position_ids, causal_mask, cache_position, use_cache - ) - else: - hidden_states = residual_block(hidden_states, position_ids, causal_mask, cache_position, use_cache) + hidden_states = residual_block(hidden_states, position_ids, causal_mask, cache_position, use_cache) hidden_states = self.final_norm(hidden_states) From cec0d3242f653c68747678d9a88a6e733c098b63 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:02:35 +0000 Subject: [PATCH 048/146] rembert --- .../models/rembert/modeling_rembert.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 22a3ba7aeb96..3197cca5fafe 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -35,6 +35,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging @@ -399,7 +400,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class RemBertLayer(nn.Module): +class RemBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -528,27 +529,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: From bec1fcd92a0d422ce32c44cb1121091dad9840b9 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:02:49 +0000 Subject: [PATCH 049/146] roberta --- .../models/roberta/modeling_roberta.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 9381c8f9ab01..e0440f3f48cf 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -37,6 +37,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging @@ -477,7 +478,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta -class RobertaLayer(nn.Module): +class RobertaLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -603,27 +604,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: From 254882fb2527c97026996f4baebfc83b67ad6d27 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:03:09 +0000 Subject: [PATCH 050/146] roberta prelayernorm --- .../modeling_roberta_prelayernorm.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 31d459e7d8b6..2a7b22627775 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -35,6 +35,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging @@ -365,7 +366,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RobertaPreLayerNorm -class RobertaPreLayerNormLayer(nn.Module): +class RobertaPreLayerNormLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -491,27 +492,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: From a1a7fdaa4a054a11b39617f1dafa0e4ad5cfdccc Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:03:53 +0000 Subject: [PATCH 051/146] roc_bert, roformer, rwkv --- .../models/roc_bert/modeling_roc_bert.py | 33 ++++++----------- .../models/roformer/modeling_roformer.py | 36 +++++++------------ src/transformers/models/rwkv/modeling_rwkv.py | 14 +++----- 3 files changed, 28 insertions(+), 55 deletions(-) diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 06721ae7719b..ca7a2e949c7e 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -35,6 +35,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging @@ -488,7 +489,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RoCBert -class RoCBertLayer(nn.Module): +class RoCBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -614,27 +615,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 3f9b2875c207..6a5211de647f 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -35,6 +35,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging @@ -423,7 +424,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class RoFormerLayer(nn.Module): +class RoFormerLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -558,29 +559,16 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - sinusoidal_pos, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - sinusoidal_pos, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index b7362000b572..8f0453981026 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -25,6 +25,7 @@ from torch import nn from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -344,7 +345,7 @@ def forward(self, hidden, state=None): return receptance * value, state -class RwkvBlock(nn.Module): +class RwkvBlock(GradientCheckpointingLayer): def __init__(self, config, layer_id): super().__init__() self.config = config @@ -604,14 +605,9 @@ def forward( all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for idx, block in enumerate(self.blocks): - if self.gradient_checkpointing and self.training: - hidden_states, state, attentions = self._gradient_checkpointing_func( - block.__call__, hidden_states, state, use_cache, output_attentions - ) - else: - hidden_states, state, attentions = block( - hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions - ) + hidden_states, state, attentions = block( + hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions + ) if ( self.layers_are_rescaled From d497df9822f70daa404461f7b449a3061252f1be Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:05:10 +0000 Subject: [PATCH 052/146] sam, sam_hq --- src/transformers/models/sam/modeling_sam.py | 11 +++-------- src/transformers/models/sam_hq/modeling_sam_hq.py | 11 +++-------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index a9088958a8f4..24d891e3746b 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -969,7 +970,7 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch } -class SamVisionLayer(nn.Module): +class SamVisionLayer(GradientCheckpointingLayer): def __init__(self, config, window_size): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -1145,13 +1146,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py index 203911698556..bdde2599d697 100644 --- a/src/transformers/models/sam_hq/modeling_sam_hq.py +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig @@ -364,7 +365,7 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch } -class SamHQVisionLayer(nn.Module): +class SamHQVisionLayer(GradientCheckpointingLayer): def __init__(self, config, window_size): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -543,13 +544,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) hidden_states = layer_outputs[0] From 987a880897339bd177f39a981bf3b3e635fbf4b5 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:07:08 +0000 Subject: [PATCH 053/146] seggpt, smolvlm, speech_to_text --- .../models/seggpt/modeling_seggpt.py | 14 +--- .../models/smolvlm/modeling_smolvlm.py | 21 ++---- .../speech_to_text/modeling_speech_to_text.py | 66 +++++++------------ 3 files changed, 32 insertions(+), 69 deletions(-) diff --git a/src/transformers/models/seggpt/modeling_seggpt.py b/src/transformers/models/seggpt/modeling_seggpt.py index 69c5ce88f7a9..cf6b6db3f2a6 100644 --- a/src/transformers/models/seggpt/modeling_seggpt.py +++ b/src/transformers/models/seggpt/modeling_seggpt.py @@ -24,6 +24,7 @@ from torch.nn import functional as F from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging, torch_int from .configuration_seggpt import SegGptConfig @@ -395,7 +396,7 @@ def extra_repr(self) -> str: return f"p={self.drop_prob}" -class SegGptLayer(nn.Module): +class SegGptLayer(GradientCheckpointingLayer): def __init__(self, config: SegGptConfig, drop_path_rate: float) -> None: super().__init__() self.attention = SegGptAttention(config) @@ -470,16 +471,7 @@ def forward( # Condition to check if we have the appropriate number of prompts to ensemble ensemble_cond = 2 if self.config.merge_index > i else 1 - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ensemble_cond, - feature_ensemble, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, ensemble_cond, feature_ensemble, output_attentions) + layer_outputs = layer_module(hidden_states, ensemble_cond, feature_ensemble, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 383450aae1f9..0c39e64a5bb6 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -31,6 +31,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging @@ -239,7 +240,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class SmolVLMEncoderLayer(nn.Module): +class SmolVLMEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: SmolVLMVisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -346,19 +347,11 @@ def forward( for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index d9fbf2faec91..0791976dcf66 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -36,6 +36,7 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -328,7 +329,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT -class Speech2TextEncoderLayer(nn.Module): +class Speech2TextEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model @@ -398,7 +399,7 @@ def forward( # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT # TODO: change copy when applying cache class -class Speech2TextDecoderLayer(nn.Module): +class Speech2TextDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model @@ -693,21 +694,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -941,33 +933,19 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: From 6ef90e11b456bd236e9bb13777cd28174e6a717b Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:07:58 +0000 Subject: [PATCH 054/146] splinter, stablelm, swin --- .../models/splinter/modeling_splinter.py | 33 ++++++----------- .../models/stablelm/modeling_stablelm.py | 36 +++++++------------ src/transformers/models/swin/modeling_swin.py | 19 +++------- 3 files changed, 28 insertions(+), 60 deletions(-) diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 06d2917b6a27..4332713b6549 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -330,7 +331,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Splinter -class SplinterLayer(nn.Module): +class SplinterLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -456,27 +457,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index df4e41bcd210..d98f40597bf9 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -38,6 +38,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_stablelm import StableLmConfig @@ -519,7 +520,7 @@ def forward( } -class StableLmDecoderLayer(nn.Module): +class StableLmDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: StableLmConfig, layer_idx: int): super().__init__() self.use_parallel_residual = config.use_parallel_residual @@ -744,29 +745,16 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index a8c29e84785a..68daf506f628 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -739,7 +740,7 @@ def forward( return layer_outputs -class SwinStage(nn.Module): +class SwinStage(GradientCheckpointingLayer): def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): super().__init__() self.config = config @@ -848,19 +849,9 @@ def forward( for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - input_dimensions, - layer_head_mask, - output_attentions, - always_partition, - ) - else: - layer_outputs = layer_module( - hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition - ) + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) hidden_states = layer_outputs[0] hidden_states_before_downsampling = layer_outputs[1] From 1b7cc3f688e5edffa3c5d5b6f09526c9151778c1 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:13:04 +0000 Subject: [PATCH 055/146] swin2sr, switch_transformer, t5, table_transformer --- .../models/swin2sr/modeling_swin2sr.py | 10 ++-- .../modeling_switch_transformers.py | 54 +++++++------------ src/transformers/models/t5/modeling_t5.py | 51 ++++++------------ .../modeling_table_transformer.py | 31 ++++------- 4 files changed, 49 insertions(+), 97 deletions(-) diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index c63579a014f5..e083b28a6dd0 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageSuperResolutionOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging @@ -592,7 +593,7 @@ def forward( return layer_outputs -class Swin2SRStage(nn.Module): +class Swin2SRStage(GradientCheckpointingLayer): """ This corresponds to the Residual Swin Transformer Block (RSTB) in the original implementation. """ @@ -705,12 +706,7 @@ def forward( for i, stage_module in enumerate(self.stages): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - stage_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions - ) - else: - layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] output_dimensions = layer_outputs[1] diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 60dc8fb7a55a..3a25906a19b2 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -33,6 +33,7 @@ Seq2SeqMoEModelOutput, Seq2SeqMoEOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -664,7 +665,7 @@ def forward( return outputs -class SwitchTransformersBlock(nn.Module): +class SwitchTransformersBlock(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, is_sparse=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -1027,41 +1028,22 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - output_router_logits, - return_dict, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - return_dict=return_dict, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) router_probs = layer_outputs[-1] layer_outputs = layer_outputs[:-1] diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 5339e8e2d904..fcc528579a53 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -37,6 +37,7 @@ Seq2SeqSequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -647,7 +648,7 @@ def forward( return outputs -class T5Block(nn.Module): +class T5Block(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -1103,39 +1104,21 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - return_dict, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 4938ea378dfb..321d3f176d9e 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -618,7 +619,7 @@ def forward( return outputs -class TableTransformerDecoderLayer(nn.Module): +class TableTransformerDecoderLayer(GradientCheckpointingLayer): # Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer.__init__ with Detr->TableTransformer def __init__(self, config: TableTransformerConfig): super().__init__() @@ -989,25 +990,15 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - combined_attention_mask, - encoder_hidden_states, - encoder_attention_mask, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + object_queries=object_queries, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From 5331bc2876116ccd59264486d32351abec6b40d2 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:15:31 +0000 Subject: [PATCH 056/146] tapas, time_series_tranformer, timesformer --- .../models/tapas/modeling_tapas.py | 33 +++------ .../modeling_time_series_transformer.py | 69 +++++++------------ .../timesformer/modeling_timesformer.py | 12 +--- 3 files changed, 37 insertions(+), 77 deletions(-) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index f6463d95a24d..d6dfde622ea9 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging @@ -475,7 +476,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class TapasLayer(nn.Module): +class TapasLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -591,27 +592,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_values, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_values, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_values, + output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 83b69af4a10b..261e5c4ff765 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -37,6 +37,7 @@ Seq2SeqTSModelOutput, Seq2SeqTSPredictionOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput @@ -435,7 +436,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER -class TimeSeriesTransformerEncoderLayer(nn.Module): +class TimeSeriesTransformerEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -507,7 +508,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER -class TimeSeriesTransformerDecoderLayer(nn.Module): +class TimeSeriesTransformerDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -857,21 +858,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1066,35 +1058,20 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 00592039a920..9f4ee262430b 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, @@ -288,7 +289,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Adapted from https://github.com/facebookresearch/TimeSformer/blob/a5ef29a7b7264baff199a30b3306ac27de901133/timesformer/models/vit.py#L89 -class TimesformerLayer(nn.Module): +class TimesformerLayer(GradientCheckpointingLayer): def __init__(self, config: TimesformerConfig, layer_index: int) -> None: super().__init__() @@ -432,14 +433,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions) hidden_states = layer_outputs[0] From dfe3d8d878e705c64150004b2e19eeeff9b5fc03 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:17:37 +0000 Subject: [PATCH 057/146] trocr, tvp, umt5 --- .../models/trocr/modeling_trocr.py | 43 +++++++----------- src/transformers/models/tvp/modeling_tvp.py | 14 ++---- src/transformers/models/umt5/modeling_umt5.py | 44 +++++++------------ 3 files changed, 33 insertions(+), 68 deletions(-) diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 11b5b4a415f8..d9ce3726fb38 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -29,6 +29,7 @@ _prepare_4d_causal_attention_mask, ) from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_trocr import TrOCRConfig @@ -288,7 +289,7 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -class TrOCRDecoderLayer(nn.Module): +class TrOCRDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: TrOCRConfig): super().__init__() self.embed_dim = config.hidden_size @@ -643,33 +644,19 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index 01932573f01b..f1a9421117a4 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import prune_linear_layer from ...utils import auto_docstring, logging @@ -455,7 +456,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class TvpEncodeLayer(nn.Module): +class TvpEncodeLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = TvpAttention(config) @@ -511,16 +512,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - (head_mask[i] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], output_attentions) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index d5bb6718baf1..cc988e848972 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -35,6 +35,7 @@ Seq2SeqSequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( DUMMY_INPUTS, @@ -406,7 +407,7 @@ def forward( return outputs -class UMT5Block(nn.Module): +class UMT5Block(GradientCheckpointingLayer): def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -765,35 +766,20 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_extended_attention_mask, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) - hidden_states = layer_outputs[0] + hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[1] From c001253210ed2335dfab7cde28bc2f8dc8c32b39 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:19:04 +0000 Subject: [PATCH 058/146] videomae, vilt, visual_bert --- .../models/videomae/modeling_videomae.py | 23 ++++--------------- src/transformers/models/vilt/modeling_vilt.py | 14 +++-------- .../visual_bert/modeling_visual_bert.py | 14 +++-------- 3 files changed, 10 insertions(+), 41 deletions(-) diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index c418a3b49c54..0bdaa35833c9 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -389,7 +390,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE -class VideoMAELayer(nn.Module): +class VideoMAELayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: VideoMAEConfig) -> None: @@ -456,15 +457,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] @@ -698,15 +691,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - None, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 6ce00d9f397c..a98fa55e0364 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -33,6 +33,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import auto_docstring, logging @@ -456,7 +457,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class ViltLayer(nn.Module): +class ViltLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config): @@ -519,16 +520,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index ee228e5f4282..9da2ddf8f2a6 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -30,6 +30,7 @@ MultipleChoiceModelOutput, SequenceClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging @@ -330,7 +331,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class VisualBertLayer(nn.Module): +class VisualBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -394,16 +395,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: From 76dd7a54ec7c5c1e486faf35aaabb124a2c2825f Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:20:10 +0000 Subject: [PATCH 059/146] vit, vit_mae, vit_msn --- src/transformers/models/vit/modeling_vit.py | 13 +++-------- .../models/vit_mae/modeling_vit_mae.py | 23 ++++--------------- .../models/vit_msn/modeling_vit_msn.py | 13 +++-------- 3 files changed, 10 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 6e85320a0fa8..9e298dab1185 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -30,6 +30,7 @@ ImageClassifierOutput, MaskedImageModelingOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int @@ -346,7 +347,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class ViTLayer(nn.Module): +class ViTLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ViTConfig) -> None: @@ -412,15 +413,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 61cf29f8569a..de5246224de5 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -531,7 +532,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE -class ViTMAELayer(nn.Module): +class ViTMAELayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ViTMAEConfig) -> None: @@ -598,15 +599,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] @@ -864,15 +857,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - None, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 25efc1ac4de6..aade45a5f954 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int @@ -349,7 +350,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN -class ViTMSNLayer(nn.Module): +class ViTMSNLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ViTMSNConfig) -> None: @@ -416,15 +417,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] From 0bc5335371058fbbd32c11944e128b951f649b0f Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:21:09 +0000 Subject: [PATCH 060/146] vitpose_backbone, vits, vivit --- .../modeling_vitpose_backbone.py | 14 +++-------- src/transformers/models/vits/modeling_vits.py | 24 +++++++------------ .../models/vivit/modeling_vivit.py | 13 +++------- 3 files changed, 14 insertions(+), 37 deletions(-) diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index 7579ecb9fbb1..6594dfa3a162 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput, BaseModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging @@ -302,7 +303,7 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return hidden_state -class VitPoseBackboneLayer(nn.Module): +class VitPoseBackboneLayer(GradientCheckpointingLayer): def __init__(self, config: VitPoseBackboneConfig) -> None: super().__init__() self.num_experts = config.num_experts @@ -377,16 +378,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - dataset_index, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, dataset_index, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, dataset_index, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 65b77a4ccf2f..2045397be372 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -28,6 +28,7 @@ from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_vits import VitsConfig @@ -1067,7 +1068,7 @@ def forward(self, hidden_states, padding_mask): return hidden_states -class VitsEncoderLayer(nn.Module): +class VitsEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: VitsConfig): super().__init__() self.attention = VitsAttention(config) @@ -1145,21 +1146,12 @@ def forward( skip_the_layer = self.training and (dropout_probability < self.layerdrop) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - padding_mask, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - padding_mask=padding_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + padding_mask=padding_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 0617e20de3b9..48b0fd852273 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int @@ -342,7 +343,7 @@ def forward(self, hidden_states, input_tensor): return hidden_states -class VivitLayer(nn.Module): +class VivitLayer(GradientCheckpointingLayer): """This corresponds to the EncoderBlock class in the scenic/vivit implementation.""" def __init__(self, config): @@ -405,15 +406,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] From 43992d9ab474022d90c9c5e6cb1ffee78c0c26fa Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:24:50 +0000 Subject: [PATCH 061/146] whisper. x_clip, xglm --- .../models/whisper/modeling_whisper.py | 67 ++++++------------- .../models/x_clip/modeling_x_clip.py | 47 +++++-------- src/transformers/models/xglm/modeling_xglm.py | 43 +++++------- 3 files changed, 52 insertions(+), 105 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 613f5fb45a72..8d9da86a297c 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -38,6 +38,7 @@ Seq2SeqModelOutput, SequenceClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging @@ -366,7 +367,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER -class WhisperEncoderLayer(nn.Module): +class WhisperEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model @@ -434,7 +435,7 @@ def forward( return outputs -class WhisperDecoderLayer(nn.Module): +class WhisperDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: WhisperConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -713,21 +714,12 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - None, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - None, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -964,34 +956,19 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - None, # encoder attention mask - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, # past_key_value - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values if use_cache else None, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values if use_cache else None, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 41db6f5ce854..3632c67f93a5 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( ModelOutput, @@ -338,7 +339,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->XCLIP -class XCLIPEncoderLayer(nn.Module): +class XCLIPEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: XCLIPConfig): super().__init__() self.embed_dim = config.hidden_size @@ -424,7 +425,7 @@ def extra_repr(self) -> str: return f"p={self.drop_prob}" -class XCLIPVisionEncoderLayer(nn.Module): +class XCLIPVisionEncoderLayer(GradientCheckpointingLayer): """ This corresponds to the `CrossFramelAttentionBlock` class in the original implementation. """ @@ -625,21 +626,12 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -842,21 +834,12 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 08520e5d3abf..66589a48871c 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -25,6 +25,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_xglm import XGLMConfig @@ -253,7 +254,7 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -class XGLMDecoderLayer(nn.Module): +class XGLMDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: XGLMConfig): super().__init__() self.embed_dim = config.d_model @@ -547,33 +548,19 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: From 461961b4717fe4d01fe63da877f104ed9d40472f Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:25:37 +0000 Subject: [PATCH 062/146] xlm_roberta, xmod --- .../xlm_roberta/modeling_xlm_roberta.py | 33 ++++++----------- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 33 ++++++----------- src/transformers/models/xmod/modeling_xmod.py | 36 +++++++------------ 3 files changed, 34 insertions(+), 68 deletions(-) diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 323e97a3e534..9c79c3b664bd 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -37,6 +37,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging @@ -478,7 +479,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->XLMRoberta -class XLMRobertaLayer(nn.Module): +class XLMRobertaLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -604,27 +605,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index a0162b8252b6..a108a6a144a3 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -36,6 +36,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging @@ -469,7 +470,7 @@ def forward(self, hidden_states, input_tensor): return hidden_states -class XLMRobertaXLLayer(nn.Module): +class XLMRobertaXLLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -596,27 +597,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 2d794aa6580f..0f7ff562f1e1 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -34,6 +34,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging @@ -423,7 +424,7 @@ def lang_adapter(self, lang_ids: torch.Tensor, hidden_states: torch.Tensor): return hidden_states -class XmodLayer(nn.Module): +class XmodLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -560,29 +561,16 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - lang_ids, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - lang_ids, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + lang_ids, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: From cf470fd04dd52592b4b3b926c74226f69221259c Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:26:22 +0000 Subject: [PATCH 063/146] yoso --- src/transformers/models/yoso/modeling_yoso.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index be705885988e..95a7163d4e0a 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -32,6 +32,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -507,7 +508,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class YosoLayer(nn.Module): +class YosoLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -559,15 +560,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: From 626dde009b3c701a761bf985139af24e8b999097 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:37:11 +0000 Subject: [PATCH 064/146] zamba --- .../models/zamba/modeling_zamba.py | 41 +++++++------------ .../models/zamba2/modeling_zamba2.py | 41 +++++++------------ 2 files changed, 28 insertions(+), 54 deletions(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index ea832692b7b9..f2eca2a1f1e2 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -580,7 +581,7 @@ def forward(self, x): return down_proj -class ZambaAttentionDecoderLayer(nn.Module): +class ZambaAttentionDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None): super().__init__() self.self_attn = ZambaAttention(config, layer_idx) @@ -643,7 +644,7 @@ def forward( return outputs -class ZambaMambaDecoderLayer(nn.Module): +class ZambaMambaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ZambaConfig, layer_idx: int): super().__init__() self.mamba = ZambaMambaMixer(config=config, layer_idx=layer_idx) @@ -975,31 +976,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - original_hidden_states, - layer_idx, - attention_mask, - causal_mask, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = layer( - hidden_states, - original_hidden_states=original_hidden_states, - layer_idx=layer_idx, - attention_mask=attention_mask, - causal_mask=causal_mask, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = layer( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + causal_mask=causal_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index ecd0abcb0263..64d2d256d7fc 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -35,6 +35,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging @@ -981,7 +982,7 @@ def forward(self, hidden_state, layer_idx=None): return output -class Zamba2AttentionDecoderLayer(nn.Module): +class Zamba2AttentionDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Zamba2Config, block_id: Optional[int] = None, layer_idx: Optional[int] = None): super().__init__() self.block_id = block_id @@ -1045,7 +1046,7 @@ def forward( return outputs -class Zamba2MambaDecoderLayer(nn.Module): +class Zamba2MambaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Zamba2Config, layer_idx: int): super().__init__() self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx) @@ -1349,31 +1350,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - original_hidden_states, - layer_idx, - attention_mask, - causal_mask, - past_key_values, - output_attentions, - use_cache, - position_embeddings, - ) - else: - layer_outputs = layer( - hidden_states, - original_hidden_states=original_hidden_states, - layer_idx=layer_idx, - attention_mask=attention_mask, - causal_mask=causal_mask, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - position_embeddings=position_embeddings, - ) + layer_outputs = layer( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + causal_mask=causal_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] if output_attentions: From 59f8879fb35070869ff8017d9c5c5f042e91ccc5 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 15:40:08 +0000 Subject: [PATCH 065/146] vitdet, wav2vec2, wav2vec2_bert --- .../models/vitdet/modeling_vitdet.py | 13 ++---- .../models/wav2vec2/modeling_wav2vec2.py | 41 +++++-------------- .../wav2vec2_bert/modeling_wav2vec2_bert.py | 27 ++++-------- 3 files changed, 22 insertions(+), 59 deletions(-) diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index e13e36d08e29..0ba71fae8b96 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput, BaseModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from ...utils.backbone_utils import BackboneMixin @@ -439,7 +440,7 @@ def window_unpartition(windows, window_size, pad_height_width, height_width): return hidden_state -class VitDetLayer(nn.Module): +class VitDetLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the original implementation.""" def __init__( @@ -560,15 +561,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 9ed86d274e23..d234b024da1d 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -42,6 +42,7 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -434,13 +435,7 @@ def forward(self, input_values): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -648,7 +643,7 @@ def forward(self, hidden_states): return hidden_states -class Wav2Vec2EncoderLayer(nn.Module): +class Wav2Vec2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = Wav2Vec2Attention( @@ -684,7 +679,7 @@ def forward(self, hidden_states, attention_mask=None, output_attentions=False): return outputs -class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): +class Wav2Vec2EncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = Wav2Vec2Attention( @@ -778,17 +773,9 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -882,17 +869,9 @@ def forward( if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py index b8a60f3d3d0e..4491c4dc40a3 100644 --- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -25,6 +25,7 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_peft_available from .configuration_wav2vec2_bert import Wav2Vec2BertConfig @@ -394,7 +395,7 @@ def _apply_relative_embeddings(self, query, key, relative_position_embeddings): return scores -class Wav2Vec2BertEncoderLayer(nn.Module): +class Wav2Vec2BertEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" def __init__(self, config): @@ -520,23 +521,13 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - relative_position_embeddings, - output_attentions, - conv_attention_mask, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - relative_position_embeddings=relative_position_embeddings, - output_attentions=output_attentions, - conv_attention_mask=conv_attention_mask, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) hidden_states = layer_outputs[0] if skip_the_layer: From b89a5dbd8d601fa5f40073f850a6cd0b12ac2c32 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 16:50:05 +0000 Subject: [PATCH 066/146] unispeech, wav2vec_conformer --- .../models/unispeech/modeling_unispeech.py | 47 +++++-------------- .../unispeech_sat/modeling_unispeech_sat.py | 47 +++++-------------- .../modeling_wav2vec2_conformer.py | 38 +++++---------- 3 files changed, 38 insertions(+), 94 deletions(-) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 02fa44891f51..cd5df6f43853 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -41,6 +41,7 @@ SequenceClassifierOutput, Wav2Vec2BaseModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging @@ -146,7 +147,7 @@ def forward(self, hidden_states): return hidden_states -class UniSpeechNoLayerNormConvLayer(nn.Module): +class UniSpeechNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -167,7 +168,7 @@ def forward(self, hidden_states): return hidden_states -class UniSpeechLayerNormConvLayer(nn.Module): +class UniSpeechLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -194,7 +195,7 @@ def forward(self, hidden_states): return hidden_states -class UniSpeechGroupNormConvLayer(nn.Module): +class UniSpeechGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -254,13 +255,7 @@ def forward(self, input_values): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -456,7 +451,7 @@ def forward(self, hidden_states): return hidden_states -class UniSpeechEncoderLayer(nn.Module): +class UniSpeechEncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = UniSpeechAttention( @@ -540,17 +535,9 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -618,7 +605,7 @@ def forward(self, hidden_states: torch.FloatTensor): return hidden_states -class UniSpeechEncoderLayerStableLayerNorm(nn.Module): +class UniSpeechEncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = UniSpeechAttention( @@ -714,17 +701,9 @@ def forward( if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 72375f8e904f..d260cba8800b 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -43,6 +43,7 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging @@ -149,7 +150,7 @@ def forward(self, hidden_states): return hidden_states -class UniSpeechSatNoLayerNormConvLayer(nn.Module): +class UniSpeechSatNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -170,7 +171,7 @@ def forward(self, hidden_states): return hidden_states -class UniSpeechSatLayerNormConvLayer(nn.Module): +class UniSpeechSatLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -197,7 +198,7 @@ def forward(self, hidden_states): return hidden_states -class UniSpeechSatGroupNormConvLayer(nn.Module): +class UniSpeechSatGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -257,13 +258,7 @@ def forward(self, input_values): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -459,7 +454,7 @@ def forward(self, hidden_states): return hidden_states -class UniSpeechSatEncoderLayer(nn.Module): +class UniSpeechSatEncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = UniSpeechSatAttention( @@ -543,17 +538,9 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -621,7 +608,7 @@ def forward(self, hidden_states: torch.FloatTensor): return hidden_states -class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module): +class UniSpeechSatEncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = UniSpeechSatAttention( @@ -717,17 +704,9 @@ def forward( if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 70042f7b93b9..a01fa908f2fb 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -25,6 +25,7 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -216,7 +217,7 @@ def forward(self, hidden_states: torch.Tensor): return relative_position_embeddings -class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module): +class Wav2Vec2ConformerNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -237,7 +238,7 @@ def forward(self, hidden_states): return hidden_states -class Wav2Vec2ConformerLayerNormConvLayer(nn.Module): +class Wav2Vec2ConformerLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -264,7 +265,7 @@ def forward(self, hidden_states): return hidden_states -class Wav2Vec2ConformerGroupNormConvLayer(nn.Module): +class Wav2Vec2ConformerGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -324,13 +325,7 @@ def forward(self, input_values): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -582,7 +577,7 @@ def _apply_relative_embeddings(self, query, key, relative_position_embeddings): return scores -class Wav2Vec2ConformerEncoderLayer(nn.Module): +class Wav2Vec2ConformerEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" def __init__(self, config): @@ -709,21 +704,12 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - relative_position_embeddings, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - relative_position_embeddings=relative_position_embeddings, - output_attentions=output_attentions, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if skip_the_layer: From db524cbefe4d9089a1e028de32a269b0886007a8 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 16:52:44 +0000 Subject: [PATCH 067/146] wavlm --- .../models/wavlm/modeling_wavlm.py | 63 ++++++------------- 1 file changed, 20 insertions(+), 43 deletions(-) diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index d718d4958a1b..3b9ec3a9cac3 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -25,6 +25,7 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_peft_available, logging from .configuration_wavlm import WavLMConfig @@ -294,7 +295,7 @@ def forward(self, hidden_states): return hidden_states -class WavLMEncoderLayer(nn.Module): +class WavLMEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( @@ -335,7 +336,7 @@ def forward(self, hidden_states, attention_mask=None, position_bias=None, output return outputs -class WavLMEncoderLayerStableLayerNorm(nn.Module): +class WavLMEncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( @@ -418,22 +419,13 @@ def forward( skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - position_bias, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - index=i, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + index=i, + ) hidden_states, position_bias = layer_outputs[:2] @@ -504,21 +496,12 @@ def forward( if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - position_bias, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - position_bias=position_bias, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + position_bias=position_bias, + ) hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: @@ -696,7 +679,7 @@ def _get_feature_vector_attention_mask( return attention_mask -class WavLMNoLayerNormConvLayer(nn.Module): +class WavLMNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -717,7 +700,7 @@ def forward(self, hidden_states): return hidden_states -class WavLMLayerNormConvLayer(nn.Module): +class WavLMLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -744,7 +727,7 @@ def forward(self, hidden_states): return hidden_states -class WavLMGroupNormConvLayer(nn.Module): +class WavLMGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -801,13 +784,7 @@ def forward(self, input_values): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states From 96db85e66e01b8b2085e57bcc95e8fb8eb3fc006 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 16:54:48 +0000 Subject: [PATCH 068/146] speecht5 --- .../models/speecht5/modeling_speecht5.py | 83 ++++++------------- 1 file changed, 27 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index c63426468d2d..92fa3c95c9e1 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -35,6 +35,7 @@ Seq2SeqModelOutput, Seq2SeqSpectrogramOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_speecht5 import SpeechT5Config, SpeechT5HifiGanConfig @@ -207,7 +208,7 @@ def compute_num_masked_span(input_length): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SpeechT5 -class SpeechT5NoLayerNormConvLayer(nn.Module): +class SpeechT5NoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -229,7 +230,7 @@ def forward(self, hidden_states): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SpeechT5 -class SpeechT5LayerNormConvLayer(nn.Module): +class SpeechT5LayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -257,7 +258,7 @@ def forward(self, hidden_states): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SpeechT5 -class SpeechT5GroupNormConvLayer(nn.Module): +class SpeechT5GroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -487,13 +488,7 @@ def forward(self, input_values): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -1032,7 +1027,7 @@ def forward(self, hidden_states): return hidden_states -class SpeechT5EncoderLayer(nn.Module): +class SpeechT5EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: SpeechT5Config): super().__init__() self.attention = SpeechT5Attention( @@ -1093,7 +1088,7 @@ def forward( return outputs -class SpeechT5DecoderLayer(nn.Module): +class SpeechT5DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: SpeechT5Config): super().__init__() self.self_attn = SpeechT5Attention( @@ -1338,23 +1333,13 @@ def forward( if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - position_bias, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -1636,33 +1621,19 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: From 279041bb9ec0abd3be7f99746ff6e5234b287226 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 16:55:27 +0000 Subject: [PATCH 069/146] swinv2 --- .../models/swinv2/modeling_swinv2.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 050e8d3fd271..8cf738513759 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -787,7 +788,7 @@ def forward( return layer_outputs -class Swinv2Stage(nn.Module): +class Swinv2Stage(GradientCheckpointingLayer): def __init__( self, config, dim, input_resolution, depth, num_heads, drop_path, downsample, pretrained_window_size=0 ): @@ -902,17 +903,12 @@ def forward( for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, input_dimensions, layer_head_mask - ) - else: - layer_outputs = layer_module( - hidden_states, - input_dimensions, - layer_head_mask, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + ) hidden_states = layer_outputs[0] hidden_states_before_downsampling = layer_outputs[1] From 5a3b5712cd873289118942774b32f45b61709ad3 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 16:57:49 +0000 Subject: [PATCH 070/146] sew / _d --- src/transformers/models/sew/modeling_sew.py | 31 ++++--------- .../models/sew_d/modeling_sew_d.py | 44 ++++++------------- 2 files changed, 23 insertions(+), 52 deletions(-) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 30949092bd0e..304c1bfd0982 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -33,6 +33,7 @@ from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging @@ -42,7 +43,7 @@ logger = logging.get_logger(__name__) -class SEWNoLayerNormConvLayer(nn.Module): +class SEWNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -63,7 +64,7 @@ def forward(self, hidden_states): return hidden_states -class SEWLayerNormConvLayer(nn.Module): +class SEWLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -90,7 +91,7 @@ def forward(self, hidden_states): return hidden_states -class SEWGroupNormConvLayer(nn.Module): +class SEWGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -223,13 +224,7 @@ def forward(self, input_values): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -410,7 +405,7 @@ def forward(self, hidden_states): return hidden_states -class SEWEncoderLayer(nn.Module): +class SEWEncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = SEWAttention( @@ -521,17 +516,9 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index f2d682884c47..eb05ea9f59ab 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import softmax_backward_data from ...utils import auto_docstring, logging @@ -242,7 +243,7 @@ def get_mask(input, local_context): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEWD -class SEWDNoLayerNormConvLayer(nn.Module): +class SEWDNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -264,7 +265,7 @@ def forward(self, hidden_states): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEWD -class SEWDLayerNormConvLayer(nn.Module): +class SEWDLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -292,7 +293,7 @@ def forward(self, hidden_states): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEWD -class SEWDGroupNormConvLayer(nn.Module): +class SEWDGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -429,13 +430,7 @@ def forward(self, input_values): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -930,7 +925,7 @@ def forward(self, hidden_states, input_tensor): return hidden_states -class SEWDLayer(nn.Module): +class SEWDLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = SEWDAttention(config) @@ -1087,25 +1082,14 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (output_states,) - if self.gradient_checkpointing and self.training: - output_states = self._gradient_checkpointing_func( - layer_module.__call__, - next_kv, - attention_mask, - query_states, - relative_pos, - rel_embeddings, - output_attentions, - ) - else: - output_states = layer_module( - next_kv, - attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - ) + output_states = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) if output_attentions: output_states, att_m = output_states From b1d78cde9885fa847aac98423927bc7f1defe3fc Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 17:01:43 +0000 Subject: [PATCH 071/146] seamless_mt4 / _v2 --- .../seamless_m4t/modeling_seamless_m4t.py | 79 +++++---------- .../modeling_seamless_m4t_v2.py | 99 ++++++------------- 2 files changed, 56 insertions(+), 122 deletions(-) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 24f02c3e6b4c..86f245e340ca 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -39,6 +39,7 @@ Seq2SeqModelOutput, Wav2Vec2BaseModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from .configuration_seamless_m4t import SeamlessM4TConfig @@ -610,7 +611,7 @@ def _apply_relative_embeddings(self, query, key, relative_position_embeddings): return scores -class SeamlessM4TConformerEncoderLayer(nn.Module): +class SeamlessM4TConformerEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerEncoderLayer.__init__ with Wav2Vec2->SeamlessM4T, attention_dropout->speech_encoder_dropout, torch.nn->nn @@ -743,23 +744,13 @@ def forward( ) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - relative_position_embeddings, - output_attentions, - conv_attention_mask, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - relative_position_embeddings=relative_position_embeddings, - output_attentions=output_attentions, - conv_attention_mask=conv_attention_mask, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -1173,7 +1164,7 @@ def forward(self, hidden_states): return hidden_states -class SeamlessM4TEncoderLayer(nn.Module): +class SeamlessM4TEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: SeamlessM4TConfig, encoder_ffn_dim=None, encoder_attention_heads=None): super().__init__() encoder_ffn_dim = config.encoder_ffn_dim if encoder_ffn_dim is None else encoder_ffn_dim @@ -1236,7 +1227,7 @@ def forward( return outputs -class SeamlessM4TDecoderLayer(nn.Module): +class SeamlessM4TDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None): super().__init__() decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim @@ -1691,19 +1682,11 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.forward, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1866,27 +1849,15 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 95c586bfd761..1fd255b31831 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -36,6 +36,7 @@ Seq2SeqModelOutput, Wav2Vec2BaseModelOutput, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from .configuration_seamless_m4t_v2 import SeamlessM4Tv2Config @@ -489,7 +490,7 @@ def forward( return attn_output, attn_weights -class SeamlessM4Tv2ConformerEncoderLayer(nn.Module): +class SeamlessM4Tv2ConformerEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerEncoderLayer.__init__ with Wav2Vec2->SeamlessM4Tv2, attention_dropout->speech_encoder_dropout, torch.nn->nn @@ -645,21 +646,12 @@ def forward( ) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - conv_attention_mask, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - conv_attention_mask=conv_attention_mask, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -1031,7 +1023,7 @@ def forward(self, hidden_states): # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TEncoderLayer with SeamlessM4T->SeamlessM4Tv2 -class SeamlessM4Tv2EncoderLayer(nn.Module): +class SeamlessM4Tv2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: SeamlessM4Tv2Config, encoder_ffn_dim=None, encoder_attention_heads=None): super().__init__() encoder_ffn_dim = config.encoder_ffn_dim if encoder_ffn_dim is None else encoder_ffn_dim @@ -1095,7 +1087,7 @@ def forward( # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoderLayer with SeamlessM4T->SeamlessM4Tv2 -class SeamlessM4Tv2DecoderLayer(nn.Module): +class SeamlessM4Tv2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_attention_heads=None): super().__init__() decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim @@ -1210,7 +1202,7 @@ def forward( return outputs -class SeamlessM4Tv2TextToUnitDecoderLayer(nn.Module): +class SeamlessM4Tv2TextToUnitDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_attention_heads=None): super().__init__() decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim @@ -1760,19 +1752,11 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.forward, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1936,27 +1920,15 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: @@ -2137,21 +2109,12 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - padding_mask, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - padding_mask=padding_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + padding_mask=padding_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: From 9a6d135a66b97cc965fe4f20c41f459cd3e943b3 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 17:08:37 +0000 Subject: [PATCH 072/146] deprecated models update --- .../models/deprecated/nezha/modeling_nezha.py | 33 +++----- .../open_llama/modeling_open_llama.py | 30 +++---- .../deprecated/qdqbert/modeling_qdqbert.py | 38 +++------ .../models/deprecated/realm/modeling_realm.py | 33 +++----- .../modeling_speech_to_text_2.py | 41 ++++------ .../modeling_trajectory_transformer.py | 14 +--- .../models/deprecated/tvlt/modeling_tvlt.py | 24 +----- .../vit_hybrid/modeling_vit_hybrid.py | 13 +--- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 78 +++++++------------ 9 files changed, 94 insertions(+), 210 deletions(-) diff --git a/src/transformers/models/deprecated/nezha/modeling_nezha.py b/src/transformers/models/deprecated/nezha/modeling_nezha.py index d1c3fd8dbaa5..28572381f8e5 100644 --- a/src/transformers/models/deprecated/nezha/modeling_nezha.py +++ b/src/transformers/models/deprecated/nezha/modeling_nezha.py @@ -36,6 +36,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ....utils import ( @@ -438,7 +439,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class NezhaLayer(nn.Module): +class NezhaLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -563,27 +564,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 100d02e22855..5e99841df762 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -30,6 +30,7 @@ from ....activations import ACT2FN from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_open_llama import OpenLlamaConfig @@ -339,7 +340,7 @@ def forward( return attn_output, attn_weights, past_key_value -class OpenLlamaDecoderLayer(nn.Module): +class OpenLlamaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: OpenLlamaConfig): super().__init__() self.hidden_size = config.hidden_size @@ -631,25 +632,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - None, - output_attentions, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py index f30a07570095..61bc778e1349 100755 --- a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py @@ -37,6 +37,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ....utils import ( @@ -452,7 +453,7 @@ def forward(self, hidden_states, input_tensor): # Based on transformers.models.bert.modeling_bert.BertLayer with Bert -> QDQBert -class QDQBertLayer(nn.Module): +class QDQBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.seq_len_dim = 1 @@ -568,32 +569,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/deprecated/realm/modeling_realm.py b/src/transformers/models/deprecated/realm/modeling_realm.py index 5714bf52a0ec..ec3df4672959 100644 --- a/src/transformers/models/deprecated/realm/modeling_realm.py +++ b/src/transformers/models/deprecated/realm/modeling_realm.py @@ -30,6 +30,7 @@ MaskedLMOutput, ModelOutput, ) +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings @@ -447,7 +448,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class RealmLayer(nn.Module): +class RealmLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -572,27 +573,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py index ce4fdd1bb2a0..43bfc042155c 100755 --- a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py @@ -25,6 +25,7 @@ from ....activations import ACT2FN from ....modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ....modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....utils import add_start_docstrings, logging, replace_return_docstrings from .configuration_speech_to_text_2 import Speech2Text2Config @@ -263,7 +264,7 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -class Speech2Text2DecoderLayer(nn.Module): +class Speech2Text2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Speech2Text2Config): super().__init__() self.embed_dim = config.d_model @@ -612,31 +613,19 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index a06a52d9f3cc..fdfbecf7fe29 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import functional as F +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....utils import ( ModelOutput, @@ -346,7 +347,7 @@ def forward( return outputs -class Block(nn.Module): +class Block(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.ln1 = nn.LayerNorm(config.n_embd) @@ -540,16 +541,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - layer_past, - use_cache, - output_attentions, - ) - else: - outputs = block(hidden_states, layer_past, use_cache, output_attentions) + outputs = block(hidden_states, layer_past, use_cache, output_attentions) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py index d0083211bd30..61f38e215de7 100644 --- a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py +++ b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py @@ -27,6 +27,7 @@ from ....activations import ACT2FN from ....modeling_outputs import BaseModelOutput, SequenceClassifierOutput +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ....utils import ( @@ -483,7 +484,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class TvltLayer(nn.Module): +class TvltLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config): @@ -546,16 +547,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] @@ -853,15 +845,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - None, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py index 553b0a7bb3bc..81af62c3a448 100644 --- a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py @@ -25,6 +25,7 @@ from ....activations import ACT2FN from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ....utils import ( @@ -390,7 +391,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to } -class ViTHybridLayer(nn.Module): +class ViTHybridLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ViTHybridConfig) -> None: @@ -457,15 +458,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py index e8b23b961e59..124525b7385f 100644 --- a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -27,6 +27,7 @@ from ....activations import ACT2FN from ....modeling_outputs import BaseModelOutput +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....utils import ( ModelOutput, @@ -1090,7 +1091,7 @@ def get_predict_relative_pos_embeddings( return predict_relative_pos_embeddings -class XLMProphetNetEncoderLayer(nn.Module): +class XLMProphetNetEncoderLayer(GradientCheckpointingLayer): """ Encoder block for XLMProphetnet """ @@ -1133,7 +1134,7 @@ def forward( return outputs -class XLMProphetNetDecoderLayer(nn.Module): +class XLMProphetNetDecoderLayer(GradientCheckpointingLayer): """ Decoder block for XLMProphetnet """ @@ -1320,21 +1321,12 @@ def forward( if output_hidden_states: encoder_hidden_states = encoder_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - extended_attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=extended_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1554,41 +1546,23 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - extended_attention_mask, - encoder_hidden_states, - extended_encoder_attention_mask, - (head_mask[idx] if head_mask is not None else None), - (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - extended_predict_attention_mask, - main_relative_position_buckets, - predict_relative_position_buckets, - position_ids, - None, - use_cache, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=extended_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attn_mask=extended_encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - extended_predict_attention_mask=extended_predict_attention_mask, - main_relative_position_buckets=main_relative_position_buckets, - predict_relative_position_buckets=predict_relative_position_buckets, - position_ids=position_ids, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attn_mask=extended_encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From a18e2578de4ad69f18d309aa23b6f8a706384ed8 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 17:10:44 +0000 Subject: [PATCH 073/146] bros --- src/transformers/models/bros/modeling_bros.py | 41 ++++++------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 94d3a9d985d9..d9c4d4e95e83 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -30,6 +30,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging from .configuration_bros import BrosConfig @@ -428,7 +429,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class BrosLayer(nn.Module): +class BrosLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -550,34 +551,16 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - bbox_pos_emb, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states=hidden_states, - bbox_pos_emb=bbox_pos_emb, - attention_mask=attention_mask, - head_mask=layer_head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - ) + layer_outputs = layer_module( + hidden_states=hidden_states, + bbox_pos_emb=bbox_pos_emb, + attention_mask=attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: From 66d0a628c7fa944db7ef32370f1a152e6eddf700 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 17:12:44 +0000 Subject: [PATCH 074/146] gemma2, gemma3 --- .../models/gemma2/modeling_gemma2.py | 38 ++++++----------- .../models/gemma3/modeling_gemma3.py | 41 +++++++------------ 2 files changed, 27 insertions(+), 52 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 28a57ab9090b..c763a155b595 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -38,6 +38,7 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, logging from ...utils.deprecation import deprecate_kwarg @@ -238,7 +239,7 @@ def forward( return attn_output, attn_weights -class Gemma2DecoderLayer(nn.Module): +class Gemma2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -466,30 +467,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - position_embeddings, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 8d78ca92b4ca..bd1315d9e974 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -33,6 +33,7 @@ from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update @@ -364,7 +365,7 @@ def forward( return attn_output, attn_weights -class Gemma3DecoderLayer(nn.Module): +class Gemma3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma3TextConfig, layer_idx: int): super().__init__() self.config = config @@ -581,32 +582,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - position_embeddings_global, - position_embeddings_local, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] From c0e5690a61597c3cd5a9b3b9d5147406ce0dddbc Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 17:29:37 +0000 Subject: [PATCH 075/146] got, hiera, hubert, llama4, mllama, oneformer, phi, olmoe, informer --- .../models/got_ocr2/modeling_got_ocr2.py | 11 +-- .../models/hiera/modeling_hiera.py | 10 +-- .../models/hubert/modeling_hubert.py | 47 +++-------- src/transformers/models/idefics/vision.py | 24 ++---- .../models/informer/modeling_informer.py | 80 +++++++------------ .../models/llama4/modeling_llama4.py | 63 +++++---------- .../models/mllama/modeling_mllama.py | 70 ++++++---------- .../models/olmoe/modeling_olmoe.py | 39 +++------ .../models/oneformer/modeling_oneformer.py | 8 +- src/transformers/models/phi/modeling_phi.py | 38 +++------ 10 files changed, 128 insertions(+), 262 deletions(-) diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index bcf30f585250..0a9b403053b8 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -33,6 +33,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple from ..auto import AutoModel @@ -192,7 +193,7 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch return outputs -class GotOcr2VisionLayer(nn.Module): +class GotOcr2VisionLayer(GradientCheckpointingLayer): def __init__(self, config, window_size): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -463,13 +464,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/hiera/modeling_hiera.py b/src/transformers/models/hiera/modeling_hiera.py index 2fadde33211e..a0664c80fcc5 100644 --- a/src/transformers/models/hiera/modeling_hiera.py +++ b/src/transformers/models/hiera/modeling_hiera.py @@ -32,6 +32,7 @@ ModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...utils import auto_docstring, logging, torch_int from ...utils.backbone_utils import BackboneMixin from .configuration_hiera import HieraConfig @@ -540,7 +541,7 @@ def forward( return (hidden_states, attn_weights) -class HieraStage(nn.Module): +class HieraStage(GradientCheckpointingLayer): def __init__( self, config, @@ -734,12 +735,7 @@ def forward( for i, stage_module in enumerate(self.stages): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - stage_module.__call__, hidden_states, layer_head_mask, output_attentions - ) - else: - layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index faa0ff48c688..e51984c44900 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -34,6 +34,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_hubert import HubertConfig @@ -107,7 +108,7 @@ def forward(self, hidden_states): return hidden_states -class HubertNoLayerNormConvLayer(nn.Module): +class HubertNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -128,7 +129,7 @@ def forward(self, hidden_states): return hidden_states -class HubertLayerNormConvLayer(nn.Module): +class HubertLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -155,7 +156,7 @@ def forward(self, hidden_states): return hidden_states -class HubertGroupNormConvLayer(nn.Module): +class HubertGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -212,13 +213,7 @@ def forward(self, input_values): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -417,7 +412,7 @@ def forward(self, hidden_states): return hidden_states -class HubertEncoderLayer(nn.Module): +class HubertEncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = HubertAttention( @@ -501,17 +496,9 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -579,7 +566,7 @@ def forward(self, hidden_states: torch.FloatTensor): return hidden_states -class HubertEncoderLayerStableLayerNorm(nn.Module): +class HubertEncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = HubertAttention( @@ -675,17 +662,9 @@ def forward( if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index 815b902d3fba..098f4966b442 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import ( ModelOutput, @@ -283,7 +284,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->IdeficsVision -class IdeficsVisionEncoderLayer(nn.Module): +class IdeficsVisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: IdeficsVisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -400,21 +401,12 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 6f59e0ee14bd..0805082cd34d 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -42,6 +42,7 @@ Seq2SeqTSPredictionOutput, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import auto_docstring, is_torch_flex_attn_available, logging @@ -744,7 +745,7 @@ def forward( # source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py -class InformerConvLayer(nn.Module): +class InformerConvLayer(GradientCheckpointingLayer): def __init__(self, c_in): super().__init__() self.downConv = nn.Conv1d( @@ -767,7 +768,7 @@ def forward(self, x): return x -class InformerEncoderLayer(nn.Module): +class InformerEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: InformerConfig): super().__init__() self.embed_dim = config.d_model @@ -845,7 +846,7 @@ def forward( return outputs -class InformerDecoderLayer(nn.Module): +class InformerDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -1086,27 +1087,15 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - if conv_layer is not None: - output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0]) - layer_outputs = (output,) + layer_outputs[1:] - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) - if conv_layer is not None: - output = conv_layer(layer_outputs[0]) - layer_outputs = (output,) + layer_outputs[1:] + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + if conv_layer is not None: + output = conv_layer(layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] hidden_states = layer_outputs[0] @@ -1299,35 +1288,20 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 38b6fde10372..cfbf4edf0451 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -30,6 +30,7 @@ from ...masking_utils import create_causal_mask, create_chunked_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -360,7 +361,7 @@ def forward( return attn_output, attn_weights -class Llama4TextDecoderLayer(nn.Module): +class Llama4TextDecoderLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.hidden_size = config.hidden_size @@ -571,31 +572,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - False, # output_router_logits is False - use_cache, - cache_position, - freq_cis, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=freq_cis, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=freq_cis, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] @@ -930,7 +917,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class Llama4VisionEncoderLayer(nn.Module): +class Llama4VisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Llama4VisionConfig): super().__init__() self.hidden_size = config.hidden_size @@ -1033,21 +1020,13 @@ def forward( for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - freqs_ci, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_state=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - freqs_ci=freqs_ci, - ) + + layer_outputs = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + freqs_ci=freqs_ci, + ) if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 841674f0c7fb..c643951b411c 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -28,6 +28,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -280,7 +281,7 @@ def forward( return attn_output, attn_weights -class MllamaVisionEncoderLayer(nn.Module): +class MllamaVisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: MllamaVisionConfig, is_gated: bool = False): super().__init__() @@ -387,19 +388,12 @@ def forward( for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_state=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) @@ -669,7 +663,7 @@ def forward(self, x): # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer -class MllamaSelfAttentionDecoderLayer(nn.Module): +class MllamaSelfAttentionDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MllamaTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -754,7 +748,7 @@ def forward( return outputs -class MllamaCrossAttentionDecoderLayer(torch.nn.Module): +class MllamaCrossAttentionDecoderLayer(GradientCheckpointingLayer): """Cross-attention transformer block with tanh-gated attention and feedforward.""" def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: @@ -1402,36 +1396,20 @@ def forward( if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - cross_attention_states, - cross_attention_mask, - causal_mask, - full_text_row_masked_out_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - attention_mask=causal_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + attention_mask=causal_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 6dc3c12c1ffb..3ea63f11cb31 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -27,6 +27,7 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import auto_docstring, logging from .configuration_olmoe import OlmoeConfig @@ -610,7 +611,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -class OlmoeDecoderLayer(nn.Module): +class OlmoeDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: OlmoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -827,31 +828,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index d400e08cd18a..195d0c6c1b4a 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...utils import ( ModelOutput, auto_docstring, @@ -2563,7 +2564,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class OneFormerTextTransformerLayer(nn.Module): +class OneFormerTextTransformerLayer(GradientCheckpointingLayer): def __init__(self, width: int, heads: int, attn_mask: torch.Tensor, layer_norm_eps=1e-05): super().__init__() self.self_attn = nn.MultiheadAttention(width, heads) @@ -2617,10 +2618,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor): for layer in self.layers: - if self.use_checkpoint: - hidden_states = self._gradient_checkpointing_func(layer, hidden_states) - else: - hidden_states = layer(hidden_states) + hidden_states = layer(hidden_states) return hidden_states diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 5edceb27c0da..c9a0ca85d6c8 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -23,6 +23,7 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_phi import PhiConfig @@ -206,7 +207,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class PhiDecoderLayer(nn.Module): +class PhiDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() self.self_attn = PhiAttention(config, layer_idx=layer_idx) @@ -410,30 +411,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] From 094275583a765743d16b7a63c7741318f85538ec Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 17:36:32 +0000 Subject: [PATCH 076/146] fixup --- src/transformers/models/bark/modeling_bark.py | 2 +- src/transformers/models/bart/modeling_bart.py | 6 ++---- .../models/bert_generation/modeling_bert_generation.py | 2 +- src/transformers/models/big_bird/modeling_big_bird.py | 2 +- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 6 ++---- src/transformers/models/biogpt/modeling_biogpt.py | 3 +-- src/transformers/models/blenderbot/modeling_blenderbot.py | 6 ++---- .../models/blenderbot_small/modeling_blenderbot_small.py | 6 ++---- src/transformers/models/bloom/modeling_bloom.py | 2 +- src/transformers/models/bridgetower/modeling_bridgetower.py | 2 +- src/transformers/models/bros/modeling_bros.py | 2 +- src/transformers/models/camembert/modeling_camembert.py | 2 +- src/transformers/models/canine/modeling_canine.py | 2 +- src/transformers/models/chameleon/modeling_chameleon.py | 2 +- .../models/chinese_clip/modeling_chinese_clip.py | 2 +- src/transformers/models/clap/modeling_clap.py | 2 +- src/transformers/models/clip/modeling_clip.py | 2 +- src/transformers/models/codegen/modeling_codegen.py | 2 +- .../models/conditional_detr/modeling_conditional_detr.py | 2 +- src/transformers/models/convbert/modeling_convbert.py | 2 +- src/transformers/models/dab_detr/modeling_dab_detr.py | 2 +- src/transformers/models/data2vec/modeling_data2vec_audio.py | 2 +- src/transformers/models/data2vec/modeling_data2vec_text.py | 2 +- .../models/data2vec/modeling_data2vec_vision.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- src/transformers/models/deberta/modeling_deberta.py | 2 +- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 2 +- .../decision_transformer/modeling_decision_transformer.py | 2 +- .../models/deformable_detr/modeling_deformable_detr.py | 2 +- src/transformers/models/deit/modeling_deit.py | 2 +- src/transformers/models/deprecated/deta/modeling_deta.py | 2 +- src/transformers/models/deprecated/mctct/modeling_mctct.py | 2 +- src/transformers/models/deprecated/nezha/modeling_nezha.py | 2 +- .../models/deprecated/open_llama/modeling_open_llama.py | 2 +- .../models/deprecated/qdqbert/modeling_qdqbert.py | 2 +- src/transformers/models/deprecated/realm/modeling_realm.py | 2 +- .../speech_to_text_2/modeling_speech_to_text_2.py | 6 ++---- src/transformers/models/deprecated/tvlt/modeling_tvlt.py | 2 +- .../models/deprecated/vit_hybrid/modeling_vit_hybrid.py | 2 +- .../deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py | 6 ++---- src/transformers/models/detr/modeling_detr.py | 2 +- src/transformers/models/dinov2/modeling_dinov2.py | 2 +- .../dinov2_with_registers/modeling_dinov2_with_registers.py | 2 +- src/transformers/models/distilbert/modeling_distilbert.py | 2 +- src/transformers/models/dpt/modeling_dpt.py | 2 +- src/transformers/models/electra/modeling_electra.py | 2 +- src/transformers/models/ernie/modeling_ernie.py | 2 +- src/transformers/models/esm/modeling_esm.py | 2 +- src/transformers/models/falcon/modeling_falcon.py | 2 +- src/transformers/models/flava/modeling_flava.py | 2 +- src/transformers/models/fnet/modeling_fnet.py | 2 +- src/transformers/models/focalnet/modeling_focalnet.py | 2 +- src/transformers/models/gemma2/modeling_gemma2.py | 3 +-- src/transformers/models/gemma3/modeling_gemma3.py | 3 +-- src/transformers/models/git/modeling_git.py | 2 +- src/transformers/models/got_ocr2/modeling_got_ocr2.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 2 +- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 2 +- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 2 +- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 2 +- src/transformers/models/groupvit/modeling_groupvit.py | 2 +- src/transformers/models/hiera/modeling_hiera.py | 2 +- src/transformers/models/hubert/modeling_hubert.py | 2 +- src/transformers/models/idefics/vision.py | 2 +- src/transformers/models/idefics2/modeling_idefics2.py | 2 +- src/transformers/models/idefics3/modeling_idefics3.py | 2 +- src/transformers/models/ijepa/modeling_ijepa.py | 2 +- src/transformers/models/imagegpt/modeling_imagegpt.py | 2 +- src/transformers/models/informer/modeling_informer.py | 6 ++---- src/transformers/models/internvl/modeling_internvl.py | 2 +- src/transformers/models/jetmoe/modeling_jetmoe.py | 2 +- src/transformers/models/kosmos2/modeling_kosmos2.py | 6 ++---- src/transformers/models/layoutlm/modeling_layoutlm.py | 2 +- src/transformers/models/layoutlmv2/modeling_layoutlmv2.py | 2 +- src/transformers/models/layoutlmv3/modeling_layoutlmv3.py | 2 +- src/transformers/models/led/modeling_led.py | 6 ++---- src/transformers/models/lilt/modeling_lilt.py | 2 +- src/transformers/models/llama4/modeling_llama4.py | 2 +- src/transformers/models/longt5/modeling_longt5.py | 2 +- src/transformers/models/luke/modeling_luke.py | 2 +- src/transformers/models/m2m_100/modeling_m2m_100.py | 2 +- src/transformers/models/marian/modeling_marian.py | 6 ++---- src/transformers/models/markuplm/modeling_markuplm.py | 2 +- src/transformers/models/mask2former/modeling_mask2former.py | 2 +- src/transformers/models/maskformer/modeling_maskformer.py | 2 +- .../models/maskformer/modeling_maskformer_swin.py | 2 +- src/transformers/models/mbart/modeling_mbart.py | 6 ++---- .../models/megatron_bert/modeling_megatron_bert.py | 2 +- src/transformers/models/mimi/modeling_mimi.py | 2 +- src/transformers/models/mixtral/modeling_mixtral.py | 3 +-- src/transformers/models/mlcd/modeling_mlcd.py | 2 +- src/transformers/models/mllama/modeling_mllama.py | 2 +- src/transformers/models/mobilevit/modeling_mobilevit.py | 2 +- src/transformers/models/mobilevitv2/modeling_mobilevitv2.py | 2 +- src/transformers/models/modernbert/modeling_modernbert.py | 2 +- src/transformers/models/moshi/modeling_moshi.py | 2 +- src/transformers/models/mpt/modeling_mpt.py | 2 +- src/transformers/models/mra/modeling_mra.py | 2 +- src/transformers/models/mt5/modeling_mt5.py | 2 +- src/transformers/models/musicgen/modeling_musicgen.py | 6 ++---- .../models/musicgen_melody/modeling_musicgen_melody.py | 2 +- src/transformers/models/mvp/modeling_mvp.py | 6 ++---- src/transformers/models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 2 +- .../models/nystromformer/modeling_nystromformer.py | 2 +- src/transformers/models/olmoe/modeling_olmoe.py | 2 +- src/transformers/models/oneformer/modeling_oneformer.py | 2 +- src/transformers/models/opt/modeling_opt.py | 2 +- src/transformers/models/owlv2/modeling_owlv2.py | 2 +- src/transformers/models/owlvit/modeling_owlvit.py | 2 +- src/transformers/models/pegasus/modeling_pegasus.py | 6 ++---- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 2 +- src/transformers/models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 3 +-- src/transformers/models/phimoe/modeling_phimoe.py | 2 +- src/transformers/models/pix2struct/modeling_pix2struct.py | 6 ++++-- src/transformers/models/pixtral/modeling_pixtral.py | 2 +- src/transformers/models/plbart/modeling_plbart.py | 6 ++---- src/transformers/models/pop2piano/modeling_pop2piano.py | 2 +- src/transformers/models/prophetnet/modeling_prophetnet.py | 6 ++---- src/transformers/models/pvt_v2/modeling_pvt_v2.py | 2 +- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1 - src/transformers/models/qwen2_audio/modeling_qwen2_audio.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 1 - src/transformers/models/qwen3_moe/modeling_qwen3_moe.py | 3 +-- .../models/recurrent_gemma/modeling_recurrent_gemma.py | 2 +- src/transformers/models/rembert/modeling_rembert.py | 2 +- src/transformers/models/roberta/modeling_roberta.py | 2 +- .../roberta_prelayernorm/modeling_roberta_prelayernorm.py | 2 +- src/transformers/models/roc_bert/modeling_roc_bert.py | 2 +- src/transformers/models/roformer/modeling_roformer.py | 2 +- src/transformers/models/sam/modeling_sam.py | 2 +- src/transformers/models/sam_hq/modeling_sam_hq.py | 2 +- .../models/seamless_m4t/modeling_seamless_m4t.py | 2 +- .../models/seamless_m4t_v2/modeling_seamless_m4t_v2.py | 2 +- src/transformers/models/sew/modeling_sew.py | 2 +- src/transformers/models/sew_d/modeling_sew_d.py | 2 +- src/transformers/models/smolvlm/modeling_smolvlm.py | 2 +- .../models/speech_to_text/modeling_speech_to_text.py | 6 ++---- src/transformers/models/speecht5/modeling_speecht5.py | 6 ++---- src/transformers/models/splinter/modeling_splinter.py | 2 +- src/transformers/models/stablelm/modeling_stablelm.py | 2 +- src/transformers/models/swin/modeling_swin.py | 2 +- src/transformers/models/swin2sr/modeling_swin2sr.py | 2 +- src/transformers/models/swinv2/modeling_swinv2.py | 2 +- .../switch_transformers/modeling_switch_transformers.py | 2 +- src/transformers/models/t5/modeling_t5.py | 2 +- .../models/table_transformer/modeling_table_transformer.py | 2 +- src/transformers/models/tapas/modeling_tapas.py | 2 +- .../modeling_time_series_transformer.py | 6 ++---- src/transformers/models/timesformer/modeling_timesformer.py | 2 +- src/transformers/models/trocr/modeling_trocr.py | 6 ++---- src/transformers/models/tvp/modeling_tvp.py | 2 +- src/transformers/models/umt5/modeling_umt5.py | 2 +- src/transformers/models/unispeech/modeling_unispeech.py | 2 +- .../models/unispeech_sat/modeling_unispeech_sat.py | 2 +- src/transformers/models/videomae/modeling_videomae.py | 2 +- src/transformers/models/vilt/modeling_vilt.py | 2 +- src/transformers/models/visual_bert/modeling_visual_bert.py | 2 +- src/transformers/models/vit/modeling_vit.py | 2 +- src/transformers/models/vit_mae/modeling_vit_mae.py | 2 +- src/transformers/models/vit_msn/modeling_vit_msn.py | 2 +- src/transformers/models/vitdet/modeling_vitdet.py | 2 +- .../models/vitpose_backbone/modeling_vitpose_backbone.py | 2 +- src/transformers/models/vits/modeling_vits.py | 2 +- src/transformers/models/vivit/modeling_vivit.py | 2 +- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 2 +- .../models/wav2vec2_bert/modeling_wav2vec2_bert.py | 2 +- .../wav2vec2_conformer/modeling_wav2vec2_conformer.py | 2 +- src/transformers/models/wavlm/modeling_wavlm.py | 2 +- src/transformers/models/whisper/modeling_whisper.py | 6 ++---- src/transformers/models/x_clip/modeling_x_clip.py | 2 +- src/transformers/models/xglm/modeling_xglm.py | 6 ++---- src/transformers/models/xlm_roberta/modeling_xlm_roberta.py | 2 +- .../models/xlm_roberta_xl/modeling_xlm_roberta_xl.py | 2 +- src/transformers/models/xmod/modeling_xmod.py | 2 +- src/transformers/models/yolos/modeling_yolos.py | 2 +- src/transformers/models/yoso/modeling_yoso.py | 2 +- src/transformers/models/zamba/modeling_zamba.py | 2 +- src/transformers/models/zamba2/modeling_zamba2.py | 2 +- 182 files changed, 205 insertions(+), 255 deletions(-) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 4598552615e5..8ace5221c082 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -31,8 +31,8 @@ ) from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available -from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput from ...modeling_utils import PreTrainedModel, get_parameter_device from ...utils import ( auto_docstring, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 72b705ef5e34..54003f3aa3d9 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -33,6 +33,7 @@ _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -42,7 +43,6 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -1135,9 +1135,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 7bb73bf33320..64a3d8a6200d 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -23,8 +23,8 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 5f2c83acb045..a7b3c0647033 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -36,7 +37,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import ModelOutput, auto_docstring, logging diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 2e392a2842cd..ad8440f3e3dd 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -32,6 +32,7 @@ _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -41,7 +42,6 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging @@ -2290,9 +2290,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 4387a2cd6af5..8a0c43eafd3f 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -20,7 +20,6 @@ # limitations under the License. import math -from functools import partial from typing import Callable, Optional, Union import torch @@ -32,13 +31,13 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, is_torch_flex_attn_available, logging diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 9ccf140c31c2..ff28abd0398b 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -34,6 +34,7 @@ _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -41,7 +42,6 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -1088,9 +1088,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 1ca3353e5627..16f1d3b5f197 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -32,6 +32,7 @@ _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -39,7 +40,6 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -1071,9 +1071,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 4ae1668c011a..414426844513 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -34,7 +35,6 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index a0a811edfae5..7af3bc5a6267 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN, QuickGELUActivation +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -32,7 +33,6 @@ ModelOutput, SequenceClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index d9c4d4e95e83..dfe36f7a353b 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -24,13 +24,13 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging from .configuration_bros import BrosConfig diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index d4ce40d2d5ed..5c49acc7a80a 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN, gelu from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -37,7 +38,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 75435799ab7b..a5f5552b78e7 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, ModelOutput, @@ -34,7 +35,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 784d97adf9aa..8bb6b8bc2540 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -27,8 +27,8 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 540ba8a4e596..f73518458c22 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -23,13 +23,13 @@ from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 4142d75b1a8e..1a9e5a250af2 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -24,12 +24,12 @@ from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index e3c0b36710c1..3e8a898b35df 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -23,8 +23,8 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index ad8f6221ddcc..2eb864715187 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -24,8 +24,8 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 547c96b57c2e..87ddbf7a3225 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -23,8 +23,8 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_timm_available, logging, requires_backends from ...utils.backbone_utils import load_backbone diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index d13bbfa14e27..bdac1fecc1c3 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, MaskedLMOutput, @@ -33,7 +34,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index 36f86ca3ba3e..5b177342472c 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -23,8 +23,8 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 1e7183a05ae9..4afed9fcea33 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -33,6 +33,7 @@ from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -41,7 +42,6 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index b42f9f56d5ba..03e286a89dfd 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN, gelu from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -34,7 +35,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 0769b455ddf3..3b593a834db0 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -26,13 +26,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, SemanticSegmenterOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 12b7b00bc93e..9446caa902d2 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -26,8 +26,8 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available -from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_dbrx import DbrxConfig diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 6cf829714e01..c6dd97736c57 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -22,6 +22,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -29,7 +30,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_deberta import DebertaConfig diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index bd341eb213fd..9089fe1f6504 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -31,7 +32,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_deberta_v2 import DebertaV2Config diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index aceef4ae41fb..ca715532fb3c 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -25,8 +25,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache -from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 43908c7548d4..4afa79405203 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -27,8 +27,8 @@ from ...activations import ACT2FN from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import meshgrid from ...utils import ( diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 53f870b3d008..4250c1180bc6 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -24,13 +24,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedImageModelingOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int diff --git a/src/transformers/models/deprecated/deta/modeling_deta.py b/src/transformers/models/deprecated/deta/modeling_deta.py index 214a41204a01..64912b512ff9 100644 --- a/src/transformers/models/deprecated/deta/modeling_deta.py +++ b/src/transformers/models/deprecated/deta/modeling_deta.py @@ -39,9 +39,9 @@ replace_return_docstrings, ) from ....modeling_attn_mask_utils import _prepare_4d_attention_mask +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutput from ....modeling_utils import PreTrainedModel -from ....modeling_layers import GradientCheckpointingLayer from ....pytorch_utils import meshgrid from ....utils import is_accelerate_available, is_ninja_available, is_torchvision_available, logging, requires_backends from ....utils.backbone_utils import load_backbone diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index 139326d44c9d..7bc835cf13df 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -26,8 +26,8 @@ from ....integrations.deepspeed import is_deepspeed_zero3_enabled from ....integrations.fsdp import is_fsdp_managed_module from ....modeling_attn_mask_utils import _prepare_4d_attention_mask -from ....modeling_outputs import BaseModelOutput, CausalLMOutput from ....modeling_layers import GradientCheckpointingLayer +from ....modeling_outputs import BaseModelOutput, CausalLMOutput from ....modeling_utils import ( PreTrainedModel, apply_chunking_to_forward, diff --git a/src/transformers/models/deprecated/nezha/modeling_nezha.py b/src/transformers/models/deprecated/nezha/modeling_nezha.py index 28572381f8e5..2ef4a560952e 100644 --- a/src/transformers/models/deprecated/nezha/modeling_nezha.py +++ b/src/transformers/models/deprecated/nezha/modeling_nezha.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ....activations import ACT2FN +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -36,7 +37,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ....utils import ( diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 5e99841df762..848f3f971e05 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -29,8 +29,8 @@ from ....activations import ACT2FN from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ....modeling_layers import GradientCheckpointingLayer +from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ....modeling_utils import PreTrainedModel from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_open_llama import OpenLlamaConfig diff --git a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py index 61bc778e1349..df3fce3b5205 100755 --- a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ....activations import ACT2FN +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -37,7 +38,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ....utils import ( diff --git a/src/transformers/models/deprecated/realm/modeling_realm.py b/src/transformers/models/deprecated/realm/modeling_realm.py index ec3df4672959..e88a75bd1bf2 100644 --- a/src/transformers/models/deprecated/realm/modeling_realm.py +++ b/src/transformers/models/deprecated/realm/modeling_realm.py @@ -24,13 +24,13 @@ from torch.nn import CrossEntropyLoss from ....activations import ACT2FN +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput, ModelOutput, ) -from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings diff --git a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py index 43bfc042155c..0599c3b592f1 100755 --- a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py @@ -24,8 +24,8 @@ from ....activations import ACT2FN from ....modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask -from ....modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ....modeling_layers import GradientCheckpointingLayer +from ....modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ....modeling_utils import PreTrainedModel from ....utils import add_start_docstrings, logging, replace_return_docstrings from .configuration_speech_to_text_2 import Speech2Text2Config @@ -619,9 +619,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py index 61f38e215de7..5280248c59e1 100644 --- a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py +++ b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py @@ -26,8 +26,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ....activations import ACT2FN -from ....modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ....modeling_layers import GradientCheckpointingLayer +from ....modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ....modeling_utils import PreTrainedModel from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ....utils import ( diff --git a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py index 81af62c3a448..03bcc24beb6b 100644 --- a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py @@ -24,8 +24,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ....activations import ACT2FN -from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ....modeling_layers import GradientCheckpointingLayer +from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ....modeling_utils import PreTrainedModel from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ....utils import ( diff --git a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py index 124525b7385f..57f2c2610e82 100644 --- a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -26,8 +26,8 @@ from torch.nn import LayerNorm from ....activations import ACT2FN -from ....modeling_outputs import BaseModelOutput from ....modeling_layers import GradientCheckpointingLayer +from ....modeling_outputs import BaseModelOutput from ....modeling_utils import PreTrainedModel from ....utils import ( ModelOutput, @@ -1552,9 +1552,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attn_mask=extended_encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 21e84354d08b..e01e629f3392 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -23,8 +23,8 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 7c0cbd6b28cb..7b023242a1cf 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -23,8 +23,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index a4b844665e12..adbb34c2fd45 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -28,8 +28,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index bcaeaa5d7372..8e2adfaf9faa 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -31,6 +31,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -39,7 +40,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ( apply_chunking_to_forward, diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 5c61e911b281..82ff615afc25 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -29,8 +29,8 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index dd23f766d029..68ce41f32d27 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN, get_activation from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions, @@ -36,7 +37,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index f364a2c20632..5171bfc73b99 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -37,7 +38,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index ea9458e560da..d4af2eac2527 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -31,7 +32,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging from .configuration_esm import EsmConfig diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index c2d82af752ff..32d60a582ef4 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -30,6 +30,7 @@ AttentionMaskConverter, ) from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -38,7 +39,6 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index b0e5f89193b5..1d1ce04c3563 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -25,8 +25,8 @@ from torch import nn from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int from .configuration_flava import ( diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 858a0d8474c5..fed31339da7c 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -31,6 +31,7 @@ from scipy import linalg from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -42,7 +43,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import logging diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 6d2fbc6069c3..47fa9d4f2eb7 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -25,8 +25,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BackboneOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from ...utils.backbone_utils import BackboneMixin diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index c763a155b595..7008538c7ab0 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Union import torch @@ -30,6 +29,7 @@ from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -38,7 +38,6 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, logging from ...utils.deprecation import deprecate_kwarg diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index bd1315d9e974..a3a476ae617c 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -22,7 +22,6 @@ import copy from collections.abc import Callable from dataclasses import dataclass -from functools import partial from typing import Optional, Union import torch @@ -33,8 +32,8 @@ from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 17fa701a2127..8058c542e9d2 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -27,13 +27,13 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 0a9b403053b8..551e887c09a7 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -31,9 +31,9 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple from ..auto import AutoModel diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 537554c1b18a..ea0c62ef11d8 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -30,6 +30,7 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -37,7 +38,6 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 063ba82233e3..7e098e1e498f 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -25,13 +25,13 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import is_torch_greater_or_equal_than_2_2 from ...utils import ( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index bbf493bc80a4..1c42d8b2dfaf 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -27,6 +27,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, @@ -36,7 +37,6 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 369e65793e86..d3c5141371b2 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -14,6 +14,7 @@ from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -22,7 +23,6 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 2b501c9b54b3..721bfd7b179e 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -28,13 +28,13 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 7f39f56ded61..362d170ffa81 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -25,8 +25,8 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging, torch_int from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig diff --git a/src/transformers/models/hiera/modeling_hiera.py b/src/transformers/models/hiera/modeling_hiera.py index a0664c80fcc5..e086b432ba50 100644 --- a/src/transformers/models/hiera/modeling_hiera.py +++ b/src/transformers/models/hiera/modeling_hiera.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BackboneOutput, BaseModelOutput, @@ -32,7 +33,6 @@ ModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...utils import auto_docstring, logging, torch_int from ...utils.backbone_utils import BackboneMixin from .configuration_hiera import HieraConfig diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index e51984c44900..0fab4184bfe3 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -32,9 +32,9 @@ from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_hubert import HubertConfig diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index 098f4966b442..d75d61545ec2 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -23,8 +23,8 @@ from torch import nn from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import ( ModelOutput, diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 704fa2785b0c..1f3f96de6303 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -26,8 +26,8 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 64f82e2d4e59..56750bc5298c 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -26,8 +26,8 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 44fe5c7b083f..5568b4ebcc8b 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -12,8 +12,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 58f43ede5225..b82e7ff04478 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -27,12 +27,12 @@ from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 0805082cd34d..c0f4eddb1a3c 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -34,6 +34,7 @@ _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -42,7 +43,6 @@ Seq2SeqTSPredictionOutput, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import auto_docstring, is_torch_flex_attn_available, logging @@ -1294,9 +1294,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index ea94d8b917c9..329dd4319e50 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -31,8 +31,8 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 42480dbe4bfd..c0a1af1eccb3 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -27,9 +27,9 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_jetmoe import JetMoeConfig diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index a3144c0ed67b..74a62e99836f 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -25,13 +25,13 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, CausalLMOutputWithCrossAttentions, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, logging, torch_int @@ -1136,9 +1136,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index f3bf7092c290..97df2506eac0 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -31,7 +32,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index f9ec7df94f8f..69058a043711 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -30,7 +31,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import auto_docstring, is_detectron2_available, logging, requires_backends diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 0261e14c77ff..1b6398a382d4 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -25,13 +25,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 7141d9a02d61..76cb3ced8c5d 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -27,8 +27,8 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask -from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from .configuration_led import LEDConfig @@ -1938,9 +1938,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 1d27e9a0d292..91664c32facb 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -30,7 +31,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index cfbf4edf0451..fb546ad68162 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -29,8 +29,8 @@ from ...integrations.hub_kernels import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_chunked_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index d740b35d774c..79e069c50a08 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -27,13 +27,13 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 1268048c0872..af01cf77be4f 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -24,8 +24,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import ModelOutput, auto_docstring, logging diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index c617d5f47607..358ab822a325 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -34,13 +34,13 @@ from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 5b5beed6c880..0e3ee3c62415 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -33,6 +33,7 @@ _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -40,7 +41,6 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -1085,9 +1085,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 9b07b3339f7d..3a1a9517e98e 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -32,7 +33,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ( PreTrainedModel, apply_chunking_to_forward, diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 50d8496df74d..5827fcdb0e3f 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -25,8 +25,8 @@ from ...activations import ACT2FN from ...file_utils import ModelOutput, is_scipy_available, requires_backends -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_accelerate_available, logging from ...utils.backbone_utils import load_backbone diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index f13c00b045e6..b8c9caf9c989 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -25,8 +25,8 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutputWithCrossAttentions from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 47f80eddd0d9..68c291f1b108 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -26,8 +26,8 @@ from ...activations import ACT2FN from ...file_utils import ModelOutput -from ...modeling_outputs import BackboneOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import torch_int diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 5fdc21c3126b..4711af692558 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -34,6 +34,7 @@ from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -43,7 +44,6 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -1128,9 +1128,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 97ee1f05c99e..941d62b58696 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -39,7 +40,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index f200ecb9b868..130ce34e0e88 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -26,9 +26,9 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_torch_flex_attn_available, logging from .configuration_mimi import MimiConfig diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index a3a9aa1c6502..8f82b59e5e47 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -24,7 +24,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Union import torch @@ -37,6 +36,7 @@ from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, MoeCausalLMOutputWithPast, @@ -46,7 +46,6 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index e1f23ca84b9d..26a12cab8bac 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -25,8 +25,8 @@ from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, torch_int diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index c643951b411c..69a1e703ea04 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -27,8 +27,8 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index eb16584579a3..1b483fe958c0 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -25,13 +25,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithNoAttention, BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention, SemanticSegmenterOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index fdf0a3261bec..a52aedca7cf6 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -24,13 +24,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithNoAttention, BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention, SemanticSegmenterOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_mobilevitv2 import MobileViTV2Config diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index f2aae73bde4e..588496d1b964 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -38,7 +39,6 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_flash_attn_2_available, logging from ...utils.import_utils import is_triton_available diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index cd80766ed23e..a18856596de8 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -28,9 +28,9 @@ from ...generation import GenerationConfig, GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput, Seq2SeqLMOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 6fa1de4e9a76..b3005728e7b9 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -25,6 +25,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -32,7 +33,6 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_mpt import MptConfig diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index c5d63ba7705a..a7fd783d848d 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -25,6 +25,7 @@ from torch.utils.cpp_extension import load from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, MaskedLMOutput, @@ -33,7 +34,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, is_cuda_platform, is_ninja_available, is_torch_cuda_available, logging diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 4eb6fa1ec7fc..b552ed3e88f4 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -37,7 +38,6 @@ Seq2SeqSequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index beff1b6560f5..15b664f6ad0c 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -43,6 +43,7 @@ from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -50,7 +51,6 @@ ModelOutput, Seq2SeqLMOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging @@ -626,9 +626,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index ea9621b459b4..441c0e862b28 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -41,8 +41,8 @@ from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, ) -from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 3410acd2a7d4..9120912aeb1e 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -29,6 +29,7 @@ _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -38,7 +39,6 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_mvp import MvpConfig @@ -932,9 +932,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None), past_key_value=past_key_value, diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index bac137fd38b4..8fe7f3328f77 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -28,6 +28,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -36,7 +37,6 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 3920ed34b029..8e8526a773bd 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -32,13 +32,13 @@ _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, Seq2SeqMoEModelOutput, Seq2SeqMoEOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index d000d94f1c34..f5b940157ded 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, MaskedLMOutput, @@ -31,7 +32,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 3ea63f11cb31..6319d0b5d37e 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -24,10 +24,10 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import auto_docstring, logging from .configuration_olmoe import OlmoeConfig diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 195d0c6c1b4a..99dde2fabba8 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -26,9 +26,9 @@ from torch.cuda.amp import autocast from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...utils import ( ModelOutput, auto_docstring, diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 39f460432f3f..a619e46b4764 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -26,13 +26,13 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 9932c0e3dca7..ee4fb714f549 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -24,8 +24,8 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_vision_available, logging, torch_int from .configuration_owlv2 import Owlv2Config, Owlv2TextConfig, Owlv2VisionConfig diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index b4aa07c2cdd5..4e269cd46599 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -24,8 +24,8 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_vision_available, logging, torch_int from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 63c513112487..a95da766eb0c 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -33,6 +33,7 @@ _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -40,7 +41,6 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -1133,9 +1133,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index b929c74ec7ca..d0fdd8beded0 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -33,13 +33,13 @@ _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 08423e838b81..f2bbef331a08 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -30,6 +30,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -37,7 +38,6 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index c9a0ca85d6c8..95164a5f5dbd 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_phi.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -from functools import partial from typing import Callable, Optional, Union import torch @@ -15,6 +14,7 @@ from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -23,7 +23,6 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_phi import PhiConfig diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 785e5deb477b..02eb6955eafe 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -27,9 +27,9 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_flash_attention_utils import is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_phimoe import PhimoeConfig diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index f3c85e7ad538..254501f44cd3 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -25,6 +25,7 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -32,7 +33,6 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( @@ -1144,7 +1144,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.gradient_checkpointing and self.training and use_cache: - logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False if input_ids is not None and inputs_embeds is not None: diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 5fb4ff26885e..f1d5ab06d54b 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -23,9 +23,9 @@ from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_rope_utils import dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, logging diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index b34acbbca588..606870417086 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -36,6 +36,7 @@ _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -44,7 +45,6 @@ Seq2SeqModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging @@ -1062,9 +1062,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index d20194b8bbff..41b2f7d04800 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -28,8 +28,8 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index e0ec38f0d1ba..739aeb68bf9d 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -27,8 +27,8 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from .configuration_prophetnet import ProphetNetConfig @@ -1393,9 +1393,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attn_mask=extended_encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, diff --git a/src/transformers/models/pvt_v2/modeling_pvt_v2.py b/src/transformers/models/pvt_v2/modeling_pvt_v2.py index fd5e5d89bc42..b357cb5970a4 100644 --- a/src/transformers/models/pvt_v2/modeling_pvt_v2.py +++ b/src/transformers/models/pvt_v2/modeling_pvt_v2.py @@ -25,8 +25,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BackboneOutput, BaseModelOutput, ImageClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput, BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index acdf046c0423..62ad238cd87c 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -40,7 +40,6 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index b36cc2eff726..8e331e1bd0bd 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -25,8 +25,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import auto_docstring, logging from ..auto import AutoModel, AutoModelForCausalLM diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index da4af7c89ecb..e616f1047f70 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -32,6 +32,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, @@ -40,7 +41,6 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_qwen2_moe import Qwen2MoeConfig diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 0e4a521169b8..5e921c5de970 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -37,7 +37,6 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index f332b16a7e99..1f74f5e55893 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Union import torch @@ -32,6 +31,7 @@ from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, MoeCausalLMOutputWithPast, @@ -41,7 +41,6 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 8b6e37462d75..5a9cbf99e1f2 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -25,8 +25,8 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import auto_docstring, logging diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 3197cca5fafe..774fc46ef7fc 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -35,7 +36,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index e0440f3f48cf..6eef0f0fa548 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN, gelu from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -37,7 +38,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 2a7b22627775..b63ebef24b99 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN, gelu from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -35,7 +36,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index ca7a2e949c7e..222f7597964c 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -35,7 +36,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 6a5211de647f..8439fed19cfc 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN, get_activation from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -35,7 +36,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 24d891e3746b..31cdec6d5f7e 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -25,8 +25,8 @@ from torch import Tensor, nn from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py index bdde2599d697..a0ff9c309673 100644 --- a/src/transformers/models/sam_hq/modeling_sam_hq.py +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -29,8 +29,8 @@ from torch import Tensor, nn from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 86f245e340ca..0e0a0312e6ce 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -32,6 +32,7 @@ _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -39,7 +40,6 @@ Seq2SeqModelOutput, Wav2Vec2BaseModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from .configuration_seamless_m4t import SeamlessM4TConfig diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 1fd255b31831..2245b795304a 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -29,6 +29,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -36,7 +37,6 @@ Seq2SeqModelOutput, Wav2Vec2BaseModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from .configuration_seamless_m4t_v2 import SeamlessM4Tv2Config diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 304c1bfd0982..da4a54b39fc9 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -32,8 +32,8 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index eb05ea9f59ab..5e00ddcd1f3b 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -27,8 +27,8 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled -from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import softmax_backward_data from ...utils import auto_docstring, logging diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 0c39e64a5bb6..3cf16a992c9f 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -30,8 +30,8 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 0791976dcf66..a2adee4f0946 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -30,13 +30,13 @@ _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -939,9 +939,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 92fa3c95c9e1..e854980fd68a 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -28,6 +28,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -35,7 +36,6 @@ Seq2SeqModelOutput, Seq2SeqSpectrogramOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_speecht5 import SpeechT5Config, SpeechT5HifiGanConfig @@ -1627,9 +1627,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 4332713b6549..ba579e7d6266 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -24,8 +24,8 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index d98f40597bf9..0dc1d00890ea 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -31,6 +31,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -38,7 +39,6 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_stablelm import StableLmConfig diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 68daf506f628..c62c2e4fc950 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -25,8 +25,8 @@ from torch import nn from ...activations import ACT2FN -from ...modeling_outputs import BackboneOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index e083b28a6dd0..ae6e0a6e7952 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -24,8 +24,8 @@ from torch import nn from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, ImageSuperResolutionOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, ImageSuperResolutionOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 8cf738513759..67657f0111a6 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -25,8 +25,8 @@ from torch import Tensor, nn from ...activations import ACT2FN -from ...modeling_outputs import BackboneOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 3a25906a19b2..43984b9293a0 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -27,13 +27,13 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, Seq2SeqMoEModelOutput, Seq2SeqMoEOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index fcc528579a53..34bc70fac7cc 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -37,7 +38,6 @@ Seq2SeqSequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 321d3f176d9e..7722c476c396 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -23,8 +23,8 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index d6dfde622ea9..6d88de473664 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -26,8 +26,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 261e5c4ff765..5cf5993309e5 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -30,6 +30,7 @@ _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -37,7 +38,6 @@ Seq2SeqTSModelOutput, Seq2SeqTSPredictionOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput @@ -1064,9 +1064,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 9f4ee262430b..191a65f9b130 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -24,8 +24,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index d9ce3726fb38..c8617739d68e 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -28,8 +28,8 @@ _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) -from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_trocr import TrOCRConfig @@ -650,9 +650,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index f1a9421117a4..cd6e88df846a 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -23,8 +23,8 @@ from torch import nn from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index cc988e848972..2c4733d0977d 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -35,7 +36,6 @@ Seq2SeqSequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( DUMMY_INPUTS, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index cd5df6f43853..af8b151ac6b2 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -34,6 +34,7 @@ from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -41,7 +42,6 @@ SequenceClassifierOutput, Wav2Vec2BaseModelOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index d260cba8800b..35cadebd5ffa 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -34,6 +34,7 @@ from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -43,7 +44,6 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 0bdaa35833c9..a8278f0b892c 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -26,8 +26,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index a98fa55e0364..d42ac12605cd 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -33,7 +34,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 9da2ddf8f2a6..34adc5dd74d4 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -24,13 +24,13 @@ from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, MultipleChoiceModelOutput, SequenceClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 9e298dab1185..dbad9ef41f3e 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -24,13 +24,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedImageModelingOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index de5246224de5..32b7151169ae 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -25,8 +25,8 @@ from torch import nn from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index aade45a5f954..11155d2d081c 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -23,8 +23,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index 0ba71fae8b96..b74bc1008f70 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -23,8 +23,8 @@ from torch import nn from ...activations import ACT2FN -from ...modeling_outputs import BackboneOutput, BaseModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput, BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from ...utils.backbone_utils import BackboneMixin diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index 6594dfa3a162..fb22d215996c 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -27,8 +27,8 @@ from torch import nn from ...activations import ACT2FN -from ...modeling_outputs import BackboneOutput, BaseModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput, BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 2045397be372..e202f98070c3 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -27,8 +27,8 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_vits import VitsConfig diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 48b0fd852273..7011552db822 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -22,8 +22,8 @@ from torch.nn import CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index d234b024da1d..13fd216c67ac 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -33,6 +33,7 @@ _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -42,7 +43,6 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py index 4491c4dc40a3..5816d3ffed42 100644 --- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -17,6 +17,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -25,7 +26,6 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_peft_available from .configuration_wav2vec2_bert import Wav2Vec2BertConfig diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index a01fa908f2fb..02ea35dffdf8 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -17,6 +17,7 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -25,7 +26,6 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 3b9ec3a9cac3..acc7443a5f34 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -17,6 +17,7 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -25,7 +26,6 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_peft_available, logging from .configuration_wavlm import WavLMConfig diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8d9da86a297c..14cbaafe47dc 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -30,6 +30,7 @@ from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -38,7 +39,6 @@ Seq2SeqModelOutput, SequenceClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging @@ -961,9 +961,7 @@ def forward( attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values if use_cache else None, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 3632c67f93a5..90f495719467 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -24,8 +24,8 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( ModelOutput, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 66589a48871c..562821e7ec31 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -24,8 +24,8 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask -from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_xglm import XGLMConfig @@ -554,9 +554,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 9c79c3b664bd..cee4fa3837c5 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN, gelu from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -37,7 +38,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index a108a6a144a3..7cbeaadb184c 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN, gelu from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -36,7 +37,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 0f7ff562f1e1..84cf9f6d5349 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN, gelu from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -34,7 +35,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 01115516683f..c61ff8cb85af 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -23,8 +23,8 @@ from torch import nn from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 95a7163d4e0a..da35490c59ed 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, MaskedLMOutput, @@ -32,7 +33,6 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index f2eca2a1f1e2..04f7b94494a7 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -32,8 +32,8 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 64d2d256d7fc..b5067a5dd2e0 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -33,9 +33,9 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging From fe80395206385bcedc4745d7fd27369e1d530823 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 17:51:19 +0000 Subject: [PATCH 077/146] Add use_cache=False and past_key_value=None to GradientCheckpointingLayer --- src/transformers/modeling_layers.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/transformers/modeling_layers.py b/src/transformers/modeling_layers.py index 57be2d8e0d7d..b8807864e934 100644 --- a/src/transformers/modeling_layers.py +++ b/src/transformers/modeling_layers.py @@ -15,6 +15,9 @@ from functools import partial import torch.nn as nn +from transformers.utils import logging + +logger = logging.get_logger(__name__) class GradientCheckpointingLayer(nn.Module): @@ -44,5 +47,21 @@ class GradientCheckpointingLayer(nn.Module): def __call__(self, *args, **kwargs): if self.gradient_checkpointing and self.training: + + do_warn = False + if "use_cache" in kwargs and kwargs["use_cache"]: + kwargs["use_cache"] = False + do_warn = True + + if "past_key_value" in kwargs and kwargs["past_key_value"] is not None: + kwargs["past_key_value"] = None + do_warn = True + + if do_warn: + logger.warning( + "Caching is incompatible with gradient checkpointing. " + "Setting `use_cache=False` and `past_key_value=None`." + ) + return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args) return super().__call__(*args, **kwargs) From d7963dcccc0277ff32753c98366b6a7da2d3b8c3 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 19 Jun 2025 17:51:29 +0000 Subject: [PATCH 078/146] fixup --- src/transformers/modeling_layers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_layers.py b/src/transformers/modeling_layers.py index b8807864e934..0542797cddf1 100644 --- a/src/transformers/modeling_layers.py +++ b/src/transformers/modeling_layers.py @@ -15,8 +15,10 @@ from functools import partial import torch.nn as nn + from transformers.utils import logging + logger = logging.get_logger(__name__) @@ -47,7 +49,6 @@ class GradientCheckpointingLayer(nn.Module): def __call__(self, *args, **kwargs): if self.gradient_checkpointing and self.training: - do_warn = False if "use_cache" in kwargs and kwargs["use_cache"]: kwargs["use_cache"] = False From 73d56148648e228e69827d9a23bb3f41d7fa304e Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 11:47:13 +0000 Subject: [PATCH 079/146] fix prophetnet --- src/transformers/models/prophetnet/modeling_prophetnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 739aeb68bf9d..eb0b2e59471a 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1389,8 +1389,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=extended_attention_mask, - encoder_hidden_states=encoder_hidden_states, + extended_attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attn_mask=extended_encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From cd7a42636fb69ffd866f7bc8f3eda1a77bf7a221 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 11:51:07 +0000 Subject: [PATCH 080/146] fix bigbird_pegasus --- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index ad8440f3e3dd..465b94e13bee 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2286,8 +2286,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From 56cb34bf3660d69b8ca61afa469c91ec4bff59da Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 11:52:00 +0000 Subject: [PATCH 081/146] fix blenderbot --- src/transformers/models/blenderbot/modeling_blenderbot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index ff28abd0398b..9711737da8c9 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1084,8 +1084,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From 0347dde272fb4f62e3eaa0fb6ec2bf8c1e577d72 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 11:53:37 +0000 Subject: [PATCH 082/146] fix mbart --- src/transformers/models/mbart/modeling_mbart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 4711af692558..2585d91a3e3e 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1124,8 +1124,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From e83086d104656fc1ff3560945483282ed86f1edc Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 12:06:20 +0000 Subject: [PATCH 083/146] fix mvp --- src/transformers/models/mvp/modeling_mvp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 9120912aeb1e..9e5136d27ad0 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -928,8 +928,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From afbfd620778483d3bfe3b2fda812964c32e46e2b Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 12:13:33 +0000 Subject: [PATCH 084/146] fix zamba2 --- .../models/zamba2/modeling_zamba2.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index b5067a5dd2e0..ce6a0a2ffd5e 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -982,7 +982,7 @@ def forward(self, hidden_state, layer_idx=None): return output -class Zamba2AttentionDecoderLayer(GradientCheckpointingLayer): +class Zamba2AttentionDecoderLayer(nn.Module): def __init__(self, config: Zamba2Config, block_id: Optional[int] = None, layer_idx: Optional[int] = None): super().__init__() self.block_id = block_id @@ -1114,7 +1114,7 @@ def forward( return outputs -class Zamba2HybridLayer(nn.Module): +class Zamba2HybridLayer(GradientCheckpointingLayer): def __init__( self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer ): @@ -1352,14 +1352,14 @@ def forward( layer_outputs = layer( hidden_states, - original_hidden_states=original_hidden_states, - layer_idx=layer_idx, - attention_mask=attention_mask, - causal_mask=causal_mask, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - position_embeddings=position_embeddings, + original_hidden_states, + layer_idx, + attention_mask, + causal_mask, + past_key_values, + output_attentions, + use_cache, + position_embeddings, ) hidden_states = layer_outputs[0] From 68f317c74f8bffe34d3b85c6fa9b7e0c0687be72 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 12:15:11 +0000 Subject: [PATCH 085/146] fix bart --- src/transformers/models/bart/modeling_bart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 54003f3aa3d9..994bf9d85dca 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1131,8 +1131,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From 98fb67068899faea6e32b49238524d262ac19411 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 12:24:14 +0000 Subject: [PATCH 086/146] fix blenderbot_small --- .../models/blenderbot_small/modeling_blenderbot_small.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 16f1d3b5f197..550e51221929 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1067,8 +1067,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, + causal_mask, + encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From 2fd38a375896d970dd1e2ca8490e10f57725e8b3 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 12:43:48 +0000 Subject: [PATCH 087/146] fix codegen --- src/transformers/models/codegen/modeling_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 2eb864715187..d00528b14069 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -439,7 +439,7 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( - hidden_states=hidden_states, + hidden_states, layer_past=past_key_values, attention_mask=causal_mask, position_ids=position_ids, From 53477671005a3650040a25b512f5fc47a2fa6c76 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 12:44:21 +0000 Subject: [PATCH 088/146] Update gradient checkpointing layer to support more past_key_values arg names --- src/transformers/modeling_layers.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_layers.py b/src/transformers/modeling_layers.py index 0542797cddf1..12084a0590f6 100644 --- a/src/transformers/modeling_layers.py +++ b/src/transformers/modeling_layers.py @@ -50,19 +50,34 @@ class GradientCheckpointingLayer(nn.Module): def __call__(self, *args, **kwargs): if self.gradient_checkpointing and self.training: do_warn = False + layer_name = self.__class__.__name__ + message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting" + if "use_cache" in kwargs and kwargs["use_cache"]: kwargs["use_cache"] = False + message += " `use_cache=False`," do_warn = True + # different names for the same thing in different layers if "past_key_value" in kwargs and kwargs["past_key_value"] is not None: kwargs["past_key_value"] = None + message += " `past_key_value=None`," do_warn = True + if "past_key_values" in kwargs and kwargs["past_key_values"] is not None: + kwargs["past_key_values"] = None + message += " `past_key_values=None`," + do_warn = True + + if "layer_past" in kwargs and kwargs["layer_past"] is not None: + kwargs["layer_past"] = None + message += " `layer_past=None`," + do_warn = True + + # warn if anything was changed if do_warn: - logger.warning( - "Caching is incompatible with gradient checkpointing. " - "Setting `use_cache=False` and `past_key_value=None`." - ) + message = message.rstrip(",") + "." + logger.warning(message) return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args) return super().__call__(*args, **kwargs) From 10f5fd1f08c1f949066af0c19cebb596691ee067 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 12:48:17 +0000 Subject: [PATCH 089/146] fix data2vec vision --- .../models/data2vec/modeling_data2vec_vision.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 3b593a834db0..381d354e3e95 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -528,7 +528,7 @@ def forward( output_attentions: bool = False, relative_position_bias: Optional[torch.Tensor] = None, interpolate_pos_encoding: bool = False, - resolution: Optional[tuple[int]] = None, + resolution: Optional[tuple[int, int]] = None, ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention @@ -702,11 +702,11 @@ def forward( layer_outputs = layer_module( hidden_states, - layer_head_mask, - output_attentions, - relative_position_bias, - interpolate_pos_encoding, - resolution, + head_mask=layer_head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, ) hidden_states = layer_outputs[0] From 792cd7dad56c79b32fb3c84d7313ff4e151b2389 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 12:53:06 +0000 Subject: [PATCH 090/146] fix deformable_detr --- .../deformable_detr/modeling_deformable_detr.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 4afa79405203..63ec6086c96e 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1263,14 +1263,14 @@ def forward( layer_outputs = decoder_layer( hidden_states, - position_embeddings=position_embeddings, - encoder_hidden_states=encoder_hidden_states, - reference_points=reference_points_input, - spatial_shapes=spatial_shapes, - spatial_shapes_list=spatial_shapes_list, - level_start_index=level_start_index, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, + position_embeddings, + reference_points_input, + spatial_shapes, + spatial_shapes_list, + level_start_index, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask, + output_attentions, ) hidden_states = layer_outputs[0] From 36415c37828e159dcfb6dff9dcbb032033430c44 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 12:56:35 +0000 Subject: [PATCH 091/146] fix gptj --- src/transformers/models/gptj/modeling_gptj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 721bfd7b179e..8189acfe76cf 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -735,7 +735,7 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( - hidden_states=hidden_states, + hidden_states, layer_past=past_key_values, attention_mask=causal_mask, position_ids=position_ids, From ff802fede13a43a1b42ca96403d34e7058bd1e6c Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 13:11:17 +0000 Subject: [PATCH 092/146] fix led --- src/transformers/models/led/modeling_led.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 76cb3ced8c5d..ad095cbcd472 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1934,8 +1934,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, + combined_attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From fc140147e4b91b6ab78361c74b31640d97a21ec2 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 13:21:50 +0000 Subject: [PATCH 093/146] fix m2m_100 --- src/transformers/models/m2m_100/modeling_m2m_100.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 358ab822a325..f3f621338b67 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1136,8 +1136,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( From f2cc8652fcde9a88840b36ac2418597cf7a27a9e Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 13:22:21 +0000 Subject: [PATCH 094/146] add comment --- src/transformers/models/m2m_100/modeling_m2m_100.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index f3f621338b67..7d5a73667ee4 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1137,7 +1137,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, attention_mask, - encoder_hidden_states, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( From eab402d373797f8064101e0ed05673c4ec8830c6 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 13:35:15 +0000 Subject: [PATCH 095/146] fix nnlb_moe --- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 8e8526a773bd..f498cf743fcc 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1290,8 +1290,8 @@ def forward( # under fsdp or deepspeed zero3 all gpus must run in sync layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, From aa1f574f8d119ae4296df681c444f33444c5210c Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 13:37:22 +0000 Subject: [PATCH 096/146] Fix pegasus_x --- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index d0fdd8beded0..842e365fb827 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1382,8 +1382,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_values, output_attentions=output_attentions, From 7c9d17d65153fb963c44e28b0143b591d9eea44c Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 13:38:37 +0000 Subject: [PATCH 097/146] fix plbart --- src/transformers/models/plbart/modeling_plbart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 606870417086..327b70b5ec73 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1058,8 +1058,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From 5da2216e424ec43e98aef69fcf6b52a30757c867 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 14:07:18 +0000 Subject: [PATCH 098/146] udop --- src/transformers/models/udop/modeling_udop.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 79c5c2ca399b..9e9300c90b01 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -39,6 +39,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_utils import PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, @@ -743,7 +744,7 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop -class UdopBlock(nn.Module): +class UdopBlock(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -1295,9 +1296,9 @@ def forward( layer_outputs = layer_module( hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=head_mask[i], From 999584c420b2cff4ee18b61704c1e9cfa0a73b96 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 14:08:56 +0000 Subject: [PATCH 099/146] fix-copies: beit, wav2vec2 --- src/transformers/models/beit/modeling_beit.py | 12 ++++++------ .../models/wav2vec2/modeling_wav2vec2.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index e830e66285a5..347471fc7f7a 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -526,7 +526,7 @@ def forward( output_attentions: bool = False, relative_position_bias: Optional[torch.Tensor] = None, interpolate_pos_encoding: bool = False, - resolution: Optional[tuple[int]] = None, + resolution: Optional[tuple[int, int]] = None, ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention @@ -698,11 +698,11 @@ def forward( layer_outputs = layer_module( hidden_states, - layer_head_mask, - output_attentions, - relative_position_bias, - interpolate_pos_encoding, - resolution, + head_mask=layer_head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 13fd216c67ac..bec9f46a44fc 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -272,7 +272,7 @@ def _sample_negative_indices( return sampled_negative_indices -class Wav2Vec2NoLayerNormConvLayer(nn.Module): +class Wav2Vec2NoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -293,7 +293,7 @@ def forward(self, hidden_states): return hidden_states -class Wav2Vec2LayerNormConvLayer(nn.Module): +class Wav2Vec2LayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -320,7 +320,7 @@ def forward(self, hidden_states): return hidden_states -class Wav2Vec2GroupNormConvLayer(nn.Module): +class Wav2Vec2GroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 From ff33682659aff2b6c7c8f73a61a114f379c4a15f Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 14:13:34 +0000 Subject: [PATCH 100/146] fix gpt_bigcode --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 7e098e1e498f..851d03727de9 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -755,6 +755,12 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: From c28e913f8398e54363855fafbdd2176bbefd6df1 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 14:13:55 +0000 Subject: [PATCH 101/146] fixup --- src/transformers/modeling_layers.py | 4 ++-- .../models/deformable_detr/modeling_deformable_detr.py | 2 +- src/transformers/models/udop/modeling_udop.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_layers.py b/src/transformers/modeling_layers.py index 12084a0590f6..5179cfa6571e 100644 --- a/src/transformers/modeling_layers.py +++ b/src/transformers/modeling_layers.py @@ -52,7 +52,7 @@ def __call__(self, *args, **kwargs): do_warn = False layer_name = self.__class__.__name__ message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting" - + if "use_cache" in kwargs and kwargs["use_cache"]: kwargs["use_cache"] = False message += " `use_cache=False`," @@ -73,7 +73,7 @@ def __call__(self, *args, **kwargs): kwargs["layer_past"] = None message += " `layer_past=None`," do_warn = True - + # warn if anything was changed if do_warn: message = message.rstrip(",") + "." diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 63ec6086c96e..26298a0f6b21 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1268,7 +1268,7 @@ def forward( spatial_shapes, spatial_shapes_list, level_start_index, - encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask, output_attentions, ) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 9e9300c90b01..a27e08fe8872 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -38,8 +38,8 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_utils import PreTrainedModel from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, From 8104bfb9f09fd7dbec73bb683be4cade172cfd52 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 14:42:59 +0000 Subject: [PATCH 102/146] fix t5 --- src/transformers/models/t5/modeling_t5.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 34bc70fac7cc..98e382d58a83 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1106,11 +1106,11 @@ def forward( layer_outputs = layer_module( hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_values, From f9a2db8121376cfd3e7177993953137ba1e7f643 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 14:45:19 +0000 Subject: [PATCH 103/146] fix switch_transformers --- .../modeling_switch_transformers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 43984b9293a0..a7a75d98d167 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1030,11 +1030,11 @@ def forward( layer_outputs = layer_module( hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_values, From fe1133ecf0c62e26c7a612c24aa43105ad23e9c9 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 14:56:19 +0000 Subject: [PATCH 104/146] fix longt5 --- src/transformers/models/longt5/modeling_longt5.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 79e069c50a08..41f441eb6ed7 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1506,11 +1506,11 @@ def forward( layer_outputs = layer_module( hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_values, From e51772a19ee8d0c6515ed012f0412b4f7f71c5f2 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 14:58:45 +0000 Subject: [PATCH 105/146] fix mt5 --- src/transformers/models/mt5/modeling_mt5.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index b552ed3e88f4..5584b2ee8255 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1091,11 +1091,11 @@ def forward( layer_outputs = layer_module( hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_values, From 69a6a784a8f3abc7cd37ede87e94e5de97c6c0e1 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:01:21 +0000 Subject: [PATCH 106/146] update tapas --- src/transformers/models/tapas/modeling_tapas.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 6d88de473664..903ff66dac9a 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -594,12 +594,12 @@ def forward( layer_outputs = layer_module( hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_values, - output_attentions, + attention_mask=attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: From eb20826e0b53e48190e2075556d193bdff97dce1 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:05:19 +0000 Subject: [PATCH 107/146] fix blip2 --- src/transformers/models/blip_2/modeling_blip_2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 4382296969e4..78cd751bd9d3 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -992,11 +992,11 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + query_length=query_length, ) hidden_states = layer_outputs[0] From eba9a9acdd3362c7647e7e8e6001df2c0ec62618 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:06:03 +0000 Subject: [PATCH 108/146] update blip --- src/transformers/models/blip/modeling_blip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 3967f8bce478..e43b79595ca3 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -552,7 +552,7 @@ def forward( layer_outputs = encoder_layer( hidden_states, - attention_mask, + attention_mask=attention_mask, output_attentions=output_attentions, ) From aa713092d601fef2d94ac7922f0de73aae1a887c Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:09:43 +0000 Subject: [PATCH 109/146] fix musicgen --- src/transformers/models/musicgen/modeling_musicgen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 15b664f6ad0c..54fd0b31bb3b 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -622,8 +622,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From c0b30846c3e6d58cc86411f20bf09745e4d300d2 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:19:06 +0000 Subject: [PATCH 110/146] fix gpt2, trocr --- src/transformers/models/gpt2/modeling_gpt2.py | 10 +++++----- src/transformers/models/trocr/modeling_trocr.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index ea0c62ef11d8..0de1032342f0 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -925,11 +925,11 @@ def forward( outputs = block( hidden_states, - past_key_value=past_key_values, - cache_position=cache_position, - attention_mask=causal_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, + past_key_values if not (self.gradient_checkpointing and self.training) else None, + cache_position, + causal_mask, + head_mask[i], + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index c8617739d68e..fdc0cae068a8 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -646,8 +646,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From b6ac147d303300449e0de99c59bfdb682f057837 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:34:36 +0000 Subject: [PATCH 111/146] fix copies --- src/transformers/models/blip_2/modeling_blip_2.py | 2 +- .../models/instructblip/modeling_instructblip.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 78cd751bd9d3..04b20b7513a3 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -531,7 +531,7 @@ def forward( layer_outputs = encoder_layer( hidden_states, - attention_mask, + attention_mask=attention_mask, output_attentions=output_attentions, ) diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 1708a86082bc..bf2c76cf9e59 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -427,7 +427,7 @@ def forward( layer_outputs = encoder_layer( hidden_states, - attention_mask, + attention_mask=attention_mask, output_attentions=output_attentions, ) @@ -889,11 +889,11 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + query_length=query_length, ) hidden_states = layer_outputs[0] From 481ae6f0806b993001d58c47cb9689d681944aa0 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:38:58 +0000 Subject: [PATCH 112/146] !!! Revert zamba, mllama --- .../models/mllama/modeling_mllama.py | 70 ++++++++++++------- .../models/zamba/modeling_zamba.py | 41 +++++++---- 2 files changed, 73 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 69a1e703ea04..841674f0c7fb 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -27,7 +27,6 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -281,7 +280,7 @@ def forward( return attn_output, attn_weights -class MllamaVisionEncoderLayer(GradientCheckpointingLayer): +class MllamaVisionEncoderLayer(nn.Module): def __init__(self, config: MllamaVisionConfig, is_gated: bool = False): super().__init__() @@ -388,12 +387,19 @@ def forward( for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - - layer_outputs = encoder_layer( - hidden_state=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) @@ -663,7 +669,7 @@ def forward(self, x): # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer -class MllamaSelfAttentionDecoderLayer(GradientCheckpointingLayer): +class MllamaSelfAttentionDecoderLayer(nn.Module): def __init__(self, config: MllamaTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -748,7 +754,7 @@ def forward( return outputs -class MllamaCrossAttentionDecoderLayer(GradientCheckpointingLayer): +class MllamaCrossAttentionDecoderLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention and feedforward.""" def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: @@ -1396,20 +1402,36 @@ def forward( if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: continue - layer_outputs = decoder_layer( - hidden_states, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - attention_mask=causal_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + cross_attention_states, + cross_attention_mask, + causal_mask, + full_text_row_masked_out_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + attention_mask=causal_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 04f7b94494a7..ea832692b7b9 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -32,7 +32,6 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -581,7 +580,7 @@ def forward(self, x): return down_proj -class ZambaAttentionDecoderLayer(GradientCheckpointingLayer): +class ZambaAttentionDecoderLayer(nn.Module): def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None): super().__init__() self.self_attn = ZambaAttention(config, layer_idx) @@ -644,7 +643,7 @@ def forward( return outputs -class ZambaMambaDecoderLayer(GradientCheckpointingLayer): +class ZambaMambaDecoderLayer(nn.Module): def __init__(self, config: ZambaConfig, layer_idx: int): super().__init__() self.mamba = ZambaMambaMixer(config=config, layer_idx=layer_idx) @@ -976,17 +975,31 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = layer( - hidden_states, - original_hidden_states=original_hidden_states, - layer_idx=layer_idx, - attention_mask=attention_mask, - causal_mask=causal_mask, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + original_hidden_states, + layer_idx, + attention_mask, + causal_mask, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = layer( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + causal_mask=causal_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if output_attentions: From 07e0995728515089dc62d39b6aba51a968c5c827 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:45:59 +0000 Subject: [PATCH 113/146] update autoformer --- src/transformers/models/autoformer/modeling_autoformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 46dde3f146c5..be8c8c621a1c 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -1228,8 +1228,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From a2c8bd6f18eeb77a2ce85a9e23739357ab0b419e Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:46:07 +0000 Subject: [PATCH 114/146] update bros --- src/transformers/models/bros/modeling_bros.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index dfe36f7a353b..965782ac08b3 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -552,11 +552,11 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( - hidden_states=hidden_states, - bbox_pos_emb=bbox_pos_emb, - attention_mask=attention_mask, - head_mask=layer_head_mask, - encoder_hidden_states=encoder_hidden_states, + hidden_states, + bbox_pos_emb, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, From 31607539a909064377bf4f66722921949c2387b4 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:51:57 +0000 Subject: [PATCH 115/146] update args / kwargs for BERT and copies --- src/transformers/models/align/modeling_align.py | 8 ++++---- src/transformers/models/bert/modeling_bert.py | 8 ++++---- .../models/bert_generation/modeling_bert_generation.py | 8 ++++---- .../models/chinese_clip/modeling_chinese_clip.py | 8 ++++---- src/transformers/models/clap/modeling_clap.py | 8 ++++---- .../models/data2vec/modeling_data2vec_text.py | 8 ++++---- src/transformers/models/electra/modeling_electra.py | 8 ++++---- src/transformers/models/ernie/modeling_ernie.py | 8 ++++---- src/transformers/models/layoutlm/modeling_layoutlm.py | 8 ++++---- src/transformers/models/markuplm/modeling_markuplm.py | 8 ++++---- src/transformers/models/roberta/modeling_roberta.py | 8 ++++---- .../roberta_prelayernorm/modeling_roberta_prelayernorm.py | 8 ++++---- src/transformers/models/roc_bert/modeling_roc_bert.py | 8 ++++---- src/transformers/models/splinter/modeling_splinter.py | 8 ++++---- 14 files changed, 56 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 55ce72472264..6ff99d6a4918 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -958,10 +958,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 743fba8a0902..e508a98614af 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -652,10 +652,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 64a3d8a6200d..bd65e88ae855 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -406,10 +406,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index f73518458c22..8d676fc50139 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -821,10 +821,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 1a9e5a250af2..973ae11fd37f 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1481,10 +1481,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 03e286a89dfd..97bca6d0d69e 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -506,10 +506,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 68ce41f32d27..81eb2d894d61 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -567,10 +567,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 5171bfc73b99..dda93fb81c9d 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -492,10 +492,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 97df2506eac0..fffbe7b061a9 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -489,10 +489,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 3a1a9517e98e..80a6011a4eef 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -649,10 +649,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 6eef0f0fa548..ecf3a6cc5314 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -608,10 +608,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index b63ebef24b99..e8636281e399 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -496,10 +496,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 222f7597964c..3985e86e0b36 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -619,10 +619,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index ba579e7d6266..1d65ec5b954c 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -461,10 +461,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] From 1f3d7b0a24240e0f67d6d59310ba8ba6dc20abd9 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:53:24 +0000 Subject: [PATCH 116/146] 2nd round of updates --- src/transformers/models/altclip/modeling_altclip.py | 8 ++++---- .../models/bridgetower/modeling_bridgetower.py | 8 ++++---- src/transformers/models/camembert/modeling_camembert.py | 8 ++++---- .../models/xlm_roberta/modeling_xlm_roberta.py | 8 ++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 76fdf62a3517..41d4c595c8d1 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -549,10 +549,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 7af3bc5a6267..bb9ac6c1bd37 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -793,10 +793,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 5c49acc7a80a..3dff8f3b2cf0 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -609,10 +609,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index cee4fa3837c5..f66396b135c5 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -609,10 +609,10 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] From 32433aacc96076b842c88bb84f2f5548a3dfdb15 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 15:56:37 +0000 Subject: [PATCH 117/146] update conditional detr --- .../conditional_detr/modeling_conditional_detr.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 87ddbf7a3225..2042817a2107 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1298,13 +1298,14 @@ def forward( pos_transformation = self.query_scale(hidden_states) # apply transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation + layer_outputs = decoder_layer( hidden_states, - attention_mask=None, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - query_sine_embed=query_sine_embed, - encoder_hidden_states=encoder_hidden_states, + None, # attention_mask + object_queries, + query_position_embeddings, + query_sine_embed, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, is_first=(idx == 0), From 95365b167bfceb786f7abd7c65d0757c5c45919a Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 17:48:27 +0000 Subject: [PATCH 118/146] Pass encoder_hidden_states as positional arg --- src/transformers/models/dab_detr/modeling_dab_detr.py | 10 +++++----- .../modeling_decision_transformer.py | 10 +++++----- src/transformers/models/detr/modeling_detr.py | 8 ++++---- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 8 ++++---- src/transformers/models/imagegpt/modeling_imagegpt.py | 8 ++++---- src/transformers/models/informer/modeling_informer.py | 4 ++-- src/transformers/models/kosmos2/modeling_kosmos2.py | 4 ++-- src/transformers/models/marian/modeling_marian.py | 4 ++-- .../models/mask2former/modeling_mask2former.py | 8 ++++---- .../models/maskformer/modeling_maskformer.py | 8 ++++---- .../models/moonshine/modeling_moonshine.py | 4 ++-- src/transformers/models/pegasus/modeling_pegasus.py | 4 ++-- .../models/pix2struct/modeling_pix2struct.py | 10 +++++----- .../models/pop2piano/modeling_pop2piano.py | 10 +++++----- .../models/seamless_m4t/modeling_seamless_m4t.py | 4 ++-- .../models/seamless_m4t_v2/modeling_seamless_m4t_v2.py | 4 ++-- .../models/speech_to_text/modeling_speech_to_text.py | 4 ++-- src/transformers/models/speecht5/modeling_speecht5.py | 4 ++-- .../table_transformer/modeling_table_transformer.py | 8 ++++---- src/transformers/models/tapas/modeling_tapas.py | 6 +++--- .../modeling_time_series_transformer.py | 4 ++-- src/transformers/models/umt5/modeling_umt5.py | 4 ++-- src/transformers/models/xglm/modeling_xglm.py | 4 ++-- 23 files changed, 71 insertions(+), 71 deletions(-) diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index 5b177342472c..3f28bb95767b 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -1132,11 +1132,11 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=None, - object_queries=object_queries, - query_position_embeddings=query_pos, - query_sine_embed=query_sine_embed, - encoder_hidden_states=encoder_hidden_states, + None, # attention_mask + object_queries, + query_pos, + query_sine_embed, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=memory_key_padding_mask, output_attentions=output_attentions, ) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index ca715532fb3c..6555c03919e3 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -657,11 +657,11 @@ def forward( outputs = block( hidden_states, - past_key_value=past_key_values, - cache_position=cache_position, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, + past_key_values if not (self.gradient_checkpointing and self.training) else None, + cache_position, + attention_mask, + head_mask[i], + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index e01e629f3392..e52ab48cdf26 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1048,10 +1048,10 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, + combined_attention_mask, + object_queries, + query_position_embeddings, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 851d03727de9..1f6e937542b1 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -895,10 +895,10 @@ def forward( outputs = block( hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, + layer_past, + attention_mask, + head_mask[i], + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index b82e7ff04478..f17fa319d9e4 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -723,10 +723,10 @@ def forward( outputs = block( hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, + layer_past, + attention_mask, + head_mask[i], + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index c0f4eddb1a3c..e78ada8780cb 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1290,8 +1290,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 74a62e99836f..a78ad47f2dd4 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -1132,8 +1132,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 0e3ee3c62415..7319671b485e 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1081,8 +1081,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 5827fcdb0e3f..b0f5fe029700 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1867,10 +1867,10 @@ def forward( layer_outputs = decoder_layer( hidden_states, - level_index=level_index, - position_embeddings=multi_stage_positional_embeddings, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, + level_index, + multi_stage_positional_embeddings, + query_position_embeddings, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=attention_mask, output_attentions=output_attentions, ) diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index b8c9caf9c989..b842adba4156 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -745,10 +745,10 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=None, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, + None, # attention_mask + object_queries, + query_position_embeddings, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index d8f61fbf5021..2909fb386fb5 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -781,9 +781,9 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index a95da766eb0c..2ffb53ee9e01 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1129,8 +1129,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 254501f44cd3..f68ccb46c72d 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1244,11 +1244,11 @@ def forward( layer_outputs = layer_module( hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_values, diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 41b2f7d04800..edf7458b7cbe 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -819,11 +819,11 @@ def forward( layer_outputs = layer_module( hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_values, diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 0e0a0312e6ce..65feeb2d2226 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1851,8 +1851,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 2245b795304a..7427f1dfab2d 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -1922,8 +1922,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index a2adee4f0946..95e9c887a51e 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -935,8 +935,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index e854980fd68a..9dfb26538289 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -1623,8 +1623,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 7722c476c396..d55a1be0c55a 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -992,10 +992,10 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, + combined_attention_mask, + object_queries, + query_position_embeddings, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 903ff66dac9a..c2660b6895a0 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -594,9 +594,9 @@ def forward( layer_outputs = layer_module( hidden_states, - attention_mask=attention_mask, - head_mask=layer_head_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_values, output_attentions=output_attentions, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 5cf5993309e5..8fbbb1fd444a 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -1060,8 +1060,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 2c4733d0977d..2b1f650c6789 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -768,8 +768,8 @@ def forward( layer_outputs = layer_module( hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_extended_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 562821e7ec31..65e3e6284046 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -550,8 +550,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), From aad2b9e109a8d53182edf0720b2e77029656d581 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 17:49:30 +0000 Subject: [PATCH 119/146] Update to pass encoder_decoder_position_bias as positional arg --- src/transformers/models/udop/modeling_udop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index a27e08fe8872..62c0b5db34f0 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1298,9 +1298,9 @@ def forward( hidden_states, causal_mask, position_bias, - encoder_hidden_states, # as a positional argument for gradient checkpointing - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=head_mask[i], past_key_value=past_key_values, use_cache=use_cache, From d34726db7c95d46d88d777c1ca6d8b242aa83b21 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 17:49:57 +0000 Subject: [PATCH 120/146] fixup --- src/transformers/models/dab_detr/modeling_dab_detr.py | 2 +- src/transformers/models/maskformer/modeling_maskformer.py | 2 +- src/transformers/models/udop/modeling_udop.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index 3f28bb95767b..47b67f4f7c93 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -1132,7 +1132,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, - None, # attention_mask + None, # attention_mask object_queries, query_pos, query_sine_embed, diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index b842adba4156..18d36427d921 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -745,7 +745,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, - None, # attention_mask + None, # attention_mask object_queries, query_position_embeddings, encoder_hidden_states, # as a positional argument for gradient checkpointing diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 62c0b5db34f0..f8dc9676693e 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1298,9 +1298,9 @@ def forward( hidden_states, causal_mask, position_bias, - encoder_hidden_states, + encoder_hidden_states, encoder_extended_attention_mask, - encoder_decoder_position_bias, # as a positional argument for gradient checkpointing + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=head_mask[i], past_key_value=past_key_values, use_cache=use_cache, From 55011beacee748954b3f7770c15f3759a529ec11 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 17:56:18 +0000 Subject: [PATCH 121/146] biogpt modular --- .../models/biogpt/modular_biogpt.py | 35 ++++++------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index d639f44ffec5..3b18890b2cb6 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -473,30 +473,17 @@ def forward( if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - head_mask[idx] if head_mask is not None else None, - None, - output_attentions, - use_cache, - position_ids, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - position_ids=position_ids, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_ids=position_ids, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] From a71d201c3652d622bbda39d347f5f59e176248e7 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 17:57:39 +0000 Subject: [PATCH 122/146] modular gemma2 --- .../models/gemma2/modular_gemma2.py | 38 +++++++------------ 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 9890711d3c1f..2496b759f0a0 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -26,6 +26,7 @@ from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import logging @@ -303,7 +304,7 @@ def forward( return attn_output, attn_weights -class Gemma2DecoderLayer(nn.Module): +class Gemma2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -449,30 +450,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - position_embeddings, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] From 0d7285731f5ef70e4deaabfa523eeee38f6f3a1d Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 17:59:12 +0000 Subject: [PATCH 123/146] modular gemma3 --- .../models/gemma3/modular_gemma3.py | 41 +++++++------------ 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index d93d53c8e93c..2efbace683fe 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -28,6 +28,7 @@ from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack @@ -443,7 +444,7 @@ def forward( return attn_output, attn_weights -class Gemma3DecoderLayer(nn.Module): +class Gemma3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma3TextConfig, layer_idx: int): super().__init__() self.config = config @@ -632,32 +633,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - position_embeddings_global, - position_embeddings_local, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] From 522df437cc48663c3e98b93f30e503871cafa05f Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:00:47 +0000 Subject: [PATCH 124/146] modular gpt_neox --- .../models/gpt_neox/modular_gpt_neox.py | 41 +++++++------------ 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 4922a4e3b4c3..0dc1058e3885 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -17,6 +17,7 @@ TokenClassifierOutput, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from ..llama.modeling_llama import LlamaModel, LlamaPreTrainedModel, LlamaRotaryEmbedding, rotate_half @@ -177,7 +178,7 @@ def forward( return attn_output, attn_weights -class GPTNeoXLayer(nn.Module): +class GPTNeoXLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.use_parallel_residual = config.use_parallel_residual @@ -362,32 +363,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - causal_mask, - position_ids, - head_mask[i], - use_cache, - past_key_values, - output_attentions, - cache_position, - position_embeddings, - ) - else: - outputs = layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - layer_past=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + outputs = layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + layer_past=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = outputs[0] if output_attentions: From 30a9a90b3745585b3d7de3fc34c6e16b86a3141a Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:03:00 +0000 Subject: [PATCH 125/146] modular informer --- .../models/informer/modeling_informer.py | 2 +- .../models/informer/modular_informer.py | 33 +++++++------------ 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index e78ada8780cb..c207a52df35c 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1291,7 +1291,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, attention_mask, - encoder_hidden_states, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 755fcd68853a..6e8b5107d0b0 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -31,6 +31,7 @@ BaseModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import ( auto_docstring, @@ -433,7 +434,7 @@ def forward( # source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py -class InformerConvLayer(nn.Module): +class InformerConvLayer(GradientCheckpointingLayer): def __init__(self, c_in): super().__init__() self.downConv = nn.Conv1d( @@ -610,27 +611,15 @@ def forward( if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - if conv_layer is not None: - output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0]) - layer_outputs = (output,) + layer_outputs[1:] - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) - if conv_layer is not None: - output = conv_layer(layer_outputs[0]) - layer_outputs = (output,) + layer_outputs[1:] + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + if conv_layer is not None: + output = conv_layer(layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] hidden_states = layer_outputs[0] From 89a2c6842a629899506002064bed4dcf505bb900 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:03:51 +0000 Subject: [PATCH 126/146] modular internvl --- src/transformers/models/internvl/modular_internvl.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 90576676b3cb..1fbdfbc6bf39 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -26,6 +26,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging, torch_int from ..clip.modeling_clip import CLIPMLP @@ -334,7 +335,7 @@ class InternVLVisionMLP(CLIPMLP): NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm} -class InternVLVisionLayer(nn.Module): +class InternVLVisionLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: InternVLVisionConfig) -> None: @@ -403,12 +404,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, output_attentions - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions) hidden_states = layer_outputs[0] From 633378c3940b8e3bcdb845e86f6a60f298e80593 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:05:21 +0000 Subject: [PATCH 127/146] modular mixtral --- .../models/mixtral/modeling_mixtral.py | 1 - .../models/mixtral/modular_mixtral.py | 41 +++++++------------ 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 8f82b59e5e47..013c958da52f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -23,7 +23,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Callable, Optional, Union import torch diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index bfc78597cfe7..abbec2f74906 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -32,6 +32,7 @@ from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import LossKwargs, logging from ..mistral.modeling_mistral import ( @@ -226,7 +227,7 @@ class MixtralAttention(MistralAttention): pass -class MixtralDecoderLayer(nn.Module): +class MixtralDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -386,32 +387,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] From 8493bada7174ddd26b536dd3c58f933c9a8af33c Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:07:19 +0000 Subject: [PATCH 128/146] modular mlcd --- src/transformers/models/mlcd/modular_mlcd.py | 21 ++++++-------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/mlcd/modular_mlcd.py b/src/transformers/models/mlcd/modular_mlcd.py index d18b2346224a..412d34daa5f7 100644 --- a/src/transformers/models/mlcd/modular_mlcd.py +++ b/src/transformers/models/mlcd/modular_mlcd.py @@ -356,21 +356,12 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - position_embeddings, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] From adf5c60fc95372117293395ad66f8d6d6fcc3cb0 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:08:46 +0000 Subject: [PATCH 129/146] modular modernbert --- .../models/modernbert/modular_modernbert.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index ff46a523a6cc..b692ee6db875 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -34,6 +34,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...utils import auto_docstring, is_flash_attn_2_available, logging from ...utils.import_utils import is_triton_available from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb @@ -710,7 +711,7 @@ def forward( return (hidden_states,) + attn_outputs[1:] # add attentions if outputted -class ModernBertEncoderLayer(nn.Module): +class ModernBertEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config @@ -994,27 +995,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - sliding_window_mask, - position_ids, - cu_seqlens, - max_seqlen, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - sliding_window_mask=sliding_window_mask, - position_ids=position_ids, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions and len(layer_outputs) > 1: all_self_attentions = all_self_attentions + (layer_outputs[1],) From 6270ff78e33d49268176f2c94c568fadb115f356 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:09:41 +0000 Subject: [PATCH 130/146] modular phi --- src/transformers/models/phi/modular_phi.py | 38 ++++++++-------------- 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index c515e13e7231..e08c72532a34 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -11,6 +11,7 @@ BaseModelOutputWithPast, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import logging from ..clip.modeling_clip import CLIPMLP @@ -118,7 +119,7 @@ class PhiMLP(CLIPMLP): pass -class PhiDecoderLayer(nn.Module): +class PhiDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() self.self_attn = PhiAttention(config, layer_idx=layer_idx) @@ -261,30 +262,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] From 3ad1fa9e5ec7875798b4f52f76c9bc51514ff7ee Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:12:45 +0000 Subject: [PATCH 131/146] modular qwen2_5_omni --- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 8 ++--- .../qwen2_5_omni/modular_qwen2_5_omni.py | 30 +++++-------------- 2 files changed, 9 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index c01ae5d83be9..49d4805df9bf 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -889,12 +889,8 @@ def forward( ) ).to(torch.int32) - for idx, encoder_layer in enumerate(self.layers): - layer_outputs = encoder_layer( - hidden_states, - cu_seqlens, - ) - + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states, cu_seqlens) hidden_states = layer_outputs[0] hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 10edb4e6a439..78f15bf58733 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1900,19 +1900,8 @@ def forward( ) ).to(torch.int32) - for idx, encoder_layer in enumerate(self.layers): - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - cu_seqlens, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - cu_seqlens, - ) - + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states, cu_seqlens) hidden_states = layer_outputs[0] hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) @@ -2166,16 +2155,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb - ) - else: - hidden_states = blk( - hidden_states, - cu_seqlens=cu_seqlens_now, - rotary_pos_emb=rotary_pos_emb, - ) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] From 7626b313807ecfb47cf5d747d1070a431fc655f3 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:14:27 +0000 Subject: [PATCH 132/146] modular qwen2_5_vl --- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 71764b245633..a6dedefa0195 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -50,6 +50,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput from ...modeling_flash_attention_utils import is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import MultiModalData, ProcessingKwargs, Unpack, VideosKwargs from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_torchdynamo_compiling, logging @@ -205,7 +206,7 @@ class Qwen2_5_VLVisionSdpaAttention(VisionSdpaAttention): } -class Qwen2_5_VLVisionBlock(nn.Module): +class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) @@ -395,12 +396,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings - ) - else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) From e3c61cef89e809f585696aff206a3d29b7857be6 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:15:27 +0000 Subject: [PATCH 133/146] modular sam_hq --- src/transformers/models/sam_hq/modeling_sam_hq.py | 4 +--- src/transformers/models/sam_hq/modular_sam_hq.py | 12 ++---------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py index a0ff9c309673..982bbbb47e07 100644 --- a/src/transformers/models/sam_hq/modeling_sam_hq.py +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -540,12 +540,10 @@ def forward( all_self_attentions = () if output_attentions else None intermediate_embeddings = [] - for i, layer_module in enumerate(self.layers): + for layer_module in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) - hidden_states = layer_outputs[0] # Collect embeddings from non-windowed blocks diff --git a/src/transformers/models/sam_hq/modular_sam_hq.py b/src/transformers/models/sam_hq/modular_sam_hq.py index a78ce712cc0d..55f475880cab 100644 --- a/src/transformers/models/sam_hq/modular_sam_hq.py +++ b/src/transformers/models/sam_hq/modular_sam_hq.py @@ -151,18 +151,10 @@ def forward( all_self_attentions = () if output_attentions else None intermediate_embeddings = [] - for i, layer_module in enumerate(self.layers): + for layer_module in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) - + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) hidden_states = layer_outputs[0] # Collect embeddings from non-windowed blocks From 01934cf256e7c452e3646e4ad09ee0b62cfce8c6 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:16:39 +0000 Subject: [PATCH 134/146] modular sew --- src/transformers/models/sew/modular_sew.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py index 0b151b05a2ce..2d56fea3bc6e 100644 --- a/src/transformers/models/sew/modular_sew.py +++ b/src/transformers/models/sew/modular_sew.py @@ -230,17 +230,9 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: From 62683dc486f32426f0259c9d0451ad7d7328b88f Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:18:18 +0000 Subject: [PATCH 135/146] wav2vec2_bert --- .../wav2vec2_bert/modular_wav2vec2_bert.py | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py index 3427a01808a3..e9e7691f466c 100644 --- a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py @@ -18,6 +18,7 @@ XVectorOutput, ) from ...modeling_utils import PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...utils import auto_docstring, logging from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2FeedForward, Wav2Vec2ForSequenceClassification, Wav2Vec2Model from ..wav2vec2_conformer.modeling_wav2vec2_conformer import ( @@ -292,7 +293,7 @@ def forward( return hidden_states, probs -class Wav2Vec2BertEncoderLayer(nn.Module): +class Wav2Vec2BertEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" def __init__(self, config): @@ -418,23 +419,13 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - relative_position_embeddings, - output_attentions, - conv_attention_mask, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - relative_position_embeddings=relative_position_embeddings, - output_attentions=output_attentions, - conv_attention_mask=conv_attention_mask, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) hidden_states = layer_outputs[0] if skip_the_layer: From 28bd09c8d53b1ddca2fa0eb4995acf2c43ffca10 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:19:21 +0000 Subject: [PATCH 136/146] modular wav2vec2_conformer --- .../modular_wav2vec2_conformer.py | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py index fc3444a545f5..d99b9eb9bd84 100644 --- a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py @@ -10,6 +10,7 @@ from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...utils import ModelOutput, auto_docstring, logging from ..wav2vec2.modeling_wav2vec2 import ( Wav2Vec2Adapter, @@ -384,7 +385,7 @@ def _apply_relative_embeddings(self, query, key, relative_position_embeddings): return scores -class Wav2Vec2ConformerEncoderLayer(nn.Module): +class Wav2Vec2ConformerEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" def __init__(self, config): @@ -511,21 +512,12 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - relative_position_embeddings, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - relative_position_embeddings=relative_position_embeddings, - output_attentions=output_attentions, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if skip_the_layer: From 5bc6525140050d2d12d5cfd6464e0ceaa6088505 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:21:38 +0000 Subject: [PATCH 137/146] modular wavlm --- .../models/wavlm/modular_wavlm.py | 49 ++++++------------- 1 file changed, 16 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py index 53d29edc0e7d..e6012c70a8a8 100644 --- a/src/transformers/models/wavlm/modular_wavlm.py +++ b/src/transformers/models/wavlm/modular_wavlm.py @@ -9,6 +9,7 @@ from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...modeling_layers import GradientCheckpointingLayer from ...utils import logging from ..wav2vec2.modeling_wav2vec2 import ( Wav2Vec2FeatureProjection, @@ -205,7 +206,7 @@ class WavLMFeedForward(Wav2Vec2FeedForward): pass -class WavLMEncoderLayer(nn.Module): +class WavLMEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( @@ -246,7 +247,7 @@ def forward(self, hidden_states, attention_mask=None, position_bias=None, output return outputs -class WavLMEncoderLayerStableLayerNorm(nn.Module): +class WavLMEncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( @@ -329,22 +330,13 @@ def forward( skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - position_bias, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - index=i, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + index=i, + ) hidden_states, position_bias = layer_outputs[:2] @@ -415,21 +407,12 @@ def forward( if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - position_bias, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - position_bias=position_bias, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + position_bias=position_bias, + ) hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: From 31dbec4d91dcf83af60f6f2387b9e2a9638222d6 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:22:05 +0000 Subject: [PATCH 138/146] fixup --- src/transformers/models/biogpt/modular_biogpt.py | 1 - src/transformers/models/gemma2/modular_gemma2.py | 3 +-- src/transformers/models/gemma3/modular_gemma3.py | 3 +-- src/transformers/models/gpt_neox/modular_gpt_neox.py | 2 +- src/transformers/models/informer/modular_informer.py | 2 +- src/transformers/models/internvl/modular_internvl.py | 2 +- src/transformers/models/mixtral/modular_mixtral.py | 3 +-- src/transformers/models/modernbert/modular_modernbert.py | 2 +- src/transformers/models/phi/modular_phi.py | 3 +-- src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py | 2 +- .../models/wav2vec2_conformer/modular_wav2vec2_conformer.py | 2 +- src/transformers/models/wavlm/modular_wavlm.py | 2 +- 12 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 3b18890b2cb6..938b1c9d8beb 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -15,7 +15,6 @@ """PyTorch BioGPT model.""" import math -from functools import partial from typing import Optional, Union import torch diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 2496b759f0a0..b317936c7766 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Union import torch @@ -25,8 +24,8 @@ from ...configuration_utils import PretrainedConfig, layer_type_validation from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import logging diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 2efbace683fe..bc1db4b50a45 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -16,7 +16,6 @@ import copy from collections.abc import Callable from dataclasses import dataclass -from functools import partial from typing import Any, Optional, Union import torch @@ -27,8 +26,8 @@ from ...configuration_utils import PretrainedConfig, layer_type_validation from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 0dc1058e3885..fde2677b4e20 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -9,6 +9,7 @@ from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -17,7 +18,6 @@ TokenClassifierOutput, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from ..llama.modeling_llama import LlamaModel, LlamaPreTrainedModel, LlamaRotaryEmbedding, rotate_half diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 6e8b5107d0b0..3d46275bdc81 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -27,11 +27,11 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import ( auto_docstring, diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 1fbdfbc6bf39..a71b9fbdad81 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -24,9 +24,9 @@ from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging, torch_int from ..clip.modeling_clip import CLIPMLP diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index abbec2f74906..c4e4a4296663 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -19,7 +19,6 @@ # limitations under the License. """PyTorch Mixtral model.""" -from functools import partial from typing import Optional, Union import torch @@ -31,8 +30,8 @@ from ...cache_utils import DynamicCache from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...processing_utils import Unpack from ...utils import LossKwargs, logging from ..mistral.modeling_mistral import ( diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index b692ee6db875..e5c0ce845d04 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...configuration_utils import PretrainedConfig from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -34,7 +35,6 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...utils import auto_docstring, is_flash_attn_2_available, logging from ...utils.import_utils import is_triton_available from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index e08c72532a34..46a367bbdb10 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -1,4 +1,3 @@ -from functools import partial from typing import Callable, Optional import torch @@ -7,11 +6,11 @@ from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import Unpack from ...utils import logging from ..clip.modeling_clip import CLIPMLP diff --git a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py index e9e7691f466c..d0f375332b33 100644 --- a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py @@ -9,6 +9,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -18,7 +19,6 @@ XVectorOutput, ) from ...modeling_utils import PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...utils import auto_docstring, logging from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2FeedForward, Wav2Vec2ForSequenceClassification, Wav2Vec2Model from ..wav2vec2_conformer.modeling_wav2vec2_conformer import ( diff --git a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py index d99b9eb9bd84..3436563c0db8 100644 --- a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py @@ -8,9 +8,9 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...utils import ModelOutput, auto_docstring, logging from ..wav2vec2.modeling_wav2vec2 import ( Wav2Vec2Adapter, diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py index e6012c70a8a8..aac25ff262bb 100644 --- a/src/transformers/models/wavlm/modular_wavlm.py +++ b/src/transformers/models/wavlm/modular_wavlm.py @@ -7,9 +7,9 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...modeling_layers import GradientCheckpointingLayer from ...utils import logging from ..wav2vec2.modeling_wav2vec2 import ( Wav2Vec2FeatureProjection, From b989ba65597dcfd6a4d6f207623c8ff6b183dd05 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:29:17 +0000 Subject: [PATCH 139/146] Update by modular instructblipvideo --- .../instructblipvideo/modeling_instructblipvideo.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 8c0d5c05a3b0..ee9cffd4f2e8 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -356,7 +356,7 @@ def forward( layer_outputs = encoder_layer( hidden_states, - attention_mask, + attention_mask=attention_mask, output_attentions=output_attentions, ) @@ -750,11 +750,11 @@ def forward( hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + query_length=query_length, ) hidden_states = layer_outputs[0] From cdb4c7062f6e5effca31bced47bd5903d2c3c27c Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:32:17 +0000 Subject: [PATCH 140/146] modular data2vec_audio --- src/transformers/models/data2vec/modular_data2vec_audio.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index 0b4695c1e28c..94a4d3e080af 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -20,6 +20,7 @@ from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel from ..wav2vec2.modeling_wav2vec2 import ( @@ -38,7 +39,7 @@ from .configuration_data2vec_audio import Data2VecAudioConfig -class Data2VecAudioConvLayer(nn.Module): +class Data2VecAudioConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 From d50dd86ca3e5d8e98e4830fab2ad9551ab0cb304 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:33:08 +0000 Subject: [PATCH 141/146] nit modular mistral --- src/transformers/models/mixtral/modeling_mixtral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 013c958da52f..8f82b59e5e47 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -23,6 +23,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Callable, Optional, Union import torch From 4c5aa0ba7d26eec5f0f93c474cf3e4ded6323391 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:35:44 +0000 Subject: [PATCH 142/146] apply modular minimax --- src/transformers/models/minimax/modeling_minimax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 1bf968e0361c..34e0b507f446 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -32,6 +32,7 @@ from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, MoeCausalLMOutputWithPast, @@ -485,7 +486,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -class MiniMaxDecoderLayer(nn.Module): +class MiniMaxDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MiniMaxConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size From 6585288e1aac3831d26c89828f2e8c7036745c81 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:38:35 +0000 Subject: [PATCH 143/146] fix modular moonshine --- src/transformers/models/moonshine/modular_moonshine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 4ee7cd81f772..500231f3b48b 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -787,9 +787,9 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, From 4ac7c961eb5ef9458533791d561a970d02f958de Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 18:44:15 +0000 Subject: [PATCH 144/146] revert zamba2 --- .../models/zamba2/modeling_zamba2.py | 41 ++++++++++++------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index ce6a0a2ffd5e..ecd0abcb0263 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -33,7 +33,6 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -1046,7 +1045,7 @@ def forward( return outputs -class Zamba2MambaDecoderLayer(GradientCheckpointingLayer): +class Zamba2MambaDecoderLayer(nn.Module): def __init__(self, config: Zamba2Config, layer_idx: int): super().__init__() self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx) @@ -1114,7 +1113,7 @@ def forward( return outputs -class Zamba2HybridLayer(GradientCheckpointingLayer): +class Zamba2HybridLayer(nn.Module): def __init__( self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer ): @@ -1350,17 +1349,31 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = layer( - hidden_states, - original_hidden_states, - layer_idx, - attention_mask, - causal_mask, - past_key_values, - output_attentions, - use_cache, - position_embeddings, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + original_hidden_states, + layer_idx, + attention_mask, + causal_mask, + past_key_values, + output_attentions, + use_cache, + position_embeddings, + ) + else: + layer_outputs = layer( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + causal_mask=causal_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] if output_attentions: From 58847e787aec0f75c28cd047412883ee9f3da265 Mon Sep 17 00:00:00 2001 From: qubvel Date: Fri, 20 Jun 2025 20:24:07 +0000 Subject: [PATCH 145/146] fix mask2former --- src/transformers/models/mask2former/modeling_mask2former.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index b0f5fe029700..5ab37ea53581 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1868,6 +1868,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, level_index, + None, # attention_mask multi_stage_positional_embeddings, query_position_embeddings, encoder_hidden_states, # as a positional argument for gradient checkpointing From 9b8e96553438ce407059617858b549afb6e39d82 Mon Sep 17 00:00:00 2001 From: qubvel Date: Mon, 23 Jun 2025 10:09:34 +0000 Subject: [PATCH 146/146] refactor idefics --- .../models/idefics/modeling_idefics.py | 100 ++++-------------- 1 file changed, 19 insertions(+), 81 deletions(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index d3bba25a5642..a5a868072a6a 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -32,6 +32,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PretrainedConfig, PreTrainedModel from ...processing_utils import Unpack @@ -668,7 +669,7 @@ def forward( # this was adapted from LlamaDecoderLayer -class IdeficsDecoderLayer(nn.Module): +class IdeficsDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size @@ -749,7 +750,7 @@ def forward( return outputs -class IdeficsGatedCrossAttentionLayer(nn.Module): +class IdeficsGatedCrossAttentionLayer(GradientCheckpointingLayer): def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size @@ -1185,95 +1186,32 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - def vblock( - main_block, - hidden_states, - attention_mask, - position_ids, - past_key_value, - image_hidden_states, - image_attention_mask, - cross_attention_gate, - output_attentions, - use_cache, - layer_idx, - cross_layer_interval, - gated_cross_attn_layers, - cache_position, - ): - # TODO(ls): Add cross attention values to respective lists - if layer_idx % cross_layer_interval == 0: - xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval] - outputs = xblock( - hidden_states, - attention_mask=attention_mask, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - cross_attention_gate=cross_attention_gate, - output_attentions=output_attentions, - use_cache=use_cache, - past_key_value=None, # not implemented - **kwargs, - ) - hidden_states = outputs[0] - - layer_outputs = main_block( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - - return layer_outputs - - if self.gradient_checkpointing and self.training: - past_key_values = None - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - layer_outputs = self._gradient_checkpointing_func( - vblock, - decoder_layer, + # TODO(ls): Add cross attention values to respective lists + if idx % self.cross_layer_interval == 0: + cross_attn_block = self.gated_cross_attn_layers[idx // self.cross_layer_interval] + outputs = cross_attn_block( hidden_states, attention_mask, - position_ids, - past_key_values, image_hidden_states, - image_attention_mask, - cross_attention_gate, - output_attentions, - use_cache, - idx, - self.cross_layer_interval, - self.gated_cross_attn_layers, - cache_position, - ) - else: - layer_outputs = vblock( - decoder_layer, - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - image_hidden_states=image_hidden_states, image_attention_mask=image_attention_mask, cross_attention_gate=cross_attention_gate, output_attentions=output_attentions, use_cache=use_cache, - layer_idx=idx, - cross_layer_interval=self.cross_layer_interval, - gated_cross_attn_layers=self.gated_cross_attn_layers, - cache_position=cache_position, + past_key_value=None, # not implemented **kwargs, ) + hidden_states = outputs[0] + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) hidden_states = layer_outputs[0] if use_cache: