diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 9f89bfe5778d..4e953aa016f4 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -81,7 +81,11 @@ def flash_attention_forward( target_dtype=target_dtype, attn_implementation=module.config._attn_implementation, layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, - s_aux=s_aux.to(query.dtype), # FA only accepts half precision + s_aux=( + s_aux.to(query.dtype) # FA only accepts half precision + if s_aux is not None + else None + ), **kwargs, )