From 1385fe3ff5588743e79e7cccc325232f1dcce150 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 3 Feb 2026 07:59:55 +0100 Subject: [PATCH 01/46] vulkan: allow using fp16 in coopmat1 flash attention shader --- .../vulkan-shaders/flash_attn.comp | 85 ++++++++++--------- .../vulkan-shaders/flash_attn_cm1.comp | 6 +- .../vulkan-shaders/vulkan-shaders-gen.cpp | 82 ++++++++++-------- 3 files changed, 96 insertions(+), 77 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0735f678549..b7d232eb66a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -3,9 +3,13 @@ #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_subgroup_extended_types_float16 : require +#endif + #extension GL_KHR_shader_subgroup_shuffle : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -29,18 +33,18 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];}; // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. -D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { uint32_t offset = (iq2 + r) * HSV + c; data_o[o_offset + offset] = D_TYPE(elem); return elem; } -shared FLOAT_TYPE tmpsh[WorkGroupSize]; +shared float tmpsh[WorkGroupSize]; shared vec4 tmpshv4[WorkGroupSize]; -shared float masksh[Bc][Br]; -shared vec4 Qf[Br][HSK / 4]; +shared FLOAT_TYPE masksh[Bc][Br]; +shared FLOAT_TYPEV4 Qf[Br][HSK / 4]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -60,15 +64,15 @@ void main() { uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + Qf[r][d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - vec4 Of[Br][HSV_per_thread / 4]; + ACC_TYPEV4 Of[Br][HSV_per_thread / 4]; [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] = vec4(0.0); + Of[r][d] = ACC_TYPEV4(0.0); } } @@ -82,9 +86,9 @@ void main() { Mf[r] = NEG_FLT_MAX_OVER_2; } - float slope[Br]; + ACC_TYPE slope[Br]; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - slope[r] = 1.0; + slope[r] = ACC_TYPE(1.0); } // ALiBi @@ -136,11 +140,11 @@ void main() { uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); masksh[c][r] = m; - max_mask = max(max_mask, m); + max_mask = max(max_mask, float(m)); } else { - masksh[c][r] = float(0); + masksh[c][r] = FLOAT_TYPE(0); } } } @@ -159,10 +163,10 @@ void main() { } } - float Sf[Br][cols_per_thread]; + ACC_TYPE Sf[Br][cols_per_thread]; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Sf[r][c] = 0.0; + Sf[r][c] = ACC_TYPE(0.0); } } @@ -176,12 +180,12 @@ void main() { uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); #else - vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + Sf[r][c] += ACC_TYPE(dot(Qf[r][d * D_split + d_tid], K_Tf)); } } } @@ -198,7 +202,7 @@ void main() { if (LOGIT_SOFTCAP) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c])); } } } @@ -206,7 +210,7 @@ void main() { if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float mvf = masksh[c * cols_per_iter + col_tid][r]; + FLOAT_TYPE mvf = masksh[c * cols_per_iter + col_tid][r]; Sf[r][c] += slope[r]*mvf; } @@ -214,41 +218,42 @@ void main() { barrier(); } - float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; + FLOAT_TYPE Pf[Br][cols_per_thread]; + float eMf[Br]; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - rowmaxf[r] = NEG_FLT_MAX_OVER_2; + float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + rowmaxf = max(rowmaxf, float(Sf[r][c])); } - Moldf[r] = Mf[r]; + float Moldf = Mf[r]; // M = max(rowmax, Mold) // P = e^(S - M) // eM = e^(Mold - M) - Mf[r] = max(rowmaxf[r], Moldf[r]); + Mf[r] = max(rowmaxf, Moldf); [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Pf[r][c] = exp(Sf[r][c] - Mf[r]); + Pf[r][c] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r])); } - eMf[r] = exp(Moldf[r] - Mf[r]); + eMf[r] = exp(Moldf - Mf[r]); // Compute sum across row of P - rowsumf[r] = 0.0; + float rowsumf = 0.0; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - rowsumf[r] += Pf[r][c]; + rowsumf += Pf[r][c]; } - Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; + Lf[r] = eMf[r]*Lf[r] + rowsumf; } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] = eMf[r] * Of[r][d]; + Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; } } @@ -261,12 +266,12 @@ void main() { uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); #else - vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); + FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] += Pf[r][c] * Vf; + Of[r][d] += ACC_TYPE(Pf[r][c] * Vf); } } } @@ -318,18 +323,18 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] = eMf * Of[r][d]; + Of[r][d] = ACC_TYPE(eMf) * Of[r][d]; tmpshv4[tid] = Of[r][d]; barrier(); [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { if (tid < s) { - Of[r][d] += tmpshv4[tid + s]; + Of[r][d] += ACC_TYPEV4(tmpshv4[tid + s]); tmpshv4[tid] = Of[r][d]; } barrier(); } - Of[r][d] = tmpshv4[d_tid]; + Of[r][d] = ACC_TYPEV4(tmpshv4[d_tid]); barrier(); } } @@ -373,7 +378,7 @@ void main() { ms = exp(Mf[r] - sink); [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] *= ms; + Of[r][d] *= ACC_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -390,9 +395,9 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] *= Lfrcp[r]; + Of[r][d] *= ACC_TYPE(Lfrcp[r]); #if defined(ACC_TYPE_MAX) - Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX)); + Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); #endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 19630972daf..49db2b45a42 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -35,7 +35,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];}; // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. -D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { uint32_t offset = (iq2 + r) * HSV + c; data_o[o_offset + offset] = D_TYPE(elem); @@ -510,7 +510,7 @@ void main() { if (d >= HSV/4) break; const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); + perElemOpGqaStore(tile_row(r), 4 * d + comp, Of[r][d_local][comp], o_offset, iq2, N); } } } @@ -574,7 +574,7 @@ void main() { if (d >= HSV / 4) break; const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); + perElemOpGqaStore(tile_row(r), 4 * d + comp, Of[r][d_local][comp], o_offset, iq2, N); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 42ebc21e2a6..b244f9fa5f7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -595,8 +595,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } void process_shaders() { - std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; - // matmul for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) { // No coopmats @@ -622,49 +620,65 @@ void process_shaders() { } } - // flash attention - for (const auto& f16acc : {false, true}) { - std::map fa_base_dict = base_dict; - fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; - fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; - if (f16acc) { - fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; + for (const bool& fp16 : {false, true}) { + std::map base_dict; + if (fp16) { + base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}}; + } else { + base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}}; } - for (const auto& tname : type_names) { - if (tname == "bf16") continue; + // flash attention + for (const bool& f16acc : {false, true}) { + if (!fp16 && f16acc) continue; -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc); - } else { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); + std::map fa_base_dict = base_dict; + fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; + if (f16acc) { + fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; } + + for (const auto& tname : type_names) { + if (tname == "bf16") continue; + + if (fp16) { +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, false, true, f16acc); + } else { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc); + } #endif #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); - } + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); + } #endif - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + } + + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, false, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); + } } } } + std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; + for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); From 3419b4136aeb309af6fcd8a669cc2bde2bc9bbc1 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 3 Feb 2026 09:08:35 +0100 Subject: [PATCH 02/46] split rows inside of subgroups for faster synchronization --- .../vulkan-shaders/flash_attn.comp | 100 ++++++++++-------- 1 file changed, 56 insertions(+), 44 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index b7d232eb66a..af0b371cfa0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -19,7 +19,9 @@ const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; -const uint32_t cols_per_iter = WorkGroupSize / D_split; +const uint32_t row_split = 4; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; @@ -44,7 +46,9 @@ shared float tmpsh[WorkGroupSize]; shared vec4 tmpshv4[WorkGroupSize]; shared FLOAT_TYPE masksh[Bc][Br]; -shared FLOAT_TYPEV4 Qf[Br][HSK / 4]; + +const uint qfstride = HSK / 4 + 1; +shared FLOAT_TYPEV4 Qf[Br * qfstride]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -54,8 +58,12 @@ void main() { init_indices(); const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; const uint32_t d_tid = gl_LocalInvocationIndex % D_split; - const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + +#define tile_row(r) (row_tid * rows_per_thread + (r)) uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4; @@ -64,37 +72,37 @@ void main() { uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r][d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + Qf[r * qfstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - ACC_TYPEV4 Of[Br][HSV_per_thread / 4]; + ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] = ACC_TYPEV4(0.0); } } - float Lf[Br], Mf[Br]; + float Lf[rows_per_thread], Mf[rows_per_thread]; // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lf[r] = 0; Mf[r] = NEG_FLT_MAX_OVER_2; } - ACC_TYPE slope[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + ACC_TYPE slope[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { slope[r] = ACC_TYPE(1.0); } // ALiBi if (p.max_bias > 0.0f) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2); } } @@ -163,8 +171,8 @@ void main() { } } - ACC_TYPE Sf[Br][cols_per_thread]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + ACC_TYPE Sf[rows_per_thread][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { Sf[r][c] = ACC_TYPE(0.0); } @@ -184,8 +192,8 @@ void main() { #else FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Sf[r][c] += ACC_TYPE(dot(Qf[r][d * D_split + d_tid], K_Tf)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qfstride + d * D_split + d_tid], K_Tf)); } } } @@ -193,14 +201,14 @@ void main() { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { // Compute sum across the D_split [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); } } } if (LOGIT_SOFTCAP) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c])); } @@ -209,8 +217,8 @@ void main() { if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - FLOAT_TYPE mvf = masksh[c * cols_per_iter + col_tid][r]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE mvf = masksh[c * cols_per_iter + col_tid][tile_row(r)]; Sf[r][c] += slope[r]*mvf; } @@ -218,9 +226,9 @@ void main() { barrier(); } - FLOAT_TYPE Pf[Br][cols_per_thread]; - float eMf[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + FLOAT_TYPE Pf[rows_per_thread][cols_per_thread]; + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { @@ -252,7 +260,7 @@ void main() { } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; } } @@ -270,7 +278,7 @@ void main() { #else FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] += ACC_TYPE(Pf[r][c] * Vf); } } @@ -284,7 +292,7 @@ void main() { // reduce across threads - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { float rowmaxf, eMf; tmpsh[tid] = Mf[r]; @@ -346,21 +354,23 @@ void main() { // note: O and Q have swapped coord 1,2. uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + perElemOpGqaStore(row, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); } } } } o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { - perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); } } @@ -368,8 +378,8 @@ void main() { } if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); float ms = 1.0f; float vs = 1.0f; @@ -388,13 +398,13 @@ void main() { } } - float Lfrcp[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] *= ACC_TYPE(Lfrcp[r]); #if defined(ACC_TYPE_MAX) Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); @@ -405,21 +415,23 @@ void main() { uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + perElemOpGqaStore(row, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); } } } } } else { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (i * Br + r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (i * Br + row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + data_o[o_offset + iq2 * HSV + (i * Br + row) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); } } } From e924abe23d3c6a9d28def42fafaeefe2f3731a95 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 5 Feb 2026 12:51:59 +0100 Subject: [PATCH 03/46] use row_split when Br >= 4, change reductions to use shared memory if row_split == 1 --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 12 +-- .../vulkan-shaders/flash_attn.comp | 73 +++++++++++-------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 8 +- 3 files changed, 52 insertions(+), 41 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a8840a0773b..41de4c36fe0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3215,7 +3215,7 @@ static void ggml_vk_load_shaders(vk_device& device) { wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc break; default: - wg_size = scalar_flash_attention_workgroup_size; + wg_size = device->subgroup_size * 4; break; } @@ -3245,15 +3245,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -16094,7 +16094,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * ggml_vk_print_graph_origin(tensor, done); } - if (avg_err > 0.5 || std::isnan(avg_err)) { + if (avg_err > 0.01 || std::isnan(avg_err)) { std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; if (src0 != nullptr) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index af0b371cfa0..49d50ed8548 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -19,10 +19,11 @@ const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; -const uint32_t row_split = 4; +const uint32_t row_split = (Br < 4) ? 1 : 4; const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; +const uint32_t num_subgroups = WorkGroupSize / SubGroupSize; layout (binding = 0) readonly buffer Q {float data_q[];}; @@ -42,8 +43,10 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_ return elem; } -shared float tmpsh[WorkGroupSize]; -shared vec4 tmpshv4[WorkGroupSize]; +const uint32_t tmpsh_reduction_size = row_split == 1 ? num_subgroups * D_split : 0; +const uint32_t tmpsh_size = tmpsh_reduction_size > 4 ? tmpsh_reduction_size : 4; +shared float tmpsh[tmpsh_size]; +shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; shared FLOAT_TYPE masksh[Bc][Br]; @@ -279,7 +282,7 @@ void main() { FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] += ACC_TYPE(Pf[r][c] * Vf); + Of[r][d] += ACC_TYPEV4(Pf[r][c] * Vf); } } } @@ -293,57 +296,67 @@ void main() { // reduce across threads [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - float rowmaxf, eMf; + float rowmaxf = Mf[r]; - tmpsh[tid] = Mf[r]; // Compute max across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s)); + } + if (row_split == 1) { + // Reduce inside workgroup with shmem + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf; } barrier(); + rowmaxf = max(max(max(tmpsh[0 * D_split + d_tid], + tmpsh[1 * D_split + d_tid]), + tmpsh[2 * D_split + d_tid]), + tmpsh[3 * D_split + d_tid]); } - rowmaxf = tmpsh[d_tid]; - barrier(); float Moldf = Mf[r]; // M = max(rowmax, Mold) // eM = e^(Mold - M) Mf[r] = max(rowmaxf, Moldf); - eMf = exp(Moldf - Mf[r]); + float eMf = exp(Moldf - Mf[r]); Lf[r] = eMf*Lf[r]; - tmpsh[tid] = Lf[r]; - // Compute sum across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Lf[r] += subgroupShuffleXor(Lf[r], s); + } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r]; } barrier(); + Lf[r] = tmpsh[0 * D_split + d_tid] + + tmpsh[1 * D_split + d_tid] + + tmpsh[2 * D_split + d_tid] + + tmpsh[3 * D_split + d_tid]; } - Lf[r] = tmpsh[d_tid]; - barrier(); [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] = ACC_TYPE(eMf) * Of[r][d]; - tmpshv4[tid] = Of[r][d]; - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - Of[r][d] += ACC_TYPEV4(tmpshv4[tid + s]); - tmpshv4[tid] = Of[r][d]; + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Of[r][d] += subgroupShuffleXor(Of[r][d], s); + } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; } barrier(); + Of[r][d] = tmpsh_accv4[0 * D_split + d_tid] + + tmpsh_accv4[1 * D_split + d_tid] + + tmpsh_accv4[2 * D_split + d_tid] + + tmpsh_accv4[3 * D_split + d_tid]; } - Of[r][d] = ACC_TYPEV4(tmpshv4[d_tid]); - barrier(); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index b244f9fa5f7..7fbe45d33f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -630,12 +630,10 @@ void process_shaders() { // flash attention for (const bool& f16acc : {false, true}) { - if (!fp16 && f16acc) continue; - std::map fa_base_dict = base_dict; - fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; - fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; - if (f16acc) { + fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4"; + if (fp16 && f16acc) { fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; } From 29751e28087d68526e89bdbd471f2de3ceaa85c9 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 5 Feb 2026 14:30:36 +0100 Subject: [PATCH 04/46] use f32 scalar FA if f16 is not supported by device --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 41de4c36fe0..1f622189eff 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3259,10 +3259,17 @@ static void ggml_vk_load_shaders(vk_device& device) { } \ } - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + if (device->fp16) { + CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + } else { + CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + } #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) From 65124f199ae98e8e561d8774511ac69667b294ca Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 5 Feb 2026 17:17:04 +0100 Subject: [PATCH 05/46] fix amd workgroup size issue --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++--------- .../vulkan-shaders/flash_attn.comp | 24 +++++++++---------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1f622189eff..cbd499adb0d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2761,11 +2761,11 @@ static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) { if (hsv >= 192) { - return 2; + return 8; } else if ((hsv | hsk) & 8 || small_cache) { - return 4; - } else { return 8; + } else { + return 16; } } @@ -2791,13 +2791,7 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if (small_rows) { return {scalar_flash_attention_num_small_rows, 64}; } else { - if ((hsv | hsk) & 8) { - // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter - // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; - } else { - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32}; - } + return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; } } @@ -3215,7 +3209,11 @@ static void ggml_vk_load_shaders(vk_device& device) { wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc break; default: - wg_size = device->subgroup_size * 4; + if (device->subgroup_size > 32 && rows_cols[0] < 4) { + wg_size = device->subgroup_size * 2; + } else { + wg_size = device->subgroup_size * 4; + } break; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 49d50ed8548..223b58d8ef9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -309,10 +309,10 @@ void main() { tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf; } barrier(); - rowmaxf = max(max(max(tmpsh[0 * D_split + d_tid], - tmpsh[1 * D_split + d_tid]), - tmpsh[2 * D_split + d_tid]), - tmpsh[3 * D_split + d_tid]); + rowmaxf = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]); + } } float Moldf = Mf[r]; @@ -334,10 +334,10 @@ void main() { tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r]; } barrier(); - Lf[r] = tmpsh[0 * D_split + d_tid] + - tmpsh[1 * D_split + d_tid] + - tmpsh[2 * D_split + d_tid] + - tmpsh[3 * D_split + d_tid]; + Lf[r] = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Lf[r] += tmpsh[s * D_split + d_tid]; + } } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { @@ -352,10 +352,10 @@ void main() { tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; } barrier(); - Of[r][d] = tmpsh_accv4[0 * D_split + d_tid] + - tmpsh_accv4[1 * D_split + d_tid] + - tmpsh_accv4[2 * D_split + d_tid] + - tmpsh_accv4[3 * D_split + d_tid]; + Of[r][d] = tmpsh_accv4[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Of[r][d] += tmpsh_accv4[s * D_split + d_tid]; + } } } } From ee170505622bbca5be0f6afdef348990354573c2 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Fri, 6 Feb 2026 13:32:33 +0100 Subject: [PATCH 06/46] optimize masksh use --- .../vulkan-shaders/flash_attn.comp | 85 +++++++++---------- .../vulkan-shaders/flash_attn_cm1.comp | 8 +- 2 files changed, 45 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 223b58d8ef9..66c892591a5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -43,8 +43,7 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_ return elem; } -const uint32_t tmpsh_reduction_size = row_split == 1 ? num_subgroups * D_split : 0; -const uint32_t tmpsh_size = tmpsh_reduction_size > 4 ? tmpsh_reduction_size : 4; +const uint32_t tmpsh_size = row_split == 1 ? num_subgroups * D_split : 1; shared float tmpsh[tmpsh_size]; shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; @@ -128,50 +127,52 @@ void main() { uint32_t mask_opt = 0; uint32_t mask_opt_idx = ~0; + uint32_t mask_opt_bits = 0; [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - - if (USE_MASK_OPT && mask_opt_idx != j / 16) { - mask_opt_idx = j / 16; - mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; - } - uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; - if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { - // skip this block - continue; - } - // Only load if the block is not all zeros - if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - float max_mask = NEG_FLT_MAX_OVER_2; - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); - masksh[c][r] = m; - max_mask = max(max_mask, float(m)); - } else { - masksh[c][r] = FLOAT_TYPE(0); - } - } - } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + if (MASK_ENABLE) { + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { + mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block continue; } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + float max_mask = NEG_FLT_MAX_OVER_2; + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { + FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c][r] = m; + max_mask = max(max_mask, float(m)); + } else { + masksh[c][r] = FLOAT_TYPE(0); + } + } + } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } + } } ACC_TYPE Sf[rows_per_thread][cols_per_thread]; @@ -181,7 +182,6 @@ void main() { } } - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; @@ -226,7 +226,6 @@ void main() { Sf[r][c] += slope[r]*mvf; } } - barrier(); } FLOAT_TYPE Pf[rows_per_thread][cols_per_thread]; @@ -286,8 +285,6 @@ void main() { } } } - - barrier(); } // prevent race on tmpsh diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 49db2b45a42..68bef90e48a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -153,22 +153,22 @@ void main() { uint32_t mask_opt = 0; uint32_t mask_opt_idx = ~0; + uint32_t mask_opt_bits = 0; + f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) { mask_cache[idx] = f16vec4(0); } if (MASK_ENABLE) { - if (USE_MASK_OPT && mask_opt_idx != j / 16) { mask_opt_idx = j / 16; mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; } - uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { // skip this block continue; @@ -329,7 +329,7 @@ void main() { barrier(); } - if (MASK_ENABLE) { + if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) / (Br / 4); uint32_t r = (idx + tid) % (Br / 4); From 21868388a31a80b482ac99027ca8d73724ec3d40 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Fri, 6 Feb 2026 14:36:31 +0100 Subject: [PATCH 07/46] add medium rows FA shader Br size --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 115 ++++++++++++++------------- 1 file changed, 58 insertions(+), 57 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index cbd499adb0d..69534ee99b8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -401,21 +401,27 @@ enum FaCodePath { FA_COOPMAT1, FA_COOPMAT2, }; +enum FaRows { + FA_ROWS_1, + FA_ROWS_SMALL, + FA_ROWS_LARGE, +}; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags) - : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {} + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, FaRows rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags) + : HSK(HSK), HSV(HSV), rows(rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {} uint32_t HSK, HSV; - bool small_rows, small_cache; + FaRows rows; + bool small_cache; FaCodePath path; bool aligned; bool f32acc; uint32_t flags; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) < - std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags); + return std::tie(HSK, HSV, rows, small_cache, path, aligned, f32acc, flags) < + std::tie(b.HSK, b.HSV, b.rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags); } }; @@ -2757,16 +2763,21 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; -static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) { +static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { + if (rows == FA_ROWS_1) { + return 1; + } else if (rows == FA_ROWS_SMALL) { + return 4; + } + if (hsv >= 192) { return 8; } else if ((hsv | hsk) & 8 || small_cache) { return 8; - } else { - return 16; } + + return 16; } // The FA coopmat1 shader assumes 16x16x16 matrix multiply support. @@ -2776,36 +2787,20 @@ static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; static constexpr uint32_t scalar_flash_attention_Bc = 64; static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; -static uint32_t get_fa_num_small_rows(FaCodePath path) { - if (path == FA_COOPMAT2) { - return flash_attention_num_small_rows; - } else { - return scalar_flash_attention_num_small_rows; - } -} - -static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) { +static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) { GGML_UNUSED(clamp); if (path == FA_SCALAR) { - if (small_rows) { - return {scalar_flash_attention_num_small_rows, 64}; - } else { - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; - } + return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 64}; } if (path == FA_COOPMAT1) { - if (small_rows) { - return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; - } else { - return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; - } + return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; } // small rows, large cols - if (small_rows) { - return {get_fa_num_small_rows(FA_COOPMAT2), 32}; + if (rows != FA_ROWS_LARGE) { + return {flash_attention_num_small_rows, 32}; } // small cols to reduce register count @@ -2819,8 +2814,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 return {64, 64}; } -static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) { - return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1]; +static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) { + return fa_rows_cols(path, hsk, hsv, 0, type, rows, small_cache)[1]; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -3187,23 +3182,23 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array { - return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) -> std::array { + return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. // For scalar, use 128 (arbitrary) // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. const uint32_t D = (hsk|hsv); - auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); + auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache); uint32_t wg_size; switch (path) { case FA_COOPMAT2: - wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128); + wg_size = ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); break; case FA_COOPMAT1: wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc @@ -3234,7 +3229,7 @@ static void ggml_vk_load_shaders(vk_device& device) { for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ uint32_t HSK = fa.first.HSK; \ uint32_t HSV = fa.first.HSV; \ - bool small_rows = fa.first.small_rows; \ + FaRows rows = fa.first.rows; \ bool small_cache = fa.first.small_cache; \ FaCodePath path = fa.first.path; \ bool aligned = fa.first.aligned; \ @@ -3243,15 +3238,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -8424,11 +8419,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache); + const uint32_t Br = get_fa_scalar_num_rows(hsk, hsv, rows, small_cache); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -8449,7 +8444,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); - const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false); + const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1]; @@ -8575,10 +8570,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache); + max_gqa = get_fa_scalar_num_rows(HSK, HSV, FA_ROWS_LARGE, small_cache); break; case FA_COOPMAT2: - max_gqa = get_fa_num_small_rows(FA_COOPMAT2); + max_gqa = flash_attention_num_small_rows; break; default: GGML_ASSERT(0); @@ -8594,23 +8589,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= gqa_ratio; } - bool small_rows = N <= get_fa_num_small_rows(path); + FaRows rows; + if (N == 1) { + rows = FA_ROWS_1; + } else if (N <= 8) { + rows = FA_ROWS_SMALL; + } else { + rows = FA_ROWS_LARGE; + } // coopmat1 does not actually support "small rows" (it needs 16 rows). // So use scalar instead. - if (small_rows && path == FA_COOPMAT1) { + if (rows != FA_ROWS_LARGE && path == FA_COOPMAT1) { path = FA_SCALAR; } // scalar is faster than coopmat2 when N==1 - if (N == 1 && path == FA_COOPMAT2) { + if (rows == FA_ROWS_1 && path == FA_COOPMAT2) { path = FA_SCALAR; } - // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory - if (path == FA_SCALAR && - !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) { - small_rows = true; + // with large hsk/hsv, scalar path may need to use small rows to fit in shared memory + if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache)) { + rows = FA_ROWS_SMALL; } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); @@ -8625,7 +8626,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache); + uint32_t alignment = fa_align(path, HSK, HSV, k->type, rows, small_cache); bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; @@ -8656,7 +8657,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx (mask != nullptr ? 2 : 0) | (logit_softcap != 0 ? 4 : 0); - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags); + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, rows, small_cache, path, aligned, f32acc, flags); vk_pipeline pipeline = nullptr; @@ -8707,7 +8708,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_preallocate_buffers(ctx, subctx); } - auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache); + auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1]; From b4f1f643aa2036dd7882efcc01a2de9b7a5a499a Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 7 Feb 2026 07:26:19 +0100 Subject: [PATCH 08/46] fixes --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 69534ee99b8..6b5f9fb3d84 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2768,7 +2768,7 @@ static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows, if (rows == FA_ROWS_1) { return 1; } else if (rows == FA_ROWS_SMALL) { - return 4; + return 8; } if (hsv >= 192) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 66c892591a5..ec4a831fd60 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -146,6 +146,7 @@ void main() { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; float max_mask = NEG_FLT_MAX_OVER_2; + barrier(); [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; From b9f155dd187febe89be4364a6633ae9731ba949a Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 7 Feb 2026 07:50:56 +0100 Subject: [PATCH 09/46] add padding to mask shmem buffer --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index ec4a831fd60..24589dfe7ce 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -47,7 +47,8 @@ const uint32_t tmpsh_size = row_split == 1 ? num_subgroups * D_split : 1; shared float tmpsh[tmpsh_size]; shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; -shared FLOAT_TYPE masksh[Bc][Br]; +const uint32_t masksh_stride = Br + 1; +shared FLOAT_TYPE masksh[Bc * masksh_stride]; const uint qfstride = HSK / 4 + 1; shared FLOAT_TYPEV4 Qf[Br * qfstride]; @@ -153,10 +154,10 @@ void main() { if (idx + tid < Bc * Br) { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); - masksh[c][r] = m; + masksh[c * masksh_stride + r] = m; max_mask = max(max_mask, float(m)); } else { - masksh[c][r] = FLOAT_TYPE(0); + masksh[c * masksh_stride + r] = FLOAT_TYPE(0); } } } @@ -222,7 +223,7 @@ void main() { if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - FLOAT_TYPE mvf = masksh[c * cols_per_iter + col_tid][tile_row(r)]; + FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)]; Sf[r][c] += slope[r]*mvf; } From c932e0d769d3d42e4df92b9e1faaa141dd305311 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 7 Feb 2026 17:15:55 +0100 Subject: [PATCH 10/46] cache q values into registers for KQ --- .../ggml-vulkan/vulkan-shaders/flash_attn.comp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 24589dfe7ce..e6a1de3f705 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -184,11 +184,17 @@ void main() { } } - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 Q_cache[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Q_cache[r] = Qf[tile_row(r) * qfstride + d * D_split + d_tid]; } - [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; @@ -198,7 +204,7 @@ void main() { FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qfstride + d * D_split + d_tid], K_Tf)); + Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); } } } From d2f428c61ef019cfb618c91d246fa1542e568928 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 8 Feb 2026 08:32:00 +0100 Subject: [PATCH 11/46] fuse lf accumulation, pf and v accumulation into a loop --- .../vulkan-shaders/flash_attn.comp | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e6a1de3f705..e641debe3c5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -236,7 +236,6 @@ void main() { } } - FLOAT_TYPE Pf[rows_per_thread][cols_per_thread]; float eMf[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { float rowmaxf = NEG_FLT_MAX_OVER_2; @@ -252,21 +251,8 @@ void main() { // P = e^(S - M) // eM = e^(Mold - M) Mf[r] = max(rowmaxf, Moldf); - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Pf[r][c] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r])); - } eMf[r] = exp(Moldf - Mf[r]); - - // Compute sum across row of P - float rowsumf = 0.0; - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; - } - rowsumf += Pf[r][c]; - } - - Lf[r] = eMf[r]*Lf[r] + rowsumf; + Lf[r] = eMf[r]*Lf[r]; } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { @@ -279,6 +265,13 @@ void main() { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } + + FLOAT_TYPE Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r])); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); @@ -289,7 +282,7 @@ void main() { FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] += ACC_TYPEV4(Pf[r][c] * Vf); + Of[r][d] += ACC_TYPEV4(Pf[r] * Vf); } } } From 2b3ed40d33756b5118edc84bd237d6258ada5e5c Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 8 Feb 2026 10:08:20 +0100 Subject: [PATCH 12/46] stage K loads through shmem --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- .../vulkan-shaders/flash_attn.comp | 50 +++++++++++++++---- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6b5f9fb3d84..bf52ae78441 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3220,7 +3220,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // Nvidia prefers shared memory use to load large tiles of K. // Switch to loading from global memory when it would use too much shared memory. // AMD prefers loading K directly from global memory - const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; + const uint32_t k_load_shmem = 1; // device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags}; }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e641debe3c5..e4ca125eb40 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -50,8 +50,12 @@ shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; const uint32_t masksh_stride = Br + 1; shared FLOAT_TYPE masksh[Bc * masksh_stride]; -const uint qfstride = HSK / 4 + 1; -shared FLOAT_TYPEV4 Qf[Br * qfstride]; +const uint32_t qf_stride = HSK / 4 + 1; +shared FLOAT_TYPEV4 Qf[Br * qf_stride]; + +const uint32_t D = HSK > HSV ? HSK : HSV; +const uint32_t kvsh_stride = D / 4 + 1; +shared FLOAT_TYPEV4 kvsh[K_LOAD_SHMEM != 0 ? Bc * kvsh_stride : 1]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -75,7 +79,7 @@ void main() { uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r * qfstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); @@ -184,10 +188,33 @@ void main() { } } + if (K_LOAD_SHMEM != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (c < Bc && d < HSK / 4) { + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); + if (!KV_bounds_check || j * Bc + c < KV) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + } + + kvsh[c * kvsh_stride + d] = K_Tf; + } + } + barrier(); + } + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { FLOAT_TYPEV4 Q_cache[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Q_cache[r] = Qf[tile_row(r) * qfstride + d * D_split + d_tid]; + Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid]; } [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { @@ -195,14 +222,19 @@ void main() { continue; } + FLOAT_TYPEV4 K_Tf; + if (K_LOAD_SHMEM != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); #else - FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif + } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); } From a4200ef459a71c61c0913632b1882efa11c8e6e8 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 8 Feb 2026 10:27:10 +0100 Subject: [PATCH 13/46] stage V loads through shmem --- .../vulkan-shaders/flash_attn.comp | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e4ca125eb40..6d85212d44b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -189,6 +189,7 @@ void main() { } if (K_LOAD_SHMEM != 0) { + barrier(); [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); @@ -293,6 +294,30 @@ void main() { } } + if (K_LOAD_SHMEM != 0) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSV / 4); + uint32_t c = (idx + tid) / (HSV / 4); + if (c < Bc) { + FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); + if (!KV_bounds_check || j * Bc + c < KV) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + V_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); +#else + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); +#endif + } + + kvsh[c * kvsh_stride + d] = V_Tf; + } + } + barrier(); + } + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; @@ -305,14 +330,19 @@ void main() { } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + FLOAT_TYPEV4 Vf; + if (K_LOAD_SHMEM != 0) { + Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + Vf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); #else - FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); + Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif + } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] += ACC_TYPEV4(Pf[r] * Vf); } From 7a6e76280d74d0412f38e3e2ffa8f96c5827e145 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 8 Feb 2026 10:41:50 +0100 Subject: [PATCH 14/46] only stage through shmem on Nvidia --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bf52ae78441..34f90cbae2c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3220,7 +3220,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // Nvidia prefers shared memory use to load large tiles of K. // Switch to loading from global memory when it would use too much shared memory. // AMD prefers loading K directly from global memory - const uint32_t k_load_shmem = 1; // device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; + const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0; return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags}; }; From e1ac2d950d258740bbd88c216bc7857d8cd370fb Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 8 Feb 2026 11:25:54 +0100 Subject: [PATCH 15/46] default to Bc 32 --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 34f90cbae2c..a1e2ca701f4 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2791,7 +2791,14 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 GGML_UNUSED(clamp); if (path == FA_SCALAR) { - return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 64}; + if (rows == FA_ROWS_1 && ((hsk|hsv) & 8)) { + // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter + // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. + // But this only applies to row_split=1, meaning FA_ROWS_1 + return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 64}; + } else { + return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 32}; + } } if (path == FA_COOPMAT1) { From 26ad71458c7c96fc1da58cefb97241251061811f Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 8 Feb 2026 12:42:53 +0100 Subject: [PATCH 16/46] also stage V through shmem when this is done for K --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 +- .../vulkan-shaders/flash_attn.comp | 12 +- .../vulkan-shaders/flash_attn_base.glsl | 2 +- .../vulkan-shaders/flash_attn_cm1.comp | 114 ++++++++++++------ 4 files changed, 84 insertions(+), 50 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a1e2ca701f4..034e1c38a88 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3224,12 +3224,12 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t D_lsb = D ^ (D & (D-1)); uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); - // Nvidia prefers shared memory use to load large tiles of K. + // Nvidia prefers shared memory use to load large tiles of K/V. // Switch to loading from global memory when it would use too much shared memory. // AMD prefers loading K directly from global memory - const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0; + const uint32_t shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0; - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags}; + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, shmem_staging, flags}; }; #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 6d85212d44b..7324d770a0f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -55,7 +55,7 @@ shared FLOAT_TYPEV4 Qf[Br * qf_stride]; const uint32_t D = HSK > HSV ? HSK : HSV; const uint32_t kvsh_stride = D / 4 + 1; -shared FLOAT_TYPEV4 kvsh[K_LOAD_SHMEM != 0 ? Bc * kvsh_stride : 1]; +shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -188,12 +188,12 @@ void main() { } } - if (K_LOAD_SHMEM != 0) { + if (SHMEM_STAGING != 0) { barrier(); [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); - if (c < Bc && d < HSK / 4) { + if (c < Bc) { FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { #if BLOCK_SIZE > 1 @@ -224,7 +224,7 @@ void main() { } FLOAT_TYPEV4 K_Tf; - if (K_LOAD_SHMEM != 0) { + if (SHMEM_STAGING != 0) { K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; } else { #if BLOCK_SIZE > 1 @@ -294,7 +294,7 @@ void main() { } } - if (K_LOAD_SHMEM != 0) { + if (SHMEM_STAGING != 0) { barrier(); [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSV / 4); @@ -331,7 +331,7 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { FLOAT_TYPEV4 Vf; - if (K_LOAD_SHMEM != 0) { + if (SHMEM_STAGING != 0) { Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; } else { #if BLOCK_SIZE > 1 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 4142c1e6eaa..0a077f68763 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -9,7 +9,7 @@ layout (constant_id = 4) const uint32_t HSV = 32; layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; layout (constant_id = 7) const uint32_t SubGroupSize = 32; -layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0; +layout (constant_id = 8) const uint32_t SHMEM_STAGING = 0; layout (constant_id = 9) const uint32_t Flags = 0; const bool USE_MASK_OPT = (Flags & 1) != 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 68bef90e48a..4776c5e0e2f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -54,10 +54,11 @@ shared f16vec4 Psh[Bc * psh_stride]; const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; shared ACC_TYPEV4 sfsh[Bc * sfshstride]; -const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4 +const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad; +const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4 const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups const uint vsh_stride = v_cols; -shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)]; +shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; shared ACC_TYPE slope[Br]; @@ -78,15 +79,15 @@ void main() { #define tile_row(r) (row_tid * rows_per_thread + (r)) // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK). - if ((HSK % 16) != 0) { + if ((HSK % 16) != 0 || (HSV % 16) != 0) { [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { if (i + tid < Br * qstride) { Qf[i + tid] = f16vec4(0); } } - [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) { - if (i + tid < Bc * kshstride) { - ksh[i + tid] = f16vec4(0); + [[unroll]] for (uint i = 0; i < Bc * kvsh_stride; i += gl_WorkGroupSize.x) { + if (i + tid < Bc * kvsh_stride) { + kvsh[i + tid] = f16vec4(0); } } barrier(); @@ -231,13 +232,13 @@ void main() { } } - if (K_LOAD_SHMEM != 0) { - [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (HSK / 4); - uint32_t c = (idx + tid) / (HSK / 4); - if (c < Bc && d < HSK / 4) { + if (SHMEM_STAGING != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK_pad / 4); + uint32_t c = (idx + tid) / (HSK_pad / 4); + if (c < Bc) { f16vec4 K_Tf = f16vec4(0); - if (!KV_bounds_check || j * Bc + c < KV) { + if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; @@ -248,7 +249,7 @@ void main() { #endif } - ksh[c * kshstride + d] = K_Tf; + kvsh[c * kvsh_stride + d] = K_Tf; } } barrier(); @@ -262,7 +263,7 @@ void main() { coopmat QMat; [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { - if (K_LOAD_SHMEM == 0) { + if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 if (KV_bounds_check || d * 16 + 16 > HSK) { #endif @@ -283,7 +284,7 @@ void main() { #endif } - ksh[row * kshstride + col_vec] = K_Tf; + kvsh[row * kvsh_stride + col_vec] = K_Tf; } } barrier(); @@ -295,8 +296,8 @@ void main() { if (KV_bounds_check || d * 16 + 16 > HSK) #endif { - uint coord = (gl_SubgroupID * MatBc) * kshstride; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + uint coord = (gl_SubgroupID * MatBc) * kvsh_stride; + coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); } #if BLOCK_SIZE == 1 else { @@ -305,8 +306,8 @@ void main() { } #endif } else { - uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4; + coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); } coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); @@ -397,6 +398,29 @@ void main() { } } + if (SHMEM_STAGING != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSV_pad / 4); + uint32_t c = (idx + tid) / (HSV_pad / 4); + if (c < Bc) { + f16vec4 V_Tf = f16vec4(0); + if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + V_Tf = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); +#else + V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); +#endif + } + + kvsh[c * kvsh_stride + d] = V_Tf; + } + } + } + barrier(); + const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up // Each subgroup handles HSV/4 columns @@ -410,6 +434,7 @@ void main() { const uint v_total = v_rows * v_cols; const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; + if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 // For f16, only preload if not aligned if (KV_bounds_check) { @@ -428,44 +453,53 @@ void main() { if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { #if BLOCK_SIZE > 1 - ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + kvsh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); #else - ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; + kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; #endif } else { - ksh[row * vsh_stride + col] = f16vec4(0.0f); + kvsh[row * vsh_stride + col] = f16vec4(0.0f); } } + #if BLOCK_SIZE == 1 } #endif - + } barrier(); - [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { - coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); + const uint osh_stride = row_split * MatBc / 4; + const uint o_offset = gl_SubgroupID * MatBc / 4; + + if (hsv_offset < HSV_pad) { + [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { + coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); + if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 - if (!KV_bounds_check) { - // F16 values can be loaded directly from global memory - const uint v_tile_row = j * Bc + bc_chunk * MatBc; - const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; - coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); - } else + if (!KV_bounds_check) { + // F16 values can be loaded directly from global memory + const uint v_tile_row = j * Bc + bc_chunk * MatBc; + const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; + coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } else #endif - { - const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); - coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); + { + const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); + coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } else { + const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4); + coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); } - SfMat = coopMatMulAdd(KMat, QMat, SfMat); + // Store SfMat to sfsh and load into Of + coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); } - // Store SfMat to sfsh and load into Of - const uint osh_stride = row_split * MatBc / 4; - const uint o_offset = gl_SubgroupID * MatBc / 4; - coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); - barrier(); const uint hsv_per_tile = row_split * MatBc; From 53acdd04ca10651d3dd7d7b4cc34ee79a6128439 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 8 Feb 2026 15:41:32 +0100 Subject: [PATCH 17/46] dynamic subgroups for intel --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 23 ++-- .../vulkan-shaders/flash_attn.comp | 106 ++++++++++++------ .../vulkan-shaders/flash_attn_base.glsl | 2 + 3 files changed, 92 insertions(+), 39 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 034e1c38a88..5f04f0533b1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3193,6 +3193,8 @@ static void ggml_vk_load_shaders(vk_device& device) { return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1}; }; + const bool disable_subgroups = device->vendor_id == VK_VENDOR_ID_INTEL; + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we @@ -3208,10 +3210,16 @@ static void ggml_vk_load_shaders(vk_device& device) { wg_size = ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); break; case FA_COOPMAT1: - wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc + if (disable_subgroups) { + wg_size = 128; + } else { + wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc + } break; default: - if (device->subgroup_size > 32 && rows_cols[0] < 4) { + if (disable_subgroups) { + wg_size = 128; + } else if (device->subgroup_size > 32 && rows_cols[0] < 4) { wg_size = device->subgroup_size * 2; } else { wg_size = device->subgroup_size * 4; @@ -3229,7 +3237,8 @@ static void ggml_vk_load_shaders(vk_device& device) { // AMD prefers loading K directly from global memory const uint32_t shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0; - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, shmem_staging, flags}; + const uint32_t subgroup_size = disable_subgroups ? 0xFFFFFFFF : device->subgroup_size; + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, subgroup_size, shmem_staging, flags}; }; #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ @@ -3245,15 +3254,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } \ } \ } \ diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 7324d770a0f..0ed6a390c59 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -43,7 +43,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_ return elem; } -const uint32_t tmpsh_size = row_split == 1 ? num_subgroups * D_split : 1; +// If SubGroupSize is set to 0xFFFFFFFF then only use shmem reductions +const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : 1) : WorkGroupSize; shared float tmpsh[tmpsh_size]; shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; @@ -67,6 +68,7 @@ void main() { const uint32_t tid = gl_LocalInvocationIndex; const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup; const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; @@ -359,20 +361,33 @@ void main() { float rowmaxf = Mf[r]; // Compute max across the row - [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { - rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s)); - } - if (row_split == 1) { - // Reduce inside workgroup with shmem - barrier(); - if (gl_SubgroupInvocationID == d_tid) { - tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf; + if (SubGroupSize != SUBGROUPS_DISABLED) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s)); + } + if (row_split == 1) { + // Reduce inside workgroup with shmem + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf; + } + barrier(); + rowmaxf = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]); + } } + } else { + barrier(); + tmpsh[tid] = rowmaxf; barrier(); - rowmaxf = tmpsh[d_tid]; - [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { - rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]); + } + barrier(); } + rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid]; } float Moldf = Mf[r]; @@ -385,37 +400,64 @@ void main() { Lf[r] = eMf*Lf[r]; // Compute sum across the row - [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { - Lf[r] += subgroupShuffleXor(Lf[r], s); - } - if (row_split == 1) { - barrier(); - if (gl_SubgroupInvocationID == d_tid) { - tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r]; + if (SubGroupSize != SUBGROUPS_DISABLED) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Lf[r] += subgroupShuffleXor(Lf[r], s); + } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r]; + } + barrier(); + Lf[r] = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Lf[r] += tmpsh[s * D_split + d_tid]; + } } + } else { barrier(); - Lf[r] = tmpsh[d_tid]; - [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { - Lf[r] += tmpsh[s * D_split + d_tid]; + tmpsh[tid] = Lf[r]; + barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s]; + } + barrier(); } + Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid]; } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { Of[r][d] = ACC_TYPE(eMf) * Of[r][d]; - [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { - Of[r][d] += subgroupShuffleXor(Of[r][d], s); - } - if (row_split == 1) { - barrier(); - if (gl_SubgroupInvocationID == d_tid) { - tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; + if (SubGroupSize != SUBGROUPS_DISABLED) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Of[r][d] += subgroupShuffleXor(Of[r][d], s); + } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; + } + barrier(); + Of[r][d] = tmpsh_accv4[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Of[r][d] += tmpsh_accv4[s * D_split + d_tid]; + } } + } else { barrier(); - Of[r][d] = tmpsh_accv4[d_tid]; - [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { - Of[r][d] += tmpsh_accv4[s * D_split + d_tid]; + tmpsh_accv4[tid] = Of[r][d]; + barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + Of[r][d] += tmpsh_accv4[tid ^ s]; + tmpsh_accv4[tid] = Of[r][d]; + } + barrier(); } + Of[r][d] = tmpsh_accv4[row_tid * threads_per_rowgroup + d_tid]; } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 0a077f68763..857e9faff4c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -66,6 +66,8 @@ layout (push_constant) uniform parameter { #define SINK_ENABLE_BIT (1<<24) #define N_LOG2_MASK 0xFFFF +#define SUBGROUPS_DISABLED 0xFFFFFFFF + layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; From 8a28404f19aa2182325f25735bab223592f6efe7 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Mon, 9 Feb 2026 08:14:57 +0100 Subject: [PATCH 18/46] use vectorized stores --- .../vulkan-shaders/flash_attn.comp | 25 ++++--------------- .../vulkan-shaders/flash_attn_base.glsl | 9 +++++++ .../vulkan-shaders/flash_attn_cm1.comp | 25 ++++--------------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 12 ++++----- 4 files changed, 25 insertions(+), 46 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0ed6a390c59..4d95ef58f99 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -34,15 +34,6 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -// Store the output when doing grouped query attention. -// Rows index by Q's dimension 2, and the first N rows are valid. -D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - uint32_t offset = (iq2 + r) * HSV + c; - data_o[o_offset + offset] = D_TYPE(elem); - return elem; -} - // If SubGroupSize is set to 0xFFFFFFFF then only use shmem reductions const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : 1) : WorkGroupSize; shared float tmpsh[tmpsh_size]; @@ -467,15 +458,13 @@ void main() { // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { const uint row = tile_row(r); if (row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(row, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); - } + gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); } } } @@ -527,16 +516,14 @@ void main() { } } - uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { const uint row = tile_row(r); if (row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(row, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); - } + gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); } } } @@ -545,9 +532,7 @@ void main() { const uint row = tile_row(r); if (i * Br + row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + row) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); - } + data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 857e9faff4c..5147b0236c5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -71,6 +71,7 @@ layout (push_constant) uniform parameter { layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; +layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];}; layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; @@ -246,3 +247,11 @@ void init_indices() // Bias applied to softmax to stay in fp16 range. // Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606 const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +void gqaStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * HSV / 4 + c; + data_ov4[o_offset + offset] = D_TYPEV4(elems); +} \ No newline at end of file diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 4776c5e0e2f..3768b23053d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -33,15 +33,6 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -// Store the output when doing grouped query attention. -// Rows index by Q's dimension 2, and the first N rows are valid. -D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - uint32_t offset = (iq2 + r) * HSV + c; - data_o[o_offset + offset] = D_TYPE(elem); - return elem; -} - shared float tmpsh[row_split]; const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 @@ -535,7 +526,7 @@ void main() { // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { @@ -543,9 +534,7 @@ void main() { const uint d = d0 + col_tid; if (d >= HSV/4) break; const uint d_local = d0 / threads_per_rowgroup; - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4 * d + comp, Of[r][d_local][comp], o_offset, iq2, N); - } + gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); } } } @@ -598,7 +587,7 @@ void main() { } } - uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { @@ -607,9 +596,7 @@ void main() { const uint d = d0 + col_tid; if (d >= HSV / 4) break; const uint d_local = d0 / threads_per_rowgroup; - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4 * d + comp, Of[r][d_local][comp], o_offset, iq2, N); - } + gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); } } } @@ -620,9 +607,7 @@ void main() { const uint d = d0 + col_tid; if (d >= HSV / 4) break; const uint d_local = d0 / threads_per_rowgroup; - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]); - } + data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 7fbe45d33f6..e083ee490e2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -644,32 +644,32 @@ void process_shaders() { #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, false, true, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); } else { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc); } #endif #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); } #endif } if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, false, false, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); } } } From 1fd08bb5f71d209a76b228c44beff8cb04afbfa3 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Mon, 9 Feb 2026 08:23:16 +0100 Subject: [PATCH 19/46] use float_type for dequantize4 functions --- .../ggml-vulkan/vulkan-shaders/flash_attn.comp | 8 ++++---- .../vulkan-shaders/flash_attn_base.glsl | 18 +++++++++--------- .../vulkan-shaders/flash_attn_cm1.comp | 8 ++++---- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 4d95ef58f99..0a425ce75fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -193,7 +193,7 @@ void main() { uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); #endif @@ -224,7 +224,7 @@ void main() { uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif @@ -299,7 +299,7 @@ void main() { uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - V_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); #endif @@ -331,7 +331,7 @@ void main() { uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - Vf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 5147b0236c5..88b6d65beeb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -97,12 +97,12 @@ layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16 #define BLOCK_SIZE 4 #define BLOCK_BYTE_SIZE 16 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { // iqs is currently always zero in the flash attention shaders if (binding_idx == BINDING_IDX_K) { - return k_packed.k_data_packed[a_offset + ib]; + return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]); } else { - return v_packed.v_data_packed[a_offset + ib]; + return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]); } } #endif @@ -110,7 +110,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -118,7 +118,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); } else { uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -126,24 +126,24 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); } } #endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); } else { const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); } } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 3768b23053d..ebb55c95040 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -234,7 +234,7 @@ void main() { uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); #endif @@ -269,7 +269,7 @@ void main() { uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); #endif @@ -400,7 +400,7 @@ void main() { uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - V_Tf = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); #endif @@ -444,7 +444,7 @@ void main() { if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { #if BLOCK_SIZE > 1 - kvsh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; #endif From 45c07756ceaf1e690e3a970a08153a057513540a Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 10 Feb 2026 18:23:43 +0100 Subject: [PATCH 20/46] use smaller scalar rows size for smaller rows count --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 5f04f0533b1..8f2852ea068 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2771,9 +2771,7 @@ static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows, return 8; } - if (hsv >= 192) { - return 8; - } else if ((hsv | hsk) & 8 || small_cache) { + if (hsv >= 192 || (hsv | hsk) & 8 || small_cache || rows == FA_ROWS_2 || rows == FA_ROWS_4 || rows == FA_ROWS_8) { return 8; } From 6c7d10e852e9d1d58b805b33a55ae58dc5866b82 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 10 Feb 2026 19:50:24 +0100 Subject: [PATCH 21/46] relax flash attention split_k condition to allow non-gqa use --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 49 ++++++++++------- .../vulkan-shaders/flash_attn.comp | 50 +++++++++++------ .../vulkan-shaders/flash_attn_base.glsl | 14 +++-- .../vulkan-shaders/flash_attn_cm1.comp | 53 +++++++++++++------ .../vulkan-shaders/flash_attn_cm2.comp | 42 ++++++++++++--- 5 files changed, 149 insertions(+), 59 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8f2852ea068..74f119673f1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8696,19 +8696,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Use a placeholder core count if one isn't available. split_k is a big help for perf. const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache); + const uint32_t Br = rows_cols[0]; + const uint32_t Bc = rows_cols[1]; + + GGML_ASSERT(Br == pipeline->wg_denoms[0]); + const uint32_t Tr = CEIL_DIV(N, Br); + // Try to use split_k when KV is large enough to be worth the overhead. - // Must either be a single batch or be using gqa, we can't mix the two. - if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) { - // Try to run two workgroups per SM. + if (gqa_ratio > 1 && workgroups_x <= Br) { split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z); - if (split_k > 1) { - // Try to evenly split KV into split_k chunks, but it needs to be a multiple - // of "align", so recompute split_k based on that. - split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); - split_k = CEIL_DIV(KV, split_kv); + } else if (gqa_ratio <= 1) { + uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z; + if (total_wgs_no_split < shader_core_count * 2) { + split_k = shader_core_count * 2 / total_wgs_no_split; } } + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); + split_k = CEIL_DIV(KV, split_kv); + } + // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3]. @@ -8722,10 +8733,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_preallocate_buffers(ctx, subctx); } - auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache); - const uint32_t Br = rows_cols[0]; - const uint32_t Bc = rows_cols[1]; - const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc); const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3; @@ -8805,15 +8812,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - workgroups_x *= pipeline->wg_denoms[0]; + + // We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + uint32_t dispatch_x; + if (gqa_ratio > 1) { + workgroups_x *= pipeline->wg_denoms[0]; + dispatch_x = split_k * workgroups_x; + } else { + dispatch_x = Tr * split_k * pipeline->wg_denoms[0]; + } + vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf}, - // We only use split_k when group query attention is enabled, which means - // there's no more than one tile of rows (i.e. workgroups_x would have been - // one). We reuse workgroups_x to mean the number of splits, so we need to - // cancel out the divide by wg_denoms[0]. - pc, { split_k * workgroups_x, workgroups_y, workgroups_z }); + pc, { dispatch_x, workgroups_y, workgroups_z }); ggml_vk_sync_buffers(ctx, subctx); const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0a425ce75fd..563a2bcbdc6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -457,27 +457,47 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - const uint row = tile_row(r); - if (row < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); + } } } - } - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - const uint row = tile_row(r); - if (row < N) { - perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } } - } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + const uint global_row = i * Br + row; + + if (global_row < N) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]); + } + } + + if (global_row < N && d_tid == 0 && col_tid == 0) { + uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_offset + iq2] = D_TYPE(Lf[r]); + data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]); + } + } + } return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 88b6d65beeb..40f529dc238 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -192,10 +192,16 @@ void init_indices() KV = p.KV; if (p.k_num > 1) { - i = 0; - // batch and split_k share gl_WorkGroupID.x - gqa_iq1 = gl_WorkGroupID.x / p.k_num; - split_k_index = gl_WorkGroupID.x % p.k_num; + if (p.gqa_ratio > 1) { + i = 0; + // batch and split_k share gl_WorkGroupID.x + gqa_iq1 = gl_WorkGroupID.x / p.k_num; + split_k_index = gl_WorkGroupID.x % p.k_num; + } else { + gqa_iq1 = 0; + split_k_index = gl_WorkGroupID.x % p.k_num; + i = gl_WorkGroupID.x / p.k_num; + } } else if (p.gqa_ratio > 1) { i = 0; gqa_iq1 = gl_WorkGroupID.x; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index ebb55c95040..e35d4666cd6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -525,25 +525,48 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - if (tile_row(r) < N) { - [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { - const uint d = d0 + col_tid; - if (d >= HSV/4) break; - const uint d_local = d0 / threads_per_rowgroup; - gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + const uint d_local = d0 / threads_per_rowgroup; + gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); + } } } - } - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - if (tile_row(r) < N) { - perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + const uint global_row = i * Br + row; + + if (global_row < N) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4; + + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]); + } + } + + if (global_row < N && col_tid == 0) { + uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_offset + iq2] = D_TYPE(Lf[r]); + data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]); + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 853f17fa16e..0ea181342ce 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } +// Store O values for non-GQA split_k. Rows are tokens, not heads. +D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) { + uint32_t global_row = i * Br + r; + if (global_row < N && c < HSV) { + uint32_t o_off = HSV * p.ne1 + * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[o_off + iq2 * HSV + c] = D_TYPE(elem); + } + return elem; +} + +// Store L/M values for non-GQA split_k. +ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) { + uint32_t global_row = i * Br + r; + if (global_row < N && c == 0) { + uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_off + lm_base + iq2] = D_TYPE(elem); + } + return elem; +} + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -290,13 +312,19 @@ void main() { if (p.k_num > 1) { coopmat O_D = coopmat(O); - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); - coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + } else { + coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N); + coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N); + coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N); + } return; } From 9d79f3f89daa6be6991e397f104a6000ad4c7a51 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Wed, 11 Feb 2026 00:41:14 +0100 Subject: [PATCH 22/46] use minimal subgroup size on Intel --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 74f119673f1..6ef7ff413df 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2785,6 +2785,26 @@ static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; static constexpr uint32_t scalar_flash_attention_Bc = 64; static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; +static bool fa_disable_subgroups(const vk_device& device, FaCodePath path) { + return device->vendor_id == VK_VENDOR_ID_INTEL && path == FA_SCALAR; +} + +static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) { + if (fa_disable_subgroups(device, path)) { + return 0xFFFFFFFF; + } + + if (path == FA_VECTOR) { + if (device->vendor_id == VK_VENDOR_ID_AMD && device->subgroup_min_size <= 32 && device->subgroup_max_size >= 32) { + return 32; + } else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) { + return device->subgroup_min_size; + } + } + + return device->subgroup_size; +} + static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) { GGML_UNUSED(clamp); @@ -3225,17 +3245,18 @@ static void ggml_vk_load_shaders(vk_device& device) { break; } + const uint32_t subgroup_size = fa_subgroup_size(device, path); + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. const uint32_t D_lsb = D ^ (D & (D-1)); - uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); + uint32_t D_split = std::min(std::min(subgroup_size, 8u), D_lsb / 4); // Nvidia prefers shared memory use to load large tiles of K/V. // Switch to loading from global memory when it would use too much shared memory. // AMD prefers loading K directly from global memory const uint32_t shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0; - const uint32_t subgroup_size = disable_subgroups ? 0xFFFFFFFF : device->subgroup_size; return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, subgroup_size, shmem_staging, flags}; }; From 0d7ed798fe877712ab130a8c5e8437c34ea37bea Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 12 Feb 2026 09:02:30 +0100 Subject: [PATCH 23/46] fix shmem support function --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 77 +++++++++++++++------------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6ef7ff413df..0eb6b102217 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2805,6 +2805,26 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) { return device->subgroup_size; } +static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) { + const uint32_t D = std::max(hsk, hsv); + switch (path) { + case FA_COOPMAT2: + return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); + case FA_COOPMAT1: + return (Bc / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc + case FA_VECTOR: + return device->vendor_id == VK_VENDOR_ID_AMD ? 256 : 128; + default: + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + return 128; + } else if (device->subgroup_size > 32 && Br < 4) { + return device->subgroup_size * 2; + } else { + return device->subgroup_size * 4; + } + } +} + static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) { GGML_UNUSED(clamp); @@ -3222,29 +3242,7 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t D = (hsk|hsv); auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache); - uint32_t wg_size; - switch (path) { - case FA_COOPMAT2: - wg_size = ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); - break; - case FA_COOPMAT1: - if (disable_subgroups) { - wg_size = 128; - } else { - wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc - } - break; - default: - if (disable_subgroups) { - wg_size = 128; - } else if (device->subgroup_size > 32 && rows_cols[0] < 4) { - wg_size = device->subgroup_size * 2; - } else { - wg_size = device->subgroup_size * 4; - } - break; - } - + const uint32_t wg_size = fa_workgroup_size(device, path, hsk, hsv, rows, rows_cols[0], rows_cols[1]); const uint32_t subgroup_size = fa_subgroup_size(device, path); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. @@ -8454,21 +8452,29 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) { // Needs to be kept up to date on shader changes - GGML_UNUSED(hsv); - const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_rows(hsk, hsv, rows, small_cache); - const uint32_t Bc = scalar_flash_attention_Bc; + const std::array rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache); + const uint32_t Br = rows_cols[0]; + const uint32_t Bc = rows_cols[1]; + const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc); + + const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t acc_type_size = !fp32acc ? sizeof(ggml_fp16_t) : sizeof(float); + // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); - const uint32_t tmpshv4 = wg_size * 4 * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * acc_type_size; - const uint32_t masksh = Bc * Br * sizeof(float); + const uint32_t masksh = Bc * (Br + 1) * float_type_size; - const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float); + const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; - const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; + const uint32_t D = std::max(hsk, hsv); + const bool shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256; + const uint32_t kvsh = shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); @@ -8597,6 +8603,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_z = (uint32_t)neq3; const bool small_cache = nek1 < 1024; + const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32; // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). @@ -8645,8 +8652,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } // with large hsk/hsv, scalar path may need to use small rows to fit in shared memory - if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache)) { - rows = FA_ROWS_SMALL; + if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, 0, k->type, FA_ROWS_LARGE, small_cache, f32acc)) { + rows = FA_ROWS_8; } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); @@ -8671,8 +8678,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx aligned = false; } - bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - float scale = 1.0f; float max_bias = 0.0f; float logit_softcap = 0.0f; From 3820695c426ed4744b9b6399fbccd622dc1e70f7 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 12 Feb 2026 11:39:28 +0100 Subject: [PATCH 24/46] fix rebase issues --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0eb6b102217..1015a704473 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2768,10 +2768,10 @@ static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows, if (rows == FA_ROWS_1) { return 1; } else if (rows == FA_ROWS_SMALL) { - return 8; + return 4; } - if (hsv >= 192 || (hsv | hsk) & 8 || small_cache || rows == FA_ROWS_2 || rows == FA_ROWS_4 || rows == FA_ROWS_8) { + if (hsv >= 192 || (hsv | hsk) & 8 || small_cache) { return 8; } @@ -2794,14 +2794,6 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) { return 0xFFFFFFFF; } - if (path == FA_VECTOR) { - if (device->vendor_id == VK_VENDOR_ID_AMD && device->subgroup_min_size <= 32 && device->subgroup_max_size >= 32) { - return 32; - } else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) { - return device->subgroup_min_size; - } - } - return device->subgroup_size; } @@ -2812,8 +2804,6 @@ static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); case FA_COOPMAT1: return (Bc / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc - case FA_VECTOR: - return device->vendor_id == VK_VENDOR_ID_AMD ? 256 : 128; default: if (device->vendor_id == VK_VENDOR_ID_INTEL) { return 128; @@ -8454,7 +8444,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) { // Needs to be kept up to date on shader changes - const std::array rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache); + const std::array rows_cols = fa_rows_cols(FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1]; const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc); @@ -8653,7 +8643,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // with large hsk/hsv, scalar path may need to use small rows to fit in shared memory if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, 0, k->type, FA_ROWS_LARGE, small_cache, f32acc)) { - rows = FA_ROWS_8; + rows = FA_ROWS_SMALL; } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); @@ -8722,7 +8712,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Use a placeholder core count if one isn't available. split_k is a big help for perf. const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; - auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache); + auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1]; From 638028fdde28b671528ad89b726fe9b96db479d8 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 12 Feb 2026 13:58:44 +0100 Subject: [PATCH 25/46] fixes --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1015a704473..9f5118e62f2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3258,18 +3258,19 @@ static void ggml_vk_load_shaders(vk_device& device) { bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ uint32_t flags = fa.first.flags; \ + bool fa_ds = path == FA_SCALAR && disable_subgroups; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -8624,7 +8625,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx FaRows rows; if (N == 1) { rows = FA_ROWS_1; - } else if (N <= 8) { + } else if (N <= (path == FA_COOPMAT2 ? flash_attention_num_small_rows : 8)) { rows = FA_ROWS_SMALL; } else { rows = FA_ROWS_LARGE; From 4f9360106c50976272da227332fd85e83c028c15 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 12 Feb 2026 16:38:57 +0100 Subject: [PATCH 26/46] Bc 4 for scalar FA is not a valid configuration --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9f5118e62f2..137b18f0ac7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2767,11 +2767,9 @@ static constexpr uint32_t flash_attention_num_small_rows = 32; static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { if (rows == FA_ROWS_1) { return 1; - } else if (rows == FA_ROWS_SMALL) { - return 4; } - if (hsv >= 192 || (hsv | hsk) & 8 || small_cache) { + if (rows == FA_ROWS_SMALL || hsv >= 192 || (hsv | hsk) & 8 || small_cache) { return 8; } From 52c8b6751dff00ebece00865820b676713d8d56b Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 12 Feb 2026 18:44:26 +0100 Subject: [PATCH 27/46] Use wave32 on AMD RDNA for scalar FA --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 137b18f0ac7..1c8f0b5fe42 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2792,23 +2792,28 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) { return 0xFFFFFFFF; } + if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) { + return 32; + } + return device->subgroup_size; } static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) { const uint32_t D = std::max(hsk, hsv); + const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path); switch (path) { case FA_COOPMAT2: return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); case FA_COOPMAT1: - return (Bc / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc + return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc default: if (device->vendor_id == VK_VENDOR_ID_INTEL) { return 128; - } else if (device->subgroup_size > 32 && Br < 4) { - return device->subgroup_size * 2; + } else if (subgroup_size > 32 && Br < 4) { + return subgroup_size * 2; } else { - return device->subgroup_size * 4; + return subgroup_size * 4; } } } @@ -2817,7 +2822,7 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 GGML_UNUSED(clamp); if (path == FA_SCALAR) { - if (rows == FA_ROWS_1 && ((hsk|hsv) & 8)) { + if (rows == FA_ROWS_1 || ((hsk|hsv) & 8)) { // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. // But this only applies to row_split=1, meaning FA_ROWS_1 @@ -3257,18 +3262,19 @@ static void ggml_vk_load_shaders(vk_device& device) { bool f32acc = fa.first.f32acc; \ uint32_t flags = fa.first.flags; \ bool fa_ds = path == FA_SCALAR && disable_subgroups; \ + uint32_t fa_sgs = fa_subgroup_size(device, path); \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } \ } \ } \ From fb4114885078bfc240d26ad1afbb00fb99a9b0b4 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Fri, 13 Feb 2026 07:02:31 +0100 Subject: [PATCH 28/46] add Intel shader core count lookup-table --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 48 +++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1c8f0b5fe42..8e9d52dde35 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4561,6 +4561,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch); +static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev); static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -4777,6 +4778,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->shader_core_count = sm_props.shaderSMCount; } else if (amd_shader_core_properties2) { device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) { + device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device); } else { device->shader_core_count = 0; } @@ -8714,8 +8717,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t split_kv = KV; uint32_t split_k = 1; + // Intel Alchemist prefers more workgroups + const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1; + // Use a placeholder core count if one isn't available. split_k is a big help for perf. - const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16; auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache); const uint32_t Br = rows_cols[0]; @@ -15472,6 +15478,46 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope } } +static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) { + VkPhysicalDeviceProperties2 props = vkdev.getProperties2(); + + if (props.properties.vendorID != VK_VENDOR_ID_INTEL) { + return 0; + } + + const uint32_t device_id = props.properties.deviceID; + + switch (device_id) { + case 0x56A6: // A310 + return 6; + case 0x5693: // A370M + case 0x56A5: // A380 + case 0x56B1: // Pro A40/A50 + return 8; + case 0x5697: // A530M + return 12; + case 0x5692: // A550M + case 0x56B3: // Pro A60 + return 16; + case 0x56A2: // A580 + return 24; + case 0x5691: // A730M + case 0x56A1: // A750 + return 28; + case 0x56A0: // A770 + case 0x5690: // A770M + return 32; + case 0xE212: // Pro B50 + return 16; + case 0xE20C: // B570 + return 18; + case 0xE20B: // B580 + return 20; + default: + return 0; + } +} + // checks #ifdef GGML_VULKAN_CHECK_RESULTS From a9d3f126122ac42b15be9a9c99e8b0bb8c929919 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 14 Feb 2026 06:45:58 +0100 Subject: [PATCH 29/46] fix regressions --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 +-- .../vulkan-shaders/flash_attn.comp | 65 ++++++++++++++----- 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8e9d52dde35..b81132622ec 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2808,9 +2808,7 @@ static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint case FA_COOPMAT1: return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc default: - if (device->vendor_id == VK_VENDOR_ID_INTEL) { - return 128; - } else if (subgroup_size > 32 && Br < 4) { + if (subgroup_size > 32 && Br < 4) { return subgroup_size * 2; } else { return subgroup_size * 4; @@ -3224,8 +3222,6 @@ static void ggml_vk_load_shaders(vk_device& device) { return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1}; }; - const bool disable_subgroups = device->vendor_id == VK_VENDOR_ID_INTEL; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we @@ -3261,7 +3257,7 @@ static void ggml_vk_load_shaders(vk_device& device) { bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ uint32_t flags = fa.first.flags; \ - bool fa_ds = path == FA_SCALAR && disable_subgroups; \ + bool fa_ds = fa_disable_subgroups(device, path); \ uint32_t fa_sgs = fa_subgroup_size(device, path); \ if (path == FAPATH) { \ if (aligned) { \ diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 563a2bcbdc6..8974593e9f1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -19,7 +19,7 @@ const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; -const uint32_t row_split = (Br < 4) ? 1 : 4; +const uint32_t row_split = (Br < 4 || HSK <= 64) ? 1 : 4; const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; @@ -205,32 +205,61 @@ void main() { barrier(); } - [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { - FLOAT_TYPEV4 Q_cache[rows_per_thread]; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid]; - } + // More d iterations means Q register caching becomes relevant + // Few iterations means the additional registers needed are worse than the speed-up from caching + if (HSK_per_thread / 4 > 4) { + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 Q_cache[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid]; + } + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + + FLOAT_TYPEV4 K_Tf; + if (SHMEM_STAGING != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); + } + } + } + } else { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - FLOAT_TYPEV4 K_Tf; - if (SHMEM_STAGING != 0) { - K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 K_Tf; + if (SHMEM_STAGING != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else - K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif - } - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf)); + } } } } From 2db2f214d39f26c714c2c8022fd7278880e42fc9 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 14 Feb 2026 09:05:15 +0100 Subject: [PATCH 30/46] device tuning --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 56 ++++++++++++++-------------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b81132622ec..01097fadbc9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2764,12 +2764,16 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; -static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { +static uint32_t get_fa_scalar_num_rows(const vk_device& device, uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { if (rows == FA_ROWS_1) { return 1; } - if (rows == FA_ROWS_SMALL || hsv >= 192 || (hsv | hsk) & 8 || small_cache) { + if ( + rows == FA_ROWS_SMALL || hsv >= 192 || (hsv | hsk) & 8 || small_cache || + (device->architecture == AMD_GCN && hsk <= 64) || + (device->vendor_id == VK_VENDOR_ID_INTEL) + ) { return 8; } @@ -2787,12 +2791,12 @@ static bool fa_disable_subgroups(const vk_device& device, FaCodePath path) { return device->vendor_id == VK_VENDOR_ID_INTEL && path == FA_SCALAR; } -static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) { +static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path, FaRows rows) { if (fa_disable_subgroups(device, path)) { return 0xFFFFFFFF; } - if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) { + if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && rows == FA_ROWS_1) { return 32; } @@ -2801,14 +2805,14 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) { static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) { const uint32_t D = std::max(hsk, hsv); - const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path); + const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path, rows); switch (path) { case FA_COOPMAT2: return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); case FA_COOPMAT1: return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc default: - if (subgroup_size > 32 && Br < 4) { + if (subgroup_size > 32 && (Br < 4 || hsk < 64)) { return subgroup_size * 2; } else { return subgroup_size * 4; @@ -2816,7 +2820,7 @@ static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint } } -static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) { +static std::array fa_rows_cols(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) { GGML_UNUSED(clamp); if (path == FA_SCALAR) { @@ -2824,9 +2828,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. // But this only applies to row_split=1, meaning FA_ROWS_1 - return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 64}; + return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 64}; } else { - return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 32}; + return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 32}; } } @@ -2850,8 +2854,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 return {64, 64}; } -static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) { - return fa_rows_cols(path, hsk, hsv, 0, type, rows, small_cache)[1]; +static uint32_t fa_align(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) { + return fa_rows_cols(device, path, hsk, hsv, 0, type, rows, small_cache)[1]; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -3219,7 +3223,7 @@ static void ggml_vk_load_shaders(vk_device& device) { }; auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) -> std::array { - return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1}; + return {fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1}; }; auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector { @@ -3229,10 +3233,10 @@ static void ggml_vk_load_shaders(vk_device& device) { // For scalar, use 128 (arbitrary) // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. const uint32_t D = (hsk|hsv); - auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache); + auto rows_cols = fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache); const uint32_t wg_size = fa_workgroup_size(device, path, hsk, hsv, rows, rows_cols[0], rows_cols[1]); - const uint32_t subgroup_size = fa_subgroup_size(device, path); + const uint32_t subgroup_size = fa_subgroup_size(device, path, rows); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -3258,13 +3262,13 @@ static void ggml_vk_load_shaders(vk_device& device) { bool f32acc = fa.first.f32acc; \ uint32_t flags = fa.first.flags; \ bool fa_ds = fa_disable_subgroups(device, path); \ - uint32_t fa_sgs = fa_subgroup_size(device, path); \ + uint32_t fa_sgs = fa_subgroup_size(device, path, rows); \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } \ } else { \ if (f32acc) { \ @@ -8448,7 +8452,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) { // Needs to be kept up to date on shader changes - const std::array rows_cols = fa_rows_cols(FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache); + const std::array rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1]; const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc); @@ -8479,7 +8483,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); - const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false); + const auto rows_cols = fa_rows_cols(device, FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1]; @@ -8606,7 +8610,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_rows(HSK, HSV, FA_ROWS_LARGE, small_cache); + max_gqa = get_fa_scalar_num_rows(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache); break; case FA_COOPMAT2: max_gqa = flash_attention_num_small_rows; @@ -8634,14 +8638,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx rows = FA_ROWS_LARGE; } - // coopmat1 does not actually support "small rows" (it needs 16 rows). - // So use scalar instead. - if (rows != FA_ROWS_LARGE && path == FA_COOPMAT1) { - path = FA_SCALAR; - } - // scalar is faster than coopmat2 when N==1 - if (rows == FA_ROWS_1 && path == FA_COOPMAT2) { + if (rows == FA_ROWS_1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) { path = FA_SCALAR; } @@ -8662,7 +8660,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(path, HSK, HSV, k->type, rows, small_cache); + uint32_t alignment = fa_align(ctx->device, path, HSK, HSV, k->type, rows, small_cache); bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; @@ -8719,7 +8717,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Use a placeholder core count if one isn't available. split_k is a big help for perf. const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16; - auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache); + auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1]; From 93ae001f18fc8e1a8bf2f57392bdb4d642cb8438 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 14 Feb 2026 11:43:31 +0100 Subject: [PATCH 31/46] tmpsh size fix --- .../vulkan-shaders/flash_attn.comp | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 8974593e9f1..d8c4f4134d8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -35,7 +35,7 @@ layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; // If SubGroupSize is set to 0xFFFFFFFF then only use shmem reductions -const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : 1) : WorkGroupSize; +const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize; shared float tmpsh[tmpsh_size]; shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; @@ -143,7 +143,7 @@ void main() { if (mask_opt_bits != MASK_OPT_ALL_ZERO) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - float max_mask = NEG_FLT_MAX_OVER_2; + float max_mask = NEG_FLT_MAX_OVER_2; barrier(); [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; @@ -152,25 +152,25 @@ void main() { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); masksh[c * masksh_stride + r] = m; - max_mask = max(max_mask, float(m)); + max_mask = max(max_mask, float(m)); } else { masksh[c * masksh_stride + r] = FLOAT_TYPE(0); } } } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; - } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } } From 29752441c32e5ed0d74140b2d75944761176bea8 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 14 Feb 2026 13:02:32 +0100 Subject: [PATCH 32/46] fix editorconfig --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 40f529dc238..c8b392e9b5b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -260,4 +260,4 @@ void gqaStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPEV4 elem { uint32_t offset = (iq2 + r) * HSV / 4 + c; data_ov4[o_offset + offset] = D_TYPEV4(elems); -} \ No newline at end of file +} From f37679523b21e4aff16da9b52e85b729f496fb00 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 17 Feb 2026 09:20:28 +0100 Subject: [PATCH 33/46] refactor fa tuning logic into a single place --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 384 +++++++++--------- .../vulkan-shaders/flash_attn.comp | 13 +- .../vulkan-shaders/flash_attn_base.glsl | 23 +- .../vulkan-shaders/flash_attn_cm1.comp | 1 - 4 files changed, 209 insertions(+), 212 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 01097fadbc9..f8fc74d6ebd 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -408,20 +408,19 @@ enum FaRows { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, FaRows rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags) - : HSK(HSK), HSV(HSV), rows(rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {} - uint32_t HSK, HSV; - FaRows rows; - bool small_cache; + uint32_t Br, Bc; + uint32_t D_split, row_split; + bool shmem_staging; FaCodePath path; + uint32_t workgroup_size, subgroup_size; bool aligned; bool f32acc; uint32_t flags; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, rows, small_cache, path, aligned, f32acc, flags) < - std::tie(b.HSK, b.HSV, b.rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags); + return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags) < + std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags); } }; @@ -2761,101 +2760,194 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ); } -// number of rows/cols for flash attention shader -static constexpr uint32_t flash_attention_num_small_rows = 32; +struct vk_fa_tuning_params { + FaCodePath path; + uint32_t workgroup_size; + uint32_t subgroup_size; + uint32_t block_rows; + uint32_t block_cols; + uint32_t d_split; + uint32_t row_split; + bool shmem_staging; + bool disable_subgroups; + + void print() const { + std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size << + " block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split << + " row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups << std::endl; + } +}; -static uint32_t get_fa_scalar_num_rows(const vk_device& device, uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { - if (rows == FA_ROWS_1) { - return 1; +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); + +static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + GGML_UNUSED(kv_type); + + vk_fa_tuning_params result{}; + result.path = FA_SCALAR; + + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + // Disable subgroup use due to performance issues when enforcing subgroup sizes + result.subgroup_size = 32; + result.disable_subgroups = true; + } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) { + result.subgroup_size = n_rows == 1 ? 32 : device->subgroup_size; + } else { + result.subgroup_size = device->subgroup_size; } - if ( - rows == FA_ROWS_SMALL || hsv >= 192 || (hsv | hsk) & 8 || small_cache || - (device->architecture == AMD_GCN && hsk <= 64) || - (device->vendor_id == VK_VENDOR_ID_INTEL) - ) { - return 8; + if (result.subgroup_size > 32 && (n_rows < 4 || hsk < 64)) { + result.workgroup_size = result.subgroup_size * 2; + } else { + result.workgroup_size = result.subgroup_size * 4; } - return 16; -} + // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers + result.row_split = (n_rows < 4 || hsk <= 64) ? 1 : 4; -// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. -// 128 threads split into four subgroups, each subgroup does 1/4 -// of the Bc dimension. -static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; -static constexpr uint32_t scalar_flash_attention_Bc = 64; -static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; + const uint32_t D = hsk | hsv; -static bool fa_disable_subgroups(const vk_device& device, FaCodePath path) { - return device->vendor_id == VK_VENDOR_ID_INTEL && path == FA_SCALAR; -} + const bool reduce_block_rows = hsv >= 192 || D & 8 || n_kv < 1024 || + (device->architecture == AMD_GCN && hsk <= 64) || + device->vendor_id == VK_VENDOR_ID_INTEL; -static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path, FaRows rows) { - if (fa_disable_subgroups(device, path)) { - return 0xFFFFFFFF; + if (n_rows == 1) { + result.block_rows = 1; + result.block_cols = 64; + } else { + // row_split 1 means higher register use per row, so block size has to be adjusted + if (result.row_split == 1) { + result.block_rows = reduce_block_rows ? 4 : 8; + } else { + result.block_rows = reduce_block_rows ? 8 : 16; + } + + result.block_cols = (D & 8) ? 64 : 32; } - if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && rows == FA_ROWS_1) { - return 32; + const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit + + result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4); + + result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; + + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) { + result.block_rows /= 2; } - return device->subgroup_size; + return result; } -static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) { - const uint32_t D = std::max(hsk, hsv); - const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path, rows); - switch (path) { - case FA_COOPMAT2: - return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); - case FA_COOPMAT1: - return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc - default: - if (subgroup_size > 32 && (Br < 4 || hsk < 64)) { - return subgroup_size * 2; - } else { - return subgroup_size * 4; - } +static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + GGML_UNUSED(n_rows); + GGML_UNUSED(n_kv); + GGML_UNUSED(kv_type); + GGML_UNUSED(f32acc); + + vk_fa_tuning_params result; + result.path = FA_COOPMAT1; + + const uint32_t D = hsk | hsv; + + const uint32_t coopmat_block_rows = 16; + const uint32_t coopmat_block_cols = 16; + + const uint32_t num_subgroups = 4; + + result.block_rows = coopmat_block_rows; + result.block_cols = coopmat_block_cols * num_subgroups; + result.row_split = num_subgroups; + result.subgroup_size = device->subgroup_size; + result.workgroup_size = num_subgroups * result.subgroup_size; + + const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit + result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4); + + result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; + + return result; +} + +static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + GGML_UNUSED(n_kv); + GGML_UNUSED(f32acc); + + vk_fa_tuning_params result; + result.path = FA_COOPMAT2; + + const uint32_t D = hsk | hsv; + + const bool small_rows = n_rows < 32; + + if (small_rows) { + result.block_rows = 32; + result.block_cols = 32; + } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) { + result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64; + result.block_cols = 32; + } else { + result.block_rows = 64; + result.block_cols = 64; } + + result.subgroup_size = device->subgroup_size; + result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128; + + return result; } -static std::array fa_rows_cols(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) { - GGML_UNUSED(clamp); +static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : + device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; - if (path == FA_SCALAR) { - if (rows == FA_ROWS_1 || ((hsk|hsv) & 8)) { - // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter - // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - // But this only applies to row_split=1, meaning FA_ROWS_1 - return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 64}; - } else { - return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 32}; - } + if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) { + // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 + path = FA_SCALAR; } if (path == FA_COOPMAT1) { - return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; + bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || + (!f32acc && device->coopmat_support_16x16x16_f16acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); + + if (!shape_ok || !shmem_ok) { + path = FA_SCALAR; + } } - // small rows, large cols - if (rows != FA_ROWS_LARGE) { - return {flash_attention_num_small_rows, 32}; + // scalar is faster than coopmat when N==1 + if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) { + path = FA_SCALAR; } - // small cols to reduce register count - if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) { - if (hsk >= 512 || hsv >= 512) { - return {32, 32}; - } else { - return {64, 32}; - } + switch (path) { + case FA_SCALAR: + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + case FA_COOPMAT1: + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + case FA_COOPMAT2: + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + default: + throw std::runtime_error("unsupported FaCodePath"); } - return {64, 64}; } -static uint32_t fa_align(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) { - return fa_rows_cols(device, path, hsk, hsv, 0, type, rows, small_cache)[1]; +static vk_fa_pipeline_state get_fa_pipeline_state(const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc, + bool use_mask, bool use_mask_opt, bool use_logit_softcap) { + uint32_t flags = (use_mask_opt ? 1 : 0) | + (use_mask ? 2 : 0) | + (use_logit_softcap ? 4 : 0); + + const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size; + + return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags}; +} + +static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) { + return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split, + state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags}; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -3222,59 +3314,27 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) -> std::array { - return {fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1}; - }; - - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector { - // For large number of rows, 128 invocations seems to work best. - // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we - // can't use 256 for D==80. - // For scalar, use 128 (arbitrary) - // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. - const uint32_t D = (hsk|hsv); - auto rows_cols = fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache); - - const uint32_t wg_size = fa_workgroup_size(device, path, hsk, hsv, rows, rows_cols[0], rows_cols[1]); - const uint32_t subgroup_size = fa_subgroup_size(device, path, rows); - - // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. - // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. - const uint32_t D_lsb = D ^ (D & (D-1)); - uint32_t D_split = std::min(std::min(subgroup_size, 8u), D_lsb / 4); - - // Nvidia prefers shared memory use to load large tiles of K/V. - // Switch to loading from global memory when it would use too much shared memory. - // AMD prefers loading K directly from global memory - const uint32_t shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0; - - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, subgroup_size, shmem_staging, flags}; - }; - #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ - uint32_t HSK = fa.first.HSK; \ - uint32_t HSV = fa.first.HSV; \ - FaRows rows = fa.first.rows; \ - bool small_cache = fa.first.small_cache; \ FaCodePath path = fa.first.path; \ + uint32_t Br = fa.first.Br; \ + uint32_t Bc = fa.first.Bc; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ - uint32_t flags = fa.first.flags; \ - bool fa_ds = fa_disable_subgroups(device, path); \ - uint32_t fa_sgs = fa_subgroup_size(device, path, rows); \ + uint32_t fa_sgs = fa.first.subgroup_size; \ + bool fa_ds = fa.first.subgroup_size == 0; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } \ } \ } \ @@ -4999,11 +5059,7 @@ static vk_device ggml_vk_get_device(size_t idx) { #if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; - - // coopmat1 fa shader currently assumes 32 invocations per subgroup - device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && - device->subgroup_size_control && device->subgroup_min_size <= 32 && - device->subgroup_max_size >= 32; + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support; #endif if (coopmat2_support) { @@ -8450,15 +8506,14 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { // Needs to be kept up to date on shader changes - const std::array rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache); - const uint32_t Br = rows_cols[0]; - const uint32_t Bc = rows_cols[1]; - const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc); + const uint32_t wg_size = params.workgroup_size; + const uint32_t Br = params.block_rows; + const uint32_t Bc = params.block_cols; const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); - const uint32_t acc_type_size = !fp32acc ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t acc_type_size = !f32acc ? sizeof(ggml_fp16_t) : sizeof(float); // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); @@ -8480,12 +8535,11 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con return supported; } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); - const auto rows_cols = fa_rows_cols(device, FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false); - const uint32_t Br = rows_cols[0]; - const uint32_t Bc = rows_cols[1]; + const uint32_t Br = params.block_rows; + const uint32_t Bc = params.block_cols; const uint32_t MatBr = 16, MatBc = 16; @@ -8575,49 +8629,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); - FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : - ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; - - if (path == FA_COOPMAT1 && ctx->device->architecture == vk_device_architecture::NVIDIA_TURING) { - // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 - path = FA_SCALAR; - } - - if (path == FA_COOPMAT1) { - const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || - (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); - - const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type); - - if (!coopmat_shape_supported || !coopmat_shmem_supported) { - path = FA_SCALAR; - } - } - uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; uint32_t workgroups_x = (uint32_t)neq1; uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - const bool small_cache = nek1 < 1024; const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32; + const vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc); + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). - uint32_t max_gqa; - switch (path) { - case FA_SCALAR: - case FA_COOPMAT1: - // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_rows(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache); - break; - case FA_COOPMAT2: - max_gqa = flash_attention_num_small_rows; - break; - default: - GGML_ASSERT(0); - } + const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u); if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { @@ -8629,25 +8653,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= gqa_ratio; } - FaRows rows; - if (N == 1) { - rows = FA_ROWS_1; - } else if (N <= (path == FA_COOPMAT2 ? flash_attention_num_small_rows : 8)) { - rows = FA_ROWS_SMALL; - } else { - rows = FA_ROWS_LARGE; - } - - // scalar is faster than coopmat2 when N==1 - if (rows == FA_ROWS_1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) { - path = FA_SCALAR; - } - - // with large hsk/hsv, scalar path may need to use small rows to fit in shared memory - if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, 0, k->type, FA_ROWS_LARGE, small_cache, f32acc)) { - rows = FA_ROWS_SMALL; - } - const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); @@ -8660,13 +8665,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(ctx->device, path, HSK, HSV, k->type, rows, small_cache); + const uint32_t alignment = tuning_params.block_cols; bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned. - if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { + if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) { aligned = false; } @@ -8684,12 +8689,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; - - uint32_t flags = (use_mask_opt ? 1 : 0) | - (mask != nullptr ? 2 : 0) | - (logit_softcap != 0 ? 4 : 0); - - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, rows, small_cache, path, aligned, f32acc, flags); + vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(tuning_params, HSK, HSV, aligned, f32acc, + mask != nullptr, use_mask_opt, logit_softcap != 0); vk_pipeline pipeline = nullptr; @@ -8717,9 +8718,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Use a placeholder core count if one isn't available. split_k is a big help for perf. const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16; - auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache); - const uint32_t Br = rows_cols[0]; - const uint32_t Bc = rows_cols[1]; + const uint32_t Br = fa_pipeline_state.Br; + const uint32_t Bc = fa_pipeline_state.Bc; GGML_ASSERT(Br == pipeline->wg_denoms[0]); const uint32_t Tr = CEIL_DIV(N, Br); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index d8c4f4134d8..4f1d352e0ae 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -19,11 +19,10 @@ const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; -const uint32_t row_split = (Br < 4 || HSK <= 64) ? 1 : 4; const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; -const uint32_t num_subgroups = WorkGroupSize / SubGroupSize; +const uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize; layout (binding = 0) readonly buffer Q {float data_q[];}; @@ -34,8 +33,8 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -// If SubGroupSize is set to 0xFFFFFFFF then only use shmem reductions -const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize; +// If SubGroupSize is set to 0 then only use shmem reductions +const uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize; shared float tmpsh[tmpsh_size]; shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; @@ -381,7 +380,7 @@ void main() { float rowmaxf = Mf[r]; // Compute max across the row - if (SubGroupSize != SUBGROUPS_DISABLED) { + if (SubGroupSize > 0) { [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s)); } @@ -420,7 +419,7 @@ void main() { Lf[r] = eMf*Lf[r]; // Compute sum across the row - if (SubGroupSize != SUBGROUPS_DISABLED) { + if (SubGroupSize > 0) { [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { Lf[r] += subgroupShuffleXor(Lf[r], s); } @@ -451,7 +450,7 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { Of[r][d] = ACC_TYPE(eMf) * Of[r][d]; - if (SubGroupSize != SUBGROUPS_DISABLED) { + if (SubGroupSize > 0) { [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { Of[r][d] += subgroupShuffleXor(Of[r][d], s); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index c8b392e9b5b..c7c7701a980 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -1,16 +1,17 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (constant_id = 0) const uint32_t WorkGroupSize = 128; -layout (constant_id = 1) const uint32_t Br = 1; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t HSK = 32; -layout (constant_id = 4) const uint32_t HSV = 32; -layout (constant_id = 5) const uint32_t Clamp = 0; -layout (constant_id = 6) const uint32_t D_split = 16; -layout (constant_id = 7) const uint32_t SubGroupSize = 32; -layout (constant_id = 8) const uint32_t SHMEM_STAGING = 0; -layout (constant_id = 9) const uint32_t Flags = 0; +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t HSK = 32; +layout (constant_id = 4) const uint32_t HSV = 32; +layout (constant_id = 5) const uint32_t Clamp = 0; +layout (constant_id = 6) const uint32_t D_split = 16; +layout (constant_id = 7) const uint32_t row_split = 1; +layout (constant_id = 8) const uint32_t SubGroupSize = 32; +layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0; +layout (constant_id = 10) const uint32_t Flags = 0; const bool USE_MASK_OPT = (Flags & 1) != 0; const bool MASK_ENABLE = (Flags & 2) != 0; @@ -66,8 +67,6 @@ layout (push_constant) uniform parameter { #define SINK_ENABLE_BIT (1<<24) #define N_LOG2_MASK 0xFFFF -#define SUBGROUPS_DISABLED 0xFFFFFFFF - layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index e35d4666cd6..69f22e7f69a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -19,7 +19,6 @@ const uint32_t MatBr = 16; const uint32_t MatBc = 16; -const uint32_t row_split = Bc / MatBc; const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; From 05d228350db5aa5af8c5b2823c0d72ca6eea5493 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 17 Feb 2026 10:15:45 +0100 Subject: [PATCH 34/46] fix gqa opt logic --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f8fc74d6ebd..92c20305a40 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8637,10 +8637,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32; - const vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc); - // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). + vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc); const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u); if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && @@ -8653,6 +8652,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= gqa_ratio; } + tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc); + const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); From 851e8325acb8e798f745321f225c4b2c0d32ac4b Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 17 Feb 2026 11:35:47 +0100 Subject: [PATCH 35/46] fix block_rows with small n_rows --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 92c20305a40..365bd88cf5a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2808,7 +2808,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, const uint32_t D = hsk | hsv; - const bool reduce_block_rows = hsv >= 192 || D & 8 || n_kv < 1024 || + const bool reduce_block_rows = n_rows <= 8 || hsv >= 192 || D & 8 || n_kv < 1024 || (device->architecture == AMD_GCN && hsk <= 64) || device->vendor_id == VK_VENDOR_ID_INTEL; From 9746ae1367e9293ad56ddd9d76ea4b69179a4e1e Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 17 Feb 2026 14:49:05 +0100 Subject: [PATCH 36/46] amd tuning --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 365bd88cf5a..0eb6e28dc48 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2797,19 +2797,22 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.subgroup_size = device->subgroup_size; } - if (result.subgroup_size > 32 && (n_rows < 4 || hsk < 64)) { + if (result.subgroup_size > 32 && (n_rows == 1 || hsk < 64)) { result.workgroup_size = result.subgroup_size * 2; } else { result.workgroup_size = result.subgroup_size * 4; } // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers - result.row_split = (n_rows < 4 || hsk <= 64) ? 1 : 4; + uint32_t row_split_max_hsk = 64; + if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) { + row_split_max_hsk = 256; + } + result.row_split = (n_rows == 1 || hsk <= row_split_max_hsk) ? 1 : 4; const uint32_t D = hsk | hsv; const bool reduce_block_rows = n_rows <= 8 || hsv >= 192 || D & 8 || n_kv < 1024 || - (device->architecture == AMD_GCN && hsk <= 64) || device->vendor_id == VK_VENDOR_ID_INTEL; if (n_rows == 1) { From d7c934c004463c75014064a8e40bc17af386228c Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 17 Feb 2026 14:57:43 +0100 Subject: [PATCH 37/46] fix hsk=72/80 issue --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0eb6e28dc48..86c20e6bea6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2797,12 +2797,6 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.subgroup_size = device->subgroup_size; } - if (result.subgroup_size > 32 && (n_rows == 1 || hsk < 64)) { - result.workgroup_size = result.subgroup_size * 2; - } else { - result.workgroup_size = result.subgroup_size * 4; - } - // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers uint32_t row_split_max_hsk = 64; if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) { @@ -2810,6 +2804,12 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, } result.row_split = (n_rows == 1 || hsk <= row_split_max_hsk) ? 1 : 4; + if (result.subgroup_size > 32 && (n_rows == 1 || hsk < (result.row_split == 1 ? 128 : 64))) { + result.workgroup_size = result.subgroup_size * 2; + } else { + result.workgroup_size = result.subgroup_size * 4; + } + const uint32_t D = hsk | hsv; const bool reduce_block_rows = n_rows <= 8 || hsv >= 192 || D & 8 || n_kv < 1024 || From 497c3e7d3a007f4c30142eb0d36d14daec7172e6 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Wed, 18 Feb 2026 09:14:22 +0100 Subject: [PATCH 38/46] tuning --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 86c20e6bea6..eb0f4987a65 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2799,10 +2799,10 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers uint32_t row_split_max_hsk = 64; - if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) { - row_split_max_hsk = 256; + if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) { + row_split_max_hsk = n_rows <= 8 ? 64 : 128; } - result.row_split = (n_rows == 1 || hsk <= row_split_max_hsk) ? 1 : 4; + result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4; if (result.subgroup_size > 32 && (n_rows == 1 || hsk < (result.row_split == 1 ? 128 : 64))) { result.workgroup_size = result.subgroup_size * 2; @@ -2812,7 +2812,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, const uint32_t D = hsk | hsv; - const bool reduce_block_rows = n_rows <= 8 || hsv >= 192 || D & 8 || n_kv < 1024 || + const bool reduce_block_rows = hsv > 256 || D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL; if (n_rows == 1) { @@ -2821,9 +2821,9 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, } else { // row_split 1 means higher register use per row, so block size has to be adjusted if (result.row_split == 1) { - result.block_rows = reduce_block_rows ? 4 : 8; + result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8); } else { - result.block_rows = reduce_block_rows ? 8 : 16; + result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16); } result.block_cols = (D & 8) ? 64 : 32; @@ -2848,7 +2848,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device GGML_UNUSED(kv_type); GGML_UNUSED(f32acc); - vk_fa_tuning_params result; + vk_fa_tuning_params result{}; result.path = FA_COOPMAT1; const uint32_t D = hsk | hsv; @@ -2876,7 +2876,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device GGML_UNUSED(n_kv); GGML_UNUSED(f32acc); - vk_fa_tuning_params result; + vk_fa_tuning_params result{}; result.path = FA_COOPMAT2; const uint32_t D = hsk | hsv; From 1cce6cd16aacb413781ff5aeda5e5d44308dc1af Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 19 Feb 2026 13:46:18 +0100 Subject: [PATCH 39/46] allow condition skipping for column check --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 4f1d352e0ae..3b734bcf1a7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -185,7 +185,7 @@ void main() { [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); - if (c < Bc) { + if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) { FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { #if BLOCK_SIZE > 1 @@ -320,7 +320,7 @@ void main() { [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSV / 4); uint32_t c = (idx + tid) / (HSV / 4); - if (c < Bc) { + if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) { FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { #if BLOCK_SIZE > 1 From ad37f1266071fd43d5f2001e09836bd0d8fb1cc6 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 19 Feb 2026 14:28:59 +0100 Subject: [PATCH 40/46] use float16 for Of if available --- .../vulkan-shaders/flash_attn.comp | 34 +++++++++---------- .../vulkan-shaders/flash_attn_base.glsl | 2 +- .../vulkan-shaders/flash_attn_cm1.comp | 32 +++++++++-------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 +- 4 files changed, 36 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 3b734bcf1a7..1703b5cf412 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -36,7 +36,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];}; // If SubGroupSize is set to 0 then only use shmem reductions const uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize; shared float tmpsh[tmpsh_size]; -shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; +shared FLOAT_TYPEV4 tmpshv4[tmpsh_size]; const uint32_t masksh_stride = Br + 1; shared FLOAT_TYPE masksh[Bc * masksh_stride]; @@ -76,10 +76,10 @@ void main() { } barrier(); - ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; + FLOAT_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] = ACC_TYPEV4(0.0); + Of[r][d] = FLOAT_TYPEV4(0.0); } } @@ -311,7 +311,7 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; + Of[r][d] = FLOAT_TYPE(eMf[r]) * Of[r][d]; } } @@ -365,7 +365,7 @@ void main() { #endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] += ACC_TYPEV4(Pf[r] * Vf); + Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf); } } } @@ -448,7 +448,7 @@ void main() { } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] = ACC_TYPE(eMf) * Of[r][d]; + Of[r][d] = FLOAT_TYPE(eMf) * Of[r][d]; if (SubGroupSize > 0) { [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { @@ -457,26 +457,26 @@ void main() { if (row_split == 1) { barrier(); if (gl_SubgroupInvocationID == d_tid) { - tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; + tmpshv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; } barrier(); - Of[r][d] = tmpsh_accv4[d_tid]; + Of[r][d] = tmpshv4[d_tid]; [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { - Of[r][d] += tmpsh_accv4[s * D_split + d_tid]; + Of[r][d] += tmpshv4[s * D_split + d_tid]; } } } else { barrier(); - tmpsh_accv4[tid] = Of[r][d]; + tmpshv4[tid] = Of[r][d]; barrier(); [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { if (rowgroup_tid < s) { - Of[r][d] += tmpsh_accv4[tid ^ s]; - tmpsh_accv4[tid] = Of[r][d]; + Of[r][d] += tmpshv4[tid ^ s]; + tmpshv4[tid] = Of[r][d]; } barrier(); } - Of[r][d] = tmpsh_accv4[row_tid * threads_per_rowgroup + d_tid]; + Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid]; } } } @@ -540,7 +540,7 @@ void main() { ms = exp(Mf[r] - sink); [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] *= ACC_TYPE(ms); + Of[r][d] *= FLOAT_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -557,9 +557,9 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] *= ACC_TYPE(Lfrcp[r]); -#if defined(ACC_TYPE_MAX) - Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); + Of[r][d] *= FLOAT_TYPE(Lfrcp[r]); +#if defined(FLOAT_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index c7c7701a980..fe2c6f17970 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -255,7 +255,7 @@ const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f; // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. -void gqaStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { uint32_t offset = (iq2 + r) * HSV / 4 + c; data_ov4[o_offset + offset] = D_TYPEV4(elems); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 69f22e7f69a..675589ca71f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -50,6 +50,9 @@ const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * nu const uint vsh_stride = v_cols; shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; +const uint32_t osh_stride = row_split * MatBc / 4; +shared f16vec4 pvsh[MatBc * osh_stride]; + shared ACC_TYPE slope[Br]; void main() { @@ -95,10 +98,10 @@ void main() { } barrier(); - ACC_TYPEV4 Of[rows_per_thread][d_per_thread]; + f16vec4 Of[rows_per_thread][d_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) { - Of[r][d] = ACC_TYPEV4(0.0); + Of[r][d] = f16vec4(0.0); } } @@ -226,7 +229,7 @@ void main() { [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK_pad / 4); uint32_t c = (idx + tid) / (HSK_pad / 4); - if (c < Bc) { + if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { f16vec4 K_Tf = f16vec4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { #if BLOCK_SIZE > 1 @@ -365,7 +368,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local]; + Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local]; } } @@ -392,7 +395,7 @@ void main() { [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSV_pad / 4); uint32_t c = (idx + tid) / (HSV_pad / 4); - if (c < Bc) { + if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { f16vec4 V_Tf = f16vec4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { #if BLOCK_SIZE > 1 @@ -417,7 +420,7 @@ void main() { [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) { const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16; - SfMat = coopmat(0); + coopmat PVMat = coopmat(0); // Preload V tiles for [Bc, 16 * num subgroups] const uint v_rows = Bc; @@ -458,7 +461,6 @@ void main() { } barrier(); - const uint osh_stride = row_split * MatBc / 4; const uint o_offset = gl_SubgroupID * MatBc / 4; if (hsv_offset < HSV_pad) { @@ -483,11 +485,11 @@ void main() { coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); } - SfMat = coopMatMulAdd(KMat, QMat, SfMat); + PVMat = coopMatMulAdd(KMat, QMat, PVMat); } - // Store SfMat to sfsh and load into Of - coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); + // Store PVMat to pvsh and load into Of + coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); } barrier(); @@ -508,7 +510,7 @@ void main() { if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) { const uint local_hsv = (hsv_col - hsv_base) / 4; - Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]); + Of[r][d_local] += pvsh[row * osh_stride + local_hsv]; } } } @@ -584,7 +586,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; - Of[r][d_local] *= ACC_TYPE(ms); + Of[r][d_local] *= float16_t(ms); } } else { vs = exp(sink - Mf[r]); @@ -602,9 +604,9 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] *= ACC_TYPE(Lfrcp[r]); -#if defined(ACC_TYPE_MAX) - Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX); + Of[r][d_local] *= float16_t(Lfrcp[r]); +#if defined(FLOAT_TYPE_MAX) + Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index e083ee490e2..85455988c57 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -623,7 +623,7 @@ void process_shaders() { for (const bool& fp16 : {false, true}) { std::map base_dict; if (fp16) { - base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}}; + base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; } else { base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}}; } From 6f2dacd1858d6f91b95dc2e4f82c950d1f7e9c83 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Fri, 20 Feb 2026 07:14:31 +0100 Subject: [PATCH 41/46] address feedback --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 11 ++++++++--- .../vulkan-shaders/flash_attn_cm1.comp | 19 +++++++++++-------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index eb0f4987a65..9d686cc8318 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2812,8 +2812,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, const uint32_t D = hsk | hsv; - const bool reduce_block_rows = hsv > 256 || D & 8 || n_kv < 1024 || - device->vendor_id == VK_VENDOR_ID_INTEL; + const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL; if (n_rows == 1) { result.block_rows = 1; @@ -8515,6 +8514,9 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; + const uint32_t MatBr = 16; + const uint32_t MatBc = 16; + const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); const uint32_t acc_type_size = !f32acc ? sizeof(ggml_fp16_t) : sizeof(float); @@ -8530,7 +8532,10 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const bool shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256; const uint32_t kvsh = shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; - const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh; + const uint32_t osh_stride = params.row_split * MatBr / 4; + const uint32_t pvsh = MatBc * osh_stride; + + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + pvsh; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 675589ca71f..526e8da384e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -50,7 +50,7 @@ const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * nu const uint vsh_stride = v_cols; shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; -const uint32_t osh_stride = row_split * MatBc / 4; +const uint32_t osh_stride = row_split * MatBr / 4; shared f16vec4 pvsh[MatBc * osh_stride]; shared ACC_TYPE slope[Br]; @@ -72,17 +72,12 @@ void main() { #define tile_row(r) (row_tid * rows_per_thread + (r)) // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK). - if ((HSK % 16) != 0 || (HSV % 16) != 0) { + if ((HSK % 16) != 0) { [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { if (i + tid < Br * qstride) { Qf[i + tid] = f16vec4(0); } } - [[unroll]] for (uint i = 0; i < Bc * kvsh_stride; i += gl_WorkGroupSize.x) { - if (i + tid < Bc * kvsh_stride) { - kvsh[i + tid] = f16vec4(0); - } - } barrier(); } @@ -256,6 +251,10 @@ void main() { coopmat QMat; [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { + // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem + // If not, f16 K is loaded directly from global memory if aligned, otherwise + // staged through a Bc * MatBr size staging buffer. + // If K is not type f16, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 if (KV_bounds_check || d * 16 + 16 > HSK) { @@ -427,6 +426,10 @@ void main() { const uint v_total = v_rows * v_cols; const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; + // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem. + // If not, f16 V is loaded directly from global memory if aligned, otherwise + // staged through a Bc * MatBr size staging buffer. + // If V is not type f16, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 // For f16, only preload if not aligned @@ -461,7 +464,7 @@ void main() { } barrier(); - const uint o_offset = gl_SubgroupID * MatBc / 4; + const uint o_offset = gl_SubgroupID * MatBr / 4; if (hsv_offset < HSV_pad) { [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { From 87e6f1bf9e7e546aeaeb82d97499c4d81ade32cb Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 21 Feb 2026 20:33:59 +0100 Subject: [PATCH 42/46] fix bad RDNA performance on head size <= 128 by limiting occupancy --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 26 ++++++++++++++----- .../vulkan-shaders/flash_attn.comp | 13 ++++++++++ .../vulkan-shaders/flash_attn_base.glsl | 1 + 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9d686cc8318..9433386695a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -417,10 +417,11 @@ struct vk_fa_pipeline_state { bool aligned; bool f32acc; uint32_t flags; + uint32_t limit_occupancy_shmem; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags) < - std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags); + return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) < + std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem); } }; @@ -2770,11 +2771,13 @@ struct vk_fa_tuning_params { uint32_t row_split; bool shmem_staging; bool disable_subgroups; + uint32_t limit_occupancy_shmem; void print() const { std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size << " block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split << - " row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups << std::endl; + " row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups << + " limit_occupancy_shmem=" << limit_occupancy_shmem << std::endl; } }; @@ -2792,7 +2795,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.subgroup_size = 32; result.disable_subgroups = true; } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) { - result.subgroup_size = n_rows == 1 ? 32 : device->subgroup_size; + result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size; } else { result.subgroup_size = device->subgroup_size; } @@ -2804,7 +2807,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, } result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4; - if (result.subgroup_size > 32 && (n_rows == 1 || hsk < (result.row_split == 1 ? 128 : 64))) { + if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) { result.workgroup_size = result.subgroup_size * 2; } else { result.workgroup_size = result.subgroup_size * 4; @@ -2838,6 +2841,15 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.block_rows /= 2; } + // On AMD RDNA, for small head sizes the shader uses few registers, so too many subgroups get scheduled + // at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy. + // This targets an occupancy of 4 subgroups per SIMD. + if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && device->properties.limits.maxComputeSharedMemorySize == 65536 && n_rows >= 64 && hsk <= 128) { + // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size + // Values are guessed, tested on RDNA2 + result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4; + } + return result; } @@ -2944,12 +2956,12 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_fa_tuning_params& par const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size; - return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags}; + return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem}; } static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) { return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split, - state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags}; + state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem}; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 1703b5cf412..135ab1ad625 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -48,6 +48,8 @@ const uint32_t D = HSK > HSV ? HSK : HSV; const uint32_t kvsh_stride = D / 4 + 1; shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1]; +shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1]; + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -62,6 +64,17 @@ void main() { const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + if (LIMIT_OCCUPANCY_SHMEM > 0) { + // This just exists to avoid the occupancy_limiter array getting optimized out + occupancy_limiter[tid] = vec4(tid); + + barrier(); + + if (occupancy_limiter[tid] == vec4(99999.0)) { + data_ov4[0] = D_TYPEV4(occupancy_limiter[tid]); + } + } + #define tile_row(r) (row_tid * rows_per_thread + (r)) uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index fe2c6f17970..d444542b533 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -12,6 +12,7 @@ layout (constant_id = 7) const uint32_t row_split = 1; layout (constant_id = 8) const uint32_t SubGroupSize = 32; layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0; layout (constant_id = 10) const uint32_t Flags = 0; +layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0; const bool USE_MASK_OPT = (Flags & 1) != 0; const bool MASK_ENABLE = (Flags & 2) != 0; From a740402df0c0ae71dc3a2702ab971e781193b995 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 22 Feb 2026 08:22:50 +0100 Subject: [PATCH 43/46] allow printing pipeline stats --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 30 ++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9433386695a..afe9a9f8b49 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1662,6 +1662,7 @@ static bool vk_perf_logger_concurrent = false; static bool vk_enable_sync_logger = false; // number of calls between perf logger prints static uint32_t vk_perf_logger_frequency = 1; +static std::string vk_pipeline_stats_filter; class vk_perf_logger { public: @@ -2178,7 +2179,32 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin executableInfo.pipeline = pipeline->pipeline; auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo); + + bool print_stats = !vk_pipeline_stats_filter.empty() && + pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos; + if (print_stats) { + std::cerr << "ggml_vulkan: pipeline stats for " << pipeline->name << ":" << std::endl; + } + for (auto & s : statistics) { + if (print_stats) { + std::cerr << "ggml_vulkan: " << s.name.data() << ": "; + switch (s.format) { + case vk::PipelineExecutableStatisticFormatKHR::eBool32: + std::cerr << (s.value.b32 ? "true" : "false"); + break; + case vk::PipelineExecutableStatisticFormatKHR::eInt64: + std::cerr << s.value.i64; + break; + case vk::PipelineExecutableStatisticFormatKHR::eUint64: + std::cerr << s.value.u64; + break; + case vk::PipelineExecutableStatisticFormatKHR::eFloat64: + std::cerr << s.value.f64; + break; + } + std::cerr << std::endl; + } // "Register Count" is reported by NVIDIA drivers. if (strcmp(s.name, "Register Count") == 0) { VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers"); @@ -5641,6 +5667,10 @@ static void ggml_vk_instance_init() { vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr; vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr; vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr; + const char* GGML_VK_PIPELINE_STATS = getenv("GGML_VK_PIPELINE_STATS"); + if (GGML_VK_PIPELINE_STATS != nullptr) { + vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS; + } const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY"); if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) { From b28bfea9f76ede48458a9f5f786aa1dc45713dde Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 22 Feb 2026 08:36:20 +0100 Subject: [PATCH 44/46] cleanup and fixes --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 34 ++++++++++------------------ 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index afe9a9f8b49..06790912bab 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -401,11 +401,6 @@ enum FaCodePath { FA_COOPMAT1, FA_COOPMAT2, }; -enum FaRows { - FA_ROWS_1, - FA_ROWS_SMALL, - FA_ROWS_LARGE, -}; struct vk_fa_pipeline_state { uint32_t HSK, HSV; @@ -8551,33 +8546,26 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { + GGML_UNUSED(f32acc); // Needs to be kept up to date on shader changes const uint32_t wg_size = params.workgroup_size; const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; - const uint32_t MatBr = 16; - const uint32_t MatBc = 16; - const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); - const uint32_t acc_type_size = !f32acc ? sizeof(ggml_fp16_t) : sizeof(float); // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); - const uint32_t tmpshv4 = wg_size * 4 * acc_type_size; + const uint32_t tmpshv4 = wg_size * 4 * float_type_size; const uint32_t masksh = Bc * (Br + 1) * float_type_size; const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; const uint32_t D = std::max(hsk, hsv); - const bool shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256; - const uint32_t kvsh = shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; - const uint32_t osh_stride = params.row_split * MatBr / 4; - const uint32_t pvsh = MatBc * osh_stride; - - const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + pvsh; + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); @@ -8587,7 +8575,6 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { // Needs to be kept up to date on shader changes - GGML_UNUSED(hsv); const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; @@ -8596,6 +8583,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t row_split = Bc / MatBc; const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); + const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16); const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; @@ -8611,17 +8599,19 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256; - const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2; + const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2; const uint32_t vsh_stride = MatBc / 4 * row_split; - const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4; + const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4; + + const uint32_t osh_stride = params.row_split * MatBr / 4; + const uint32_t pvsh = MatBc * osh_stride * f16vec4; const uint32_t slope = Br * acctype; - const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope; + const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); return supported; } From c73e128a69d07a009aeac935f952a22856dab282 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 22 Feb 2026 09:22:16 +0100 Subject: [PATCH 45/46] limit occupancy for GCN for small batch FA with large HSK --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 06790912bab..0923dddb91b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2862,13 +2862,20 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.block_rows /= 2; } - // On AMD RDNA, for small head sizes the shader uses few registers, so too many subgroups get scheduled + // On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled // at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy. // This targets an occupancy of 4 subgroups per SIMD. - if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && device->properties.limits.maxComputeSharedMemorySize == 65536 && n_rows >= 64 && hsk <= 128) { - // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size - // Values are guessed, tested on RDNA2 - result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4; + if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) { + if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) { + // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size + // Values are guessed, tested on RDNA2 + result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4; + } else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) { + // Same thing for GCN, with an occupancy target of 2 subgroups per SIMD. + // Here low-batch FA with large head size is affected. + // n_rows < 4 switch because workgroup size switches from 128 to 256 there. + result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4; + } } return result; From ae849d33ce2a440874b20df7dcd667cb5fa368f1 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Mon, 23 Feb 2026 07:33:12 +0100 Subject: [PATCH 46/46] disable f16 FA for GCN AMD GPUs on the proprietary driver --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0923dddb91b..11698e9af1d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -624,6 +624,8 @@ struct vk_device_struct { // floor(log2(maxComputeWorkGroupInvocations)) uint32_t max_workgroup_size_log2 {}; + bool flash_attention_fp16; + bool coopmat_support; bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; @@ -3382,7 +3384,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } \ } - if (device->fp16) { + if (device->flash_attention_fp16) { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) @@ -5419,6 +5421,10 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mmvq_mode = 1; } + // Driver issues with older AMD GPUs on Windows, see https://github.com/ggml-org/llama.cpp/pull/19625#issuecomment-3940840613 + const bool is_amd_proprietary_gcn = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture == AMD_GCN && device->driver_id == vk::DriverId::eAmdProprietary; + device->flash_attention_fp16 = device->fp16 && !is_amd_proprietary_gcn; + return device; } @@ -8559,7 +8565,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; - const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t float_type_size = device->flash_attention_fp16 ? sizeof(ggml_fp16_t) : sizeof(float); // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); @@ -8682,7 +8688,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32; + const bool f32acc = !ctx->device->flash_attention_fp16 || dst->op_params[3] == GGML_PREC_F32; // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).