Skip to content

croll83/llama.cpp-dgx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9,101 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

llama.cpp-dgx

Fork of ggml-org/llama.cpp optimized for NVIDIA DGX Spark / GB10 (Blackwell, SM 12.1).

Upstream CUDA Arch

Why this fork

llama.cpp-dgx is a runtime for hybrid Qwen3.5/3.6 / Qwopus 27B-class models on a single GB10 (DGX Spark, 128 GB unified memory). It composes five upstream-or-near-upstream tracks that do not yet land together in ggml-org/llama.cpp, plus a small number of Blackwell-specific tweaks. Verified against upstream/master at 0adede866 (re-merge cadence: weekly).

The five tracks:

  1. TurboQuant on weights — TQ3_0 / TQ3_4S / TQ3_1S 3-bit weight quantization with Lloyd-Max codebooks. Imported from @turbo-tan / llama.cpp-tq3 (62eb27dce baseline) — see ggml/src/ggml-turbo-quant.c. Used here to ship Qwopus3.6-27B-v1-Abliterated-preview at ~14 GiB / ~3.5 bpw with PPL parity to Q3_K_S.

  2. TurboQuant on KV cacheturbo2_0 / turbo3_0 / turbo4_0 and turbo3_tcq / turbo2_tcq (Trellis-Coded Quant) types for the K/V cache, with FWHT (Fast Walsh–Hadamard Transform) rotation matrices baked into the FA kernels. Imported from @spiritbuun / buun-llama-cpp — see ggml/src/ggml-cuda/fattn-common.cuh and the d_turbo_centroids_*_fattn codebooks. tq3_0 K+V on the standard llama path lands ~22% smaller KV vs Q4_0 with no measurable decode regression on GB10 (matches the upstream PR numbers).

  3. NVFP4 (FP4 tensor cores) inference — native NVFP4 matmul + per-tensor scale2 application after the kernel, tracking the WIP upstream PRs (#21089, #20977). Loader path supports plain NVFP4 (NVIDIA ModelOpt NVFP4_DEFAULT_CFG); the AWQ variant (NVFP4_AWQ_LITE_CFG) is intentionally not used because llama.cpp does not apply the AWQ .pre_quant_scale channel-wise factor at inference and therefore returns garbage tokens when the model is exported with the AWQ recipe. The dflash custom target graph (see below) also applies the per-tensor scale2 after every ggml_mul_mat so NVFP4 + speculative decoding work end-to-end. The matmul itself uses Blackwell's native mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64 PTX (see ggml/src/ggml-cuda/mma.cuh's mma_block_scaled_fp4) — no dequantize-to-bf16 round-trip — so on AEON-XS NVFP4 the prefill rate is essentially VRAM-bandwidth-bound, not compute-bound.

  4. DFlash MTP speculative decoding — block-diffusion draft + DDtree verify integration, ported from Luce-Org / lucebox-hub (tools/dflash-cli/). Wired into llama-server so that the dflash custom target graph runs in place of llama_decode for text-only requests, while mmproj (vision) requests fall back to the standard path. Includes causal sliding-window-attention support for the Qwen3.6-27B-DFlash draft (4 SWA layers + 1 full-attention layer), and a borrow path that lets dflash share the host llama_model's on-GPU weight tensors instead of re-uploading them — saves ~18.5 GiB of VRAM when --mmproj is set.

  5. Chunk-fused GatedDeltaNet kernel for prefill — from-scratch, GB10-tiled (sm_120/121) replacement for FlashQLA on the GDN forward path. FlashQLA is Hopper-targeted and needs 192 KB shared memory per CTA, exceeding the 99 KB sm_121a opt-in cap; this fork therefore ships a 4-kernel pipeline (cumsumkkt_solveprepare_hfused_fwd) sized for the Blackwell consumer budget, all wmma 16×16×16 bf16 with fp32 accumulators on Blackwell tensor cores. Active for prefill ubatches with n_tokens >= 64 (chain mode, S_v = 128); falls back to the per-token kernel for decode, tree mode, or KDA. Bit-equivalent output to the per-token recurrence on greedy decode (verified by sending the same prompt through both paths with temperature=0 and getting byte-identical 40-token continuations). Microbench (B=1 T=192 H=16): 1537 → 162 us per call (9.5×); end-to-end prefill rate is parity since GDN is ~1-2% of the total on AEON-XS NVFP4 (FFN is the limiter, see point 3). Quality benefit observed in production: small token-level hallucinations that the buggy pre-fix scalar path emitted on long structured outputs are gone, because the fix during tuning corrected a latent ~3 % per-element drift in the original FlashQLA-style A_sol formulation. See docs/rfc-gdn-chunk-kernel.md. Env-var escape hatch: GGML_GDN_CHUNK_DISABLE=1 forces per-token.

Blackwell / GB10 specifics (custom vs upstream)

Models we ship & test against

  • Target: croll83/Qwopus3.6-27B-v1-Abliterated-preview — abliterated derivative of Jackrong/Qwopus3.6-27B-v1-preview, itself a Claude-distilled SFT on Qwen/Qwen3.6-27B (qwen35 hybrid arch: 16 full-attention + 48 GatedDeltaNet layers, ~28B params, 262K context). Repo ships BF16 safetensors, mmproj F16, and GGUFs (Q4_K_M, TQ3_4S, NVFP4-plain).
  • Draft (DFlash): z-lab/Qwen3.6-27B-DFlash — block-diffusion drafter, 5 layers (4 SWA + 1 full), block_size=16, BF16 safetensors. Required when running with --dflash.
  • Older draft: z-lab/Qwen3.5-27B-DFlash — non-causal, full-attention. Slightly lower accept rate vs the Qwen3.6 draft on Qwen3.6 targets but works without SWA support in the inference engine.

Install / build

Same as upstream — see docs/build.md. Quickstart for GB10:

git clone https://github.com/croll83/llama.cpp-dgx.git
cd llama.cpp-dgx
cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=121a-real
cmake --build build --target llama-server llama-cli llama-quantize llama-dflash llama-dflash-server -j 8

121a-real targets GB10 specifically. Use 120a-real for B200, or native to auto-detect.

DGX-only flags reference

These flags / env vars exist only in this fork (or have changed semantics vs upstream):

Flag / env Where What it does
-ctk tq3_0 / -ctv tq3_0 llama-server, llama-cli Use TQ3_0 (3.5 bpw) for the standard llama_kv_cache K/V. Saves ~22% vs Q4_0; fattn vec kernel handles the 256-stride alignment. K=TQ3_0 is also wired (ours commit 7b5f82569).
-ctk turbo3 etc. (planned) spiritbuun TurboQuant types for KV cache. Type names are exposed in ggml.h (GGML_TYPE_TURBO2_0, TURBO3_0, TURBO4_0, TURBO3_TCQ, TURBO2_TCQ) but the F32→TURBO* CPY/set-rows wiring is still TODO — see docs/dflash_kv_quant_status.md.
--dflash llama-server Enable DFlash MTP speculative decoding. Replaces llama_decode with the dflash custom target graph for text-only requests; mmproj requests fall through to the standard path.
--dflash-draft PATH llama-server Path to the DFlash draft model. Accepts .safetensors (BF16, ~3.3 GiB, default) or .gguf (community Q8_0 quants like spiritbuun/Qwen3.6-27B-DFlash-GGUF, ~1.8 GiB). Q8_0 GGUF saves ~1.5 GiB VRAM but trades ~16 % decode throughput and ~30 % accept rate vs BF16 in our dflash MTP rollback path — keep BF16 unless VRAM-constrained.
--dflash-budget N llama-server DDtree node budget per draft step. Default 22; sweep summary: 22 is balanced, 32 wins on JSON (+25%), 64+ saturates.
--dflash-max-ctx N llama-server Per-slot dflash KV ring size. Default --ctx-size / n_parallel.
--dflash-prefill-ubatch N llama-server dflash prefill ubatch size. Default 192 on GB10.
--dflash-fa-window N llama-server Sliding window FA on full-attn layers (port of Luce-Org/lucebox-hub#26). Default 0 (off). When set (recommended 2048), limits FA to the last N KV positions per query: cuts FA cost from O(kv_len) to O(N) at long contexts, +26 % overall throughput on hermes agent traffic in our bench. WARNING — agent-incompatible: the window drops attention to early KV positions, including the system prompt. On agents whose identity / tool list is in the system prompt (Dark Jarvis SOUL, etc.), the model loses context once kv_start > window and falls back to vanilla refusals. Safe for raw continuation / story-style workloads; for agents wait for an attention-sink follow-up that pins the first K tokens + last W tokens. Production default is 0.
DFLASH27B_KV_V=tq3_0 env dflash V-cache type override (q8_0 default). Do not use in production on either Qwopus3.6 OR AEON-7 XS. On Qwopus the drift pathology is loud (token loop, 21-min hang). On AEON-7 XS — even with linear_attn.conv1d preserved BF16 — the drift is more subtle: short prose stress tests pass cleanly, but at >25K accumulated context with structured tool_call generation the model emits malformed JSON (path truncation, missing arguments key). The conv1d-BF16 preservation reduces but does not eliminate the cumulative attention-score noise that TQ3 V introduces across thousands of DDtree verify reads — the SSM in_proj_a/b/qkv/z projection matmuls are the next-most-likely culprits and they ARE FP4-quantised on the XS body. Standard llama path -ctv tq3_0 is unaffected on either body.
DFLASH27B_KV_K=tq3_0 env Experimental — do NOT use in production. Boots and runs short prompts cleanly (commit 6858a4192 fixed the SIGSEGV by forcing the VEC fattn kernel) but hits a long-generation token-loop pathology on agent workloads (~78K committed tokens before it converges on a single repeated token id). Suspected cause: cumulative attention-score degradation from 3.5 bpw K compounded over thousands of decode steps in the dflash custom graph. Standard llama path -ctk tq3_0 is unaffected and remains the recommended K-quant shortcut.
DFLASH27B_KV_TQ3=1 env Both dflash K and V to TQ3_0 in one shot. Do NOT use in production — combines both pathologies above.
DFLASH27B_KV_F16=1 env Force dflash KV back to f16 (regression baseline).
DFLASH27B_SHARE_KV=1 env Share the standard llama_kv_cache K/V buffers with the dflash session (target body only — not the drafter). The dflash custom graph builds non-owning ggml views into llama_kv_cache.layers[il].k/v instead of allocating its own per-layer K/V. Saves ~9 GiB resident at np=2 with --dflash-max-ctx 131072 and -c 262144, by removing the duplicate K/V the two paths otherwise each carry. The view layout uses non-standard ggml strides (nb2 < nb1, since llama_kv_cache packs heads onto axis 0); every kernel touched (FA-vec direct vec_dot, FA-MMA-F16 via to_fp16_nc dequant, cpy_q_q, set_rows<block_q8_0>) handles arbitrary strides via byte-offset math (no kernel patches needed — see docs/rfc-unified-target-cache.md). The K/V types come from -ctk/-ctv on this path; DFLASH27B_KV_K/_V are ignored. Default 0 (off) until production validation completes. Caveat: when both share-kv and --mmproj are on and the slot processes a vision request followed by a text request, the dflash session detects the kv_end desync and forces a full re-prefill (correct, but costs one prefill round) — the standard path's writes are preserved bit-for-bit since both paths produce the same K/V projections from the same target weights.
TURBO_LAYER_ADAPTIVE=N env Layer-adaptive Turbo KV quant (1–11 strategies; 0 = uniform, default).
GGML_GDN_CHUNK_DISABLE=1 env Forces the per-token GDN kernel even for prefill ubatches >= 64 tokens. The chunked GDN path (4-kernel pipeline, wmma 16×16×16, 9.5× microbench speedup over the scalar fallback) is on by default; this escape hatch reverts to the legacy per-token recurrence kernel for A/B comparison or as a safety bypass. Greedy decoding output is bit-identical between the two paths after the math fix in commit 3c66666df, so the production effect is performance-only. See docs/rfc-gdn-chunk-kernel.md §9 for the full tuning log.

Recommended runtime config (GB10, 128 GB unified memory)

For 262K context with -np 2 (two persistent slots, e.g. agent + memory writer) on the Qwopus3.6 27B target:

./build/bin/llama-server \
  -m /path/to/Qwopus-27B-NVFP4-plain.gguf \
  --mmproj /path/to/mmproj-Abliterated-F16.gguf \
  --dflash --dflash-draft /path/to/qwopus36-dflash-v4/model.safetensors \
  --dflash-budget 22 --dflash-max-ctx 131072 --dflash-prefill-ubatch 192 \
  --host 0.0.0.0 --port 30000 -c 262144 -np 2 -ngl 99 \
  -ctk q8_0 -ctv q8_0 \
  --dflash-fa-window 16384 --dflash-fa-sink 16384 \
  --slot-prompt-similarity 0.5 --cache-reuse 256 \
  --jinja --reasoning auto --alias dark-opus --no-webui --no-warmup

This gives ~40 GiB resident on GPU (NVFP4 weights borrowed from llama_model + standard KV K=Q8_0 V=Q8_0 + dflash V=Q8_0 ring, dflash K stays Q8_0, np=2). The TQ3_0 KV path would save another ~6 GiB but is unsafe on agent / coding workloads — both DFLASH27B_KV_V=tq3_0 and DFLASH27B_KV_K=tq3_0 cause a long-generation token-loop pathology after ~15-30K accumulated context. The standard llama path -ctk tq3_0 -ctv tq3_0 is unaffected and remains a valid VRAM-saver if you can run without --dflash. The --dflash-fa-sink 16384 --dflash-fa-window 16384 pair caps FA cost at long context while preserving the system prompt (attention-sink port, see docs/dflash_kv_quant_status.md).

Troubleshooting

  • Server dies silently right after DFlash run: log line, no GGML_ASSERT or CUDA error in the output. Pre-6858a4192 symptom of K=TQ3_0 selecting the MMA fattn kernel — which has no TQ3_0 dequant entry — and crashing on a NULL function pointer inside launch_fattn<...>(). Fixed by extending the force-VEC predicate in ggml-cuda/fattn.cu to include GGML_TYPE_TQ3_0. Re-pull and rebuild; if you still see this, get a backtrace with gdb -batch -ex run llama-dflash <args> to confirm.

  • Agent stalls at "No response from provider for 180s", server log shows endless [step N] committed=… last_tok=X next=X with the same X for thousands of iterations. Token-loop pathology on the dflash custom decode graph: the model produces a self-reinforcing prediction on a single token after enough accumulated context. Two known triggers, both fixed by switching the dflash KV cache off TQ3_0:

    • DFLASH27B_KV_K=tq3_0 — onset around 78K committed tokens; symptom last_tok=328 next=328 style.
    • DFLASH27B_KV_V=tq3_0 — onset varies by body but always present in long structured-output workloads. Qwopus3.6: loud drift at 15-25K context, commit/step blows up to 14+, output is a wall of 0 tokens or Chinese filler, 21-min slot hang. AEON-7 XS (conv1d preserved BF16): subtle drift starting around 25K context — short prose passes cleanly but tool_call generation emits malformed JSON (truncated paths, missing arguments key). The conv1d-BF16 preservation helps but the SSM projection matmuls (in_proj_a/b/qkv/z) are still FP4-quantised on the XS body and contribute residual noise. TL;DR: keep V=Q8 on dflash regardless of body.

    Stop the server, drop both env overrides (or set them to q8_0), restart. Standard path -ctk tq3_0 -ctv tq3_0 (without --dflash) is unaffected. Tracking in docs/dflash_kv_quant_status.md.

  • GGML_ASSERT(buf != NULL && "tensor buffer not set") from ggml_backend_tensor_set during dflash prefill. Pre-fix symptom on DFLASH27B_KV_V=q8_0 (or any non-TQ3_0 V-type): build_full_attn_block only references the kv_idxs input on the set_rows path (TQ3_0-only), so on the cpy fallback path ggml_gallocr dead-code-eliminates the input and sg.kv_idxs->buffer stays NULL. Fixed in tools/dflash-cli/session.cpp by gating the upload on sg.kv_idxs->buffer != nullptr. If you see this on a build past that commit, your build tree is stale — cmake --build build --target llama-server -j and relaunch.

  • speculative decoding not supported by this context log line on init. Expected with --dflash: this is the legacy speculative-decoding path's compat probe failing because the dflash session takes over. The DFlash session is unrelated.

  • cache_reuse is not supported by multimodal log line. Expected with --mmproj + --cache-reuse. The prompt cache stays effective for slot persistence; only the cross-request prefix-match path is disabled.

  • fattn vec kernel aborts with K%256 != 0. Either the cache type is one of the TQ3_0 / TURBO* family on the standard path (the fork bumps fattn_stride to 256 automatically — make sure you're on origin/feature/dflash-integration or later) or you set --dflash-max-ctx to a non-256-aligned value.

  • Failed to parse input at pos 41: 不休ief粟… Output garbage on NVFP4 means the per-tensor .scale tensors did not load. Re-export the model with NVFP4 plain (NVFP4_DEFAULT_CFG), not AWQ (NVFP4_AWQ_LITE_CFG); see tools/dflash-cli/quantize_nvfp4_plain.py.

  • OOM during model load with --mmproj + --dflash. The borrow path is auto-enabled in this configuration; if you see two ~15 GiB "CUDA0 model buffer" log lines instead of one, re-pull and rebuild — the patch is in commit 87102e46b.

Memory savings stack

Two production stable end-states depending on target body. Both use np=2, mmproj loaded, dflash + dflash drafter v4. The dflash agent path uses --dflash-max-ctx 131072 per slot independently of -c-c only sizes the standard llama_kv_cache that the mmproj/vision path uses. The two contexts are separate.

Qwopus3.6 (FP4 conv1d, drift-prone with TQ3 V):

Stage GPU resident Δ
Baseline (weights duplicated) 58.4 GiB
+ borrow llama_model weights (commit 87102e46b) 39.9 GiB −18.5 GiB
+ standard -ctk tq3_0 -ctv tq3_0 (commit 7b5f82569) 37.5 GiB −2.4 GiB
stable end-state, V cache stuck at Q8 to dodge drift, -c 262144 ~64-72 GiB

AEON-7 Qwen3.6-27B-AEON-Ultimate-Uncensored-Multimodal-NVFP4-MTP-XS (conv1d preserved BF16, less but not zero drift on TQ3 V):

Stage GPU resident dflash agent ctx/slot mmproj ctx/slot Δ
Baseline (Qwopus state) 64 GiB 131K 131K
+ AEON-7 XS body swap (linear_attn.conv1d BF16) 64 GiB 131K 131K 0
+ halve DFLASH_ANCHOR_SLOTS 4 → 2 (commit ba400dcee) 62 GiB 131K 131K −2 GiB
+ DFLASH_ANCHOR_SLOTS = 1 61 GiB 131K 131K −1 GiB
+ mmproj F16 → Q8 (mmproj-AEON-XS-Q8.gguf, 928 MB → 629 MB) 60.7 GiB 131K 131K −0.3 GiB
+ DFLASH27B_SHARE_KV=1 (Phase 2.2 unified target K/V cache, RFC) ~54 GiB 131K 131K −6.5 GiB
stable end-state with V=Q8 on dflash, full 131K vision ~54 GiB 131K 131K −10 GiB vs Qwopus baseline, −47 GiB vs vLLM equivalent (103 GiB)

We attempted a DFLASH27B_KV_V=tq3_0 re-enable on AEON XS (would have saved ~4 GiB extra) but rolled it back: conv1d-BF16 preservation reduces the drift signature vs Qwopus (no token loop, no wall of zeros) but does NOT eliminate it on long-context structured output — at >25K context the model emits malformed JSON tool_calls (truncated paths, missing keys). The SSM projection matmuls (in_proj_a/b/qkv/z) are still FP4-quantised on XS and contribute residual noise. Production stays at V=Q8 on dflash for both Qwopus and AEON.

The current -c 131072 config trades half the vision context (65K/slot vs 131K) for 3 GiB. The agent path is unaffected — dflash sizes its own KV ring via --dflash-max-ctx. For image-heavy multimodal workflows where you want full 131K vision capacity, set -c 262144 -np 2 → ~62 GiB resident. | (extra) dflash DFLASH27B_KV_K=tq3_0 + force-VEC fix (commit 6858a4192) | 34.9 GiB | −2.6 GiB but unstable: long-generation token loop |

The last row is left in the codebase as a documented experimental knob — see the flags reference. DFLASH27B_KV_K=tq3_0 boots, passes short-prompt sanity, but loses coherence on agent-style multi-turn / long-decode workloads (we saw the model degenerate to a single-token loop after ~78K committed tokens). Standard path -ctk tq3_0 is unaffected.

Benchmarks (GB10, NVFP4 + mmproj, np=2, c=262144)

Decode throughput on Qwopus3.6-27B-v1-Abliterated-preview with the Qwen3.6-27B-DFlash draft, --reasoning auto (thinking on by default), per-request enable_thinking overrides as noted:

Workload tok/s accept commits/step thinking
JSON 1024 (color names) 68.7 65 % 10.5 on
MATH 256 (algebra step-by-step) 45.7 46 % 7.3 on
CODE 512 (heapsort + tests) 38.3 47 % 7.5 on
LongCode 2048 38.0 43 % 6.9 on
PROSE 400 (free essay) 27.1 29 % 4.7 on
PROSE 400 (same prompt) 18.7 20 % 3.2 off

Memory footprint at this config (idle, after first warmup pass):

Component Size
NVFP4 target weights (borrowed) 15.5 GiB
Standard llama_kv_cache (K=TQ3_0 + V=TQ3_0, 16 attn layers × 131K × 2 seqs) 3.6 GiB
DFlash per-slot ring (K=Q8_0 + V=TQ3_0 + SSM + target_feat, ×2 slots) ~13 GiB
Standard compute buffer + recurrent state 2.1 GiB
mmproj vision encoder 0.9 GiB
Draft model (Qwen3.6-27B-DFlash) 0.9 GiB
Prompt cache (server-side, lazy, capped 8 GiB) up to 8 GiB
CUDA runtime + libraries ~3 GiB
Total resident on GB10 ~37.7 GiB (stable; 34.9 GiB unstable with K=TQ3_0 on dflash)

For comparison on the same workload, lucebox-hub's llama-dflash-server standalone (no --mmproj, no prompt cache, single slot, Q4_K_M target) runs at ~25–50 tok/s and 26.6 GiB resident. The ~13 GiB delta is the price of --mmproj + -np 2 + the prompt cache; remove either of those to recover most of it.

Verifying against upstream

This fork is meant to stay rebase-able onto upstream/master. To audit the diff:

git remote add upstream https://github.com/ggml-org/llama.cpp.git
git fetch upstream
git log --oneline upstream/master..HEAD                # commits unique to the fork
git diff --stat upstream/master..HEAD -- ggml/        # ggml-side delta
git diff --stat upstream/master..HEAD -- tools/       # tools / dflash-cli delta

Most of the fork lives in:

  • ggml/src/ggml-cuda/cpy.cu, set-rows.cu, fattn*.cuh, turbo-wht.cu — TQ3_0 / Turbo* CUDA kernels
  • ggml/src/ggml-turbo-quant.c — CPU TurboQuant reference (stub on most types; CUDA kernels are the load-bearing path)
  • src/llama-kv-cache.cpp, src/llama-graph.cpp — TQ3_0 / Turbo* dispatch in the standard llama path
  • src/llama-model.cpp — NVFP4 .scale / .input_scale per-tensor loading and tensors_by_name map
  • tools/dflash-cli/ — DFlash custom target graph + draft graph + session
  • tools/dflash-server/ — standalone dflash HTTP server (llama-dflash-server)
  • tools/server/server-context.cpp--dflash dispatch + mmproj coexistence + weight borrow

Credits


llama.cpp (upstream)

(everything below is the upstream README from ggml-org/llama.cpp@upstream/master, kept for parity)

llama.cpp

llama

License: MIT Release Server

Manifesto / ggml / ops

LLM inference in C/C++

Recent API changes

Hot topics


Quick start

Getting started with llama.cpp is straightforward. Here are several ways to install it on your machine:

Once installed, you'll need a model to work with. Head to the Obtaining and quantizing models section to learn more.

Example command:

# Use a local model file
llama-cli -m my_model.gguf

# Or download and run a model directly from Hugging Face
llama-cli -hf ggml-org/gemma-3-1b-it-GGUF

# Launch OpenAI-compatible API server
llama-server -hf ggml-org/gemma-3-1b-it-GGUF

Description

The main goal of llama.cpp is to enable LLM inference with minimal setup and state-of-the-art performance on a wide range of hardware - locally and in the cloud.

  • Plain C/C++ implementation without any dependencies
  • Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
  • AVX, AVX2, AVX512 and AMX support for x86 architectures
  • RVV, ZVFH, ZFH, ZICBOP and ZIHINTPAUSE support for RISC-V architectures
  • 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
  • Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
  • Vulkan and SYCL backend support
  • CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity

The llama.cpp project is the main playground for developing new features for the ggml library.

Models

Typically finetunes of the base models below are supported as well.

Instructions for adding support for new models: HOWTO-add-model.md

Text-only

Multimodal

Bindings
UIs

(to have a project listed here, it should clearly state that it depends on llama.cpp)

Tools
  • akx/ggify – download PyTorch models from Hugging Face Hub and convert them to GGML
  • akx/ollama-dl – download models from the Ollama library to be used directly with llama.cpp
  • crashr/gppm – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
  • gpustack/gguf-parser - review/check the GGUF file and estimate the memory usage
  • Styled Lines (proprietary licensed, async wrapper of inference part for game development in Unity3d with pre-built Mobile and Web platform wrappers and a model example)
  • unslothai/unsloth – 🦥 exports/saves fine-tuned and trained models to GGUF (Apache-2.0)
Infrastructure
  • Paddler - Open-source LLMOps platform for hosting and scaling AI in your own infrastructure
  • GPUStack - Manage GPU clusters for running LLMs
  • llama_cpp_canister - llama.cpp as a smart contract on the Internet Computer, using WebAssembly
  • llama-swap - transparent proxy that adds automatic model switching with llama-server
  • Kalavai - Crowdsource end to end LLM deployment at any scale
  • llmaz - ☸️ Easy, advanced inference platform for large language models on Kubernetes.
  • LLMKube - Kubernetes operator for llama.cpp with multi-GPU and Apple Silicon Metal support"
Games
  • Lucy's Labyrinth - A simple maze game where agents controlled by an AI model will try to trick you.

Supported backends

Backend Target devices
Metal Apple Silicon
BLAS All
BLIS All
SYCL Intel and Nvidia GPU
OpenVINO [In Progress] Intel CPUs, GPUs, and NPUs
MUSA Moore Threads GPU
CUDA Nvidia GPU
HIP AMD GPU
ZenDNN AMD CPU
Vulkan GPU
CANN Ascend NPU
OpenCL Adreno GPU
IBM zDNN IBM Z & LinuxONE
WebGPU [In Progress] All
RPC All
Hexagon [In Progress] Snapdragon
VirtGPU VirtGPU APIR

Obtaining and quantizing models

The Hugging Face platform hosts a number of LLMs compatible with llama.cpp:

You can either manually download the GGUF file or directly use any llama.cpp-compatible models from Hugging Face or other model hosting sites, by using this CLI argument: -hf <user>/<model>[:quant]. For example:

llama-cli -hf ggml-org/gemma-3-1b-it-GGUF

By default, the CLI would download from Hugging Face, you can switch to other options with the environment variable MODEL_ENDPOINT. The MODEL_ENDPOINT must point to a Hugging Face compatible API endpoint.

After downloading a model, use the CLI tools to run it locally - see below.

llama.cpp requires the model to be stored in the GGUF file format. Models in other data formats can be converted to GGUF using the convert_*.py Python scripts in this repo.

The Hugging Face platform provides a variety of online tools for converting, quantizing and hosting models with llama.cpp:

To learn more about model quantization, read this documentation

A CLI tool for accessing and experimenting with most of llama.cpp's functionality.

  • Run in conversation mode

    Models with a built-in chat template will automatically activate conversation mode. If this doesn't occur, you can manually enable it by adding -cnv and specifying a suitable chat template with --chat-template NAME

    llama-cli -m model.gguf
    
    # > hi, who are you?
    # Hi there! I'm your helpful assistant! I'm an AI-powered chatbot designed to assist and provide information to users like you. I'm here to help answer your questions, provide guidance, and offer support on a wide range of topics. I'm a friendly and knowledgeable AI, and I'm always happy to help with anything you need. What's on your mind, and how can I assist you today?
    #
    # > what is 1+1?
    # Easy peasy! The answer to 1+1 is... 2!
  • Run in conversation mode with custom chat template
    # use the "chatml" template (use -h to see the list of supported templates)
    llama-cli -m model.gguf -cnv --chat-template chatml
    
    # use a custom template
    llama-cli -m model.gguf -cnv --in-prefix 'User: ' --reverse-prompt 'User:'
  • Constrain the output with a custom grammar
    llama-cli -m model.gguf -n 256 --grammar-file grammars/json.gbnf -p 'Request: schedule a call at 8pm; Command:'
    
    # {"appointmentTime": "8pm", "appointmentDetails": "schedule a a call"}

    The grammars/ folder contains a handful of sample grammars. To write your own, check out the GBNF Guide.

    For authoring more complex JSON grammars, check out https://grammar.intrinsiclabs.ai/

A lightweight, OpenAI API compatible, HTTP server for serving LLMs.

  • Start a local HTTP server with default configuration on port 8080
    llama-server -m model.gguf --port 8080
    
    # Basic web UI can be accessed via browser: http://localhost:8080
    # Chat completion endpoint: http://localhost:8080/v1/chat/completions
  • Support multiple-users and parallel decoding
    # up to 4 concurrent requests, each with 4096 max context
    llama-server -m model.gguf -c 16384 -np 4
  • Enable speculative decoding
    # the draft.gguf model should be a small variant of the target model.gguf
    llama-server -m model.gguf -md draft.gguf
  • Serve an embedding model
    # use the /embedding endpoint
    llama-server -m model.gguf --embedding --pooling cls -ub 8192
  • Serve a reranking model
    # use the /reranking endpoint
    llama-server -m model.gguf --reranking
  • Constrain all outputs with a grammar
    # custom grammar
    llama-server -m model.gguf --grammar-file grammar.gbnf
    
    # JSON
    llama-server -m model.gguf --grammar-file grammars/json.gbnf

A tool for measuring the perplexity 1 (and other quality metrics) of a model over a given text.

  • Measure the perplexity over a text file
    llama-perplexity -m model.gguf -f file.txt
    
    # [1]15.2701,[2]5.4007,[3]5.3073,[4]6.2965,[5]5.8940,[6]5.6096,[7]5.7942,[8]4.9297, ...
    # Final estimate: PPL = 5.4007 +/- 0.67339
  • Measure KL divergence
    # TODO

Benchmark the performance of the inference for various parameters.

  • Run default benchmark
    llama-bench -m model.gguf
    
    # Output:
    # | model               |       size |     params | backend    | threads |          test |                  t/s |
    # | ------------------- | ---------: | ---------: | ---------- | ------: | ------------: | -------------------: |
    # | qwen2 1.5B Q4_0     | 885.97 MiB |     1.54 B | Metal,BLAS |      16 |         pp512 |      5765.41 ± 20.55 |
    # | qwen2 1.5B Q4_0     | 885.97 MiB |     1.54 B | Metal,BLAS |      16 |         tg128 |        197.71 ± 0.81 |
    #
    # build: 3e0ba0e60 (4229)

A minimal example for implementing apps with llama.cpp. Useful for developers.

  • Basic text completion
    llama-simple -m model.gguf
    
    # Hello my name is Kaitlyn and I am a 16 year old girl. I am a junior in high school and I am currently taking a class called "The Art of

Contributing

  • Contributors can open PRs
  • Collaborators will be invited based on contributions
  • Maintainers can push to branches in the llama.cpp repo and merge PRs into the master branch
  • Any help with managing issues, PRs and projects is very appreciated!
  • See good first issues for tasks suitable for first contributions
  • Read the CONTRIBUTING.md for more information
  • Make sure to read this: Inference at the edge
  • A bit of backstory for those who are interested: Changelog podcast

Other documentation

Development documentation

Seminal papers and background on the models

If your issue is with model generation quality, then please at least scan the following links and papers to understand the limitations of LLaMA models. This is especially important when choosing an appropriate model size and appreciating both the significant and subtle differences between LLaMA models and ChatGPT:

XCFramework

The XCFramework is a precompiled version of the library for iOS, visionOS, tvOS, and macOS. It can be used in Swift projects without the need to compile the library from source. For example:

// swift-tools-version: 5.10
// The swift-tools-version declares the minimum version of Swift required to build this package.

import PackageDescription

let package = Package(
    name: "MyLlamaPackage",
    targets: [
        .executableTarget(
            name: "MyLlamaPackage",
            dependencies: [
                "LlamaFramework"
            ]),
        .binaryTarget(
            name: "LlamaFramework",
            url: "https://github.com/ggml-org/llama.cpp/releases/download/b5046/llama-b5046-xcframework.zip",
            checksum: "c19be78b5f00d8d29a25da41042cb7afa094cbf6280a225abe614b03b20029ab"
        )
    ]
)

The above example is using an intermediate build b5046 of the library. This can be modified to use a different version by changing the URL and checksum.

Completions

Command-line completion is available for some environments.

Bash Completion

$ build/bin/llama-cli --completion-bash > ~/.llama-completion.bash
$ source ~/.llama-completion.bash

Optionally this can be added to your .bashrc or .bash_profile to load it automatically. For example:

$ echo "source ~/.llama-completion.bash" >> ~/.bashrc

Dependencies

  • yhirose/cpp-httplib - Single-header HTTP server, used by llama-server - MIT license
  • stb-image - Single-header image format decoder, used by multimodal subsystem - Public domain
  • nlohmann/json - Single-header JSON library, used by various tools/examples - MIT License
  • miniaudio.h - Single-header audio format decoder, used by multimodal subsystem - Public domain
  • subprocess.h - Single-header process launching solution for C and C++ - Public domain

Footnotes

  1. https://huggingface.co/docs/transformers/perplexity

About

llama.cpp fork optimized for NVIDIA DGX Spark / GB10 (Blackwell, SM 12.1) — TurboQuant weights + KV, NVFP4, DFlash MTP

Topics

Resources

License

Contributing

Security policy

Stars

Watchers

Forks

Packages

 
 
 

Contributors