Skip to content

turboquant_kv_cache_updated_v10#32

Merged
InfernalDread merged 170 commits intoInfernalDread:turboquant_kv_cache_updated_v10from
Addy-ad:addyad-latest
Apr 21, 2026
Merged

turboquant_kv_cache_updated_v10#32
InfernalDread merged 170 commits intoInfernalDread:turboquant_kv_cache_updated_v10from
Addy-ad:addyad-latest

Conversation

@InfernalDread
Copy link
Copy Markdown
Owner

No description provided.

TheTom and others added 30 commits March 26, 2026 12:15
New types: GGML_TYPE_TURBO3_0 (3-bit) and GGML_TYPE_TURBO4_0 (4-bit)
Implements PolarQuant + QJL compression per the ICLR 2026 paper.

Block size = 128 (matching head_dim for optimal rotation Gaussianization)
turbo3: 52 bytes per 128 values = 3.25 bits/value (4.9× vs fp16)
turbo4: 68 bytes per 128 values = 4.25 bits/value (3.8× vs fp16)

Status:
- ✅ Type definitions in ggml.h
- ✅ Block structures in ggml-common.h
- ✅ Quantize/dequantize C implementation in ggml-turbo-quant.c
- ✅ Registered in ggml.c type traits
- ✅ Added to kv_cache_types in arg.cpp
- ✅ Builds successfully
- ✅ Shows in --help output
- ❌ Metal SET_ROWS kernel not implemented (blocks GPU inference)
- ❌ Needs Metal dequantize kernels for attention computation

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>


Added Metal shader implementations:
- quantize_turbo3_0 / quantize_turbo4_0 (per-block quantization)
- dequantize_turbo3_0 / dequantize_turbo4_0 (type4x4 and type4 variants)
- kernel_set_rows_turbo template (128-element block size)
- Flash attention instantiations for all dk/dv variants

Added TURBO3_0/TURBO4_0 to Metal device SET_ROWS validation.

Builds successfully. Testing with Qwen 3.5 35B-A3B MoE on M5 Max.

Note: Initial version uses simplified quantization (no rotation matrix)
for Metal compatibility. Full rotation requires custom kernel with extra
buffer bindings — tracked for follow-up.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…g#21

Embedded pre-computed 128×128 rotation and QJL matrices (256KB constant
memory) directly in the Metal shader. Both quantize and dequantize now
perform the full TurboQuant algorithm:

Quantize: normalize → rotate → codebook → inverse rotate → residual → QJL
Dequantize: codebook → inverse rotate → QJL correction → rescale

Previous version (no rotation) produced garbage. This should produce
meaningful output since the rotation Gaussianizes the KV distribution.

Note: dequantize does full 128-element rotation per chunk (8× work).
Optimization possible with caching or restructured kernel in follow-up.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ml-org#21

- Inlined turbo-matrices.h directly into ggml-metal.metal (256KB)
  to fix JIT compilation failure with #include
- Added C round-trip test (test-turbo-quant.c):
  turbo3 cosine=0.906, turbo4 cosine=0.966 — matches Python prototype
- Metal library loads successfully ("loaded in 5.9 sec")
- Model runs on Metal but output quality needs debugging
  (Metal quantize/dequantize may have a bug vs the working C version)

C round-trip PROVES the algorithm works in C. Metal shader needs
debugging — likely an issue with the dequantize chunk addressing
or the large constant arrays in thread-local memory.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…org#23

Codex review found:
1. Stale duplicate code in dequantize_turbo3_0_t4 (compile would fail)
2. thread static is risky/non-portable in MSL

Fixed: removed thread static caching, using plain thread locals.
Speed unchanged (2.4 tok/s) — the static caching wasn't actually working
on Metal. True optimization needs architectural change in flash attention
kernel to dequantize once per block, not per chunk.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…gml-org#26

Massive reduction in constant memory and compute:
- 256KB of dense matrices → 512 bytes of sign arrays
- O(d²) = 16,384 ops → O(d log d) = 896 ops per rotation
- Metal shader file: 1.5MB → 432KB

Speed: still 2.4 tok/s. WHT reduced per-rotation cost but the
bottleneck is redundant calls (8-32× per block from flash attention).
The dequantize function is called per 4/16-element chunk, each time
doing the full 128-element WHT. Need to modify the flash attention
kernel to dequantize once per block.

Quality: WHT+signs gives BETTER quality than dense QR on real KV
tensors (cosine 0.94 vs 0.79 at 2-bit). Sub-Gaussian distribution
(kurtosis 1.53) means fewer outliers hitting extreme centroids.

Reviewed by Codex: WHT butterfly correct, inverse order verified,
QJL correction matches reference C implementation.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…gml-org#23

Root cause analysis: 8-32× redundant full-block dequantize per block
from flash attention template. Four approaches documented with expected
speedups and risk levels.

