From 5c4a210cd17d8922d4d54d7202f5832726e29b39 Mon Sep 17 00:00:00 2001 From: James Braza Date: Wed, 22 Apr 2026 17:53:23 -0700 Subject: [PATCH] Guard `s_aux` cast in `flash_attention_forward` for sink-less models `flash_attention_forward` unconditionally called `s_aux.to(query.dtype)`, which crashed with `AttributeError: 'NoneType' object has no attribute 'to'` for models that don't use attention sinks (e.g. Gemma). Mirrors the parallel guard added in #40434 for `flash_paged.py`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/transformers/integrations/flash_attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 9f89bfe5778d..4e953aa016f4 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -81,7 +81,11 @@ def flash_attention_forward( target_dtype=target_dtype, attn_implementation=module.config._attn_implementation, layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, - s_aux=s_aux.to(query.dtype), # FA only accepts half precision + s_aux=( + s_aux.to(query.dtype) # FA only accepts half precision + if s_aux is not None + else None + ), **kwargs, )