feat: MTP support for dense Qwen 3.5 with FastMTP vocabulary trimming#20700
feat: MTP support for dense Qwen 3.5 with FastMTP vocabulary trimming#20700itigges22 wants to merge 1 commit intoggml-org:masterfrom
Conversation
|
As of March 18th 2026, this is just a draft! Still work in progress :) If you have any questions feel free to lmk! |
|
March 19th- making headway, but wow... this is not an easy one. |
Investigation Update: Speculative Framework Crash on DeltaNetAfter enabling MTP speculative decoding (bypassing the seq_rm compat check), the server initializes correctly but segfaults during the first speculative draft/verify cycle on DeltaNet hybrid models. What works:
What crashes:
Root cause hypothesis: Added fallback (in latest push):
Next steps needed:
Tested on: RTX 5060 Ti 16GB, Linux, Qwen3.5-9B-MTP-Q4_K_M.gguf, |
Breakthrough: 95% MTP acceptance rate with cooldown fixRoot cause found and fixed. After draft rejection, MTP logits are read from the DRAFT token's position (last in the [sampled, draft] batch). These logits predict what comes after the rejected draft — which is wrong. The next proposal uses these stale logits, producing a cascade of bad drafts (13% acceptance rate → garbled output). Fix: Added cooldown flag in Results:
The remaining speed being similar to non-MTP (16.7 vs 16.7) is because cooldown means proposals happen every OTHER step. The theoretical max with cooldown is ~1.33x (not 2x). Removing cooldown for ACCEPTED drafts (only cooldown on rejection) should increase the effective speedup. Output quality is almost correct — minor degradation in docstrings ( |
fef9ada to
affba2a
Compare
Status Update — MTP Attention + FastMTPMajor rework since last update. Squashed all intermediate commits into a single clean commit. What changed:
Key finding: DeltaNet + speculative decode is fundamentally hardThe recurrent state in DeltaNet accumulates all previous tokens. Unlike KV cache, you can't The two-phase decode approach (decode sampled → verify → decode draft only if accepted) produces correct output but halves throughput since each accepted step requires 2 decode calls. Open questions for reviewers:
Happy to split this into smaller PRs if that helps review (e.g., separate the recurrent state fixes from the MTP graph builder). |
Add Multi-Token Prediction (MTP) speculative decoding for Qwen3.5 dense models (0.8B-27B). The MTP head uses a full transformer block (attention + FFN) to predict the next-next token, enabling ~28 tok/s on RTX 5060 Ti. Key changes: - Model loading: Qwen3.5 MTP layer tensors (nextn.eh_proj, attention weights, FFN) loaded into layers[n_layer-1] - Graph builder: Full MTP head with self-attention, gated RoPE, FFN, and vocabulary projection. Unfiltered hidden state passed for proper KV cache population during prompt processing. - FastMTP: Vocabulary trimming from 248K to 32K tokens via ggml_view_2d on the lm_head. Reduces draft generation from 22ms to 6ms (3.7x). - Speculative framework: MTP auto-detection for hybrid models, fuzzy seq_rm checkpoint matching for DeltaNet rollback. - Server: Two-phase decode option for hybrid/recurrent models to avoid DeltaNet state corruption from rejected drafts. - Recurrent state: Fixed copy_cell (ggml_view_1d takes element count, not bytes), buffer assignment for no_alloc views. Results on Qwen3.5-9B Q4_K_M (RTX 5060 Ti 16GB): - 28.1 tok/s with 82% acceptance rate (temp=0) - 92% acceptance with two-phase decode (correct output, 15 tok/s) - Draft generation: 6.1ms with FastMTP (vs 22.4ms full vocab)
affba2a to
19fdba5
Compare
Relationship to #18886 (MTP API) and #15225 (GLM MTP)This PR takes a different approach from @ngxson's MTP API design in #18886 — we use a single context with the MTP head inline in the main compute graph rather than separate contexts. This was necessary because Qwen3.5's hybrid DeltaNet architecture makes state copying between contexts complex (the recurrent state can't be trivially transferred). However, the core contributions here are architecture-independent and would benefit any MTP implementation for hybrid/recurrent models:
Happy to refactor to align with the #18886 API once it stabilizes. The DeltaNet-specific handling would need to be integrated regardless of API design — it's a fundamental requirement for speculative decoding on any recurrent/hybrid architecture (Mamba, RWKV, DeltaNet, etc.). If it helps, I can split the recurrent state fixes into a separate PR that's useful independent of MTP. |
|
Please note: This PR is still very much a WIP- I do not expect it to be merged any time soon. However, what I do hope is that it provides some detail into the direction llama.cpp should be going in terms of adding MTP support. Since I do not have the resources to fully test. I have no doubt that there are bugs, it needs to be refactored, etc... What I sincerely ask of the reviewers and the community to do is to take a look at the work done here, and see if you are able to find a better hopefully more suitable solution. This work is motivated by ATLAS — MTP support would enable meaningful speedups for local setups running the Qwen3.5 family of models. Best, Isaac :) |
|
Hi @itigges22, thanks for this PR — I've been testing it on Mac M4 (Metal) and RTX 3090 (CUDA) with Qwen3.5-9B Q4_K_M. The MTP implementation works well — no crashes, deterministic output, correct multi-turn behavior. However, I'm seeing a 63.5% draft acceptance rate at temperature=0, compared to the 82% reported in the PR description. My measurement (consistent across multiple requests): To help reproduce your 82% number, could you share:
My setup:
Thanks! |
|
Actually, re-reading the PR more carefully I see:
So you used mixed precision — Q4_K_M for the base model with F16 for the MTP head weights. My conversion quantized everything uniformly to Q4_K_M (including the MTP head), which would explain the lower acceptance rate. Could you share how you produced the mixed-precision GGUF? Specifically:
This would help reproduce the 82% number and would also be useful context for anyone else testing MTP. |
@petter-b Apologies for the delay! To answer your question- I converted the model myself from the raw HuggingFace weights using the standard convert_hf_to_gguf.py script that comes with llama.cpp. This script reads the HF model files and outputs a GGUF, and I then ran llama-quantize on it to get Q4_K_M. No special steps were taken to preserve the MTP head in F16- the entire model including the MTP layers was quantized uniformly to Q4_K_M. The f32 tensors you'd see in the GGUF are just norm/embedding layers that llama-quantize leaves in f32 by default. I should also mention that after more testing, the acceptance rate has been inconsistent — I've seen it range anywhere from 40% to 82%, and replicating the higher numbers consistently has proven difficult. It's been one of the harder things to track down on my end, so I wouldn't treat the 82% as a reliable baseline just yet. However, I do not have the compute to attempt any of this in FP16! So- my methodology and approach was grounded in the limited compute that I have! |
The PR ggml-org#20075/ggml-org#20700 cherry-picks brought in three diagnostic fprintf sites that print on every recurrent find_slot and every spec-decode draft verification. They flood stderr (~8K lines per 8K-fill bench) and would dominate the spec-decode hot path. Stripped: - find_slot entry print (src/llama-memory-recurrent.cpp) - spec verify per-token print (common/sampling.cpp) - spec verify bonus-token print (common/sampling.cpp) Signed-off-by: David Connolly <david@connol.ly>
|
@itigges22 very solid work and well architected. I did a plow through version myself a few days ago and was checking to see if theree was an open issue. I'll piggyback on yours. I vouch for the accuracy rate (matches my tests) as well as the variability, which I think I've mostly solved. There's an additional mechanism I have working in some configurations where you can use some additional techniques. I did run into a number of upstream bugs that have just not been hit before, I really don't know if anyone cares to fix them, but I'll post what I have and anyone please feel free to follow up with requests, feedback, corrections, etc. |
|
Here's the work: https://github.com/quivent/qwen-mtp-llamacpp Additional optimizations: |
|
This one is the killer. adaptive chained MTP prediction. After prediction, compute the confidence score, this on its own can allow up to 96% confidence and justification for a second pass, with a higher probability of accurate prediction, bumping the performance numbers to about 2x TPS with no quality loss I'm actually still exploring optimizations here including:
This is the first time I'm posting open source optimization work, mostly because life is hard and optimization is less difficult for me, and it took me until last week to have the realization that I should be publishing these for others. I also do not mind requests. I enjoy this quite a bit. Some Qwen lore which was uncovered in hunting for the missing tensors: Please no hammering me if I've done anything "out of policy" - this is my first contribution, it has multiple optimizations, and it is absolutely done out of protocol. I hope you all can let one offset the other. |
The PR ggml-org#20075/ggml-org#20700 cherry-picks brought in three diagnostic fprintf sites that print on every recurrent find_slot and every spec-decode draft verification. They flood stderr (~8K lines per 8K-fill bench) and would dominate the spec-decode hot path. Stripped: - find_slot entry print (src/llama-memory-recurrent.cpp) - spec verify per-token print (common/sampling.cpp) - spec verify bonus-token print (common/sampling.cpp) Signed-off-by: David Connolly <david@connol.ly>
Report: 27B output corruption on this branch (M3 Ultra, Metal)Built this PR branch ( Symptom: output is word-salad, word ordering corrupted — not a sampling issue, the tokens themselves are in wrong positions. Example ( Observed in both modes:
Plain llama.cpp (main branch) presumably wouldn't have this. MLX-LM 4-bit on the same merged BF16 weights generates clean, coherent output at 13.3 tok/s, so the weights themselves are fine. I see validation has only been done on 9B. This looks like the fix needs 27B-specific handling (different MTP head shape? different vocabulary trim?). Happy to help reproduce / test fixes — full pipeline + Q6_K GGUF repro is at https://github.com/AImindPalace/mac-studio-mlx-serving. Model is a private fine-tune, but swapping to base |
Feel free to dive deeper into it to see what the solution may be/ any documentation on bugs surrounding what you are seeing- I am also unable to fully test with larger models due to my compute constraint so any help related to the expanded test/ development is helpful. @quivent If you wanted to take a look potentially? |
Change >= to > in batch position check to allow re-evaluation starting at the same position as the last stored position. This is needed for speculative decoding where rejected draft tokens cause re-eval from the last accepted position. Matches the fix in upstream PR ggml-org#20700. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
Adds Multi-Token Prediction (MTP) speculative decoding for Qwen3.5 dense models (0.8B-27B). These hybrid DeltaNet/attention models have a built-in MTP head that predicts the next-next token, enabling speculative decoding without a separate draft model.
Key features:
copy_cell(element count vs byte count inggml_view_1d), fuzzyseq_rmcheckpoint matchingResults (Qwen3.5-9B Q4_K_M, RTX 5060 Ti 16GB VRAM):
With two-phase decode (guaranteed correct output):
Architecture
Qwen3.5 uses a repeating pattern of 3 DeltaNet (linear attention) + 1 full attention layers. The MTP head is a single full-attention transformer block that:
eh_projThe DeltaNet recurrent state cannot be partially rolled back (unlike KV cache), so rejected drafts corrupt the state. The two-phase decode option handles this by only decoding accepted drafts.
Files changed:
src/models/qwen35.cpp— MTP head graph builder with attention + FastMTPsrc/llama-memory-recurrent.cpp— copy_cell fix, seq_rm fuzzy matchingsrc/llama-model.cpp— MTP tensor loading, rs_size configsrc/llama-context.cpp/h— MTP logits extraction, reduced vocab trackingcommon/speculative.cpp— MTP state machine, FastMTP vocab supporttools/server/server-context.cpp— Two-phase decode for hybrid modelsinclude/llama.h—llama_get_mtp_n_vocab()APIconvert_hf_to_gguf.py— Qwen3.5 MTP tensor conversion supportTesting
Tested on:
--jinja --embeddings --parallel 1What's needed for merge: