Skip to content

ggml: backend-agnostic tensor parallelism (experimental)#19378

Merged
JohannesGaessler merged 56 commits intoggml-org:masterfrom
JohannesGaessler:ggml-meta-backend-8
Apr 9, 2026
Merged

ggml: backend-agnostic tensor parallelism (experimental)#19378
JohannesGaessler merged 56 commits intoggml-org:masterfrom
JohannesGaessler:ggml-meta-backend-8

Conversation

@JohannesGaessler
Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler commented Feb 5, 2026

This PR adds initial support for tensor parallelism, enabled via specifying --split-mode tensor. This should be considered as an experimental feature that is not yet production ready. In principle the implementation is backend-agnostic, in practice as of right now only the CUDA backend has received the necessary extensions and performance optimizations to make the performance better than --split-mode layer (in some cases).

The preexisting --split-mode row could already parallelize some matrix multiplications in the CUDA backend but this required a synchronization after every single operation. As a consequence the overhead is so large that it is only really worthwhile for old and slow GPUs like P40s where adding a bit of latency between operations makes relatively little difference to the overall runtime. The new implementation works by adding a new "meta" backend that internally wraps multiple conventional ggml backends. When given a compute graph the meta backend then automatically infers how the data is split based on the ggml compute graph and only schedules a synchronization at the necessary points. And the external interface for a meta backend is the same as for any other ggml backend. So in practice the meta backend allows ggml to use multiple GPUs in the same way as a single GPU. Importantly all of this is done at the ggml backend level and there are no hard dependencies for any extensions beyond what already exists on master (but without extensions the performance may be so bad that there is no point).

Current Support Status

What currently works:

  • Multiple CUDA GPUs work.
  • For good performance, make sure that NCCL is installed.
  • Performance should become comparatively better for deeper contexts where there is more work for each individual GPU.
  • The "ROCm" backend works since it is just the CUDA code translated via HIP. On the hardware combinations that I have (RX 6800 + MI50 or RX 9060 XT + MI100) the performance is bad vs. the -sm layer baseline though.
  • Any number of GPUs is supported (limited to an arbitrary maximum of 16 GPUs for practical reasons, ping me if you have a use case that needs a higher limit).
  • The preexisting argument --tensor-split can be used to distribute arbitrary fractions of the data across GPUs.
  • In my testing --n-gpu-layers was working correctly in combination with -sm tensor though I suspect that there are some scenarios where this combination is broken.
  • Most models should work correctly with --split-mode tensor, for more details see below.

What currently doesn't work:

  • There is currently an issue with a VRAM leak due to CUDA graphs being continually created but never destroyed (negligible for -sm layer but not for -sm tensor).
  • Vulkan technically works at short contexts but the performance is bad, at long contexts there are also stability issues.
  • All other backends may work but should be assumed to be broken or unusable by default.
  • Going forward the parallelization of NUMA nodes for better CPU performance is planned. As of right now there is no support.
  • FlashAttention must be enabled and KV cache quantiztation must be disabled, or otherwise the meta backend will be unable to correctly infer the data splits between GPUs.
  • Support for -fit is not implemented, you may need to set the --ctx-size manually.
  • Support for backend sampling is not implemented, I'm not yet sure what the best way to go about it is.
  • Memory use in general is inefficient due to overallocations.

Model support

For quality control of this PR I added end-to-end tests in test-llama-archs to assert that --split-mode row produces consistent results with the CPU backend. Based on these tests most of the model architectures in llama.cpp are working correctly. However, the following model architectures are definitely broken with -sm tensor:

Details
  • bitnet
  • deepseek2
  • falcon_h1
  • gemma3n
  • glm_dsa
  • granite_hybrid
  • grok
  • jamba
  • kimi_linear lfm2
  • lfm2moe
  • mamba
  • mamba2
  • minicpm
  • minimax_m2
  • mistral4
  • mpt
  • nemotron_h
  • nemotron_h_moe
  • olmo2
  • olmoe
  • plamo2
  • t5

Furthermore, for the following model architectures the tests themselves are broken and they probably also will not work:

Details
  • arwkv7
  • bert
  • deepseek2ocr
  • eurobert
  • embed
  • gemma_embedding
  • jina_bert_v2
  • jina_bert_v3
  • modern_bert
  • neo_bert
  • nomic_bert_moe
  • plm
  • rwkv6
  • rwkv6qwen2
  • rwkv7
  • t5encoder

Performance

Performance on 2x RTX 4090
model test t/s -sm layer t/s -sm tensor Speedup
gemma4 26B A4B Q4_0 pp512 8682.29 3428.43 0.39
gemma4 26B A4B Q4_0 pp2048 12917.31 3245.69 0.25
gemma4 26B A4B Q4_0 tg128 197.11 138.76 0.70
gemma4 26B A4B Q4_0 pp512 @ d65536 3892.61 2507.52 0.64
gemma4 26B A4B Q4_0 pp2048 @ d65536 6714.55 2575.19 0.38
gemma4 26B A4B Q4_0 tg128 @ d65536 146.11 118.64 0.81
gemma4 31B Q4_K_M pp512 2998.02 2846.09 0.95
gemma4 31B Q4_K_M pp2048 4475.39 2725.37 0.61
gemma4 31B Q4_K_M tg128 43.81 59.45 1.36
gemma4 31B Q4_K_M pp512 @ d65536 1308.09 1706.04 1.30
gemma4 31B Q4_K_M pp2048 @ d65536 2100.01 1760.90 0.84
gemma4 31B Q4_K_M tg128 @ d65536 34.03 48.80 1.43
gemma4 31B Q8_0 pp512 2736.66 2877.47 1.05
gemma4 31B Q8_0 pp2048 4117.76 2708.27 0.66
gemma4 31B Q8_0 tg128 26.74 40.98 1.53
gemma4 31B Q8_0 pp512 @ d65536 1259.75 1696.77 1.35
gemma4 31B Q8_0 pp2048 @ d65536 2013.56 1751.84 0.87
gemma4 31B Q8_0 tg128 @ d65536 22.75 35.72 1.57
gpt-oss 20B MXFP4 MoE pp512 9552.27 9298.15 0.97
gpt-oss 20B MXFP4 MoE pp2048 13987.72 8734.39 0.62
gpt-oss 20B MXFP4 MoE tg128 260.27 269.68 1.04
gpt-oss 20B MXFP4 MoE pp512 @ d65536 4331.62 5209.42 1.20
gpt-oss 20B MXFP4 MoE pp2048 @ d65536 7116.87 5563.52 0.78
gpt-oss 20B MXFP4 MoE tg128 @ d65536 172.17 204.50 1.19
llama 1B F16 pp512 47450.84 30410.70 0.64
llama 1B F16 pp2048 73060.68 28377.32 0.39
llama 1B F16 tg128 324.01 374.62 1.16
llama 1B F16 pp512 @ d65536 8604.25 9659.83 1.12
llama 1B F16 pp2048 @ d65536 15157.90 11101.89 0.73
llama 1B F16 tg128 @ d65536 181.99 248.22 1.36
llama 8B F16 pp512 10624.32 8119.05 0.76
llama 8B F16 pp2048 15854.74 7654.63 0.48
llama 8B F16 tg128 60.28 96.63 1.60
llama 8B F16 pp512 @ d65536 3043.24 3779.64 1.24
llama 8B F16 pp2048 @ d65536 4791.03 4032.86 0.84
llama 8B F16 tg128 @ d65536 38.81 66.63 1.72
qwen35 27B Q4_0 pp512 3221.77 2927.23 0.91
qwen35 27B Q4_0 pp2048 4286.78 2712.77 0.63
qwen35 27B Q4_0 tg128 51.06 61.82 1.21
qwen35 27B Q4_0 pp512 @ d65536 1843.62 2083.90 1.13
qwen35 27B Q4_0 pp2048 @ d65536 2486.88 2080.78 0.84
qwen35 27B Q4_0 tg128 @ d65536 41.25 53.48 1.30
qwen35 27B Q8_0 pp512 2767.40 2839.12 1.03
qwen35 27B Q8_0 pp2048 3727.66 2639.67 0.71
qwen35 27B Q8_0 tg128 30.61 43.03 1.41
qwen35 27B Q8_0 pp512 @ d65536 1692.06 2044.16 1.21
qwen35 27B Q8_0 pp2048 @ d65536 2295.14 2038.38 0.89
qwen35 27B Q8_0 tg128 @ d65536 26.78 38.54 1.44
qwen35moe 35B.A3B Q4_0 pp512 7347.60 7161.24 0.97
qwen35moe 35B.A3B Q4_0 pp2048 9872.26 6217.26 0.63
qwen35moe 35B.A3B Q4_0 tg128 205.74 141.41 0.69
qwen35moe 35B.A3B Q4_0 pp512 @ d65536 4139.60 4654.36 1.12
qwen35moe 35B.A3B Q4_0 pp2048 @ d65536 5500.91 4644.75 0.84
qwen35moe 35B.A3B Q4_0 tg128 @ d65536 154.43 125.49 0.81
qwen35moe 35B.A3B Q8_0 pp512 6947.57 6989.91 1.01
qwen35moe 35B.A3B Q8_0 pp2048 9451.57 6068.24 0.64
qwen35moe 35B.A3B Q8_0 tg128 158.23 126.58 0.80
qwen35moe 35B.A3B Q8_0 pp512 @ d65536 4019.37 4429.08 1.10
qwen35moe 35B.A3B Q8_0 pp2048 @ d65536 5318.73 4534.64 0.85
qwen35moe 35B.A3B Q8_0 tg128 @ d65536 126.66 107.34 0.85
Performance on 4x RTX 4090
model test t/s -sm layer t/s -sm tensor Speedup
gpt-oss 120B MXFP4 MoE pp512 3822.24 5446.21 1.42
gpt-oss 120B MXFP4 MoE pp8192 4675.65 5211.46 1.11
gpt-oss 120B MXFP4 MoE tg128 180.36 179.79 1.00
gpt-oss 120B MXFP4 MoE pp512 @ d65536 2308.57 3538.01 1.53
gpt-oss 120B MXFP4 MoE pp8192 @ d65536 3034.32 3762.05 1.24
gpt-oss 120B MXFP4 MoE tg128 @ d65536 118.65 147.89 1.25
llama 1B F16 pp512 44766.52 18753.39 0.42
llama 1B F16 pp8192 50398.93 17823.13 0.35
llama 1B F16 tg128 314.75 409.87 1.30
llama 1B F16 pp512 @ d65536 7528.70 7263.59 0.96
llama 1B F16 pp8192 @ d65536 11519.92 7805.30 0.68
llama 1B F16 tg128 @ d65536 178.13 307.56 1.73
llama 8B F16 pp512 10409.04 4262.81 0.41
llama 8B F16 pp8192 11816.43 4112.37 0.35
llama 8B F16 tg128 59.88 143.65 2.40
llama 8B F16 pp512 @ d65536 2884.76 2707.52 0.94
llama 8B F16 pp8192 @ d65536 3844.32 2915.65 0.76
llama 8B F16 tg128 @ d65536 38.65 104.49 2.70
llama 70B Q4_K_M pp512 1439.14 1313.79 0.91
llama 70B Q4_K_M pp8192 1701.36 1264.31 0.74
llama 70B Q4_K_M tg128 21.25 48.31 2.27
llama 70B Q4_K_M pp512 @ d65536 574.89 929.93 1.62
llama 70B Q4_K_M pp8192 @ d65536 726.52 922.87 1.27
llama 70B Q4_K_M tg128 @ d65536 14.26 36.68 2.57
glm4moe 106B.A12B IQ4_XS pp512 2664.37 2983.24 1.12
glm4moe 106B.A12B IQ4_XS pp8192 2931.28 2762.00 0.94
glm4moe 106B.A12B IQ4_XS tg128 98.64 115.14 1.17
glm4moe 106B.A12B IQ4_XS pp512 @ d65536 653.92 1400.18 2.14
glm4moe 106B.A12B IQ4_XS pp8192 @ d65536 799.05 1435.96 1.80
glm4moe 106B.A12B IQ4_XS tg128 @ d65536 42.64 79.20 1.86

All GPUs are connected via 16x PCIe 4.0. pp512 is representative of single-GPU performance since pipelining is impossible, pp2048 and pp8192 are representative of multi-GPU performance with pipelining.

By default models are run with --split-mode layer which runs the GPUs sequentially. For the prompt this is efficient because the tokens can be pipelined with minimal synchronization overhead. However, when generating tokens for a single concurrent request there is no speedup because the GPUs have to run sequentially. --split-mode tensor works for all cases but the synchronization overhead may be prohibitive. The key factor is how long the base runtime is vs. the synchronization overhead; if the model runs slow by default adding a bit of overhead hurts less, if it is already fast adding even a small amount of overhead is really noticeable. So --split-mode tensor works comparatively better for slow GPUs with fast interconnect running large, dense models at large quantiztations and deep contexts, and comparatively worse for fast GPUs with slow interconnect running small, sparse models at small quantizations and shallow contexts.
Similarly, because the runtime per token for the prompt is much shorter than when generating them, the speedup for generation is also much better than for the prompt.

Original PR description

Details This PR adds support for backend-agnostic tensor parallelism, enabled via specifying `--split-mode tensor`. This is done by adding a new "meta" backend that internally wraps multiple "simple" backends but can be used in the same way as a regular ggml backend.

ggml Backend Interface Changes

This PR extends the ggml backend interface with some new functions for tensor copying:

  • set_tensor_2d_async/get_tensor_2d_async which are equivalent to memcpy2DAsync in CUDA. This is not needed for the computation of the meta backend itself but rather for setting/getting weights or the output. Currently not implemented, as a workaround the one-dimensional version is used in a loop.
  • shfl_tensor_async to allow two ggml backends to exchange two tensors and to synchronize on the completion of the exchange. As a fallback cpy_tensor_async can be used but this has a higher latency because the copy in one direction can only start once the one in the other direction has finished. Needed for a generic AllReduce between ggml backends. Implemented.
  • allreduce_tensor_async to allow ggml backends to specify a backend-specific way to do an AllReduce operation. Intended to be used for NCCL support in cooperation with @gaugarg-nv . Not yet implemented.

@slaren please provide feedback regarding whether you agree that these operations should be in the ggml backend interface. For context, all of them are optional and can use existing operations as a fallback.

Meta Backend Implementation Details

The meta backend implements an entire ggml backend stack starting from a meta device. The meta device is created from multiple simple backend devices as well as a function to determine how the data should be split across devices ("split states"). Backend buffer types, buffers, and backends are created as per usual. When calling ggml_backend_graph_compute the code infers the split states of the nodes in the compute graph based on the split states assigned for the weights/kv cache. The basic pattern is to make all tensors mirrored by default. For the weight matrices, do a split in dimension 1, then a split in dimension 0, then an AllReduce. For a transformer this means two AllReduce operations, one after the attention and one after the FFN. The attention is effectively split by dimension 0, which equates to a split by attention head.

An generic AllReduce operation is performed in the meta backend by splitting the graph into subgraphs. After a subgraph is executed, call shfl_tensor_async to make backends exchange partial results, and then have them execute auxiliary graphs that contain only a GGML_ADD operation to combine the results.

The memory allocation for the compute graph is rather tricky - the way I solved it is to allocate the memory for the meta backend as per usual and to then transplant the calculated addresses relative to the backend buffer base pointer to the underlying simple backends. Because the simple tensors only require a fraction of the full memory this yields correct results, though it does result in overallocation for the compute graphs. For the weights/kv cache the memory allocation for the meta backend is done via a new function ggml_backend_meta_alloc_ctx_tensors_from_buft to prevent duplicated weights (which are much larger in size). I'm not yet sure what the best approach will be long-term, I think the graph allocation code in ggml-alloc.c will need to be adjusted.

Current Issues/Limitations

  • Only 1 or 2 GPUs are supported. Note: 1 GPU is not actually any faster, this is only useful for testing whether the code works correctly.
  • All GPUs must have an equal share of the data, --tensor-split has no effect.
  • Only dense models are supported. The LLaMA 3 models seem to be working correctly, I have not yet tested others.
  • Support for llama_params_fit is not implemented so the context size has to be set manually.
  • Without FlashAttention the code will probably crash because some transition between split states is not yet implemented.
  • In principle all backends should work. CUDA does in my testing, Vulkan however does not. I think there may be some issues with deadlock between the GPUs. @jeffbolznv @0cc4m if you could take a look it would be appreciated.
  • Memory for the ggml contexts is being overallocated.
  • Performance is (presumably) still suboptimal vs. NCCL.
  • I'm currently using tensor names to determine how to split individual tensors. I think it would be preferable to use some sort of enum instead (which we already seem to have for loading tensors). This should also be used for llama_params_fit.
  • I'm currently setting ggml_tensor::data to dummy values since that is what is checked in ggml-alloc.c to determine whether or not a tensor is considered allocated. This dummy value should never actually be used for any computations but I don't consider this a good solution.
  • I'm not putting meta devices into the ggml backend device registry (which I think is correct).

Performance

LLaMA 3 on 2x RTX 4090
model test t/s -sm layer t/s -sm row t/s -sm tensor
llama 8B Q4_0 pp512 12550.97 2997.41 6305.68
llama 8B Q4_0 pp2048 18788.11 2970.83 6300.12
llama 8B Q4_0 tg128 175.43 67.00 101.98
llama 8B Q4_0 pp512 @ d32768 5099.65 2200.50 4137.73
llama 8B Q4_0 pp2048 @ d32768 7925.21 2242.84 4337.54
llama 8B Q4_0 tg128 @ d32768 96.78 49.76 102.81
llama 8B Q4_0 pp512 @ d65536 3154.69 1748.05 3139.40
llama 8B Q4_0 pp2048 @ d65536 4996.19 1806.82 3404.70
llama 8B Q4_0 tg128 @ d65536 67.11 40.27 83.62
llama 8B Q4_0 pp512 @ d131072 1800.72 1243.57 2152.01
llama 8B Q4_0 pp2048 @ d131072 2867.83 1294.61 2238.01
llama 8B Q4_0 tg128 @ d131072 41.66 29.19 59.53
llama 8B F16 pp512 10578.54 1591.55 5950.44
llama 8B F16 pp2048 15890.45 1581.04 6005.96
llama 8B F16 tg128 60.07 46.32 70.64
llama 8B F16 pp512 @ d32768 4745.19 1330.02 4032.19
llama 8B F16 pp2048 @ d32768 **7329.28 ** 1347.02 4279.02
llama 8B F16 tg128 @ d32768 47.03 38.02 64.63
llama 8B F16 pp512 @ d65536 3033.10 1154.54 3100.10
llama 8B F16 pp2048 @ d65536 4782.02 1176.88 3341.03
llama 8B F16 tg128 @ d65536 38.72 32.19 56.48
llama 8B F16 pp512 @ d131072 1735.63 905.39 2090.03
llama 8B F16 pp2048 @ d131072 2782.42 936.55 2288.65
llama 8B F16 tg128 @ d131072 28.60 24.79 43.32
llama 70B Q3_K - Small pp512 1287.50 582.81 1072.80
llama 70B Q3_K - Small pp2048 1954.83 590.45 1069.28
llama 70B Q3_K - Small tg128 27.97 20.36 29.65
llama 70B Q3_K - Small pp512 @ d32768 776.80 458.17 812.77
llama 70B Q3_K - Small pp2048 @ d32768 1185.82 459.43 824.99
llama 70B Q3_K - Small tg128 @ d32768 20.85 16.19 29.39

Generally speaking it can be observed that parallelizing larger models has better performance than parallelizing smaller models. Similarly, parallelizing the model becomes more worthwhile as the context depth increases. This makes sense as both of these result in a larger workload per GPU vs. the overhead from parallelization. Token generation benefits more from parallelization than prompt processing because the amount of data that needs to be transferred between GPUs is proportional to batch size - long-term it may make sense to implement support for FP16/BF16 compute types which would count the I/O in half vs. FP32. For pp512 pipeline parallelism is effectively disabled while for pp2048 it's enabled. With pipeline parallelism -sm layer is still faster than -sm tensor even at high context depths.

@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend examples ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Ascend NPU issues specific to Ascend NPUs OpenCL Issues specific to the OpenCL backend IBM zDNN issues specific to IBM zDNN Accelerator labels Feb 5, 2026
@jeffbolznv
Copy link
Copy Markdown
Contributor

In principle all backends should work. CUDA does in my testing, Vulkan however does not. I think there may be some issues with deadlock between the GPUs. @jeffbolznv @0cc4m if you could take a look it would be appreciated.

I'm not seeing a deadlock, just a crash in the driver with an invalid descriptor. I ran llama-bench.exe -fa 1 -p 512 -n 0 -m c:\models\llama-2-7b.Q4_0.gguf -sm tensor

Validation Error: [ VUID-VkDescriptorBufferInfo-offset-00340 ] | MessageID = 0xc23dafe5
vkUpdateDescriptorSets(): pDescriptorWrites[0].pBufferInfo[2].offset (18362368) is greater than or equal to buffer size (8388608).
The Vulkan spec states: offset must be less than the size of buffer (https://docs.vulkan.org/spec/latest/chapters/descriptorsets.html#VUID-VkDescriptorBufferInfo-offset-00340)

-		dst_buf	{buffer=shared_ptr {buffer={m_buffer=0x0000be00000000be {...} } device_memory={m_deviceMemory=0x0000bf00000000bf {...} } ...} [0x00000003 strong refs] [make_shared] ...}	vk_subbuffer
+		buffer	shared_ptr {buffer={m_buffer=0x0000be00000000be {...} } device_memory={m_deviceMemory=0x0000bf00000000bf {...} } ...} [0x00000003 strong refs] [make_shared]	std::shared_ptr<vk_buffer_struct>
		offset	0x0000000001183000	unsigned __int64
		size	0x0000000000800000	unsigned __int64
-		dst	0x0000018f98716fd0 {type=GGML_TYPE_F32 (0x00000000) buffer=0x0000018f98385c80 {iface={free_buffer=0x00007ffbc6647c50 {ggml-vulkan.dll!ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer *)} ...} ...} ...}	ggml_tensor *
		type	GGML_TYPE_F32 (0x00000000)	ggml_type
+		buffer	0x0000018f98385c80 {iface={free_buffer=0x00007ffbc6647c50 {ggml-vulkan.dll!ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer *)} ...} ...}	ggml_backend_buffer *
+		ne	0x0000018f98716fe0 {0x0000000000001000, 0x0000000000000200, 0x0000000000000001, 0x0000000000000001}	__int64[0x00000004]
+		nb	0x0000018f98717000 {0x0000000000000004, 0x0000000000004000, 0x0000000000800000, 0x0000000000800000}	unsigned __int64[0x00000004]
		op	GGML_OP_ADD (0x00000002)	ggml_op
+		op_params	0x0000018f98717024 {0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, ...}	int[0x00000010]
		flags	0x00000010	int
+		src	0x0000018f98717068 {0x00000191a1b1d1b0 {type=GGML_TYPE_F32 (0x00000000) buffer=0x0000018f97d4f130 {iface=...} ...}, ...}	ggml_tensor *[0x0000000a]
-		view_src	0x00000191a1b1d1b0 {type=GGML_TYPE_F32 (0x00000000) buffer=0x0000018f97d4f130 {iface={free_buffer=0x00007ffbc6647c50 {ggml-vulkan.dll!ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer *)} ...} ...} ...}	ggml_tensor *
		type	GGML_TYPE_F32 (0x00000000)	ggml_type
+		buffer	0x0000018f97d4f130 {iface={free_buffer=0x00007ffbc6647c50 {ggml-vulkan.dll!ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer *)} ...} ...}	ggml_backend_buffer *
+		ne	0x00000191a1b1d1c0 {0x0000000000001000, 0x0000000000000200, 0x0000000000000001, 0x0000000000000001}	__int64[0x00000004]
+		nb	0x00000191a1b1d1e0 {0x0000000000000004, 0x0000000000004000, 0x0000000000800000, 0x0000000000800000}	unsigned __int64[0x00000004]
		op	GGML_OP_MUL_MAT (0x0000001d)	ggml_op
+		op_params	0x00000191a1b1d204 {0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, ...}	int[0x00000010]
		flags	0x00000010	int
+		src	0x00000191a1b1d248 {0x0000018f9f29d940 {type=GGML_TYPE_Q4_0 (0x00000002) buffer=0x0000018f97d4f670 {...} ...}, ...}	ggml_tensor *[0x0000000a]
+		view_src	0x0000000000000000 <NULL>	ggml_tensor *
		view_offs	0x0000000000000000	unsigned __int64
		data	0x0000000001184000	void *
+		name	0x00000191a1b1d2b0 "attn_out-0"	char[0x00000040]
		extra	0x0000000000000000	void *
+		padding	0x00000191a1b1d2f8 ""	char[0x00000008]
		view_offs	0x0000000000000000	unsigned __int64
		data	0x0000000001184000	void *
+		name	0x0000018f987170d0 "attn_out-0 (view)"	char[0x00000040]
		extra	0x0000000000000000	void *
+		padding	0x0000018f98717118 ""	char[0x00000008]

It seems like tensor->data (and tensor->view_src->data) are too large. I haven't debugged further.

@jacekpoplawski
Copy link
Copy Markdown
Contributor

works for me on 2x3090 for llama 3 8B and Mistral Nemo 12B

on Devstral I have OOM (expected because model size?)
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 30720.00 MiB on device 0: cudaMalloc failed: out of memory
alloc_tensor_range: failed to allocate CUDA0 buffer of size 32212254720

Thread 1 "llama-server" received signal SIGABRT, Aborted.
Download failed: Invalid argument.  Continuing without source file ./nptl/./nptl/pthread_kill.c.
__pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
warning: 44     ./nptl/pthread_kill.c: No such file or directory
(gdb) bt
#0  __pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (threadid=<optimized out>, signo=6) at ./nptl/pthread_kill.c:89
#2  __GI___pthread_kill (threadid=<optimized out>, signo=signo@entry=6) at ./nptl/pthread_kill.c:100
#3  0x00007fffedc4579e in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007fffedc288cd in __GI_abort () at ./stdlib/abort.c:73
#5  0x00007ffff7719ec5 in ggml_abort (file=0x7ffff77bdaa8 "/home/jacek/git/llama.cpp/ggml/src/ggml-backend.cpp", line=119, fmt=0x7ffff77bd88f "GGML_ASSERT(%s) failed")
    at /home/jacek/git/llama.cpp/ggml/src/ggml.c:256
#6  0x00007ffff773565b in ggml_backend_buffer_get_size (buffer=0x0) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend.cpp:119
#7  0x00007ffff7743b0e in ggml_backend_meta_alloc_ctx_tensors_from_buft (ctx=0x555557dc41f0, buft=0x55555b032b18) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp:638
#8  0x00007ffff7734a4a in ggml_backend_alloc_ctx_tensors_from_buft (ctx=0x555557dc41f0, buft=0x55555b032b18) at /home/jacek/git/llama.cpp/ggml/src/ggml-alloc.c:1245
#9  0x00007ffff6f3d6c2 in llama_kv_cache::llama_kv_cache (this=0x555556363770, model=..., type_k=GGML_TYPE_F16, type_v=GGML_TYPE_F16, v_trans=false, offload=true, unified=true, kv_size=393216,
    n_seq_max=4, n_pad=1, n_swa=0, swa_type=LLAMA_SWA_TYPE_NONE, filter=..., reuse=...) at /home/jacek/git/llama.cpp/src/llama-kv-cache.cpp:190
#10 0x00007ffff700faec in llama_model::create_memory (this=0x555556357ea0, params=..., cparams=...) at /home/jacek/git/llama.cpp/src/llama-model.cpp:7617
#11 0x00007ffff6ed008f in llama_context::llama_context (this=0x555557a65260, model=..., params=...) at /home/jacek/git/llama.cpp/src/llama-context.cpp:274
#12 0x00007ffff6edd53c in llama_init_from_model (model=0x555556357ea0, params=...) at /home/jacek/git/llama.cpp/src/llama-context.cpp:3046
#13 0x00005555558f4e24 in common_init_result::common_init_result (this=0x55555624dc70, params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1183
#14 0x00005555558f50d0 in common_init_from_params (params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1215
#15 0x0000555555710999 in server_context_impl::load_model (this=0x55555636eba0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:625
#16 0x00005555556e9ac8 in server_context::load_model (this=0x7fffffff79b0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:2856
#17 0x00005555556177bd in main (argc=7, argv=0x7fffffffe0c8) at /home/jacek/git/llama.cpp/tools/server/server.cpp:248
(gdb)
but on MInistral 14B too
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 20480.00 MiB on device 0: cudaMalloc failed: out of memory
alloc_tensor_range: failed to allocate CUDA0 buffer of size 21474836480

Thread 1 "llama-server" received signal SIGABRT, Aborted.
Download failed: Invalid argument.  Continuing without source file ./nptl/./nptl/pthread_kill.c.
__pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
warning: 44     ./nptl/pthread_kill.c: No such file or directory
(gdb) bt
#0  __pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (threadid=<optimized out>, signo=6) at ./nptl/pthread_kill.c:89
#2  __GI___pthread_kill (threadid=<optimized out>, signo=signo@entry=6) at ./nptl/pthread_kill.c:100
#3  0x00007fffedc4579e in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007fffedc288cd in __GI_abort () at ./stdlib/abort.c:73
#5  0x00007ffff7719ec5 in ggml_abort (file=0x7ffff77bdaa8 "/home/jacek/git/llama.cpp/ggml/src/ggml-backend.cpp", line=119, fmt=0x7ffff77bd88f "GGML_ASSERT(%s) failed")
    at /home/jacek/git/llama.cpp/ggml/src/ggml.c:256
#6  0x00007ffff773565b in ggml_backend_buffer_get_size (buffer=0x0) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend.cpp:119
#7  0x00007ffff7743b0e in ggml_backend_meta_alloc_ctx_tensors_from_buft (ctx=0x555557e339e0, buft=0x55555abb40b8) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp:638
#8  0x00007ffff7734a4a in ggml_backend_alloc_ctx_tensors_from_buft (ctx=0x555557e339e0, buft=0x55555abb40b8) at /home/jacek/git/llama.cpp/ggml/src/ggml-alloc.c:1245
#9  0x00007ffff6f3d6c2 in llama_kv_cache::llama_kv_cache (this=0x55555abb9470, model=..., type_k=GGML_TYPE_F16, type_v=GGML_TYPE_F16, v_trans=false, offload=true, unified=true, kv_size=262144,
    n_seq_max=4, n_pad=1, n_swa=0, swa_type=LLAMA_SWA_TYPE_NONE, filter=..., reuse=...) at /home/jacek/git/llama.cpp/src/llama-kv-cache.cpp:190
#10 0x00007ffff700faec in llama_model::create_memory (this=0x555556357e50, params=..., cparams=...) at /home/jacek/git/llama.cpp/src/llama-model.cpp:7617
#11 0x00007ffff6ed008f in llama_context::llama_context (this=0x555557a65140, model=..., params=...) at /home/jacek/git/llama.cpp/src/llama-context.cpp:274
#12 0x00007ffff6edd53c in llama_init_from_model (model=0x555556357e50, params=...) at /home/jacek/git/llama.cpp/src/llama-context.cpp:3046
#13 0x00005555558f4e24 in common_init_result::common_init_result (this=0x55555624dc70, params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1183
#14 0x00005555558f50d0 in common_init_from_params (params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1215
#15 0x0000555555710999 in server_context_impl::load_model (this=0x55555636eb30, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:625
#16 0x00005555556e9ac8 in server_context::load_model (this=0x7fffffff79b0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:2856
#17 0x00005555556177bd in main (argc=7, argv=0x7fffffffe0c8) at /home/jacek/git/llama.cpp/tools/server/server.cpp:248
(gdb)
qwen 4B has a different issue
/home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp:386: GGML_ASSERT(ne[split_dim] % n_simple_bufs == 0) failed

Thread 1 "llama-server" received signal SIGABRT, Aborted.
Download failed: Invalid argument.  Continuing without source file ./nptl/./nptl/pthread_kill.c.
__pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
warning: 44     ./nptl/pthread_kill.c: No such file or directory
(gdb) bt
#0  __pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (threadid=<optimized out>, signo=6) at ./nptl/pthread_kill.c:89
#2  __GI___pthread_kill (threadid=<optimized out>, signo=signo@entry=6) at ./nptl/pthread_kill.c:100
#3  0x00007fffedc4579e in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007fffedc288cd in __GI_abort () at ./stdlib/abort.c:73
#5  0x00007ffff7719ec5 in ggml_abort (file=0x7ffff77bedd8 "/home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp", line=386, fmt=0x7ffff77beb9f "GGML_ASSERT(%s) failed")
    at /home/jacek/git/llama.cpp/ggml/src/ggml.c:256
#6  0x00007ffff7741e84 in ggml_backend_meta_buffer_init_tensor (buffer=0x55555a4421d0, tensor=0x55555a46df70) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp:386
#7  0x00007ffff7743a52 in ggml_backend_meta_alloc_ctx_tensors_from_buft (ctx=0x55555a4422f0, buft=0x55555a426188) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp:632
#8  0x00007ffff7734a4a in ggml_backend_alloc_ctx_tensors_from_buft (ctx=0x55555a4422f0, buft=0x55555a426188) at /home/jacek/git/llama.cpp/ggml/src/ggml-alloc.c:1245
#9  0x00007ffff6fff452 in llama_model::load_tensors (this=0x555556357e50, ml=...) at /home/jacek/git/llama.cpp/src/llama-model.cpp:7055
#10 0x00007ffff6e557fc in llama_model_load (fname="/mnt/models2/Qwen/Qwen_Qwen3-4B-Q8_0.gguf", splits=std::vector of length 0, capacity 0, model=..., params=...)
    at /home/jacek/git/llama.cpp/src/llama.cpp:876
#11 0x00007ffff6e56ce3 in llama_model_load_from_file_impl (path_model="/mnt/models2/Qwen/Qwen_Qwen3-4B-Q8_0.gguf", splits=std::vector of length 0, capacity 0, params=...)
    at /home/jacek/git/llama.cpp/src/llama.cpp:1069
#12 0x00007ffff6e56fe3 in llama_model_load_from_file (path_model=0x555556367610 "/mnt/models2/Qwen/Qwen_Qwen3-4B-Q8_0.gguf", params=...) at /home/jacek/git/llama.cpp/src/llama.cpp:1096
#13 0x00005555558f46c9 in common_init_result::common_init_result (this=0x55555624dc70, params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1107
#14 0x00005555558f50d0 in common_init_from_params (params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1215
#15 0x0000555555710999 in server_context_impl::load_model (this=0x55555636ead0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:625
#16 0x00005555556e9ac8 in server_context::load_model (this=0x7fffffff79d0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:2856
#17 0x00005555556177bd in main (argc=7, argv=0x7fffffffe0e8) at /home/jacek/git/llama.cpp/tools/server/server.cpp:248
(gdb)

@DocShotgun
Copy link
Copy Markdown
Contributor

Interesting. If the CPU backend is able to be virtualized into multiple devices as described here, would it be possible to allow multiple NUMA nodes to be parallelized?

@gopinath87607
Copy link
Copy Markdown

will this pr solve if we use multiple rpc connected to gpu with cpu and when we use cpu moe flag ?

@JohannesGaessler
Copy link
Copy Markdown
Contributor Author

@jacekpoplawski the combination of --split-mode tensor and llama_params_fit is not implemented so you'll have to set the context size manually if you didn't already.

@DocShotgun longer-term I intend to also enable this code for better NUMA support though I'm not yet sure what to do in terms of hardware for development. Originally I had intended to buy 1.5 TiB of DDR5 RAM and 2 EPYC CPUs but at the current prices that would be financially irresponsible of me to do.

@gopinath87607 I don't understand what you mean.

@FullstackSensei
Copy link
Copy Markdown

@DocShotgun longer-term I intend to also enable this code for better NUMA support though I'm not yet sure what to do in terms of hardware for development. Originally I had intended to buy 1.5 TiB of DDR5 RAM and 2 EPYC CPUs but at the current prices that would be financially irresponsible of me to do.

If you're looking for a DDR4 platform, I might be able to help with that, but unfortunately, not so much with RAM for the system. I'm also in Germany.

Wouldn't mind giving you access to my systems if you want. Have two dual Xeon systems one with P40s and the other with Mi50s.

Would be very happy to help either way.

@ggerganov
Copy link
Copy Markdown
Member

ggerganov commented Feb 6, 2026

Started doing some initial tests to get familiar with the changes. I'm using virtual Metal devices and things appear to be mostly working - e.g. seeing graph execution on both devices, llama-perplexity produces the same result with 2 devices.

However, the following command does not produce identical results on each run:

GGML_METAL_DEVICES=2 ./bin/llama-completion -m ~/models/llama-3.1-8b/ggml-model-f16.gguf -no-cnv -p "I believe the meaning of life is" -n 32 --top-k 1 -sm tensor

I think this either means that there could be a problem with the fallback for the missing backend API, or that I could have made an error in the implementation of the events and the async copy in Metal. Will investigate more.

For context, all of them are optional and can use existing operations as a fallback.

In the meantime, @JohannesGaessler do you confirm that the command above has deterministic results on your end? Also, if is it deterministic with the fallback calls?

@JohannesGaessler
Copy link
Copy Markdown
Contributor Author

JohannesGaessler commented Feb 6, 2026

Generally speaking you cannot expect bit-for-bit identical results if you split the computation across multiple virtual devices. The order in which floats are summed up will be different which will in turn change the rounding error. If I run llama-perplexity using LLaMA 3 8b f16 I get 6.2560 for -sm layer and 6.2556 for -sm tensor. It's of course still possible that there are bugs on top of that, I would just say changes to the results are expected. Long-term I think we should test this new code in the same way we test ggml op fusion in test-backend-ops.

If you use -sm tensor with only a single GPU the executed ops should be the exact same and the result should be bit-for-bit identical to -sm layer.

@JohannesGaessler
Copy link
Copy Markdown
Contributor Author

Sorry, I think I misread your post. If you are saying that the results are not deterministic with 2 virtual GPUs but they are with 1 GPU then that I think is indicative of a bug w.r.t. the synchronization.

@ggerganov
Copy link
Copy Markdown
Member

Yes, I understand that 1GPU vs 2GPU will not be bit-for-bit identical. What I mean is that in my test, running the command with 2 GPUs several times produces non-deterministic results from one run to the other:

GGML_METAL_DEVICES=2 ./bin/llama-completion -m ~/models/llama-3.1-8b/ggml-model-f16.gguf -no-cnv -p "I believe the meaning of life is" -n 32 --top-k 1 -sm tensor

# run 1
I believe the meaning of life is to find your gift. The purpose of life is to give it away.
I believe that the meaning of life is to find your gift. The purpose of life

# run 2
I believe the meaning of life is to find your gift. The purpose of life is to give it away. To give it away, you have to find it. [end of text]

# run 3
I believe the meaning of life is to be happy. I believe that happiness is the only thing that matters. I believe that happiness is the only thing that matters. I believe that happiness is the

Sorry, I think I misread your post. If you are saying that the results are not deterministic with 2 virtual GPUs but they are with 1 GPU then that I think is indicative of a bug w.r.t. the synchronization.

Yes, seems like a synchronization issue. Was wondering if you observe it on your end with and without the fallback backend API. This will give me indication where to look for the issue.

@JohannesGaessler
Copy link
Copy Markdown
Contributor Author

With the command you posted (minus the GGML_METAL_DEVICES=2) I get deterministic results on 2x RTX 4090, both with and without the fallback for shfl_tensor_async. The output of llama-perplexity is bit-for-bit identical with vs. without the fallback. So presumably either there is a bug with Metal synchronization or I made assumptions about the behavior of ggml backends that are correct for CUDA but not universally.

Copy link
Copy Markdown
Member

@pwilkin pwilkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack for arg.cpp changes ;) BTW, shouldn't we be deprecating --split-mode row with this?

@JohannesGaessler JohannesGaessler changed the title ggml: backend-agnostic tensor parallelism ggml: backend-agnostic tensor parallelism (experimental) Apr 9, 2026
@JohannesGaessler JohannesGaessler merged commit d6f3030 into ggml-org:master Apr 9, 2026
86 of 101 checks passed
@gopinath87607
Copy link
Copy Markdown

hope one day this --split-mode tensor will work on over the rpc too btw nice work guys.

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented Apr 9, 2026

Congratulations on this :)

Wanted to try, but alas:
llama_init_from_model: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented #ohwell

@IMbackK
Copy link
Copy Markdown
Collaborator

IMbackK commented Apr 9, 2026

hope one day this --split-mode tensor will work on over the rpc too btw nice work guys.

its going to be pretty brutal across consumer rpc interconnects

@FullstackSensei
Copy link
Copy Markdown

hope one day this --split-mode tensor will work on over the rpc too btw nice work guys.

Even a 40gb NIC will be too slow to handle this on any sizeable models.

@FullstackSensei
Copy link
Copy Markdown

Congratulations everyone! This was a gig effort!

@digitalscream
Copy link
Copy Markdown

Nice work, all - this was the only feature that kept me looking sideways at vLLM.

Now, how many beers do I need to buy the Vulkan team to get them to optimise for it? :D

@nawoa
Copy link
Copy Markdown

nawoa commented Apr 9, 2026

Would the 200G NCCL/RDMA link between DGX Spark nodes be compatible with this?

If I understand correctly this PR allows for arbitrary tensor splits. This might allow Sparks in a 3-node ring topology (100G interconnect between each unit) to actually be useful. Currently with vLLM the only options are 2/4/8.

I don't know what performance would look like in Llama vs vLLM but 384GB of RAM @ ~820 GB/s effective RAM bandwidth would be pretty nice. Saves the cost and latency of a CX7 switch.

Sorry if I'm off-base.

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Apr 9, 2026

bit of a hmm.. i can test perplexity of the gemma base at Q8_0 and everything is fine.

When I do it to q8_0_xl of the IT version it always craps itself out at iteration 260/576. I tried removing undervolt and some other things. In the case of using cuda devices 0,1,2,3 the GPU0 PCIE link is also taken down to 2x until I reboot despite there being no lane errors or any other messages. When I flip it around to 3,2,1,0 the test still crashes in the same place but I don't eat the link and I can keep re-running things.


1283.1084,[258]11271.1522,[259]11265.8000,[260]11286.2032,/home/supermicro/ai/llama.cpp.main/ggml/src/ggml-cuda/ggml-cuda.cu:97: CUDA error
CUDA error: an internal operation failed
  current device: 0, in function ggml_cuda_op_mul_mat_cublas at /home/supermicro/ai/llama.cpp.main/ggml/src/ggml-cuda/ggml-cuda.cu:1504
  cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, row_diff, src1_ncols, ne10, &alpha_f16, src0_ptr, CUDA_R_16F, ne00, src1_ptr, CUDA_R_16F, ne10, &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, CUBLAS_COMPUTE_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)
[New LWP 136334]
[New LWP 136348]
[New LWP 136349]
[New LWP 136350]
[New LWP 136351]
[New LWP 136352]
[New LWP 136353]
[New LWP 136354]
[New LWP 136355]
[New LWP 136356]
[New LWP 136357]
[New LWP 136363]
[New LWP 136364]
[New LWP 136365]
[New LWP 136366]
[New LWP 136367]
[New LWP 136368]
[New LWP 136369]
[New LWP 136370]
[New LWP 136371]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007f1df14ea42f in __GI___wait4 (pid=161186, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30	../sysdeps/unix/sysv/linux/wait4.c: No such file or directory.
#0  0x00007f1df14ea42f in __GI___wait4 (pid=161186, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30	in ../sysdeps/unix/sysv/linux/wait4.c
#1  0x00007f1df215b40f in ggml_print_backtrace () from /home/supermicro/ai/llama.cpp.main/bin/libggml-base.so.0
#2  0x00007f1df215b582 in ggml_abort () from /home/supermicro/ai/llama.cpp.main/bin/libggml-base.so.0
#3  0x00007f1de13a3203 in ggml_cuda_error(char const*, char const*, char const*, int, char const*) () from /home/supermicro/ai/llama.cpp.main/bin/libggml-cuda.so.0
#4  0x00007f1de13ab94b in ggml_cuda_op_mul_mat_cublas(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*, char const*, float const*, char const*, float*, long, long, long, long, CUstream_st*) () from /home/supermicro/ai/llama.cpp.main/bin/libggml-cuda.so.0
#5  0x00007f1de13b12e2 in ggml_cuda_op_mul_mat(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*, void (*)(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*, char const*, float const*, char const*, float*, long, long, long, long, CUstream_st*), void (*)(float const*, int const*, void*, ggml_type, long, long, long, long, long, long, long, long, CUstream_st*)) () from /home/supermicro/ai/llama.cpp.main/bin/libggml-cuda.so.0
#6  0x00007f1de13b4f36 in ggml_cuda_compute_forward(ggml_backend_cuda_context&, ggml_tensor*) () from /home/supermicro/ai/llama.cpp.main/bin/libggml-cuda.so.0
#7  0x00007f1de13ba351 in ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context*, ggml_cgraph*, bool, bool, void const*) () from /home/supermicro/ai/llama.cpp.main/bin/libggml-cuda.so.0
#8  0x00007f1de13bc182 in ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) () from /home/supermicro/ai/llama.cpp.main/bin/libggml-cuda.so.0
#9  0x00007f1df218408b in ggml_backend_meta_graph_compute(ggml_backend*, ggml_cgraph*) () from /home/supermicro/ai/llama.cpp.main/bin/libggml-base.so.0
#10 0x00007f1df2179a51 in ggml_backend_sched_graph_compute_async () from /home/supermicro/ai/llama.cpp.main/bin/libggml-base.so.0
#11 0x00007f1df229cee2 in llama_context::graph_compute(ggml_cgraph*, bool) () from /home/supermicro/ai/llama.cpp.main/bin/libllama.so.0
#12 0x00007f1df229f50f in llama_context::process_ubatch(llama_ubatch const&, llm_graph_type, llama_memory_context_i*, ggml_status&) () from /home/supermicro/ai/llama.cpp.main/bin/libllama.so.0
#13 0x00007f1df22a7b29 in llama_context::decode(llama_batch const&) () from /home/supermicro/ai/llama.cpp.main/bin/libllama.so.0
#14 0x00007f1df22a95ec in llama_decode () from /home/supermicro/ai/llama.cpp.main/bin/libllama.so.0
#15 0x000055d56d435aba in perplexity(llama_context*, common_params const&, int) ()
#16 0x000055d56d425a9a in main ()
[Inferior 1 (process 136333) detached]
Aborted (core dumped)

A sweep bench up to 32k is fine and other kinds of inference appear to work. It's just perplexity. My other NCCL programs don't have this problem despite pushing more watts, etc. Trying to determine if it's a me problem or a code problem.

XeonBloomfield added a commit to XeonBloomfield/llama.cpp that referenced this pull request Apr 11, 2026
* model, mtmd: fix gguf conversion for audio/vision mmproj (ggml-org#21309)

* fix gguf conversion for audio/vision mmproj

* fix test

* tests: allow exporting graph ops from HF file without downloading weights (ggml-org#21182)

* tests: allow exporting graph ops from HF file without downloading weights

* use unique_ptr for llama_context in HF metadata case

* fix missing non-required tensors falling back to type f32

* use unique pointers where possible

* use no_alloc instead of fixing f32 fallback

* fix missing space

* ggml-webgpu: add vectorized flash attention (ggml-org#20709)

* naive vectorized version

* add vectorized flash attention

* update vec version

* remove unused path and shader

* remove unused helper functions

* add comments

* remove pad path

* ggml-webgpu: fix flash-attn vec nwg=1 path and tighten vec specialization

* change back to vec4

* enable multi split

* enable vec path when:
- Q->ne[1] < 20
- Q->ne[0] % 32 == 0
- V->ne[0] % 4 == 0
- K->type == f16

* update flast_attn_vec_split.wgsl to reduce redundant workgroup barrier usage and use select

* enable vec path for q4 and q8

* flash-attn vec nwg=1 fast path (skip tmp/reduce staging)

* use packed f16 K loads in flash-attn vec split

* use packed f16 K loads in flash-attn vec split on host side

* tune flash-attn vec f16 VEC_NE by head dim

* cleanup

* cleanup

* keep host side clean

* cleanup host side

* change back to original host wait/submit behavior

* formatting

* reverted param-buffer pool r ecfactor

* add helper functions

* ggml-webgpu: move flash-attn vec pipeline caching back into shader lib

* ggml-webgpu: remove duplicate functions

* ggml-webgpu: reserve flash-attn vec scratch in dst buffer allocation

* ggml-webgpu: revert unrelated change

* ggml-webgpu: revert deleted comment

* disable uniformity check

* remove unnecessary change

* Update ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl

* Update ggml/src/ggml-webgpu/ggml-webgpu.cpp

---------

Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* tests : add unit test coverage for llama_tensor_get_type (ggml-org#20112)

* Add unit test coverage for llama_tensor_get_type

* Fix merge conflicts, add more schemas

* clang formatter changes

* Trailing whitespace

* Update name

* Start rebase

* Updating files with upstream changes prior to rebase

* Changes needed from rebase

* Update attn_qkv schema, change throw behaviour

* Fix merge conflicts

* White space

* Update with latest changes to state counters

* Revert accidental personal CLAUDE.md changes

* Change quotation mark

* Reuse metadata.name since we have it

* Move test-only stuff out of llama-quant.cpp

* Hide the regex functionality back in llama-quant.cpp, use a unique pointer to a new struct 'compiled_tensor_type_patterns' which contains the patterns

* cont : inital deslop guidelines

* Cleanup based on review comments

* Continue cleanup

* Small cleanup

* Manually set proper ordering of tensors, mostly applies to gemma

* Formatting

* Update tests/test-quant-type-selection.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Fix merge conflicts

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix: gemma 4 template (ggml-org#21326)

* [HIP] Bump ROCm version to 7.2.1 (ggml-org#21066)

Bump ROCm version on Linux from 7.2 to 7.2.1
Add gfx1102 target
Delete LLVM workaround since ROCm 7.2.1 has fix for ROCm 7.2 perf regression ROCm/rocm-systems#2865

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* ci : add AMD ZenDNN label to PR labeler (ggml-org#21345)

* ci : add AMD CPU label to PR labeler
Add automatic labeling for PRs that modify AMD CPU (ZenDNN) backend files

* ci : rename label AMD CPU to AMD ZenDNN in labeler config

Co-authored-by: Aaron Teo <taronaeo@gmail.com>

---------

Co-authored-by: Aaron Teo <taronaeo@gmail.com>

* (revert) kv-cache : do not quantize SWA KV cache (ggml-org#21332)

This reverts commit 17193cc.

* chat : avoid including json in chat.h (ggml-org#21306)

* rpc : reuse compute graph buffers (ggml-org#21299)

Reuse the buffer for the ggml context which is used for creating the
compute graph on the server side. This partially addresses a memory leak
created by the CUDA backend due to using buffer addresses as cache
keys.

ref: ggml-org#21265
ref: ggml-org#20315

* vocab: fix Gemma4 tokenizer (ggml-org#21343)

* seems to work

* fix case with new line

Co-authored-by: sayap <sokann@gmail.com>

* gemma 4: fix pre tok regex

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: sayap <sokann@gmail.com>

* ggml-zendnn : add MUL_MAT_ID op support for MoE models (ggml-org#21315)

* ggml-zendnn : add MUL_MAT_ID op support for MoE models
- Add MUL_MAT_ID op acceleration for Mixture-of-Experts models
- MUL_MAT_ID op fallback to CPU backend if total experts > 32
- Point ZenDNN lib to latest bits ZenDNN-2026-WW13

* ggml-zendnn : add braces to sgemm failure condition for consistency

Co-authored-by: Aaron Teo <taronaeo@gmail.com>

---------

Co-authored-by: Aaron Teo <taronaeo@gmail.com>

* fix: add openssl to nix dependencies (ggml-org#21353) (ggml-org#21355)

* HIP: build eatch ci build test for a different architecture (ggml-org#21337)

This helps improve our chances of finding build failures before the release workflow
builds for all architectures.

* fix: remove stale assert (ggml-org#21369)

* ci: add more binary checks (ggml-org#21349)

* jinja: coerce input for string-specific filters (ggml-org#21370)

* docs: Update build.md: HSA_OVERRIDE_GFX_VERSION clarification (ggml-org#21331)

The `HSA_OVERRIDE_GFX_VERSION` variable can be used in ROCm to override an unsupported target architecture with a similar but supported target architecture.

This does not and has never worked on Windows. I think the clarification could avoid driving Windows people towards this solution that does not work.

* docker : bump cuda12 to 12.9.1 (ggml-org#20920)

Co-authored-by: M1DNYT3 <m1dnyt3@MacBookPro.lan>
Co-authored-by: CISC <CISC@users.noreply.github.com>

* common : fix tool call type detection for nullable and enum schemas (ggml-org#21327)

* common : fix tool call type detection for nullable and enum schemas

* common, tests : fix grammar delegation for nullable/enum schemas and add tests

Fix enum type inference to scan all enum values (not just index 0) so
schemas like {"enum": [0, "celsius"]} correctly detect string type.

Fix schema_delegates in peg-parser to handle nullable type arrays
(["string", "null"]) and typeless enum schemas in raw mode, allowing
the tagged parser to use raw text instead of JSON-formatted strings.

Add test cases for Qwen3-Coder (TAG_WITH_TAGGED format):
- nullable string ["string", "null"]
- nullable string with null first ["null", "string"]
- nullable integer ["integer", "null"]
- enum without explicit type key

* common/parser: fix call ID detection (Mistral parser mostly) + atomicity for tag-json parsers (ggml-org#21230)

* Fix call ID detection (Mistral parser mostly) + atomicity for tag-json parsers

* Rename

* Update common/chat-auto-parser-generator.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* server: save and clear idle slots on new task (`--clear-idle`) (ggml-org#20993)

* server: clear idle slots KV from VRAM (LLAMA_KV_KEEP_ONLY_ACTIVE)

* server: move idle slot KV clearing to slot release

The save "cost" is now paid by the finishing request.

* server: add --kv-clear-idle flag, enable by default

* server: skip clearing last idle slot, clear on launch

* server: test --no-kv-clear-idle flag

* server: simplify on-release clearing loop

* server: remove on-release KV clearing, keep launch-only

* cont : clean-up

* tests: update log strings after --clear-idle rename

* tests: use debug tags instead of log message matching

* test: fix Windows CI by dropping temp log file unlink

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* ci: Add Windows Vulkan backend testing on Intel (ggml-org#21292)

* experimenting CI

* Experimenting CI fix for MinGW

* experimenting CI on Windows

* modified script for integration with VisualStudio

* added proxy handling

* adding python version for Windows execution

* fix iterator::end() dereference

* fixed proxy handling

* Fix errors occurring on Windows

* fixed ci script

* Reverted to master

* Stripping test items to simplify Windows test

* adjusting script for windows testing

* Changed shell

* Fixed shell

* Fixed shell

* Fix CI setting

* Fix CI setting

* Fix CI setting

* Experimenting ci fix

* Experimenting ci fix

* Experimenting ci fix

* Experimenting ci fix

* experimenting fix for unit test error

* Changed to use BUILD_LOW_PERF to skip python tests

* Fix CI

* Added option to specify Ninja generator

* Reverted proxy related changes

* ggml-webgpu: move from parameter buffer pool to single buffer with offsets (ggml-org#21278)

* Work towards removing bitcast

* Move rest of existing types over

* Add timeout back to wait and remove synchronous set_tensor/memset_tensor

* move to unpackf16 for wider compatibility

* cleanup

* Remove deadlock condition in free_bufs

* Start work on removing parameter buffer pools

* Simplify and optimize further

* simplify profile futures

* Fix stride

* Try using a single command buffer per batch

* formatting

* llama: add custom newline split for Gemma 4 (ggml-org#21406)

* llama-model: read final_logit_softcapping for Gemma 4 (ggml-org#21390)

* common : respect specified tag, only fallback when tag is empty (ggml-org#21413)

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* server: Fix undefined timing measurement errors in server context (ggml-org#21201)

Co-authored-by: Dan Hoffman <dhoffman@cyket.net>

* common : add gemma 4 specialized parser (ggml-org#21418)

* common : add gemma4 dedicated parser

* cont : add '<|tool_response>' as eog

* cont : emit JSON from Gemma4 tool call AST

* cont : more fixes

* cont : refactor convert function

* cont : refine rules and mapping

* cont : add more tests

* cont : clean up

* cont : remove autoparser gemma4 implementation

* cont : more cleanup

* cont : rename gemma4.jinja to match the others

* cont : add custom template to support interleaved thinking

* cont : preserve reasoning in model turns

* cont : fix initializer error

* cont : fix unused vars

* cont : fix accidental static

* cont : fix specialized_template signature

* fix extra semicolon

* remove debug line and extra space [no ci]

* ci: fix vulkan workflow referencing non-existent action (ggml-org#21442)

* ci: lower cuda12 floor to 12.8.1 for broader host compatibility (ggml-org#21438)

Co-authored-by: M1DNYT3 <m1dnyt3@MacBookPro.lan>

* server : fix logging of build + system info (ggml-org#21460)

This PR changes the logging that occurs at startup of llama-server.
Currently, it is redundant (including CPU information twice) and it is
missing the build + commit info.

* ci : use default RISE RISC-V Runners (ggml-org#21263)

* model : add HunyuanOCR support (ggml-org#21395)

* HunyuanOCR: add support for text and vision models

- Add HunyuanOCR vision projector (perceiver-based) with Conv2d merge
- Add separate HUNYUAN_OCR chat template (content-before-role format)
- Handle HunyuanOCR's invalid pad_token_id=-1 in converter
- Fix EOS/EOT token IDs from generation_config.json
- Support xdrope RoPE scaling type
- Add tensor mappings for perceiver projector (mm.before_rms, mm.after_rms, etc.)
- Register HunYuanVLForConditionalGeneration for both text and mmproj conversion

* fix proper mapping

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

* Update tools/mtmd/clip.cpp

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

* address comments

* update

* Fix typecheck

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* llama : correct platform-independent loading of BOOL metadata (ggml-org#21428)

* model-loader : fix GGUF bool array conversion

* model-loader : fix remaining GGUF bool pointer uses

* hexagon: slight optimization for argosrt output init (ggml-org#21463)

* sycl : handle other FA case (ggml-org#21377)

* convert : set "add bos" == True for Gemma 4 (ggml-org#21500)

* convert : set "add bos" == True for Gemma 4

* cont : handle old GGUFs

* docs: add hunyuan-ocr gguf, also add test [no ci] (ggml-org#21490)

* server : handle unsuccessful sink.write in chunked stream provider (ggml-org#21478)

Check the return value of sink.write() in the chunked content provider
and return false when the write fails, matching cpp-httplib's own
streaming contract. This prevents logging chunks as sent when the sink
rejected them and properly aborts the stream on connection failure.

* convert : fix block_ff_dim retrieval for lfm2 (ggml-org#21508)

* vocab : add byte token handling to BPE detokenizer for Gemma4 (ggml-org#21488)

* llama-bench: add `-fitc` and `-fitt` to arguments (ggml-org#21304)

* llama-bench: add `-fitc` and `-fitt` to arguments

* update README.md

* address review comments

* update compare-llama-bench.py

* [CUDA ] Write an optimized flash_attn_stream_k_fixup kernel (ggml-org#21159)

* Write an optimized flash_attn_stream_k_fixup kernel

Write a specialized and more optimized kernel for cases where nblocks_stream_k is multiple of ntiles_dst.
Make nblocks_stream_k to multiple of ntiles_dst if nblocks_stream_k > 2 * ntiles_dst

* Use the new kernel only for nblocks_stream_k_raw > 4 * ntiles_dst to make sure we have enough concurrency on GPUs

* Address review comments

* Address review comments

* Revert variable names to original

* cli: fix stripping of \n in multiline input (ggml-org#21485)

* llama-cli: fix stripping of \n in multiline input

* Change & string to string_view

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Fix EditorConfig linter error

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* ggml: add Q1_0 1-bit quantization support (CPU) (ggml-org#21273)

* ggml: add Q1_0 and Q1_0_g128 1-bit quantization support (CPU)

* add generic fallback for x86

* remove Q1_0 (group size 32)

* rename Q1_0_g128 => Q1_0

* fix Q1_0 LlamaFileType Enum

* Fix trailing spaces; add generic fallback for othre backends

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix /r/n spacing + arch-fallback

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* ggml-webgpu: Add the support of `MUL_MAT_ID` (ggml-org#21147)

* Add mul_mat_id support to WebGPU

* Apply suggestion from @reeselevine

---------

Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* docs: fix typo in build.md (emdawbwebgpu -> emdawnwebgpu) (ggml-org#21518)

* [SYCL] Add Q8_0 reorder optimization (~3x tg speedup on Intel Arc) (ggml-org#21527)

Extend the existing reorder optimization to Q8_0. The reorder
separates scale factors from weight data for coalesced memory
access -- was implemented for Q4_0/Q4_K/Q6_K but Q8_0 was missing.

On Arc Pro B70 (Xe2), Q8_0 tg goes from 4.88 to 15.24 t/s (3.1x)
on Qwen3.5-27B. BW utilization: 21% -> 66%.

The key fix beyond the kernels: Q8_0 was missing from the type
check in ggml_backend_sycl_buffer_init_tensor() that allocates
the extra struct carrying the reorder flag -- so the optimization
was silently skipped.

AI (Claude) was used to assist with root cause investigation and
writing the kernel code. All code was human-reviewed and tested
on real hardware.

Fixes: ggml-org#21517

* Fix rtl text rendering (ggml-org#21382)

* Fix Arabic RTL text rendering in web UI

- Add dir='auto' attributes to markdown containers and blocks
- Implement post-processing to add dir='auto' to all text elements
- Replace directional CSS properties with logical properties for proper RTL list alignment
- Ensure bidirectional text support for mixed Arabic/English content

* Clean up commented duplicate function

Remove the commented-out duplicate transformMdastNode function
that was left over from refactoring.

* Fix Arabic RTL text rendering in web UI

- Add dir='auto' attributes to markdown containers and blocks
- Implement post-processing to add dir='auto' to all text elements
- Replace directional CSS properties with logical properties for proper RTL list alignment
- Minor code formatting improvements

This ensures bidirectional text support for mixed Arabic/English content in the llama.cpp web UI.

* Implement rehype plugin for comprehensive RTL text support

- Add rehypeRtlSupport plugin that applies dir='auto' to all elements with children
- Replace DOMParser-based approach with efficient HAST tree processing
- Remove hardcoded element lists for better maintainability
- Ensure proper bidirectional text rendering for mixed RTL/LTR content

* Fix RTL text rendering with rehype plugin and cleanup

* fix: prettier formatting

* fix: Detect streaming state in reasoning content blocks (ggml-org#21549)

* ggml-cuda : fix CDNA2 compute capability constant for gfx90a (MI210) (ggml-org#21519)

GGML_CUDA_CC_CDNA2 was set to 0x910
Fix by setting the constant to 0x90a to match the actual gfx90a ISA.

* webui : store reasoning_content so it is sent back in subsequent requests (ggml-org#21249)

* vulkan: add FA dequant for q4_1, q5_0, q5_1, iq4_nl (ggml-org#21029)

Add dequantize4() implementations for Q4_1, Q5_0, Q5_1, and IQ4_NL
in the flash attention base shader. Register them in the shader
generator, pipeline creation, and enable in the scalar/coopmat1 FA
support check.

* ggml: Vulkan build, Linux -- output error string for errno on fork failure (ggml-org#20868) (ggml-org#20904)

* ggml : deprecate GGML_OP_ADD1 (ggml-org#21363)

* ggml : deprecate GGML_OP_ADD1

* cont : remove tests

* cont : re-enable vulkan check

* server : fix restore for checkpoints with pos_min == 0 (ggml-org#21510)

* llama: remove per-arch tensor name lists (ggml-org#21531)

* unicode : add custom Qwen2 regex handler to fix segfault on long input (ggml-org#21257)

* unicode : add custom Qwen2 regex handler to fix segfault on long input

std::regex uses recursive backtracking internally, which causes a stack
overflow (segfault) when tokenizing long sequences of repeated characters
(e.g. 43K 'A's). The Qwen2 tokenizer regex differs from Llama3 only in
the digit pattern (\p{N} vs \p{N}{1,3}), so it was falling through to
the std::regex fallback path instead of using a custom handler.

Add unicode_regex_split_custom_qwen2() following the established pattern
used by gpt2, llama3, kimi_k2, and afmoe custom handlers.

Closes: ggml-org#21113

* cont : remove TODO comment

* cont : update comment to reflect original regex

* use the correct regex in the comment this time... [no ci]

---------

Co-authored-by: Aldehir Rojas <hello@alde.dev>

* llama-server: fix model params not propagated (ggml-org#21509)

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* CUDA: check for buffer overlap before fusing (ggml-org#21566)

* CUDA: check for buffer overlap before fusing

* use ggml_cuda_check_fusion_memory_ranges

* ggml-webgpu: parameterize submission size and add iOS specific limits (ggml-org#21533)

* Work towards removing bitcast

* Move rest of existing types over

* Add timeout back to wait and remove synchronous set_tensor/memset_tensor

* move to unpackf16 for wider compatibility

* cleanup

* Remove deadlock condition in free_bufs

* Start work on removing parameter buffer pools

* Simplify and optimize further

* simplify profile futures

* Fix stride

* Try using a single command buffer per batch

* formatting

* Add parameters for different browsers in-flight submissions

* Update handling of batch size too

* Throttle ios as much as possible

* Increase timeout for llvm-pipe testing

* kv-cache : support attention rotation for heterogeneous iSWA (ggml-org#21513)

* kv-cache : support attention rotation for heterogeneous iSWA

* cont : remove assert

* gguf-py : fix missing comma after bad merge in tensor-mapping (ggml-org#21558)

This commit adds a missing comma in the vision encoder attention qkv
block.

The motivation for this change is that without the comma there will be
a string concatenation of the Kimi-K2.5 and the Nemotron Nano v2 VL
tensor mappings which will be broken.

* ggml-cuda: ds_read_b128 for q4_0 and q4_1 mmq kernels (ggml-org#21168)

* ds_read_b128 for q4_0 and q4_1 mmq kernels

     Current for loop generates ds_read_b32 instructions with hip compiler, the new solution generates ds_read_b128 instructions for the same operation, saving some LDS bandwidth. Tested on MI50 and RX6800XT, its faster on both.

* Vectorized lds load update: used ggml_cuda_get_max_cpy_bytes and ggml_cuda_memcpy_1 functions for generic implementation

* Explicit for loop in mmq, renamed vec into tmp

* Fixed max_cpy usage in the loading loop

* Fixed typo in q4_1 kernel

* Update ggml/src/ggml-cuda/mmq.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/mmq.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/mmq.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Renoved trailing white line 500

* Update mmq.cuh removed other whitelines

* Remove trailing whitespaces

---------

Co-authored-by: iacopPBK <iacopPBK@users.noreply.github.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: iacopPBK <iacop@deneb.com>

* CUDA: make cuda graphs props check faster (ggml-org#21472)

* CUDA: compute fast hash instead of expensive props check

* use seen node

* use memcp

* devops: kleidiai: provide KleidiAI-Enabled ARM Release Artifact (ggml-org#21259)

* Unified macOS release setup with strategy-matrix block
 * Added KleidiAI arm64 macOS release definition


Change-Id: I05520889ffc646488a178d06817a17f29274465a

Signed-off-by: Martin Klacer <martin.klacer@arm.com>

* webui: fix syntax highlighting lost after streaming for non-common languages (ggml-org#21206)

* webui: fix syntax highlighting lost for non-common languages after streaming

rehype-highlight uses lowlight internally, which only bundles 37 "common"
languages. The streaming code path uses highlight.js directly (192 languages),
so languages like Haskell highlight correctly while streaming but lose all
color once the code block closes. Pass the full lowlight language set to
rehype-highlight so both paths support the same languages.

* webui: rebuild static files after rebase

* model : support step3-vl-10b (ggml-org#21287)

* feat: support step3-vl-10b

* use fused QKV && mapping tensor in tensor_mapping.py

* guard hardcoded params and drop crop metadata

* get understand_projector_stride from global config

* img_u8_resize_bilinear_to_f32 move in step3vl class

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix the \r\n mess

* add width and heads to MmprojModel.set_gguf_parameters

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* chore: Remove legacy files (ggml-org#21606)

* chore: Update labeler to have separate labels for `server/webui` and `server` changes (ggml-org#21567)

* tests : remove obsolete .mjs script (ggml-org#21615)

* parser: fix MiniMax handling (ggml-org#21573)

* examples : disable cb_eval callback for --save-logits (ggml-org#21553)

This commit updates the debug example to not create the
base_callback_data.

The motivation for this is when using `--save-logits`, which is used by
examples/model-conversion scripts, we often don't care about the tensor
outputs and they just add noise to the output. This changes is quiet by
default we can always remove --save-logits to get the tensor outputs
when debugging.

* gemma : perform per-layer projections in the first layer (ggml-org#21612)

* gemma : reduce graph splits by keeping per-layer ops in the input layer

* gemma : put the per-layer proj in the first layer

* cont : move the projection before the layer loop

* metal: Q1_0 backend (ggml-org#21528)

* initial Q1_0 Metal backend

* tuning q1_0 metal kernels

* add Q1_0 to test-backend-ops

* add Q1_0<->F32 copy test

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* webgpu : Query for adapter support when registering WebGPU backend (ggml-org#21579)

* kv-cache : extend cache quantization checks (ggml-org#21586)

to also check for enabled flash attention, instead of just auto.

* Propose fix a couple of typos (ggml-org#21581)

Signed-off-by: John E <jeis4wpi@outlook.com>

* webui : send both backend_sampling == false/true (ggml-org#18781)

* webui : send both backend_sampling == false/true

* feat: Parameter sync

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* vocab : remove </s> eog token if gemma4 (ggml-org#21492)

* server: respect the ignore eos flag (ggml-org#21203)

* fix: free ctx_copy in ggml_opt_free to plug per-training-session leak (ggml-org#21592)

* fix: free ctx_copy in ggml_opt_free to plug per-training-session leak

ggml_opt_alloc populates opt_ctx->ctx_copy via a free+init pair every
time the allocated graph shape changes. The last ctx_copy from the
final ggml_opt_alloc call survives until ggml_opt_free is invoked,
but ggml_opt_free was only freeing ctx_static and ctx_cpu, never
ctx_copy. Each opt_ctx lifetime therefore leaks the final per-batch
context — ~900 KB for a typical GNN training session in
sindarin-pkg-tensor, surfaced via AddressSanitizer.

ctx_copy is nullptr-initialized and ggml_free() handles NULL safely,
so the new release is guard-free.

* Update ggml/src/ggml-opt.cpp

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: realorko <realorko@nowhere.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* CUDA: also store `node->src->data` ptrs for equality check (ggml-org#21635)

* CUDA: also store node->src->data ptrs for equality check

* address review comments

* common : skip non-primary GGUF split files when selecting model (ggml-org#21633)

We should not assume files are listed in order.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* vulkan: unify type macros to use Vx instead of _VECx (ggml-org#21605)

* ci: drop v5 `all:` composition from labeler.yml (ggml-org#21627)

actions/labeler@v6 removed the `all:` / `any:` composition keys.
The `server/webui` and `server` entries used `all:` to combine
`any-glob-to-any-file` with negated `all-globs-to-all-files`,
which now errors on every PR with:

    Unknown config options were under "changed-files": all

Flatten both entries to a single `any-glob-to-any-file`. PRs
touching both webui and other server files will now receive both
labels instead of only `server/webui`.

Co-authored-by: Marxist-Leninist <noreply@users.noreply.github.com>

* sycl : add flash-attn support for head size 512 (ggml-org#21654)

* sycl : add flash-attn support for head size 512

This patch extends the SYCL Flash Attention implementation to support head sizes (DKQ/DV) of 512.

Changes:
- Added DKQ/DV 512 cases to both tile and vector Flash Attention kernels.
- Updated kernel selection logic to allow vector kernels for head sizes up to 512 (previously 256).
- Removed unused/redundant AMD and RDNA-specific configuration functions in `fattn-tile.hpp`.
- Refactored `ggml_backend_sycl_buffer_init_tensor` to use a switch statement for clearer tensor extra buffer initialization.
- Added necessary template instances for the new 512 head size across various quantization types.

* remove defunct mxfp4 reorder from setting buffer type

* webui: Add option to pre-encode conversation for faster next turns (ggml-org#21034)

* server : fix grammar commandline args (ggml-org#21543)

Co-authored-by: AUTOMATIC <->

* fix: Model Selector choice sync (ggml-org#21628)

* metal : add missing mm-id specializations for q1_0 (ggml-org#21662)

* jinja : support ensure_ascii=true, string repetition and int/float self-filtering (ggml-org#21623)

* feat: jinja engine improvements for reka-edge

Port three Jinja engine improvements needed for the reka-edge model:
1. Python-style string repetition ("ab" * 3 → "ababab")
2. ensure_ascii=true support for tojson filter (escapes non-ASCII to \uXXXX)
3. int() builtin on value_int_t (identity, needed for Reka Edge template)

* fix: escape invalid utf8 bytes when ensure_ascii=true

The json_ensure_ascii_preserving_format function does not correctly
handle an edge case where if UTF-8 parsing fails, it adds the non-ascii
character back to the output as a raw byte.

This commit fixes that by adding the unicode standard replacement
character \\ufffd to the output instead. This is the standard behavior
for various programming languages like Python, Rust, Go, etc.

* chore: address PR comments

1. Add todo comment for supporting string repetition for array/tuples
2. Add support for float identity operation
3. Move invalid ascii test case to test_fuzzing

* chore: accept suggestion for common/jinja/value.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* vocab: add gemma4 tokenizer tests, fix edge case (ggml-org#21534)

* YATF (Yet Another Tokenizer Fix) for Gemma 4. With tests!
* Remove unnecessary hash  from update script.
* minor: move constant

* mtmd: support dots.ocr (ggml-org#17575)

* convert gguf

* clip impl

* fix conversion

* wip

* corrections

* update docs

* add gguf to test script

* model: fix multimodal padding token for gemma3n/gemma4 (ggml-org#21625)

* model: fix multimodal padding token for gemma3n/gemma4

* nits

* common : simplify autoparser tagged parser rules (ggml-org#21216)

* common : simplify autoparser tagged parser rules

* cont : remove upper limit on optional args

* cont : revert changes to parsing at the end

* cont : undo arbitrary ordering of optional args

* cont : fix uninitialized required parameters

* revert to simplify merge

* re-apply patches

* restore flexible optional arg ordering tests

* common : fix ambiguous grammar rule in gemma4 (ggml-org#21661)

* common : fix ambiguous grammar rule in gemma4

* cont : fix missing comma...

* webui: add "Send message on Enter" setting (ggml-org#21577)

* webui: make Enter to send chat a setting

* Shorten description

* Use isMobile hook from $lib/hooks

* Rebuild static output

* requirements : update transformers to 5.5.1 (ggml-org#21617)

* requirements : update transformers to 5.5.0

This commit updates the transformers dependency to version 5.5.0.

The motivation for this is that transformers 5.5.0 includes support for
Gemma4 and is required to be able to convert Gemma4 models. This is also
causing issues for user of gguf-my-repo.

Refs: https://huggingface.co/spaces/ggml-org/gguf-my-repo/discussions/202

* fix huggingface_hub version

* set version of transformers to 5.5.0

* convert : add ty ignore directives to convert_hf_to_gguf.py

This commit adds `ty: ignore` directives to transformers tokenizers
field/methods to avoid type check errors. There might be better ways to
handle this and perhaps this can be done in a follow up commit.

The motivation for this is that it looks like in transformers 5.5.0
AutoTokenizer.from_pretrained can return generic tokenizer types or None
and the type checker now produces an error when the conversion script
accesses field like tokenizer.vocab.

* convert : add ty ignore to suppress type check errors

* convert : remove incorrect type ignores

* convert : fix remaining python checks

I was running a newer version of ty locally but I've switched to
version 0.0.26 which is what CI uses and I was then able to reproduce
the errors. Sorry about the noise.

* update transformers version to 5.5.1

* ggml : check return value of CUB calls used in argsort and top-k (they all return cudaError_t) (ggml-org#21676)

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>

* ggml: backend-agnostic tensor parallelism (experimental) (ggml-org#19378)

* ggml: backend-agnostic tensor parallelism

* support for GPT-OSS, Qwen 3 MoE

* partial Vulkan fix

* add support for 4/8 GPUs

* unconditional peer access

* re-use buffers + ggml contexts

* fix output pattern

* NCCL support

* GGML: HIP: add RCCL support

* Remove shfl and AllReduce from backend interface

* move allocation workaround out of ggml-alloc.c

* 2d tensor set/get support

* Fix the seg fault without NCCL

* Apply suggestion from JohannesGaessler

* support for tensor dims % n_devs != 0

* fix view_offs scaling

* arbitrary num. of GPUs/tensor split

* fix compilation

* better granularity estimate

* Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA.

Fix compilation errors.

* partial Qwen 3 Next support

* Fix qwen3 30b (ggml-org#8)

* Fix crash with Qwen-30B-A3B Q4_0

Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation.

* Decide block size based on tensor quantization type

* Fix crashes due to KV cache serialization (ggml-org#9)

KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset.

* metal : fix build (ggml-org#7)

* static memory allocations, fix usage count

* fix tensor granularity

* more even memory distribution

* use BF16 for allreduce

* rebase fixup

* better error message for unsupported architectures

* Fix device mismatch during scatter of allReduce. (ggml-org#11)

There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies

* Enable the previous allreduce implementation. It is better in both perf and stability (ggml-org#12)

* delay AllReduce for Moe for less I/O

* build : clean-up compile warnings

* backend : move most of the meta backend API to ggml-backend-impl.h

* cont : hide unused public API in the implementation

* llama : use llama_device + remove ggml_backend_dev_is_meta()

* ggml-backend : remove unused alloc include

* minor : remove regex include

* ggml : introduce ggml-ext.h for staging new APIs

* rebase fixup

* fix tests

* llama : more robust logic for determining Meta devices (ggml-org#16)

* llama : more robust logic for determining Meta devices

* cont : fix devs size check

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* cont : fix log type

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* disable roundtrip for meta backend

* fix arch selection

* Qwen 3.5 support

* fix Gemma 4 MoE

* fix OpenVino, SYCL

* fix test-llama-archs for CPU-only builds

* Fix Qwen 3.5 MoE

* disable meta backend tests for WebGPU

* tests : filter CPU-based devices from the Meta backend tests (ggml-org#17)

* meta : formatting, naming, indentation (ggml-org#18)

* formatting : llama-model.cpp

* formatting : ggml-ext.h

* formatting : ggml-backend-meta.cpp

* meta : add TODO

* add documentation

* better error messages

* fix GPT-OSS

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Gaurav Garg <gaugarg@nvidia.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* HIP: add CDNA4 (gfx950) architecture support for MI350X/MI355X (ggml-org#21570)

Add AMD Instinct MI350X/MI355X (gfx950, CDNA4) support:

- vendors/hip.h: Add CDNA4 preprocessor define for __gfx950__
- common.cuh: Add GGML_CUDA_CC_CDNA4 and GGML_CUDA_CC_IS_CDNA4 macros
- mma.cuh: Route CDNA4 to compatible MFMA instructions:
  * f32 matmul: mfma_f32_16x16x4f32 (xf32 variant unavailable on gfx950)
  * bf16 matmul: mfma_f32_16x16x16bf16_1k (same as CDNA3)
  * int8 matmul: mfma_i32_16x16x32_i8/32x32x16 (same as CDNA3)
- mmq.cuh: Include CDNA4 in stream-k kernel dispatch

CDNA4 is largely compatible with CDNA3 except:
- No xf32 MFMA (mfma_f32_16x16x8_xf32) — routes to f32 path
- Different FP8 format (e4m3fn vs e4m3_fnuz) — not changed here

Tested on AMD Instinct MI355X (gfx950), ROCm 7.0.1:
- Build: compiles cleanly with -DAMDGPU_TARGETS=gfx950
- llama-bench (Qwen2.5-1.5B Q4_K_M, single GPU):
  * f16+FA: 40,013 tok/s prefill, 254 tok/s decode
  * q8_0+FA: functional
- Flash attention: works correctly
- MMQ: works correctly with stream-k dispatch

Co-authored-by: Andy Luo <andyluo7@users.noreply.github.com>

* CUDA: fuse muls (ggml-org#21665)

* common : add fluidity to the progress bar (ggml-org#21671)

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* vulkan: Support Q1_0 (ggml-org#21539)

* vulkan: Support Q1_0

* use get_dm

* docs : fix broken link to ggml-openvino in OPENVINO.md (ggml-org#21709)

* common : enable reasoning budget sampler for gemma4 (ggml-org#21697)

* fix: enable reasoning budget sampler for gemma4

Add thinking_start_tag and thinking_end_tag to
common_chat_params_init_gemma4(). Without these, the reasoning
budget sampler never activates for gemma4.

Make the newline after "thought" optional in the PEG parser to
handle budget=0 (sampler forces end tag before the newline).

Add test case for empty thinking block.

Fixes ggml-org#21487

* use p.space() instead of p.optional(p.literal("\n")) in gemma4 thought parser

* webui: Static build output improvements (ggml-org#21667)

* refactor: Build improvements

* chore: Formatting + package lock update

* common: mark --split-mode tensor as experimental (ggml-org#21684)

* common : fix when loading a cached HF models with unavailable API (ggml-org#21670)

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* server : ignore --alias when using --models-preset (ggml-org#21380)

I'm not sure what the purpose of keeping `--alias` was when using
`--models-preset`, but the result is really weird, as shown in the
following logs:

    $ build/bin/llama-server --models-preset preset.ini --alias "Gemma 4 E4B UD Q8_K_XL"
    ...
    init: using 31 threads for HTTP server
    srv   load_models: Loaded 2 cached model presets
    srv   load_models: Loaded 1 custom model presets from preset.ini
    main: failed to initialize router models: alias 'Gemma 4 E4B UD Q8_K_XL' for model 'angt/test-split-model-stories260K:F32' conflicts with existing model name

So I propose to simply ignore `--alias` too in this case. With this
commit, the server starts in routing mode correctly.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* ggml-webgpu: address quantization precision and backend lifecycle managment (ggml-org#21521)

* ggml(webgpu): fix the busy-polls in Emscripten  in the waitAny after ggml-org#20618, and remove the busy webgpu log

* Merge with upstream

* Fix GET_ROWS packed integer NaN when using f16 as memory buffer in shader quants

* Update Unary wgsl EXP and EXPM1 for f16 stability

* Fix GET_ROWS IQ4_XS strcut for NaN f16 canonicalization

* Fix numerical percision for unary sqrt when working with f16

* Fix NaN canonicalization for packed integers using f16

* Update err threshold for binary div ops when using f16

* backend: Keep one Dawn/WebGPU instance alive for the lifetime of the static backend

* clean: uncomment existing code logs

* clean: clean the unncessary debug info

* Refactor and generalize dequant helpers

* Remove deprecated quant structs

* Refactor shader defines to reduce repetition

* Remove error override for F16 type

* fix: fix the accidential removal of the proper initialization of ctx

* clean: clean legacy and format code

* fix: did not modify tests ops

---------

Co-authored-by: Jeremy J. Hartmann <jeremy@mtion.tv>

* ggml-webgpu: support non-square subgroup matrix configs for Intel GPUs (ggml-org#21669)

* model : make Gemma 4 shared-KV tail attn_k tensors optional on load (ggml-org#21739)

* common : add callback interface for download progress (ggml-org#21735)

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* common : better align to the updated official gemma4 template (ggml-org#21704)

* hexagon: improved Op queuing, buffer and cache management (ggml-org#21705)

* hexagon: introduce op request batching and rewrite buffer managment

The host now prepares batches of requests and dispatches them via a single dspqueue message.

Buffers are mapped explicitly by NPU while processing batches.

* hex-dma: disable l2 bypass since to work around new issue due to no flushes between Ops

* hex-utils: add explicit l2flush and l2clear helpers

* hex-opreq: use fine-grain per tensor l2 management

* hex-opreq: avoid redundant invalidates for tensors we already flushed

* hex-opreq: update debug messages

* htp-opreq: reuse ops_context

* hex-opreq: do not flush or invalidate cache lines beyond buffer boundry

* hex-opreq: fix errors in log message

* Revert "hex-opreq: do not flush or invalidate cache lines beyond buffer boundry"

This reverts commit 8b7f0a55a750a6430ce4eb1874c7feb3d720056d.

* hexagon: limit l2 flushes to 1MB which covers l2 cache

* hex-opreq: limit cache flush to 4MB

Looks like 4MB cont. vitual space should cover the 1MB cache.

* hexagon: drop cache flush size to 2MB

* hex-opreq: start reworking opreq packing

* hex-opreq: introduce new way of packing opbatch where tensors are stored separately

* hex-opreq: add a simple fastrpc call to force unmap all buffers

* hex-l2flush: somehow 2MB does not seem robust, also cleanup step size to use line-size

* hex-opreq: bump opreq batch size to 256

* hex-mm: place src1 spad at the top of vtcm for easy reuse

* hex-ops: introduce internal types and disable src1 reuse for now

Nothing new just formalizing the repack / qyn.quant types we've been using.

* htp-opreq: use tensor pointers instead of copies

* hex-opreq: introduce more robust way for tracking vtcm/spad reuse

This removes the SKIP_QUANTIZE flag that became fragile with the addition of HMX and other ops.

* hex-cumsum: fix error post opreq merge

* hex-opreq: move request batch handling into the session

Prepping everything for using dspqueue buffers and doing that inside the session is much cleaner.

* hex-mm: yet another fix for src1 reuse when we're mixing hmx/hvx

* hex-bufs: introduce pinned mmapings and use non-pinned ones for model buffers

* hex-buf: add support for allocating shared/pinned buffer for opreqs

* hex-opbatch: make opbatches configurable

* hex-naming: better name for ggml_hexagon_shared_buffer

* hex-naming: add session->c_name() helper

* hex-opbatch: start using shm but still copy for now

* hex-opbatch: use shared buffer for packing opbatch

* hex-opbatch: beter naming for opbatch related classes and code

* hex-opbatch: reuse batched tensors with same data/dims/strides

* hex-opbatch: update logging

* hex-opbatch: add support for vmem limit for op batching

* hex-opbatch: update htp side to properly support dynamic mmap/unmap

* hex-opbatch: add OB and OQ params for run-completion script and fix the asserts in batch processing

* hex-opbatch: fixed src1 handling in act ops

* hex-act: fix empty src1 handling in swiglu and friends

Simplify preamble macro while at it

* hex-mm: minor fix vtcm and dma handling in matmul

cleaning up some left-overs from merges

* hex-opbatch: allocate extra 1KB for dspqueue overhead

* hexagon: fix softmax for non-aligned tensors and cleanup vtcm alloc

* hex-mm: properly handle hmx_disabled flag

* hex-ops: update comments

* hex-ops: add debug output for get/set-rows

* hex-mmap: optimize un/mapping of buffers

* hex-opreq: global cache flush and invalidate beyond 128KB threshold

* hex-ops: add super simple opfilter regex for debugging

If an Op matches the regex hex backend will reject it.

* hex-opbatch: wireup newer ops missed in merge and update main switch to detect this in future

* hexagon: improved vtcm acquision to remove inter-op overhead

Fully compatible with QNN-HTP coex

* hex-mm: fixed hvx fallback path

* hex-mm: lower the vmem threshold a bit further to ~3GB

* hexagon: update debug & error logs

This also fixes an issue with newer llvm merging repack and non-repack
functions. We use those pointer to distinguish between buffer types.

* hexagon: move ops context into main context

Just a cleanup. We don't need separate contexts at this point.

* hex-opbatch: cleanup naming and headers for opbatch and related descriptors

* hex-fa: it's now better to enable FA during TG to reduce graph splits

* hexagon: remove GGML_HEXAGON_EXPERIMENTAL env var

It's no longer useful. Please use more flexible GGML_HEXAGON_OPFILTER to disable Ops
if needed for debugging or validation.

* hexagon: fixed editorconfig check

* Update ggml/src/ggml-hexagon/ggml-hexagon.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Trivikram Reddy <tamarnat@qti.qualcomm.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* hexagon: add support for linux on snapdragon (ggml-org#21707)

* hexagon: add support for debian on ex2

* hexagon: add -fvectotize to c/c++ cmake flags

* hexagon: remove trailing white space

* update onboarding steps

* hexagon: update linux setup documentation

* hexagon: update intallation scripts

* Hexagon: update docs

* hexagon: update onboarding scripts

---------

Co-authored-by: Zack Li <zackli@qti.qualcomm.com>

* fix: Fix broken structured output when using $refs in json_schema (ggml-org#21699)

* CUDA: also store node->src ne/nb for graph equality (ggml-org#21736)

* py : Bump typer to latest to fix huggingface_hub issue (ggml-org#21701)

* ggml : fix a few instances of missing GGML_TYPE_Q1_0 cases (ggml-org#21716)

* TP: fix Qwen 3 Next data split (ggml-org#21732)

* opencl: add basic support for q5_k (ggml-org#21593)

* opencl: add general q5_k mv

* opencl: add flattened Q5_K mv and general Q5_K mm

* opencl: fix Q5_K unit tests

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
Signed-off-by: Martin Klacer <martin.klacer@arm.com>
Signed-off-by: John E <jeis4wpi@outlook.com>
Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
Co-authored-by: Ruben Ortlam <rortlam@redhat.com>
Co-authored-by: Zheyuan Chen <sephirotheca17@gmail.com>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>
Co-authored-by: Bartowski <3266127+bartowski1182@users.noreply.github.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
Co-authored-by: Slobodan Josic <127323561+slojosic-amd@users.noreply.github.com>
Co-authored-by: Vishal Singh <vishal@zettabolt.com>
Co-authored-by: Aaron Teo <taronaeo@gmail.com>
Co-authored-by: Radoslav Gerganov <rgerganov@gmail.com>
Co-authored-by: sayap <sokann@gmail.com>
Co-authored-by: Tillerino <Tillerino@users.noreply.github.com>
Co-authored-by: uvos <carl@uvos.xyz>
Co-authored-by: Aaron Teo <aaron.teo1@ibm.com>
Co-authored-by: jeromew <jerome.wagner@m4x.org>
Co-authored-by: M1DNYT3 <42499082+M1DNYT3@users.noreply.github.com>
Co-authored-by: M1DNYT3 <m1dnyt3@MacBookPro.lan>
Co-authored-by: CISC <CISC@users.noreply.github.com>
Co-authored-by: Samanvya Tripathi <samanu09@gmail.com>
Co-authored-by: Yes You Can Have Your Own <188969017+yychyo@users.noreply.github.com>
Co-authored-by: Masato Nakasaka <masato.nakasaka@intel.com>
Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Co-authored-by: SamareshSingh <97642706+ssam18@users.noreply.github.com>
Co-authored-by: Adrien Gallouët <angt@huggingface.co>
Co-authored-by: Dan Hoffman <43101339+thedanhoffman@users.noreply.github.com>
Co-authored-by: Dan Hoffman <dhoffman@cyket.net>
Co-authored-by: Aldehir Rojas <hello@alde.dev>
Co-authored-by: Nicholas Sparks <157740354+nisparks@users.noreply.github.com>
Co-authored-by: ddh0 <chemist-mulches-39@icloud.com>
Co-authored-by: Ludovic Henry <ludovic@rivosinc.com>
Co-authored-by: Richard Davison <richard.davison1@gmail.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Co-authored-by: anchortense <daniel.redshaw@uqconnect.edu.au>
Co-authored-by: Yarden Tal <yardent@qti.qualcomm.com>
Co-authored-by: Neo Zhang <zhang.jianyu@outlook.com>
Co-authored-by: lainon1 <271530700+lainon1@users.noreply.github.com>
Co-authored-by: Gaurav Garg <gaugarg@nvidia.com>
Co-authored-by: Bipin Yadav <83943505+bipinyadav3175@users.noreply.github.com>
Co-authored-by: Pasha Khosravi <khosravipasha@users.noreply.github.com>
Co-authored-by: Masashi Yoshimura <yoshimura.masashi.frbs@gmail.com>
Co-authored-by: Dmytro Romanov <casteldazur@gmail.com>
Co-authored-by: PMZFX <georgiopapairo@gmail.com>
Co-authored-by: Kabir08 <62639358+Kabir08@users.noreply.github.com>
Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
Co-authored-by: Antoine Viallon <antoine@lesviallon.fr>
Co-authored-by: mkoker <132301062+mkoker@users.noreply.github.com>
Co-authored-by: Tom Overlund <tomov@dilacero.org>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Son H. Nguyen <33925625+nhs000@users.noreply.github.com>
Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com>
Co-authored-by: iacopPBK <iacopogiottorossi@gmail.com>
Co-authored-by: iacopPBK <iacopPBK@users.noreply.github.com>
Co-authored-by: iacopPBK <iacop@deneb.com>
Co-authored-by: Martin Klacer <martin.klacer@arm.com>
Co-authored-by: Hamish M. Blair <hmblair@stanford.edu>
Co-authored-by: forforever73 <63285796+forforever73@users.noreply.github.com>
Co-authored-by: Erik Scholz <Green-Sky@users.noreply.github.com>
Co-authored-by: John Eismeier <42679190+jeis4wpi@users.noreply.github.com>
Co-authored-by: Yuri Khrustalev <ykhrustalev@users.noreply.github.com>
Co-authored-by: RealOrko <45273739+RealOrko@users.noreply.github.com>
Co-authored-by: realorko <realorko@nowhere.com>
Co-authored-by: Marxist-Leninist <31905382+Marxist-Leninist@users.noreply.github.com>
Co-authored-by: Marxist-Leninist <noreply@users.noreply.github.com>
Co-authored-by: Akarshan Biswas <akarshan@menlo.ai>
Co-authored-by: AUTOMATIC1111 <16777216c@gmail.com>
Co-authored-by: Kwa Jie Hao <31984694+kwajiehao@users.noreply.github.com>
Co-authored-by: JvM <mourix@live.nl>
Co-authored-by: fairydreaming <166155368+fairydreaming@users.noreply.github.com>
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Co-authored-by: andyluo7 <43718156+andyluo7@users.noreply.github.com>
Co-authored-by: Andy Luo <andyluo7@users.noreply.github.com>
Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
Co-authored-by: Belem Zhang <belem.zhang@intel.com>
Co-authored-by: Berk Idem <55372926+berkidem@users.noreply.github.com>
Co-authored-by: Chen Yuan <constant.chen@uwaterloo.ca>
Co-authored-by: Jeremy J. Hartmann <jeremy@mtion.tv>
Co-authored-by: Rithik Sharma <rithiksh02@gmail.com>
Co-authored-by: MoonRide303 <130458190+MoonRide303@users.noreply.github.com>
Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
Co-authored-by: Trivikram Reddy <tamarnat@qti.qualcomm.com>
Co-authored-by: Todor Boinovski <todorb@qti.qualcomm.com>
Co-authored-by: Zack Li <zackli@qti.qualcomm.com>
Co-authored-by: Galunid <karolek1231456@gmail.com>
Co-authored-by: shaofeiqi <shaoqi@qti.qualcomm.com>
slartibardfast pushed a commit to slartibardfast/llama.cpp that referenced this pull request Apr 12, 2026
)

* ggml: backend-agnostic tensor parallelism

* support for GPT-OSS, Qwen 3 MoE

* partial Vulkan fix

* add support for 4/8 GPUs

* unconditional peer access

* re-use buffers + ggml contexts

* fix output pattern

* NCCL support

* GGML: HIP: add RCCL support

* Remove shfl and AllReduce from backend interface

* move allocation workaround out of ggml-alloc.c

* 2d tensor set/get support

* Fix the seg fault without NCCL

* Apply suggestion from JohannesGaessler

* support for tensor dims % n_devs != 0

* fix view_offs scaling

* arbitrary num. of GPUs/tensor split

* fix compilation

* better granularity estimate

* Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA.

Fix compilation errors.

* partial Qwen 3 Next support

* Fix qwen3 30b (ggml-org#8)

* Fix crash with Qwen-30B-A3B Q4_0

Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation.

* Decide block size based on tensor quantization type

* Fix crashes due to KV cache serialization (ggml-org#9)

KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset.

* metal : fix build (ggml-org#7)

* static memory allocations, fix usage count

* fix tensor granularity

* more even memory distribution

* use BF16 for allreduce

* rebase fixup

* better error message for unsupported architectures

* Fix device mismatch during scatter of allReduce. (ggml-org#11)

There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies

* Enable the previous allreduce implementation. It is better in both perf and stability (ggml-org#12)

* delay AllReduce for Moe for less I/O

* build : clean-up compile warnings

* backend : move most of the meta backend API to ggml-backend-impl.h

* cont : hide unused public API in the implementation

* llama : use llama_device + remove ggml_backend_dev_is_meta()

* ggml-backend : remove unused alloc include

* minor : remove regex include

* ggml : introduce ggml-ext.h for staging new APIs

* rebase fixup

* fix tests

* llama : more robust logic for determining Meta devices (ggml-org#16)

* llama : more robust logic for determining Meta devices

* cont : fix devs size check

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* cont : fix log type

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* disable roundtrip for meta backend

* fix arch selection

* Qwen 3.5 support

* fix Gemma 4 MoE

* fix OpenVino, SYCL

* fix test-llama-archs for CPU-only builds

* Fix Qwen 3.5 MoE

* disable meta backend tests for WebGPU

* tests : filter CPU-based devices from the Meta backend tests (ggml-org#17)

* meta : formatting, naming, indentation (ggml-org#18)

* formatting : llama-model.cpp

* formatting : ggml-ext.h

* formatting : ggml-backend-meta.cpp

* meta : add TODO

* add documentation

* better error messages

* fix GPT-OSS

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Gaurav Garg <gaugarg@nvidia.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
oussamaahmia pushed a commit to oussamaahmia/llama-cpp-turboquant-gemma4 that referenced this pull request Apr 13, 2026
)

* ggml: backend-agnostic tensor parallelism

* support for GPT-OSS, Qwen 3 MoE

* partial Vulkan fix

* add support for 4/8 GPUs

* unconditional peer access

* re-use buffers + ggml contexts

* fix output pattern

* NCCL support

* GGML: HIP: add RCCL support

* Remove shfl and AllReduce from backend interface

* move allocation workaround out of ggml-alloc.c

* 2d tensor set/get support

* Fix the seg fault without NCCL

* Apply suggestion from JohannesGaessler

* support for tensor dims % n_devs != 0

* fix view_offs scaling

* arbitrary num. of GPUs/tensor split

* fix compilation

* better granularity estimate

* Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA.

Fix compilation errors.

* partial Qwen 3 Next support

* Fix qwen3 30b (TheTom#8)

* Fix crash with Qwen-30B-A3B Q4_0

Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation.

* Decide block size based on tensor quantization type

* Fix crashes due to KV cache serialization (TheTom#9)

KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset.

* metal : fix build (TheTom#7)

* static memory allocations, fix usage count

* fix tensor granularity

* more even memory distribution

* use BF16 for allreduce

* rebase fixup

* better error message for unsupported architectures

* Fix device mismatch during scatter of allReduce. (TheTom#11)

There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies

* Enable the previous allreduce implementation. It is better in both perf and stability (TheTom#12)

* delay AllReduce for Moe for less I/O

* build : clean-up compile warnings

* backend : move most of the meta backend API to ggml-backend-impl.h

* cont : hide unused public API in the implementation

* llama : use llama_device + remove ggml_backend_dev_is_meta()

* ggml-backend : remove unused alloc include

* minor : remove regex include

* ggml : introduce ggml-ext.h for staging new APIs

* rebase fixup

* fix tests

* llama : more robust logic for determining Meta devices (TheTom#16)

* llama : more robust logic for determining Meta devices

* cont : fix devs size check

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* cont : fix log type

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* disable roundtrip for meta backend

* fix arch selection

* Qwen 3.5 support

* fix Gemma 4 MoE

* fix OpenVino, SYCL

* fix test-llama-archs for CPU-only builds

* Fix Qwen 3.5 MoE

* disable meta backend tests for WebGPU

* tests : filter CPU-based devices from the Meta backend tests (TheTom#17)

* meta : formatting, naming, indentation (TheTom#18)

* formatting : llama-model.cpp

* formatting : ggml-ext.h

* formatting : ggml-backend-meta.cpp

* meta : add TODO

* add documentation

* better error messages

* fix GPT-OSS

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Gaurav Garg <gaugarg@nvidia.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Apr 21, 2026
)

* ggml: backend-agnostic tensor parallelism

* support for GPT-OSS, Qwen 3 MoE

* partial Vulkan fix

* add support for 4/8 GPUs

* unconditional peer access

* re-use buffers + ggml contexts

* fix output pattern

* NCCL support

* GGML: HIP: add RCCL support

* Remove shfl and AllReduce from backend interface

* move allocation workaround out of ggml-alloc.c

* 2d tensor set/get support

* Fix the seg fault without NCCL

* Apply suggestion from JohannesGaessler

* support for tensor dims % n_devs != 0

* fix view_offs scaling

* arbitrary num. of GPUs/tensor split

* fix compilation

* better granularity estimate

* Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA.

Fix compilation errors.

* partial Qwen 3 Next support

* Fix qwen3 30b (reeselevine#8)

* Fix crash with Qwen-30B-A3B Q4_0

Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation.

* Decide block size based on tensor quantization type

* Fix crashes due to KV cache serialization (reeselevine#9)

KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset.

* metal : fix build (reeselevine#7)

* static memory allocations, fix usage count

* fix tensor granularity

* more even memory distribution

* use BF16 for allreduce

* rebase fixup

* better error message for unsupported architectures

* Fix device mismatch during scatter of allReduce. (reeselevine#11)

There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies

* Enable the previous allreduce implementation. It is better in both perf and stability (reeselevine#12)

* delay AllReduce for Moe for less I/O

* build : clean-up compile warnings

* backend : move most of the meta backend API to ggml-backend-impl.h

* cont : hide unused public API in the implementation

* llama : use llama_device + remove ggml_backend_dev_is_meta()

* ggml-backend : remove unused alloc include

* minor : remove regex include

* ggml : introduce ggml-ext.h for staging new APIs

* rebase fixup

* fix tests

* llama : more robust logic for determining Meta devices (reeselevine#16)

* llama : more robust logic for determining Meta devices

* cont : fix devs size check

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* cont : fix log type

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* disable roundtrip for meta backend

* fix arch selection

* Qwen 3.5 support

* fix Gemma 4 MoE

* fix OpenVino, SYCL

* fix test-llama-archs for CPU-only builds

* Fix Qwen 3.5 MoE

* disable meta backend tests for WebGPU

* tests : filter CPU-based devices from the Meta backend tests (reeselevine#17)

* meta : formatting, naming, indentation (reeselevine#18)

* formatting : llama-model.cpp

* formatting : ggml-ext.h

* formatting : ggml-backend-meta.cpp

* meta : add TODO

* add documentation

* better error messages

* fix GPT-OSS

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Gaurav Garg <gaugarg@nvidia.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
mudler added a commit to mudler/LocalAI that referenced this pull request Apr 25, 2026
)

Adds split_mode (alias sm) to the llama.cpp backend options allowlist,
accepting none|layer|row|tensor. The tensor value targets the experimental
backend-agnostic tensor parallelism from ggml-org/llama.cpp#19378 and
requires a llama.cpp build that includes that PR, FlashAttention enabled,
KV-cache quantization disabled, and a manually set context size.


Assisted-by: Claude:claude-opus-4-7

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

AMD ZenDNN Issues related to the AMD ZenDNN backend Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Ascend NPU issues specific to Ascend NPUs examples ggml changes relating to the ggml tensor library for machine learning Hexagon IBM zDNN issues specific to IBM zDNN Accelerator model Model specific Nvidia GPU Issues specific to Nvidia GPUs OpenCL Issues specific to the OpenCL backend OpenVINO SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language testing Everything test related Vulkan Issues specific to the Vulkan backend WebGPU

Projects

None yet

Development

Successfully merging this pull request may close these issues.