CUDA: fix FTZ in FA for Gemma 3#13991
Conversation
ggerganov
left a comment
There was a problem hiding this comment.
This seems like a good solution, though I have some small remaining concerns that there might be something else going on. I tried the same approach with the Metal implementation (i.e. keep accumulating the output in F16 and FTZ the scores like in the CUDA code) and Gemma 3 27B keeps outputting garbage for large prompts. Hard to say what is the root cause as the Metal implementation does not provide many tools for debugging.
Anyway, this should be OK to merge since @mostlygeek confirmed to be running, but we should keep an eye out for any remaining issues.
I don't have multimodal Gemma 3 set up
Btw, you don't need multi-modal Gemma to reproduce the issue. Just load the text-only model and ask it to summarize something about ~100k tokens (for example, server.cpp + llama-context.cpp).
Well, I hope not. If the CUDA code had to use FP32 for the accumulation of VKQ that would be a pretty big headache for me due to register pressure. BF16 could partially solve the issue but then the new issue is that not all instructions are available on all GPUs. |
Fixes #12433 (comment) .
What I think is happening is that there is an underflow in the FlashAttention code when rescaling the FP16 VKQ accumulators. This PR flushes the scale to 0 if it's < 2.06e-9. I don't have multimodal Gemma 3 set up, I did not reproduce the issue on my machine.