Plan: D (reduce overhead) → A/B (eliminate redundant calls)
Target: 2.4 tok/s → 20-40 tok/s

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…-org#23

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…gml-org#23

No-op dequant test: even returning all zeros from dequantize, turbo3
runs at 2.4 tok/s (same as with full WHT rotation). The bottleneck is
NOT in the attention dequantize path.

New hypothesis: the SET_ROWS (quantize) path is the bottleneck. The
Metal quantize_turbo3_0 function does 3 WHT rotations per KV write,
totaling ~3200 ops per block × 224 blocks per token.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…rg#23

CRITICAL BUG: The #include "turbo-wht.h" caused Metal JIT compilation
to fail at runtime. The model silently fell back to CPU for ALL ops.
ALL previous benchmarks (2.4 tok/s) were measuring CPU, not Metal GPU.

After inlining the header:
- MoE gen: 2.4 → 10.7 tok/s (4.5× improvement, now actually on Metal)
- MoE prompt: 4.2 → 60.9 tok/s (14.5× improvement)

Remaining gap vs q8_0: 85 → 10.7 tok/s (8× slower, down from 35×)

This is the SAME bug we hit with turbo-matrices.h earlier.
Rule: NEVER use #include in ggml-metal.metal — always inline.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…org#23

Previous 2.4 tok/s was CPU fallback. Real Metal numbers:
MoE: 10.7 tok/s gen (8× slower than q8_0, was thought to be 35×)
Qwopus: 5.3 tok/s gen (3.3× slower than q8_0)

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…l-org#27

Full investigation log with all tests, results, and the root cause.
Upstream TurboQuant activity tracked in ggml-org#27.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…-org#28

Key findings from Dejan.ai, unixsysdev, and mudler:
1. QJL naively added back destroys quality (cosine 0.69)
2. Pre-rotate queries eliminates rotation from dequant path
3. WHT abandoned by everyone — dense QR or no rotation preferred
4. unixsysdev gets -0.8% speed loss with fused CUDA kernel
5. We're the only Metal implementation

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…in) ggml-org#23

Removing WHT rotation from dequant (quality broken, speed test only):
  gen: 10.7 → 49.1 tok/s (4.6× improvement, 57% of q8_0)
  prompt: 67.3 → 162.6 tok/s

Confirms pre-rotate-queries would deliver ~49 tok/s.
Remaining gap (49 vs 85) is block size + QJL overhead.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Speed ceiling confirmed: stripping rotation from dequant gives 49.1 tok/s
(vs 10.7 with rotation, vs 85.5 q8_0 baseline).

Implementation plan: store rotation matrix in KV cache, apply to Q in
graph builder, strip from Metal dequant. 6 files to modify.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…org#23

Instead of inverse-rotating every K during dequant, rotate Q once
before attention. Math: <q, R^T*c[idx]> = <R*q, c[idx]>.

Changes:
- Store rotation matrix (R^T) in KV cache, filled after buffer clear
- Apply ggml_mul_mat(R_T, q) in build_attn_mha after permute
- Strip turbo_rotate_inverse from Metal dequant
- Dynamic cast to access rotation from mctx

Results:
- MoE gen: 10.7 → 51.4 tok/s (4.8× speedup)
- MoE prompt: 67.3 → 160.3 tok/s (2.4× speedup)
- Now at 60% of q8_0 speed with 4.9× compression
- Model produces coherent output

Codex review: fixed buffer clear ordering (was zeroing rotation after init).
Verified: rotation point is correct (after 4d reshape + permute, ne[0]=128).

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…gml-org#23

Full investigation log documenting every test, every dead end, and every
breakthrough. 21× total improvement from CPU fallback to pre-rotate-queries.

Key lessons: no #include in Metal, no-op testing, pre-rotate-queries,
buffer clear ordering, codex+roast catch real bugs.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Validated on real Qwen3 KV tensors: cosine sim 0.9508 → 0.9831 (+3.2%)
MSE-only better on 99.3% of vectors including p1 tails.

3-bit index split: lower 2 bits in qs[], upper 1 bit in signs[].
No QJL stage in quantize or dequant.

Results:
- MoE gen: 51.4 → 62.2 tok/s (73% of q8_0, was 60%)
- MoE prompt: 160 → 200 tok/s (90% of q8_0)
- Qwopus gen: 14.6 → 15.5 tok/s (88% of q8_0, was 83%)
- Qwopus prompt: 67 → 83 tok/s (100% of q8_0!)

Codex verified: bit packing correct, quantize/dequant consistent.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Speed ceiling without Q rotation: 61.3 tok/s (vs 62.2 with it).
The 128×128 ggml_mul_mat adds <1% overhead on Metal.

Remaining gap is structural (block size + dequant complexity).
Final: MoE 62.2 tok/s (73%), Qwopus 15.5 tok/s (88%).

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diagnostic benchmark proves the 26% gap is entirely from block size 128.
q4_0 (block 32, 4-bit quantization) runs at 84.2 tok/s = identical to q8_0.

Next: turbo3 with block size 32.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Changed QK_TURBO3 from 128 to 32 (storage block size).
Rotation still operates on 128-element groups (QK_TURBO3_GROUP=128).
SET_ROWS kernel processes 4 blocks per rotation group.
Flash attention nl_k changed from 32 to 8 (matching q4_0).

Block struct: 14 bytes per 32 values = 3.5 bits/val → 4.6× compression.

Results:
- MoE gen: 62.2 → 77.7 tok/s (91% of q8_0 at 85.5)
- MoE prompt: 200 → 218.5 tok/s (98% of q8_0)
- Qwopus gen: 15.5 → 17.0 tok/s (97% of q8_0 at 17.6)
- Qwopus prompt: 83 → 89.5 tok/s (108% of q8_0 — FASTER)

Target was 75+ tok/s. Exceeded.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Codex post-commit review found:
1. TURBO_D was QK_TURBO3 (now 32) — broke turbo4 C array sizes
2. SET_ROWS kernel turbo3-specific but instantiated for turbo4
3. Tail block drop for non-128 head dims

Fixed ggml-org#3 (TURBO_D). ggml-org#1 and ggml-org#2 don't affect turbo3+dk128 path.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…l-org#30

Perplexity benchmarking reveals catastrophic quality failure:
- f16: 6.121, q8_0: 6.111, q4_0: 6.142
- turbo3: 165.6 (27× worse)

Speed benchmarks were meaningless — fast garbage.
Root cause investigation needed before any quality claims.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1. V cache returns rotated-space values (cosine=0.02 vs correct 0.987)
2. dynamic_cast to llama_kv_cache_context fails for MoE models
   (uses llama_memory_hybrid_context, not kv_cache_context)
   → Q rotation and V inverse rotation NEVER executed

Fix: store rotation tensors in llm_graph_context, not KV cache.
Or access through hybrid memory interface.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…gml-org#31

Block 128: PPL=165.6 (same as block 32)
Disabled Q rotation: PPL=165.6 (same)
Root cause: dynamic_cast fails for MoE hybrid memory context.
Q rotation and V inverse rotation never execute.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ml-org#31 ggml-org#30

ROOT CAUSE: pre-rotate-queries never executed because:
1. Q ne[0]=256 (GQA concatenated heads), rotation matrix ne[0]=128
2. mctx dynamic_cast failed for MoE hybrid memory

FIX: put inverse WHT rotation back in dequantize_full_block.
This is slower (10.7 tok/s vs 77.7) but produces CORRECT results.

PERPLEXITY RESULTS:
- f16:     6.121
- q8_0:    6.111
- q4_0:    6.142
- turbo3:  6.194 (+1.2% vs q8_0) ✅

The speed optimization (pre-rotate-queries) needs to be reimplemented
to work with GQA head layout and hybrid memory types.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Quality confirmed: PPL 6.194 (+1.4% of q8_0)
Speed: 10.7 tok/s (inverse rotation in dequant, no pre-rotate-queries)
Previous speed claims (51-77 tok/s) were invalid — measured garbage output speed.

Key lessons documented for future reference.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Addy-ad added 20 commits April 10, 2026 02:40
feat: add Turbo/Iso/PlanarQuant with Windows MSVC compatibility

- Added TurboQuant, IsoQuant, PlanarQuant quantizations from:
  https://github.com/johndpope/llama-cpp-turboquant
- Integrated features from feature/planarquant-kv-cache branch.
- Added Windows/MSVC compatibility fixes:
    - Resolved M_PI definition issues.
    - Fixed InnerQ linker symbols.
    - Optimized std::transform loops for stable Windows builds.
- Reference: https://www.reddit.com/r/LocalLLaMA/comments/1sbdihw/gemma_4_31b_at_256k_full_context_on_a_single_rtx/
@InfernalDread InfernalDread merged commit 14ddbbf into InfernalDread:turboquant_kv_cache_updated_v10 Apr 21, 2026
16 of 57 checks passed
InfernalDread pushed a commit that referenced this pull request Apr 23, 2026
…n data

Part of #32: turbo3 prefill degrades relative to q8_0 with context length.

Changes so far:
- Skip ggml_cont when tensors already contiguous (+1%, minimal)
- Generated 32x32 rotation matrices (turbo-rotation-data-32.h) for
  reduced group size approach (16x less matmul compute)
- Fixed V un-rotation to check v->type not k->type

Next: update QK_TURBO3_GROUP, Metal WHT kernel, and KV cache for d=32.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants