Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
148 commits
Select commit Hold shift + click to select a range
a97ca9f
first batch (4)
qubvel Jun 19, 2025
2627646
align
qubvel Jun 19, 2025
c8926f7
altclip
qubvel Jun 19, 2025
ae1b29a
beit
qubvel Jun 19, 2025
4dff076
bert
qubvel Jun 19, 2025
0d0d8c7
yolos
qubvel Jun 19, 2025
8b66428
dino, pvt_v2
qubvel Jun 19, 2025
0d387eb
bark, bart, bert_generation
qubvel Jun 19, 2025
6faee3f
big_bird, biogpt
qubvel Jun 19, 2025
3f34606
blnderbot, bloom
qubvel Jun 19, 2025
3bb70d9
bridgetower
qubvel Jun 19, 2025
5757f3e
camambert, canine, chameleon
qubvel Jun 19, 2025
c59a7d5
chinese clip, clap, clip
qubvel Jun 19, 2025
d7cb795
codegen, conditional detr, convbert
qubvel Jun 19, 2025
39784f7
dab_detr, data2vec
qubvel Jun 19, 2025
203348d
dbrx, deberta
qubvel Jun 19, 2025
b2719f3
deberta, decicion_tranformer, deformable_detr
qubvel Jun 19, 2025
2ed2c5b
deit, deta, mctct
qubvel Jun 19, 2025
87704a7
detr, dinov2, distilbert
qubvel Jun 19, 2025
cd69033
donut, dpt, electra
qubvel Jun 19, 2025
9a54ad1
ernie, esm, falcon
qubvel Jun 19, 2025
6855515
flava, fnet, falcon_mamba
qubvel Jun 19, 2025
f4f8319
focalnet, git, gpt2
qubvel Jun 19, 2025
b8f4ecf
gpt - bigcode, neo, neox
qubvel Jun 19, 2025
d844b12
gptj, groupvit
qubvel Jun 19, 2025
700d20d
idefics2, idefics3
qubvel Jun 19, 2025
0b3ffba
ijepa, imagegpt, internvl
qubvel Jun 19, 2025
9ed27ef
jetmoe, kosmos2, layoutlm
qubvel Jun 19, 2025
6d3ecbc
layoutlm2-3, led
qubvel Jun 19, 2025
e398d8e
lilt, longformer, longt5, luke
qubvel Jun 19, 2025
4363156
m2m, mamba1-2
qubvel Jun 19, 2025
dde58de
marian, markuplm, mask2former
qubvel Jun 19, 2025
69b2cf8
maskformer
qubvel Jun 19, 2025
d4ccb79
mbart, megatron_bert, mimi
qubvel Jun 19, 2025
ab213da
mixtral, mlcd
qubvel Jun 19, 2025
cb90916
mobilevit1-2, modernbert
qubvel Jun 19, 2025
c2d3cbc
moshi, mpt, mra
qubvel Jun 19, 2025
80bcd7c
mt5, musicgen
qubvel Jun 19, 2025
825e2b1
mvp, nemotron
qubvel Jun 19, 2025
8f6a8fb
nllb_moe
qubvel Jun 19, 2025
6253d78
nystromformer, omdet_turbo
qubvel Jun 19, 2025
ab136ef
opt, owlvit, owlv2
qubvel Jun 19, 2025
3fb64a9
pegasus, pegasus_x, presimmon
qubvel Jun 19, 2025
32b2876
phimoe, pix2struct, pixtral
qubvel Jun 19, 2025
942f7a4
plbart, pop2piano, prophetnet
qubvel Jun 19, 2025
b083c86
qwen2*
qubvel Jun 19, 2025
429ba11
qwen2, qwen3 moe, rec gemma
qubvel Jun 19, 2025
cec0d32
rembert
qubvel Jun 19, 2025
bec1fcd
roberta
qubvel Jun 19, 2025
254882f
roberta prelayernorm
qubvel Jun 19, 2025
a1a7fda
roc_bert, roformer, rwkv
qubvel Jun 19, 2025
d497df9
sam, sam_hq
qubvel Jun 19, 2025
987a880
seggpt, smolvlm, speech_to_text
qubvel Jun 19, 2025
6ef90e1
splinter, stablelm, swin
qubvel Jun 19, 2025
1b7cc3f
swin2sr, switch_transformer, t5, table_transformer
qubvel Jun 19, 2025
5331bc2
tapas, time_series_tranformer, timesformer
qubvel Jun 19, 2025
dfe3d8d
trocr, tvp, umt5
qubvel Jun 19, 2025
c001253
videomae, vilt, visual_bert
qubvel Jun 19, 2025
76dd7a5
vit, vit_mae, vit_msn
qubvel Jun 19, 2025
0bc5335
vitpose_backbone, vits, vivit
qubvel Jun 19, 2025
43992d9
whisper. x_clip, xglm
qubvel Jun 19, 2025
461961b
xlm_roberta, xmod
qubvel Jun 19, 2025
cf470fd
yoso
qubvel Jun 19, 2025
626dde0
zamba
qubvel Jun 19, 2025
59f8879
vitdet, wav2vec2, wav2vec2_bert
qubvel Jun 19, 2025
b89a5db
unispeech, wav2vec_conformer
qubvel Jun 19, 2025
db524cb
wavlm
qubvel Jun 19, 2025
96db85e
speecht5
qubvel Jun 19, 2025
279041b
swinv2
qubvel Jun 19, 2025
5a3b571
sew / _d
qubvel Jun 19, 2025
b1d78cd
seamless_mt4 / _v2
qubvel Jun 19, 2025
9a6d135
deprecated models update
qubvel Jun 19, 2025
a18e257
bros
qubvel Jun 19, 2025
66d0a62
gemma2, gemma3
qubvel Jun 19, 2025
c0e5690
got, hiera, hubert, llama4, mllama, oneformer, phi, olmoe, informer
qubvel Jun 19, 2025
0942755
fixup
qubvel Jun 19, 2025
fe80395
Add use_cache=False and past_key_value=None to GradientCheckpointing…
qubvel Jun 19, 2025
d7963dc
fixup
qubvel Jun 19, 2025
73d5614
fix prophetnet
qubvel Jun 20, 2025
cd7a426
fix bigbird_pegasus
qubvel Jun 20, 2025
56cb34b
fix blenderbot
qubvel Jun 20, 2025
0347dde
fix mbart
qubvel Jun 20, 2025
e83086d
fix mvp
qubvel Jun 20, 2025
afbfd62
fix zamba2
qubvel Jun 20, 2025
68f317c
fix bart
qubvel Jun 20, 2025
98fb670
fix blenderbot_small
qubvel Jun 20, 2025
2fd38a3
fix codegen
qubvel Jun 20, 2025
5347767
Update gradient checkpointing layer to support more past_key_values a…
qubvel Jun 20, 2025
10f5fd1
fix data2vec vision
qubvel Jun 20, 2025
792cd7d
fix deformable_detr
qubvel Jun 20, 2025
36415c3
fix gptj
qubvel Jun 20, 2025
ff802fe
fix led
qubvel Jun 20, 2025
fc14014
fix m2m_100
qubvel Jun 20, 2025
f2cc865
add comment
qubvel Jun 20, 2025
eab402d
fix nnlb_moe
qubvel Jun 20, 2025
aa1f574
Fix pegasus_x
qubvel Jun 20, 2025
7c9d17d
fix plbart
qubvel Jun 20, 2025
5da2216
udop
qubvel Jun 20, 2025
999584c
fix-copies: beit, wav2vec2
qubvel Jun 20, 2025
ff33682
fix gpt_bigcode
qubvel Jun 20, 2025
c28e913
fixup
qubvel Jun 20, 2025
8104bfb
fix t5
qubvel Jun 20, 2025
f9a2db8
fix switch_transformers
qubvel Jun 20, 2025
fe1133e
fix longt5
qubvel Jun 20, 2025
e51772a
fix mt5
qubvel Jun 20, 2025
69a6a78
update tapas
qubvel Jun 20, 2025
eb20826
fix blip2
qubvel Jun 20, 2025
eba9a9a
update blip
qubvel Jun 20, 2025
aa71309
fix musicgen
qubvel Jun 20, 2025
c0b3084
fix gpt2, trocr
qubvel Jun 20, 2025
b6ac147
fix copies
qubvel Jun 20, 2025
481ae6f
!!! Revert zamba, mllama
qubvel Jun 20, 2025
07e0995
update autoformer
qubvel Jun 20, 2025
a2c8bd6
update bros
qubvel Jun 20, 2025
3160753
update args / kwargs for BERT and copies
qubvel Jun 20, 2025
1f3d7b0
2nd round of updates
qubvel Jun 20, 2025
32433aa
update conditional detr
qubvel Jun 20, 2025
95365b1
Pass encoder_hidden_states as positional arg
qubvel Jun 20, 2025
aad2b9e
Update to pass encoder_decoder_position_bias as positional arg
qubvel Jun 20, 2025
d34726d
fixup
qubvel Jun 20, 2025
55011be
biogpt modular
qubvel Jun 20, 2025
a71d201
modular gemma2
qubvel Jun 20, 2025
0d72857
modular gemma3
qubvel Jun 20, 2025
522df43
modular gpt_neox
qubvel Jun 20, 2025
30a9a90
modular informer
qubvel Jun 20, 2025
89a2c68
modular internvl
qubvel Jun 20, 2025
633378c
modular mixtral
qubvel Jun 20, 2025
8493bad
modular mlcd
qubvel Jun 20, 2025
adf5c60
modular modernbert
qubvel Jun 20, 2025
6270ff7
modular phi
qubvel Jun 20, 2025
3ad1fa9
modular qwen2_5_omni
qubvel Jun 20, 2025
7626b31
modular qwen2_5_vl
qubvel Jun 20, 2025
e3c61ce
modular sam_hq
qubvel Jun 20, 2025
01934cf
modular sew
qubvel Jun 20, 2025
62683dc
wav2vec2_bert
qubvel Jun 20, 2025
28bd09c
modular wav2vec2_conformer
qubvel Jun 20, 2025
5bc6525
modular wavlm
qubvel Jun 20, 2025
31dbec4
fixup
qubvel Jun 20, 2025
b989ba6
Update by modular instructblipvideo
qubvel Jun 20, 2025
cdb4c70
modular data2vec_audio
qubvel Jun 20, 2025
d50dd86
nit modular mistral
qubvel Jun 20, 2025
4c5aa0b
apply modular minimax
qubvel Jun 20, 2025
6585288
fix modular moonshine
qubvel Jun 20, 2025
4ac7c96
revert zamba2
qubvel Jun 20, 2025
58847e7
fix mask2former
qubvel Jun 20, 2025
2e4e2b1
Merge branch 'main' into gradient-checkpointing-layer-propagation
qubvel Jun 21, 2025
9b8e965
refactor idefics
qubvel Jun 23, 2025
8a8898c
Merge branch 'main' into gradient-checkpointing-layer-propagation
qubvel Jun 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/transformers/modeling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

import torch.nn as nn

from transformers.utils import logging


logger = logging.get_logger(__name__)


class GradientCheckpointingLayer(nn.Module):
"""Base class for layers with gradient checkpointing.
Expand Down Expand Up @@ -44,5 +49,35 @@ class GradientCheckpointingLayer(nn.Module):

def __call__(self, *args, **kwargs):
if self.gradient_checkpointing and self.training:
do_warn = False
layer_name = self.__class__.__name__
message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting"

if "use_cache" in kwargs and kwargs["use_cache"]:
kwargs["use_cache"] = False
message += " `use_cache=False`,"
do_warn = True

# different names for the same thing in different layers
if "past_key_value" in kwargs and kwargs["past_key_value"] is not None:
kwargs["past_key_value"] = None
message += " `past_key_value=None`,"
do_warn = True

if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
kwargs["past_key_values"] = None
message += " `past_key_values=None`,"
do_warn = True

if "layer_past" in kwargs and kwargs["layer_past"] is not None:
kwargs["layer_past"] = None
message += " `layer_past=None`,"
do_warn = True

# warn if anything was changed
if do_warn:
message = message.rstrip(",") + "."
logger.warning(message)

Comment on lines +52 to +81
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update for GradientCheckpointingLayer

return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
return super().__call__(*args, **kwargs)
33 changes: 11 additions & 22 deletions src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch import nn

from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithNoAttention,
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -827,7 +828,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to


# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText
class AlignTextLayer(nn.Module):
class AlignTextLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
Expand Down Expand Up @@ -953,27 +954,15 @@ def forward(
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]
if use_cache:
Expand Down
56 changes: 18 additions & 38 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch.utils.checkpoint

from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -418,7 +419,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to


# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta
class AltRobertaLayer(nn.Module):
class AltRobertaLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
Expand Down Expand Up @@ -544,27 +545,15 @@ def forward(
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]
if use_cache:
Expand Down Expand Up @@ -732,7 +721,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class AltCLIPEncoderLayer(nn.Module):
class AltCLIPEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: AltCLIPConfig):
super().__init__()
self.embed_dim = config.hidden_size
Expand Down Expand Up @@ -848,21 +837,12 @@ def forward(
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
Expand Down Expand Up @@ -282,7 +283,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to


# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
class ASTLayer(nn.Module):
class ASTLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""

def __init__(self, config: ASTConfig) -> None:
Expand Down Expand Up @@ -349,16 +350,7 @@ def forward(

layer_head_mask = head_mask[i] if head_mask is not None else None

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)

layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]

if output_attentions:
Expand Down
75 changes: 26 additions & 49 deletions src/transformers/models/autoformer/modeling_autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPredictionOutput, Seq2SeqTSPredictionOutput
from ...modeling_utils import PreTrainedModel
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
Expand Down Expand Up @@ -670,7 +671,7 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


class AutoformerEncoderLayer(nn.Module):
class AutoformerEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: AutoformerConfig):
super().__init__()
self.embed_dim = config.d_model
Expand Down Expand Up @@ -744,7 +745,7 @@ def forward(
return outputs


class AutoformerDecoderLayer(nn.Module):
class AutoformerDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: AutoformerConfig):
super().__init__()
self.embed_dim = config.d_model
Expand Down Expand Up @@ -1042,21 +1043,12 @@ def forward(
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]

Expand Down Expand Up @@ -1186,6 +1178,12 @@ def forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if self.gradient_checkpointing and self.training and use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

input_shape = inputs_embeds.size()[:-1]

# expand encoder attention mask
Expand Down Expand Up @@ -1228,38 +1226,17 @@ def forward(

past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
(hidden_states, residual_trend) = layer_outputs[0]
trend = trend + residual_trend

Expand Down
30 changes: 10 additions & 20 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, get_parameter_device
from ...utils import (
Expand Down Expand Up @@ -309,7 +310,7 @@ def forward(self, hidden_states):
return hidden_states


class BarkBlock(nn.Module):
class BarkBlock(GradientCheckpointingLayer):
def __init__(self, config, is_causal=False):
super().__init__()

Expand Down Expand Up @@ -606,25 +607,14 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
hidden_states,
past_key_values=past_layer_key_values,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
outputs = block(
hidden_states,
past_key_values=past_layer_key_values,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)

hidden_states = outputs[0]

Expand Down
Loading