Skip to content

model : refactor QKV into common build_qkv and create_tensor_qkv helpers#21245

Merged
CISC merged 2 commits intoggml-org:masterfrom
JoursBleu:refactor/build-qkv-helper
Apr 16, 2026
Merged

model : refactor QKV into common build_qkv and create_tensor_qkv helpers#21245
CISC merged 2 commits intoggml-org:masterfrom
JoursBleu:refactor/build-qkv-helper

Conversation

@JoursBleu
Copy link
Copy Markdown
Contributor

@JoursBleu JoursBleu commented Apr 1, 2026

Overview

Currently llama.cpp supports 112 model files in src/models/.

We modified the 85 applicable model files. Our changes abstract the duplicated
Q/K/V tensors' loading and graph-building code into two reusable helpers,
following the create_tensor_gate_up_exps pattern (#19139).

create_tensor_qkv (llama-model.cpp): tries fused wqkv/bqkv first (TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL), falls back to separate wq/wk/wv. Supports adding biases.

build_qkv (llama-graph.h/cpp): returns {Qcur, Kcur, Vcur} as 3D tensors. Fused case: single fused qkv matmul + ggml_view_3d split. Separate case: 3 separate matmuls + ggml_reshape_3d.

Test: test-llama-archs — all OK, 0 FAIL. Zero diff on llama-arch.cpp.

The remaining 27 models are not modified for the following reasons:

Reason Count Models
Non-attention (SSM/linear/RNN) 10 mamba, mamba-base, rwkv6, rwkv6-base, rwkv6qwen2, rwkv7, rwkv7-base, arwkv7, delta-net-base, wavtokenizer-dec
MLA attention 4 deepseek2, minicpm3, minimax-m2, plm
Graph directly uses layer.wqkv (non-standard layout) 3 cogvlm, openelm, plamo2
Q+gate joint projection 4 qwen35, qwen35moe, qwen3next, plamo3
n_embd_head_k != n_embd_head_v 2 step35-iswa, mimo2-iswa
No fused wqkv_enc 1 t5-enc
Other special architectures 3 olmo2, olmoe, kimi-linear

Additional information

Basing on the discussion in #20628 (@am17an, @ngxson). The plan is:

  1. This PR: This PR does not modify any logic, it simply extracts the redundant code into
    the two functions above, and adds handling for the fused qkv case.
  2. Future PR: add --fuse-qkv to convert_hf_to_gguf.py.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES - used as a translation tool for translating the PR description

@github-actions github-actions Bot added the model Model specific label Apr 1, 2026
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch 3 times, most recently from da129d5 to 26e72e0 Compare April 1, 2026 04:18
@JoursBleu JoursBleu marked this pull request as ready for review April 1, 2026 06:01
@JoursBleu JoursBleu requested a review from CISC as a code owner April 1, 2026 06:01
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp
Comment thread src/llama-model.cpp
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from 26e72e0 to bcc69fd Compare April 1, 2026 09:05
@JoursBleu
Copy link
Copy Markdown
Contributor Author

hi @CISC,

  • Removed the has_bias flag.
  • Bias tensors are now always created with TENSOR_NOT_REQUIRED.
  • Fixed the incomplete conversions and typos mentioned above.

Comment thread src/llama-model.cpp
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp
@JoursBleu
Copy link
Copy Markdown
Contributor Author

@CISC Done:

  • Remove unnecessary comments/restore comments that should be retained.
  • JAIS2 restores manually created bias tensors.

Comment thread src/models/afmoe.cpp
@JoursBleu
Copy link
Copy Markdown
Contributor Author

@CISC Done:

  • Remove the Vcur reshape in afmoe.cpp

Copy link
Copy Markdown
Member

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OP is inaccurate, there's nothing special about these:

  • nemotron-h: just add build_qkv in llm_build_nemotron_h::build_attention_layer
  • granite-hybrid: just add build_qkv in lm_build_granite_hybrid::build_attention_layer
  • olmo/mpt/dbrx: use build_qkv, add clamping
  • gemma3n-iswa: just do build_qkv
  • t5-dec/t5-enc: do build_qkv on normal self-attention
  • bert: use build_qkv
  • lfm2: do build_qkv in build_attn_block

Comment thread src/llama-graph.cpp Outdated
@JoursBleu JoursBleu marked this pull request as draft April 2, 2026 12:56
@CISC
Copy link
Copy Markdown
Member

CISC commented Apr 4, 2026

I meant move the clamping to build_qkv.

@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from 09d8066 to 04506d4 Compare April 6, 2026 01:27
Comment thread src/llama-graph.cpp Outdated
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch 2 times, most recently from 050b5a9 to 623ed29 Compare April 9, 2026 01:17
@JoursBleu JoursBleu marked this pull request as ready for review April 9, 2026 01:17
@JoursBleu
Copy link
Copy Markdown
Contributor Author

@CISC Done:

  • Extended build_qkv to bert, mpt, dbrx, olmo, lfm2, nemotron-h, granite-hybrid, gemma3n-iswa, t5-dec, t5-enc;
  • Clamping handled internally in build_qkv using hparams.f_clamp_kqv.

Comment thread src/models/gemma3n-iswa.cpp Outdated
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from 623ed29 to ccd1f60 Compare April 10, 2026 13:39
Comment thread src/llama-graph.h Outdated
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from ccd1f60 to 67a8492 Compare April 11, 2026 05:29
@CISC CISC requested review from ggerganov and ngxson April 11, 2026 09:36
Comment thread src/models/mimo2-iswa.cpp Outdated
Comment thread src/models/openai-moe-iswa.cpp Outdated
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from 67a8492 to d8bf733 Compare April 12, 2026 07:18
@JoursBleu
Copy link
Copy Markdown
Contributor Author

JoursBleu commented Apr 13, 2026

@ngxson @am17an @ggerganov This PR is ready. Could you take a look when you have time?

Copy link
Copy Markdown
Contributor

@am17an am17an left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job!

Comment thread src/llama-graph.cpp Outdated
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from d8bf733 to 51dbd8c Compare April 16, 2026 09:05
@ggerganov ggerganov added the merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. label Apr 16, 2026
@CISC CISC merged commit 9db77a0 into ggml-org:master Apr 16, 2026
48 of 50 checks passed
cnsiva pushed a commit to saas-home/llama.cpp that referenced this pull request Apr 17, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
samuraieng pushed a commit to samuraieng/llama.cpp that referenced this pull request Apr 19, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
jspadgett pushed a commit to jspadgett/llama.cpp that referenced this pull request Apr 20, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
mengqin pushed a commit to mengqin/llama.cpp that referenced this pull request Apr 20, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Apr 21, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Apr 23, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
rsenthilkumar6 pushed a commit to rsenthilkumar6/llama.cpp that referenced this pull request May 1, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
jimbothigpen pushed a commit to jimbothigpen/frankenturbo2 that referenced this pull request May 2, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. model Model specific

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants