Skip to content

ggml-cpu: add RVV vec dot kernels for quantization types#7

Merged
taimur-10x merged 3 commits into10x/riscv-quant-vec-dotfrom
10x/riscv-quant-vec-dot-kernels
Jan 20, 2026
Merged

ggml-cpu: add RVV vec dot kernels for quantization types#7
taimur-10x merged 3 commits into10x/riscv-quant-vec-dotfrom
10x/riscv-quant-vec-dot-kernels

Conversation

@taimur-10x
Copy link
Copy Markdown
Collaborator

@taimur-10x taimur-10x commented Jan 12, 2026

Summary

This PR adds RVV vector dot kernels for a number of quantization types.

Key Changes

  • Added the following RVV kernels:
Kernel VLEN
ggml_vec_dot_iq1_s_q8_K 256
ggml_vec_dot_iq1_m_q8_K 256
ggml_vec_dot_iq2_s_q8_K 128, 256
ggml_vec_dot_iq3_s_q8_K 256
ggml_vec_dot_tq1_0_q8_K 256
ggml_vec_dot_tq2_0_q8_K 256

Testing

Kernels were functionally tested through test-quantize-fns for 128-bit and 256-bit VLENs.

Benchmarking Results

End-to-end benchmarking on BananaPI-BPI F3 (VLEN=256) with llama-bench.

IQ1_S

Tokens / Second

Prefill

Model Prompt Size Scalar Vectorized
Tinyllama 1.1B 32 3.08 10.23
Tinyllama 1.1B 64 3.08 10.21
Tinyllama 1.1B 128 3.09 10.27
Tinyllama 1.1B 256 3.07 10.22
Tinyllama 1.1B 512 3.09 10.24

Decode

Model Prompt Size = 32 Scalar Vectorized
Tinyllama 1.1B 10 2.44 7.08
Tinyllama 1.1B 16 2.53 7.04
Tinyllama 1.1B 32 2.65 7.14
Tinyllama 1.1B 64 2.58 7.11
Tinyllama 1.1B 100 2.54 7.11

IQ1_M

Tokens / Second

Prefill

Model Prompt Size Scalar Vectorized
Tinyllama 1.1B 32 1.96 8.01
Tinyllama 1.1B 64 1.95 8.05
Tinyllama 1.1B 128 1.96 8.04
Tinyllama 1.1B 256 1.96 8.05
Tinyllama 1.1B 512 1.94 7.98

Decode

Model Prompt Size = 32 Scalar Vectorized
Tinyllama 1.1B 10 1.71 5.62
Tinyllama 1.1B 16 1.73 5.59
Tinyllama 1.1B 32 1.73 5.64
Tinyllama 1.1B 64 1.72 5.63
Tinyllama 1.1B 100 1.73 5.65

IQ2_S

Tokens / Second

Prefill

Model Prompt Size Scalar Vectorized
Tinyllama 1.1B 32 8.42 1.17
Tinyllama 1.1B 64 7.57 1.16
Tinyllama 1.1B 128 8.78 1.14
Tinyllama 1.1B 256 8.57 1.2
Tinyllama 1.1B 512 8.68 1.95

Decode

Model Prompt Size = 32 Scalar Vectorized
Tinyllama 1.1B 10 3.11 1.18
Tinyllama 1.1B 16 3.45 1.02
Tinyllama 1.1B 32 3.25 1.06
Tinyllama 1.1B 64 3.27 1.06
Tinyllama 1.1B 100 3.15 1.04

IQ3_S

Tokens / Second

Prefill

Model Prompt Size Scalar Vectorized
Tinyllama 1.1B 32 8.42 1.19
Tinyllama 1.1B 64 7.57 1.25
Tinyllama 1.1B 128 8.78 1.21
Tinyllama 1.1B 256 8.57 1.18
Tinyllama 1.1B 512 8.68 1.12

Decode

Model Prompt Size = 32 Scalar Vectorized
Tinyllama 1.1B 10 3.11 1.22
Tinyllama 1.1B 16 3.45 1.14
Tinyllama 1.1B 32 3.25 1.13
Tinyllama 1.1B 64 3.27 1.13
Tinyllama 1.1B 100 3.15 1.12

TQ1_0

Tokens / Second

Prefill

Model Prompt Size Scalar Vectorized
Tinyllama 1.1B 32 8.42 2.71
Tinyllama 1.1B 64 7.57 2.72
Tinyllama 1.1B 128 8.78 2.75
Tinyllama 1.1B 256 8.57 2.63
Tinyllama 1.1B 512 8.68 2.68

Decode

Model Prompt Size = 32 Scalar Vectorized
Tinyllama 1.1B 10 3.11 2.50
Tinyllama 1.1B 16 3.45 2.41
Tinyllama 1.1B 32 3.25 2.46
Tinyllama 1.1B 64 3.27 2.51
Tinyllama 1.1B 100 3.15 2.37

TQ2_0

Tokens / Second

Prefill

Model Prompt Size Scalar Vectorized
Tinyllama 1.1B 32 8.42 4.37
Tinyllama 1.1B 64 7.57 3.89
Tinyllama 1.1B 128 8.78 3.81
Tinyllama 1.1B 256 8.57 3.81
Tinyllama 1.1B 512 8.68 3.79

Decode

Model Prompt Size = 32 Scalar Vectorized
Tinyllama 1.1B 10 3.11 4.22
Tinyllama 1.1B 16 3.45 3.60
Tinyllama 1.1B 32 3.25 3.51
Tinyllama 1.1B 64 3.27 3.16
Tinyllama 1.1B 100 3.15 3.11

Future Work

Subsequent PRs plan to extend existing RVV kernels for quantization types to other VLENs.

@taimur-10x taimur-10x marked this pull request as draft January 12, 2026 14:57
@github-actions github-actions Bot added the ggml label Jan 12, 2026
taimur-10x and others added 2 commits January 12, 2026 20:01
Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>
Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>
@taimur-10x taimur-10x force-pushed the 10x/riscv-quant-vec-dot-kernels branch from ef71fd4 to d432cf5 Compare January 12, 2026 15:08
@taimur-10x taimur-10x self-assigned this Jan 12, 2026
@taimur-10x taimur-10x force-pushed the 10x/riscv-quant-vec-dot-kernels branch from 716d818 to 85ecce6 Compare January 13, 2026 14:23
Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>
@taimur-10x taimur-10x force-pushed the 10x/riscv-quant-vec-dot-kernels branch from 85ecce6 to 8814361 Compare January 14, 2026 11:20
@taimur-10x taimur-10x marked this pull request as ready for review January 14, 2026 12:13
@taimur-10x
Copy link
Copy Markdown
Collaborator Author

@luhenry, @xctan, could this be reviewed please? Thank you.

@taimur-10x taimur-10x changed the base branch from master to 10x/riscv-quant-vec-dot January 20, 2026 10:43
@taimur-10x taimur-10x merged commit d436d68 into 10x/riscv-quant-vec-dot Jan 20, 2026
58 of 76 checks passed
rehan-10xengineer pushed a commit that referenced this pull request Apr 14, 2026
)

* ggml: backend-agnostic tensor parallelism

* support for GPT-OSS, Qwen 3 MoE

* partial Vulkan fix

* add support for 4/8 GPUs

* unconditional peer access

* re-use buffers + ggml contexts

* fix output pattern

* NCCL support

* GGML: HIP: add RCCL support

* Remove shfl and AllReduce from backend interface

* move allocation workaround out of ggml-alloc.c

* 2d tensor set/get support

* Fix the seg fault without NCCL

* Apply suggestion from JohannesGaessler

* support for tensor dims % n_devs != 0

* fix view_offs scaling

* arbitrary num. of GPUs/tensor split

* fix compilation

* better granularity estimate

* Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA.

Fix compilation errors.

* partial Qwen 3 Next support

* Fix qwen3 30b (#8)

* Fix crash with Qwen-30B-A3B Q4_0

Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation.

* Decide block size based on tensor quantization type

* Fix crashes due to KV cache serialization (#9)

KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset.

* metal : fix build (#7)

* static memory allocations, fix usage count

* fix tensor granularity

* more even memory distribution

* use BF16 for allreduce

* rebase fixup

* better error message for unsupported architectures

* Fix device mismatch during scatter of allReduce. (#11)

There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies

* Enable the previous allreduce implementation. It is better in both perf and stability (#12)

* delay AllReduce for Moe for less I/O

* build : clean-up compile warnings

* backend : move most of the meta backend API to ggml-backend-impl.h

* cont : hide unused public API in the implementation

* llama : use llama_device + remove ggml_backend_dev_is_meta()

* ggml-backend : remove unused alloc include

* minor : remove regex include

* ggml : introduce ggml-ext.h for staging new APIs

* rebase fixup

* fix tests

* llama : more robust logic for determining Meta devices (ggml-org#16)

* llama : more robust logic for determining Meta devices

* cont : fix devs size check

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* cont : fix log type

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* disable roundtrip for meta backend

* fix arch selection

* Qwen 3.5 support

* fix Gemma 4 MoE

* fix OpenVino, SYCL

* fix test-llama-archs for CPU-only builds

* Fix Qwen 3.5 MoE

* disable meta backend tests for WebGPU

* tests : filter CPU-based devices from the Meta backend tests (ggml-org#17)

* meta : formatting, naming, indentation (ggml-org#18)

* formatting : llama-model.cpp

* formatting : ggml-ext.h

* formatting : ggml-backend-meta.cpp

* meta : add TODO

* add documentation

* better error messages

* fix GPT-OSS

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Gaurav Garg <gaugarg@nvidia.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants