Skip to content

fix: 8-bit dequant for MLX mixed-precision gate quantization#14

Open
userFRM wants to merge 1 commit intodanveloper:mainfrom
userFRM:fix/8bit-gate-dequant
Open

fix: 8-bit dequant for MLX mixed-precision gate quantization#14
userFRM wants to merge 1 commit intodanveloper:mainfrom
userFRM:fix/8bit-gate-dequant

Conversation

@userFRM
Copy link
Copy Markdown

@userFRM userFRM commented Mar 23, 2026

Problem

MLX 4-bit quantized models use 8-bit precision for routing gates, specified per-tensor in config.json:

"quantization": {
    "bits": 4, "group_size": 64,
    "model.layers.0.mlp.gate": {"group_size": 64, "bits": 8},
    "model.layers.0.mlp.shared_expert_gate": {"group_size": 64, "bits": 8}
}

The 4-bit dequant kernel extracts 8 nibbles per uint32, but these tensors pack 4 bytes per uint32 (8-bit). This corrupts routing gate scores, selecting wrong experts every layer, producing nonsensical output.

Verification

Compared gate output against MLX Python reference for mlx-community/Qwen3-Coder-Next-4bit:

  • Without fix: gate scores have wrong magnitudes and signs (RMS 1.2 vs MLX 6.8)
  • With fix (CPU path): gate scores match MLX exactly (same top expert indices, same score range)

Forced full CPU computation (g_metal = NULL) confirmed coherent output: "2 + 2 = 4", correct code generation, proper EOS handling.

Changes

shaders.metal: Added dequant_matvec_8bit kernel — same tiled ROWS_PER_TG=8 structure as dequant_matvec_4bit_v3, but extracts 4 bytes per uint32 with & 0xFF instead of 8 nibbles with & 0xF. FMA-optimized with precomputed scale*x and bias*x.

infer.m:

  • Added int bits field to BatchMatvecSpec for per-tensor bit-width dispatch
  • Added matvec_8bit pipeline state to MetalCtx
  • Added cpu_dequant_matvec_8bit CPU fallback
  • Updated gpu_encode_batch_matvec and gpu_batch_matvec to select 8-bit kernel when bits == 8
  • Marked gate_w and seg_w (shared_expert_gate) as bits=8 in all 7 BatchMatvecSpec initialization sites

Impact

Affects any MLX quantized model with per-tensor bit-width overrides in the quantization config. This is standard for Qwen3 family models.

Fixes #10

Test plan

  • ./infer --prompt "Hello" --tokens 20 --k 4 produces coherent output
  • Gate scores match MLX reference (--timing shows reasonable routing)
  • No regression on 2-bit mode (--2bit)
  • Shader compiles on M1/M2/M3/M4

MLX 4-bit models quantize routing gates (mlp.gate, mlp.shared_expert_gate)
at 8-bit precision, specified per-tensor in config.json. The inference
engine treated all tensors as 4-bit, extracting 8 nibbles per uint32 from
data that actually packs 4 bytes per uint32. This corrupts routing scores,
selecting wrong experts and producing nonsensical output.

Changes:
- Add dequant_matvec_8bit Metal kernel (4 bytes/uint32, FMA-optimized)
- Add cpu_dequant_matvec_8bit CPU fallback
- Add BatchMatvecSpec.bits field for per-tensor bit-width dispatch
- Mark gate and shared_expert_gate as 8-bit in all dispatch sites

Fixes danveloper#10

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
rrr3try pushed a commit to Graf-RAGov/flash-moe-mlx that referenced this pull request Apr 17, 2026
…-e base

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
rrr3try pushed a commit to Graf-RAGov/flash-moe-mlx that referenced this pull request Apr 17, 2026
Upstream + fork + issue context compiled for the port effort: PR diffs
(danveloper#3 runtime config, danveloper#11 perf wins, danveloper#13 Qwen3-Coder-Next, danveloper#14 8-bit dequant),
fork summaries (nerds-odd-e, gorroai), issue captures (danveloper#15 setup gotchas,
danveloper#17 expert_index scope bug, danveloper#20 other Qwen models), target architecture
spec (qwen3.6-35b-a3b-arch.md), hardcoded-constants map of upstream
flash-moe, condensed port plan. Plus benchmark results, parallelism
exploration, 10x optimization ideas.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
AndrewFarley added a commit to AndrewFarley/flash-moe that referenced this pull request Apr 29, 2026
Applied PRs danveloper#5 and danveloper#14:

Dashboard:
- ncurses-based htop-style terminal monitor (dashboard.c)
- Reads /tmp/flash-moe-stats.json written by inference server
- Real-time status, progress bars, TTFT, tok/s, rolling averages

Serve loop improvements:
- SSE streaming with per-token delta events
- Non-streaming JSON response mode (stream: false)
- Tool call parsing from <tool_call> blocks in model output
- Full OpenAI messages array parsing for generic clients
- Dashboard stats reporting (server state, prefill progress, generation)
- GPU KV buffer increased to 32K pre-allocation
- CPU 2-bit expert forward path for fallback compute
- CMD1+CMD2 merge optimization for linear attention layers
- select() loop for idle stats updates, GET /stats endpoint

8-bit gate dequant (PR danveloper#14):
- dequant_matvec_8bit Metal kernel (FMA-optimized)
- cpu_dequant_matvec_8bit CPU fallback
- BatchMatvecSpec.bits field for per-tensor bit-width dispatch
- Auto-detection of gate quantization from config.json
- Gate bits applied dynamically (4-bit or 8-bit based on model)

Additional fixes:
- BPE byte marker decoding in SSE output (Ġ→space, Ċ→newline)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Nonsensical output on Apple M4 Pro (Mac Mini 64GB) — 14.5 tok/s but garbage generation

1 participant