Skip to content

CUDA: add small-k optimization for mul-mat-id bs 2-4#20885

Closed
am17an wants to merge 2 commits intoggml-org:masterfrom
am17an:cuda_small_k_mul_mat_id
Closed

CUDA: add small-k optimization for mul-mat-id bs 2-4#20885
am17an wants to merge 2 commits intoggml-org:masterfrom
am17an:cuda_small_k_mul_mat_id

Conversation

@am17an
Copy link
Copy Markdown
Contributor

@am17an am17an commented Mar 23, 2026

Following on #20635

I just did a basic test on a 5090, no compilation slow-down

Model Microbatch size Test t/s master t/s cuda_small_k_mul_mat_id Speedup
qwen35moe 35B.A3B Q4_K_S 1 pp512 232.06 241.96 1.04
qwen35moe 35B.A3B Q4_K_S 2 pp512 331.25 351.08 1.06
qwen35moe 35B.A3B Q4_K_S 3 pp512 448.62 486.85 1.09
qwen35moe 35B.A3B Q4_K_S 4 pp512 500.31 550.46 1.10

@am17an am17an requested a review from a team as a code owner March 23, 2026 04:42
@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Mar 23, 2026
Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need to edit the switch cases for 2-4 columns for some of the GPUs.

Comment thread ggml/src/ggml-cuda/mmvq.cu Outdated
Comment thread ggml/src/ggml-cuda/mmvq.cu Outdated
Comment thread ggml/src/ggml-cuda/mmvq.cu Outdated
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Mar 23, 2026

I think you also need to edit the switch cases for 2-4 columns for some of the GPUs.

MUL-MAT-ID takes the ncols_dst=1 path, which we added in #18958

@gaugarg-nv
Copy link
Copy Markdown
Contributor

gaugarg-nv commented Mar 23, 2026

Sorry, I didn't realize you had created a PR for the same issue. I filed one just now: #20905

The implementation is very different, though. So it will be good to do a perf comparison.

@gaugarg-nv
Copy link
Copy Markdown
Contributor

Checked the perf data. PR #20905 shows better perf than #20885 across models and GPUs.

Performance
gpu_info model_type n_ubatch n_prompt Master-f93c09e26-avg_ts 20885-avg_ts 20905-avg_ts 20905/20885
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 2 512 450.37331 489.00997 492.9244 1.01
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 3 512 521.34689 577.29033 604.3901 1.05
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 4 512 598.14532 675.02906 726.8059 1.08
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_K - Medium 2 512 383.56361 406.58554 414.1365 1.02
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_K - Medium 3 512 496.52874 536.76288 575.0555 1.07
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_K - Medium 4 512 550.21397 602.64762 665.2835 1.10
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q8_0 2 512 333.9944 344.84384 391.5877 1.14
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q8_0 3 512 430.29557 456.19274 476.4898 1.04
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q8_0 4 512 494.92561 532.66282 573.6039 1.08
NVIDIA GeForce RTX 5090 qwen35moe 35B.A3B Q4_K - Medium 2 512 336.54982 352.34231 358.5733 1.02
NVIDIA GeForce RTX 5090 qwen35moe 35B.A3B Q4_K - Medium 3 512 455.54097 486.40261 500.7537 1.03
NVIDIA GeForce RTX 5090 qwen35moe 35B.A3B Q4_K - Medium 4 512 540.15397 585.58987 614.7323 1.05
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 2 512 539.64431 539.81394 630.7859 1.17
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 3 512 694.81228 694.11404 773.4004 1.11
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 4 512 789.33624 789.5108 908.4758 1.15
NVIDIA GeForce RTX 5090 gpt-oss 20B Q4_K - Medium 2 512 564.72012 564.23564 600.5675 1.06
NVIDIA GeForce RTX 5090 gpt-oss 20B Q4_K - Medium 3 512 721.69967 721.1223 804.7416 1.12
NVIDIA GeForce RTX 5090 gpt-oss 20B Q4_K - Medium 4 512 802.83533 802.49769 922.9511 1.15
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3moe 30B.A3B Q4_0 2 512 447.2068 482.4597 485.0051 1.01
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3moe 30B.A3B Q4_0 3 512 521.3682 576.4494 595.0763 1.03
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3moe 30B.A3B Q4_0 4 512 596.4848 673.9139 720.1008 1.07
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3moe 30B.A3B Q4_K - Medium 2 512 382.9618 403.6524 462.0411 1.14
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3moe 30B.A3B Q4_K - Medium 3 512 503.1637 542.0355 571.6991 1.05
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3moe 30B.A3B Q4_K - Medium 4 512 562.423 615.7635 669.4282 1.09
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3moe 30B.A3B Q8_0 2 512 332.2923 341.0055 382.9107 1.12
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3moe 30B.A3B Q8_0 3 512 434.0147 456.8183 471.7925 1.03
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3moe 30B.A3B Q8_0 4 512 499.0083 534.8668 564.7199 1.06
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen35moe 35B.A3B Q4_K - Medium 2 512 333.2131 348.4453 400.1267 1.15
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen35moe 35B.A3B Q4_K - Medium 3 512 452.5427 484.1747 493.7494 1.02
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen35moe 35B.A3B Q4_K - Medium 4 512 540.4372 587.0452 608.4603 1.04
NVIDIA RTX PRO 6000 Blackwell Workstation Edition gpt-oss 20B MXFP4 MoE 2 512 537.8328 538.5858 623.5488 1.16
NVIDIA RTX PRO 6000 Blackwell Workstation Edition gpt-oss 20B MXFP4 MoE 3 512 696.4054 697.833 767.9747 1.10
NVIDIA RTX PRO 6000 Blackwell Workstation Edition gpt-oss 20B MXFP4 MoE 4 512 798.8649 799.0684 910.6224 1.14
NVIDIA RTX PRO 6000 Blackwell Workstation Edition gpt-oss 20B Q4_K - Medium 2 512 560.4776 561.4026 655.734 1.17
NVIDIA RTX PRO 6000 Blackwell Workstation Edition gpt-oss 20B Q4_K - Medium 3 512 724.9005 725.0865 805.8596 1.11
NVIDIA RTX PRO 6000 Blackwell Workstation Edition gpt-oss 20B Q4_K - Medium 4 512 817.6949 817.3278 932.1244 1.14
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3next 80B.A3B Q4_K - Medium 2 512 258.839 273.9867 316.9603 1.16
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3next 80B.A3B Q4_K - Medium 3 512 349.9941 378.9441 389.8006 1.03
NVIDIA RTX PRO 6000 Blackwell Workstation Edition qwen3next 80B.A3B Q4_K - Medium 4 512 406.3511 449.5354 471.7863 1.05
NVIDIA RTX PRO 6000 Blackwell Workstation Edition gpt-oss 120B MXFP4 MoE 2 512 347.8694 348.4328 397.0758 1.14
NVIDIA RTX PRO 6000 Blackwell Workstation Edition gpt-oss 120B MXFP4 MoE 3 512 449.5139 449.8659 488.3505 1.09
NVIDIA RTX PRO 6000 Blackwell Workstation Edition gpt-oss 120B MXFP4 MoE 4 512 513.5035 514.2759 574.5022 1.12
NVIDIA GeForce RTX 3090 qwen3moe 30B.A3B Q4_0 2 512 256.0022 282.6422 300.2996 1.06
NVIDIA GeForce RTX 3090 qwen3moe 30B.A3B Q4_0 3 512 310.1992 352.6244 394.1881 1.12
NVIDIA GeForce RTX 3090 qwen3moe 30B.A3B Q4_0 4 512 337.4377 389.6713 450.5619 1.16
NVIDIA GeForce RTX 3090 qwen3moe 30B.A3B Q4_K - Medium 2 512 242.4454 263.1537 278.222 1.06
NVIDIA GeForce RTX 3090 qwen3moe 30B.A3B Q4_K - Medium 3 512 287.2251 317.4232 353.3931 1.11
NVIDIA GeForce RTX 3090 qwen3moe 30B.A3B Q4_K - Medium 4 512 309.4449 344.5603 392.9521 1.14
NVIDIA GeForce RTX 3090 qwen35moe 35B.A3B Q4_K - Medium 2 512 222.2873 231.1396 242.5083 1.05
NVIDIA GeForce RTX 3090 qwen35moe 35B.A3B Q4_K - Medium 3 512 286.8951 303.804 331.7216 1.09
NVIDIA GeForce RTX 3090 qwen35moe 35B.A3B Q4_K - Medium 4 512 330.5946 355.0937 398.2023 1.12
NVIDIA GeForce RTX 3090 gpt-oss 20B MXFP4 MoE 2 512 316.2192 317.2192 346.3659 1.09
NVIDIA GeForce RTX 3090 gpt-oss 20B MXFP4 MoE 3 512 363.7077 365.5939 423.9369 1.16
NVIDIA GeForce RTX 3090 gpt-oss 20B MXFP4 MoE 4 512 385.6871 386.7337 466.7459 1.21
NVIDIA GeForce RTX 3090 gpt-oss 20B Q4_K - Medium 2 512 328.897 329.5281 363.2213 1.10
NVIDIA GeForce RTX 3090 gpt-oss 20B Q4_K - Medium 3 512 371.8909 373.1746 435.7728 1.17
NVIDIA GeForce RTX 3090 gpt-oss 20B Q4_K - Medium 4 512 389.532 390.3885 472.3913 1.21

gaugarg-nv pushed a commit to gaugarg-nv/llama.cpp that referenced this pull request Mar 25, 2026
…optimization only for cases where it benefits

Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8
@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Mar 26, 2026

Superseded by #20905

@am17an am17an closed this Mar 26, 2026
@am17an am17an deleted the cuda_small_k_mul_mat_id branch March 26, 2026 02:16
JohannesGaessler pushed a commit that referenced this pull request Mar 29, 2026
* Optimize MOE GEMV kernel for BS > 1.

The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row.

New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync).

This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization.

* Remove em-dashes

* Cherry-pick changes from @am17an PR #20885 to enable small_k optimization only for cases where it benefits

Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8

* Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
ggerganov pushed a commit to ggml-org/ggml that referenced this pull request Apr 1, 2026
* Optimize MOE GEMV kernel for BS > 1.

The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row.

New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync).

This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization.

* Remove em-dashes

* Cherry-pick changes from @am17an PR ggml-org/llama.cpp#20885 to enable small_k optimization only for cases where it benefits

Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8

* Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
ggerganov pushed a commit to ggml-org/ggml that referenced this pull request Apr 1, 2026
* Optimize MOE GEMV kernel for BS > 1.

The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row.

New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync).

This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization.

* Remove em-dashes

* Cherry-pick changes from @am17an PR ggml-org/llama.cpp#20885 to enable small_k optimization only for cases where it benefits

Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8

* Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
slartibardfast pushed a commit to slartibardfast/llama.cpp that referenced this pull request Apr 12, 2026
* Optimize MOE GEMV kernel for BS > 1.

The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row.

New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync).

This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization.

* Remove em-dashes

* Cherry-pick changes from @am17an PR ggml-org#20885 to enable small_k optimization only for cases where it benefits

Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8

* Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
* Optimize MOE GEMV kernel for BS > 1.

The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row.

New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync).

This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization.

* Remove em-dashes

* Cherry-pick changes from @am17an PR ggml-org#20885 to enable small_k optimization only for cases where it benefits

Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8

* Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
pull Bot pushed a commit to sh1970/whisper.cpp that referenced this pull request May 1, 2026
* Optimize MOE GEMV kernel for BS > 1.

The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row.

New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync).

This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization.

* Remove em-dashes

* Cherry-pick changes from @am17an PR ggml-org/llama.cpp#20885 to enable small_k optimization only for cases where it benefits

Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8

* Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
rsenthilkumar6 pushed a commit to rsenthilkumar6/llama.cpp that referenced this pull request May 1, 2026
* Optimize MOE GEMV kernel for BS > 1.

The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row.

New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync).

This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization.

* Remove em-dashes

* Cherry-pick changes from @am17an PR ggml-org#20885 to enable small_k optimization only for cases where it benefits

Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8

* Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants