[cuda] Fix mmq/mma path#1
Merged
khosravipasha merged 1 commit intoprismfrom Mar 19, 2026
Merged
Conversation
There was a problem hiding this comment.
Pull request overview
Enables the CUDA MMQ (MMA) path for Q1_0 and Q1_0_g128 to avoid slow cuBLAS fallback during prompt processing.
Changes:
- Add/enable MMQ template instantiations and runtime dispatch for Q1_0 and Q1_0_g128.
- Implement/adjust Q1_0 and Q1_0_g128 MMA tile loading logic and explicitly disable the DP4A path for these types.
- Update model ftype display strings and extend the CUDA template-instance generator to include Q1_0 variants.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| src/llama-model-loader.cpp | Simplifies displayed ftype names for Q1_0 and Q1_0_g128. |
| ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu | Adds MMQ instantiation for GGML_TYPE_Q1_0. |
| ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0_g128.cu | Adds MMQ instantiation for GGML_TYPE_Q1_0_g128. |
| ggml/src/ggml-cuda/template-instances/generate_cu_files.py | Generates MMQ instance files for Q1_0 and Q1_0_g128. |
| ggml/src/ggml-cuda/quantize.cu | Minor formatting-only change. |
| ggml/src/ggml-cuda/mmq.cuh | Adds MMA tile loaders for Q1_0/Q1_0_g128 and disables their DP4A vec-dot path. |
| ggml/src/ggml-cuda/mmq.cu | Enables MMQ dispatch/eligibility for Q1_0/Q1_0_g128 with a Turing+ MMA guard. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -6,12 +6,12 @@ | |||
| static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { | |||
| switch (args.type_x) { | |||
| // TODO: Q1_0/Q1_0_g128 MMQ disabled due to accuracy issues; for now commenting these to use cuBLAS fallback | |||
Comment on lines
277
to
+279
| // TODO: Q1_0 and Q1_0_g128 MMQ implementation exists but is currently disabled due to accuracy issues | ||
| // case GGML_TYPE_Q1_0: | ||
| // case GGML_TYPE_Q1_0_g128: | ||
| case GGML_TYPE_Q1_0: | ||
| case GGML_TYPE_Q1_0_g128: |
bricklc
pushed a commit
to bricklc/prism-ml-llama.cpp
that referenced
this pull request
Apr 25, 2026
Codex post-commit review found: 1. TURBO_D was QK_TURBO3 (now 32) — broke turbo4 C array sizes 2. SET_ROWS kernel turbo3-specific but instantiated for turbo4 3. Tail block drop for non-128 head dims Fixed PrismML-Eng#3 (TURBO_D). PrismML-Eng#1 and PrismML-Eng#2 don't affect turbo3+dk128 path. Co-Authored-By: tturney@psyguard.ai Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
bricklc
pushed a commit
to bricklc/prism-ml-llama.cpp
that referenced
this pull request
Apr 25, 2026
Complete experiment log: PrismML-Eng#1 4-mag LUT: 15.1 at 8K (BEST, +38%) PrismML-Eng#2 Batched extract: 13.7 (+25%) PrismML-Eng#3 Inline FA block: 13.5 (I-cache pressure) PrismML-Eng#4 Deferred norm: 12.9 (loses ILP) PrismML-Eng#5 2-pair half2: 12.0 (ternary overhead) PrismML-Eng#6 Select chain: 11.9 (branches kill) PrismML-Eng#7 Bit-arithmetic: 11.6 (ALU too heavy) PrismML-Eng#8 FMA branchless: 11.4 (ALU still too heavy) PrismML-Eng#9 Named-reg ternary: 10.3 (branches worst) PrismML-Eng#10 Main (8-LUT): 10.95 (baseline) PrismML-Eng#11 Non-vec FA: 10.2 (wrong kernel) Ceiling: 24.5 (no dequant) Apple8 hardware truth: 1 divergent constant read < 7 ALU ops (even with fma) Branches cost MORE than divergent constant reads Array indexing ALWAYS spills on Metal 4 constant addresses is the sweet spot The 4-mag LUT is the dequant-level ceiling on Apple Silicon. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-Authored-By: tturney@psyguard.ai
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes the prompt processing MMQ kernels for Q1_0 and Q1_0_g128, before was doing cuBLAS fallback which is much slower.