Skip to content

ggml-webgpu: enable FLASH_ATTN_EXT on browser without subgroup matrix #22199

Merged
reeselevine merged 14 commits intoggml-org:masterfrom
ArberSephirotheca:browser
Apr 24, 2026
Merged

ggml-webgpu: enable FLASH_ATTN_EXT on browser without subgroup matrix #22199
reeselevine merged 14 commits intoggml-org:masterfrom
ArberSephirotheca:browser

Conversation

@ArberSephirotheca
Copy link
Copy Markdown
Contributor

Overview

This PR addresses few things:

  1. Cleanup the vec path to remove requirement for subgroup matrix.
  2. Add a subgroup-based tile flash-attention kernel that works without subgroup-matrix operations.
  3. Remove subgroup-matrix-specific parameters and naming from the vec path.
  4. Enable FLASH_ATTN_EXT on browsers for the non-subgroup-matrix paths (vec and tile).

Additional information

Performance when running on my M4 Mac Mini Pro:

ggml_metal_device_init: tensor API disabled for pre-M5 and pre-A19 devices
ggml_metal_library_init: using embedded metal library
ggml_metal_library_init: loaded in 0.019 sec
ggml_metal_rsets_init: creating a residency set collection (keep_alive = 180 s)
ggml_metal_device_init: GPU name:   MTL0
ggml_metal_device_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_device_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_device_init: GPU family: MTLGPUFamilyMetal4  (5002)
ggml_metal_device_init: simdgroup reduction   = true
ggml_metal_device_init: simdgroup matrix mul. = true
ggml_metal_device_init: has unified memory    = true
ggml_metal_device_init: has bfloat            = true
ggml_metal_device_init: has tensor            = false
ggml_metal_device_init: use residency sets    = true
ggml_metal_device_init: use shared buffers    = true
ggml_metal_device_init: recommendedMaxWorkingSetSize  = 40200.90 MB
ggml_webgpu: adapter_info: vendor_id: 4203 | vendor: apple | architecture: metal-3 | device_id: 0 | name: Apple M4 Pro | device_desc: Metal driver on macOS Version 26.2 (Build 25C56)
Testing 4 devices

Backend 1/4: MTL0
  Skipping
Backend 2/4: WebGPU
  Device description: WebGPU
  Device memory: 28753 MB (28753 MB free)

  FLASH_ATTN_EXT(hsk=72,hsv=72,nh=16,nr23=[1,1],kv=5776,nb=5776,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]): not supported
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                    1590 runs -  1170.24 us/run - 125.83 MFLOP/run - 107.52 GFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=4,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                     796 runs -  1321.89 us/run - 503.32 MFLOP/run - 380.76 GFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                    8190 runs -   562.59 us/run -   8.39 MFLOP/run -  14.91 GFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[4,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                    2981 runs -   568.56 us/run -  33.55 MFLOP/run -  59.02 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                  5961 runs -  1007.82 us/run -  16.78 MFLOP/run -  16.65 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[4,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                  1491 runs -   983.43 us/run -  67.11 MFLOP/run -  68.24 GFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                    5961 runs -  1268.33 us/run -  16.78 MFLOP/run -  13.23 GFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[4,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                    1491 runs -  1232.79 us/run -  67.11 MFLOP/run -  54.44 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                  2981 runs -  2101.64 us/run -  33.55 MFLOP/run -  15.97 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[4,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                   746 runs -  2144.23 us/run - 134.22 MFLOP/run -  62.59 GFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                   2981 runs -  2684.76 us/run -  33.55 MFLOP/run -  12.50 GFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[4,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                    746 runs -  2687.63 us/run - 134.22 MFLOP/run -  49.94 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 1491 runs -  4217.26 us/run -  67.11 MFLOP/run -  15.91 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[4,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                  373 runs -  4319.32 us/run - 268.44 MFLOP/run -  62.15 GFLOPS

The performance is around 1/2 of subgroup matrix version

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure:
    Yes, I use AI agent to help me understand the flash attention code on other backends (e.g., Vulkan, CUDA and Metal)

@ArberSephirotheca ArberSephirotheca requested a review from a team as a code owner April 21, 2026 05:21
@ArberSephirotheca ArberSephirotheca changed the title Supp ggml-webgpu: enable FLASH_ATTN_EXT on browser without subgroup matrix Apr 21, 2026
@github-actions github-actions Bot added ggml changes relating to the ggml tensor library for machine learning WebGPU labels Apr 21, 2026
Copy link
Copy Markdown
Contributor

@reeselevine reeselevine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work on this!

With the non-subgroup matrix version of flashattention being slower, does it still lead to end-to-end speedups? I can test on my M3 too if that helps.

Comment thread ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl Outdated
Comment thread ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl Outdated
Comment thread ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp Outdated
@ArberSephirotheca
Copy link
Copy Markdown
Contributor Author

Yea, I still see performance improvement (~10%):

./build-webgpu/bin/llama-bench \
  -m ./llama-3.2-1b-instruct-q4_0.gguf  -fa 1

ggml_metal_device_init: tensor API disabled for pre-M5 and pre-A19 devices
ggml_metal_library_init: using embedded metal library
ggml_metal_library_init: loaded in 0.018 sec
ggml_metal_rsets_init: creating a residency set collection (keep_alive = 180 s)
ggml_metal_device_init: GPU name:   MTL0
ggml_metal_device_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_device_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_device_init: GPU family: MTLGPUFamilyMetal4  (5002)
ggml_metal_device_init: simdgroup reduction   = true
ggml_metal_device_init: simdgroup matrix mul. = true
ggml_metal_device_init: has unified memory    = true
ggml_metal_device_init: has bfloat            = true
ggml_metal_device_init: has tensor            = false
ggml_metal_device_init: use residency sets    = true
ggml_metal_device_init: use shared buffers    = true
ggml_metal_device_init: recommendedMaxWorkingSetSize  = 40200.90 MB
ggml_webgpu: adapter_info: vendor_id: 4203 | vendor: apple | architecture: metal-3 | device_id: 0 | name: Apple M4 Pro | device_desc: Metal driver on macOS Version 26.2 (Build 25C56)
| model                          |       size |     params | backend    | threads | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | -: | --------------: | -------------------: |
| llama 1B Q4_0                  | 727.75 MiB |     1.24 B | MTL,WebGPU,BLAS |       8 |  1 |           pp512 |       2325.05 ± 2.20 |
| llama 1B Q4_0                  | 727.75 MiB |     1.24 B | MTL,WebGPU,BLAS |       8 |  1 |           tg128 |        154.55 ± 0.18 |
build: 59510168a (8873)
./build-webgpu/bin/llama-bench \
  -m ./llama-3.2-1b-instruct-q4_0.gguf       

ggml_metal_device_init: tensor API disabled for pre-M5 and pre-A19 devices
ggml_metal_library_init: using embedded metal library
ggml_metal_library_init: loaded in 0.013 sec
ggml_metal_rsets_init: creating a residency set collection (keep_alive = 180 s)
ggml_metal_device_init: GPU name:   MTL0
ggml_metal_device_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_device_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_device_init: GPU family: MTLGPUFamilyMetal4  (5002)
ggml_metal_device_init: simdgroup reduction   = true
ggml_metal_device_init: simdgroup matrix mul. = true
ggml_metal_device_init: has unified memory    = true
ggml_metal_device_init: has bfloat            = true
ggml_metal_device_init: has tensor            = false
ggml_metal_device_init: use residency sets    = true
ggml_metal_device_init: use shared buffers    = true
ggml_metal_device_init: recommendedMaxWorkingSetSize  = 40200.90 MB
ggml_webgpu: adapter_info: vendor_id: 4203 | vendor: apple | architecture: metal-3 | device_id: 0 | name: Apple M4 Pro | device_desc: Metal driver on macOS Version 26.2 (Build 25C56)
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 1B Q4_0                  | 727.75 MiB |     1.24 B | MTL,WebGPU,BLAS |       8 |           pp512 |       2113.57 ± 1.07 |
| llama 1B Q4_0                  | 727.75 MiB |     1.24 B | MTL,WebGPU,BLAS |       8 |           tg128 |        143.10 ± 1.12 |

build: 59510168a (8873)

@reeselevine
Copy link
Copy Markdown
Contributor

Hey @ArberSephirotheca as I mentioned in the linked PR, I realize that for FlashAttention to work in the browser, we'll also need to resolve some buffer aliasing in the flashattention shaders. I think that would be good to get merged with this PR, do you think it's something you can look into? Hopefully the patterns from the other shaders which need it would make it not too hard to add here. Removing the skip_validation flag in this PR will also ensure that future code handles aliasing properly too.

@Constannnnnt
Copy link
Copy Markdown
Contributor

For references, I pulled this PR into my branch and tested it locally in my inference engine in the browser. It did not report any issues or errors. I was using Qwen-3.5-0.8B-Q4_O (text-only) and LFM2.5-VL-450M-F16 (vision).

@ArberSephirotheca
Copy link
Copy Markdown
Contributor Author

I looked into the code and looks like the only overlapping happen is when nwg==1, dst and tmp buffers got overlapped:

ggml_webgpu_flash_attn: vec alias check nwg=1 kv_tile=32 same_buf=1 overlap=1 tmp=[50176,51200) dst=[50176,51200)
/Users/zheyuan/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp:3110: ggml_webgpu: Device error! Reason: 2, Message: Writable storage buffer binding aliasing found between [BindGroup "flash_attn_vec_f16_hsqk64_hsv64"] set at bind group index 0, binding index 3, and [BindGroup "flash_attn_vec_f16_hsqk64_hsv64"] set at bind group index 0, binding index 4, with overlapping ranges (offset: 50176, size: 1024) and (offset: 50176, size: 1024) in [Buffer "tensor_buf0"].
 - While encoding [ComputePassEncoder (unlabeled)].DispatchWorkgroups(4, 1, 1).
 - While finishing [CommandEncoder (unlabeled)].

Just fixed it.
I also turned off the skip_validation flag and tested several models in the browser using wllama, and I’m not seeing any validation errors.

@reeselevine
Copy link
Copy Markdown
Contributor

awesome! how about we try removing skip_validation in this PR then :). If it works, we should be able to catch more issues sooner in the future, and if it breaks something unrelated, we'll leave it for now and fix that separately.

@ArberSephirotheca
Copy link
Copy Markdown
Contributor Author

Yes sounds good! I already removed the skip_validation in my previous commit.

@ArberSephirotheca
Copy link
Copy Markdown
Contributor Author

Looks like the CI is complaining about the K and V buffer overlapping in some test cases when turning off skip_validation. I added a variant kv_overlap to flag if kv overlaps. If yes, host binds one merged KV range, and the shaders use the existing K/V offsets and strides to calculate the logical K and V views from that single binding.

@reeselevine
Copy link
Copy Markdown
Contributor

Let’s get #22266 merged and then see if the ci here passes with that change

@reeselevine reeselevine requested review from CISC and ggerganov April 24, 2026 16:54
@reeselevine reeselevine merged commit 13d36cf into ggml-org:master Apr 24, 2026
40 of 46 checks passed
@Constannnnnt
Copy link
Copy Markdown
Contributor

Hi @ArberSephirotheca, can I ask a follow-up question on this line? Why is this seq dimension (or prefill length, not sure if I misunderstood this) is < 20 and % 32 ? It seems to fall back to the CPU in most cases for use_vec? Thanks!

@ArberSephirotheca
Copy link
Copy Markdown
Contributor Author

ArberSephirotheca commented Apr 27, 2026

Hi @Constannnnnt, I believe it comes from the heuristic of metal backend(see this). As I am mainly testing on my M4 mac, I use metal backend as my primary reference. On the browser we expect vec to be used only for small query length (<20), while most F16 cases should go through the tile path instead. I think the immediate next step is to broaden the tile path so longer non-F16 cases (for example Q4_0/Q8_0/F32 K/V) do not fall back off the WebGPU.

@Constannnnnt
Copy link
Copy Markdown
Contributor

Thank you @ArberSephirotheca! This makes sense. I was actually testing mixed loads (SISO and LILO) on my inference engine in the browser and found that most of the time it fell back to CPU, so it led me to this line. Thanks again for your help!

IntelNav pushed a commit to IntelNav/llama.cpp that referenced this pull request Apr 29, 2026
…ggml-org#22199)

* ggml-webgpu: add tile flash attention fallback

* ggml-webgpu: add new fields and discard usage of mnk for tile version

* ggml-webgpu: modify the vec path to discard the mnk parameter

* ggml-webgpu: enable flash attention vec and tile version for broswer

* ggml-webgpu: stagging KV for flash attention tile version

* formatting

* turn on subgroup uniformity check

* remove Q_TILE as it is always 1 for vec path

* make row_max and exp_sum to local register

* make different bindings with same underlying buffer to have the same usage flags

* move path selection into the shader library and have the host consume a single flash-attn decision object.

* turn off skip_validation and address buffer overlapping when nwg==1

* formatting

* merge binding when kv overlap
IntelNav pushed a commit to IntelNav/llama.cpp that referenced this pull request Apr 29, 2026
…ggml-org#22199)

* ggml-webgpu: add tile flash attention fallback

* ggml-webgpu: add new fields and discard usage of mnk for tile version

* ggml-webgpu: modify the vec path to discard the mnk parameter

* ggml-webgpu: enable flash attention vec and tile version for broswer

* ggml-webgpu: stagging KV for flash attention tile version

* formatting

* turn on subgroup uniformity check

* remove Q_TILE as it is always 1 for vec path

* make row_max and exp_sum to local register

* make different bindings with same underlying buffer to have the same usage flags

* move path selection into the shader library and have the host consume a single flash-attn decision object.

* turn off skip_validation and address buffer overlapping when nwg==1

* formatting

* merge binding when kv overlap
rsenthilkumar6 pushed a commit to rsenthilkumar6/llama.cpp that referenced this pull request May 1, 2026
…ggml-org#22199)

* ggml-webgpu: add tile flash attention fallback

* ggml-webgpu: add new fields and discard usage of mnk for tile version

* ggml-webgpu: modify the vec path to discard the mnk parameter

* ggml-webgpu: enable flash attention vec and tile version for broswer

* ggml-webgpu: stagging KV for flash attention tile version

* formatting

* turn on subgroup uniformity check

* remove Q_TILE as it is always 1 for vec path

* make row_max and exp_sum to local register

* make different bindings with same underlying buffer to have the same usage flags

* move path selection into the shader library and have the host consume a single flash-attn decision object.

* turn off skip_validation and address buffer overlapping when nwg==1

* formatting

* merge binding when kv overlap
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning WebGPU

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants