Skip to content

vulkan: add mul_mat variant for embedded gpus #15800

Open
rmatif wants to merge 4 commits intoggml-org:masterfrom
rmatif:vk-mulmat-embed
Open

vulkan: add mul_mat variant for embedded gpus #15800
rmatif wants to merge 4 commits intoggml-org:masterfrom
rmatif:vk-mulmat-embed

Conversation

@rmatif
Copy link
Copy Markdown
Collaborator

@rmatif rmatif commented Sep 4, 2025

This PR adds a mat_mul variant designed for embedded gpus, as the current shaders perform poorly. Essentially I reused the approach from my opencl implementation #14535

It has currently been tested only on mali gpu, but I believe it should suits well on others as well

ggml_vulkan: 0 = Mali-G715 (Mali-G715) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 16 | shared memory: 32768 | int dot: 0 | matrix cores: KHR_coopmat

Model Test Master PR Speedup
Qwen2 1.5B Q4_0 pp512 2.81 ± 0.03 66.26 ± 0.06 23.58x
Llama 1B F16 pp512 5.76 ± 0.08 95.83 ± 0.14 16.63x

Master:

  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                 2 runs - 634214.00 us/run -  60.13 GFLOP/run -  8.92 GFLOPS

PR:

  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                 2 runs - 634214.00 us/run -  60.13 GFLOP/run -  150.01 GFLOPS

I started messing with coopmat, but didn’t have enough time to include it in this PR. If we go this route for embedded gpus, I plan to add variants for conv2d, mul_vec, and maybe fa in the future

@rmatif rmatif requested a review from 0cc4m as a code owner September 4, 2025 16:25
@rmatif
Copy link
Copy Markdown
Collaborator Author

rmatif commented Sep 4, 2025

I got two failures with iq2_s and iq3_s. Does anyone know why these two in particular are failing?

[MUL_MAT] NMSE = 0.363474692 > 0.000500000   MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): FAIL
[MUL_MAT] NMSE = 2.417624850 > 0.000500000   MUL_MAT(type_a=iq3_s,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): FAIL

@jeffbolznv
Copy link
Copy Markdown
Contributor

Does anyone know why these two in particular are failing?

I think there's a pre-existing bug in the dequant functions for these types. I ran into this once before when working on get_rows (I think?) and the bug didn't seen to be happening in any code paths that were getting hit in practice, and the bug wasn't obvious, so I didn't pursue it.

Can you explain what it is about your matmul shader that makes it faster for mobile? At first glance it doesn't seem fundamentally different from what the scalar path is doing, except that you're dequantizing the matrix as a separate pass, and maybe some minor differences to tile size. I'll leave some comments on the code in a little while.

@github-actions github-actions Bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Sep 4, 2025
@rmatif
Copy link
Copy Markdown
Collaborator Author

rmatif commented Sep 4, 2025

Can you explain what it is about your matmul shader that makes it faster for mobile?

The main difference I'd say is the very low register pressure and avoiding spilling. Mali provides only 64 registers per thread, so the entire design was built around that constraint. I tried fine-tuning the existing shaders some time ago but without success. I also believe that due to the simplicity, even outdated drivers on low-end hardware should handle them more easily

@jeffbolznv
Copy link
Copy Markdown
Contributor

But you can reduce the register usage by changing the tile size via spec constants. The other big difference I see is you're using vec4s everywhere, I wonder if that's somehow related.

@rmatif
Copy link
Copy Markdown
Collaborator Author

rmatif commented Sep 4, 2025

But you can reduce the register usage by changing the tile size via spec constants. The other big difference I see is you're using vec4s everywhere, I wonder if that's somehow related.

I know, but for some reason that wasn’t enough (though I was testing on an older and less powerful device, so I should give it another try). On adreno vec4 is much faster, but on mali on the latest gen it shouldn’t make much difference, according to arm’s documentation it should perform similarly to scalar, so I kept it in case another device requires it

Comment thread ggml/src/ggml-vulkan/ggml-vulkan.cpp Outdated
Comment thread ggml/src/ggml-vulkan/ggml-vulkan.cpp Outdated
Comment thread ggml/src/ggml-vulkan/ggml-vulkan.cpp
Comment thread ggml/src/ggml-vulkan/ggml-vulkan.cpp Outdated
Comment thread ggml/src/ggml-vulkan/ggml-vulkan.cpp
Comment thread ggml/src/ggml-vulkan/ggml-vulkan.cpp Outdated
for (uint t = 0; t < num_k_tiles; t++) {
const uint k_tile_start = t * BK;

#pragma unroll
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[[unroll]] is preferred.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't #pragma unroll better for old compilers?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use [[unroll]] in a bunch of shaders and haven't had any problems.

Comment thread ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp Outdated
Comment thread ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp Outdated
Copy link
Copy Markdown
Contributor

@0cc4m 0cc4m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have time right now for a full review, but I wanted to add this to Jeff's comments.

Edit: impressive results, nice work. It's great to add support to more kinds of devices.

Comment thread ggml/src/ggml-vulkan/ggml-vulkan.cpp Outdated
}
}

if (device->vendor_id == VK_VENDOR_ID_ARM) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if this used the same codepath as the main shader, since it's duplicating quite a bit. If the push constants are the same, is there a reason not to just add an ARM path to the shader selection function and leave the dequant to the existing logic?

@netrunnereve
Copy link
Copy Markdown
Collaborator

If we go this route for embedded gpus, I plan to add variants for conv2d, mul_vec, and maybe fa in the future

If we go that route we'll need a better way of testing these shaders. As a start we can have an environment variable or compile flag to enable the embedded path, and while I don't see anything arm specific in the code we should try it out on some desktop GPUs to see if it runs fine there.

@netrunnereve
Copy link
Copy Markdown
Collaborator

Also I'm finding it a bit hard to believe that it's faster to dequant everything, write that to regular memory, and than read all that back into shared memory for the actual multiplication. That's basically limiting everything to your memory speed like inference.

I don't think having less registers matters in this case since there's not much difference in register count between writing the dequantized weights to regular memory versus shared memory. The mul mat shader can be treated as two sections where one does the dequantization to shared memory and the other does the actual multiplication, and the register counts for each can be more or less independent. I'd be curious to see how much long the dequantization takes compared to the multiplication if you run with GGML_VK_PERF_LOGGER.

And as a side note our dequant functions are like triplicated across dequant_funcs.comp, mul_mm.comp, and dequant_q*.comp and they really should be merged if possible.

@rmatif
Copy link
Copy Markdown
Collaborator Author

rmatif commented Sep 5, 2025

I'd be curious to see how much long the dequantization takes compared to the multiplication if you run with GGML_VK_PERF_LOGGER

I will try to take a look

I did a quick test by fine-tuning and matching the tile sizes of the current scalar shaders to this one, and I observed roughly the same results (a bit faster due to coopmat). Marking this as a draft for now, as I still need to test it on older devices and other vendors to see if it does any good

@rmatif rmatif marked this pull request as draft September 5, 2025 06:33
@0cc4m
Copy link
Copy Markdown
Contributor

0cc4m commented Sep 5, 2025

The perf logger does not show the dequant and matmul dispatches separately, you'd need a profiler for that.

@rmatif
Copy link
Copy Markdown
Collaborator Author

rmatif commented Sep 5, 2025

The perf logger does not show the dequant and matmul dispatches separately, you'd need a profiler for that.

I don’t know if Arm offers that. I will reach out to some of their engineers and ask them

@0cc4m
Copy link
Copy Markdown
Contributor

0cc4m commented Sep 5, 2025

I've done most of my work without profilers, just logically it should be better to dequant directly to shared memory because you avoid the need for an intermediate buffer in vram and you avoid the global reads/writes. But if the regular mul_mm shader works with smaller tiles already, then you don't need to implement that, luckily.

@rmatif
Copy link
Copy Markdown
Collaborator Author

rmatif commented Sep 5, 2025

I've done most of my work without profilers, just logically it should be better to dequant directly to shared memory because you avoid the need for an intermediate buffer in vram and you avoid the global reads/writes. But if the regular mul_mm shader works with smaller tiles already, then you don't need to implement that, luckily.

I spent quite a bit of time trying to do the dequant directly in the shared memory, but it turned out to be too much work for me, so I dropped it and and by "laziness" adopted this two-stage approach. That said I do agree that the optimal solution is to do it directly

@rmatif rmatif marked this pull request as ready for review September 5, 2025 21:27
@rmatif rmatif requested a review from jeffbolznv September 5, 2025 21:54
@rmatif
Copy link
Copy Markdown
Collaborator Author

rmatif commented Sep 5, 2025

I’ve unlocked a massive improvement that goes far beyond the fine-tuned existing shaders, that would justify the existence of this one. I’ve updated the performance numbers in the OP and will soon look into reports on older devices

Turns out the compiler has some specific quirks. It generates much faster code from explicit manually unrolled mad, so I've flattened the inner loops. While further unrolling is tempting it drastically increases register pressure and may causes instability. I stopped here, I'm a bit short in time anyway

@jeffbolznv If you have a chance, could you please review again? I believe I’ve addressed your comments but let me know if I missed something

As a start we can have an environment variable

I think using an env var is better. I don’t see why it wouldn’t work on dGPU given the extreme simplicity, and honestly I don’t see the usefulness of running it there. The compilers and architectures are so different that I don’t think we’d gain much information from it

Also I'm finding it a bit hard to believe that it's faster to dequant everything, write that to regular memory, and than read all that back into shared memory for the actual multiplication. That's basically limiting everything to your memory speed like inference

I haven’t touched the dequant part, it’s still the same. I’m writing to an f16 buffer, and the f16xf32/f32xf32 path is now much faster hence the speedup. As I mentioned, I started experimenting with dequant in shared memory, but it turned out to be too much work for a first step. We can leave that as a future plan

If someone has a Raspberry Pi lying around, I’d be very curious to see the performance gains on that kind of device

}
}

const uint num_k_tiles = (p.K + BK - 1) / BK;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not robust enough and might be wrong for adreno case, but it passes the tests on test-backend-ops feel like it shouldn't

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be hard to add a case or two with odd K. I suggest having relatively small M,N to avoid the error being hidden.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misquoted, I was thinking more about the adreno case:

BM = 32, BK = 8 -> VEC_K = 2
WG_SIZE = 128
A_LOADS_PER_THREAD = (32 * 2) / 128 = 64 / 128 = 0

So theoretically it shouldn’t be able to load matrix A regardless of the dimensions, but the tests are passing so I’m a bit confused

@jeffbolznv
Copy link
Copy Markdown
Contributor

 I think using an env var is better. I don’t see why it wouldn’t work on dGPU given the extreme simplicity, and honestly I don’t see the usefulness of running it there. The compilers and architectures are so different that I don’t think we’d gain much information from it

It's useful to be able to test it for correctness.

Copy link
Copy Markdown
Contributor

@jeffbolznv jeffbolznv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd still like to see this, at the least, using the same push constant and spec constant interface as the rest of the matmul shaders, and running through the same code paths in ggml_vk_mul_mat_q_f16. I think it would be nice if it could be folded into mul_mm, but I'm not sure we understand what specifically is causing the better perf.

return;
}

const std::vector<uint32_t> pc = { (uint32_t)M, (uint32_t)K, (uint32_t)K, (uint32_t)K, (uint32_t)(ggml_nelements(src0)) };
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this path is not handling noncontiguous src0. Like @0cc4m said, it'll be better to let this run through the existing code paths rather than having this separate code path.

}
}

const uint num_k_tiles = (p.K + BK - 1) / BK;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be hard to add a case or two with odd K. I suggest having relatively small M,N to avoid the error being hidden.

@jeffbolznv
Copy link
Copy Markdown
Contributor

I've made a PR to fix the failing dequant shaders.

@netrunnereve
Copy link
Copy Markdown
Collaborator

Turns out the compiler has some specific quirks. It generates much faster code from explicit manually unrolled mad, so I've flattened the inner loops. While further unrolling is tempting it drastically increases register pressure and may causes instability.

At this point I think it's time to start looking into the Arm dev tools and assembly dumps, if they even have them 😉

@rmatif
Copy link
Copy Markdown
Collaborator Author

rmatif commented Sep 8, 2025

Turns out the compiler has some specific quirks. It generates much faster code from explicit manually unrolled mad, so I've flattened the inner loops. While further unrolling is tempting it drastically increases register pressure and may causes instability.

At this point I think it's time to start looking into the Arm dev tools and assembly dumps, if they even have them 😉

I’ve tried reaching out to some ARM devs, so hopefully they’ll take a look into it

Speaking of compilers, the Adreno compilers crash when running the existing mul_mat shaders, although they work fine with this one, so I think we’ll need this variant anyway (performances are better than opencl but I suspect a bug so I won't jump too quickly)

Sorry I don’t have much time to address all the concerns right now. I’ll try when I have time at least to implement an env var to run the shaders on the dGPU

EDIT: Just got confirmation from an arm dev, they don’t have a public disassembler, only a profiler

const uint64_t ne12 = src1->ne[2];
const uint64_t ne13 = src1->ne[3];

if ((ctx->device->vendor_id == VK_VENDOR_ID_ARM || ctx->device->vendor_id == VK_VENDOR_ID_QUALCOMM) &&
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems to have an impact also with Intel integrated GPUs in some cases:

before:

| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | Vulkan     | 999 |           pp512 |         60.12 ± 0.00 |
| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | Vulkan     | 999 |           tg128 |          8.02 ± 0.00 |

| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     | 999 |           pp512 |        109.99 ± 0.00 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     | 999 |           tg128 |          8.12 ± 0.00 |

after:

| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | Vulkan     | 999 |           pp512 |         86.34 ± 0.00 |
| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | Vulkan     | 999 |           tg128 |          8.30 ± 0.00 |

| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     | 999 |           pp512 |        101.52 ± 0.00 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     | 999 |           tg128 |          8.01 ± 0.00 |

@Benjamin-Wegener
Copy link
Copy Markdown

Benjamin-Wegener commented Oct 28, 2025

i managed to compile it under termux on a mali g52 mc2 with vulkan 1.1 without errors. before i had complaints about the vulkan version not beeing 1.3. or seg faults. so good work so far!

since i am relatively new to vulkan coding, what would the the next steps to make it faster on my device? i am getting same speed on qwen3 4b as on cpu alone, offloading all 37 layers to gpu, using no extra parameters.

EDIT:
using my other phone with mali g76 mc4 i get 6t/s with gpu on qwen3 1.7b q4_0 and 4.9t/s on cpu

@System64fumo
Copy link
Copy Markdown

Just tested this with an rk3588 machine using panvk, Unfortunately it crashed with a simple "Hello" prompt:

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Mali-G610 (panvk) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 16 | shared memory: 32768 | int dot: 1 | matrix cores: none
build: 1 (850b9bf) with cc (GCC) 15.2.1 20251112 for aarch64-unknown-linux-gnu
system info: n_threads = 8, n_threads_batch = 8, total_threads = 8

system_info: n_threads = 8 (n_threads_batch = 8) / 8 | CPU : NEON = 1 | ARM_FMA = 1 | FP16_VA = 1 | DOTPROD = 1 | LLAMAFILE = 1 | OPENMP = 1 | KLEIDIAI = 1 | REPACK = 1 | 

main: binding port with default address family
main: HTTP server is listening, hostname: 192.168.2.1, port: 8080, http threads: 7
main: loading model
srv    load_model: loading model '/mnt/nas/Llama-3.2-1B-Instruct-Q4_1.gguf'
llama_model_load_from_file_impl: using device Vulkan0 (Mali-G610) - 11859 MiB free
llama_model_loader: loaded meta data with 36 key-value pairs and 147 tensors from /mnt/nas/Llama-3.2-1B-Instruct-Q4_1.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Llama-3.2-1B-Instruct
llama_model_loader: - kv   3:                           general.finetune str              = Instruct
llama_model_loader: - kv   4:                           general.basename str              = Llama-3.2-1B-Instruct
llama_model_loader: - kv   5:                       general.quantized_by str              = Unsloth
llama_model_loader: - kv   6:                         general.size_label str              = 1B
llama_model_loader: - kv   7:                           general.repo_url str              = https://huggingface.co/unsloth
llama_model_loader: - kv   8:                          llama.block_count u32              = 16
llama_model_loader: - kv   9:                       llama.context_length u32              = 131072
llama_model_loader: - kv  10:                     llama.embedding_length u32              = 2048
llama_model_loader: - kv  11:                  llama.feed_forward_length u32              = 8192
llama_model_loader: - kv  12:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv  13:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv  14:                       llama.rope.freq_base f32              = 500000.000000
llama_model_loader: - kv  15:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  16:                 llama.attention.key_length u32              = 64
llama_model_loader: - kv  17:               llama.attention.value_length u32              = 64
llama_model_loader: - kv  18:                           llama.vocab_size u32              = 128256
llama_model_loader: - kv  19:                 llama.rope.dimension_count u32              = 64
llama_model_loader: - kv  20:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  21:                         tokenizer.ggml.pre str              = llama-bpe
llama_model_loader: - kv  22:                      tokenizer.ggml.tokens arr[str,128256]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  23:                  tokenizer.ggml.token_type arr[i32,128256]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  24:                      tokenizer.ggml.merges arr[str,280147]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  25:                tokenizer.ggml.bos_token_id u32              = 128000
llama_model_loader: - kv  26:                tokenizer.ggml.eos_token_id u32              = 128009
llama_model_loader: - kv  27:            tokenizer.ggml.padding_token_id u32              = 128004
llama_model_loader: - kv  28:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  29:                    tokenizer.chat_template str              = {{- bos_token }}\n{%- if custom_tools ...
llama_model_loader: - kv  30:               general.quantization_version u32              = 2
llama_model_loader: - kv  31:                          general.file_type u32              = 3
llama_model_loader: - kv  32:                      quantize.imatrix.file str              = Llama-3.2-1B-Instruct-GGUF/imatrix_un...
llama_model_loader: - kv  33:                   quantize.imatrix.dataset str              = unsloth_calibration_Llama-3.2-1B-Inst...
llama_model_loader: - kv  34:             quantize.imatrix.entries_count i32              = 112
llama_model_loader: - kv  35:              quantize.imatrix.chunks_count i32              = 689
llama_model_loader: - type  f32:   34 tensors
llama_model_loader: - type q4_1:  112 tensors
llama_model_loader: - type q6_K:    1 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q4_1
print_info: file size   = 785.75 MiB (5.33 BPW) 
load: printing all EOG tokens:
load:   - 128001 ('<|end_of_text|>')
load:   - 128008 ('<|eom_id|>')
load:   - 128009 ('<|eot_id|>')
load: special tokens cache size = 256
load: token to piece cache size = 0.7999 MB
print_info: arch             = llama
print_info: vocab_only       = 0
print_info: n_ctx_train      = 131072
print_info: n_embd           = 2048
print_info: n_layer          = 16
print_info: n_head           = 32
print_info: n_head_kv        = 8
print_info: n_rot            = 64
print_info: n_swa            = 0
print_info: is_swa_any       = 0
print_info: n_embd_head_k    = 64
print_info: n_embd_head_v    = 64
print_info: n_gqa            = 4
print_info: n_embd_k_gqa     = 512
print_info: n_embd_v_gqa     = 512
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-05
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 8192
print_info: n_expert         = 0
print_info: n_expert_used    = 0
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 0
print_info: rope scaling     = linear
print_info: freq_base_train  = 500000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 131072
print_info: rope_finetuned   = unknown
print_info: model type       = 1B
print_info: model params     = 1.24 B
print_info: general.name     = Llama-3.2-1B-Instruct
print_info: vocab type       = BPE
print_info: n_vocab          = 128256
print_info: n_merges         = 280147
print_info: BOS token        = 128000 '<|begin_of_text|>'
print_info: EOS token        = 128009 '<|eot_id|>'
print_info: EOT token        = 128009 '<|eot_id|>'
print_info: EOM token        = 128008 '<|eom_id|>'
print_info: PAD token        = 128004 '<|finetune_right_pad_id|>'
print_info: LF token         = 198 'Ċ'
print_info: EOG token        = 128001 '<|end_of_text|>'
print_info: EOG token        = 128008 '<|eom_id|>'
print_info: EOG token        = 128009 '<|eot_id|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 16 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 17/17 layers to GPU
load_tensors:      Vulkan0 model buffer size =   785.75 MiB
load_tensors:   CPU_Mapped model buffer size =   205.49 MiB
...........................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4096
llama_context: n_ctx_per_seq = 4096
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = enabled
llama_context: kv_unified    = false
llama_context: freq_base     = 500000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (4096) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_context: Vulkan_Host  output buffer size =     0.49 MiB
llama_kv_cache:    Vulkan0 KV buffer size =   128.00 MiB
llama_kv_cache: size =  128.00 MiB (  4096 cells,  16 layers,  1/1 seqs), K (f16):   64.00 MiB, V (f16):   64.00 MiB
llama_context:    Vulkan0 compute buffer size =   254.50 MiB
llama_context: Vulkan_Host compute buffer size =    12.02 MiB
llama_context: graph nodes  = 503
llama_context: graph splits = 2
common_init_from_params: added <|end_of_text|> logit bias = -inf
common_init_from_params: added <|eom_id|> logit bias = -inf
common_init_from_params: added <|eot_id|> logit bias = -inf
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
srv          init: initializing slots, n_slots = 1
slot         init: id  0 | task -1 | new slot n_ctx_slot = 4096
main: model loaded
main: chat template, chat_template: {{- bos_token }}
{%- if custom_tools is defined %}
    {%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
    {%- set tools_in_user_message = true %}
{%- endif %}
{%- if not date_string is defined %}
    {%- if strftime_now is defined %}
        {%- set date_string = strftime_now("%d %b %Y") %}
    {%- else %}
        {%- set date_string = "26 Jul 2024" %}
    {%- endif %}
{%- endif %}
{%- if not tools is defined %}
    {%- set tools = none %}
{%- endif %}

{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
    {%- set system_message = messages[0]['content']|trim %}
    {%- set messages = messages[1:] %}
{%- else %}
    {%- set system_message = "" %}
{%- endif %}

{#- System message #}
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
{%- if tools is not none %}
    {{- "Environment: ipython\n" }}
{%- endif %}
{{- "Cutting Knowledge Date: December 2023\n" }}
{{- "Today Date: " + date_string + "\n\n" }}
{%- if tools is not none and not tools_in_user_message %}
    {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
    {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
    {{- "Do not use variables.\n\n" }}
    {%- for t in tools %}
        {{- t | tojson(indent=4) }}
        {{- "\n\n" }}
    {%- endfor %}
{%- endif %}
{{- system_message }}
{{- "<|eot_id|>" }}

{#- Custom tools are passed in a user message with some extra guidance #}
{%- if tools_in_user_message and not tools is none %}
    {#- Extract the first user message so we can plug it in here #}
    {%- if messages | length != 0 %}
        {%- set first_user_message = messages[0]['content']|trim %}
        {%- set messages = messages[1:] %}
    {%- else %}
        {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
{%- endif %}
    {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
    {{- "Given the following functions, please respond with a JSON for a function call " }}
    {{- "with its proper arguments that best answers the given prompt.\n\n" }}
    {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
    {{- "Do not use variables.\n\n" }}
    {%- for t in tools %}
        {{- t | tojson(indent=4) }}
        {{- "\n\n" }}
    {%- endfor %}
    {{- first_user_message + "<|eot_id|>"}}
{%- endif %}

{%- for message in messages %}
    {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
        {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}
    {%- elif 'tool_calls' in message %}
        {%- if not message.tool_calls|length == 1 %}
            {{- raise_exception("This model only supports single tool-calls at once!") }}
        {%- endif %}
        {%- set tool_call = message.tool_calls[0].function %}
        {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
        {{- '{"name": "' + tool_call.name + '", ' }}
        {{- '"parameters": ' }}
        {{- tool_call.arguments | tojson }}
        {{- "}" }}
        {{- "<|eot_id|>" }}
    {%- elif message.role == "tool" or message.role == "ipython" %}
        {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
        {%- if message.content is mapping or message.content is iterable %}
            {{- message.content | tojson }}
        {%- else %}
            {{- message.content }}
        {%- endif %}
        {{- "<|eot_id|>" }}
    {%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
, example_format: '<|start_header_id|>system<|end_header_id|>

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

Hello<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Hi there<|eot_id|><|start_header_id|>user<|end_header_id|>

How are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

'
main: server is listening on http://192.168.2.1:8080 - starting the main loop
srv  update_slots: all slots are idle
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 4096, n_keep = 0, n_prompt_tokens = 11
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 11, n_tokens = 11, progress = 1.000000
slot update_slots: id  0 | task 0 | prompt done, n_past = 11, n_tokens = 11
[New LWP 8512]
[New LWP 8511]
[New LWP 8510]
[New LWP 8509]
[New LWP 8508]
[New LWP 8507]
[New LWP 8506]
[New LWP 8505]
[New LWP 8504]
[New LWP 8503]
[New LWP 8502]
[New LWP 8501]

This GDB supports auto-downloading debuginfo from the following URLs:
  <https://debuginfod.artixlinux.org>
Enable debuginfod for this session? (y or [n]) [answered N; input not from terminal]
Debuginfod has been disabled.
To make this setting permanent, add 'set debuginfod enabled off' to .gdbinit.
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/usr/lib/libthread_db.so.1".
0x0000ffff88d03d68 in ?? () from /usr/lib/libc.so.6
#0  0x0000ffff88d03d68 in ?? () from /usr/lib/libc.so.6
#1  0x0000ffff88cf6bd4 in ?? () from /usr/lib/libc.so.6
#2  0x0000ffff88cf6c18 in ?? () from /usr/lib/libc.so.6
#3  0x0000ffff88d4eda0 in wait4 () from /usr/lib/libc.so.6
#4  0x0000ffff89252910 in ggml_print_backtrace () from /mnt/nas/llama.cpp/build/bin/libggml-base.so
#5  0x0000ffff8926406c in ggml_uncaught_exception() () from /mnt/nas/llama.cpp/build/bin/libggml-base.so
#6  0x0000ffff8900f1ac in __cxxabiv1::__terminate (handler=<optimized out>) at /usr/src/debug/gcc/gcc/libstdc++-v3/libsupc++/eh_terminate.cc:48
warning: 48	/usr/src/debug/gcc/gcc/libstdc++-v3/libsupc++/eh_terminate.cc: No such file or directory
#7  0x0000ffff8900366c in std::terminate () at /usr/src/debug/gcc/gcc/libstdc++-v3/libsupc++/eh_terminate.cc:58
58	in /usr/src/debug/gcc/gcc/libstdc++-v3/libsupc++/eh_terminate.cc
#8  0x0000ffff8900f5c8 in __cxxabiv1::__cxa_throw (obj=<optimized out>, tinfo=0xffff8946c128 <typeinfo for vk::DeviceLostError>, dest=0xffff8941a7e0 <vk::DeviceLostError::~DeviceLostError()>) at /usr/src/debug/gcc/gcc/libstdc++-v3/libsupc++/eh_throw.cc:98
warning: 98	/usr/src/debug/gcc/gcc/libstdc++-v3/libsupc++/eh_throw.cc: No such file or directory
#9  0x0000ffff893433f4 in ggml_vk_wait_for_fence(ggml_backend_vk_context*) () from /mnt/nas/llama.cpp/build/bin/libggml-vulkan.so
#10 0x0000ffff894072d8 in ggml_vk_build_graph(ggml_backend_vk_context*, ggml_cgraph*, int, ggml_tensor*, int, bool, bool, bool, bool) () from /mnt/nas/llama.cpp/build/bin/libggml-vulkan.so
#11 0x0000ffff89408614 in ggml_backend_vk_graph_compute(ggml_backend*, ggml_cgraph*) () from /mnt/nas/llama.cpp/build/bin/libggml-vulkan.so
#12 0x0000ffff8926c6bc in ggml_backend_sched_graph_compute_async () from /mnt/nas/llama.cpp/build/bin/libggml-base.so
#13 0x0000ffff8bfaa6e0 in llama_context::graph_compute(ggml_cgraph*, bool) () from /mnt/nas/llama.cpp/build/bin/libllama.so
#14 0x0000ffff8bfac008 in llama_context::process_ubatch(llama_ubatch const&, llm_graph_type, llama_memory_context_i*, ggml_status&) () from /mnt/nas/llama.cpp/build/bin/libllama.so
#15 0x0000ffff8bfb1064 in llama_context::decode(llama_batch const&) () from /mnt/nas/llama.cpp/build/bin/libllama.so
#16 0x0000ffff8bfb1f0c in llama_decode () from /mnt/nas/llama.cpp/build/bin/libllama.so
#17 0x0000aaaad975db54 in server_context::update_slots() ()
#18 0x0000aaaad972f994 in server_queue::start_loop() ()
#19 0x0000aaaad96e83f4 in main ()

@SuperPauly
Copy link
Copy Markdown

@rmatif Any updates?

@rmatif
Copy link
Copy Markdown
Collaborator Author

rmatif commented Jan 12, 2026

@rmatif Any updates?

Sorry, I got super busy these days, but I’ll definitely come back to this at some point, just not sure when

@SuperPauly
Copy link
Copy Markdown

@rmatif Any updates?

Sorry, I got super busy these days, but I’ll definitely come back to this at some point, just not sure when

Ah it's okay 😀 we all got our own lives to lead, I wish I was smart enough to continue your work, maybe someone else will pick it up an continue, you never know!

I've hurt my back & been bed bound for a couple years and can't use a laptop or pc since sitting & standing becomes painful after a few mins so I'm in bed trying to do what I normally would like code and DevOps stuff but only using a Pixel 8 Pro phone.

So playing with on device LLMs and watching them improve over the past 2 years has been great to watch! I love staying upto date with the Vulkan/Mali 720 speed improvements like this as Snapdragon seems to get all the attention!

I was wondering the status of the PR because I was curious of the TPS increase from This PR and this PR #18493

I'm curious, what's left to do to complete this PR? Was the only issue left the compilers qwirky behaviour?

@Gong-Mi
Copy link
Copy Markdown

Gong-Mi commented Jan 19, 2026

Hello! I saw this PR is targeting embedded GPUs.

I wanted to share some baseline benchmark data from an Adreno 840 (Snapdragon 8 Elite) using Gong-Mi/vkpeak. This fork includes vec2 and dual variant metrics which are useful for identifying peak capabilities on modern high-end embedded Vulkan devices.

Device Info:

  • Device: Adreno (TM) 840 (Snapdragon 8 Elite)
  • Driver: Qualcomm Technologies Inc. Adreno Vulkan Driver (Build: f4e1c113f2)

Key Metrics:

  • FP32: ~2784 GFLOPS (Scalar) / ~3193 GFLOPS (Vec8)
  • FP16: ~3148 GFLOPS (Vec2)
  • INT8: ~25676 GIOPS (Vec4 Arithmetic) / ~6771 GIOPS (DotProduct)
  • Bandwidth (D2D): ~26.34 GBPS
Full vkpeak Output
device       = Adreno (TM) 840
driver       = Qualcomm Technologies Inc. Adreno Vulkan Driver / Driver Build: f4e1c113f2, I2a5dcc3b1e, 1765377043
Date: 12/10/25
Compiler Version: E031.50.19.13

fp32-scalar  = 2784.41 GFLOPS
fp32-dual    = 2776.76 GFLOPS (Scalar x2)
fp32-vec2    = 3116.94 GFLOPS
fp32-vec2-x2 = 2167.96 GFLOPS (Vec2 x2)
fp32-vec4    = 2201.77 GFLOPS
fp32-vec8    = 3193.24 GFLOPS

fp16-scalar  = 1832.25 GFLOPS
fp16-dual    = 2784.80 GFLOPS (Scalar x2)
fp16-vec2    = 3148.28 GFLOPS
fp16-vec2-x2 = 3146.84 GFLOPS (Vec2 x2)
fp16-vec4    = 1880.60 GFLOPS
fp16-vec8    = 2505.88 GFLOPS

int32-scalar =  800.21 GIOPS
int32-dual   =  799.40 GIOPS (Scalar x2)
int32-vec2   =  664.36 GIOPS
int32-vec2-x2=  624.49 GIOPS (Vec2 x2)
int32-vec4   =  772.72 GIOPS
int32-vec8   =  626.75 GIOPS

int16-scalar = 2784.34 GIOPS
int16-dual   = 2790.33 GIOPS (Scalar x2)
int16-vec2   = 3097.40 GIOPS
int16-vec2-x2= 2844.31 GIOPS (Vec2 x2)
int16-vec4   =  987.83 GIOPS
int16-vec8   = 1597.55 GIOPS

int64-scalar =  203.34 GIOPS
int64-dual   =  165.06 GIOPS (Scalar x2)
int64-vec2   =  178.29 GIOPS
int64-vec2-x2=  160.14 GIOPS (Vec2 x2)
int64-vec4   =  156.03 GIOPS
int64-vec8   =  146.77 GIOPS

int8-scalar  =  976.09 GIOPS
int8-dual    =  976.54 GIOPS (Scalar x2)
int8-vec2    =  986.47 GIOPS
int8-vec2-x2 =  987.18 GIOPS (Vec2 x2)
int8-vec4    = 25676.46 GIOPS (Arithmetic)
int8-dotprod = 6771.81 GIOPS
int8-vec8    = 11582.07 GIOPS

copy-h2h     = 15.97 GBPS
copy-h2d     = 18.10 GBPS
copy-d2h     = 19.80 GBPS
copy-d2d     = 26.34 GBPS

Is this kind of peak capability data useful for your embedded GPU optimization efforts? Just wanted to share in case it helps!

@Gong-Mi
Copy link
Copy Markdown

Gong-Mi commented Jan 19, 2026

Quick note regarding the Layout[] Node debug dumps in the output:

These artifacts suggest that the Adreno 840 Vulkan Driver (on the Snapdragon 8 Elite) is struggling with certain advanced shader features (likely cooperative matrix or specific extensions), despite claiming support.

It would be advisable to avoid using these specific Vulkan extensions/paths on this hardware for now to ensure stability. While the scalar and vector metrics are solid, the driver's handling of matrix operations seems immature and might lead to crashes or undefined behavior in a real production workload like llama.cpp.

@Benjamin-Wegener
Copy link
Copy Markdown

these are tensorg4 pixel 9a results:

device = Mali-G715
driver = Mali-G715 / v1.r54p1-11eac0.22cd999d6ad92369434ede032544ba86

fp32-scalar = 1436.36 GFLOPS
fp32-vec4 = 1484.59 GFLOPS

fp16-scalar = 1419.75 GFLOPS
fp16-vec4 = 2861.68 GFLOPS
fp16-matrix = 1636.29 GFLOPS

fp64-scalar = 0.00 GFLOPS
fp64-vec4 = 0.00 GFLOPS

int32-scalar = 207.53 GIOPS
int32-vec4 = 198.12 GIOPS

int16-scalar = 207.61 GIOPS
int16-vec4 = 414.74 GIOPS

int64-scalar = 49.36 GIOPS
int64-vec4 = 49.31 GIOPS

int8-dotprod = 6300.69 GIOPS
int8-matrix = 6476.82 GIOPS

bf16-dotprod = 0.00 GFLOPS
bf16-matrix = 0.00 GFLOPS

fp8-matrix = 0.00 GFLOPS
bf8-matrix = 0.00 GFLOPS

copy-h2h = 13.09 GBPS
copy-h2d = 14.61 GBPS
copy-d2h = 1.23 GBPS
copy-d2d = 19.25 GBPS

---- gong-mi version ----

vkpeak 20240505
[Disclaimer] This tool measures peak throughput. Results are NOT verified for correctness.

fp32-scalar = 1425.40 GFLOPS
fp32-dual = 1422.02 GFLOPS (Scalar x2)
fp32-vec2 = 1416.61 GFLOPS
fp32-vec2-x2 = 1334.48 GFLOPS (Vec2 x2)
fp32-vec4 = 1483.97 GFLOPS
fp32-vec8 = 1480.85 GFLOPS

fp16-scalar = 1417.65 GFLOPS
fp16-dual = 1414.00 GFLOPS (Scalar x2)
fp16-vec2 = 2841.17 GFLOPS
fp16-vec2-x2 = 2842.97 GFLOPS (Vec2 x2)
fp16-vec4 = 2869.99 GFLOPS
fp16-vec8 = 2897.77 GFLOPS
fp16-matrix = 1635.78 GFLOPS

int32-scalar = 208.09 GIOPS
int32-vec4 = 198.28 GIOPS
int32-vec8 = 207.85 GIOPS

int16-scalar = 207.49 GIOPS
int16-vec2 = 413.16 GIOPS
int16-vec8 = 416.22 GIOPS

int64-scalar = 49.18 GIOPS
int64-vec4 = 49.38 GIOPS

int8-scalar = 207.11 GIOPS
int8-general = 587.79 GIOPS (General ALU, no dotprod)
int8-dotprod = 6291.52 GIOPS
int8-matrix = 6477.57 GIOPS

copy-h2h = 12.07 GBPS
copy-h2d = 13.62 GBPS
copy-d2h = 1.87 GBPS
copy-d2d = 19.09 GBPS

@Gong-Mi
Copy link
Copy Markdown

Gong-Mi commented Feb 4, 2026

Try to enable the zero copy effect of Android, you can try to reduce the memory occupation.

@Benjamin-Wegener
Copy link
Copy Markdown

Benjamin-Wegener commented Feb 5, 2026

The build process properly utilized the Android-specific Vulkan extensions including VK_ANDROID_external_memory_android_hardware_buffer which enables zero-copy functionality on Android systems.

The memory copy performance metrics show:
 - Host to host (copy-h2h): 16.94 GBPS
 - Host to device (copy-h2d): 12.54 GBPS
 - Device to host (copy-d2h): 1.24 GBPS
 - Device to device (copy-d2d): 20.41 GBPS

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants