From e2feb59b62124ad52a51616968f1465613bf8644 Mon Sep 17 00:00:00 2001 From: vasqu Date: Thu, 4 Sep 2025 20:33:08 +0200 Subject: [PATCH] alternative gemma fix --- src/transformers/integrations/sdpa_attention.py | 10 +++++----- src/transformers/models/gemma3/modeling_gemma3.py | 3 +-- src/transformers/models/gemma3/modular_gemma3.py | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 1f04806cad09..f6c6f2785c3f 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -69,11 +69,11 @@ def sdpa_attention_forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # NOTE: It is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` - # NOTE: We give priority to the passed kwarg. Otherwise, we check for the module's set flag. This is especially important for models with - # mixed attentions such as encoder-decoder models (encoder, decoder, and encoder-decoder/cross attention). - is_causal = getattr(module, "is_causal", True) if is_causal is None else is_causal - is_causal = query.shape[2] > 1 and attention_mask is None and is_causal + # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` + if is_causal is None: + # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag + # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns + is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # We convert it to a bool for the SDPA kernel that only accepts bools. diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 54df7642553e..0c080a355788 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -279,7 +279,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.query_pre_attn_scalar**-0.5 self.attention_dropout = self.config.attention_dropout - self.is_causal = True + self.is_causal = not self.config.use_bidirectional_attention self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias @@ -581,7 +581,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - is_causal=not self.config.use_bidirectional_attention, **kwargs, ) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index c594e3471021..a1f85c8aade7 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -404,6 +404,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): super().__init__(config, layer_idx) self.sliding_window = config.sliding_window if self.is_sliding else None + self.is_causal = not self.config.use_bidirectional_attention self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) @@ -665,7 +666,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - is_causal=not self.config.use_bidirectional_attention, **kwargs, )