ggml-webgpu: enable FLASH_ATTN_EXT on browser without subgroup matrix #22199
ggml-webgpu: enable FLASH_ATTN_EXT on browser without subgroup matrix #22199reeselevine merged 14 commits intoggml-org:masterfrom
Conversation
reeselevine
left a comment
There was a problem hiding this comment.
Thanks for the work on this!
With the non-subgroup matrix version of flashattention being slower, does it still lead to end-to-end speedups? I can test on my M3 too if that helps.
|
Yea, I still see performance improvement (~10%): ./build-webgpu/bin/llama-bench \
-m ./llama-3.2-1b-instruct-q4_0.gguf -fa 1
ggml_metal_device_init: tensor API disabled for pre-M5 and pre-A19 devices
ggml_metal_library_init: using embedded metal library
ggml_metal_library_init: loaded in 0.018 sec
ggml_metal_rsets_init: creating a residency set collection (keep_alive = 180 s)
ggml_metal_device_init: GPU name: MTL0
ggml_metal_device_init: GPU family: MTLGPUFamilyApple9 (1009)
ggml_metal_device_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_device_init: GPU family: MTLGPUFamilyMetal4 (5002)
ggml_metal_device_init: simdgroup reduction = true
ggml_metal_device_init: simdgroup matrix mul. = true
ggml_metal_device_init: has unified memory = true
ggml_metal_device_init: has bfloat = true
ggml_metal_device_init: has tensor = false
ggml_metal_device_init: use residency sets = true
ggml_metal_device_init: use shared buffers = true
ggml_metal_device_init: recommendedMaxWorkingSetSize = 40200.90 MB
ggml_webgpu: adapter_info: vendor_id: 4203 | vendor: apple | architecture: metal-3 | device_id: 0 | name: Apple M4 Pro | device_desc: Metal driver on macOS Version 26.2 (Build 25C56)
| model | size | params | backend | threads | fa | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | -: | --------------: | -------------------: |
| llama 1B Q4_0 | 727.75 MiB | 1.24 B | MTL,WebGPU,BLAS | 8 | 1 | pp512 | 2325.05 ± 2.20 |
| llama 1B Q4_0 | 727.75 MiB | 1.24 B | MTL,WebGPU,BLAS | 8 | 1 | tg128 | 154.55 ± 0.18 |
build: 59510168a (8873)./build-webgpu/bin/llama-bench \
-m ./llama-3.2-1b-instruct-q4_0.gguf
ggml_metal_device_init: tensor API disabled for pre-M5 and pre-A19 devices
ggml_metal_library_init: using embedded metal library
ggml_metal_library_init: loaded in 0.013 sec
ggml_metal_rsets_init: creating a residency set collection (keep_alive = 180 s)
ggml_metal_device_init: GPU name: MTL0
ggml_metal_device_init: GPU family: MTLGPUFamilyApple9 (1009)
ggml_metal_device_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_device_init: GPU family: MTLGPUFamilyMetal4 (5002)
ggml_metal_device_init: simdgroup reduction = true
ggml_metal_device_init: simdgroup matrix mul. = true
ggml_metal_device_init: has unified memory = true
ggml_metal_device_init: has bfloat = true
ggml_metal_device_init: has tensor = false
ggml_metal_device_init: use residency sets = true
ggml_metal_device_init: use shared buffers = true
ggml_metal_device_init: recommendedMaxWorkingSetSize = 40200.90 MB
ggml_webgpu: adapter_info: vendor_id: 4203 | vendor: apple | architecture: metal-3 | device_id: 0 | name: Apple M4 Pro | device_desc: Metal driver on macOS Version 26.2 (Build 25C56)
| model | size | params | backend | threads | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 1B Q4_0 | 727.75 MiB | 1.24 B | MTL,WebGPU,BLAS | 8 | pp512 | 2113.57 ± 1.07 |
| llama 1B Q4_0 | 727.75 MiB | 1.24 B | MTL,WebGPU,BLAS | 8 | tg128 | 143.10 ± 1.12 |
build: 59510168a (8873) |
|
Hey @ArberSephirotheca as I mentioned in the linked PR, I realize that for FlashAttention to work in the browser, we'll also need to resolve some buffer aliasing in the flashattention shaders. I think that would be good to get merged with this PR, do you think it's something you can look into? Hopefully the patterns from the other shaders which need it would make it not too hard to add here. Removing the |
|
For references, I pulled this PR into my branch and tested it locally in my inference engine in the browser. It did not report any issues or errors. I was using Qwen-3.5-0.8B-Q4_O (text-only) and LFM2.5-VL-450M-F16 (vision). |
|
I looked into the code and looks like the only overlapping happen is when ggml_webgpu_flash_attn: vec alias check nwg=1 kv_tile=32 same_buf=1 overlap=1 tmp=[50176,51200) dst=[50176,51200)
/Users/zheyuan/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp:3110: ggml_webgpu: Device error! Reason: 2, Message: Writable storage buffer binding aliasing found between [BindGroup "flash_attn_vec_f16_hsqk64_hsv64"] set at bind group index 0, binding index 3, and [BindGroup "flash_attn_vec_f16_hsqk64_hsv64"] set at bind group index 0, binding index 4, with overlapping ranges (offset: 50176, size: 1024) and (offset: 50176, size: 1024) in [Buffer "tensor_buf0"].
- While encoding [ComputePassEncoder (unlabeled)].DispatchWorkgroups(4, 1, 1).
- While finishing [CommandEncoder (unlabeled)].Just fixed it. |
|
awesome! how about we try removing |
|
Yes sounds good! I already removed the |
|
Looks like the CI is complaining about the K and V buffer overlapping in some test cases when turning off |
|
Let’s get #22266 merged and then see if the ci here passes with that change |
… a single flash-attn decision object.
2cbffdc to
a0a3f85
Compare
|
Hi @ArberSephirotheca, can I ask a follow-up question on this line? Why is this seq dimension (or prefill length, not sure if I misunderstood this) is < 20 and % 32 ? It seems to fall back to the CPU in most cases for use_vec? Thanks! |
|
Hi @Constannnnnt, I believe it comes from the heuristic of metal backend(see this). As I am mainly testing on my M4 mac, I use metal backend as my primary reference. On the browser we expect vec to be used only for small query length (<20), while most F16 cases should go through the tile path instead. I think the immediate next step is to broaden the tile path so longer non-F16 cases (for example Q4_0/Q8_0/F32 K/V) do not fall back off the WebGPU. |
|
Thank you @ArberSephirotheca! This makes sense. I was actually testing mixed loads (SISO and LILO) on my inference engine in the browser and found that most of the time it fell back to CPU, so it led me to this line. Thanks again for your help! |
…ggml-org#22199) * ggml-webgpu: add tile flash attention fallback * ggml-webgpu: add new fields and discard usage of mnk for tile version * ggml-webgpu: modify the vec path to discard the mnk parameter * ggml-webgpu: enable flash attention vec and tile version for broswer * ggml-webgpu: stagging KV for flash attention tile version * formatting * turn on subgroup uniformity check * remove Q_TILE as it is always 1 for vec path * make row_max and exp_sum to local register * make different bindings with same underlying buffer to have the same usage flags * move path selection into the shader library and have the host consume a single flash-attn decision object. * turn off skip_validation and address buffer overlapping when nwg==1 * formatting * merge binding when kv overlap
…ggml-org#22199) * ggml-webgpu: add tile flash attention fallback * ggml-webgpu: add new fields and discard usage of mnk for tile version * ggml-webgpu: modify the vec path to discard the mnk parameter * ggml-webgpu: enable flash attention vec and tile version for broswer * ggml-webgpu: stagging KV for flash attention tile version * formatting * turn on subgroup uniformity check * remove Q_TILE as it is always 1 for vec path * make row_max and exp_sum to local register * make different bindings with same underlying buffer to have the same usage flags * move path selection into the shader library and have the host consume a single flash-attn decision object. * turn off skip_validation and address buffer overlapping when nwg==1 * formatting * merge binding when kv overlap
…ggml-org#22199) * ggml-webgpu: add tile flash attention fallback * ggml-webgpu: add new fields and discard usage of mnk for tile version * ggml-webgpu: modify the vec path to discard the mnk parameter * ggml-webgpu: enable flash attention vec and tile version for broswer * ggml-webgpu: stagging KV for flash attention tile version * formatting * turn on subgroup uniformity check * remove Q_TILE as it is always 1 for vec path * make row_max and exp_sum to local register * make different bindings with same underlying buffer to have the same usage flags * move path selection into the shader library and have the host consume a single flash-attn decision object. * turn off skip_validation and address buffer overlapping when nwg==1 * formatting * merge binding when kv overlap
Overview
This PR addresses few things:
Additional information
Performance when running on my M4 Mac Mini Pro:
The performance is around 1/2 of subgroup matrix version
Requirements
Yes, I use AI agent to help me understand the flash attention code on other backends (e.g., Vulkan, CUDA and Metal)