diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 9af6eba11f66..552d89bac2f6 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -58,8 +58,10 @@ def flash_attention_forward( else: target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype - # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice - kwargs.pop("is_causal", None) + # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented + is_causal = kwargs.pop("is_causal", None) + if is_causal is None: + is_causal = module.is_causal attn_output = _flash_attention_forward( query, @@ -67,7 +69,7 @@ def flash_attention_forward( value, attention_mask, query_length=seq_len, - is_causal=module.is_causal, + is_causal=is_causal, dropout=dropout, softmax_scale=scaling, sliding_window=sliding_window,