I want to perform bidirectional attention in the Qwen3 model to train an embedding model, so I passed is_causal=False in the model forward (I manually added is_causal arguments in all forward method such as Qwen3Model and Qwen3Attention inmodeling_qwen3.py):
class Qwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
...
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
is_causal: Optional[bool] = True, # I add is_causal here
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
...
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # diff with Llama
is_causal=is_causal, # and is_causal from the argument is passed to the attention_interface (e.g. `flash_attention_2`, `sdpa_attention_forward`)
**kwargs,
)
I can successfully change the causality of the attention in sdpa_attention_forward. However, I realized that it does not change the causality in the attention in flash_attention_forward. After diving into the implementation of flash_attention_forward, I found the reason in flash_attention_forward located at transformers/integrations/flash_attention.py:
def flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
...
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
kwargs.pop("is_causal", None)
attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
query_length=seq_len,
is_causal=module.is_causal, # here module is `Qwen3Attention`
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
softcap=softcap,
use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
attn_implementation=module.config._attn_implementation,
**kwargs,
)
As you can see, the is_causal argument is popped, and the is_causal of Qwen3Attention is used as the argument. Note that Qwen3Attention.is_causal is never changed, and its default value is True, so the is_causal argument passed into _flash_attention_forward will always be True regardless of any change.
After I add a line of code to alter the Qwen3Attention.is_causal, i.e. self.is_causal = is_causal before passing the arguments into attention_interface, I can change the causality of flash_attention_forward. So I would like to know if it is a feature or a bug? Thank you!!
I want to perform bidirectional attention in the Qwen3 model to train an embedding model, so I passed
is_causal=Falsein the modelforward(I manually addedis_causalarguments in allforwardmethod such asQwen3ModelandQwen3Attentioninmodeling_qwen3.py):I can successfully change the causality of the attention in
sdpa_attention_forward. However, I realized that it does not change the causality in the attention inflash_attention_forward. After diving into the implementation offlash_attention_forward, I found the reason inflash_attention_forwardlocated attransformers/integrations/flash_attention.py:As you can see, the
is_causalargument is popped, and theis_causalofQwen3Attentionis used as the argument. Note thatQwen3Attention.is_causalis never changed, and its default value isTrue, so theis_causalargument passed into_flash_attention_forwardwill always beTrueregardless of any change.After I add a line of code to alter the
Qwen3Attention.is_causal, i.e.self.is_causal = is_causalbefore passing the arguments intoattention_interface, I can change the causality offlash_attention_forward. So I would like to know if it is a feature or a bug? Thank you!!