diff --git a/Cargo.lock b/Cargo.lock index 27e9865cd..a12f0ecec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -795,6 +795,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + [[package]] name = "block-buffer" version = "0.10.4" @@ -1292,6 +1298,17 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "core-graphics-types" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb" +dependencies = [ + "bitflags 2.11.0", + "core-foundation 0.10.1", + "libc", +] + [[package]] name = "cpubits" version = "0.1.0" @@ -1454,6 +1471,15 @@ dependencies = [ "subtle", ] +[[package]] +name = "cudarc" +version = "0.19.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f071cd6a7b5d51607df76aa2d426aaabc7a74bc6bdb885b8afa63a880572ad9b" +dependencies = [ + "libloading", +] + [[package]] name = "dap" version = "0.4.1-alpha1" @@ -2005,7 +2031,28 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" dependencies = [ - "foreign-types-shared", + "foreign-types-shared 0.1.1", +] + +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared 0.3.1", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", ] [[package]] @@ -2014,6 +2061,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -3019,6 +3072,16 @@ version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" +[[package]] +name = "libloading" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "libm" version = "0.2.16" @@ -3159,6 +3222,15 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + [[package]] name = "markdown" version = "1.0.0" @@ -3220,6 +3292,21 @@ dependencies = [ "autocfg", ] +[[package]] +name = "metal" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7047791b5bc903b8cd963014b355f71dc9864a9a0b727057676c1dcae5cbc15" +dependencies = [ + "bitflags 2.11.0", + "block", + "core-graphics-types", + "foreign-types 0.5.0", + "log", + "objc", + "paste", +] + [[package]] name = "mime" version = "0.3.17" @@ -3931,6 +4018,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8f8bdf33df195859076e54ab11ee78a1b208382d3a26ec40d142ffc1ecc49ef" +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", +] + [[package]] name = "object" version = "0.37.3" @@ -3980,7 +4076,7 @@ checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" dependencies = [ "bitflags 2.11.0", "cfg-if", - "foreign-types", + "foreign-types 0.3.2", "libc", "once_cell", "openssl-macros", @@ -4624,12 +4720,14 @@ dependencies = [ "base64", "blake3", "bytes", + "cudarc", "divan", "hex", "itertools 0.14.0", "keccak 0.2.0-rc.2", "mavros-artifacts", "mavros-vm", + "metal", "noirc_abi", "ntt", "postcard", @@ -7259,7 +7357,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/WizardOfMenlo/whir/?rev=0aeaa7f337c743d9ddfcb9d909628d6491e3355c#0aeaa7f337c743d9ddfcb9d909628d6491e3355c" +source = "git+https://github.com/zkfriendly/whir.git?branch=zkfr%2Fadd-metal-gpu-refactor#8c6dd165b27a29d71d87092b2d94658f3a3263d4" dependencies = [ "ark-ff 0.5.0", "ark-serialize 0.5.0", diff --git a/Cargo.toml b/Cargo.toml index 73d5ac541..27f9eba0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -159,6 +159,8 @@ xz2 = "0.1.7" zerocopy = "0.8.25" zeroize = "1.8.1" zstd = "0.13" +metal = "0.33.0" +cudarc = { version = "0.19.4", default-features = false, features = ["cuda-12000", "driver", "nvrtc", "dynamic-loading"] } # WASM-specific dependencies js-sys = "0.3" @@ -200,4 +202,4 @@ spongefish = { git = "https://github.com/arkworks-rs/spongefish", features = [ "sha2", ], rev = "fcc277f8a857fdeeadd7cca92ab08de63b1ff1a1" } spongefish-pow = { git = "https://github.com/arkworks-rs/spongefish", rev = "fcc277f8a857fdeeadd7cca92ab08de63b1ff1a1" } -whir = { git ="https://github.com/WizardOfMenlo/whir/", rev="0aeaa7f337c743d9ddfcb9d909628d6491e3355c", features = ["tracing", "rs_in_order"] } +whir = { git = "https://github.com/zkfriendly/whir.git", branch = "zkfr/add-metal-gpu-refactor", features = ["tracing", "rs_in_order"] } diff --git a/provekit/common/Cargo.toml b/provekit/common/Cargo.toml index 46cb173d7..22b64a631 100644 --- a/provekit/common/Cargo.toml +++ b/provekit/common/Cargo.toml @@ -59,6 +59,12 @@ xz2.workspace = true [package.metadata.cargo-machete] ignored = ["keccak"] # Intentionally anchored to keep Noir beta.19-compatible RC version +[target.'cfg(target_os = "macos")'.dependencies] +metal.workspace = true + +[target.'cfg(target_os = "linux")'.dependencies] +cudarc.workspace = true + [dev-dependencies] divan.workspace = true proptest.workspace = true diff --git a/provekit/common/src/interner.rs b/provekit/common/src/interner.rs index 822a6a7dd..2189729ca 100644 --- a/provekit/common/src/interner.rs +++ b/provekit/common/src/interner.rs @@ -9,7 +9,7 @@ pub struct Interner { values: Vec, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] pub struct InternedFieldElement(usize); impl Default for Interner { diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 3c7e76032..80f4fdae5 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -41,23 +41,61 @@ pub use { /// /// Must be called once before any prove/verify operations. /// Idempotent — safe to call multiple times. -pub fn register_ntt() { - use std::sync::{Arc, Once}; +pub fn register_whir_backends() { + use std::sync::Once; static INIT: Once = Once::new(); INIT.call_once(|| { - // Register NTT for polynomial operations - #[cfg(not(feature = "provekit_ntt"))] - let ntt: Arc> = - Arc::new(whir::algebra::ntt::NttEngine::::new_from_fftfield()); - - #[cfg(feature = "provekit_ntt")] - let ntt: Arc> = - Arc::new(crate::ntt::RSFr); - - whir::algebra::ntt::NTT.insert(ntt); + let irs_committer = build_irs_committer(); + whir::protocols::irs_commit::IRS_COMMITTERS.insert(irs_committer); // Register Skyscraper (ProveKit-specific); WHIR's built-in engines // (SHA2, Keccak, Blake3, etc.) are pre-registered via whir::hash::ENGINES. - whir::hash::ENGINES.register(Arc::new(skyscraper::SkyscraperHashEngine)); + whir::hash::ENGINES + .register(std::sync::Arc::new(skyscraper::SkyscraperHashEngine)); }); } + +/// Build the IRS committer for BN254. +/// +/// With `provekit_ntt`: uses ProveKit's optimized NTT backends (Metal on +/// macOS with CPU fallback, CPU-only on other targets). +/// Without `provekit_ntt`: uses whir's built-in `NttEngine`. +fn build_irs_committer() +-> std::sync::Arc> { + use std::sync::Arc; + use whir::protocols::irs_commit::CpuIrsCommitter; + + #[cfg(feature = "provekit_ntt")] + { + #[cfg(target_os = "macos")] + match crate::ntt::MetalBn254Ntt::new() { + Ok(ntt) => return Arc::new(ntt), + Err(err) => { + tracing::info!( + error = %err, + "Metal BN254 IRS backend unavailable, using ProveKit CPU fallback" + ); + } + } + + #[cfg(target_os = "linux")] + match crate::ntt::CudaBn254Ntt::new() { + Ok(ntt) => return Arc::new(ntt), + Err(err) => { + tracing::info!( + error = %err, + "CUDA BN254 IRS backend unavailable, using ProveKit CPU fallback" + ); + } + } + + Arc::new(CpuIrsCommitter::new(Arc::new(crate::ntt::RSFr))) + } + + #[cfg(not(feature = "provekit_ntt"))] + { + Arc::new(CpuIrsCommitter::new(Arc::new( + whir::algebra::ntt::NttEngine::::new_from_fftfield(), + ))) + } +} diff --git a/provekit/common/src/ntt.rs b/provekit/common/src/ntt/backends/cpu.rs similarity index 100% rename from provekit/common/src/ntt.rs rename to provekit/common/src/ntt/backends/cpu.rs diff --git a/provekit/common/src/ntt/backends/cuda/BENCHMARK.md b/provekit/common/src/ntt/backends/cuda/BENCHMARK.md new file mode 100644 index 000000000..52ffa426a --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/BENCHMARK.md @@ -0,0 +1,430 @@ +# CUDA NTT backend — benchmark notes + +This document records how the CUDA backend (`provekit/common/src/ntt/backends/cuda/`) +was measured and what the numbers look like on the reference workstation. It mirrors +the existing Metal backend in structure, but on Linux + NVIDIA the matrix produced by +the encode kernel can stay on the GPU between commit and the WHIR open phase, which is +where most of the host-memory savings come from. + +The CUDA backend implements `whir::protocols::irs_commit::IrsCommitter` (in +`commit.rs`) so that: + +1. `encode_matrix` runs on the GPU and leaves the encoded matrix in a pooled + `CudaSlice` device buffer. +2. The Merkle leaf hashes are computed on the GPU (`encode_field_rows_le` + + `sha256_many`). +3. The internal Merkle nodes are computed on the GPU layer-by-layer over a + single tree buffer. +4. `DeviceRows::read_rows` and `DeviceMerkleWitness::read_nodes` lazily download + only the rows / nodes WHIR actually opens. + +When the workload is too small or the hash isn't SHA-2, the implementation falls +back cleanly to `CpuIrsCommitter` (same predicate as the Metal backend). + +--- + +## Environment + +| Component | Value | +|---|---| +| OS | Pop!_OS 24.04 LTS, Linux 6.18.7 (x86_64) | +| CPU | 12th Gen Intel Core i7-12700H (10 P+E cores, 20 threads) | +| RAM | 38 GiB | +| GPU | NVIDIA GeForce RTX 3060 Laptop, 6 GiB, compute capability 8.6 | +| Driver | 580.126.18 | +| CUDA toolkit | 12.0 (V12.0.140), `libnvrtc.so.12`, `libcuda.so.1` | +| Rust | nightly 2026-03-03 (1.96.0-nightly) | +| `cudarc` crate | 0.19.4 (`dynamic-loading` + `nvrtc`, no link-time CUDA dep) | +| ProveKit commit | `c2c969a5` (branch `zkfr/add-cuda-gpu`) | +| WHIR commit | `8742f70` ("feat: IRS commit") | + +The `cudarc` dependency is feature-gated on `target_os = "linux"` and uses +`dynamic-loading`, so the binary has no compile-time link to the CUDA libraries — +it simply fails the `CudaBn254Ntt::new()` initialisation and falls back to CPU +when the libraries aren't available at runtime. + +### Workload + +The benchmarks below all use the same Noir circuit: + +``` +noir-examples/noir-passport-monolithic/complete_age_check +``` + +with `prover.pkp` prepared using SHA-2 leaf/Merkle hashes (the configuration the +GPU commit path requires). The prove command itself is: + +```bash +cd noir-examples/noir-passport-monolithic/complete_age_check +target/release/provekit-cli prove ./prover.pkp ./Prover.toml +``` + +The same command is used for both modes; only the env var changes: + +- **CPU baseline** — `PROVEKIT_DISABLE_CUDA_NTT=1` (forces the fallback path). +- **CUDA** — no env var; `CudaBn254Ntt::new()` initialises and registers as the + `IrsCommitter` for `Fr = ark_bn254::Fr`. + +Other useful env vars: + +- `PROVEKIT_CUDA_NTT_TRACE=1` — emit per-call backend events on stderr (init + device, NTT roots-cache hits/misses, encode shapes, PTX cache hits). + +### Build + +```bash +cargo build --release --bin provekit-cli +``` + +The binary picks up CUDA automatically because the workspace `provekit-common` +dependency enables the `provekit_ntt` feature, and `lib.rs::build_irs_committer` +registers `Arc::new(CudaBn254Ntt)` directly as the IRS committer on Linux when +init succeeds. + +### Parity tests + +Three tests compare the GPU output to the CPU reference: + +```bash +cargo test --release -p provekit-common --features provekit_ntt \ + --lib ntt::cuda_tests -- --test-threads=1 +``` + +``` +running 3 tests +test ntt::cuda_tests::cuda_matches_cpu_for_large_case ... ok +test ntt::cuda_tests::cuda_matches_cpu_for_multi_poly_case ... ok +test ntt::cuda_tests::cuda_matches_cpu_with_masks ... ok + +test result: ok. 3 passed; 0 failed; 0 ignored; 0 measured; 72 filtered out +``` + +--- + +## Measurement methodology + +Three measurement passes were used: + +1. **OS-level wall / CPU / memory** via `/usr/bin/time -v`, 5 runs per mode + (Python wrapper to compute mean ± stdev of each field). +2. **Per-process peak GPU memory** via `nvidia-smi --query-compute-apps` + sampled every 50 ms during a single CUDA prove. +3. **Per-shape commit timings** parsed from the prover's own + `tracing_flame`-style hierarchical logs; each `whir::irs_commit::commit` + span is paired with its closing duration line and decomposed into + "encode" (`gpu_encode` / `encode_matrix` for CUDA, `interleaved_encode` + for CPU) vs "rest" (leaf hash + Merkle tree). + +All three Python scripts are reproduced verbatim below so the numbers are +re-measurable exactly. + +--- + +## 1. OS-level stats — `/usr/bin/time -v` + +### Command + +5 runs per mode through the same Python wrapper. The wrapper was invoked from the +prover example directory: + +```python +import os, re, subprocess, statistics + +BIN = "/home/zkfriendly/dev/prove/provekit/target/release/provekit-cli" +ARGS = ["prove", "./prover.pkp", "./Prover.toml"] + +def one_run(env_extra): + env = dict(os.environ); env.update(env_extra) + p = subprocess.run(["/usr/bin/time", "-v", BIN, *ARGS], + env=env, capture_output=True) + err = p.stderr.decode(errors='ignore') + grab = lambda pat, conv=float: (lambda m: conv(m.group(1)) if m else None)(re.search(pat, err)) + wall_str = re.search(r'Elapsed \(wall clock\) time \(h:mm:ss or m:ss\):\s*(\S+)', err) + parts = [float(x) for x in wall_str.group(1).split(':')] if wall_str else [] + wall = (parts[-1] + (parts[-2]*60 if len(parts)>=2 else 0) + + (parts[-3]*3600 if len(parts)>=3 else 0)) if parts else None + return dict( + wall=wall, + user=grab(r'User time \(seconds\):\s*([\d.]+)'), + sys=grab(r'System time \(seconds\):\s*([\d.]+)'), + cpu_pct=grab(r'Percent of CPU this job got:\s*(\d+)\s*%'), + rss_mb=(grab(r'Maximum resident set size \(kbytes\):\s*(\d+)', int) or 0)/1024, + minor_pf=grab(r'Minor \(reclaiming a frame\) page faults:\s*(\d+)', int), + major_pf=grab(r'Major \(requiring I/O\) page faults:\s*(\d+)', int), + vol_cs=grab(r'Voluntary context switches:\s*(\d+)', int), + invol_cs=grab(r'Involuntary context switches:\s*(\d+)', int), + exit=p.returncode, + ) + +for label, env in [("CPU", {"PROVEKIT_DISABLE_CUDA_NTT":"1"}), ("CUDA", {})]: + runs = [one_run(env) for _ in range(5)] + for k in ("wall","user","sys","cpu_pct","rss_mb","minor_pf","major_pf","vol_cs","invol_cs"): + vs = [r[k] for r in runs if r[k] is not None] + m = statistics.mean(vs); s = statistics.stdev(vs) if len(vs)>1 else 0 + print(f" {label:<6}{k:<14}: {m:.2f} ± {s:.2f}") +``` + +### Results (mean ± stdev, n=5) + +| metric | CPU | CUDA | Δ | +|---|---:|---:|---| +| wall (s) | 3.60 ± 0.27 | 3.62 ± 0.20 | flat (within noise) | +| internal `run:` (s) | 3.59 ± 0.27 | 3.49 ± 0.20 | **−2.8 %** | +| user CPU (s) | 25.11 ± 3.62 | 21.24 ± 2.38 | **−15.4 %** | +| sys CPU (s) | 1.47 ± 0.30 | 1.67 ± 0.26 | +14 % (driver ioctls) | +| CPU utilization | 735 ± 62 % | 631 ± 52 % | **−104 pp** | +| **peak host RSS (MB)** | **920 ± 9** | **693 ± 10** | **−227 MB / −24.7 %** | +| minor page faults | 558 639 | 392 727 | −30 % | +| voluntary ctx switches | 17 924 | 14 750 | −17 % | +| major page faults | 0 | 0 | — | + +**The headline numbers are peak host RSS (−25 %) and user CPU time (−15 %).** +Wall clock is flat because the WHIR pipeline contains CPU-bound stages (sumcheck, +fold ops, file decompression) that aren't accelerated, and the GPU work runs on a +separate stream that overlaps with those stages. + +--- + +## 2. Peak GPU memory — `nvidia-smi` + +### Command + +```python +import subprocess, threading, time, os + +BIN = "/home/zkfriendly/dev/prove/provekit/target/release/provekit-cli" +ARGS = ["prove", "./prover.pkp", "./Prover.toml"] + +def baseline_used(): + out = subprocess.check_output( + ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"] + ).decode() + return int(out.strip().splitlines()[0]) + +def proc_used(): + out = subprocess.check_output( + ["nvidia-smi", "--query-compute-apps=pid,used_memory", + "--format=csv,noheader,nounits"] + ).decode().strip() + return sum(int(line.split(',')[1].strip()) for line in out.splitlines() if line) + +base = baseline_used() +print(f"baseline GPU used by other apps: {base} MiB") + +peak_total = base; peak_proc = 0; done = threading.Event() +def sample(): + nonlocal_peak = {"t": peak_total, "p": peak_proc} + while not done.is_set(): + nonlocal_peak["t"] = max(nonlocal_peak["t"], baseline_used()) + nonlocal_peak["p"] = max(nonlocal_peak["p"], proc_used()) + time.sleep(0.05) + sample.result = nonlocal_peak + +t = threading.Thread(target=sample); t.start() +subprocess.run([BIN, *ARGS], capture_output=True) +done.set(); t.join() +print(sample.result) +``` + +### Results + +``` +baseline GPU used by other apps: 1031 MiB (cosmic-comp + Xwayland) +peak GPU total used: 1855 MiB +peak provekit-cli on GPU: 1032 MiB +``` + +So the prover itself peaks at **~1 GiB on the GPU**. This is two pooled +working buffers (`current` + `transposed`) for the largest commit +(1×1048576 vector → codeword length 524 288 × 8 messages = 4 M GpuFields = 128 MiB +each; rounded up to next pow-of-two by the bucket pool), plus the persistent +matrix and Merkle-tree buffers held alive by the active `IrsCommitArtifact`, +plus the NTT roots-of-unity cache. Comfortably inside the 6 GiB VRAM on the 3060. + +### Net memory across host + device + +| backend | host peak | GPU peak | combined | +|---|---:|---:|---:| +| CPU | 920 MB | 0 MB | 920 MB | +| CUDA | 693 MB | 1032 MB | ~1.7 GB | + +We trade ~227 MB of host RAM for ~1 GB of GPU RAM. That trade is the whole point +of `IrsCommitter`: the encoded matrix and Merkle tree never have to be +materialised on the host. + +--- + +## 3. Commit-phase decomposition — parsed from the prover's tracing logs + +### Capture command + +```bash +cd noir-examples/noir-passport-monolithic/complete_age_check +PROVEKIT_DISABLE_CUDA_NTT=1 target/release/provekit-cli prove \ + ./prover.pkp ./Prover.toml > /tmp/prove_cpu.log 2>&1 +target/release/provekit-cli prove \ + ./prover.pkp ./Prover.toml > /tmp/prove_cuda.log 2>&1 +``` + +### Parser + +The prover emits hierarchical spans like + +``` +├─╮ whir::protocols::irs_commit::commit self=size 1×1048576/8 rate 2⁻2.00 … +│ ├─╮ provekit_common::ntt::backends::cuda::encode::encode_matrix … +│ ├─╯ encode_matrix: 33.20 ms duration … +├─╯ commit: 52.35 ms duration … +``` + +The parser strips ANSI, pairs each `├─╮` opening line with its sibling `├─╯` +closing line at the same depth, and groups by the `size A×B` shape suffix. +For each parent commit it also locates the inner encode child to split +"encode" from "rest" (leaf-hash + Merkle-tree). + +```python +import re + +def parse(path): + s = open(path, 'rb').read().decode(errors='ignore') + return re.sub(r'\x1b\[[0-9;]*[mGKH]', '', s).splitlines() + +def depth(line): + head = line.split('├')[0] if '├' in line else line.split('╰')[0] if '╰' in line else line + return head.count('│') + +def to_ms(v, u): + v = float(v) + return v*1000 if u=='s' else (v/1000 if u=='μs' else (v/1e6 if u=='ns' else v)) + +DUR = re.compile(r'([0-9.]+)\s*(ms|μs|s|ns)\s*duration') + +def commits(path): + lines = parse(path) + rows = [] + for i, line in enumerate(lines): + if '├─╮' not in line: continue + m = re.search(r'irs_commit::commit self=size (\d+×\d+)/\d+', line) + if not m: continue + d = depth(line); shape = m.group(1) + # close at same depth + close = next((j for j in range(i+1, len(lines)) + if '├─╯' in lines[j] and depth(lines[j]) == d), None) + if close is None: continue + cm = DUR.search(lines[close]) + total_ms = to_ms(cm.group(1), cm.group(2)) if cm else None + # find the inner encode child + encode_ms, kind = None, None + for j in range(i+1, close): + cl = lines[j] + if '├─╮' not in cl: continue + if 'cuda::encode::encode_matrix' in cl: kind = 'gpu_encode' + elif 'cpu::interleaved_encode' in cl: kind = 'cpu_encode' + else: continue + cd = depth(cl) + cclose = next((k for k in range(j+1, close+1) + if '├─╯' in lines[k] and depth(lines[k]) == cd), None) + if cclose is not None: + em = DUR.search(lines[cclose]) + if em: encode_ms = to_ms(em.group(1), em.group(2)) + break + rows.append((shape, total_ms, encode_ms, kind)) + return rows +``` + +### Top-level commits per shape (mean of 2 calls each) + +| shape (vec×size) | CPU total | CUDA total | speedup | +|---|---:|---:|---:| +| 1×1048576 *(GPU)* | 190.5 ms | **52.4 ms** | **3.6×** | +| 1×131072 *(GPU)* | 88.7 ms | **22.4 ms** | **4.0×** | +| 1×16384 *(GPU)* | 36.0 ms | **10.5 ms** | **3.4×** | +| 21×4096 *(GPU)* | 10.1 ms | **5.8 ms** | **1.7×** | +| 1×2048 *(CPU)* | 15.0 ms | 11.9 ms | 1.3× (noise) | +| 1×256 *(CPU)* | 5.4 ms | 7.5 ms | noise (both CPU) | +| 1×512 / 1×64 / 1×32 / 1×8 *(CPU)* | sub-ms each | sub-ms each | flat | +| **TOTAL (all 20 commits)** | **698 ms** | **227 ms** | **3.1× / −67 %** | + +The four GPU-eligible shapes (≥ 2²⁰ elements **or** ≥ 64 rows, with the SHA-2 +leaf/Merkle hash) account for nearly all the savings: 650 → 183 ms = −467 ms. +The smaller shapes correctly fall through to `CpuIrsCommitter` and show no +meaningful change (same `RSFr` encoder both ways). + +### Inside each commit — encode vs hash + merkle + +| shape | path | total ms | encode ms | rest (leaf-hash + merkle) ms | +|---|---|---:|---:|---:| +| 1×1048576 | CPU encode + CPU sha+merkle | 190.5 | 160.5 | 30.0 | +| 1×1048576 | **GPU encode + GPU sha+merkle** | **52.4** | **33.2** | **19.2** | +| 1×131072 | CPU | 88.7 | 74.9 | 13.8 | +| 1×131072 | **GPU** | **22.4** | **13.8** | **8.6** | +| 1×16384 | CPU | 36.0 | 28.6 | 7.4 | +| 1×16384 | **GPU** | **10.5** | **5.6** | **4.9** | +| 21×4096 | CPU | 10.1 | 7.6 | 2.6 | +| 21×4096 | **GPU** | **5.8** | **1.9** | 4.0 | + +Per-row interpretation: + +- **encode** drops by **~4–5×** on the four GPU shapes — the NTT kernel is + the primary win. +- **hash + merkle** drops by **~1.5×**. CPU SHA-256 has hardware acceleration + (`sha_ni`) on this Alder Lake part, so the GPU's per-byte SHA throughput + edge over the CPU is much smaller than its NTT edge. The GPU still wins + because (a) it overlaps with the encode on the same stream and (b) it + avoids a 128 MB host download. +- For `21×4096` the "rest" cost actually rises slightly on GPU (4.0 vs 2.6 + ms): with 168 small rows the hash kernel and the per-Merkle-layer launch + overhead start to dominate over the per-leaf savings. Still a net win on + the commit total. + +--- + +## Why wall-clock barely moves despite these wins + +The GPU saves **471 ms of IRS-commit work** (698 → 227 ms) and **227 MB of host +memory**, but the prover's wall-clock barely changes (3.60 → 3.62 s). Two reasons: + +1. The remainder of the prove pipeline — sumcheck rounds, the + `fold_weight_to_mask_size` calls, `evaluate_gamma_block`, and the LZMA- + compressed `.pkp` decompression — is unchanged and stays on the CPU. That + accounts for ~3 seconds of the ~3.6 s wall. +2. The CUDA work runs on a separate CUDA stream and overlaps with those CPU + stages, so the saved work shows up as **less host CPU time** and **less + peak host RSS** rather than as a wall-clock drop. + +The matrix-stays-on-device pattern (the `IrsCommitter` impl + `DeviceRows` / +`DeviceMerkleWitness`) is the key foundation for ever pulling additional +WHIR stages onto the GPU, since the matrix doesn't have to be marshalled +back across PCIe between phases. + +--- + +## Cross-reference to the source layout + +``` +provekit/common/src/ntt/backends/cuda/ +├── mod.rs CudaBn254Ntt + ReedSolomon impl + new()/runtime() + GPU-shape filter +├── engine.rs CudaRuntime: cudarc context, default stream, nvrtc compile + PTX +│ disk cache (~/.cache/provekit/cuda), kernel handles, NTT roots +│ cache, byte-level pooled buffer pool, raw memset/memcpy helpers +├── encode.rs gpu_encode (returns Vec) + encode_matrix (returns DeviceMatrix +│ on device) + encode_shape — uploads &[Fr] directly via the +│ layout-equivalent &[GpuField] view +├── commit.rs IrsCommitter, hash_rows_to_buffer, build_merkle_witness, +│ DeviceRows: MatrixRows, DeviceMerkleWitness: WitnessTrait +├── field.rs Fr ↔ GpuField (4×u64 Montgomery limbs) +├── types.rs GpuField + DeviceRepr/ValidAsZeroBits + param structs + +│ DeviceMatrix / DeviceRows / DeviceMerkleWitness / EncodeShape +├── logging.rs trace_event (PROVEKIT_CUDA_NTT_TRACE) +└── kernels/ + ├── common.cuh Fe + struct layouts + BN254_MODULUS / N0PRIME / SHA256_K + ├── field.cuh Montgomery add/sub/mul + from_mont (port of metal/field.metal) + ├── ntt.cu bit_reverse_permute_rows_in_place, radix2_ntt_stage_rows_in_place, + │ replicate_first_coset (port of metal/ntt.metal) + ├── matrix.cu transpose_matrix + └── sha256.cu encode_field_rows_le + sha256_many (port of metal/sha256.metal) +``` + +The CUDA backend is gated on `cfg(target_os = "linux")`. On macOS the Metal +backend (in `backends/metal/`) is selected by `lib.rs::build_irs_committer`; +on other platforms the CPU committer is used. diff --git a/provekit/common/src/ntt/backends/cuda/commit.rs b/provekit/common/src/ntt/backends/cuda/commit.rs new file mode 100644 index 000000000..67975cae3 --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/commit.rs @@ -0,0 +1,355 @@ +use { + super::{ + engine::PooledBuffer, + field::gpu_to_fr, + types::{ + DeviceMatrix, DeviceMerkleWitness, DeviceRows, EncodeFieldBytesParams, GpuField, + HashManyParams, + }, + CudaBn254Ntt, + }, + ark_bn254::Fr, + cudarc::driver::PushKernelArg, + std::{mem::size_of, sync::Arc}, + whir::{ + hash::Hash, + protocols::{ + irs_commit::{CpuIrsCommitter, IrsCommitArtifact, IrsCommitter, MatrixRows}, + matrix_commit::{Config as MatrixCommitConfig, Encodable}, + merkle_tree::WitnessTrait, + }, + }, +}; + +impl IrsCommitter for CudaBn254Ntt { + fn commit( + &self, + messages: &[&[Fr]], + masks: &[Fr], + codeword_length: usize, + matrix_commit: &MatrixCommitConfig, + ) -> IrsCommitArtifact { + let cpu_commit = || { + CpuIrsCommitter::new(Arc::new(crate::ntt::RSFr)).commit( + messages, + masks, + codeword_length, + matrix_commit, + ) + }; + + if !Self::supports_gpu_shape(codeword_length, messages) + || !Self::supports_gpu_commit(matrix_commit) + { + return cpu_commit(); + } + + let Ok(matrix) = self.encode_matrix(messages, masks, codeword_length) else { + return cpu_commit(); + }; + let Ok(leaf_hashes) = self.hash_rows_to_buffer(&matrix) else { + return cpu_commit(); + }; + let Ok(merkle_witness) = self.build_merkle_witness(matrix_commit, &leaf_hashes) else { + return cpu_commit(); + }; + + IrsCommitArtifact { + root: merkle_witness.root(), + rows: Arc::new(DeviceRows { + rows: matrix.rows, + cols: matrix.cols, + buffer: matrix.buffer, + }), + matrix_witness: merkle_witness, + } + } +} + +impl CudaBn254Ntt { + /// Hash every row of `matrix` with SHA-256 on the GPU. Returns a pooled + /// device buffer holding `matrix.rows * 32` bytes (one digest per row). + pub(super) fn hash_rows_to_buffer( + &self, + matrix: &DeviceMatrix, + ) -> Result { + let runtime = self.runtime()?; + if matrix.rows == 0 { + return Ok(runtime.pooled_buffer::(0)); + } + + let total_elements = matrix.rows * matrix.cols; + let total_bytes = total_elements * Fr::encoded_size(); + let message_size = matrix.cols * Fr::encoded_size(); + if total_elements > u32::MAX as usize || message_size > u32::MAX as usize { + return Err("GPU hash launch exceeds current 32-bit grid limit".into()); + } + + // Encoded canonical bytes (rows × cols × 32 bytes), in natural row + // order (the encode kernel applies the bit-reversal of the row + // index when reading from `matrix.buffer`). + let encoded = runtime.pooled_bytes(total_bytes); + let hashes = runtime.pooled_buffer::(matrix.rows); + + let encode_params = EncodeFieldBytesParams { + rows: matrix.rows as u32, + cols: matrix.cols as u32, + }; + { + let cfg = runtime.launch_cfg_1d(total_elements); + // SAFETY: kernel signature matches; ranges in-bounds. + unsafe { + runtime + .stream + .launch_builder(&runtime.encode_bytes_function) + .arg(matrix.buffer.slice()) + .arg(encoded.slice()) + .arg(&encode_params) + .launch(cfg) + } + .map_err(|e| format!("launch encode_field_rows_le: {e:?}"))?; + } + + let hash_params = HashManyParams { + size: message_size as u32, + count: matrix.rows as u32, + }; + { + let cfg = runtime.launch_cfg_1d(matrix.rows); + // SAFETY: kernel signature matches; ranges in-bounds. + unsafe { + runtime + .stream + .launch_builder(&runtime.sha256_function) + .arg(encoded.slice()) + .arg(hashes.slice()) + .arg(&hash_params) + .launch(cfg) + } + .map_err(|e| format!("launch sha256_many (leaf): {e:?}"))?; + } + + runtime.synchronize()?; + Ok(hashes) + } + + /// Build the Merkle tree on device by repeatedly hashing pairs of nodes. + /// `leaf_hashes` is the (rows × 32 bytes) buffer produced by + /// [`hash_rows_to_buffer`]. The resulting `DeviceMerkleWitness` owns a + /// pooled tree buffer; `read_nodes` lazily downloads requested nodes. + pub(super) fn build_merkle_witness( + &self, + matrix_commit: &MatrixCommitConfig, + leaf_hashes: &PooledBuffer, + ) -> Result, String> { + let runtime = self.runtime()?; + let num_leaves = matrix_commit.num_rows(); + let leaf_capacity = 1usize << matrix_commit.merkle_tree.layers.len(); + let num_nodes = matrix_commit.merkle_tree.num_nodes(); + if leaf_capacity == 0 { + return Err("invalid empty Merkle leaf capacity".into()); + } + if num_nodes == 0 { + return Err("invalid empty Merkle tree".into()); + } + if num_leaves > leaf_capacity { + return Err("Merkle config has fewer layers than leaves require".into()); + } + if leaf_capacity > u32::MAX as usize { + return Err("GPU Merkle launch exceeds current 32-bit grid limit".into()); + } + + let tree = runtime.pooled_buffer::(num_nodes); + runtime.memset_zeros(&tree, 0, num_nodes * size_of::())?; + if num_leaves != 0 { + runtime.memcpy_dtod_bytes( + &tree, + 0, + leaf_hashes, + 0, + num_leaves * size_of::(), + )?; + } + + // Walk the layers from leaf to root, each iteration hashing the + // 2^k current-layer nodes into 2^(k-1) parent-layer nodes. + let mut previous_offset = 0usize; + let mut previous_len = leaf_capacity; + for _ in matrix_commit.merkle_tree.layers.iter().rev() { + let current_len = previous_len / 2; + if current_len == 0 { + break; + } + if current_len > u32::MAX as usize { + return Err("GPU Merkle launch exceeds current 32-bit grid limit".into()); + } + let params = HashManyParams { + size: 64, + count: current_len as u32, + }; + let prev_byte_off = previous_offset * size_of::(); + let curr_byte_off = (previous_offset + previous_len) * size_of::(); + + // The sha256_many kernel takes input/output by raw byte + // pointer; cudarc forwards the device pointer plus offset via a + // byte-offset CudaView. + let input_view = tree + .slice() + .try_slice(prev_byte_off..prev_byte_off + previous_len * size_of::()) + .ok_or_else(|| "merkle: input view oob".to_string())?; + let output_view = tree + .slice() + .try_slice(curr_byte_off..curr_byte_off + current_len * size_of::()) + .ok_or_else(|| "merkle: output view oob".to_string())?; + let cfg = runtime.launch_cfg_1d(current_len); + // SAFETY: kernel reads from `input_view` and writes to + // `output_view`; the two ranges are disjoint regions of the + // tree buffer (parent layer comes after child layer in memory). + unsafe { + runtime + .stream + .launch_builder(&runtime.sha256_function) + .arg(&input_view) + .arg(&output_view) + .arg(¶ms) + .launch(cfg) + } + .map_err(|e| format!("launch sha256_many (merkle): {e:?}"))?; + + previous_offset += previous_len; + previous_len = current_len; + } + + runtime.synchronize()?; + + // Read just the root (last node) back to surface it via + // `IrsCommitArtifact.root`. + let mut root_bytes = [0u8; size_of::()]; + runtime.download_bytes(&tree, (num_nodes - 1) * size_of::(), &mut root_bytes)?; + let mut root = Hash::default(); + root.0.copy_from_slice(&root_bytes); + + Ok(Arc::new(DeviceMerkleWitness { + num_nodes, + root, + buffer: tree, + })) + } +} + +impl WitnessTrait for DeviceMerkleWitness { + fn num_nodes(&self) -> usize { + self.num_nodes + } + + fn read_nodes(&self, indices: &[usize]) -> Vec { + if indices.is_empty() { + return Vec::new(); + } + let runtime = match crate::ntt::CudaBn254Ntt.runtime() { + Ok(r) => Arc::clone(r), + Err(err) => panic!("CUDA runtime unavailable: {err}"), + }; + // Batch all per-node downloads on the stream and synchronise once at + // the end. Without this, each node request would block the host on + // its own stream-sync round-trip, which adds up to ~hundreds of ms + // for the WHIR open phase (Merkle paths × in-domain samples). + let hash_bytes = size_of::(); + let mut staging = vec![0u8; indices.len() * hash_bytes]; + for (i, &index) in indices.iter().enumerate() { + assert!(index < self.num_nodes, "Merkle node index out of bounds"); + let dst = &mut staging[i * hash_bytes..(i + 1) * hash_bytes]; + runtime + .download_bytes_async(&self.buffer, index * hash_bytes, dst) + .unwrap_or_else(|err| panic!("CUDA Merkle node download failed: {err}")); + } + runtime + .synchronize() + .unwrap_or_else(|err| panic!("CUDA Merkle node sync failed: {err}")); + let mut out = Vec::with_capacity(indices.len()); + for chunk in staging.chunks_exact(hash_bytes) { + let mut hash = Hash::default(); + hash.0.copy_from_slice(chunk); + out.push(hash); + } + out + } +} + +impl std::fmt::Debug for DeviceRows { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DeviceRows") + .field("rows", &self.rows) + .field("cols", &self.cols) + .finish() + } +} + +impl MatrixRows for DeviceRows { + fn num_rows(&self) -> usize { + self.rows + } + + fn num_cols(&self) -> usize { + self.cols + } + + /// Lazily download just the requested rows from the device buffer, + /// applying the bit-reversal of the codeword index so that index `i` + /// in `indices` corresponds to natural codeword position `i`. + /// + /// All per-row memcpys are queued on the stream first; we synchronise + /// the host once at the end. Previously this issued one + /// `cuMemcpyDtoH_v2` + sync per row, which dominated the WHIR open phase + /// for the largest commit (~160 ms for 127 sampled indices). + fn read_rows(&self, indices: &[usize]) -> Vec { + if indices.is_empty() { + return Vec::new(); + } + let runtime = match crate::ntt::CudaBn254Ntt.runtime() { + Ok(r) => Arc::clone(r), + Err(err) => panic!("CUDA runtime unavailable: {err}"), + }; + let cols = self.cols; + let row_bytes = cols * size_of::(); + // Pre-allocate one contiguous staging buffer so each per-row memcpy + // writes into a disjoint slice. We never re-read from `staging` + // until after `synchronize` returns. + let mut staging: Vec = vec![GpuField::default(); indices.len() * cols]; + for (i, &row) in indices.iter().enumerate() { + assert!(row < self.rows, "row index out of bounds"); + let src_row = reverse_bit_index(row, self.rows); + let dst = &mut staging[i * cols..(i + 1) * cols]; + // SAFETY: GpuField is plain repr(C) bytes; dst.len() == cols; + // each call writes into a disjoint slice of `staging` and + // `staging` is not read until after synchronize() below. + unsafe { + runtime + .download_into_async::(&self.buffer, src_row * row_bytes, dst) + .unwrap_or_else(|err| panic!("CUDA matrix row download failed: {err}")); + } + } + runtime + .synchronize() + .unwrap_or_else(|err| panic!("CUDA matrix row sync failed: {err}")); + staging.into_iter().map(gpu_to_fr).collect() + } +} + +impl std::fmt::Debug for DeviceMerkleWitness { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DeviceMerkleWitness") + .field("num_nodes", &self.num_nodes) + .field("root", &self.root) + .finish() + } +} + +fn reverse_bit_index(index: usize, codeword_length: usize) -> usize { + let bits = usize::BITS - (codeword_length - 1).leading_zeros(); + if bits == 0 { + index + } else { + index.reverse_bits() >> (usize::BITS - bits) + } +} diff --git a/provekit/common/src/ntt/backends/cuda/encode.rs b/provekit/common/src/ntt/backends/cuda/encode.rs new file mode 100644 index 000000000..1f6fd4e5b --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/encode.rs @@ -0,0 +1,328 @@ +use { + super::{ + logging::trace_event, + types::{ + BitReverseParams, DeviceMatrix, EncodeShape, GpuField, NttStageParams, + ReplicateCosetsParams, TransposeParams, + }, + CudaBn254Ntt, + }, + ark_bn254::Fr, + cudarc::driver::PushKernelArg, + std::mem::{align_of, size_of}, + tracing::instrument, + whir::algebra::ntt::ReedSolomon, +}; + +impl CudaBn254Ntt { + /// Public entry point used by `ReedSolomon::interleaved_encode`. Runs the + /// GPU encode and downloads the resulting matrix into a `Vec`. + /// + /// The `IrsCommitter` path uses `encode_matrix` directly to keep the + /// matrix on device. + #[instrument(skip(self, messages, masks), fields( + num_messages = messages.len(), + message_len = messages.first().map(|c| c.len()), + codeword_length = codeword_length, + mask_len = masks.len().checked_div(messages.len()) + ))] + pub fn gpu_encode( + &self, + messages: &[&[Fr]], + masks: &[Fr], + codeword_length: usize, + ) -> Result, String> { + // We rely on Fr having identical memory layout to GpuField (4×u64 + // Montgomery limbs). This lets us host↔device memcpy &[Fr] slices + // directly, avoiding a host pack/unpack pass. + const _: () = assert!(size_of::() == size_of::()); + const _: () = assert!(align_of::() == align_of::()); + + let matrix = self.encode_matrix(messages, masks, codeword_length)?; + if matrix.rows == 0 || matrix.cols == 0 { + return Ok(Vec::new()); + } + let runtime = self.runtime()?; + + // Read back the full matrix into a host Vec, applying the + // bit-reversal of the codeword index so the output matches the CPU + // ordering (host row r ⇄ codeword index r). + let total = matrix.rows * matrix.cols; + let mut host_buf: Vec = Vec::with_capacity(total); + // SAFETY: every element is overwritten by the memcpy below before + // we read from `output`. Capacity is exactly `total`. + unsafe { host_buf.set_len(total) }; + // SAFETY: Fr has identical layout to GpuField (asserted above); + // bytes copied = total * sizeof(GpuField). + unsafe { + runtime.download_into::(&matrix.buffer, 0, &mut host_buf)?; + } + + // The on-device layout is [codeword_index_in_BR_order][message_index]. + // Apply bit-reversal of the row index to produce natural codeword + // order. + let cols = matrix.cols; + let rows = matrix.rows; + let mut output: Vec = Vec::with_capacity(total); + // SAFETY: capacity == total; every element overwritten by the loop. + unsafe { output.set_len(total) }; + for dst_row in 0..rows { + let natural_row = reverse_bit_index(dst_row, rows); + let src_start = natural_row * cols; + let dst_start = dst_row * cols; + output[dst_start..dst_start + cols] + .copy_from_slice(&host_buf[src_start..src_start + cols]); + } + Ok(output) + } + + /// Encode the messages and masks into a device matrix. The returned + /// `DeviceMatrix.buffer` holds the encoded values laid out as + /// `[codeword_index_in_BR_order, message_index]` (row-major), with + /// `rows = codeword_length`, `cols = message_count`. Lifetime of the + /// buffer is tied to the returned `DeviceMatrix` (Arc'd via the pool). + #[instrument(skip(self, messages, masks), fields( + num_messages = messages.len(), + message_len = messages.first().map(|c| c.len()), + codeword_length = codeword_length, + mask_len = masks.len().checked_div(messages.len()) + ))] + pub fn encode_matrix( + &self, + messages: &[&[Fr]], + masks: &[Fr], + codeword_length: usize, + ) -> Result { + let runtime = self.runtime()?; + let shape = Self::encode_shape(messages, masks, codeword_length)?; + if shape.total_elements == 0 { + return Ok(DeviceMatrix { + rows: 0, + cols: 0, + buffer: runtime.pooled_buffer::(0), + }); + } + + trace_event(format_args!( + "encode rows={} codeword_length={} num_cosets={} coset_size={} polynomials={} \ + path=coset", + shape.row_count, + codeword_length, + shape.num_cosets, + shape.coset_size, + messages.len(), + )); + + // Working buffer, zero-initialised so the unused tail of each row + // (between message_length+mask_length and coset_size) is zero before + // the replicate step overwrites the rest of the row. + let current = runtime.pooled_buffer::(shape.total_elements); + runtime.memset_zeros(¤t, 0, shape.total_elements * size_of::())?; + + // Upload messages and masks directly from the caller's &[Fr] slices, + // bypassing any host-side pack pass. (Fr ↔ GpuField layout is + // asserted equivalent.) + for (row_index, msg) in messages.iter().enumerate() { + let dst_offset = (row_index * shape.codeword_length) * size_of::(); + // SAFETY: layout equivalence asserted; range fits in current. + unsafe { + runtime.upload_into::(msg, ¤t, dst_offset)?; + } + } + if shape.mask_length != 0 { + // Masks are laid out [mask_col][row] in the caller's slice but we + // need them at [row][message_length + mask_col] in the device + // buffer. Pack on host into a small contiguous staging buffer, + // then upload per row (mask_length is typically very small). + let mut staging: Vec = vec![Fr::default(); shape.mask_length]; + for row_index in 0..shape.row_count { + for mask_col in 0..shape.mask_length { + staging[mask_col] = masks[mask_col * shape.row_count + row_index]; + } + let dst_offset = (row_index * shape.codeword_length + shape.message_length) + * size_of::(); + // SAFETY: layout equivalence; range fits in current. + unsafe { + runtime.upload_into::(&staging, ¤t, dst_offset)?; + } + } + } + + let roots = runtime.roots_buffer(codeword_length)?; + let transposed = runtime.pooled_buffer::(shape.total_elements); + + let stage_count = codeword_length.trailing_zeros() as usize; + let skipped_stage_count = shape.num_cosets.trailing_zeros() as usize; + let total_butterflies = shape.total_elements / 2; + + // 1. Replicate the first coset across the rest of each row. + let replicate_params = ReplicateCosetsParams { + row_len: shape.codeword_length as u32, + coset_size: shape.coset_size as u32, + trailing_elements: shape + .row_count + .saturating_mul(shape.codeword_length - shape.coset_size) + as u32, + }; + if replicate_params.trailing_elements != 0 { + let cfg = runtime.launch_cfg_1d(replicate_params.trailing_elements as usize); + // SAFETY: kernel signature matches arg list; arrays in-bounds. + unsafe { + runtime + .stream + .launch_builder(&runtime.replicate_cosets_function) + .arg(current.slice()) + .arg(&replicate_params) + .launch(cfg) + } + .map_err(|e| format!("launch replicate_first_coset: {e:?}"))?; + } + + // 2. Bit-reverse permute each row (so the subsequent NTT proceeds + // in natural-order indexing through twiddles). + let bit_reverse_params = BitReverseParams { + row_len: shape.codeword_length as u32, + log_n: stage_count as u32, + total_elements: shape.total_elements as u32, + _pad0: 0, + }; + { + let cfg = runtime.launch_cfg_1d(shape.total_elements); + // SAFETY: kernel signature matches; in-bounds. + unsafe { + runtime + .stream + .launch_builder(&runtime.bit_reverse_function) + .arg(current.slice()) + .arg(&bit_reverse_params) + .launch(cfg) + } + .map_err(|e| format!("launch bit_reverse: {e:?}"))?; + } + + // 3. Iteratively run NTT butterfly stages. + let mut twiddle_offset = (1usize << skipped_stage_count).saturating_sub(1); + let total_butterflies_u32 = total_butterflies as u32; + for stage in skipped_stage_count..stage_count { + let half_m = 1usize << stage; + let params = NttStageParams { + row_len: shape.codeword_length as u32, + half_m: half_m as u32, + twiddle_offset: twiddle_offset as u32, + _pad0: 0, + }; + let cfg = runtime.launch_cfg_1d(total_butterflies); + // SAFETY: kernel signature matches; ranges in-bounds. + unsafe { + runtime + .stream + .launch_builder(&runtime.ntt_stage_function) + .arg(current.slice()) + .arg(&*roots) + .arg(¶ms) + .arg(&total_butterflies_u32) + .launch(cfg) + } + .map_err(|e| format!("launch ntt_stage(stage={stage}): {e:?}"))?; + twiddle_offset += 1usize << stage; + } + + // 4. Transpose to [codeword_index_in_BR_order][message_index]. + let transpose_params = TransposeParams { + rows: shape.row_count as u32, + cols: shape.codeword_length as u32, + total_elements: shape.total_elements as u32, + }; + { + let cfg = runtime.launch_cfg_1d(shape.total_elements); + // SAFETY: kernel signature matches; ranges in-bounds. + unsafe { + runtime + .stream + .launch_builder(&runtime.transpose_function) + .arg(current.slice()) + .arg(transposed.slice()) + .arg(&transpose_params) + .launch(cfg) + } + .map_err(|e| format!("launch transpose: {e:?}"))?; + } + + // Synchronise so that `current` can be safely returned to the pool + // when it's dropped at the end of this function. + runtime.synchronize()?; + + Ok(DeviceMatrix { + rows: shape.codeword_length, + cols: shape.row_count, + buffer: transposed, + }) + } + + pub fn encode_shape( + messages: &[&[Fr]], + masks: &[Fr], + codeword_length: usize, + ) -> Result { + if messages.is_empty() { + return Ok(EncodeShape { + row_count: 0, + codeword_length, + coset_size: 0, + message_length: 0, + mask_length: 0, + num_cosets: 0, + total_elements: 0, + }); + } + if !Self::supports_gpu_shape(codeword_length, messages) { + return Err("problem shape unsupported for GPU path".into()); + } + + let row_count = messages.len(); + let message_length = messages[0].len(); + if messages.iter().any(|row| row.len() != message_length) { + return Err("all messages must have the same length".into()); + } + if !masks.len().is_multiple_of(row_count) { + return Err("mask count must be divisible by row count".into()); + } + let mask_length = masks.len() / row_count; + let masked_message_length = message_length + mask_length; + let mut coset_size = Self + .next_order(masked_message_length) + .ok_or_else(|| "no supported coset size for encode".to_string())?; + while !codeword_length.is_multiple_of(coset_size) { + coset_size = Self + .next_order(coset_size + 1) + .ok_or_else(|| "no supported coset size for encode".to_string())?; + } + let num_cosets = codeword_length / coset_size; + + let total_elements = row_count + .checked_mul(codeword_length) + .ok_or_else(|| "GPU encode launch exceeds current 32-bit grid limit".to_string())?; + if total_elements > u32::MAX as usize { + return Err("GPU encode launch exceeds current 32-bit grid limit".into()); + } + + Ok(EncodeShape { + row_count, + codeword_length, + coset_size, + message_length, + mask_length, + num_cosets, + total_elements, + }) + } +} + +fn reverse_bit_index(index: usize, codeword_length: usize) -> usize { + let bits = usize::BITS - (codeword_length - 1).leading_zeros(); + if bits == 0 { + index + } else { + index.reverse_bits() >> (usize::BITS - bits) + } +} diff --git a/provekit/common/src/ntt/backends/cuda/engine.rs b/provekit/common/src/ntt/backends/cuda/engine.rs new file mode 100644 index 000000000..862c99501 --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/engine.rs @@ -0,0 +1,525 @@ +use { + super::{field::fr_to_gpu, logging::trace_event, types::GpuField}, + ark_bn254::Fr, + ark_ff::{FftField, Field}, + cudarc::{ + driver::{ + result, sys, CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, + DriverError, LaunchConfig, + }, + nvrtc::{compile_ptx_with_opts, CompileOptions, Ptx}, + }, + std::{ + collections::{hash_map::DefaultHasher, HashMap}, + fs, + hash::{Hash as _, Hasher}, + mem::{size_of, ManuallyDrop}, + path::PathBuf, + sync::{Arc, Mutex}, + }, +}; + +const CUDA_SOURCE: &str = concat!( + "// common.cuh\n", + include_str!("kernels/common.cuh"), + "\n// field.cuh\n", + include_str!("kernels/field.cuh"), + "\n// ntt.cu\n", + include_str!("kernels/ntt.cu"), + "\n// matrix.cu\n", + include_str!("kernels/matrix.cu"), + "\n// sha256.cu\n", + include_str!("kernels/sha256.cu"), + "\n", +); + +// --------------------------------------------------------------------------- +// PooledBuffer: an Arc-wrapped CudaSlice that returns to the runtime's +// pool on drop. We bucket allocations by power-of-two byte size, mirroring +// the Metal backend's PooledBuffer. All sizes are tracked in bytes so a +// single pool can serve GpuField, Hash, and raw-byte buffers. +// +// Buffers are exposed as `&CudaSlice`. cudarc's `launch_builder.arg(&s)` +// accepts an immutable reference even for kernels that mutate the slice (it +// only forwards the raw device pointer to the kernel), so we never need a +// `&mut CudaSlice` and can keep the buffer behind `Arc`. +// --------------------------------------------------------------------------- + +struct PooledBufferInner { + runtime: Arc, + bucket_bytes: usize, + /// `ManuallyDrop` so we can take the slice out in `Drop` and recycle it. + buffer: ManuallyDrop>, +} + +#[derive(Clone)] +pub struct PooledBuffer(Arc); + +impl PooledBuffer { + pub fn slice(&self) -> &CudaSlice { + &self.0.buffer + } + + pub fn bucket_bytes(&self) -> usize { + self.0.bucket_bytes + } +} + +impl std::fmt::Debug for PooledBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PooledBuffer") + .field("bucket_bytes", &self.0.bucket_bytes) + .finish() + } +} + +impl Drop for PooledBufferInner { + fn drop(&mut self) { + // SAFETY: `buffer` is initialised on construction and dropped only here. + let buffer = unsafe { ManuallyDrop::take(&mut self.buffer) }; + self.runtime.recycle_buffer(self.bucket_bytes, buffer); + } +} + +// --------------------------------------------------------------------------- +// CudaRuntime: long-lived context, default stream, kernel handles, NTT root +// table cache, and buffer pool. Initialised once per process via +// `CudaBn254Ntt::new`. +// --------------------------------------------------------------------------- + +pub struct CudaRuntime { + // Held to keep the context alive (the stream and modules transitively + // depend on it); we don't call methods on it directly. + #[allow(dead_code)] + pub context: Arc, + pub stream: Arc, + pub device_name: String, + pub compute_capability: (i32, i32), + pub max_block_size: u32, + #[allow(dead_code)] + module: Arc, + pub bit_reverse_function: CudaFunction, + pub ntt_stage_function: CudaFunction, + pub replicate_cosets_function: CudaFunction, + pub transpose_function: CudaFunction, + pub encode_bytes_function: CudaFunction, + pub sha256_function: CudaFunction, + roots_cache: Mutex>>>, + buffer_pool: Mutex>>>, +} + +impl CudaRuntime { + pub fn new() -> Result { + let context = CudaContext::new(0).map_err(driver_err)?; + let stream = context.default_stream(); + let device_name = context.name().map_err(driver_err)?; + let (cc_major, cc_minor) = context.compute_capability().map_err(driver_err)?; + + let arch = arch_for_compute_capability(cc_major, cc_minor); + let ptx = compile_or_load_ptx(CUDA_SOURCE, arch)?; + let module = context.load_module(ptx).map_err(driver_err)?; + + let bit_reverse_function = module + .load_function("bit_reverse_permute_rows_in_place") + .map_err(driver_err)?; + let ntt_stage_function = module + .load_function("radix2_ntt_stage_rows_in_place") + .map_err(driver_err)?; + let replicate_cosets_function = module + .load_function("replicate_first_coset") + .map_err(driver_err)?; + let transpose_function = module + .load_function("transpose_matrix") + .map_err(driver_err)?; + let encode_bytes_function = module + .load_function("encode_field_rows_le") + .map_err(driver_err)?; + let sha256_function = module + .load_function("sha256_many") + .map_err(driver_err)?; + + Ok(Self { + context, + stream, + device_name, + compute_capability: (cc_major, cc_minor), + max_block_size: 256, + module, + bit_reverse_function, + ntt_stage_function, + replicate_cosets_function, + transpose_function, + encode_bytes_function, + sha256_function, + roots_cache: Mutex::new(HashMap::new()), + buffer_pool: Mutex::new(HashMap::new()), + }) + } + + // ----- buffer pool ----------------------------------------------------- + + pub fn pooled_buffer(self: &Arc, len: usize) -> PooledBuffer { + self.pooled_bytes(len * size_of::()) + } + + pub fn pooled_bytes(self: &Arc, bytes: usize) -> PooledBuffer { + let bucket_bytes = bucket_bytes(bytes); + let buffer = self.take_buffer(bucket_bytes); + PooledBuffer(Arc::new(PooledBufferInner { + runtime: Arc::clone(self), + bucket_bytes, + buffer: ManuallyDrop::new(buffer), + })) + } + + fn take_buffer(&self, bucket_bytes: usize) -> CudaSlice { + // cudarc disallows zero-byte allocs; back empty buffers with a + // single byte instead. Callers never read it. + let alloc_bytes = bucket_bytes.max(1); + if let Some(buffer) = self + .buffer_pool + .lock() + .unwrap() + .get_mut(&bucket_bytes) + .and_then(Vec::pop) + { + return buffer; + } + // SAFETY: Caller is responsible for fully overwriting any region + // they later read from. `gpu_encode` zeroes the working buffer + // before any read; the SHA path memsets the tree buffer first; + // downloads only read regions explicitly written by a prior kernel. + unsafe { + self.stream + .alloc::(alloc_bytes) + .expect("CUDA pooled-buffer alloc") + } + } + + fn recycle_buffer(&self, bucket_bytes: usize, buffer: CudaSlice) { + if bucket_bytes == 0 { + return; + } + self.buffer_pool + .lock() + .unwrap() + .entry(bucket_bytes) + .or_default() + .push(buffer); + } + + // ----- raw byte-level memset / memcpy --------------------------------- + + /// Synchronously enqueue a byte memset of `bytes` bytes starting at + /// `offset_bytes` inside `dst`. + pub fn memset_zeros( + &self, + dst: &PooledBuffer, + offset_bytes: usize, + bytes: usize, + ) -> Result<(), String> { + if bytes == 0 { + return Ok(()); + } + debug_assert!(offset_bytes + bytes <= dst.bucket_bytes()); + let (ptr, _hold) = dst.slice().device_ptr(&self.stream); + // SAFETY: `ptr + offset_bytes` is in-bounds of the alloc; `bytes` + // does not exceed the bucket size; the stream is the alloc's stream. + unsafe { + result::memset_d8_async( + ptr + offset_bytes as sys::CUdeviceptr, + 0, + bytes, + self.stream.cu_stream(), + ) + } + .map_err(driver_err) + } + + /// Upload `host` into `dst` starting at `dst_offset_bytes`. `T` must be + /// layout-equivalent to `GpuField` (4×u64 Montgomery limbs); this is how + /// we support `&[Fr]` directly without a host pack pass. + /// + /// # Safety + /// + /// `T` must have identical size/alignment/layout to `GpuField`, and + /// `dst_offset_bytes + size_of::() * host.len()` must be `<= dst.bucket_bytes()`. + pub unsafe fn upload_into( + &self, + host: &[T], + dst: &PooledBuffer, + dst_offset_bytes: usize, + ) -> Result<(), String> { + if host.is_empty() { + return Ok(()); + } + debug_assert_eq!(size_of::(), size_of::()); + let bytes = std::mem::size_of_val(host); + debug_assert!(dst_offset_bytes + bytes <= dst.bucket_bytes()); + // SAFETY: caller guarantees layout equivalence. + let host_bytes: &[u8] = + unsafe { std::slice::from_raw_parts(host.as_ptr().cast::(), bytes) }; + let (ptr, _hold) = dst.slice().device_ptr(&self.stream); + // SAFETY: device pointer + offset is in-bounds; lifetime of `host` + // extends past the synchronisation we perform later (caller is + // expected to synchronise the stream before reusing the buffer). + unsafe { + result::memcpy_htod_async( + ptr + dst_offset_bytes as sys::CUdeviceptr, + host_bytes, + self.stream.cu_stream(), + ) + } + .map_err(driver_err) + } + + /// Download `dst.len()` `T` elements from `src` (starting at + /// `src_offset_bytes`) into `dst`. Synchronously: blocks the host until + /// the copy is complete. For batched downloads, prefer pairing + /// [`download_into_async`] calls with a single trailing + /// [`synchronize`]. + /// + /// # Safety + /// + /// Same as [`upload_into`]. + pub unsafe fn download_into( + &self, + src: &PooledBuffer, + src_offset_bytes: usize, + dst: &mut [T], + ) -> Result<(), String> { + // SAFETY: forwarded. + unsafe { self.download_into_async::(src, src_offset_bytes, dst) }?; + self.synchronize() + } + + /// Asynchronous variant of [`download_into`]: queues the device-to-host + /// copy on the stream but does NOT synchronise. Caller MUST synchronise + /// (or otherwise wait on the stream) before reading from `dst`. + /// + /// # Safety + /// + /// Same as [`upload_into`], plus the caller must keep `dst` alive and + /// not move it until the stream synchronises. + pub unsafe fn download_into_async( + &self, + src: &PooledBuffer, + src_offset_bytes: usize, + dst: &mut [T], + ) -> Result<(), String> { + if dst.is_empty() { + return Ok(()); + } + debug_assert_eq!(size_of::(), size_of::()); + let bytes = std::mem::size_of_val(dst); + debug_assert!(src_offset_bytes + bytes <= src.bucket_bytes()); + // SAFETY: caller guarantees layout equivalence. + let dst_bytes: &mut [u8] = + unsafe { std::slice::from_raw_parts_mut(dst.as_mut_ptr().cast::(), bytes) }; + let (ptr, _hold) = src.slice().device_ptr(&self.stream); + // SAFETY: device pointer + offset is in-bounds; caller is + // responsible for synchronising before reading `dst`. + unsafe { + result::memcpy_dtoh_async( + dst_bytes, + ptr + src_offset_bytes as sys::CUdeviceptr, + self.stream.cu_stream(), + ) + .map_err(driver_err) + } + } + + /// Download raw bytes (no type assumption). Synchronous. + pub fn download_bytes( + &self, + src: &PooledBuffer, + src_offset_bytes: usize, + dst: &mut [u8], + ) -> Result<(), String> { + self.download_bytes_async(src, src_offset_bytes, dst)?; + self.synchronize() + } + + /// Asynchronous raw-byte download (no implicit synchronise). + pub fn download_bytes_async( + &self, + src: &PooledBuffer, + src_offset_bytes: usize, + dst: &mut [u8], + ) -> Result<(), String> { + if dst.is_empty() { + return Ok(()); + } + debug_assert!(src_offset_bytes + dst.len() <= src.bucket_bytes()); + let (ptr, _hold) = src.slice().device_ptr(&self.stream); + // SAFETY: caller must synchronise before reading `dst`. + unsafe { + result::memcpy_dtoh_async( + dst, + ptr + src_offset_bytes as sys::CUdeviceptr, + self.stream.cu_stream(), + ) + .map_err(driver_err) + } + } + + /// Device-to-device byte copy (`bytes` bytes from + /// `src[src_offset_bytes..]` into `dst[dst_offset_bytes..]`). + pub fn memcpy_dtod_bytes( + &self, + dst: &PooledBuffer, + dst_offset_bytes: usize, + src: &PooledBuffer, + src_offset_bytes: usize, + bytes: usize, + ) -> Result<(), String> { + if bytes == 0 { + return Ok(()); + } + debug_assert!(dst_offset_bytes + bytes <= dst.bucket_bytes()); + debug_assert!(src_offset_bytes + bytes <= src.bucket_bytes()); + let (dst_ptr, _hold_d) = dst.slice().device_ptr(&self.stream); + let (src_ptr, _hold_s) = src.slice().device_ptr(&self.stream); + // SAFETY: ranges are in-bounds; both buffers belong to this stream. + unsafe { + result::memcpy_dtod_async( + dst_ptr + dst_offset_bytes as sys::CUdeviceptr, + src_ptr + src_offset_bytes as sys::CUdeviceptr, + bytes, + self.stream.cu_stream(), + ) + } + .map_err(driver_err) + } + + pub fn synchronize(&self) -> Result<(), String> { + self.stream.synchronize().map_err(driver_err) + } + + // ----- roots cache ----------------------------------------------------- + + /// Roots-of-unity table for an `codeword_length`-point NTT. Stage layout + /// matches the Metal backend exactly so the same kernel indexing works. + pub fn roots_buffer( + &self, + codeword_length: usize, + ) -> Result>, String> { + let mut cache = self.roots_cache.lock().unwrap(); + if let Some(buffer) = cache.get(&codeword_length) { + trace_event(format_args!( + "roots cache hit codeword_length={codeword_length}" + )); + return Ok(Arc::clone(buffer)); + } + + let root = Fr::get_root_of_unity(codeword_length as u64).unwrap(); + let stage_count = codeword_length.trailing_zeros() as usize; + let mut roots = Vec::with_capacity(codeword_length.saturating_sub(1)); + for stage in 0..stage_count { + let stage_size = 1usize << (stage + 1); + let half_stage = stage_size >> 1; + let stage_root = root.pow([(codeword_length / stage_size) as u64]); + let mut current = Fr::ONE; + for _ in 0..half_stage { + roots.push(fr_to_gpu(current)); + current *= stage_root; + } + } + + let buffer = self.stream.clone_htod(&roots).map_err(driver_err)?; + let arc = Arc::new(buffer); + cache.insert(codeword_length, Arc::clone(&arc)); + trace_event(format_args!( + "roots cache miss codeword_length={codeword_length}" + )); + Ok(arc) + } + + pub fn launch_cfg_1d(&self, work: usize) -> LaunchConfig { + let block: u32 = self.max_block_size.max(1); + let work = work.max(1) as u32; + let grid = work.div_ceil(block); + LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (block, 1, 1), + shared_mem_bytes: 0, + } + } +} + +// --------------------------------------------------------------------------- +// PTX caching: hash (source, arch) and store under $XDG_CACHE_HOME/provekit +// so subsequent runs skip the (~hundreds of ms) nvrtc compile. +// --------------------------------------------------------------------------- + +fn compile_or_load_ptx(source: &str, arch: Option<&'static str>) -> Result { + let cache_path = ptx_cache_path(source, arch); + if let Some(path) = cache_path.as_ref() { + if let Ok(text) = fs::read_to_string(path) { + trace_event(format_args!("ptx cache hit path={}", path.display())); + return Ok(Ptx::from_src(text)); + } + } + let mut opts = CompileOptions::default(); + opts.options.push("-std=c++17".into()); + if let Some(arch) = arch { + opts.arch = Some(arch); + } + let ptx = compile_ptx_with_opts(source, opts).map_err(|err| format!("nvrtc: {err:?}"))?; + if let Some(path) = cache_path.as_ref() { + if let Some(parent) = path.parent() { + let _ = fs::create_dir_all(parent); + } + let _ = fs::write(path, ptx.to_src()); + trace_event(format_args!("ptx cache wrote path={}", path.display())); + } + Ok(ptx) +} + +fn ptx_cache_path(source: &str, arch: Option<&str>) -> Option { + let cache_root = std::env::var_os("XDG_CACHE_HOME") + .map(PathBuf::from) + .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".cache")))?; + let mut hasher = DefaultHasher::new(); + "provekit-cuda-ntt".hash(&mut hasher); + arch.unwrap_or("generic").hash(&mut hasher); + source.hash(&mut hasher); + Some( + cache_root + .join("provekit") + .join("cuda") + .join(format!( + "ntt-{}-{:016x}.ptx", + arch.unwrap_or("generic"), + hasher.finish() + )), + ) +} + +fn arch_for_compute_capability(major: i32, minor: i32) -> Option<&'static str> { + match (major, minor) { + (5, 0) => Some("compute_50"), + (5, 2) => Some("compute_52"), + (6, 0) => Some("compute_60"), + (6, 1) => Some("compute_61"), + (7, 0) => Some("compute_70"), + (7, 5) => Some("compute_75"), + (8, 0) => Some("compute_80"), + (8, 6) => Some("compute_86"), + (8, 9) => Some("compute_89"), + (9, 0) => Some("compute_90"), + _ => None, + } +} + +fn driver_err(err: DriverError) -> String { + format!("{err:?}") +} + +fn bucket_bytes(bytes: usize) -> usize { + if bytes == 0 { + 0 + } else { + bytes.next_power_of_two() + } +} diff --git a/provekit/common/src/ntt/backends/cuda/field.rs b/provekit/common/src/ntt/backends/cuda/field.rs new file mode 100644 index 000000000..d2e6cf1ca --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/field.rs @@ -0,0 +1,14 @@ +use { + super::types::GpuField, + ark_bn254::{Fr, FrConfig}, + ark_ff::{BigInt, Fp, MontBackend}, + std::marker::PhantomData, +}; + +pub(super) fn fr_to_gpu(value: Fr) -> GpuField { + GpuField { limbs: value.0 .0 } +} + +pub(super) fn gpu_to_fr(value: GpuField) -> Fr { + Fp::, 4>(BigInt(value.limbs), PhantomData) +} diff --git a/provekit/common/src/ntt/backends/cuda/kernels/common.cuh b/provekit/common/src/ntt/backends/cuda/kernels/common.cuh new file mode 100644 index 000000000..9610e83ae --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/kernels/common.cuh @@ -0,0 +1,71 @@ +// CUDA equivalent of metal/kernels/common.metal — same struct layouts, same constants. + +typedef unsigned long long ulong; +typedef unsigned int uint; +typedef unsigned char uchar; + +struct Bn254Element { + ulong limbs[4]; +}; + +typedef Bn254Element Fe; + +struct StageConfig { + uint row_len; + uint half_m; + uint twiddle_offset; + uint _pad0; +}; + +struct BitReverseParams { + uint row_len; + uint log_n; + uint total_elements; + uint _pad0; +}; + +struct TransposeParams { + uint rows; + uint cols; + uint total_elements; +}; + +struct HashManyParams { + uint size; + uint count; +}; + +struct ReplicateCosetsParams { + uint row_len; + uint coset_size; + uint trailing_elements; +}; + +__device__ __constant__ ulong BN254_MODULUS[4] = { + 0x43e1f593f0000001ull, + 0x2833e84879b97091ull, + 0xb85045b68181585dull, + 0x30644e72e131a029ull, +}; + +__device__ __constant__ ulong BN254_N0PRIME = 0xc2e1f593efffffffull; +__device__ __constant__ Fe FE_ONE = {{1ull, 0ull, 0ull, 0ull}}; + +__device__ __constant__ uint SHA256_K[64] = { + 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, + 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, + 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, + 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, + 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, + 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, + 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, + 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, + 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, + 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, + 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, + 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, + 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, + 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, + 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, + 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u, +}; diff --git a/provekit/common/src/ntt/backends/cuda/kernels/field.cuh b/provekit/common/src/ntt/backends/cuda/kernels/field.cuh new file mode 100644 index 000000000..8ee1c3374 --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/kernels/field.cuh @@ -0,0 +1,142 @@ +// CUDA equivalent of metal/kernels/field.metal — same Montgomery arithmetic. + +__device__ __forceinline__ Fe make_element(ulong a0, ulong a1, ulong a2, ulong a3) { + Fe value; + value.limbs[0] = a0; + value.limbs[1] = a1; + value.limbs[2] = a2; + value.limbs[3] = a3; + return value; +} + +__device__ __forceinline__ bool ge_modulus(Fe value) { + if (value.limbs[3] != BN254_MODULUS[3]) { + return value.limbs[3] > BN254_MODULUS[3]; + } + if (value.limbs[2] != BN254_MODULUS[2]) { + return value.limbs[2] > BN254_MODULUS[2]; + } + if (value.limbs[1] != BN254_MODULUS[1]) { + return value.limbs[1] > BN254_MODULUS[1]; + } + return value.limbs[0] >= BN254_MODULUS[0]; +} + +__device__ __forceinline__ ulong add_with_carry(ulong a, ulong b, ulong &carry) { + ulong sum = a + b; + ulong c1 = sum < a ? 1ull : 0ull; + ulong sum_with_carry = sum + carry; + ulong c2 = sum_with_carry < sum ? 1ull : 0ull; + carry = c1 + c2; + return sum_with_carry; +} + +__device__ __forceinline__ ulong sub_with_borrow(ulong a, ulong b, ulong &borrow) { + ulong diff = a - b; + ulong b1 = diff > a ? 1ull : 0ull; + ulong diff_with_borrow = diff - borrow; + ulong b2 = diff_with_borrow > diff ? 1ull : 0ull; + borrow = b1 | b2; + return diff_with_borrow; +} + +__device__ __forceinline__ Fe sub_modulus(Fe value) { + ulong borrow = 0; + value.limbs[0] = sub_with_borrow(value.limbs[0], BN254_MODULUS[0], borrow); + value.limbs[1] = sub_with_borrow(value.limbs[1], BN254_MODULUS[1], borrow); + value.limbs[2] = sub_with_borrow(value.limbs[2], BN254_MODULUS[2], borrow); + value.limbs[3] = sub_with_borrow(value.limbs[3], BN254_MODULUS[3], borrow); + return value; +} + +__device__ __forceinline__ Fe add_modulus(Fe value) { + ulong carry = 0; + value.limbs[0] = add_with_carry(value.limbs[0], BN254_MODULUS[0], carry); + value.limbs[1] = add_with_carry(value.limbs[1], BN254_MODULUS[1], carry); + value.limbs[2] = add_with_carry(value.limbs[2], BN254_MODULUS[2], carry); + value.limbs[3] = add_with_carry(value.limbs[3], BN254_MODULUS[3], carry); + return value; +} + +__device__ __forceinline__ Fe add_mod(Fe lhs, Fe rhs) { + ulong carry = 0; + Fe result; + result.limbs[0] = add_with_carry(lhs.limbs[0], rhs.limbs[0], carry); + result.limbs[1] = add_with_carry(lhs.limbs[1], rhs.limbs[1], carry); + result.limbs[2] = add_with_carry(lhs.limbs[2], rhs.limbs[2], carry); + result.limbs[3] = add_with_carry(lhs.limbs[3], rhs.limbs[3], carry); + if (carry != 0 || ge_modulus(result)) { + result = sub_modulus(result); + } + return result; +} + +__device__ __forceinline__ Fe sub_mod(Fe lhs, Fe rhs) { + ulong borrow = 0; + Fe result; + result.limbs[0] = sub_with_borrow(lhs.limbs[0], rhs.limbs[0], borrow); + result.limbs[1] = sub_with_borrow(lhs.limbs[1], rhs.limbs[1], borrow); + result.limbs[2] = sub_with_borrow(lhs.limbs[2], rhs.limbs[2], borrow); + result.limbs[3] = sub_with_borrow(lhs.limbs[3], rhs.limbs[3], borrow); + if (borrow != 0) { + result = add_modulus(result); + } + return result; +} + +__device__ __forceinline__ void add_scaled_step(ulong &dst, ulong s, ulong a, ulong &carry) { + ulong product_lo = s * a; + ulong product_hi = __umul64hi(s, a); + ulong sum = dst + product_lo; + ulong carry0 = sum < dst ? 1ull : 0ull; + ulong sum_with_carry = sum + carry; + ulong carry1 = sum_with_carry < sum ? 1ull : 0ull; + dst = sum_with_carry; + carry = product_hi + carry0 + carry1; +} + +__device__ __forceinline__ void add_scaled(ulong *dst, ulong s, ulong a0, ulong a1, ulong a2, ulong a3) { + ulong carry = 0; + add_scaled_step(dst[0], s, a0, carry); + add_scaled_step(dst[1], s, a1, carry); + add_scaled_step(dst[2], s, a2, carry); + add_scaled_step(dst[3], s, a3, carry); + dst[4] += carry; +} + +__device__ __forceinline__ Fe mont_mul(Fe lhs, Fe rhs) { + ulong buf[9] = {0, 0, 0, 0, 0, 0, 0, 0, 0}; + uint off = 0; + +#pragma unroll + for (uint i = 0; i < 4; i++) { + add_scaled(&buf[off], lhs.limbs[i], + rhs.limbs[0], rhs.limbs[1], rhs.limbs[2], rhs.limbs[3]); + ulong m = buf[off] * BN254_N0PRIME; + add_scaled(&buf[off], m, + BN254_MODULUS[0], BN254_MODULUS[1], BN254_MODULUS[2], BN254_MODULUS[3]); + off += 1; + buf[off + 4] = 0; + } + + Fe result = make_element(buf[off], buf[off + 1], buf[off + 2], buf[off + 3]); + if (ge_modulus(result)) { + result = sub_modulus(result); + } + return result; +} + +__device__ __forceinline__ Fe canonicalize(Fe value) { + if (ge_modulus(value)) { + return sub_modulus(value); + } + return value; +} + +__device__ __forceinline__ Fe from_mont(Fe value) { + return canonicalize(mont_mul(value, FE_ONE)); +} + +__device__ __forceinline__ uint reverse_bits_width(uint value, uint width) { + return __brev(value) >> (32u - width); +} diff --git a/provekit/common/src/ntt/backends/cuda/kernels/matrix.cu b/provekit/common/src/ntt/backends/cuda/kernels/matrix.cu new file mode 100644 index 000000000..f521e0fff --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/kernels/matrix.cu @@ -0,0 +1,17 @@ +// CUDA equivalent of metal/kernels/matrix.metal — same transpose. + +extern "C" __global__ void transpose_matrix( + const Fe *input, + Fe *output, + TransposeParams params +) { + uint gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= params.total_elements) { + return; + } + + uint row = gid / params.cols; + uint col = gid - row * params.cols; + uint dst = col * params.rows + row; + output[dst] = input[gid]; +} diff --git a/provekit/common/src/ntt/backends/cuda/kernels/ntt.cu b/provekit/common/src/ntt/backends/cuda/kernels/ntt.cu new file mode 100644 index 000000000..625d60b55 --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/kernels/ntt.cu @@ -0,0 +1,72 @@ +// CUDA equivalent of metal/kernels/ntt.metal — same NTT butterfly + coset replicate. + +extern "C" __global__ void bit_reverse_permute_rows_in_place( + Fe *values, + BitReverseParams config +) { + uint index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= config.total_elements || config.row_len <= 1u) { + return; + } + + uint row = index / config.row_len; + uint within = index - row * config.row_len; + uint reversed = reverse_bits_width(within, config.log_n); + if (reversed <= within) { + return; + } + + uint row_base = row * config.row_len; + uint mate = row_base + reversed; + uint current = row_base + within; + Fe tmp = values[current]; + values[current] = values[mate]; + values[mate] = tmp; +} + +extern "C" __global__ void radix2_ntt_stage_rows_in_place( + Fe *values, + const Fe *twiddles, + StageConfig config, + uint total_butterflies +) { + uint index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= total_butterflies) { + return; + } + + uint butterflies_per_row = config.row_len >> 1u; + uint row = index / butterflies_per_row; + uint local = index - row * butterflies_per_row; + uint half_m = config.half_m; + uint pair_in_group = local % half_m; + uint group = local / half_m; + uint row_base = row * config.row_len; + uint base = row_base + group * (half_m << 1u) + pair_in_group; + uint mate = base + half_m; + + Fe even = values[base]; + Fe odd = values[mate]; + Fe twiddle = twiddles[config.twiddle_offset + pair_in_group]; + Fe t = mont_mul(twiddle, odd); + + values[base] = add_mod(even, t); + values[mate] = sub_mod(even, t); +} + +extern "C" __global__ void replicate_first_coset( + Fe *buffer, + ReplicateCosetsParams params +) { + uint gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= params.trailing_elements) { + return; + } + + uint repeats_per_row = params.row_len - params.coset_size; + uint row = gid / repeats_per_row; + uint within = gid - row * repeats_per_row; + uint dst = row * params.row_len + params.coset_size + within; + uint src = row * params.row_len + (within % params.coset_size); + buffer[dst] = buffer[src]; +} diff --git a/provekit/common/src/ntt/backends/cuda/kernels/sha256.cu b/provekit/common/src/ntt/backends/cuda/kernels/sha256.cu new file mode 100644 index 000000000..8655da5c6 --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/kernels/sha256.cu @@ -0,0 +1,187 @@ +// CUDA equivalent of metal/kernels/sha256.metal. +// +// Two kernels, exactly like the Metal port: +// - encode_field_rows_le: writes canonical little-endian bytes for each field, +// bit-reversing the row index so the byte matrix is in natural codeword order. +// - sha256_many: hashes equal-sized byte messages. + +struct FieldBytesParams { + uint rows; + uint cols; +}; + +__device__ __forceinline__ uint rotr32(uint x, uint n) { + return (x >> n) | (x << (32u - n)); +} + +__device__ __forceinline__ uint ch(uint x, uint y, uint z) { + return (x & y) ^ ((~x) & z); +} + +__device__ __forceinline__ uint maj(uint x, uint y, uint z) { + return (x & y) ^ (x & z) ^ (y & z); +} + +__device__ __forceinline__ uint big_sigma0(uint x) { + return rotr32(x, 2u) ^ rotr32(x, 13u) ^ rotr32(x, 22u); +} + +__device__ __forceinline__ uint big_sigma1(uint x) { + return rotr32(x, 6u) ^ rotr32(x, 11u) ^ rotr32(x, 25u); +} + +__device__ __forceinline__ uint small_sigma0(uint x) { + return rotr32(x, 7u) ^ rotr32(x, 18u) ^ (x >> 3u); +} + +__device__ __forceinline__ uint small_sigma1(uint x) { + return rotr32(x, 17u) ^ rotr32(x, 19u) ^ (x >> 10u); +} + +__device__ __forceinline__ void sha256_init(uint state[8]) { + state[0] = 0x6a09e667u; + state[1] = 0xbb67ae85u; + state[2] = 0x3c6ef372u; + state[3] = 0xa54ff53au; + state[4] = 0x510e527fu; + state[5] = 0x9b05688cu; + state[6] = 0x1f83d9abu; + state[7] = 0x5be0cd19u; +} + +__device__ __forceinline__ uchar sha256_padding_byte( + uint idx, uint size, uint total_padded_len, uint bit_len +) { + if (idx == size) { + return 0x80u; + } + if (idx >= total_padded_len - 8u) { + uint shift = (total_padded_len - 1u - idx) * 8u; + return shift >= 32u ? 0u : (uchar)((bit_len >> shift) & 0xffu); + } + return 0u; +} + +__device__ __forceinline__ uint sha256_load_byte_word( + const uchar *input, + uint offset, + uint block_base, + uint word_index, + uint size, + uint total_padded_len, + uint bit_len +) { + uint word = 0u; +#pragma unroll + for (uint j = 0; j < 4u; ++j) { + uint idx = block_base + word_index * 4u + j; + uchar byte = idx < size + ? input[offset + idx] + : sha256_padding_byte(idx, size, total_padded_len, bit_len); + word = (word << 8) | (uint)byte; + } + return word; +} + +__device__ __forceinline__ void sha256_extend_schedule(uint w[64]) { + for (uint i = 16u; i < 64u; ++i) { + w[i] = small_sigma1(w[i - 2u]) + w[i - 7u] + small_sigma0(w[i - 15u]) + w[i - 16u]; + } +} + +__device__ __forceinline__ void sha256_compress(uint state[8], const uint w[64]) { + uint a = state[0], b = state[1], c = state[2], d = state[3]; + uint e = state[4], f = state[5], g = state[6], h = state[7]; + + for (uint i = 0u; i < 64u; ++i) { + uint t1 = h + big_sigma1(e) + ch(e, f, g) + SHA256_K[i] + w[i]; + uint t2 = big_sigma0(a) + maj(a, b, c); + h = g; g = f; f = e; + e = d + t1; + d = c; c = b; b = a; + a = t1 + t2; + } + + state[0] += a; state[1] += b; state[2] += c; state[3] += d; + state[4] += e; state[5] += f; state[6] += g; state[7] += h; +} + +__device__ __forceinline__ void sha256_write_digest(uchar *out, const uint state[8]) { +#pragma unroll + for (uint i = 0u; i < 8u; ++i) { + out[i * 4u + 0u] = (uchar)((state[i] >> 24) & 0xffu); + out[i * 4u + 1u] = (uchar)((state[i] >> 16) & 0xffu); + out[i * 4u + 2u] = (uchar)((state[i] >> 8) & 0xffu); + out[i * 4u + 3u] = (uchar)( state[i] & 0xffu); + } +} + +// Convert each `Fe` from Montgomery form to canonical little-endian bytes. +// +// Reads `params.rows × params.cols` fields from `input` and writes +// `params.rows × params.cols × 32` bytes to `output`. The output rows are +// emitted in natural codeword order: the matrix in `input` is laid out in +// bit-reversed row order, so we apply `reverse_bits_width(row, log2(rows))` +// when reading from `input`. +extern "C" __global__ void encode_field_rows_le( + const Fe *input, + uchar *output, + FieldBytesParams params +) { + uint total_elements = params.rows * params.cols; + uint gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= total_elements) { + return; + } + + uint row = gid / params.cols; + uint col = gid - row * params.cols; + uint row_bits = 31u - __clz(params.rows); + uint src_row = (row_bits == 0u) ? row : (__brev(row) >> (32u - row_bits)); + Fe canonical = from_mont(input[src_row * params.cols + col]); + uint byte_offset = gid * 32u; +#pragma unroll + for (uint limb = 0u; limb < 4u; ++limb) { + ulong value = canonical.limbs[limb]; +#pragma unroll + for (uint byte = 0u; byte < 8u; ++byte) { + output[byte_offset + limb * 8u + byte] = + (uchar)((value >> (byte * 8u)) & 0xffull); + } + } +} + +// Hash `params.count` equal-size messages of `params.size` bytes each. +extern "C" __global__ void sha256_many( + const uchar *input, + uchar *output, + HashManyParams params +) { + uint gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= params.count) { + return; + } + + uint offset = gid * params.size; + uint total_blocks = (params.size + 9u + 63u) / 64u; + uint total_padded_len = total_blocks * 64u; + uint bit_len = params.size * 8u; + uint state[8]; + sha256_init(state); + + for (uint block = 0u; block < total_blocks; ++block) { + uint block_base = block * 64u; + uint w[64]; +#pragma unroll + for (uint i = 0u; i < 16u; ++i) { + w[i] = sha256_load_byte_word( + input, offset, block_base, i, + params.size, total_padded_len, bit_len + ); + } + sha256_extend_schedule(w); + sha256_compress(state, w); + } + + sha256_write_digest(output + gid * 32u, state); +} diff --git a/provekit/common/src/ntt/backends/cuda/logging.rs b/provekit/common/src/ntt/backends/cuda/logging.rs new file mode 100644 index 000000000..fcdd35e72 --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/logging.rs @@ -0,0 +1,7 @@ +use std::env; + +pub fn trace_event(args: std::fmt::Arguments<'_>) { + if env::var_os("PROVEKIT_CUDA_NTT_TRACE").is_some() { + eprintln!("[provekit-cuda-ntt] {args}"); + } +} diff --git a/provekit/common/src/ntt/backends/cuda/mod.rs b/provekit/common/src/ntt/backends/cuda/mod.rs new file mode 100644 index 000000000..b62722ecf --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/mod.rs @@ -0,0 +1,202 @@ +#[cfg(target_os = "linux")] +mod commit; +#[cfg(target_os = "linux")] +mod encode; +#[cfg(target_os = "linux")] +mod engine; +#[cfg(target_os = "linux")] +mod field; +mod logging; +#[cfg(target_os = "linux")] +mod types; + +#[cfg(target_os = "linux")] +use { + self::engine::CudaRuntime, + std::{ + env, + sync::{Arc, OnceLock}, + }, + tracing::info, + whir::{hash::SHA2, protocols::matrix_commit::Config as MatrixCommitConfig}, +}; +use { + self::logging::trace_event, + crate::ntt::backends::RSFr, + ark_bn254::Fr, + ark_ff::{FftField, Field}, + tracing::instrument, + whir::algebra::ntt::ReedSolomon, +}; + +/// CUDA-accelerated Reed–Solomon committer for BN254. +/// +/// Mirrors the structure of the Metal backend in `backends/metal/`: +/// - implements `IrsCommitter` (in `commit.rs`), so the encoded matrix +/// and the Merkle tree can stay on the GPU between `commit` and `open`, +/// - implements `ReedSolomon` for code paths that don't go through +/// the IRS committer (`gpu_encode` / CPU fallback). +#[derive(Clone, Copy, Debug, Default)] +pub struct CudaBn254Ntt; + +#[cfg(target_os = "linux")] +static RUNTIME: OnceLock, String>> = OnceLock::new(); + +impl CudaBn254Ntt { + /// Minimum problem size at which the GPU path is used. Below these the + /// CPU NTT (parallelised, SIMD-heavy) wins after subtracting per-call + const MIN_GPU_TOTAL_ELEMENTS: usize = 1 << 18; + const MIN_GPU_ROW_COUNT: usize = 64; + + #[cfg(target_os = "linux")] + pub fn new() -> Result { + if env::var_os("PROVEKIT_DISABLE_CUDA_NTT").is_some() { + return Err("CUDA NTT disabled via PROVEKIT_DISABLE_CUDA_NTT".into()); + } + + match RUNTIME.get_or_init(|| CudaRuntime::new().map(Arc::new)) { + Ok(runtime) => { + let (cc_major, cc_minor) = runtime.compute_capability; + info!( + device = %runtime.device_name, + compute_capability = format!("{cc_major}.{cc_minor}"), + "initialized CUDA BN254 NTT backend" + ); + trace_event(format_args!( + "init device={} compute_capability={cc_major}.{cc_minor}", + runtime.device_name, + )); + Ok(Self) + } + Err(err) => Err(err.clone()), + } + } + + #[cfg(not(target_os = "linux"))] + pub fn new() -> Result { + Err("CUDA BN254 NTT is only available on Linux".into()) + } + + #[cfg(target_os = "linux")] + pub(super) fn runtime(&self) -> Result<&Arc, String> { + match RUNTIME.get() { + Some(Ok(runtime)) => Ok(runtime), + Some(Err(err)) => Err(err.clone()), + None => Err("CUDA runtime not initialized".into()), + } + } + + fn supports_gpu_shape(codeword_length: usize, row_coeffs: &[&[Fr]]) -> bool { + if row_coeffs.is_empty() { + return false; + } + if codeword_length <= 1 || !codeword_length.is_power_of_two() { + return false; + } + let total_elements = row_coeffs.len().saturating_mul(codeword_length); + total_elements >= Self::MIN_GPU_TOTAL_ELEMENTS + || row_coeffs.len() >= Self::MIN_GPU_ROW_COUNT + } + + /// Only the SHA-2 hash family is implemented on GPU; for any other + /// configuration we fall back to the CPU committer. + #[cfg(target_os = "linux")] + fn supports_gpu_commit(matrix_commit: &MatrixCommitConfig) -> bool { + matrix_commit.leaf_hash_id == SHA2 + && matrix_commit + .merkle_tree + .layers + .iter() + .all(|layer| layer.hash_id == SHA2) + } +} + +impl ReedSolomon for CudaBn254Ntt { + fn next_order(&self, size: usize) -> Option { + let order = size.next_power_of_two(); + if order <= 1 << 28 { Some(order) } else { None } + } + + fn evaluation_points( + &self, + masked_message_length: usize, + codeword_length: usize, + indices: &[usize], + ) -> Vec { + let _ = masked_message_length; + let generator = self.generator(codeword_length); + indices + .iter() + .map(|i| { + let bits = usize::BITS - (codeword_length - 1).leading_zeros(); + let k = if bits == 0 { + *i + } else { + i.reverse_bits() >> (usize::BITS - bits) + }; + generator.pow([k as u64]) + }) + .collect() + } + + fn generator(&self, codeword_length: usize) -> Fr { + Fr::get_root_of_unity(codeword_length as u64).unwrap() + } + + /// Try GPU encode first; fall back to CPU on shape mismatch or any GPU + /// error. (The IRS committer path in `commit.rs` is what actually keeps + /// the matrix on device for the WHIR open phase.) + #[instrument(skip(self, messages, masks), fields( + num_messages = messages.len(), + message_len = messages.first().map(|c| c.len()), + codeword_length = codeword_length, + mask_len = masks.len().checked_div(messages.len()) + ))] + fn interleaved_encode( + &self, + messages: &[&[Fr]], + masks: &[Fr], + codeword_length: usize, + ) -> Vec { + if messages.is_empty() { + return vec![]; + } + + let num_messages = messages.len(); + let message_length = messages[0].len(); + for message in messages { + assert_eq!(message_length, message.len()); + } + assert!(masks.len().is_multiple_of(num_messages)); + + if !Self::supports_gpu_shape(codeword_length, messages) { + trace_event(format_args!( + "encode fallback path=cpu codeword_length={} rows={} reason=unsupported-shape", + codeword_length, num_messages, + )); + return RSFr.interleaved_encode(messages, masks, codeword_length); + } + + #[cfg(target_os = "linux")] + { + match self.gpu_encode(messages, masks, codeword_length) { + Ok(codeword) => return codeword, + Err(err) => { + trace_event(format_args!( + "encode fallback path=cpu codeword_length={} rows={} reason=gpu-error \ + error={}", + codeword_length, num_messages, err, + )); + } + } + } + + #[cfg(not(target_os = "linux"))] + trace_event(format_args!( + "encode fallback path=cpu codeword_length={} rows={} reason=unsupported-platform", + codeword_length, num_messages, + )); + + RSFr.interleaved_encode(messages, masks, codeword_length) + } +} diff --git a/provekit/common/src/ntt/backends/cuda/types.rs b/provekit/common/src/ntt/backends/cuda/types.rs new file mode 100644 index 000000000..8eb2e034a --- /dev/null +++ b/provekit/common/src/ntt/backends/cuda/types.rs @@ -0,0 +1,103 @@ +use {super::engine::PooledBuffer, whir::hash::Hash}; + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct GpuField { + pub limbs: [u64; 4], +} + +// SAFETY: GpuField is repr(C) and contains only POD limbs. +unsafe impl cudarc::driver::DeviceRepr for GpuField {} +// SAFETY: GpuField has no padding; an all-zero bit pattern represents the +// field element 0 (same in standard and Montgomery form). +unsafe impl cudarc::driver::ValidAsZeroBits for GpuField {} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct BitReverseParams { + pub row_len: u32, + pub log_n: u32, + pub total_elements: u32, + pub _pad0: u32, +} +// SAFETY: POD repr(C) struct. +unsafe impl cudarc::driver::DeviceRepr for BitReverseParams {} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct NttStageParams { + pub row_len: u32, + pub half_m: u32, + pub twiddle_offset: u32, + pub _pad0: u32, +} +// SAFETY: POD repr(C) struct. +unsafe impl cudarc::driver::DeviceRepr for NttStageParams {} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct TransposeParams { + pub rows: u32, + pub cols: u32, + pub total_elements: u32, +} +// SAFETY: POD repr(C) struct. +unsafe impl cudarc::driver::DeviceRepr for TransposeParams {} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct EncodeFieldBytesParams { + pub rows: u32, + pub cols: u32, +} +// SAFETY: POD repr(C) struct. +unsafe impl cudarc::driver::DeviceRepr for EncodeFieldBytesParams {} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct HashManyParams { + pub size: u32, + pub count: u32, +} +// SAFETY: POD repr(C) struct. +unsafe impl cudarc::driver::DeviceRepr for HashManyParams {} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct ReplicateCosetsParams { + pub row_len: u32, + pub coset_size: u32, + pub trailing_elements: u32, +} +// SAFETY: POD repr(C) struct. +unsafe impl cudarc::driver::DeviceRepr for ReplicateCosetsParams {} + +pub struct DeviceMatrix { + pub rows: usize, + pub cols: usize, + pub buffer: PooledBuffer, +} + +#[derive(Clone)] +pub struct DeviceRows { + pub rows: usize, + pub cols: usize, + pub buffer: PooledBuffer, +} + +pub struct DeviceMerkleWitness { + pub num_nodes: usize, + pub root: Hash, + pub buffer: PooledBuffer, +} + +#[derive(Clone, Copy, Debug)] +pub struct EncodeShape { + pub row_count: usize, + pub codeword_length: usize, + pub coset_size: usize, + pub message_length: usize, + pub mask_length: usize, + pub num_cosets: usize, + pub total_elements: usize, +} diff --git a/provekit/common/src/ntt/backends/metal/commit.rs b/provekit/common/src/ntt/backends/metal/commit.rs new file mode 100644 index 000000000..1a4acbc85 --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/commit.rs @@ -0,0 +1,319 @@ +use { + super::{ + engine::PooledBuffer, + field::gpu_to_fr, + types::{ + DeviceMatrix, DeviceMerkleWitness, DeviceRows, EncodeFieldBytesParams, GpuField, + HashManyParams, + }, + MetalBn254Ntt, + }, + ark_bn254::Fr, + metal::{MTLSize, NSRange, NSUInteger}, + std::{ffi::c_void, mem::size_of, sync::Arc}, + whir::{ + hash::Hash, + protocols::{ + irs_commit::{CpuIrsCommitter, IrsCommitArtifact, IrsCommitter, MatrixRows}, + matrix_commit::{Config as MatrixCommitConfig, Encodable}, + merkle_tree::WitnessTrait, + }, + }, +}; + +impl IrsCommitter for MetalBn254Ntt { + fn commit( + &self, + messages: &[&[Fr]], + masks: &[Fr], + codeword_length: usize, + matrix_commit: &MatrixCommitConfig, + ) -> IrsCommitArtifact { + let cpu_commit = || { + CpuIrsCommitter::new(Arc::new(crate::ntt::RSFr)).commit( + messages, + masks, + codeword_length, + matrix_commit, + ) + }; + + if !Self::supports_gpu_shape(codeword_length, messages) + || !Self::supports_gpu_commit(matrix_commit) + { + return cpu_commit(); + } + + let Ok(matrix) = self.encode_matrix(messages, masks, codeword_length) else { + return cpu_commit(); + }; + let Ok(leaf_hashes) = self.hash_rows_to_buffer(&matrix) else { + return cpu_commit(); + }; + let Ok(merkle_witness) = self.build_merkle_witness(matrix_commit, &leaf_hashes) else { + return cpu_commit(); + }; + + IrsCommitArtifact { + root: merkle_witness.root(), + rows: Arc::new(DeviceRows { + rows: matrix.rows, + cols: matrix.cols, + buffer: matrix.buffer, + }), + matrix_witness: merkle_witness, + } + } +} + +impl MetalBn254Ntt { + pub(super) fn hash_rows_to_buffer( + &self, + matrix: &DeviceMatrix, + ) -> Result { + if matrix.rows == 0 { + return Ok(self.runtime()?.pooled_buffer::(0)); + } + + let runtime = self.runtime()?; + let total_elements = matrix.rows * matrix.cols; + let total_bytes = total_elements * Fr::encoded_size(); + let message_size = matrix.cols * Fr::encoded_size(); + if total_elements > u32::MAX as usize || message_size > u32::MAX as usize { + return Err("GPU hash launch exceeds current 32-bit grid limit".into()); + } + + let encoded = runtime.pooled_bytes(total_bytes); + let hashes = runtime.pooled_buffer::(matrix.rows); + let encode_params = EncodeFieldBytesParams { + rows: matrix.rows as u32, + cols: matrix.cols as u32, + }; + let hash_params = HashManyParams { + size: message_size as u32, + count: matrix.rows as u32, + }; + let command_buffer = runtime.queue.new_command_buffer(); + + let encode_encoder = command_buffer.new_compute_command_encoder(); + encode_encoder.set_compute_pipeline_state(&runtime.encode_bytes_pipeline); + encode_encoder.set_buffer(0, Some(matrix.buffer.as_ref()), 0); + encode_encoder.set_buffer(1, Some(encoded.as_ref()), 0); + encode_encoder.set_bytes( + 2, + size_of::() as NSUInteger, + (&encode_params as *const EncodeFieldBytesParams).cast::(), + ); + let encode_threads = + runtime.threads_per_threadgroup(&runtime.encode_bytes_pipeline, total_elements); + encode_encoder.dispatch_threads( + MTLSize { + width: total_elements as u64, + height: 1, + depth: 1, + }, + encode_threads, + ); + encode_encoder.end_encoding(); + + let hash_encoder = command_buffer.new_compute_command_encoder(); + hash_encoder.set_compute_pipeline_state(&runtime.sha256_pipeline); + hash_encoder.set_buffer(0, Some(encoded.as_ref()), 0); + hash_encoder.set_buffer(1, Some(hashes.as_ref()), 0); + hash_encoder.set_bytes( + 2, + size_of::() as NSUInteger, + (&hash_params as *const HashManyParams).cast::(), + ); + let hash_threads = runtime.threads_per_threadgroup(&runtime.sha256_pipeline, matrix.rows); + hash_encoder.dispatch_threads( + MTLSize { + width: matrix.rows as u64, + height: 1, + depth: 1, + }, + hash_threads, + ); + hash_encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + Ok(hashes) + } + + pub(super) fn build_merkle_witness( + &self, + matrix_commit: &MatrixCommitConfig, + leaf_hashes: &PooledBuffer, + ) -> Result, String> { + let runtime = self.runtime()?; + let num_leaves = matrix_commit.num_rows(); + let leaf_capacity = 1usize << matrix_commit.merkle_tree.layers.len(); + let num_nodes = matrix_commit.merkle_tree.num_nodes(); + if leaf_capacity == 0 { + return Err("invalid empty Merkle leaf capacity".into()); + } + if num_nodes == 0 { + return Err("invalid empty Merkle tree".into()); + } + if num_leaves > leaf_capacity { + return Err("Merkle config has fewer layers than leaves require".into()); + } + if leaf_capacity > u32::MAX as usize { + return Err("GPU Merkle launch exceeds current 32-bit grid limit".into()); + } + + let tree = runtime.pooled_buffer::(num_nodes); + let command_buffer = runtime.queue.new_command_buffer(); + let blit = command_buffer.new_blit_command_encoder(); + blit.fill_buffer( + tree.as_ref(), + NSRange::new(0, (num_nodes * size_of::()) as u64), + 0, + ); + if num_leaves != 0 { + blit.copy_from_buffer( + leaf_hashes.as_ref(), + 0, + tree.as_ref(), + 0, + (num_leaves * size_of::()) as u64, + ); + } + blit.end_encoding(); + + let mut previous_offset = 0usize; + let mut previous_len = leaf_capacity; + for _ in matrix_commit.merkle_tree.layers.iter().rev() { + let current_len = previous_len / 2; + if current_len == 0 { + break; + } + if current_len > u32::MAX as usize { + return Err("GPU Merkle launch exceeds current 32-bit grid limit".into()); + } + + let params = HashManyParams { + size: 64, + count: current_len as u32, + }; + let current_offset = previous_offset + previous_len; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&runtime.sha256_pipeline); + encoder.set_buffer( + 0, + Some(tree.as_ref()), + (previous_offset * size_of::()) as u64, + ); + encoder.set_buffer( + 1, + Some(tree.as_ref()), + (current_offset * size_of::()) as u64, + ); + encoder.set_bytes( + 2, + size_of::() as NSUInteger, + (¶ms as *const HashManyParams).cast::(), + ); + let threads = runtime.threads_per_threadgroup(&runtime.sha256_pipeline, current_len); + encoder.dispatch_threads( + MTLSize { + width: current_len as u64, + height: 1, + depth: 1, + }, + threads, + ); + encoder.end_encoding(); + previous_offset = current_offset; + previous_len = current_len; + } + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let root = runtime.buffer_slice::(tree.as_ref(), num_nodes)[num_nodes - 1]; + + Ok(Arc::new(DeviceMerkleWitness { + num_nodes, + root, + buffer: tree, + })) + } +} + +impl WitnessTrait for DeviceMerkleWitness { + fn num_nodes(&self) -> usize { + self.num_nodes + } + + fn read_nodes(&self, indices: &[usize]) -> Vec { + let nodes = unsafe { + std::slice::from_raw_parts( + self.buffer.as_ref().contents().cast::(), + self.num_nodes, + ) + }; + let mut out = Vec::with_capacity(indices.len()); + for &index in indices { + assert!(index < self.num_nodes, "Merkle node index out of bounds"); + out.push(nodes[index]); + } + out + } +} + +impl std::fmt::Debug for DeviceRows { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DeviceRows") + .field("rows", &self.rows) + .field("cols", &self.cols) + .finish() + } +} + +impl MatrixRows for DeviceRows { + fn num_rows(&self) -> usize { + self.rows + } + + fn num_cols(&self) -> usize { + self.cols + } + + fn read_rows(&self, indices: &[usize]) -> Vec { + let mut out = Vec::with_capacity(indices.len() * self.cols); + let fields = unsafe { + std::slice::from_raw_parts( + self.buffer.as_ref().contents().cast::(), + self.rows * self.cols, + ) + }; + for &row in indices { + assert!(row < self.rows, "row index out of bounds"); + let start = reverse_bit_index(row, self.rows) * self.cols; + let end = start + self.cols; + out.extend(fields[start..end].iter().copied().map(gpu_to_fr)); + } + out + } +} + +impl std::fmt::Debug for DeviceMerkleWitness { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DeviceMerkleWitness") + .field("num_nodes", &self.num_nodes) + .field("root", &self.root) + .finish() + } +} + +fn reverse_bit_index(index: usize, codeword_length: usize) -> usize { + let bits = usize::BITS - (codeword_length - 1).leading_zeros(); + if bits == 0 { + index + } else { + index.reverse_bits() >> (usize::BITS - bits) + } +} diff --git a/provekit/common/src/ntt/backends/metal/encode.rs b/provekit/common/src/ntt/backends/metal/encode.rs new file mode 100644 index 000000000..5a67d1900 --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/encode.rs @@ -0,0 +1,321 @@ +use { + super::{ + field::{fr_to_gpu, gpu_to_fr}, + logging::trace_event, + types::{ + BitReverseParams, DeviceMatrix, EncodeShape, GpuField, NttStageParams, + ReplicateCosetsParams, + TransposeParams, + }, + MetalBn254Ntt, + }, + ark_bn254::Fr, + ark_ff::AdditiveGroup, + metal::{MTLSize, NSUInteger}, + rayon::prelude::*, + std::{ffi::c_void, mem::size_of}, + tracing::instrument, + whir::algebra::ntt::ReedSolomon, +}; + +impl MetalBn254Ntt { + #[instrument(skip(self, messages, masks), fields( + num_messages = messages.len(), + message_len = messages.first().map(|c| c.len()), + codeword_length = codeword_length, + mask_len = masks.len().checked_div(messages.len()) + ))] + pub fn gpu_encode( + &self, + messages: &[&[Fr]], + masks: &[Fr], + codeword_length: usize, + ) -> Result, String> { + let matrix = self.encode_matrix(messages, masks, codeword_length)?; + let fields = self + .runtime()? + .buffer_slice::(matrix.buffer.as_ref(), matrix.rows * matrix.cols); + if matrix.rows == 0 || matrix.cols == 0 { + return Ok(Vec::new()); + } + + let mut output = vec![Fr::ZERO; matrix.rows * matrix.cols]; + output + .par_chunks_mut(matrix.cols) + .enumerate() + .for_each(|(dst_row, dst)| { + let natural_row = reverse_bit_index(dst_row, matrix.rows); + let src_start = natural_row * matrix.cols; + let src = &fields[src_start..src_start + matrix.cols]; + dst.iter_mut() + .zip(src.iter().copied()) + .for_each(|(dst, src)| *dst = gpu_to_fr(src)); + }); + Ok(output) + } + + #[instrument(skip(self, messages, masks), fields( + num_messages = messages.len(), + message_len = messages.first().map(|c| c.len()), + codeword_length = codeword_length, + mask_len = masks.len().checked_div(messages.len()) + ))] + pub fn encode_matrix( + &self, + messages: &[&[Fr]], + masks: &[Fr], + codeword_length: usize, + ) -> Result { + let runtime = self.runtime()?; + let shape = Self::encode_shape(messages, masks, codeword_length)?; + if shape.total_elements == 0 { + return Ok(DeviceMatrix { + rows: 0, + cols: 0, + buffer: runtime.pooled_buffer::(0), + }); + } + + trace_event(format_args!( + "encode rows={} codeword_length={} num_cosets={} coset_size={} polynomials={} \ + path=coset", + shape.row_count, + codeword_length, + shape.num_cosets, + shape.coset_size, + messages.len(), + )); + + let current = runtime.pooled_buffer::(shape.total_elements); + runtime.zero_buffer::(current.as_ref(), shape.total_elements); + pack_messages_and_masks_into_buffer( + runtime.buffer_slice_mut(current.as_ref(), shape.total_elements), + messages, + masks, + shape, + ); + let roots = runtime.roots_buffer(codeword_length)?; + + let transposed = runtime.pooled_buffer::(shape.total_elements); + let stage_count = codeword_length.trailing_zeros() as usize; + let skipped_stage_count = shape.num_cosets.trailing_zeros() as usize; + let total_butterflies = shape.total_elements / 2; + let bit_reverse_threads = runtime + .threads_per_threadgroup(&runtime.bit_reverse_pipeline, shape.total_elements); + let stage_threads = + runtime.threads_per_threadgroup(&runtime.ntt_stage_pipeline, total_butterflies); + let transpose_threads = + runtime.threads_per_threadgroup(&runtime.transpose_pipeline, shape.total_elements); + let bit_reverse_params = BitReverseParams { + row_len: shape.codeword_length as u32, + log_n: stage_count as u32, + total_elements: shape.total_elements as u32, + _pad0: 0, + }; + let transpose_params = TransposeParams { + rows: shape.row_count as u32, + cols: shape.codeword_length as u32, + total_elements: shape.total_elements as u32, + }; + + let command_buffer = runtime.queue.new_command_buffer(); + let replicate_params = ReplicateCosetsParams { + row_len: shape.codeword_length as u32, + coset_size: shape.coset_size as u32, + trailing_elements: shape + .row_count + .saturating_mul(shape.codeword_length - shape.coset_size) + as u32, + }; + if replicate_params.trailing_elements != 0 { + let replicate_encoder = command_buffer.new_compute_command_encoder(); + replicate_encoder.set_compute_pipeline_state(&runtime.replicate_cosets_pipeline); + replicate_encoder.set_buffer(0, Some(current.as_ref()), 0); + replicate_encoder.set_bytes( + 1, + size_of::() as u64, + (&replicate_params as *const ReplicateCosetsParams).cast::(), + ); + let replicate_threads = runtime.threads_per_threadgroup( + &runtime.replicate_cosets_pipeline, + replicate_params.trailing_elements as usize, + ); + replicate_encoder.dispatch_threads( + MTLSize { + width: replicate_params.trailing_elements as u64, + height: 1, + depth: 1, + }, + replicate_threads, + ); + replicate_encoder.end_encoding(); + } + let bit_reverse_encoder = command_buffer.new_compute_command_encoder(); + bit_reverse_encoder.set_compute_pipeline_state(&runtime.bit_reverse_pipeline); + bit_reverse_encoder.set_buffer(0, Some(current.as_ref()), 0); + bit_reverse_encoder.set_bytes( + 1, + size_of::() as NSUInteger, + (&bit_reverse_params as *const BitReverseParams).cast::(), + ); + bit_reverse_encoder.dispatch_threads( + MTLSize { + width: shape.total_elements as u64, + height: 1, + depth: 1, + }, + bit_reverse_threads, + ); + bit_reverse_encoder.end_encoding(); + + let stage_encoder = command_buffer.new_compute_command_encoder(); + stage_encoder.set_compute_pipeline_state(&runtime.ntt_stage_pipeline); + + let mut twiddle_offset = (1usize << skipped_stage_count).saturating_sub(1); + for stage in skipped_stage_count..stage_count { + let half_m = 1usize << stage; + let params = NttStageParams { + row_len: shape.codeword_length as u32, + half_m: half_m as u32, + twiddle_offset: twiddle_offset as u32, + _pad0: 0, + }; + stage_encoder.set_buffer(0, Some(current.as_ref()), 0); + stage_encoder.set_buffer(1, Some(roots.as_ref()), 0); + stage_encoder.set_bytes( + 2, + size_of::() as NSUInteger, + (¶ms as *const NttStageParams).cast::(), + ); + stage_encoder.dispatch_threads( + MTLSize { + width: total_butterflies as u64, + height: 1, + depth: 1, + }, + stage_threads, + ); + twiddle_offset += 1usize << stage; + } + stage_encoder.end_encoding(); + + let transpose_encoder = command_buffer.new_compute_command_encoder(); + transpose_encoder.set_compute_pipeline_state(&runtime.transpose_pipeline); + transpose_encoder.set_buffer(0, Some(current.as_ref()), 0); + transpose_encoder.set_buffer(1, Some(transposed.as_ref()), 0); + transpose_encoder.set_bytes( + 2, + size_of::() as NSUInteger, + (&transpose_params as *const TransposeParams).cast::(), + ); + transpose_encoder.dispatch_threads( + MTLSize { + width: shape.total_elements as u64, + height: 1, + depth: 1, + }, + transpose_threads, + ); + transpose_encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + Ok(DeviceMatrix { + rows: shape.codeword_length, + cols: shape.row_count, + buffer: transposed, + }) + } + + pub fn encode_shape( + messages: &[&[Fr]], + masks: &[Fr], + codeword_length: usize, + ) -> Result { + if messages.is_empty() { + return Ok(EncodeShape { + row_count: 0, + codeword_length, + coset_size: 0, + message_length: 0, + mask_length: 0, + num_cosets: 0, + total_elements: 0, + }); + } + if !Self::supports_gpu_shape(codeword_length, messages) { + return Err("problem shape unsupported for GPU path".into()); + } + + let row_count = messages.len(); + let message_length = messages[0].len(); + if messages.iter().any(|row| row.len() != message_length) { + return Err("all messages must have the same length".into()); + } + if !masks.len().is_multiple_of(row_count) { + return Err("mask count must be divisible by row count".into()); + } + let mask_length = masks.len() / row_count; + let masked_message_length = message_length + mask_length; + let mut coset_size = Self::default() + .next_order(masked_message_length) + .ok_or_else(|| "no supported coset size for encode".to_string())?; + while !codeword_length.is_multiple_of(coset_size) { + coset_size = Self::default() + .next_order(coset_size + 1) + .ok_or_else(|| "no supported coset size for encode".to_string())?; + } + let num_cosets = codeword_length / coset_size; + + let total_elements = row_count + .checked_mul(codeword_length) + .ok_or_else(|| "GPU encode launch exceeds current 32-bit grid limit".to_string())?; + if total_elements > u32::MAX as usize { + return Err("GPU encode launch exceeds current 32-bit grid limit".into()); + } + + Ok(EncodeShape { + row_count, + codeword_length, + coset_size, + message_length, + mask_length, + num_cosets, + total_elements, + }) + } + +} + +fn pack_messages_and_masks_into_buffer( + packed: &mut [GpuField], + messages: &[&[Fr]], + masks: &[Fr], + shape: EncodeShape, +) { + packed + .par_chunks_mut(shape.codeword_length) + .enumerate() + .for_each(|(row_index, row)| { + for (dst, &coeff) in row[..shape.message_length] + .iter_mut() + .zip(messages[row_index]) + { + *dst = fr_to_gpu(coeff); + } + for mask_column in 0..shape.mask_length { + row[shape.message_length + mask_column] = + fr_to_gpu(masks[mask_column * shape.row_count + row_index]); + } + }); +} + +fn reverse_bit_index(index: usize, codeword_length: usize) -> usize { + let bits = usize::BITS - (codeword_length - 1).leading_zeros(); + if bits == 0 { + index + } else { + index.reverse_bits() >> (usize::BITS - bits) + } +} diff --git a/provekit/common/src/ntt/backends/metal/engine.rs b/provekit/common/src/ntt/backends/metal/engine.rs new file mode 100644 index 000000000..08e69d7f6 --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/engine.rs @@ -0,0 +1,268 @@ +use { + super::{field::fr_to_gpu, logging::trace_event}, + ark_bn254::Fr, + ark_ff::{FftField, Field}, + metal::{ + objc::rc::autoreleasepool, Buffer, CommandQueue, CompileOptions, ComputePipelineState, + Device, Library, MTLResourceOptions, MTLSize, NSUInteger, + }, + std::{ + collections::HashMap, + ffi::c_void, + mem::size_of, + ptr, + sync::{Arc, Mutex}, + }, +}; + +const SHADER_SOURCE: &str = concat!( + include_str!("kernels/common.metal"), + "\n", + include_str!("kernels/field.metal"), + "\n", + include_str!("kernels/ntt.metal"), + "\n", + include_str!("kernels/matrix.metal"), + "\n", + include_str!("kernels/sha256.metal"), + "\n", +); + +struct PooledBufferInner { + runtime: Arc, + bucket_bytes: usize, + buffer: Buffer, +} + +#[derive(Clone)] +pub struct PooledBuffer(Arc); + +impl PooledBuffer { + pub fn as_ref(&self) -> &Buffer { + &self.0.buffer + } +} + +impl std::fmt::Debug for PooledBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PooledBuffer") + .field("length", &self.0.buffer.length()) + .finish() + } +} + +impl std::ops::Deref for PooledBuffer { + type Target = Buffer; + + fn deref(&self) -> &Self::Target { + &self.0.buffer + } +} + +impl AsRef for PooledBuffer { + fn as_ref(&self) -> &Buffer { + &self.0.buffer + } +} + +impl Drop for PooledBufferInner { + fn drop(&mut self) { + self.runtime + .recycle_buffer(self.bucket_bytes, self.buffer.to_owned()); + } +} + +pub struct MetalRuntime { + pub device: Device, + pub queue: CommandQueue, + pub bit_reverse_pipeline: ComputePipelineState, + pub ntt_stage_pipeline: ComputePipelineState, + pub replicate_cosets_pipeline: ComputePipelineState, + pub transpose_pipeline: ComputePipelineState, + pub encode_bytes_pipeline: ComputePipelineState, + pub sha256_pipeline: ComputePipelineState, + roots_cache: Mutex>>, + buffer_pool: Mutex>>, +} + +impl MetalRuntime { + pub fn new() -> Result { + autoreleasepool(|| { + let device = Device::system_default() + .or_else(|| Device::all().into_iter().next()) + .ok_or_else(|| { + "no Metal device found; sandboxed macOS processes may not expose Metal" + .to_string() + })?; + let options = CompileOptions::new(); + let library = device.new_library_with_source(SHADER_SOURCE, &options)?; + + Ok(Self { + device: device.to_owned(), + queue: device.new_command_queue(), + bit_reverse_pipeline: Self::new_pipeline( + &device, + &library, + "bit_reverse_permute_rows_in_place", + )?, + ntt_stage_pipeline: Self::new_pipeline( + &device, + &library, + "radix2_ntt_stage_rows_in_place", + )?, + replicate_cosets_pipeline: Self::new_pipeline( + &device, + &library, + "replicate_first_coset", + )?, + transpose_pipeline: Self::new_pipeline( + &device, + &library, + "transpose_matrix", + )?, + encode_bytes_pipeline: Self::new_pipeline( + &device, + &library, + "encode_field_rows_le", + )?, + sha256_pipeline: Self::new_pipeline(&device, &library, "sha256_many")?, + roots_cache: Mutex::new(HashMap::new()), + buffer_pool: Mutex::new(HashMap::new()), + }) + }) + } + + pub fn buffer_with_data(&self, values: &[T]) -> Buffer { + self.device.new_buffer_with_data( + values.as_ptr().cast::(), + std::mem::size_of_val(values) as NSUInteger, + MTLResourceOptions::StorageModeShared, + ) + } + + pub fn pooled_buffer(self: &Arc, len: usize) -> PooledBuffer { + self.pooled_bytes(len * size_of::()) + } + + pub fn pooled_bytes(self: &Arc, len: usize) -> PooledBuffer { + let bucket_bytes = bucket_bytes(len); + let buffer = self.take_buffer(bucket_bytes); + PooledBuffer(Arc::new(PooledBufferInner { + runtime: Arc::clone(self), + bucket_bytes, + buffer, + })) + } + + pub fn buffer_slice<'a, T>(&self, buffer: &'a Buffer, len: usize) -> &'a [T] { + let ptr = buffer.contents().cast::(); + unsafe { std::slice::from_raw_parts(ptr, len) } + } + + pub fn buffer_slice_mut<'a, T>(&self, buffer: &'a Buffer, len: usize) -> &'a mut [T] { + let ptr = buffer.contents().cast::(); + unsafe { std::slice::from_raw_parts_mut(ptr, len) } + } + + pub fn zero_buffer(&self, buffer: &Buffer, len: usize) { + if len == 0 { + return; + } + unsafe { + ptr::write_bytes(buffer.contents(), 0, len * size_of::()); + } + } + + pub fn threads_per_threadgroup( + &self, + pipeline: &ComputePipelineState, + work_items: usize, + ) -> MTLSize { + let width = pipeline + .thread_execution_width() + .min(pipeline.max_total_threads_per_threadgroup()) + .min(work_items as u64) + .max(1); + MTLSize { + width, + height: 1, + depth: 1, + } + } + + pub fn roots_buffer(&self, codeword_length: usize) -> Result, String> { + let mut cache = self.roots_cache.lock().unwrap(); + if let Some(buffer) = cache.get(&codeword_length) { + trace_event(format_args!( + "roots cache hit codeword_length={codeword_length}" + )); + return Ok(Arc::clone(buffer)); + } + + let root = Fr::get_root_of_unity(codeword_length as u64).unwrap(); + let stage_count = codeword_length.trailing_zeros() as usize; + let mut roots = Vec::with_capacity(codeword_length.saturating_sub(1)); + for stage in 0..stage_count { + let stage_size = 1usize << (stage + 1); + let half_stage = stage_size >> 1; + let stage_root = root.pow([(codeword_length / stage_size) as u64]); + let mut current = Fr::ONE; + for _ in 0..half_stage { + roots.push(fr_to_gpu(current)); + current *= stage_root; + } + } + + let buffer = Arc::new(self.buffer_with_data(&roots)); + cache.insert(codeword_length, Arc::clone(&buffer)); + trace_event(format_args!( + "roots cache miss codeword_length={codeword_length}" + )); + Ok(buffer) + } + + fn take_buffer(&self, bucket_bytes: usize) -> Buffer { + if bucket_bytes == 0 { + return self + .device + .new_buffer(0, MTLResourceOptions::StorageModeShared); + } + + let mut pool = self.buffer_pool.lock().unwrap(); + if let Some(buffer) = pool.get_mut(&bucket_bytes).and_then(Vec::pop) { + return buffer; + } + drop(pool); + + self.device + .new_buffer(bucket_bytes as u64, MTLResourceOptions::StorageModeShared) + } + + fn recycle_buffer(&self, bucket_bytes: usize, buffer: Buffer) { + if bucket_bytes == 0 { + return; + } + + let mut pool = self.buffer_pool.lock().unwrap(); + pool.entry(bucket_bytes).or_default().push(buffer); + } + + fn new_pipeline( + device: &Device, + library: &Library, + function_name: &str, + ) -> Result { + library + .get_function(function_name, None) + .map_err(|err| err.to_string()) + .and_then(|function| device.new_compute_pipeline_state_with_function(&function)) + } +} + +fn bucket_bytes(bytes: usize) -> usize { + if bytes == 0 { + 0 + } else { + bytes.next_power_of_two() + } +} diff --git a/provekit/common/src/ntt/backends/metal/field.rs b/provekit/common/src/ntt/backends/metal/field.rs new file mode 100644 index 000000000..d2e6cf1ca --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/field.rs @@ -0,0 +1,14 @@ +use { + super::types::GpuField, + ark_bn254::{Fr, FrConfig}, + ark_ff::{BigInt, Fp, MontBackend}, + std::marker::PhantomData, +}; + +pub(super) fn fr_to_gpu(value: Fr) -> GpuField { + GpuField { limbs: value.0 .0 } +} + +pub(super) fn gpu_to_fr(value: GpuField) -> Fr { + Fp::, 4>(BigInt(value.limbs), PhantomData) +} diff --git a/provekit/common/src/ntt/backends/metal/kernels/common.metal b/provekit/common/src/ntt/backends/metal/kernels/common.metal new file mode 100644 index 000000000..e4aed3178 --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/kernels/common.metal @@ -0,0 +1,74 @@ +#include + +using namespace metal; + +struct Bn254Element { + ulong limbs[4]; +}; + +typedef Bn254Element Fe; + +struct StageConfig { + uint row_len; + uint half_m; + uint twiddle_offset; + uint _pad0; +}; + +struct BitReverseParams { + uint row_len; + uint log_n; + uint total_elements; + uint _pad0; +}; + +struct TransposeParams { + uint rows; + uint cols; + uint total_elements; +}; + +struct FieldBytesParams { + uint rows; + uint cols; +}; + +struct HashManyParams { + uint size; + uint count; +}; + +struct ReplicateCosetsParams { + uint row_len; + uint coset_size; + uint trailing_elements; +}; + +constant ulong BN254_MODULUS[4] = { + 0x43e1f593f0000001ul, + 0x2833e84879b97091ul, + 0xb85045b68181585dul, + 0x30644e72e131a029ul, +}; + +constant ulong BN254_N0PRIME = 0xc2e1f593effffffful; +constant Fe FE_ONE = {{1ul, 0ul, 0ul, 0ul}}; + +constant uint SHA256_K[64] = { + 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, + 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, + 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, + 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, + 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, + 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, + 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, + 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, + 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, + 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, + 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, + 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, + 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, + 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, + 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, + 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u +}; diff --git a/provekit/common/src/ntt/backends/metal/kernels/field.metal b/provekit/common/src/ntt/backends/metal/kernels/field.metal new file mode 100644 index 000000000..453571f7e --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/kernels/field.metal @@ -0,0 +1,160 @@ +inline Fe make_element(ulong a0, ulong a1, ulong a2, ulong a3) { + Fe value; + value.limbs[0] = a0; + value.limbs[1] = a1; + value.limbs[2] = a2; + value.limbs[3] = a3; + return value; +} + +inline bool ge_modulus(Fe value) { + if (value.limbs[3] != BN254_MODULUS[3]) { + return value.limbs[3] > BN254_MODULUS[3]; + } + if (value.limbs[2] != BN254_MODULUS[2]) { + return value.limbs[2] > BN254_MODULUS[2]; + } + if (value.limbs[1] != BN254_MODULUS[1]) { + return value.limbs[1] > BN254_MODULUS[1]; + } + return value.limbs[0] >= BN254_MODULUS[0]; +} + +inline ulong add_with_carry(ulong a, ulong b, thread ulong &carry) { + ulong sum = a + b; + ulong c1 = sum < a ? 1ul : 0ul; + ulong sum_with_carry = sum + carry; + ulong c2 = sum_with_carry < sum ? 1ul : 0ul; + carry = c1 + c2; + return sum_with_carry; +} + +inline ulong sub_with_borrow(ulong a, ulong b, thread ulong &borrow) { + ulong diff = a - b; + ulong b1 = diff > a ? 1ul : 0ul; + ulong diff_with_borrow = diff - borrow; + ulong b2 = diff_with_borrow > diff ? 1ul : 0ul; + borrow = b1 | b2; + return diff_with_borrow; +} + +inline Fe sub_modulus(Fe value) { + ulong borrow = 0; + value.limbs[0] = sub_with_borrow(value.limbs[0], BN254_MODULUS[0], borrow); + value.limbs[1] = sub_with_borrow(value.limbs[1], BN254_MODULUS[1], borrow); + value.limbs[2] = sub_with_borrow(value.limbs[2], BN254_MODULUS[2], borrow); + value.limbs[3] = sub_with_borrow(value.limbs[3], BN254_MODULUS[3], borrow); + return value; +} + +inline Fe add_modulus(Fe value) { + ulong carry = 0; + value.limbs[0] = add_with_carry(value.limbs[0], BN254_MODULUS[0], carry); + value.limbs[1] = add_with_carry(value.limbs[1], BN254_MODULUS[1], carry); + value.limbs[2] = add_with_carry(value.limbs[2], BN254_MODULUS[2], carry); + value.limbs[3] = add_with_carry(value.limbs[3], BN254_MODULUS[3], carry); + return value; +} + +inline Fe add_mod(Fe lhs, Fe rhs) { + ulong carry = 0; + Fe result; + result.limbs[0] = add_with_carry(lhs.limbs[0], rhs.limbs[0], carry); + result.limbs[1] = add_with_carry(lhs.limbs[1], rhs.limbs[1], carry); + result.limbs[2] = add_with_carry(lhs.limbs[2], rhs.limbs[2], carry); + result.limbs[3] = add_with_carry(lhs.limbs[3], rhs.limbs[3], carry); + + if (carry != 0 || ge_modulus(result)) { + result = sub_modulus(result); + } + + return result; +} + +inline Fe sub_mod(Fe lhs, Fe rhs) { + ulong borrow = 0; + Fe result; + result.limbs[0] = sub_with_borrow(lhs.limbs[0], rhs.limbs[0], borrow); + result.limbs[1] = sub_with_borrow(lhs.limbs[1], rhs.limbs[1], borrow); + result.limbs[2] = sub_with_borrow(lhs.limbs[2], rhs.limbs[2], borrow); + result.limbs[3] = sub_with_borrow(lhs.limbs[3], rhs.limbs[3], borrow); + + if (borrow != 0) { + result = add_modulus(result); + } + + return result; +} + +inline void add_scaled_step(thread ulong &dst, ulong s, ulong a, thread ulong &carry) { + ulong product_lo = s * a; + ulong product_hi = mulhi(s, a); + + ulong sum = dst + product_lo; + ulong carry0 = sum < dst ? 1ul : 0ul; + ulong sum_with_carry = sum + carry; + ulong carry1 = sum_with_carry < sum ? 1ul : 0ul; + + dst = sum_with_carry; + carry = product_hi + carry0 + carry1; +} + +inline void add_scaled(thread ulong *dst, ulong s, ulong a0, ulong a1, ulong a2, ulong a3) { + ulong carry = 0; + add_scaled_step(dst[0], s, a0, carry); + add_scaled_step(dst[1], s, a1, carry); + add_scaled_step(dst[2], s, a2, carry); + add_scaled_step(dst[3], s, a3, carry); + dst[4] += carry; +} + +inline Fe mont_mul(Fe lhs, Fe rhs) { + ulong buf[9] = {0}; + uint off = 0; + +#pragma clang loop unroll(enable) + for (uint i = 0; i < 4; i++) { + add_scaled( + &buf[off], + lhs.limbs[i], + rhs.limbs[0], + rhs.limbs[1], + rhs.limbs[2], + rhs.limbs[3] + ); + + ulong m = buf[off] * BN254_N0PRIME; + add_scaled( + &buf[off], + m, + BN254_MODULUS[0], + BN254_MODULUS[1], + BN254_MODULUS[2], + BN254_MODULUS[3] + ); + + off += 1; + buf[off + 4] = 0; + } + + Fe result = make_element(buf[off], buf[off + 1], buf[off + 2], buf[off + 3]); + if (ge_modulus(result)) { + result = sub_modulus(result); + } + return result; +} + +inline Fe canonicalize(Fe value) { + if (ge_modulus(value)) { + return sub_modulus(value); + } + return value; +} + +inline Fe from_mont(Fe value) { + return canonicalize(mont_mul(value, FE_ONE)); +} + +inline uint reverse_bits_width(uint value, uint width) { + return reverse_bits(value) >> (32u - width); +} diff --git a/provekit/common/src/ntt/backends/metal/kernels/matrix.metal b/provekit/common/src/ntt/backends/metal/kernels/matrix.metal new file mode 100644 index 000000000..69acf7146 --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/kernels/matrix.metal @@ -0,0 +1,43 @@ +[[kernel]] +void transpose_matrix( + device const Fe *input [[buffer(0)]], + device Fe *output [[buffer(1)]], + constant TransposeParams ¶ms [[buffer(2)]], + uint gid [[thread_position_in_grid]] +) { + if (gid >= params.total_elements) { + return; + } + + uint row = gid / params.cols; + uint col = gid - row * params.cols; + uint dst = col * params.rows + row; + output[dst] = input[gid]; +} + +[[kernel]] +void encode_field_rows_le( + device const Fe *input [[buffer(0)]], + device uchar *output [[buffer(1)]], + constant FieldBytesParams ¶ms [[buffer(2)]], + uint gid [[thread_position_in_grid]] +) { + uint total_elements = params.rows * params.cols; + if (gid >= total_elements) { + return; + } + + uint row = gid / params.cols; + uint col = gid - row * params.cols; + uint row_bits = 31u - clz(params.rows); + uint src_row = reverse_bits_width(row, row_bits); + Fe canonical = from_mont(input[src_row * params.cols + col]); + uint byte_offset = gid * 32u; + for (uint limb = 0; limb < 4; ++limb) { + ulong value = canonical.limbs[limb]; + for (uint byte = 0; byte < 8; ++byte) { + output[byte_offset + limb * 8u + byte] = uchar((value >> (byte * 8u)) & 0xfful); + } + } +} + diff --git a/provekit/common/src/ntt/backends/metal/kernels/ntt.metal b/provekit/common/src/ntt/backends/metal/kernels/ntt.metal new file mode 100644 index 000000000..52160bb0f --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/kernels/ntt.metal @@ -0,0 +1,69 @@ +[[kernel]] +void bit_reverse_permute_rows_in_place( + device Fe *values [[buffer(0)]], + constant BitReverseParams &config [[buffer(1)]], + uint index [[thread_position_in_grid]] +) { + if (index >= config.total_elements || config.row_len <= 1u) { + return; + } + + uint row = index / config.row_len; + uint within = index - row * config.row_len; + uint reversed = reverse_bits_width(within, config.log_n); + if (reversed <= within) { + return; + } + + uint row_base = row * config.row_len; + uint mate = row_base + reversed; + uint current = row_base + within; + Fe tmp = values[current]; + values[current] = values[mate]; + values[mate] = tmp; +} + +[[kernel]] +void radix2_ntt_stage_rows_in_place( + device Fe *values [[buffer(0)]], + device const Fe *twiddles [[buffer(1)]], + constant StageConfig &config [[buffer(2)]], + uint index [[thread_position_in_grid]] +) { + uint butterflies_per_row = config.row_len >> 1u; + uint row = index / butterflies_per_row; + uint local = index - row * butterflies_per_row; + uint half_m = config.half_m; + uint pair_in_group = local % half_m; + uint group = local / half_m; + uint row_base = row * config.row_len; + uint base = row_base + group * (half_m << 1u) + pair_in_group; + uint mate = base + half_m; + + Fe even = values[base]; + Fe odd = values[mate]; + Fe twiddle = twiddles[config.twiddle_offset + pair_in_group]; + Fe t = mont_mul(twiddle, odd); + + values[base] = add_mod(even, t); + values[mate] = sub_mod(even, t); +} + +[[kernel]] +void replicate_first_coset( + device Fe *buffer [[buffer(0)]], + constant ReplicateCosetsParams ¶ms [[buffer(1)]], + uint gid [[thread_position_in_grid]] +) { + if (gid >= params.trailing_elements) { + return; + } + + uint repeats_per_row = params.row_len - params.coset_size; + uint row = gid / repeats_per_row; + uint within = gid - row * repeats_per_row; + uint dst = row * params.row_len + params.coset_size + within; + uint src = row * params.row_len + (within % params.coset_size); + buffer[dst] = buffer[src]; +} + diff --git a/provekit/common/src/ntt/backends/metal/kernels/sha256.metal b/provekit/common/src/ntt/backends/metal/kernels/sha256.metal new file mode 100644 index 000000000..b23c20c7d --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/kernels/sha256.metal @@ -0,0 +1,242 @@ +inline uint rotr32(uint x, uint n) { + return (x >> n) | (x << (32 - n)); +} + +inline uint ch(uint x, uint y, uint z) { + return (x & y) ^ ((~x) & z); +} + +inline uint maj(uint x, uint y, uint z) { + return (x & y) ^ (x & z) ^ (y & z); +} + +inline uint big_sigma0(uint x) { + return rotr32(x, 2) ^ rotr32(x, 13) ^ rotr32(x, 22); +} + +inline uint big_sigma1(uint x) { + return rotr32(x, 6) ^ rotr32(x, 11) ^ rotr32(x, 25); +} + +inline uint small_sigma0(uint x) { + return rotr32(x, 7) ^ rotr32(x, 18) ^ (x >> 3); +} + +inline uint small_sigma1(uint x) { + return rotr32(x, 17) ^ rotr32(x, 19) ^ (x >> 10); +} + +inline uchar field_byte(Fe value, uint byte_index) { + ulong limb = value.limbs[byte_index >> 3u]; + uint shift = (byte_index & 7u) << 3u; + return uchar((limb >> shift) & 0xfful); +} + +inline void sha256_init(thread uint state[8]) { + state[0] = 0x6a09e667u; + state[1] = 0xbb67ae85u; + state[2] = 0x3c6ef372u; + state[3] = 0xa54ff53au; + state[4] = 0x510e527fu; + state[5] = 0x9b05688cu; + state[6] = 0x1f83d9abu; + state[7] = 0x5be0cd19u; +} + +inline uchar sha256_padding_byte(uint idx, uint size, uint total_padded_len, uint bit_len) { + if (idx == size) { + return 0x80u; + } + if (idx >= total_padded_len - 8u) { + uint shift = (total_padded_len - 1u - idx) * 8u; + return shift >= 32u ? 0u : uchar((bit_len >> shift) & 0xffu); + } + return 0u; +} + +inline uint sha256_load_field_word( + Fe field0, + Fe field1, + uint block_base, + uint word_index, + uint size, + uint total_padded_len, + uint bit_len +) { + uint word = 0u; +#pragma clang loop unroll(enable) + for (uint j = 0; j < 4; ++j) { + uint idx = block_base + word_index * 4u + j; + uchar byte = 0u; + if (idx < size) { + uint byte_in_block = idx - block_base; + byte = byte_in_block < 32u + ? field_byte(field0, byte_in_block) + : field_byte(field1, byte_in_block - 32u); + } else { + byte = sha256_padding_byte(idx, size, total_padded_len, bit_len); + } + word = (word << 8) | uint(byte); + } + return word; +} + +inline uint sha256_load_byte_word( + device const uchar *input, + uint offset, + uint block_base, + uint word_index, + uint size, + uint total_padded_len, + uint bit_len +) { + uint word = 0u; +#pragma clang loop unroll(enable) + for (uint j = 0; j < 4; ++j) { + uint idx = block_base + word_index * 4u + j; + uchar byte = idx < size + ? input[offset + idx] + : sha256_padding_byte(idx, size, total_padded_len, bit_len); + word = (word << 8) | uint(byte); + } + return word; +} + +inline void sha256_extend_schedule(thread uint w[64]) { + for (uint i = 16; i < 64; ++i) { + w[i] = small_sigma1(w[i - 2]) + w[i - 7] + small_sigma0(w[i - 15]) + w[i - 16]; + } +} + +inline void sha256_compress(thread uint state[8], thread const uint w[64]) { + uint a = state[0]; + uint b = state[1]; + uint c = state[2]; + uint d = state[3]; + uint e = state[4]; + uint f = state[5]; + uint g = state[6]; + uint h = state[7]; + + for (uint i = 0; i < 64; ++i) { + uint t1 = h + big_sigma1(e) + ch(e, f, g) + SHA256_K[i] + w[i]; + uint t2 = big_sigma0(a) + maj(a, b, c); + h = g; + g = f; + f = e; + e = d + t1; + d = c; + c = b; + b = a; + a = t1 + t2; + } + + state[0] += a; + state[1] += b; + state[2] += c; + state[3] += d; + state[4] += e; + state[5] += f; + state[6] += g; + state[7] += h; +} + +inline void sha256_write_digest(device uchar *out, thread const uint state[8]) { +#pragma clang loop unroll(enable) + for (uint i = 0; i < 8; ++i) { + out[i * 4 + 0] = uchar((state[i] >> 24) & 0xffu); + out[i * 4 + 1] = uchar((state[i] >> 16) & 0xffu); + out[i * 4 + 2] = uchar((state[i] >> 8) & 0xffu); + out[i * 4 + 3] = uchar(state[i] & 0xffu); + } +} + +[[kernel]] +void sha256_field_rows( + device const Fe *input [[buffer(0)]], + device uchar *output [[buffer(1)]], + constant HashManyParams ¶ms [[buffer(2)]], + uint gid [[thread_position_in_grid]] +) { + if (gid >= params.count) { + return; + } + + uint field_count = params.size >> 5u; + uint row_offset = gid * field_count; + uint total_blocks = (params.size + 9u + 63u) / 64u; + uint total_padded_len = total_blocks * 64u; + uint bit_len = params.size * 8u; + uint state[8]; + sha256_init(state); + + for (uint block = 0; block < total_blocks; ++block) { + uint block_base = block * 64u; + uint field_base = block << 1u; + bool has_field0 = block_base < params.size; + bool has_field1 = block_base + 32u < params.size; + Fe field0 = has_field0 ? from_mont(input[row_offset + field_base]) : FE_ONE; + Fe field1 = has_field1 ? from_mont(input[row_offset + field_base + 1u]) : FE_ONE; + uint w[64]; + +#pragma clang loop unroll(enable) + for (uint i = 0; i < 16; ++i) { + w[i] = sha256_load_field_word( + field0, + field1, + block_base, + i, + params.size, + total_padded_len, + bit_len + ); + } + + sha256_extend_schedule(w); + sha256_compress(state, w); + } + + sha256_write_digest(output + gid * 32u, state); +} + +[[kernel]] +void sha256_many( + device const uchar *input [[buffer(0)]], + device uchar *output [[buffer(1)]], + constant HashManyParams ¶ms [[buffer(2)]], + uint gid [[thread_position_in_grid]] +) { + if (gid >= params.count) { + return; + } + + uint offset = gid * params.size; + uint total_blocks = (params.size + 9u + 63u) / 64u; + uint total_padded_len = total_blocks * 64u; + uint bit_len = params.size * 8u; + uint state[8]; + sha256_init(state); + + for (uint block = 0; block < total_blocks; ++block) { + uint block_base = block * 64u; + uint w[64]; + +#pragma clang loop unroll(enable) + for (uint i = 0; i < 16; ++i) { + w[i] = sha256_load_byte_word( + input, + offset, + block_base, + i, + params.size, + total_padded_len, + bit_len + ); + } + + sha256_extend_schedule(w); + sha256_compress(state, w); + } + + sha256_write_digest(output + gid * 32u, state); +} diff --git a/provekit/common/src/ntt/backends/metal/logging.rs b/provekit/common/src/ntt/backends/metal/logging.rs new file mode 100644 index 000000000..e75603ec7 --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/logging.rs @@ -0,0 +1,7 @@ +use std::env; + +pub fn trace_event(args: std::fmt::Arguments<'_>) { + if env::var_os("PROVEKIT_METAL_NTT_TRACE").is_some() { + eprintln!("[provekit-metal-ntt] {args}"); + } +} diff --git a/provekit/common/src/ntt/backends/metal/mod.rs b/provekit/common/src/ntt/backends/metal/mod.rs new file mode 100644 index 000000000..714b1d041 --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/mod.rs @@ -0,0 +1,209 @@ +#[cfg(target_os = "macos")] +mod commit; +#[cfg(target_os = "macos")] +mod encode; +#[cfg(target_os = "macos")] +mod engine; +#[cfg(target_os = "macos")] +mod field; +mod logging; +#[cfg(target_os = "macos")] +mod types; + +#[cfg(target_os = "macos")] +use self::engine::MetalRuntime; +#[cfg(target_os = "macos")] +use std::{ + env, + sync::{Arc, OnceLock}, +}; +#[cfg(target_os = "macos")] +use tracing::info; +#[cfg(target_os = "macos")] +use whir::{hash::SHA2, protocols::matrix_commit::Config as MatrixCommitConfig}; +use { + self::logging::trace_event, + crate::ntt::backends::RSFr, + ark_bn254::Fr, + ark_ff::{FftField, Field}, + tracing::instrument, + whir::algebra::ntt::ReedSolomon, +}; + +#[derive(Clone, Copy, Debug, Default)] +pub struct MetalBn254Ntt; + +#[cfg(target_os = "macos")] +static RUNTIME: OnceLock, String>> = OnceLock::new(); + +impl MetalBn254Ntt { + const MIN_GPU_TOTAL_ELEMENTS: usize = 1 << 20; + const MIN_GPU_ROW_COUNT: usize = 64; + + #[cfg(target_os = "macos")] + pub fn new() -> Result { + if env::var_os("PROVEKIT_DISABLE_METAL_NTT").is_some() { + return Err("Metal NTT disabled via PROVEKIT_DISABLE_METAL_NTT".into()); + } + + match RUNTIME.get_or_init(|| MetalRuntime::new().map(Arc::new)) { + Ok(runtime) => { + info!( + device = runtime.device.name(), + thread_execution_width = runtime.ntt_stage_pipeline.thread_execution_width(), + max_total_threads_per_threadgroup = runtime + .ntt_stage_pipeline + .max_total_threads_per_threadgroup(), + "initialized Metal BN254 NTT backend" + ); + trace_event(format_args!( + "init device={} thread_execution_width={} max_total_threads_per_threadgroup={}", + runtime.device.name(), + runtime.ntt_stage_pipeline.thread_execution_width(), + runtime + .ntt_stage_pipeline + .max_total_threads_per_threadgroup(), + )); + Ok(Self) + } + Err(err) => Err(err.clone()), + } + } + + #[cfg(not(target_os = "macos"))] + pub fn new() -> Result { + Err("Metal BN254 NTT is only available on macOS".into()) + } + + #[cfg(target_os = "macos")] + pub fn runtime(&self) -> Result<&Arc, String> { + match RUNTIME.get() { + Some(Ok(runtime)) => Ok(runtime), + Some(Err(err)) => Err(err.clone()), + None => Err("metal runtime not initialized".into()), + } + } + + fn supports_gpu_shape(codeword_length: usize, row_coeffs: &[&[Fr]]) -> bool { + if row_coeffs.is_empty() { + return false; + } + if codeword_length <= 1 || !codeword_length.is_power_of_two() { + return false; + } + let total_elements = row_coeffs.len().saturating_mul(codeword_length); + total_elements >= Self::MIN_GPU_TOTAL_ELEMENTS + || row_coeffs.len() >= Self::MIN_GPU_ROW_COUNT + } + + #[cfg(target_os = "macos")] + fn supports_gpu_commit(matrix_commit: &MatrixCommitConfig) -> bool { + matrix_commit.leaf_hash_id == SHA2 + && matrix_commit + .merkle_tree + .layers + .iter() + .all(|layer| layer.hash_id == SHA2) + } +} + +impl ReedSolomon for MetalBn254Ntt { + /// @dev: Metal does not need next_order. + /// implementing this because trait requires it. + fn next_order(&self, size: usize) -> Option { + let order = size.next_power_of_two(); + if order <= 1 << 28 { + Some(order) + } else { + None + } + } + + /// @dev: Metal does not need evaluation_points. + /// implementing this because trait requires it. + fn evaluation_points( + &self, + masked_message_length: usize, + codeword_length: usize, + indices: &[usize], + ) -> Vec { + let _ = masked_message_length; + let generator = self.generator(codeword_length); + + indices + .iter() + .map(|i| { + let bits = usize::BITS - (codeword_length - 1).leading_zeros(); + let k = if bits == 0 { + *i + } else { + i.reverse_bits() >> (usize::BITS - bits) + }; + + generator.pow([k as u64]) + }) + .collect() + } + + /// @dev: Metal does not need generator. + /// implementing this because trait requires it. + fn generator(&self, codeword_length: usize) -> Fr { + Fr::get_root_of_unity(codeword_length as u64).unwrap() + } + + /// @note: tries GPU encode first, falls back to CPU if workload is too + /// small or too large, or if GPU fails. + #[instrument(skip(self, messages, masks), fields( + num_messages = messages.len(), + message_len = messages.first().map(|c| c.len()), + codeword_length = codeword_length, + mask_len = masks.len().checked_div(messages.len()) + ))] + fn interleaved_encode( + &self, + messages: &[&[Fr]], + masks: &[Fr], + codeword_length: usize, + ) -> Vec { + if messages.is_empty() { + return vec![]; + } + + let num_messages = messages.len(); + let message_length = messages[0].len(); + for message in messages { + assert_eq!(message_length, message.len()); + } + assert!(masks.len().is_multiple_of(num_messages)); + let _mask_length = masks.len() / num_messages; + if !Self::supports_gpu_shape(codeword_length, messages) { + trace_event(format_args!( + "encode fallback path=cpu codeword_length={} rows={} reason=unsupported-shape", + codeword_length, num_messages, + )); + return RSFr.interleaved_encode(messages, masks, codeword_length); + } + + #[cfg(target_os = "macos")] + { + match self.gpu_encode(messages, masks, codeword_length) { + Ok(codeword) => return codeword, + Err(err) => { + trace_event(format_args!( + "encode fallback path=cpu codeword_length={} rows={} reason=gpu-error \ + error={}", + codeword_length, num_messages, err, + )); + } + } + } + + #[cfg(not(target_os = "macos"))] + trace_event(format_args!( + "encode fallback path=cpu codeword_length={} rows={} reason=unsupported-platform", + codeword_length, num_messages, + )); + + RSFr.interleaved_encode(messages, masks, codeword_length) + } +} diff --git a/provekit/common/src/ntt/backends/metal/types.rs b/provekit/common/src/ntt/backends/metal/types.rs new file mode 100644 index 000000000..c38eb0b21 --- /dev/null +++ b/provekit/common/src/ntt/backends/metal/types.rs @@ -0,0 +1,85 @@ +use {super::engine::PooledBuffer, whir::hash::Hash}; + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct GpuField { + pub limbs: [u64; 4], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct BitReverseParams { + pub row_len: u32, + pub log_n: u32, + pub total_elements: u32, + pub _pad0: u32, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct NttStageParams { + pub row_len: u32, + pub half_m: u32, + pub twiddle_offset: u32, + pub _pad0: u32, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct TransposeParams { + pub rows: u32, + pub cols: u32, + pub total_elements: u32, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct EncodeFieldBytesParams { + pub rows: u32, + pub cols: u32, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct HashManyParams { + pub size: u32, + pub count: u32, +} + +pub struct DeviceMatrix { + pub rows: usize, + pub cols: usize, + pub buffer: PooledBuffer, +} + +#[derive(Clone)] +pub struct DeviceRows { + pub rows: usize, + pub cols: usize, + pub buffer: PooledBuffer, +} + +pub struct DeviceMerkleWitness { + pub num_nodes: usize, + pub root: Hash, + pub buffer: PooledBuffer, +} + +#[derive(Clone, Copy, Debug)] +pub struct EncodeShape { + pub row_count: usize, + pub codeword_length: usize, + pub coset_size: usize, + pub message_length: usize, + pub mask_length: usize, + pub num_cosets: usize, + pub total_elements: usize, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default)] +pub struct ReplicateCosetsParams { + pub row_len: u32, + pub coset_size: u32, + pub trailing_elements: u32, +} diff --git a/provekit/common/src/ntt/backends/mod.rs b/provekit/common/src/ntt/backends/mod.rs new file mode 100644 index 000000000..7ed4beb01 --- /dev/null +++ b/provekit/common/src/ntt/backends/mod.rs @@ -0,0 +1,12 @@ +pub mod cpu; + +#[cfg(target_os = "linux")] +pub mod cuda; +#[cfg(target_os = "macos")] +pub mod metal; + +pub use cpu::RSFr; +#[cfg(target_os = "linux")] +pub use cuda::CudaBn254Ntt; +#[cfg(target_os = "macos")] +pub use metal::MetalBn254Ntt; diff --git a/provekit/common/src/ntt/mod.rs b/provekit/common/src/ntt/mod.rs new file mode 100644 index 000000000..8ad49bea4 --- /dev/null +++ b/provekit/common/src/ntt/mod.rs @@ -0,0 +1,124 @@ +pub mod backends; + +#[cfg(target_os = "linux")] +pub use backends::CudaBn254Ntt; +#[cfg(target_os = "macos")] +pub use backends::MetalBn254Ntt; +pub use backends::RSFr; + +#[cfg(all(test, target_os = "linux"))] +mod cuda_tests { + use { + super::{CudaBn254Ntt, RSFr}, + ark_bn254::Fr, + ark_ff::UniformRand, + whir::algebra::ntt::ReedSolomon, + }; + + fn try_init() -> Option { + match CudaBn254Ntt::new() { + Ok(gpu) => Some(gpu), + Err(err) => { + eprintln!("skipping CUDA test: {err}"); + None + } + } + } + + #[test] + fn cuda_matches_cpu_for_large_case() { + let Some(gpu) = try_init() else { return }; + let mut rng = ark_std::test_rng(); + let coeffs: Vec<_> = (0..(1 << 12)).map(|_| Fr::rand(&mut rng)).collect(); + let messages = [&coeffs[..1 << 11], &coeffs[1 << 11..]]; + let cpu = RSFr.interleaved_encode(&messages, &[], 1 << 11); + let actual = gpu.interleaved_encode(&messages, &[], 1 << 11); + assert_eq!(cpu, actual); + } + + #[test] + fn cuda_matches_cpu_for_multi_poly_case() { + let Some(gpu) = try_init() else { return }; + let mut rng = ark_std::test_rng(); + let storage: Vec> = (0..64) + .map(|_| (0..16).map(|_| Fr::rand(&mut rng)).collect()) + .collect(); + let messages: Vec<&[Fr]> = storage.iter().map(Vec::as_slice).collect(); + let cpu = RSFr.interleaved_encode(&messages, &[], 32); + let actual = gpu.interleaved_encode(&messages, &[], 32); + assert_eq!(cpu, actual); + } + + #[test] + fn cuda_matches_cpu_with_masks() { + let Some(gpu) = try_init() else { return }; + let mut rng = ark_std::test_rng(); + let storage: Vec> = (0..64) + .map(|_| (0..16).map(|_| Fr::rand(&mut rng)).collect()) + .collect(); + let messages: Vec<&[Fr]> = storage.iter().map(Vec::as_slice).collect(); + let masks: Vec = (0..(64 * 4)).map(|_| Fr::rand(&mut rng)).collect(); + let cpu = RSFr.interleaved_encode(&messages, &masks, 32); + let actual = gpu.interleaved_encode(&messages, &masks, 32); + assert_eq!(cpu, actual); + } +} + +#[cfg(all(test, target_os = "macos"))] +mod tests { + use { + super::{MetalBn254Ntt, RSFr}, + ark_bn254::Fr, + ark_ff::UniformRand, + whir::algebra::ntt::ReedSolomon, + }; + + #[test] + fn metal_matches_cpu_for_small_case() { + let gpu = MetalBn254Ntt::new().unwrap(); + eprintln!( + "using Metal device: {}", + gpu.runtime().unwrap().device.name() + ); + + let mut rng = ark_std::test_rng(); + let coeffs: Vec<_> = (0..(1 << 12)).map(|_| Fr::rand(&mut rng)).collect(); + let messages = [&coeffs[..1 << 11], &coeffs[1 << 11..]]; + let cpu = RSFr.interleaved_encode(&messages, &[], 1 << 11); + let gpu = gpu.interleaved_encode(&messages, &[], 1 << 11); + assert_eq!(cpu, gpu); + } + + #[test] + fn metal_matches_cpu_for_small_codeword_case() { + let gpu = MetalBn254Ntt::new().unwrap(); + let mut rng = ark_std::test_rng(); + let messages_storage: Vec<_> = (0..2) + .map(|_| (0..16).map(|_| Fr::rand(&mut rng)).collect::>()) + .collect(); + let messages = messages_storage + .iter() + .map(Vec::as_slice) + .collect::>(); + let masks: Vec<_> = (0..(2 * 4)).map(|_| Fr::rand(&mut rng)).collect(); + let cpu = RSFr.interleaved_encode(&messages, &masks, 32); + let gpu = gpu.interleaved_encode(&messages, &masks, 32); + assert_eq!(cpu, gpu); + } + + #[test] + fn metal_matches_cpu_for_multi_poly_case() { + let gpu = MetalBn254Ntt::new().unwrap(); + let mut rng = ark_std::test_rng(); + let messages_storage: Vec<_> = (0..4) + .map(|_| (0..16).map(|_| Fr::rand(&mut rng)).collect::>()) + .collect(); + let messages = messages_storage + .iter() + .map(Vec::as_slice) + .collect::>(); + let cpu = RSFr.interleaved_encode(&messages, &[], 32); + let gpu = gpu.interleaved_encode(&messages, &[], 32); + assert_eq!(cpu, gpu); + } +} diff --git a/provekit/common/src/prefix_covector.rs b/provekit/common/src/prefix_covector.rs index 1e332893e..7bcc2a3ac 100644 --- a/provekit/common/src/prefix_covector.rs +++ b/provekit/common/src/prefix_covector.rs @@ -4,6 +4,39 @@ use { whir::algebra::{dot, linear_form::LinearForm, multilinear_extend}, }; +/// Apply `acc[i] += scalar * vec[i]` to the prefix `acc[..vec.len()]` using +/// rayon when the workload is big enough and we're not already on a rayon +/// worker (top-level thread only — avoids nested-parallelism scheduler +/// overhead that turned out to be a 20 % CPU sink). +#[inline] +fn scalar_mul_add_prefix(acc: &mut [FieldElement], vec: &[FieldElement], scalar: FieldElement) { + debug_assert!(acc.len() >= vec.len()); + // Roughly L2-sized chunk, picked to match WHIR's `workload_size::()`. + const CHUNK: usize = 1 << 13; + let n = vec.len(); + if n == 0 { + return; + } + + #[cfg(feature = "parallel")] + if n > CHUNK && rayon::current_thread_index().is_none() { + use rayon::prelude::*; + acc[..n] + .par_chunks_mut(CHUNK) + .zip(vec.par_chunks(CHUNK)) + .for_each(|(a, v)| { + for (a_i, v_i) in a.iter_mut().zip(v) { + *a_i += scalar * *v_i; + } + }); + return; + } + + for (a_i, v_i) in acc[..n].iter_mut().zip(vec) { + *a_i += scalar * *v_i; + } +} + /// A covector that stores only a power-of-two prefix, with the rest /// implicitly zero-padded to `domain_size`. Saves memory when the /// covector is known to be zero beyond the prefix (e.g. R1CS alpha @@ -76,12 +109,7 @@ impl LinearForm for PrefixCovector { } fn accumulate(&self, accumulator: &mut [FieldElement], scalar: FieldElement) { - for (acc, val) in accumulator[..self.vector.len()] - .iter_mut() - .zip(&self.vector) - { - *acc += scalar * *val; - } + scalar_mul_add_prefix(accumulator, &self.vector, scalar); } } @@ -139,12 +167,8 @@ impl LinearForm for OffsetCovector { } fn accumulate(&self, accumulator: &mut [FieldElement], scalar: FieldElement) { - for (acc, &w) in accumulator[self.offset..self.offset + self.weights.len()] - .iter_mut() - .zip(&self.weights) - { - *acc += scalar * w; - } + let dst = &mut accumulator[self.offset..self.offset + self.weights.len()]; + scalar_mul_add_prefix(dst, &self.weights, scalar); } } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 28af714d9..058ef3492 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -103,7 +103,7 @@ impl Prove for NoirProver { self, acir_witness_idx_to_value_map: WitnessMap, ) -> Result { - provekit_common::register_ntt(); + provekit_common::register_whir_backends(); let mut public_input_indices = self.program.functions[0].public_inputs().indices(); public_input_indices.sort_unstable(); @@ -271,7 +271,7 @@ impl Prove for NoirProver { impl Prove for MavrosProver { #[cfg(feature = "witness-generation")] fn prove(mut self, input_map: InputMap) -> Result { - provekit_common::register_ntt(); + provekit_common::register_whir_backends(); let params = crate::input_utils::ordered_params_from_btreemap(&self.abi, &input_map)?; let phase1 = mavros_interpreter::run_phase1( diff --git a/provekit/verifier/src/lib.rs b/provekit/verifier/src/lib.rs index 1c3461fa8..f879c0b12 100644 --- a/provekit/verifier/src/lib.rs +++ b/provekit/verifier/src/lib.rs @@ -14,7 +14,7 @@ pub trait Verify { impl Verify for Verifier { #[instrument(skip_all)] fn verify(&mut self, proof: &NoirProof) -> Result<()> { - provekit_common::register_ntt(); + provekit_common::register_whir_backends(); self.whir_for_witness .take() diff --git a/tooling/provekit-ffi/src/ffi.rs b/tooling/provekit-ffi/src/ffi.rs index 7e687b59e..0e8d501db 100644 --- a/tooling/provekit-ffi/src/ffi.rs +++ b/tooling/provekit-ffi/src/ffi.rs @@ -90,7 +90,7 @@ pub unsafe extern "C" fn pk_get_last_error(out_buf: *mut PKBuf) -> c_int { /// Must be called once before using any other ProveKit functions. #[no_mangle] pub extern "C" fn pk_init() -> c_int { - provekit_common::register_ntt(); + provekit_common::register_whir_backends(); PKStatus::Success.into() }