From 6551632e6c81ef30490c43ca0176164be71a45fb Mon Sep 17 00:00:00 2001 From: Loveneet Nigam Date: Wed, 22 Apr 2026 18:18:34 +0530 Subject: [PATCH 1/7] ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (GQA=32) Adds MMA-f16 and tile kernel configs, dispatch logic, template instances, and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to ncols2=32 to support GQA ratio 32 only. --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 17 ++++++++ ggml/src/ggml-cuda/fattn-tile.cu | 4 ++ ggml/src/ggml-cuda/fattn-tile.cuh | 42 +++++++++++++++---- ggml/src/ggml-cuda/fattn.cu | 23 ++++++++++ ...ttn-mma-f16-instance-ncols1_1-ncols2_32.cu | 1 + ...ttn-mma-f16-instance-ncols1_2-ncols2_32.cu | 1 + .../fattn-tile-instance-dkq320-dv256.cu | 5 +++ .../template-instances/generate_cu_files.py | 13 ++++-- 8 files changed, 93 insertions(+), 13 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e185449d491..54f629d5bcb 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,6 +66,10 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + // Mistral Small 4 (DKQ=320, DV=256): tuned conservatively from the DKQ=256, DV=256 configs. + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 32, 160, 128, 128, 2, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); @@ -85,6 +89,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); @@ -118,6 +125,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); @@ -162,6 +172,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 2, 32, 160, 128, 128, 1, false); + // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy // compile-time static_asserts even though the kernel guard prevents runtime execution. // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility. @@ -1825,6 +1838,10 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); +// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build: +extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); +extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); + // For GLM 4.7 Flash extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 25b16e83cac..d60634cc0e9 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 320: { + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst); + } break; case 512: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 26721cc4c7d..44ac20b740b 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) @@ -128,6 +130,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64) @@ -195,6 +199,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) @@ -264,6 +270,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64) @@ -1144,14 +1152,18 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm } } - if (Q->ne[1] > 8/ncols2) { - constexpr int cols_per_block = 16; - const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); - return; + // cols_per_block must be >= ncols2 so ncols1 = cols_per_block/ncols2 is never 0 (integer division). + // Without if constexpr, NVCC/MSVC still instantiate flash_attn_tile<..., 0, ncols2, ...> when ncols2 > 16. + if constexpr (16 >= ncols2) { + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } } if constexpr (ncols2 <= 8) { @@ -1210,6 +1222,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + if constexpr (DKQ == 320 && DV == 256) { + // Mistral Small 4: only build/dispatch ncols2=32 + if (use_gqa_opt && gqa_ratio % 32 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32"); + } + if constexpr (DKQ == 576) { if (use_gqa_opt && gqa_ratio % 16 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); @@ -1221,7 +1242,9 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DKQ <= 512) { + // (320, 256) is handled above (ncols2=32 only); DKQ==320 satisfies DKQ<=512 so exclude it here to avoid + // instantiating generic ncols2 ladders that have no kernel config / flash_attn_tile_iter specializations. + if constexpr (DKQ <= 512 && !(DKQ == 320 && DV == 256)) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1275,5 +1298,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(320, 256); extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ea6607cd337..dc411dd06fe 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -143,6 +143,24 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; + case 320: + // Go straight to the ncols1 switch (ncols2=32-only build). + GGML_ASSERT(V->ne[0] == 256); + { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + // The detailed alignment / KV-padding checks are handled centrally in ggml_cuda_get_best_fattn_kernel(). + // Here we only re-check the semantic preconditions for the GQA-optimized variant. + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 32 == 0); + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst); + } + break; case 512: GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst); @@ -352,6 +370,11 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } break; + case 320: + if (V->ne[0] != 256 || !gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 512: if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu index 1f554d81e5e..8fc3b17976e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu index 264751d65ec..abd2b21ce04 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu new file mode 100644 index 00000000000..c91f508079d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(320, 256); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 841059c15b5..828c160fd16 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,7 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576] TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] @@ -62,7 +62,7 @@ def get_short_name(long_quant_name): os.remove(filename) for head_size_kq in HEAD_SIZES_KQ: - head_size_v = head_size_kq if head_size_kq != 576 else 512 + head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) @@ -84,13 +84,18 @@ def get_short_name(long_quant_name): continue if head_size_kq == 72: continue + if head_size_kq == 320: + # Mistral Small 4: only instantiate ncols2=32 variants for MMA f16. + if ncols2 != 32: + continue if head_size_kq == 512 and ncols2 not in (4, 8): continue - if head_size_kq != 576 and ncols2 in (16, 32): + # Only 576 (and 320 with ncols2=32 only, see above) use ncols2 in {16,32}; skip for all other KQ sizes. + if head_size_kq not in (320, 576) and ncols2 in (16, 32): continue if head_size_kq == 576 and ncols2 not in (4, 16, 32): continue - head_size_v = head_size_kq if head_size_kq != 576 else 512 + head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: From aac22efc552fc23e24625ba16eb24616e425447e Mon Sep 17 00:00:00 2001 From: Loveneet Nigam Date: Fri, 24 Apr 2026 19:38:33 +0530 Subject: [PATCH 2/7] Adding check to return BEST_FATTN_KERNEL_NONE in case GQA!=32 --- ggml/src/ggml-cuda/fattn.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index dc411dd06fe..511d8dd93dc 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -374,6 +374,10 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (V->ne[0] != 256 || !gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; } + // MMA/tile fast paths for 320/256 only implement the ncols2=32 GQA layout (gqa_ratio multiple of 32). + if (gqa_ratio % 32 != 0) { + return BEST_FATTN_KERNEL_NONE; + } break; case 512: if (V->ne[0] != K->ne[0]) { From 253b00a9e2ca13f1850827edd3df5c36c8b5c8af Mon Sep 17 00:00:00 2001 From: lnigam Date: Tue, 28 Apr 2026 12:46:19 +0530 Subject: [PATCH 3/7] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address review comments Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-tile.cuh | 11 +++-------- ggml/src/ggml-cuda/fattn.cu | 5 +---- .../ggml-cuda/template-instances/generate_cu_files.py | 11 +++++------ 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 44ac20b740b..fa66668285e 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -1152,9 +1152,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm } } - // cols_per_block must be >= ncols2 so ncols1 = cols_per_block/ncols2 is never 0 (integer division). - // Without if constexpr, NVCC/MSVC still instantiate flash_attn_tile<..., 0, ncols2, ...> when ncols2 > 16. - if constexpr (16 >= ncols2) { + if constexpr (ncols2 <= 16) { if (Q->ne[1] > 8/ncols2) { constexpr int cols_per_block = 16; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; @@ -1222,8 +1220,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; - if constexpr (DKQ == 320 && DV == 256) { - // Mistral Small 4: only build/dispatch ncols2=32 + if constexpr (DKQ == 320) { // Mistral Small 4 if (use_gqa_opt && gqa_ratio % 32 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1242,9 +1239,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - // (320, 256) is handled above (ncols2=32 only); DKQ==320 satisfies DKQ<=512 so exclude it here to avoid - // instantiating generic ncols2 ladders that have no kernel config / flash_attn_tile_iter specializations. - if constexpr (DKQ <= 512 && !(DKQ == 320 && DV == 256)) { + if constexpr (DKQ <= 512 && DKQ != 320) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 511d8dd93dc..8256591b21d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -144,14 +144,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; case 320: - // Go straight to the ncols1 switch (ncols2=32-only build). + // For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build). GGML_ASSERT(V->ne[0] == 256); { float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - // The detailed alignment / KV-padding checks are handled centrally in ggml_cuda_get_best_fattn_kernel(). - // Here we only re-check the semantic preconditions for the GQA-optimized variant. const bool use_gqa_opt = mask && max_bias == 0.0f; GGML_ASSERT(use_gqa_opt); GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); @@ -374,7 +372,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (V->ne[0] != 256 || !gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; } - // MMA/tile fast paths for 320/256 only implement the ncols2=32 GQA layout (gqa_ratio multiple of 32). if (gqa_ratio % 32 != 0) { return BEST_FATTN_KERNEL_NONE; } diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 828c160fd16..49c12bc93b6 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -84,16 +84,15 @@ def get_short_name(long_quant_name): continue if head_size_kq == 72: continue - if head_size_kq == 320: - # Mistral Small 4: only instantiate ncols2=32 variants for MMA f16. + # Skip compilation of unused ncols2 values for niche head sizes: + if head_size_kq == 320: # Mistral Small 4 if ncols2 != 32: continue - if head_size_kq == 512 and ncols2 not in (4, 8): + if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4 continue - # Only 576 (and 320 with ncols2=32 only, see above) use ncols2 in {16,32}; skip for all other KQ sizes. - if head_size_kq not in (320, 576) and ncols2 in (16, 32): + if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash continue - if head_size_kq == 576 and ncols2 not in (4, 16, 32): + if head_size_kq not in (320, 576) and ncols2 in (16, 32): continue head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) From 3cab359555755142414516e056a5b9ee285d0bdb Mon Sep 17 00:00:00 2001 From: Loveneet Nigam Date: Tue, 28 Apr 2026 14:03:51 +0530 Subject: [PATCH 4/7] Address review comments and making kernel config default to DQK=512, DV=512 instead of DQK=256,DV=256 --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 13 +++++-------- ggml/src/ggml-cuda/fattn-tile.cuh | 4 ++-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 54f629d5bcb..95836d56e58 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,9 +66,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); - // Mistral Small 4 (DKQ=320, DV=256): tuned conservatively from the DKQ=256, DV=256 configs. - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 32, 160, 128, 128, 2, false); + // Mistral Small 4 (DKQ=320, DV=256): inherit config style from DKQ=512, DV=512. + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); @@ -89,8 +89,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); @@ -172,9 +172,6 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 2, 32, 160, 128, 128, 1, false); - // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy // compile-time static_asserts even though the kernel guard prevents runtime execution. // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility. diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index fa66668285e..928b856f9d2 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -199,7 +199,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) @@ -270,7 +270,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) From 6e9ccb09eb9ba7bedfefd25caa6892f21763d45d Mon Sep 17 00:00:00 2001 From: Loveneet Nigam Date: Tue, 28 Apr 2026 15:04:47 +0530 Subject: [PATCH 5/7] Fixed bug with sinks=1, with ncols=32, there are two warp-groups created but sinks index is same(0,...,15) for both the groups hence with sinks=1, output is not matching with CPU output. Added sink_base which will be base index for each warp_group (threadIdx.y / np) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 95836d56e58..dda2a5e0cb6 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1227,7 +1227,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); + // jc is the (local) Q-head index within the current ncols tile. + // For wide layouts (e.g. cols_per_warp=16, ncols=32) multiple warp-groups cover disjoint jc ranges. + // The sinks vector is indexed by Q head, so include the warp-group base in jc. + const int jc_base = (threadIdx.y / np) * cols_per_warp; + const int jc = jc_base + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col)); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); From 59295da1ae8231c907a102a2ec9b775a1be7c259 Mon Sep 17 00:00:00 2001 From: lnigam Date: Tue, 28 Apr 2026 15:28:22 +0530 Subject: [PATCH 6/7] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 7 +------ ggml/src/ggml-cuda/template-instances/generate_cu_files.py | 5 ++--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index dda2a5e0cb6..3f01e858de7 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,7 +66,6 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); - // Mistral Small 4 (DKQ=320, DV=256): inherit config style from DKQ=512, DV=512. GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); @@ -1227,11 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - // jc is the (local) Q-head index within the current ncols tile. - // For wide layouts (e.g. cols_per_warp=16, ncols=32) multiple warp-groups cover disjoint jc ranges. - // The sinks vector is indexed by Q head, so include the warp-group base in jc. - const int jc_base = (threadIdx.y / np) * cols_per_warp; - const int jc = jc_base + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col)); + const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col)); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 49c12bc93b6..c6e22833a69 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -85,9 +85,8 @@ def get_short_name(long_quant_name): if head_size_kq == 72: continue # Skip compilation of unused ncols2 values for niche head sizes: - if head_size_kq == 320: # Mistral Small 4 - if ncols2 != 32: - continue + if head_size_kq == 320 && ncols2 != 32: # Mistral Small 4 + continue if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4 continue if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash From 91afffff9f7bb9eb566a0324c000d7158718ece2 Mon Sep 17 00:00:00 2001 From: lnigam Date: Tue, 28 Apr 2026 16:47:27 +0530 Subject: [PATCH 7/7] Update ggml/src/ggml-cuda/template-instances/generate_cu_files.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/template-instances/generate_cu_files.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index c6e22833a69..5e9a1cb2eb3 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -85,7 +85,7 @@ def get_short_name(long_quant_name): if head_size_kq == 72: continue # Skip compilation of unused ncols2 values for niche head sizes: - if head_size_kq == 320 && ncols2 != 32: # Mistral Small 4 + if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4 continue if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4 continue