From 9ffaf569152fbba521015e4121cb6d47c6773399 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 3 Feb 2026 16:26:51 -0600 Subject: [PATCH 1/3] vulkan: make FA mask/softcap enables spec constants --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 61 ++++++++++--------- .../vulkan-shaders/flash_attn.comp | 8 +-- .../vulkan-shaders/flash_attn_base.glsl | 15 ++--- .../vulkan-shaders/flash_attn_cm1.comp | 8 +-- .../vulkan-shaders/flash_attn_cm2.comp | 8 +-- 5 files changed, 52 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 4357da24d42..6f8a5abee7a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -402,19 +402,19 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt) - : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {} + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags) + : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {} uint32_t HSK, HSV; bool small_rows, small_cache; FaCodePath path; bool aligned; bool f32acc; - bool use_mask_opt; + uint32_t flags; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) < - std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt); + return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) < + std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags); } }; @@ -1047,7 +1047,7 @@ struct vk_flash_attn_push_constants { float max_bias; float logit_softcap; - uint32_t mask_n_head_log2; + uint32_t n_head_log2; float m0; float m1; @@ -3193,7 +3193,7 @@ static void ggml_vk_load_shaders(vk_device& device) { return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. @@ -3225,7 +3225,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // AMD prefers loading K directly from global memory const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt}; + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags}; }; #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ @@ -3237,19 +3237,19 @@ static void ggml_vk_load_shaders(vk_device& device) { FaCodePath path = fa.first.path; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ - bool use_mask_opt = fa.first.use_mask_opt; \ + uint32_t flags = fa.first.flags; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -8595,10 +8595,27 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt); + uint32_t flags = (use_mask_opt ? 1 : 0) | + (mask != nullptr ? 2 : 0) | + (logit_softcap != 0 ? 4 : 0) | + (sinks != nullptr ? 8 : 0); + + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags); vk_pipeline pipeline = nullptr; @@ -8678,18 +8695,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; - - memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); - - if (logit_softcap != 0) { - scale /= logit_softcap; - } - const uint32_t n_head_kv = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -8703,8 +8708,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf; vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf; - uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; - if (use_mask_opt) { const vk_op_flash_attn_mask_opt_push_constants opt_pc = { @@ -8735,7 +8738,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx k_stride, (uint32_t)nbk2, (uint32_t)nbk3, v_stride, (uint32_t)nbv2, (uint32_t)nbv3, scale, max_bias, logit_softcap, - mask_n_head_log2, m0, m1, + n_head_log2, m0, m1, gqa_ratio, split_kv, split_k }; if (split_k > 1) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 49a3c530cb6..ea66699bfec 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -127,7 +127,7 @@ void main() { continue; } // Only load if the block is not all zeros - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) { + if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { @@ -181,7 +181,7 @@ void main() { } } - if (p.logit_softcap != 0.0f) { + if (LOGIT_SOFTCAP) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); @@ -189,7 +189,7 @@ void main() { } } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) { + if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { float mvf = masksh[c * cols_per_iter + col_tid][r]; @@ -345,7 +345,7 @@ void main() { return; } - if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + if (SINK_ENABLE) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 252451101ab..9468778fafc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -10,7 +10,12 @@ layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; layout (constant_id = 7) const uint32_t SubGroupSize = 32; layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0; -layout (constant_id = 9) const bool USE_MASK_OPT = false; +layout (constant_id = 9) const uint32_t Flags = 0; + +const bool USE_MASK_OPT = (Flags & 1) != 0; +const bool MASK_ENABLE = (Flags & 2) != 0; +const bool LOGIT_SOFTCAP = (Flags & 4) != 0; +const bool SINK_ENABLE = (Flags & 8) != 0; // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths const uint32_t HSK_pad = (HSK + 15) & ~15; @@ -50,7 +55,7 @@ layout (push_constant) uniform parameter { float max_bias; float logit_softcap; - uint32_t mask_n_head_log2; + uint32_t n_head_log2; float m0; float m1; @@ -59,10 +64,6 @@ layout (push_constant) uniform parameter { uint32_t k_num; } p; -#define SINK_ENABLE_BIT (1<<24) -#define MASK_ENABLE_BIT (1<<16) -#define N_LOG2_MASK 0xFFFF - layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; @@ -160,7 +161,7 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i { const uint32_t h = iq2 + (r % p.gqa_ratio); - uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK; + uint32_t n_head_log2 = p.n_head_log2; const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1); const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 89af3697e1d..ad23066d895 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -160,7 +160,7 @@ void main() { mask_cache[idx] = f16vec4(0); } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE) { if (USE_MASK_OPT && mask_opt_idx != j / 16) { mask_opt_idx = j / 16; @@ -303,7 +303,7 @@ void main() { coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); barrier(); - if (p.logit_softcap != 0.0f) { + if (LOGIT_SOFTCAP) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) / (Br / 4); uint32_t r = (idx + tid) % (Br / 4); @@ -314,7 +314,7 @@ void main() { barrier(); } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) / (Br / 4); uint32_t r = (idx + tid) % (Br / 4); @@ -512,7 +512,7 @@ void main() { return; } - if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + if (SINK_ENABLE) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 47b110621b7..6ad98418ca8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -155,7 +155,7 @@ void main() { for (uint32_t j = start_j; j < end_j; ++j) { coopmat mv = coopmat(0); - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE) { if (USE_MASK_OPT && mask_opt_idx != j / 16) { mask_opt_idx = j / 16; @@ -197,14 +197,14 @@ void main() { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); S = coopMatMulAdd(Qf16, K_T, S); - if (p.logit_softcap != 0.0f) { + if (LOGIT_SOFTCAP) { [[unroll]] for (int k = 0; k < S.length(); ++k) { S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); } } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE) { S += slopeMat*coopmat(mv); } @@ -292,7 +292,7 @@ void main() { // resize L by using smear/reduce coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); - if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + if (SINK_ENABLE) { coopmat S; coopMatPerElementNV(S, S, perElemOpGetSink, iq2); From 3aad36d68999c20d4d716c9819173df91e7adbd3 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 5 Feb 2026 13:42:55 -0600 Subject: [PATCH 2/3] don't specialize for sinks --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 9 +++++---- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl | 8 +++++--- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp | 2 +- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6f8a5abee7a..72097ffd0ff 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1047,7 +1047,7 @@ struct vk_flash_attn_push_constants { float max_bias; float logit_softcap; - uint32_t n_head_log2; + uint32_t mask_n_head_log2; float m0; float m1; @@ -8612,8 +8612,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t flags = (use_mask_opt ? 1 : 0) | (mask != nullptr ? 2 : 0) | - (logit_softcap != 0 ? 4 : 0) | - (sinks != nullptr ? 8 : 0); + (logit_softcap != 0 ? 4 : 0); vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags); @@ -8708,6 +8707,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf; vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf; + uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2; + if (use_mask_opt) { const vk_op_flash_attn_mask_opt_push_constants opt_pc = { @@ -8738,7 +8739,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx k_stride, (uint32_t)nbk2, (uint32_t)nbk3, v_stride, (uint32_t)nbv2, (uint32_t)nbv3, scale, max_bias, logit_softcap, - n_head_log2, m0, m1, + mask_n_head_log2, m0, m1, gqa_ratio, split_kv, split_k }; if (split_k > 1) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index ea66699bfec..914f131c965 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -345,7 +345,7 @@ void main() { return; } - if (SINK_ENABLE) { + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 9468778fafc..74005cffb3f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -15,7 +15,6 @@ layout (constant_id = 9) const uint32_t Flags = 0; const bool USE_MASK_OPT = (Flags & 1) != 0; const bool MASK_ENABLE = (Flags & 2) != 0; const bool LOGIT_SOFTCAP = (Flags & 4) != 0; -const bool SINK_ENABLE = (Flags & 8) != 0; // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths const uint32_t HSK_pad = (HSK + 15) & ~15; @@ -55,7 +54,7 @@ layout (push_constant) uniform parameter { float max_bias; float logit_softcap; - uint32_t n_head_log2; + uint32_t mask_n_head_log2; float m0; float m1; @@ -64,6 +63,9 @@ layout (push_constant) uniform parameter { uint32_t k_num; } p; +#define SINK_ENABLE_BIT (1<<24) +#define N_LOG2_MASK 0xFFFF + layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; @@ -161,7 +163,7 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i { const uint32_t h = iq2 + (r % p.gqa_ratio); - uint32_t n_head_log2 = p.n_head_log2; + uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK; const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1); const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index ad23066d895..b3177738234 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -512,7 +512,7 @@ void main() { return; } - if (SINK_ENABLE) { + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 6ad98418ca8..b07c21f6e55 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -292,7 +292,7 @@ void main() { // resize L by using smear/reduce coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); - if (SINK_ENABLE) { + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { coopmat S; coopMatPerElementNV(S, S, perElemOpGetSink, iq2); From 0f27654388d5d0a308b528cbf9e8d7fb73f21336 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 5 Feb 2026 18:08:13 -0600 Subject: [PATCH 3/3] bump timeout a little bit --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8ce679bd9ab..51a3dc76e9e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -468,7 +468,7 @@ jobs: export GGML_VK_VISIBLE_DEVICES=0 export GGML_VK_DISABLE_F16=1 # This is using llvmpipe and runs slower than other backends - ctest -L main --verbose --timeout 4200 + ctest -L main --verbose --timeout 4800 ubuntu-24-cmake-webgpu: runs-on: ubuntu-24.04