Skip to content

ggml-cuda: native bf16 flash attention for vec kernel#20525

Merged
JohannesGaessler merged 5 commits intoggml-org:masterfrom
eous:bf16-flash-attn
Mar 22, 2026
Merged

ggml-cuda: native bf16 flash attention for vec kernel#20525
JohannesGaessler merged 5 commits intoggml-org:masterfrom
eous:bf16-flash-attn

Conversation

@eous
Copy link
Copy Markdown
Contributor

@eous eous commented Mar 13, 2026

mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo. Only about 1.2% improvement on sm120.

mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo
@eous eous requested a review from JohannesGaessler as a code owner March 13, 2026 20:27
Copilot AI review requested due to automatic review settings March 13, 2026 20:27
@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs python python script changes ggml changes relating to the ggml tensor library for machine learning labels Mar 13, 2026
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds BF16 support for CUDA flash-attention (vec + tile paths) by extending template instantiations and adding BF16 load/convert + dispatch plumbing.

Changes:

  • Extend template-instance generation to include GGML_TYPE_BF16 KV types and add new vec instance .cu files.
  • Add BF16 cases to vec flash-attention dispatch and selection logic.
  • Teach tile/vec kernels + launch_fattn to handle BF16 (including conversion buffers and BF16-aware loads).

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
ggml/src/ggml-cuda/template-instances/generate_cu_files.py Adds BF16 to generated KV type matrix for template instances.
ggml/src/ggml-cuda/template-instances/fattn-vec-instance-*.cu New autogenerated vec instantiations for BF16 combinations.
ggml/src/ggml-cuda/fattn.cu Enables BF16 vec dispatch cases + kernel eligibility.
ggml/src/ggml-cuda/fattn-vec.cuh Treats BF16 similarly to F16 for thread config; plumbs BF16 flags into launcher.
ggml/src/ggml-cuda/fattn-tile.cuh Generalizes KV loads to support BF16 and adds BF16-aware tile kernel launcher wrapper.
ggml/src/ggml-cuda/fattn-common.cuh Adds BF16 vec-dot/dequant, BF16 conversion buffers, and BF16 flags in launch_fattn.
ggml/src/ggml-cuda/CMakeLists.txt Ensures BF16 vec template instance is compiled in the non-FA_ALL_QUANTS path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread ggml/src/ggml-cuda/fattn-common.cuh
Comment thread ggml/src/ggml-cuda/fattn-common.cuh Outdated
Comment thread ggml/src/ggml-cuda/fattn-common.cuh Outdated
Comment thread ggml/src/ggml-cuda/fattn-common.cuh
Comment thread ggml/src/ggml-cuda/fattn-tile.cuh Outdated
Comment on lines +354 to +371
if constexpr (std::is_same_v<T_KV, half2>) {
const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
ggml_cuda_memcpy_1<cpy_nb>(
tile_KV + i*(J/2 + J_padding) + j,
!oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
} else {
const __align__(16) T_KV zero[cpy_ne] = {};
__align__(16) T_KV tmp[cpy_ne];
ggml_cuda_memcpy_1<cpy_nb>(
tmp, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
__align__(16) half2 converted[cpy_ne];
#pragma unroll
for (int l = 0; l < cpy_ne; ++l) {
const float2 f = __bfloat1622float2(tmp[l]);
converted[l] = make_half2(f.x, f.y);
}
ggml_cuda_memcpy_1<sizeof(converted)>(tile_KV + i*(J/2 + J_padding) + j, converted);
}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one might be worth doing?

Comment thread ggml/src/ggml-cuda/fattn-tile.cuh Outdated
Comment on lines +361 to +370
__align__(16) T_KV tmp[cpy_ne];
ggml_cuda_memcpy_1<cpy_nb>(
tmp, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
__align__(16) half2 converted[cpy_ne];
#pragma unroll
for (int l = 0; l < cpy_ne; ++l) {
const float2 f = __bfloat1622float2(tmp[l]);
converted[l] = make_half2(f.x, f.y);
}
ggml_cuda_memcpy_1<sizeof(converted)>(tile_KV + i*(J/2 + J_padding) + j, converted);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure the compiler is just doing the right thing here, but I guess maybe?

@xkmire
Copy link
Copy Markdown

xkmire commented Mar 14, 2026

I tested this PR on my RTX 6000 PRO blackwell. It works great on Qwen 3.5 122B that uses delta attention and I have flash attention turned on.
Even if some of calculations looks to be done in 32-bit precision in
ggml_cuda_mad(sum, __bfloat1622float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
the performance in tokens/second is the same as using f16 on the KV cache (even if there probably are some small difference if measured in detail).
I confirmed that the compiler, for 120a-real architecture uses the 32-bit path.

HUGE improvement vs. the CPU fallback that happened before.

I hope this will go in soon...
Thanks! Made my day.

Comment thread ggml/src/ggml-cuda/CMakeLists.txt Outdated
Comment on lines +96 to +101
#ifdef V_DOT2_F32_F16_AVAILABLE
const float2 bf16_f2 = __bfloat1622float2(tmp[k_KQ_1]);
ggml_cuda_mad(sum, make_half2(bf16_f2.x, bf16_f2.y), ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#else
ggml_cuda_mad(sum, __bfloat1622float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#endif // V_DOT2_F32_F16_AVAILABLE
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the V_DOT2_F32_F16_AVAILABLE path, that instruction does not have a BF16 equivalent. Also use ggml_cuda_cast defined in convert.cuh for the conversion to float2 (this is also what the half2 path on master should have used).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I know you said to remove the V_DOT2 path from vec_dot_bf16 since there is no bf16 equivalent of the v_dot2 instruction. agreed, but there is a downstream issue that kind of forcing us to add it back

On V_DOT2 platforms Q_reg is declared as half2 (fattn-vec.cuh line 133, selected by ifdef V_DOT2_F32_F16_AVAILABLE). the bf16 vec_dot casts Q_v as float2 which would reinterpret those half2 values as float2, producing garbage. the V_DOT2 path in vec_dot_bf16 converts the half2 Q values to float2 before the dot product so both operands are correctly typed

Removing the V_DOT2 path entirely would require also changing how Q_reg is stored for bf16 K on V_DOT2 platforms, which means adding if constexpr guards around the Q_reg declaration and Q loading loops. happy to do that if preferred, but it seemed like a bigger change than warranted

Comment thread ggml/src/ggml-cuda/fattn-common.cuh Outdated
Comment on lines +357 to +366
if constexpr (std::is_same_v<T, half>) {
static_assert(ne % 2 == 0, "bad ne");
__align__(16) nv_bfloat162 tmp[ne/2];
ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0);
half2 * dst_h2 = (half2 *) dst;
#pragma unroll
for (int l = 0; l < ne/2; ++l) {
const float2 f2 = __bfloat1622float2(tmp[l]);
dst_h2[l] = make_half2(f2.x, f2.y);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this code path, it doesn't make sense to use.

Comment thread ggml/src/ggml-cuda/fattn-common.cuh Outdated
float2 * dst_f2 = (float2 *) dst;
#pragma unroll
for (int l = 0; l < ne/2; ++l) {
dst_f2[l] = __bfloat1622float2(tmp[l]);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dst_f2[l] = __bfloat1622float2(tmp[l]);
dst_f2[l] = ggml_cuda_cast<float2>(tmp[l]);

Comment thread ggml/src/ggml-cuda/fattn-vec.cuh Outdated
<D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \

#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \

Comment thread ggml/src/ggml-cuda/fattn-tile.cuh Outdated
Comment on lines +324 to +326
template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check, typename T_KV>
static __device__ __forceinline__ void flash_attn_tile_load_tile(
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
const T_KV * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert the changes to this function. It doesn't make sense to convert BF16 to FP16 so just always use the FP32 path for BF16.

Comment thread ggml/src/ggml-cuda/fattn-tile.cuh Outdated
Comment on lines +426 to +430
if constexpr (std::is_same_v<T_KV, half2>) {
tmp_f2[l] = __half22float2(tmp_kv[l]);
} else {
tmp_f2[l] = __bfloat1622float2(tmp_kv[l]);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if constexpr (std::is_same_v<T_KV, half2>) {
tmp_f2[l] = __half22float2(tmp_kv[l]);
} else {
tmp_f2[l] = __bfloat1622float2(tmp_kv[l]);
}
tmp_f2[l] = ggml_cuda_cast<float2>(tmp_kv[l]);

Add a new case to ggml_cuda_cast if necessary.

Comment thread ggml/src/ggml-cuda/fattn-tile.cuh Outdated
}

template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, typename T_KV = half2> // D == head size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, typename T_KV = half2> // D == head size
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, typename T_KV> // D == head size

I think the type should be explicitly required to avoid accidental misuse.

Comment thread ggml/src/ggml-cuda/fattn-tile.cuh Outdated
static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
T_vec_dot * const Q_tmp,
const half2 * const __restrict__ K_h2,
const T_KV * const __restrict__ K_kv,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const T_KV * const __restrict__ K_kv,
const T_KV * const __restrict__ K_t2,

I think _kv is a poor choice for the suffix name. I don't particularly like _t2 either but I think it's less bad.

Comment thread ggml/src/ggml-cuda/fattn-tile.cuh Outdated
Comment on lines +1112 to +1113
const ggml_tensor * K = dst->src[1];
const bool bf16_kv = K->type == GGML_TYPE_BF16;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const ggml_tensor * K = dst->src[1];
const bool bf16_kv = K->type == GGML_TYPE_BF16;
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
GGML_ASSERT(K->type == V->type);
const bool bf16_kv = K->type == GGML_TYPE_BF16;

reverted tile kernel changes to avoid larger refactor
@eous
Copy link
Copy Markdown
Contributor Author

eous commented Mar 14, 2026

addressed all review feedback

  • removed V_DOT2 path from bf16 vec_dot, bf16 always goes through float
  • switched to ggml_cuda_cast for all bf16 conversions instead of raw intrinsics
  • removed the half dequantize path for bf16 V, float only with static_assert
  • reverted all launch_fattn changes, no bf16 conversion params needed
  • reverted all tile kernel changes, the proper fix needs an if constexpr refactor of the fp16/fp32 path selection which is a bigger change than makes sense here. tile and mma fall back to bf16 to fp16 pre conversion for now
  • cmake uses explicit filenames in one list() call instead of globs
  • added ggml_cuda_cast(nv_bfloat162) overload to convert.cuh
  • fixed trailing whitespace in EXTERN_DECL macro

benchmarked on RTX PRO 6000 Blackwell with qwen 3.5 27B BF16 up to 65k context, no regressions

@eous eous changed the title ggml-cuda: native bf16 flash attention for vec and tile kernels ggml-cuda: native bf16 flash attention for vec Mar 15, 2026
@eous eous changed the title ggml-cuda: native bf16 flash attention for vec ggml-cuda: native bf16 flash attention for vec kernel Mar 15, 2026
@eous
Copy link
Copy Markdown
Contributor Author

eous commented Mar 15, 2026

Working through the V_DOT2 path, I am starting to appreciate the layout of the cuda kernels and their cross arch compatibility. This CI setup is great.

@eous eous requested a review from a team as a code owner March 15, 2026 19:51
@xkmire
Copy link
Copy Markdown

xkmire commented Mar 19, 2026

Is there something specific hindering this to go in before it gets old?
It is important to have when you run Qwen 3.5 with deltanet attention on GPU.
Anything I can help with?

@am17an
Copy link
Copy Markdown
Contributor

am17an commented Mar 22, 2026

@JohannesGaessler merge?

Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were AMD-specific changes changes after I had approved so I wanted to re-review. The reason you would want to use BF16 vs. FP16 is the increased numerical range. So the AMD workaround where BF16 is cast to FP16 makes the data type essentially universally worse than FP16. I think the vector kernel needs to be refactored though I cannot expect someone without access to AMD hardware to do this - I added a comment to the code indicating that this needs to be fixed (by me).

@JohannesGaessler
Copy link
Copy Markdown
Contributor

@am17an my approval has become stale since after pushing the comments I am now the last pusher. Can you approve instead?

@JohannesGaessler JohannesGaessler merged commit db9d8aa into ggml-org:master Mar 22, 2026
1 check passed
@CISC
Copy link
Copy Markdown
Member

CISC commented Mar 22, 2026

@Orcazephyr
Copy link
Copy Markdown

@JohannesGaessler

I'm very interested in the optimizations to amd which would let hip use the new bfloat16 path without any casts to float16. I'd be willing to test on my 7900xtx if you still plan to implement those paths and you don't have access to amd hardware. Thanks! Otherwise looking forward to that.

@eous eous deleted the bf16-flash-attn branch April 5, 2026 21:47
Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
* ggml-cuda: native bf16 flash attention for vec and tile kernels

mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo

* ggml-cuda: address code owner review feedback

reverted tile kernel changes to avoid larger refactor

* fix ci failures on turing and hip

* fix bf16 vec kernel compile on hip v_dot2 platforms

* add comments

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
rsenthilkumar6 pushed a commit to rsenthilkumar6/llama.cpp that referenced this pull request May 1, 2026
* ggml-cuda: native bf16 flash attention for vec and tile kernels

mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo

* ggml-cuda: address code owner review feedback

reverted tile kernel changes to avoid larger refactor

* fix ci failures on turing and hip

* fix bf16 vec kernel compile on hip v_dot2 platforms

* add comments

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants