Skip to content

cuda: Q1_0 initial backend#21629

Merged
JohannesGaessler merged 4 commits intoggml-org:masterfrom
PrismML-Eng:q1-cuda
Apr 15, 2026
Merged

cuda: Q1_0 initial backend#21629
JohannesGaessler merged 4 commits intoggml-org:masterfrom
PrismML-Eng:q1-cuda

Conversation

@khosravipasha
Copy link
Copy Markdown
Contributor

@khosravipasha khosravipasha commented Apr 8, 2026

Overview

Follow up after merging of Q1_0 CPU PR. This PR adds the relevant CUDA backend.
Seems also this works for AMD in some cases that was a nice surprise :)

See a live demo of Bonsai 8B using these CUDA kernels and llama-server on hugging-face space prism-ml/Bonsai-demo, using a L40S GPU and getting decent speeds. Each request running on one gpu with a naive load balancer (just for demo purposes).

Models:

Questions:

  • fixed I could not get DP4A working for these kernels, kept getting wrong results, is that required or okay to do cuBLAS fallback for that, seems its for few generation ago?
  • I tried tuning the kernel a bit but now sure if its fully optimized. Surprisingly get similar speeds for 4090 and 5090.

llama-bench (-fa 1)

Device: NVIDIA RTX 5090 (32 GB), CUDA backend

Bonsai-1.7B (231.13 MiB, 1.72B params)

model size params backend ngl fa test t/s
qwen3 1.7B 231.13 MiB 1.72 B CUDA 99 1 pp512 29249.58 ± 4403.05
qwen3 1.7B 231.13 MiB 1.72 B CUDA 99 1 tg128 626.18 ± 7.55

Bonsai-4B (540.09 MiB, 4.02B params)

model size params backend ngl fa test t/s
qwen3 4B 540.09 MiB 4.02 B CUDA 99 1 pp512 18621.21 ± 1839.94
qwen3 4B 540.09 MiB 4.02 B CUDA 99 1 tg128 485.21 ± 2.35

Bonsai-8B (1.07 GiB, 8.19B params)

model size params backend ngl fa test t/s
qwen3 8B 1.07 GiB 8.19 B CUDA 99 1 pp512 12287.47 ± 719.62
qwen3 8B 1.07 GiB 8.19 B CUDA 99 1 tg128 373.77 ± 2.01

End-to-end testing: KL Divergence (Q1_0 vs unpacked into FP16)

To test accuracy of the CUDA backend, we compare the KL divergence of the Q1_0 model against the unpacked FP16 model. The weights are equivalent so checking the logits gives us a good indication of the accuracy of the CUDA backend. Ran on 20 chunks of wikitext-2-raw, ctx 512.

For each model testing vs the unpacked version here: https://huggingface.co/collections/prism-ml/bonsai-auxiliary

Model Mean KLD RMS Δp Same top p pp512 (t/s) tg128 (t/s)
8B 0.000514 0.635% 98.706% 12287 374
4B 0.000429 0.593% 98.510% 18621 485
1.7B 0.000419 0.555% 98.941% 29250 626

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: Yes, AI was used to help debug initial kernels and adding debugging prints, etc. Those codes are not included in this PR. Ran the PR with llama-bench, and also KL divergence tests above to ensure correctness.

@khosravipasha khosravipasha requested a review from a team as a code owner April 8, 2026 14:50
Comment thread ggml/src/ggml-cuda/mmq.cu Outdated
}

// Q1_0 requires MMA — no DP4A fallback path
if (type == GGML_TYPE_Q1_0 && !turing_mma_available(cc) && !amd_mfma_available(cc) && !amd_wmma_available(cc)) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!amd_mfma_available(cc) && !amd_wmma_available(cc)
Not fully sure about the AMD part, copilot review suggested adding that to avoid cuBLAS fallback on AMD gpus (don't have access to an AMD gpu to test myself).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its wrong, since you guard the kernel with defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) but then select for both MFMA and WMMA gpus. It should accept AMD_WMMA_AVAILABLE too.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, will fix. First will try to support d4pa then this probably won't be needed.

IMbackK
IMbackK previously requested changes Apr 8, 2026
Comment thread ggml/src/ggml-cuda/mmq.cu Outdated
}

// Q1_0 requires MMA — no DP4A fallback path
if (type == GGML_TYPE_Q1_0 && !turing_mma_available(cc) && !amd_mfma_available(cc) && !amd_wmma_available(cc)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its wrong, since you guard the kernel with defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) but then select for both MFMA and WMMA gpus. It should accept AMD_WMMA_AVAILABLE too.

@IMbackK
Copy link
Copy Markdown
Collaborator

IMbackK commented Apr 8, 2026

over all, not supporting the dp4a path is pretty undesirable btw, not just for older gpus

@khosravipasha
Copy link
Copy Markdown
Contributor Author

khosravipasha commented Apr 8, 2026

@IMbackK fair enough, I can give the dp4a path another try. What's the best way to test it?
I only have access to cloud gpus like RTX 4090,5090 and some from google colab like T4/L4/G4.

Runs decently on T4 (this google colab demo)

ggml_cuda_init: found 1 CUDA devices:
  Device 0: Tesla T4, compute capability 7.5, VMM: yes
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | CUDA       | 999 |  1 |           pp512 |       1346.04 ± 5.07 |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | CUDA       | 999 |  1 |           tg128 |         61.92 ± 0.44 |

Is this build correct to force D4pa path?

  mkdir -p build-dp4a && cd build-dp4a
  cmake .. -DGGML_CUDA=ON -DGGML_CUDA_FA=ON -DCMAKE_BUILD_TYPE=Release -DCMAKE_CUDA_ARCHITECTURES=70
  cmake --build . -j$(nproc)

@JohannesGaessler
Copy link
Copy Markdown
Contributor

If you are on CUDA 11 or 12 you can compile with -DCMAKE_CUDA_ARCHITECTURES=61-virtual in order to compile for Pascal, since the code is forwards-compatible it should then do JIT compilation for any more recent GPU on the first run.

@pl752
Copy link
Copy Markdown
Contributor

pl752 commented Apr 9, 2026

Have you tried using the LUT in shared memory for unpacking the bits?

@khosravipasha
Copy link
Copy Markdown
Contributor Author

khosravipasha commented Apr 9, 2026

@IMbackK I think dp4a should be good now, ran benchmark and KL validation tests. Also was curious if I force cuBLAS fallback what happens so tried that as well:

@JohannesGaessler Thanks for suggestion that was helpful, had to switch to 4090 as on 5090 the JIT was crashing but works on 4090

Tried the following build options:

Build CMake Flags
default -DGGML_CUDA=ON (default arches incl. 89-real)
DP4A -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=61-virtual
cuBLAS -DGGML_CUDA=ON -DGGML_CUDA_FORCE_CUBLAS=ON

RTX 4090 GPU Benchmarks — Prompt Processing (pp512, tokens/s)

Model default DP4A cuBLAS
Bonsai-1.7B 27,175 ± 3,498 19,236 ± 4,029 20,567 ± 4,917
Bonsai-4B 16,544 ± 1,309 10,015 ± 1,489 11,451 ± 947
Bonsai-8B 11,062 ± 712 6,223 ± 373 6,828 ± 446

RTX 4090 GPU Benchmarks — Token Generation (tg128, tokens/s)

Model default DP4A cuBLAS
Bonsai-1.7B 655 ± 22 488 ± 9 648 ± 19
Bonsai-4B 426 ± 5 326 ± 2 416 ± 11
Bonsai-8B 346 ± 2 349 ± 1 340 ± 4

**RTX 4090 - KL Divergence Summary **

Build Model Mean KLD Same Top Token Status
DP4A 1.7B 0.000372 ± 0.000011 98.82% PASS
DP4A 4B 0.000396 ± 0.000015 98.59% PASS
DP4A 8B 0.000414 ± 0.000013 99.14% PASS
cuBLAS 1.7B 0.000369 ± 0.000011 98.82% PASS
cuBLAS 4B 0.000418 ± 0.000017 98.59% PASS
cuBLAS 8B 0.000401 ± 0.000012 99.14% PASS

@khosravipasha
Copy link
Copy Markdown
Contributor Author

khosravipasha commented Apr 9, 2026

@pl752 I think I tried something similar but don't think it helped or was slower (don't remember the details).

@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs python python script changes ggml changes relating to the ggml tensor library for machine learning labels Apr 9, 2026
@Green-Sky
Copy link
Copy Markdown
Collaborator

Green-Sky commented Apr 10, 2026

Did not see the last commit with dp4a, not sure it would make a difference. But here are some numbers from my machine of 84ab75f:

$ result/bin/llama-bench -m /run/media/green/d20c801b-7aae-4042-85ec-bf2153257be8/green/workspace/llama.cpp/models/Bonsai-8B.gguf -b 64,128,256,512,1024,2048 -ctk q8_0 -ctv q8_0 -p 64,128,256,512,1024,2048 -n 128,512,4096 -fa 1
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 7777 MiB):
  Device 0: NVIDIA GeForce RTX 2070, compute capability 7.5, VMM: yes, VRAM: 7777 MiB
model size params backend ngl n_batch type_k type_v fa test t/s
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 64 q8_0 q8_0 1 pp64 1053.91 ± 45.72
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 64 q8_0 q8_0 1 pp128 1005.84 ± 6.14
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 64 q8_0 q8_0 1 pp256 1046.86 ± 5.46
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 64 q8_0 q8_0 1 pp512 1038.24 ± 2.08
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 64 q8_0 q8_0 1 pp1024 1022.94 ± 1.58
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 64 q8_0 q8_0 1 pp2048 994.86 ± 1.35
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 64 q8_0 q8_0 1 tg128 66.30 ± 0.89
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 64 q8_0 q8_0 1 tg512 65.44 ± 0.82
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 64 q8_0 q8_0 1 tg4096 57.70 ± 0.69
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 128 q8_0 q8_0 1 pp64 1029.19 ± 6.17
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 128 q8_0 q8_0 1 pp128 1213.07 ± 43.73
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 128 q8_0 q8_0 1 pp256 1196.87 ± 14.17
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 128 q8_0 q8_0 1 pp512 1185.92 ± 8.29
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 128 q8_0 q8_0 1 pp1024 1176.62 ± 3.35
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 128 q8_0 q8_0 1 pp2048 1147.23 ± 2.06
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 128 q8_0 q8_0 1 tg128 63.94 ± 1.08
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 128 q8_0 q8_0 1 tg512 63.51 ± 0.45
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 128 q8_0 q8_0 1 tg4096 56.49 ± 0.98
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 256 q8_0 q8_0 1 pp64 1025.42 ± 18.59
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 256 q8_0 q8_0 1 pp128 1205.88 ± 11.11
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 256 q8_0 q8_0 1 pp256 1266.83 ± 14.05
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 256 q8_0 q8_0 1 pp512 1273.99 ± 3.51
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 256 q8_0 q8_0 1 pp1024 1275.33 ± 4.07
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 256 q8_0 q8_0 1 pp2048 1249.60 ± 1.99
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 256 q8_0 q8_0 1 tg128 65.53 ± 0.18
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 256 q8_0 q8_0 1 tg512 64.96 ± 0.18
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 256 q8_0 q8_0 1 tg4096 57.41 ± 0.33
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 512 q8_0 q8_0 1 pp64 1027.07 ± 15.39
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 512 q8_0 q8_0 1 pp128 1219.57 ± 5.92
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 512 q8_0 q8_0 1 pp256 1289.94 ± 4.33
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 512 q8_0 q8_0 1 pp512 1322.18 ± 1.53
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 512 q8_0 q8_0 1 pp1024 1348.64 ± 7.06
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 512 q8_0 q8_0 1 pp2048 1287.65 ± 37.62
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 512 q8_0 q8_0 1 tg128 61.71 ± 4.26
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 512 q8_0 q8_0 1 tg512 63.72 ± 0.91
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 512 q8_0 q8_0 1 tg4096 57.15 ± 0.88
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 1024 q8_0 q8_0 1 pp64 1013.12 ± 16.05
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 1024 q8_0 q8_0 1 pp128 1214.17 ± 2.26
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 1024 q8_0 q8_0 1 pp256 1327.62 ± 3.05
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 1024 q8_0 q8_0 1 pp512 1356.79 ± 2.48
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 1024 q8_0 q8_0 1 pp1024 1360.01 ± 1.87
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 1024 q8_0 q8_0 1 pp2048 1322.03 ± 6.76
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 1024 q8_0 q8_0 1 tg128 64.03 ± 1.82
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 1024 q8_0 q8_0 1 tg512 65.26 ± 0.06
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 1024 q8_0 q8_0 1 tg4096 57.34 ± 0.82
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 2048 q8_0 q8_0 1 pp64 1022.37 ± 28.61
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 2048 q8_0 q8_0 1 pp128 1220.35 ± 4.13
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 2048 q8_0 q8_0 1 pp256 1289.69 ± 4.02
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 2048 q8_0 q8_0 1 pp512 1299.62 ± 13.23
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 2048 q8_0 q8_0 1 pp1024 1323.46 ± 3.22
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 2048 q8_0 q8_0 1 pp2048 1299.80 ± 28.36
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 2048 q8_0 q8_0 1 tg128 65.33 ± 0.03
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 2048 q8_0 q8_0 1 tg512 64.81 ± 0.11
qwen3 8B Q1_0 1.07 GiB 8.19 B CUDA 99 2048 q8_0 q8_0 1 tg4096 58.05 ± 0.23

The numbers are pretty close to the T4 numbers, which makes sense (TU106 vs TU104).

Comment on lines +3 to +23
static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q1_0 * x = (const block_q1_0 *) vx;

const float d = x[ib].d;

const int bit_index_0 = iqs;
const int bit_index_1 = iqs + 1;

const int byte_index_0 = bit_index_0 / 8;
const int bit_offset_0 = bit_index_0 % 8;

const int byte_index_1 = bit_index_1 / 8;
const int bit_offset_1 = bit_index_1 % 8;

// Extract bits: 1 = +d, 0 = -d (branchless)
const int bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1;
const int bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1;

v.x = (2*bit_0 - 1) * d;
v.y = (2*bit_1 - 1) * d;
}
Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition iqs % 2 == 0 should always be true so you could potentially optimize this function (does not need to be in this PR).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting, will give a try, removes on division and one % so might be worth.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it and did not make any change in speed it seems.

Comment on lines +343 to +344
const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) |
(bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the function get_int_b2 instead.

Copy link
Copy Markdown
Contributor Author

@khosravipasha khosravipasha Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried get_int_b2 in both places and with 3 build combination, there is a large speed degredation on the token generation path for smaller models, see table below:

That was kinda surprising so I did the rediti builds twice to make sure there was no other changes in the builds:

get_int_b2 Benchmark Results — L40S (sm_89, CUDA 12.8)

Build get_int_b2 Model pp512 (t/s) tg128 (t/s) pp Δ tg Δ
default No 1.7B 27,161 630
default Yes 1.7B 26,863 513 -1.1% -18.6%
default No 4B 15,623 414
default Yes 4B 15,481 332 -0.9% -19.8%
default No 8B 10,258 339
default Yes 8B 10,252 337 -0.1% -0.4%
cuBLAS No 1.7B 20,447 622
cuBLAS Yes 1.7B 20,629 506 +0.9% -18.6%
cuBLAS No 4B 10,860 409
cuBLAS Yes 4B 10,988 329 +1.2% -19.6%
cuBLAS No 8B 6,534 336
cuBLAS Yes 8B 6,466 334 -1.0% -0.5%
DP4A No 1.7B 19,767 492
DP4A Yes 1.7B 20,219 495 +2.3% +0.7%
DP4A No 4B 10,016 316
DP4A Yes 4B 9,971 319 -0.4% +1.0%
DP4A No 8B 6,151 337
DP4A Yes 8B 6,125 336 -0.4% -0.1%

Comment on lines +350 to +355
const int bits4 = (qs0 >> shift) & 0x0F;
const int b0 = (bits4 & 0x01) ? 1 : -1;
const int b1 = (bits4 & 0x02) ? 1 : -1;
const int b2 = (bits4 & 0x04) ? 1 : -1;
const int b3 = (bits4 & 0x08) ? 1 : -1;
unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to optimize this using __vadd4 (does not need to be in this PR).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, will try it. Will send another PR if it helps a lot.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this as well over the weekend, and seems did not make a difference. Will take another pass after for tuning after this PR is merged.

Comment thread ggml/src/ggml-cuda/mmq.cuh Outdated
Comment thread ggml/src/ggml-cuda/mmq.cuh Outdated
Comment on lines +691 to +692
const int v = bq1_0->qs[offset + 0] | (bq1_0->qs[offset + 1] << 8) |
(bq1_0->qs[offset + 2] << 16) | (bq1_0->qs[offset + 3] << 24);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the function get_int_b2 instead.

Comment thread ggml/src/ggml-cuda/vecdotq.cuh Outdated
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
@khosravipasha
Copy link
Copy Markdown
Contributor Author

Is there any more changes needed on our side? I see few actions failed but don't think its due to this PR.

Failing tests:
  ROPE(type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0)
  Backend CUDA0: FAIL
Backend 2/2: CPU
  Skipping CPU backend
1/2 backends passed
FAIL

Copy link
Copy Markdown
Contributor

@am17an am17an left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CI failures a bit random, not sure what's causing them

@JohannesGaessler JohannesGaessler merged commit 7e72b38 into ggml-org:master Apr 15, 2026
46 of 50 checks passed
@khosravipasha khosravipasha deleted the q1-cuda branch April 19, 2026 07:10
mengqin pushed a commit to mengqin/llama.cpp that referenced this pull request Apr 20, 2026
* [cuda] initial Q1_0 backend

* remove unused code, fix AMD MMA guard

* attempt to support dp4a

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Apr 21, 2026
* [cuda] initial Q1_0 backend

* remove unused code, fix AMD MMA guard

* attempt to support dp4a

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Apr 23, 2026
* [cuda] initial Q1_0 backend

* remove unused code, fix AMD MMA guard

* attempt to support dp4a

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants