ggml-cuda: native bf16 flash attention for vec kernel#20525
ggml-cuda: native bf16 flash attention for vec kernel#20525JohannesGaessler merged 5 commits intoggml-org:masterfrom
Conversation
mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo
There was a problem hiding this comment.
Pull request overview
Adds BF16 support for CUDA flash-attention (vec + tile paths) by extending template instantiations and adding BF16 load/convert + dispatch plumbing.
Changes:
- Extend template-instance generation to include
GGML_TYPE_BF16KV types and add new vec instance.cufiles. - Add BF16 cases to vec flash-attention dispatch and selection logic.
- Teach tile/vec kernels +
launch_fattnto handle BF16 (including conversion buffers and BF16-aware loads).
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| ggml/src/ggml-cuda/template-instances/generate_cu_files.py | Adds BF16 to generated KV type matrix for template instances. |
| ggml/src/ggml-cuda/template-instances/fattn-vec-instance-*.cu | New autogenerated vec instantiations for BF16 combinations. |
| ggml/src/ggml-cuda/fattn.cu | Enables BF16 vec dispatch cases + kernel eligibility. |
| ggml/src/ggml-cuda/fattn-vec.cuh | Treats BF16 similarly to F16 for thread config; plumbs BF16 flags into launcher. |
| ggml/src/ggml-cuda/fattn-tile.cuh | Generalizes KV loads to support BF16 and adds BF16-aware tile kernel launcher wrapper. |
| ggml/src/ggml-cuda/fattn-common.cuh | Adds BF16 vec-dot/dequant, BF16 conversion buffers, and BF16 flags in launch_fattn. |
| ggml/src/ggml-cuda/CMakeLists.txt | Ensures BF16 vec template instance is compiled in the non-FA_ALL_QUANTS path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| if constexpr (std::is_same_v<T_KV, half2>) { | ||
| const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}}; | ||
| ggml_cuda_memcpy_1<cpy_nb>( | ||
| tile_KV + i*(J/2 + J_padding) + j, | ||
| !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); | ||
| } else { | ||
| const __align__(16) T_KV zero[cpy_ne] = {}; | ||
| __align__(16) T_KV tmp[cpy_ne]; | ||
| ggml_cuda_memcpy_1<cpy_nb>( | ||
| tmp, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); | ||
| __align__(16) half2 converted[cpy_ne]; | ||
| #pragma unroll | ||
| for (int l = 0; l < cpy_ne; ++l) { | ||
| const float2 f = __bfloat1622float2(tmp[l]); | ||
| converted[l] = make_half2(f.x, f.y); | ||
| } | ||
| ggml_cuda_memcpy_1<sizeof(converted)>(tile_KV + i*(J/2 + J_padding) + j, converted); | ||
| } |
There was a problem hiding this comment.
This one might be worth doing?
| __align__(16) T_KV tmp[cpy_ne]; | ||
| ggml_cuda_memcpy_1<cpy_nb>( | ||
| tmp, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); | ||
| __align__(16) half2 converted[cpy_ne]; | ||
| #pragma unroll | ||
| for (int l = 0; l < cpy_ne; ++l) { | ||
| const float2 f = __bfloat1622float2(tmp[l]); | ||
| converted[l] = make_half2(f.x, f.y); | ||
| } | ||
| ggml_cuda_memcpy_1<sizeof(converted)>(tile_KV + i*(J/2 + J_padding) + j, converted); |
There was a problem hiding this comment.
Pretty sure the compiler is just doing the right thing here, but I guess maybe?
|
I tested this PR on my RTX 6000 PRO blackwell. It works great on Qwen 3.5 122B that uses delta attention and I have flash attention turned on. HUGE improvement vs. the CPU fallback that happened before. I hope this will go in soon... |
| #ifdef V_DOT2_F32_F16_AVAILABLE | ||
| const float2 bf16_f2 = __bfloat1622float2(tmp[k_KQ_1]); | ||
| ggml_cuda_mad(sum, make_half2(bf16_f2.x, bf16_f2.y), ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); | ||
| #else | ||
| ggml_cuda_mad(sum, __bfloat1622float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); | ||
| #endif // V_DOT2_F32_F16_AVAILABLE |
There was a problem hiding this comment.
Remove the V_DOT2_F32_F16_AVAILABLE path, that instruction does not have a BF16 equivalent. Also use ggml_cuda_cast defined in convert.cuh for the conversion to float2 (this is also what the half2 path on master should have used).
There was a problem hiding this comment.
So I know you said to remove the V_DOT2 path from vec_dot_bf16 since there is no bf16 equivalent of the v_dot2 instruction. agreed, but there is a downstream issue that kind of forcing us to add it back
On V_DOT2 platforms Q_reg is declared as half2 (fattn-vec.cuh line 133, selected by ifdef V_DOT2_F32_F16_AVAILABLE). the bf16 vec_dot casts Q_v as float2 which would reinterpret those half2 values as float2, producing garbage. the V_DOT2 path in vec_dot_bf16 converts the half2 Q values to float2 before the dot product so both operands are correctly typed
Removing the V_DOT2 path entirely would require also changing how Q_reg is stored for bf16 K on V_DOT2 platforms, which means adding if constexpr guards around the Q_reg declaration and Q loading loops. happy to do that if preferred, but it seemed like a bigger change than warranted
| if constexpr (std::is_same_v<T, half>) { | ||
| static_assert(ne % 2 == 0, "bad ne"); | ||
| __align__(16) nv_bfloat162 tmp[ne/2]; | ||
| ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0); | ||
| half2 * dst_h2 = (half2 *) dst; | ||
| #pragma unroll | ||
| for (int l = 0; l < ne/2; ++l) { | ||
| const float2 f2 = __bfloat1622float2(tmp[l]); | ||
| dst_h2[l] = make_half2(f2.x, f2.y); | ||
| } |
There was a problem hiding this comment.
Remove this code path, it doesn't make sense to use.
| float2 * dst_f2 = (float2 *) dst; | ||
| #pragma unroll | ||
| for (int l = 0; l < ne/2; ++l) { | ||
| dst_f2[l] = __bfloat1622float2(tmp[l]); |
There was a problem hiding this comment.
| dst_f2[l] = __bfloat1622float2(tmp[l]); | |
| dst_f2[l] = ggml_cuda_cast<float2>(tmp[l]); |
| <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ | ||
|
|
||
| #define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ | ||
| #define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ |
There was a problem hiding this comment.
| #define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ | |
| #define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ |
| template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check> | ||
| template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check, typename T_KV> | ||
| static __device__ __forceinline__ void flash_attn_tile_load_tile( | ||
| const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { | ||
| const T_KV * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { |
There was a problem hiding this comment.
Revert the changes to this function. It doesn't make sense to convert BF16 to FP16 so just always use the FP32 path for BF16.
| if constexpr (std::is_same_v<T_KV, half2>) { | ||
| tmp_f2[l] = __half22float2(tmp_kv[l]); | ||
| } else { | ||
| tmp_f2[l] = __bfloat1622float2(tmp_kv[l]); | ||
| } |
There was a problem hiding this comment.
| if constexpr (std::is_same_v<T_KV, half2>) { | |
| tmp_f2[l] = __half22float2(tmp_kv[l]); | |
| } else { | |
| tmp_f2[l] = __bfloat1622float2(tmp_kv[l]); | |
| } | |
| tmp_f2[l] = ggml_cuda_cast<float2>(tmp_kv[l]); | |
Add a new case to ggml_cuda_cast if necessary.
| } | ||
|
|
||
| template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size | ||
| template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, typename T_KV = half2> // D == head size |
There was a problem hiding this comment.
| template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, typename T_KV = half2> // D == head size | |
| template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, typename T_KV> // D == head size |
I think the type should be explicitly required to avoid accidental misuse.
| static __device__ __forceinline__ void flash_attn_tile_iter_KQ( | ||
| T_vec_dot * const Q_tmp, | ||
| const half2 * const __restrict__ K_h2, | ||
| const T_KV * const __restrict__ K_kv, |
There was a problem hiding this comment.
| const T_KV * const __restrict__ K_kv, | |
| const T_KV * const __restrict__ K_t2, |
I think _kv is a poor choice for the suffix name. I don't particularly like _t2 either but I think it's less bad.
| const ggml_tensor * K = dst->src[1]; | ||
| const bool bf16_kv = K->type == GGML_TYPE_BF16; |
There was a problem hiding this comment.
| const ggml_tensor * K = dst->src[1]; | |
| const bool bf16_kv = K->type == GGML_TYPE_BF16; | |
| const ggml_tensor * K = dst->src[1]; | |
| const ggml_tensor * V = dst->src[2]; | |
| GGML_ASSERT(K->type == V->type); | |
| const bool bf16_kv = K->type == GGML_TYPE_BF16; |
reverted tile kernel changes to avoid larger refactor
|
addressed all review feedback
benchmarked on RTX PRO 6000 Blackwell with qwen 3.5 27B BF16 up to 65k context, no regressions |
|
Working through the V_DOT2 path, I am starting to appreciate the layout of the cuda kernels and their cross arch compatibility. This CI setup is great. |
|
Is there something specific hindering this to go in before it gets old? |
|
@JohannesGaessler merge? |
JohannesGaessler
left a comment
There was a problem hiding this comment.
There were AMD-specific changes changes after I had approved so I wanted to re-review. The reason you would want to use BF16 vs. FP16 is the increased numerical range. So the AMD workaround where BF16 is cast to FP16 makes the data type essentially universally worse than FP16. I think the vector kernel needs to be refactored though I cannot expect someone without access to AMD hardware to do this - I added a comment to the code indicating that this needs to be fixed (by me).
|
@am17an my approval has become stale since after pushing the comments I am now the last pusher. Can you approve instead? |
|
Looks like this warrants an extra guard: |
|
I'm very interested in the optimizations to amd which would let hip use the new bfloat16 path without any casts to float16. I'd be willing to test on my 7900xtx if you still plan to implement those paths and you don't have access to amd hardware. Thanks! Otherwise looking forward to that. |
* ggml-cuda: native bf16 flash attention for vec and tile kernels mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo * ggml-cuda: address code owner review feedback reverted tile kernel changes to avoid larger refactor * fix ci failures on turing and hip * fix bf16 vec kernel compile on hip v_dot2 platforms * add comments --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
* ggml-cuda: native bf16 flash attention for vec and tile kernels mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo * ggml-cuda: address code owner review feedback reverted tile kernel changes to avoid larger refactor * fix ci failures on turing and hip * fix bf16 vec kernel compile on hip v_dot2 platforms * add comments --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo. Only about 1.2% improvement on sm120.