Skip to content

ggml-webgpu: Add fused RMS_NORM + MUL#21983

Merged
reeselevine merged 5 commits intoggml-org:masterfrom
yomaytk:fused-rms-norm-mul
Apr 22, 2026
Merged

ggml-webgpu: Add fused RMS_NORM + MUL#21983
reeselevine merged 5 commits intoggml-org:masterfrom
yomaytk:fused-rms-norm-mul

Conversation

@yomaytk
Copy link
Copy Markdown
Contributor

@yomaytk yomaytk commented Apr 16, 2026

Overview

This PR adds the initial kernel fusion to WebGPU backend with RMS_NORM + MUL (it is similar to #14800).
The performance on the major models on my device (M2, Metal 4) is as follows, but unfortunately, the performance is almost the same on this implementation.

The command is like this:
llama-bench -m Llama-3.2-3B-Instruct-Q4_K_M.gguf -fa 1 -p 512 -n 0

Model Test t/s master(a620695) t/s yomaytk/fused-rms-norm-mul Speedup
gemma4 E4B Q4_K_M pp1 32.09 31.94 0.995
gemma4 E4B Q4_K_M pp512 464.23 467.15 1.006
gemma4 E4B Q4_K_M tg1 31.99 32.46 1.015
gemma4 E4B Q4_K_M tg128 32.50 32.85 1.011
qwen35 4B Q4_K_M pp1 30.88 30.84 0.999
qwen35 4B Q4_K_M pp512 479.07 480.75 1.004
qwen35 4B Q4_K_M tg1 30.99 30.90 0.997
qwen35 4B Q4_K_M tg128 31.12 31.35 1.007
llama3.2 3B Q4_K_M pp1 52.51 53.57 1.020
llama3.2 3B Q4_K_M pp512 670.86 676.70 1.009
llama3.2 3B Q4_K_M tg1 53.40 53.83 1.008
llama3.2 3B Q4_K_M tg128 54.03 54.71 1.013

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, I used AI to investigate the kernel fusion of Vulkan and CUDA backend, and help analyze the profiling data.

@yomaytk yomaytk requested a review from a team as a code owner April 16, 2026 04:53
@github-actions github-actions Bot added ggml changes relating to the ggml tensor library for machine learning WebGPU labels Apr 16, 2026
@yomaytk
Copy link
Copy Markdown
Contributor Author

yomaytk commented Apr 16, 2026

@reeselevine This PR should conflicts with #21873 (especially in ggml_backend_webgpu_graph_compute), so I will update accordingly after the PR is merged.

@yomaytk yomaytk changed the title ggml-webgpu: Add the support of fused RMS_NORM + MUL ggml-webgpu: Add fused RMS_NORM + MUL Apr 16, 2026
@yomaytk yomaytk force-pushed the fused-rms-norm-mul branch from 8fd976c to 5fef017 Compare April 17, 2026 02:22
Copy link
Copy Markdown
Contributor

@reeselevine reeselevine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting to work on fusion, this is a big step!

The performance not changing too much is a little disappointing, but also not a blocker. Once we get the fusion format working we can optimize. I wonder if the reason for the lack of performance is just that the current RMS_NORM is not very well-optimized? So the reduction in bandwidth ends up being hidden because RMS_NORM is too slow.

Do you know if this fusion path leads to significant performance gains in other backends?

Comment thread ggml/src/ggml-webgpu/ggml-webgpu.cpp Outdated
size_t memset_bytes_per_thread;

bool disable_fusion;
uint32_t num_additional_fused_ops;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this general structure comes from the vulkan backend right? I haven't looked into it too closely, but my first thought is that it seems too general, at least based on this PR, because you end up having to check which ops you are actually fusing, and this doesn't encode that at all.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is almost the same as the vulkan backend, but I agree that this is too general. I removed this as I explain below.

Comment thread ggml/src/ggml-webgpu/ggml-webgpu.cpp Outdated

static bool ggml_webgpu_can_fuse_check(webgpu_context & ctx, const struct ggml_cgraph * cgraph, int node_idx) {
// RMS_NORM + MUL
if (ggml_webgpu_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

under the hood can_fuse ends up repeating the if condition on RMS_NORM + MUL, so really should we have separate functions for each set of potential fusions?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it looks better to have a specific can_fuse function for each fusion.

Comment thread ggml/src/ggml-webgpu/ggml-webgpu.cpp Outdated
if (!ctx->disable_fusion) {
ggml_webgpu_can_fuse_check(ctx, cgraph, i);
}
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes, i)) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now we're hiding whether the encode_node encodes based on whether the next node will be fused.

Instead, what do you think of just calling encode node (maybe function name should just change to encode to encompass the fact it might encode multiple nodes), and updating i based on the number of fused operations. So for the new RMS_NORM + MUL, we end updating i by 2. That avoids hiding the fusion in the additional_fused_ops variable. That to me seems cleaner for now, but maybe I'm missing something that doesn't translate well to future fusions?

Copy link
Copy Markdown
Contributor Author

@yomaytk yomaytk Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. It looks like whether the ops can be fused can basically be determined by cgraph->nodes, at least from my understanding. So I moved the can_fuse logic to ggml_webgpu_encode, which removes num_additional_fused_ops and also eliminates the double-checking for fusion (ggml_webgpu_can_fuse_check in ggml_backend_webgpu_graph_computeand ctx->num_additional_fused_ops > 0). I think this fix addresses your comment. How about this?

Comment thread ggml/src/ggml-webgpu/ggml-webgpu.cpp Outdated
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
*(uint32_t *) rn_dst->op_params // epsilon, treated as f32 in the shader
Copy link
Copy Markdown
Contributor

@reeselevine reeselevine Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this leads to compiler warnings and will fail when the new ggml-webgpu-nvidia-ci is enabled, I moved to a new format: https://github.com/ggml-org/llama.cpp/blob/master/ggml/src/ggml-webgpu/ggml-webgpu.cpp#L1910

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I’ll update after this PR is rebased because this fix depends on the new function ggml_webgpu_u32_from_f32.

@yomaytk
Copy link
Copy Markdown
Contributor Author

yomaytk commented Apr 20, 2026

I re-ran the benchmarks on the same three models with -r 10 and measured
op-level GPU time for RMS_NORM + MUL using the profiling build
(GGML_WEBGPU_CPU_PROFILE / GGML_WEBGPU_GPU_PROFILE).

As in the first test, end-to-end t/s is nearly unchanged, but the GPU
profiling time for the RMS_NORM + MUL path is clearly reduced by the fusion.
The only Qwen3.5 pp512 has op-level regression.

The command is like this:
$ llama-bench -m gemma-4-E4B-it-Q4_K_M.gguf -fa 1 -p 0,1,512 -n 0,1,128 -r 10

Llama-3.2-3B-Instruct Q4_K_M

Test No-fused (t/s) Fused (t/s) Speedup Op-level time (no-fused → fused) Occupancy (% of GPU)
pp1 52.42 52.72 1.006x 1.46 → 0.83 ms (-43%) 1.79% → 1.03%
pp512 671.16 670.74 0.999x 16.13 → 6.98 ms (-57%) 0.82% → 0.36%
tg1 † 51.61 51.56 0.999x 1.54 → 0.88 ms (-43%) 1.82% → 1.04%
tg128 52.73 52.98 1.005x 180.68 → 101.81 ms (-44%) 1.90% → 1.08%

Qwen3.5-4B Q4_K_M

Test No-fused (t/s) Fused (t/s) Speedup Op-level time (no-fused → fused) Occupancy (% of GPU)
pp1 30.59 30.17 0.986x 1.97 → 1.22 ms (-38%) 1.78% → 1.10%
pp512 480.00 472.25 0.984x 28.22 → 47.99 ms (+70%) 1.17% → 1.93%
tg1 30.44 30.45 1.000x 2.01 → 1.21 ms (-40%) 1.81% → 1.09%
tg128 30.44 30.14 0.990x 238.62 → 154.62 ms (-35%) 1.83% → 1.20%

Gemma-4-E4B-it Q4_K_M

Test No-fused (t/s) Fused (t/s) Speedup Op-level time (no-fused → fused) Occupancy (% of GPU)
pp1 31.88 31.94 1.002x 4.91 → 2.70 ms (-45%) 4.19% → 2.35%
pp512 465.46 466.11 1.001x 51.48 → 37.71 ms (-27%) 2.04% → 1.50%
tg1 31.77 32.01 1.008x 4.93 → 2.79 ms (-43%) 4.17% → 2.41%
tg128 32.03 32.19 1.005x 609.67 → 336.46 ms (-45%) 4.39% → 2.48%

Using GGML_WEBGPU_GPU_PROFILE, I confirmed that the ops related to RMS_NORM and MUL in these three models are as follows.

  • no-fused: rms_norm, rms_norm_inplace, MUL_f32_inplace
  • fused: RMS_NORM_MUL, MUL_F32_inplace

Notes on the table:

  • Op-level time = GPU profile time of the ops that get fused.
    • no-fused: rms_norm + rms_norm_inplace + (MUL_f32_inplace(no-fused) − MUL_f32_inplace(fused))
    • fused: RMS_NORM_MUL
  • Occupancy = (op-level time) / (total GPU time).

I wonder if the reason for the lack of performance is just that the current RMS_NORM is not very well-optimized? So the reduction in bandwidth ends up being hidden because RMS_NORM is too slow.

That's plausible, yes. In addition, the RMS_NORM + MUL path's share of total
GPU time is small on WebGPU/M2 — roughly 1–4% depending on the model — so
I think that even this op-level reduction only moves end-to-end t/s by up to around 1%. I expect
this fusion's benefit to become more visible once RMS_NORM and the surrounding
ops are further optimized.

Do you know if this fusion path leads to significant performance gains in other backends?

I haven't measured other backends in my environment, but the corresponding
CUDA and Metal fusion PRs (#14800, #14596) report ~5–10% improvements on
pp/tg tests, which I'd expect here too once the other implementation is optimized.

Copy link
Copy Markdown
Contributor

@reeselevine reeselevine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a merge or rebase but otherwise this is looking good!

static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
ggml_cgraph * cgraph,
int node_idx,
int & num_encoded_ops) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be a cleaner way of returning the number of encoded ops then passing in the extra parameter by reference, but I also realize that the return of the optional encoded_op makes it slightly more complex. One option would be to just put num_encoded_ops as a field in webgpu_encoded_op. But I think this is also ok for now, I might think about it a bit more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks. I leave the current form for now then.

Comment thread ggml/src/ggml-webgpu/ggml-webgpu.cpp Outdated
case GGML_OP_REPEAT:
return ggml_webgpu_repeat(ctx, src0, node);
case GGML_OP_RMS_NORM:
if (ggml_webgpu_can_fuse_rms_norm_mul(ctx, cgraph, node_idx)) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a nice way of doing the check, thanks!

Comment thread ggml/src/ggml-webgpu/ggml-webgpu.cpp Outdated
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->disable_fusion = getenv("GGML_WEBGPU_DISABLE_FUSION") != nullptr;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hesitant to include environment variables in the WebGPU backend, ideally all the logic for why we should/should not do something can exist in the backend code itself. Is there a specific reason for including disabling fusion as an option like this?

Copy link
Copy Markdown
Contributor Author

@yomaytk yomaytk Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because I followed the other backends (vulkan, cuda and opencl have this env var), so I didn't have a ton of thoughts about this. I confirmed that the PR (#14545) adds the same env var to the vulkan backend because an execution error was reported after the RMS_NORM + MUL fusion PR was merged. However, the error was actually due to fusion implementation bugs, so it seems they added the env var only as a tentative guard. I realize we should be cautious about adding more external dependencies too, and I'm not sure disabling fusion is necessary for now. So I'm starting to think we should remove this env var and disable_fusion for now. What do you think?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I think we can remove it, and if we do run into issues due to fusion, we can try to fix forward or try to add guards within the WebGPU backend itself to disable it.

@yomaytk yomaytk force-pushed the fused-rms-norm-mul branch from 8d925e6 to ac69ca8 Compare April 21, 2026 15:22
key.overlap = context.overlap;
key.src_overlap = context.src_overlap;
ggml_webgpu_binary_pipeline_key key = {
.type = context.dst->type,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to move away from using the C++20 initializers, they lead to compiler warnings right now and when I enable the ggml-ci-webgpu-nvidia node they will cause errors, as that CI is stricter.

Copy link
Copy Markdown
Contributor Author

@yomaytk yomaytk Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I updated accordingly.

@yomaytk yomaytk force-pushed the fused-rms-norm-mul branch from be84398 to 89cc0bb Compare April 22, 2026 00:52
@reeselevine reeselevine requested review from CISC and ggerganov April 22, 2026 03:20
@reeselevine reeselevine merged commit 6da7168 into ggml-org:master Apr 22, 2026
46 of 49 checks passed
@yomaytk yomaytk deleted the fused-rms-norm-mul branch April 22, 2026 22:15
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Apr 23, 2026
* fused rms_norm_mul + mul

* Add GGML_WEBGPU_DISABLE_FUSION for being able to disable kernel fusion.

* Decouple num_fused_ops from webgpu_context; misc cleanup

* Fix eps handling and remove disable_fusion.

* Fix not to use c++20 initializers.
rsenthilkumar6 pushed a commit to rsenthilkumar6/llama.cpp that referenced this pull request May 1, 2026
* fused rms_norm_mul + mul

* Add GGML_WEBGPU_DISABLE_FUSION for being able to disable kernel fusion.

* Decouple num_fused_ops from webgpu_context; misc cleanup

* Fix eps handling and remove disable_fusion.

* Fix not to use c++20 initializers.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning WebGPU

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants