From d1a0abf4fff108ac80e07c3ca0b5bb00cf0eb970 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 6 May 2026 15:12:54 -0300 Subject: [PATCH 01/16] add first cuda files --- Cargo.lock | 32 + Cargo.toml | 1 + crypto/math-cuda/Cargo.toml | 22 + crypto/math-cuda/build.rs | 57 + crypto/math-cuda/kernels/arith.cu | 83 ++ crypto/math-cuda/kernels/ext3.cuh | 121 +++ crypto/math-cuda/kernels/goldilocks.cuh | 69 ++ crypto/math-cuda/kernels/ntt.cu | 285 +++++ crypto/math-cuda/src/device.rs | 251 +++++ crypto/math-cuda/src/lde.rs | 984 ++++++++++++++++++ crypto/math-cuda/src/lib.rs | 152 +++ crypto/math-cuda/src/ntt.rs | 211 ++++ crypto/math-cuda/tests/bench_quick.rs | 356 +++++++ crypto/math-cuda/tests/evaluate_coset_ext3.rs | 143 +++ crypto/math-cuda/tests/ext3.rs | 87 ++ crypto/math-cuda/tests/goldilocks.rs | 127 +++ crypto/math-cuda/tests/lde.rs | 112 ++ crypto/math-cuda/tests/lde_batch.rs | 96 ++ crypto/math-cuda/tests/lde_batch_ext3.rs | 161 +++ crypto/math-cuda/tests/ntt.rs | 136 +++ crypto/stark/Cargo.toml | 4 + 21 files changed, 3490 insertions(+) create mode 100644 crypto/math-cuda/Cargo.toml create mode 100644 crypto/math-cuda/build.rs create mode 100644 crypto/math-cuda/kernels/arith.cu create mode 100644 crypto/math-cuda/kernels/ext3.cuh create mode 100644 crypto/math-cuda/kernels/goldilocks.cuh create mode 100644 crypto/math-cuda/kernels/ntt.cu create mode 100644 crypto/math-cuda/src/device.rs create mode 100644 crypto/math-cuda/src/lde.rs create mode 100644 crypto/math-cuda/src/lib.rs create mode 100644 crypto/math-cuda/src/ntt.rs create mode 100644 crypto/math-cuda/tests/bench_quick.rs create mode 100644 crypto/math-cuda/tests/evaluate_coset_ext3.rs create mode 100644 crypto/math-cuda/tests/ext3.rs create mode 100644 crypto/math-cuda/tests/goldilocks.rs create mode 100644 crypto/math-cuda/tests/lde.rs create mode 100644 crypto/math-cuda/tests/lde_batch.rs create mode 100644 crypto/math-cuda/tests/lde_batch_ext3.rs create mode 100644 crypto/math-cuda/tests/ntt.rs diff --git a/Cargo.lock b/Cargo.lock index f6eea84d6..7b6ed3c62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -803,6 +803,15 @@ dependencies = [ "typenum", ] +[[package]] +name = "cudarc" +version = "0.19.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f071cd6a7b5d51607df76aa2d426aaabc7a74bc6bdb885b8afa63a880572ad9b" +dependencies = [ + "libloading", +] + [[package]] name = "darling" version = "0.21.3" @@ -1989,6 +1998,16 @@ version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" +[[package]] +name = "libloading" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "libm" version = "0.2.15" @@ -2105,6 +2124,18 @@ dependencies = [ "serde_json", ] +[[package]] +name = "math-cuda" +version = "0.1.0" +dependencies = [ + "cudarc", + "math", + "rand 0.8.5", + "rand_chacha 0.3.1", + "rayon", + "sha3", +] + [[package]] name = "memchr" version = "2.7.6" @@ -3172,6 +3203,7 @@ dependencies = [ "itertools 0.11.0", "log", "math", + "math-cuda", "rayon", "serde", "serde-wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index 4d10b7c44..e43dc7f0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "crypto/stark", "crypto/crypto", "crypto/math", + "crypto/math-cuda", "bin/cli", ] diff --git a/crypto/math-cuda/Cargo.toml b/crypto/math-cuda/Cargo.toml new file mode 100644 index 000000000..8c22d1110 --- /dev/null +++ b/crypto/math-cuda/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "math-cuda" +description = "CUDA-accelerated FFT/NTT for Goldilocks (base field) used by the lambda-vm STARK prover" +version = "0.1.0" +edition = "2024" + +[dependencies] +cudarc = { version = "0.19", default-features = false, features = [ + "driver", + "nvrtc", + "std", + "cuda-12080", + "dynamic-loading", +] } +math = { path = "../math" } +rayon = "1.7" + +[dev-dependencies] +rand = { version = "0.8.5", features = ["std"] } +rand_chacha = "0.3.1" +rayon = "1.7" +sha3 = "0.10.8" diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs new file mode 100644 index 000000000..a6defb5ab --- /dev/null +++ b/crypto/math-cuda/build.rs @@ -0,0 +1,57 @@ +use std::env; +use std::path::PathBuf; +use std::process::Command; + +fn cuda_home() -> PathBuf { + env::var_os("CUDA_HOME") + .or_else(|| env::var_os("CUDA_PATH")) + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from("/usr/local/cuda")) +} + +fn nvcc_path() -> PathBuf { + cuda_home().join("bin").join("nvcc") +} + +fn compile_ptx(src: &str, out_name: &str) { + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let src_path = manifest_dir.join("kernels").join(src); + let out_path = out_dir.join(out_name); + + println!("cargo:rerun-if-changed=kernels/{src}"); + println!("cargo:rerun-if-env-changed=CUDA_HOME"); + println!("cargo:rerun-if-env-changed=CUDA_PATH"); + println!("cargo:rerun-if-env-changed=CUDARC_NVCC_ARCH"); + + // Emit PTX for a virtual architecture; the CUDA driver JIT-compiles it for the + // actual GPU at load time, so one PTX works across Ada/Hopper/Blackwell. Override + // with CUDARC_NVCC_ARCH to pin a specific compute capability. + let arch = env::var("CUDARC_NVCC_ARCH").unwrap_or_else(|_| "compute_89".to_string()); + + let status = Command::new(nvcc_path()) + .args([ + "--ptx", + "-O3", + "-std=c++17", + "-arch", + &arch, + "-o", + ]) + .arg(&out_path) + .arg(&src_path) + .status() + .expect("failed to invoke nvcc — is CUDA installed and CUDA_HOME set?"); + + if !status.success() { + panic!("nvcc failed compiling {}", src_path.display()); + } +} + +fn main() { + // Headers are not compiled; emit rerun-if-changed so edits trigger rebuilds. + println!("cargo:rerun-if-changed=kernels/goldilocks.cuh"); + println!("cargo:rerun-if-changed=kernels/ext3.cuh"); + compile_ptx("arith.cu", "arith.ptx"); + compile_ptx("ntt.cu", "ntt.ptx"); +} diff --git a/crypto/math-cuda/kernels/arith.cu b/crypto/math-cuda/kernels/arith.cu new file mode 100644 index 000000000..4bee9b8bb --- /dev/null +++ b/crypto/math-cuda/kernels/arith.cu @@ -0,0 +1,83 @@ +// Element-wise Goldilocks kernels used by the Phase-2 parity tests. These mirror +// the CPU reference in `crypto/math/src/field/goldilocks.rs` so raw u64 outputs +// are bit-identical to the CPU path. + +#include "goldilocks.cuh" +#include "ext3.cuh" + +using goldilocks::add; +using goldilocks::sub; +using goldilocks::mul; +using goldilocks::neg; + +extern "C" __global__ void vector_add_u64(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = a[tid] + b[tid]; // plain wrapping u64 add — toolchain sanity only. +} + +extern "C" __global__ void gl_add_kernel(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = add(a[tid], b[tid]); +} + +extern "C" __global__ void gl_sub_kernel(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = sub(a[tid], b[tid]); +} + +extern "C" __global__ void gl_mul_kernel(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = mul(a[tid], b[tid]); +} + +extern "C" __global__ void gl_neg_kernel(const uint64_t *a, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = neg(a[tid]); +} + +// --------------------------------------------------------------------------- +// Ext3 (Goldilocks cubic extension) test kernels. +// Input/output arrays are interleaved [a_0, b_0, c_0, a_1, b_1, c_1, ...]. +// --------------------------------------------------------------------------- + +extern "C" __global__ void ext3_mul_kernel(const uint64_t *a_int, + const uint64_t *b_int, + uint64_t *c_int, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + ext3::Fe3 a = ext3::make(a_int[tid*3 + 0], a_int[tid*3 + 1], a_int[tid*3 + 2]); + ext3::Fe3 b = ext3::make(b_int[tid*3 + 0], b_int[tid*3 + 1], b_int[tid*3 + 2]); + ext3::Fe3 r = ext3::mul(a, b); + c_int[tid*3 + 0] = r.a; + c_int[tid*3 + 1] = r.b; + c_int[tid*3 + 2] = r.c; +} + +extern "C" __global__ void ext3_add_kernel(const uint64_t *a_int, + const uint64_t *b_int, + uint64_t *c_int, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + ext3::Fe3 a = ext3::make(a_int[tid*3 + 0], a_int[tid*3 + 1], a_int[tid*3 + 2]); + ext3::Fe3 b = ext3::make(b_int[tid*3 + 0], b_int[tid*3 + 1], b_int[tid*3 + 2]); + ext3::Fe3 r = ext3::add(a, b); + c_int[tid*3 + 0] = r.a; + c_int[tid*3 + 1] = r.b; + c_int[tid*3 + 2] = r.c; +} diff --git a/crypto/math-cuda/kernels/ext3.cuh b/crypto/math-cuda/kernels/ext3.cuh new file mode 100644 index 000000000..2f4040714 --- /dev/null +++ b/crypto/math-cuda/kernels/ext3.cuh @@ -0,0 +1,121 @@ +// Goldilocks cubic extension on device: Fp3 = Fp[w] / (w^3 - 2) +// where Fp is Goldilocks (2^64 - 2^32 + 1). +// +// Layout matches the CPU `Degree3GoldilocksExtensionField` (see +// `crypto/math/src/field/extensions_goldilocks.rs`): an element is a +// 3-tuple `(a, b, c)` representing `a + b*w + c*w^2`. +// +// The reducible `w^3 = 2` means cross-term products get a factor of 2: +// (a0 + a1*w + a2*w^2) * (b0 + b1*w + b2*w^2) +// = (a0*b0 + 2*(a1*b2 + a2*b1)) +// + (a0*b1 + a1*b0 + 2*a2*b2) * w +// + (a0*b2 + a1*b1 + a2*b0) * w^2 +// +// We use the same dot-product-of-three folding as the CPU (which saves +// reductions by summing u128 products before `reduce128`). CUDA has +// `__umul64hi` so we implement `dot_product_3` inline. + +#pragma once +#include "goldilocks.cuh" + +namespace ext3 { + +struct Fe3 { + uint64_t a, b, c; +}; + +__device__ __forceinline__ Fe3 make(uint64_t a, uint64_t b, uint64_t c) { + Fe3 r = {a, b, c}; + return r; +} + +__device__ __forceinline__ Fe3 zero() { return make(0, 0, 0); } +__device__ __forceinline__ Fe3 one() { return make(1, 0, 0); } + +__device__ __forceinline__ Fe3 add(const Fe3 &x, const Fe3 &y) { + return make(goldilocks::add(x.a, y.a), + goldilocks::add(x.b, y.b), + goldilocks::add(x.c, y.c)); +} + +__device__ __forceinline__ Fe3 sub(const Fe3 &x, const Fe3 &y) { + return make(goldilocks::sub(x.a, y.a), + goldilocks::sub(x.b, y.b), + goldilocks::sub(x.c, y.c)); +} + +__device__ __forceinline__ Fe3 neg(const Fe3 &x) { + return make(goldilocks::neg(x.a), + goldilocks::neg(x.b), + goldilocks::neg(x.c)); +} + +/// Mixed: base * ext3 → ext3 (componentwise). +__device__ __forceinline__ Fe3 mul_base(const Fe3 &x, uint64_t s) { + return make(goldilocks::mul(x.a, s), + goldilocks::mul(x.b, s), + goldilocks::mul(x.c, s)); +} + +/// Dot-product of three (a0*b0 + a1*b1 + a2*b2) mod p, with one reduce128 +/// on the sum of three u128 products. Matches CPU `dot_product_3`. +__device__ __forceinline__ uint64_t dot3(uint64_t a0, uint64_t b0, + uint64_t a1, uint64_t b1, + uint64_t a2, uint64_t b2) { + // Split the sum of three u128 products into hi/lo u128 halves, then + // reduce once. We track overflow-count (at most 2) and add EPSILON^2 + // per overflow, matching the CPU path. + // prod_i = a_i * b_i (u128) + uint64_t lo0 = a0 * b0, hi0 = __umul64hi(a0, b0); + uint64_t lo1 = a1 * b1, hi1 = __umul64hi(a1, b1); + uint64_t lo2 = a2 * b2, hi2 = __umul64hi(a2, b2); + + // sum01 = prod0 + prod1 (in u128 lanes) + uint64_t s01_lo = lo0 + lo1; + uint64_t carry01 = (s01_lo < lo0) ? 1ULL : 0ULL; + uint64_t s01_hi = hi0 + hi1 + carry01; + uint32_t over1 = (s01_hi < hi0 + carry01) ? 1u : 0u; // low-pass overflow + + // sum012 = sum01 + prod2 + uint64_t s012_lo = s01_lo + lo2; + uint64_t carry012 = (s012_lo < s01_lo) ? 1ULL : 0ULL; + uint64_t s012_hi = s01_hi + hi2 + carry012; + uint32_t over2 = (s012_hi < hi2 + carry012) ? 1u : 0u; + + uint64_t reduced = goldilocks::reduce128(s012_lo, s012_hi); + + uint32_t overflow_count = over1 + over2; + if (overflow_count > 0) { + // 2^128 mod p = EPSILON^2 (= (2^32 - 1)^2). + uint64_t eps = goldilocks::EPSILON; + uint64_t eps_sq = eps * eps; + reduced = goldilocks::add_no_canonicalize(reduced, eps_sq); + if (overflow_count > 1) { + reduced = goldilocks::add_no_canonicalize(reduced, eps_sq); + } + } + return reduced; +} + +/// Full ext3 × ext3 multiplication (matches CPU +/// `Degree3GoldilocksExtensionField::mul`). +__device__ __forceinline__ Fe3 mul(const Fe3 &x, const Fe3 &y) { + // c0 = x.a*y.a + x.b*(2*y.c) + x.c*(2*y.b) + // c1 = x.a*y.b + x.b*y.a + x.c*(2*y.c) + // c2 = x.a*y.c + x.b*y.b + x.c*y.a + uint64_t b1_2 = goldilocks::add(y.b, y.b); + uint64_t b2_2 = goldilocks::add(y.c, y.c); + + uint64_t c0 = dot3(x.a, y.a, x.b, b2_2, x.c, b1_2); + uint64_t c1 = dot3(x.a, y.b, x.b, y.a, x.c, b2_2); + uint64_t c2 = dot3(x.a, y.c, x.b, y.b, x.c, y.a); + return make(c0, c1, c2); +} + +__device__ __forceinline__ Fe3 canonical(const Fe3 &x) { + return make(goldilocks::canonical(x.a), + goldilocks::canonical(x.b), + goldilocks::canonical(x.c)); +} + +} // namespace ext3 diff --git a/crypto/math-cuda/kernels/goldilocks.cuh b/crypto/math-cuda/kernels/goldilocks.cuh new file mode 100644 index 000000000..5e296a390 --- /dev/null +++ b/crypto/math-cuda/kernels/goldilocks.cuh @@ -0,0 +1,69 @@ +// Goldilocks field on device. Ports `crypto/math/src/field/goldilocks.rs` one-to-one: +// - Representation: non-canonical u64 in [0, 2^64). Canonicalise only at boundaries. +// - Prime: 2^64 - 2^32 + 1. +// - Reduction: exploits 2^64 ≡ EPSILON (mod p) and 2^96 ≡ -1 (mod p). +// +// The arithmetic here must produce bit-identical u64 outputs to the CPU path so +// LDE parity tests can assert raw equality. + +#pragma once +#include + +namespace goldilocks { + +__device__ constexpr uint64_t PRIME = 0xFFFFFFFF00000001ULL; +__device__ constexpr uint64_t EPSILON = 0xFFFFFFFFULL; // 2^32 - 1 + +__device__ __forceinline__ uint64_t add_no_canonicalize(uint64_t x, uint64_t y) { + // Mirror of `add_no_canonicalize_trashing_input`: one add, one EPSILON bump on carry. + uint64_t sum = x + y; + return sum + (sum < x ? EPSILON : 0ULL); +} + +__device__ __forceinline__ uint64_t add(uint64_t a, uint64_t b) { + uint64_t sum = a + b; + uint64_t over1 = (sum < a) ? EPSILON : 0ULL; + uint64_t sum2 = sum + over1; + uint64_t over2 = (sum2 < sum) ? EPSILON : 0ULL; + return sum2 + over2; +} + +__device__ __forceinline__ uint64_t sub(uint64_t a, uint64_t b) { + uint64_t diff = a - b; + uint64_t under1 = (a < b) ? EPSILON : 0ULL; + uint64_t diff2 = diff - under1; + uint64_t under2 = (diff2 > diff) ? EPSILON : 0ULL; + return diff2 - under2; +} + +__device__ __forceinline__ uint64_t reduce128(uint64_t lo, uint64_t hi) { + uint64_t x_hi_hi = hi >> 32; + uint64_t x_hi_lo = hi & EPSILON; + + // 2^96 ≡ -1 (mod p): subtract x_hi_hi from lo, EPSILON-correct on borrow. + uint64_t t0 = lo - x_hi_hi; + if (lo < x_hi_hi) t0 -= EPSILON; + + // 2^64 ≡ EPSILON (mod p): x_hi_lo * EPSILON = (x_hi_lo << 32) - x_hi_lo. + uint64_t t1 = (x_hi_lo << 32) - x_hi_lo; + + return add_no_canonicalize(t0, t1); +} + +__device__ __forceinline__ uint64_t mul(uint64_t a, uint64_t b) { + uint64_t lo = a * b; + uint64_t hi = __umul64hi(a, b); + return reduce128(lo, hi); +} + +__device__ __forceinline__ uint64_t neg(uint64_t a) { + // `a` may be non-canonical. Canonicalise first, then p - a (or 0). + uint64_t canon = (a >= PRIME) ? (a - PRIME) : a; + return canon == 0 ? 0 : (PRIME - canon); +} + +__device__ __forceinline__ uint64_t canonical(uint64_t a) { + return (a >= PRIME) ? (a - PRIME) : a; +} + +} // namespace goldilocks diff --git a/crypto/math-cuda/kernels/ntt.cu b/crypto/math-cuda/kernels/ntt.cu new file mode 100644 index 000000000..2a5c8c786 --- /dev/null +++ b/crypto/math-cuda/kernels/ntt.cu @@ -0,0 +1,285 @@ +// Radix-2 DIT NTT over Goldilocks. One kernel per butterfly level; the caller +// runs `bit_reverse_permute` once before the first level. +// +// Input layout: bit-reversed-order coefficients (after `bit_reverse_permute`). +// Output layout: natural-order evaluations — matches the CPU `evaluate_fft` contract. +// +// Twiddle table: `tw[i] = ω^i` for i in [0, n/2). Stride-indexed per level. + +#include "goldilocks.cuh" + +using goldilocks::add; +using goldilocks::sub; +using goldilocks::mul; + +/// Reverse the low `log_n` bits of each index and swap x[i] ↔ x[rev(i)]. +/// One thread per index; guarded by `tid < rev` to avoid double-swap. +extern "C" __global__ void bit_reverse_permute(uint64_t *x, + uint64_t n, + uint64_t log_n) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + + // __brevll reverses all 64 bits; shift right so result lives in [0, n). + uint64_t rev = __brevll(tid) >> (64 - log_n); + if (tid < rev) { + uint64_t tmp = x[tid]; + x[tid] = x[rev]; + x[rev] = tmp; + } +} + +/// Pointwise multiply: x[i] *= w[i]. Used for coset scaling (w = g^i weights). +extern "C" __global__ void pointwise_mul(uint64_t *x, + const uint64_t *w, + uint64_t n) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) x[tid] = mul(x[tid], w[tid]); +} + +/// Broadcast scalar multiply: x[i] *= c. Used for the 1/n factor at the end of iNTT. +extern "C" __global__ void scalar_mul(uint64_t *x, + uint64_t c, + uint64_t n) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) x[tid] = mul(x[tid], c); +} + +// ============================================================================ +// BATCHED KERNELS +// +// One launch processes M columns at once. The device buffer holds M columns +// back-to-back; column `c` starts at `data + c * col_stride`. gridDim.y is +// the column index, so each block handles one (column, butterfly-window) pair. +// +// The same twiddle table is shared across all columns of a batch (they all +// NTT on the same domain). The coset weights are also shared. +// ============================================================================ + +extern "C" __global__ void bit_reverse_permute_batched(uint64_t *data, + uint64_t n, + uint64_t log_n, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + + uint64_t rev = __brevll(tid) >> (64 - log_n); + if (tid < rev) { + uint64_t tmp = x[tid]; + x[tid] = x[rev]; + x[rev] = tmp; + } +} + +extern "C" __global__ void ntt_dit_level_batched(uint64_t *data, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t level, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n_half = n >> 1; + if (tid >= n_half) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + + uint64_t half = 1ULL << level; + uint64_t block_size = half << 1; + uint64_t block_idx = tid >> level; + uint64_t k = tid & (half - 1); + + uint64_t i0 = block_idx * block_size + k; + uint64_t i1 = i0 + half; + + uint64_t tw_index = k << (log_n - level - 1); + uint64_t w = tw[tw_index]; + + uint64_t u = x[i0]; + uint64_t v = mul(w, x[i1]); + x[i0] = add(u, v); + x[i1] = sub(u, v); +} + +extern "C" __global__ void ntt_dit_8_levels_batched(uint64_t *data, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t base_step, + uint64_t col_stride) { + __shared__ uint64_t tile[256]; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + + uint32_t n_loc_steps = (uint32_t)min((uint64_t)8, log_n - base_step); + + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + + uint64_t group_size = 1ULL << base_step; + uint64_t n_groups = n >> base_step; + uint64_t low_bits = tid / n_groups; + uint64_t high_bits = tid & (n_groups - 1); + uint64_t row = high_bits * group_size + low_bits; + + tile[threadIdx.x] = x[row]; + __syncthreads(); + + uint32_t remaining_high_bits = (uint32_t)(log_n - base_step - 1); + uint32_t high_mask = (1u << remaining_high_bits) - 1u; + + for (uint32_t loc_step = 0; loc_step < n_loc_steps; ++loc_step) { + if (threadIdx.x < 128) { + uint32_t i = threadIdx.x; + uint32_t half = 1u << loc_step; + uint32_t grp = i >> loc_step; + uint32_t grp_pos = i & (half - 1); + uint32_t idx1 = (grp << (loc_step + 1)) + grp_pos; + uint32_t idx2 = idx1 + half; + + uint32_t gs = (uint32_t)base_step + loc_step; + uint32_t ggp = (blockIdx.x << 7) + i; + ggp = ((ggp & high_mask) << (uint32_t)base_step) + (ggp >> remaining_high_bits); + ggp = ggp & ((1u << gs) - 1u); + uint64_t factor = tw[(uint64_t)ggp * (n >> (gs + 1))]; + + uint64_t u = tile[idx1]; + uint64_t v = mul(tile[idx2], factor); + tile[idx1] = add(u, v); + tile[idx2] = sub(u, v); + } + __syncthreads(); + } + + x[row] = tile[threadIdx.x]; +} + + +/// Batched pointwise multiply: first n elements of each column multiplied by +/// the SHARED weight vector `w` (size n). Used for coset scaling — every +/// column of a table sees the same `g^i / N` weights. +extern "C" __global__ void pointwise_mul_batched(uint64_t *data, + const uint64_t *w, + uint64_t n, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + x[tid] = mul(x[tid], w[tid]); +} + +/// Batched broadcast scalar multiply — one scalar c applied to the first n +/// elements of every column. +extern "C" __global__ void scalar_mul_batched(uint64_t *data, + uint64_t c, + uint64_t n, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + x[tid] = mul(x[tid], c); +} + +/// One DIT butterfly level. Thread `tid` (of n/2 total) owns exactly one +/// butterfly pair (i0, i1 = i0 + half). Twiddle picked from the shared full +/// `tw` table at stride `n / block_size`. Kept for log_n < 8 where shmem +/// fusion is overkill. +extern "C" __global__ void ntt_dit_level(uint64_t *x, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t level) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n_half = n >> 1; + if (tid >= n_half) return; + + uint64_t half = 1ULL << level; // 2^ℓ + uint64_t block_size = half << 1; // 2^{ℓ+1} + uint64_t block_idx = tid >> level; // floor(tid / half) + uint64_t k = tid & (half - 1); // tid mod half + + uint64_t i0 = block_idx * block_size + k; + uint64_t i1 = i0 + half; + + // Stride = n / block_size = n >> (level + 1). + uint64_t tw_index = k << (log_n - level - 1); + uint64_t w = tw[tw_index]; + + uint64_t u = x[i0]; + uint64_t v = mul(w, x[i1]); + x[i0] = add(u, v); + x[i1] = sub(u, v); +} + +/// Up to 8 DIT butterfly levels fused in one kernel using shared memory. +/// +/// Ported from Zisk's `br_ntt_8_steps` (`pil2-stark/src/goldilocks/src/ntt_goldilocks.cu`), +/// simplified to single-column. Each block of 256 threads processes 256 +/// elements in on-chip shared memory, running up to 8 butterfly levels +/// without writing to global memory between them — cuts DRAM traffic by up +/// to 8× vs the per-level kernel. +/// +/// `base_step` selects which 8-level window this launch handles (0, 8, 16…). +/// For levels 0–7 the implicit DIT element layout already places all pair +/// mates inside the same 256-block; for higher base_step we remap the loaded +/// row so pair mates land in consecutive shared-memory slots. +/// +/// Expects bit-reversed input (the caller runs `bit_reverse_permute` once +/// before the first kernel launch). +/// +/// Assumes `n` is a multiple of 256, i.e. `log_n >= 8`. +extern "C" __global__ void ntt_dit_8_levels(uint64_t *x, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t base_step) { + __shared__ uint64_t tile[256]; + + uint32_t n_loc_steps = (uint32_t)min((uint64_t)8, log_n - base_step); + + // tid is the *unpermuted* flat index the block/thread would own. + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + + // Row remap: for base_step > 0, gather elements that pair at levels + // `base_step..base_step+7` so they land consecutively in the block. + uint64_t group_size = 1ULL << base_step; + uint64_t n_groups = n >> base_step; // = n / group_size + uint64_t low_bits = tid / n_groups; + uint64_t high_bits = tid & (n_groups - 1); + uint64_t row = high_bits * group_size + low_bits; + + // Load one element per thread. + tile[threadIdx.x] = x[row]; + __syncthreads(); + + // Each butterfly level uses half the threads (128 butterflies per block). + // The global butterfly index `ggp` is recovered from blockIdx + threadIdx + // and reshaped by the same row-remap to find the right twiddle. + uint32_t remaining_high_bits = (uint32_t)(log_n - base_step - 1); // log2(n_groups / 2) + uint32_t high_mask = (1u << remaining_high_bits) - 1u; + + for (uint32_t loc_step = 0; loc_step < n_loc_steps; ++loc_step) { + if (threadIdx.x < 128) { + uint32_t i = threadIdx.x; + uint32_t half = 1u << loc_step; + uint32_t grp = i >> loc_step; + uint32_t grp_pos = i & (half - 1); + uint32_t idx1 = (grp << (loc_step + 1)) + grp_pos; + uint32_t idx2 = idx1 + half; + + // Global step and butterfly position for twiddle lookup. + uint32_t gs = (uint32_t)base_step + loc_step; + uint32_t ggp = (blockIdx.x << 7) + i; // blockIdx * 128 + i + // Un-remap ggp to find its position in the natural ordering. + ggp = ((ggp & high_mask) << (uint32_t)base_step) + (ggp >> remaining_high_bits); + ggp = ggp & ((1u << gs) - 1u); + uint64_t factor = tw[(uint64_t)ggp * (n >> (gs + 1))]; + + uint64_t u = tile[idx1]; + uint64_t v = mul(tile[idx2], factor); + tile[idx1] = add(u, v); + tile[idx2] = sub(u, v); + } + __syncthreads(); + } + + // Store back to the remapped row. + x[row] = tile[threadIdx.x]; +} diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs new file mode 100644 index 000000000..2c70716a6 --- /dev/null +++ b/crypto/math-cuda/src/device.rs @@ -0,0 +1,251 @@ +//! CUDA device context, stream pool, kernel handles, and twiddle cache. +//! +//! One process-wide backend — lazy-initialised on first use. All kernels live +//! on a single CUDA context; a pool of streams lets rayon-parallel callers +//! overlap H2D / compute / D2H. + +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; + +use cudarc::driver::{CudaContext, CudaFunction, CudaSlice, CudaStream}; +use cudarc::nvrtc::Ptx; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsFFTField; + +use crate::Result; +use crate::ntt::{twiddles_forward, twiddles_inverse}; + +/// Reusable pinned host staging buffer. One per stream; the stream's LDE call +/// holds its buffer's lock across the D2H + memcpy-to-user-Vecs window. +/// +/// Allocated with `cuMemHostAlloc(flags=0)` — portable, non-write-combined, +/// so both DMA writes from device and CPU reads into user Vecs run at full +/// speed. Grows power-of-two; never shrinks. +pub struct PinnedStaging { + ptr: *mut u64, + capacity_elems: usize, +} + +// SAFETY: the raw pointer aliases host memory allocated via cuMemHostAlloc. +// We guard concurrent access with a Mutex; the pointer is valid for the +// lifetime of this struct and is freed on drop. +unsafe impl Send for PinnedStaging {} +unsafe impl Sync for PinnedStaging {} + +impl PinnedStaging { + const fn empty() -> Self { + Self { + ptr: std::ptr::null_mut(), + capacity_elems: 0, + } + } + + pub fn ensure_capacity( + &mut self, + min_elems: usize, + ctx: &CudaContext, + ) -> Result<()> { + if self.capacity_elems >= min_elems { + return Ok(()); + } + // cuMemHostAlloc requires the context to be current on this thread. + ctx.bind_to_thread()?; + // Free old (if any) before allocating the new one. + if !self.ptr.is_null() { + unsafe { + let _ = cudarc::driver::sys::cuMemFreeHost(self.ptr as *mut _); + } + self.ptr = std::ptr::null_mut(); + self.capacity_elems = 0; + } + let new_cap = min_elems.next_power_of_two().max(1 << 20); // at least 8 MB + let bytes = new_cap * std::mem::size_of::(); + let ptr = unsafe { + cudarc::driver::result::malloc_host(bytes, 0 /* flags: non-WC */)? + } as *mut u64; + self.ptr = ptr; + self.capacity_elems = new_cap; + Ok(()) + } + + /// View of the first `len` elements. Caller must hold this `PinnedStaging` + /// locked while using the slice; the slice aliases the internal pointer. + /// + /// # Safety + /// Caller must not outlive the `PinnedStaging` and must not race with + /// concurrent uses. + pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [u64] { + assert!(len <= self.capacity_elems); + if len == 0 { + return &mut []; + } + unsafe { std::slice::from_raw_parts_mut(self.ptr, len) } + } +} + +impl Drop for PinnedStaging { + fn drop(&mut self) { + if !self.ptr.is_null() { + unsafe { + let _ = cudarc::driver::sys::cuMemFreeHost(self.ptr as *mut _); + } + } + } +} + +const ARITH_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/arith.ptx")); +const NTT_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/ntt.ptx")); +/// Number of CUDA streams in the pool. Larger pools let many rayon-parallel +/// callers overlap on the GPU without serializing on stream ownership. The +/// default stream is deliberately excluded because it synchronises with all +/// other streams, defeating the point of the pool. +const STREAM_POOL_SIZE: usize = 32; + +pub struct Backend { + pub ctx: Arc, + streams: Vec>, + /// Single shared pinned staging buffer, grown to the biggest LDE size + /// seen. Concurrent batched LDE calls serialise on it; in exchange the + /// process keeps only ONE gigabyte-sized pinned allocation (per-stream + /// buffers 32×-inflated memory use and multiplied the one-time pinning + /// cost for every first use of a new table size). + pinned_staging: Mutex, + util_stream: Arc, + next: AtomicUsize, + + // arith.ptx + pub vector_add_u64: CudaFunction, + pub gl_add: CudaFunction, + pub gl_sub: CudaFunction, + pub gl_mul: CudaFunction, + pub gl_neg: CudaFunction, + pub ext3_mul: CudaFunction, + pub ext3_add: CudaFunction, + + // ntt.ptx + pub bit_reverse_permute: CudaFunction, + pub ntt_dit_level: CudaFunction, + pub ntt_dit_8_levels: CudaFunction, + pub pointwise_mul: CudaFunction, + pub scalar_mul: CudaFunction, + pub bit_reverse_permute_batched: CudaFunction, + pub ntt_dit_level_batched: CudaFunction, + pub ntt_dit_8_levels_batched: CudaFunction, + pub pointwise_mul_batched: CudaFunction, + pub scalar_mul_batched: CudaFunction, + + // Twiddle caches keyed by log_n. + fwd_twiddles: Mutex>>>>, + inv_twiddles: Mutex>>>>, +} + +impl Backend { + fn init() -> Result { + let ctx = CudaContext::new(0)?; + // cudarc's default per-slice CudaEvent tracking adds two driver calls + // per alloc and serialises under the context lock. We never share + // slices across streams (every call scopes its own buffers and syncs + // before returning), so the tracking is pure overhead. Disable it. + unsafe { ctx.disable_event_tracking() }; + + let arith = ctx.load_module(Ptx::from_src(ARITH_PTX))?; + let ntt = ctx.load_module(Ptx::from_src(NTT_PTX))?; + + let mut streams = Vec::with_capacity(STREAM_POOL_SIZE); + for _ in 0..STREAM_POOL_SIZE { + streams.push(ctx.new_stream()?); + } + let pinned_staging = Mutex::new(PinnedStaging::empty()); + // Separate "utility" stream for twiddle uploads and other bookkeeping; + // not part of the pool that callers rotate through. + let util_stream = ctx.new_stream()?; + + // Goldilocks TWO_ADICITY is 32, so log_n ≤ 32 covers every LDE size + // the prover can produce. Overshoot by one for safety. + let max_log = GoldilocksField::TWO_ADICITY as usize + 1; + + Ok(Self { + vector_add_u64: arith.load_function("vector_add_u64")?, + gl_add: arith.load_function("gl_add_kernel")?, + gl_sub: arith.load_function("gl_sub_kernel")?, + gl_mul: arith.load_function("gl_mul_kernel")?, + gl_neg: arith.load_function("gl_neg_kernel")?, + ext3_mul: arith.load_function("ext3_mul_kernel")?, + ext3_add: arith.load_function("ext3_add_kernel")?, + bit_reverse_permute: ntt.load_function("bit_reverse_permute")?, + ntt_dit_level: ntt.load_function("ntt_dit_level")?, + ntt_dit_8_levels: ntt.load_function("ntt_dit_8_levels")?, + pointwise_mul: ntt.load_function("pointwise_mul")?, + scalar_mul: ntt.load_function("scalar_mul")?, + bit_reverse_permute_batched: ntt.load_function("bit_reverse_permute_batched")?, + ntt_dit_level_batched: ntt.load_function("ntt_dit_level_batched")?, + ntt_dit_8_levels_batched: ntt.load_function("ntt_dit_8_levels_batched")?, + pointwise_mul_batched: ntt.load_function("pointwise_mul_batched")?, + scalar_mul_batched: ntt.load_function("scalar_mul_batched")?, + fwd_twiddles: Mutex::new(vec![None; max_log]), + inv_twiddles: Mutex::new(vec![None; max_log]), + ctx, + streams, + pinned_staging, + util_stream, + next: AtomicUsize::new(0), + }) + } + + /// Round-robin over the stream pool. Concurrent callers get different + /// streams so their kernel launches overlap on the GPU. + pub fn next_stream(&self) -> Arc { + let idx = self.next.fetch_add(1, Ordering::Relaxed) % self.streams.len(); + self.streams[idx].clone() + } + + /// Shared pinned staging buffer. Grows to the largest LDE the process + /// has seen so far. Concurrent callers serialise on the mutex. + pub fn pinned_staging(&self) -> &Mutex { + &self.pinned_staging + } + + pub fn fwd_twiddles_for(&self, log_n: u64) -> Result>> { + self.cached_twiddles(log_n, true) + } + + pub fn inv_twiddles_for(&self, log_n: u64) -> Result>> { + self.cached_twiddles(log_n, false) + } + + fn cached_twiddles(&self, log_n: u64, forward: bool) -> Result>> { + let idx = log_n as usize; + let cache = if forward { + &self.fwd_twiddles + } else { + &self.inv_twiddles + }; + { + let guard = cache.lock().unwrap(); + if let Some(t) = &guard[idx] { + return Ok(t.clone()); + } + } + // Compute on host, upload on the utility stream. Another thread may + // have populated the cache in the meantime; prefer that entry. + let host = if forward { + twiddles_forward(log_n) + } else { + twiddles_inverse(log_n) + }; + let dev = Arc::new(self.util_stream.clone_htod(&host)?); + self.util_stream.synchronize()?; + let mut guard = cache.lock().unwrap(); + if let Some(t) = &guard[idx] { + Ok(t.clone()) + } else { + guard[idx] = Some(dev.clone()); + Ok(dev) + } + } +} + +pub fn backend() -> &'static Backend { + static BACKEND: OnceLock = OnceLock::new(); + BACKEND.get_or_init(|| Backend::init().expect("failed to initialise CUDA backend")) +} diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs new file mode 100644 index 000000000..fd25d1bca --- /dev/null +++ b/crypto/math-cuda/src/lde.rs @@ -0,0 +1,984 @@ +//! Full coset LDE on device. Mirrors `Polynomial::coset_lde_full_expand` in +//! `crypto/math/src/fft/polynomial.rs` algebraically: +//! +//! Input : N evaluations (natural order) of a poly on the standard subgroup, +//! plus coset weights (size N). The weights include the `1/N` iFFT +//! normalisation, matching the `LdeTwiddles::coset_weights` format at +//! `crypto/stark/src/prover.rs:248` — i.e. `weights[i] = g^i / N`. +//! Output : N*blowup_factor evaluations (natural order) on the coset. +//! +//! On-device steps, picks a stream from the shared pool so rayon-parallel +//! callers overlap on the GPU. Twiddles are cached in the backend. + +use std::sync::Arc; + +use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; +use crate::ntt::run_ntt_body; + +/// Handle to a base-field LDE kept live on device after R1 commit. +/// Layout: `m` columns, each `lde_size` u64s, column `c` at byte offset +/// `c * lde_size * 8` within `buf`. Freed when `buf` Arc drops. +#[derive(Clone)] +pub struct GpuLdeBase { + pub buf: Arc>, + pub m: usize, + pub lde_size: usize, +} + +/// Handle to an ext3 LDE kept live on device, de-interleaved into 3 base +/// slabs per column. Column `c` component `k` at u64 offset +/// `(c*3 + k) * lde_size` within `buf`. +#[derive(Clone)] +pub struct GpuLdeExt3 { + pub buf: Arc>, + pub m: usize, + pub lde_size: usize, +} + +pub fn coset_lde_base( + evals: &[u64], + blowup_factor: usize, + weights: &[u64], +) -> Result> { + let n = evals.len(); + assert!(n.is_power_of_two(), "evals length must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match evals"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + if n == 0 { + return Ok(Vec::new()); + } + let lde_size = n * blowup_factor; + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + + // Device buffer of lde_size, zero-padded tail, first N filled by copy. + let mut buf = stream.alloc_zeros::(lde_size)?; + { + let mut head = buf.slice_mut(0..n); + stream.memcpy_htod(evals, &mut head)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + + // === 1. iNTT on first N: bit_reverse + 8-level-fused DIT body === + unsafe { + stream + .launch_builder(&be.bit_reverse_permute) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + // Note: `run_ntt_body` expects a standalone CudaSlice; we pass `buf` and + // the kernel walks the first `n_u64` elements via its own indexing. + run_ntt_body(stream.as_ref(), &mut buf, inv_tw.as_ref(), n_u64, log_n)?; + // Note: the CPU iFFT does not include 1/N — it's folded into `weights`. The + // next pointwise multiply applies both the coset shift and the 1/N factor. + + // === 2. Pointwise multiply first N by coset weights (includes 1/N) === + unsafe { + stream + .launch_builder(&be.pointwise_mul) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + + // === 3. Forward NTT on full buffer === + unsafe { + stream + .launch_builder(&be.bit_reverse_permute) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .launch(LaunchConfig::for_num_elems(lde_size as u32))?; + } + run_ntt_body(stream.as_ref(), &mut buf, fwd_tw.as_ref(), lde_u64, log_lde)?; + + let out = stream.clone_dtoh(&buf)?; + stream.synchronize()?; + Ok(out) +} + +/// Batched coset LDE: processes `m` columns (all the same domain) in a single +/// pipeline on one stream. One H2D per column, then per-level batched kernels +/// that launch with `grid.y = m` so a single launch does the butterflies for +/// every column at that level. +/// +/// Returns one `Vec` per input column, each of length `n * blowup_factor`. +pub fn coset_lde_batch_base( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], +) -> Result>> { + if columns.is_empty() { + return Ok(Vec::new()); + } + let m = columns.len(); + let n = columns[0].len(); + assert!(n.is_power_of_two(), "column length must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match column length"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + for c in columns.iter() { + assert_eq!(c.len(), n, "all columns must be the same size"); + } + + if n == 0 { + return Ok(vec![Vec::new(); m]); + } + let lde_size = n * blowup_factor; + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let debug_phases = std::env::var("MATH_CUDA_PHASE_TIMING").is_ok(); + let t_start = if debug_phases { Some(std::time::Instant::now()) } else { None }; + let phase = |label: &str, prev: &mut Option| { + if let Some(p) = prev.as_ref() { + let now = std::time::Instant::now(); + eprintln!(" [{:>6.2} ms] {}", (now - *p).as_secs_f64() * 1e3, label); + *prev = Some(now); + } + }; + let mut last = t_start; + + // Pinned staging. Lock and grow to max(m*n for upload, m*lde_size for + // download). Holding the guard across the whole call serialises concurrent + // batched calls that happened to hash to the same stream slot, but that's + // exactly what we want — one stream can only do one sequence at a time. + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + // SAFETY: staging is locked, the slice alias ends before we unlock. + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + if debug_phases { phase("staging lock + grow", &mut last); } + + // Pack columns into first m*n slots of the pinned buffer. Parallel: pinned + // writes are DRAM-bandwidth bound, saturates at ~8 cores on modern + // hardware, so rayon shaves 20+ ms at prover scale. + use rayon::prelude::*; + let pinned_base_ptr = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: each task writes to a disjoint `[c*n..c*n+n]` region of + // `pinned`, and the outer `staging` lock guarantees no other call is + // using the buffer concurrently. + let dst = unsafe { + std::slice::from_raw_parts_mut( + (pinned_base_ptr as *mut u64).add(c * n), + n, + ) + }; + dst.copy_from_slice(col); + }); + if debug_phases { phase("host pack (pinned, rayon)", &mut last); } + + // Column layout: `buf[c * lde_size + r]`. Zeroed so the [n, lde_size) + // tail of each column is already the zero-pad the CPU path does. + let mut buf = stream.alloc_zeros::(m * lde_size)?; + if debug_phases { stream.synchronize()?; phase("alloc_zeros", &mut last); } + // One memcpy per column from the pinned buffer into the strided slots. + // The pinned source hits PCIe line-rate. + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + if debug_phases { stream.synchronize()?; phase("H2D cols (pinned)", &mut last); } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + if debug_phases { stream.synchronize()?; phase("twiddles + weights", &mut last); } + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // === 1. Bit-reverse first N of every column === + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + if debug_phases { stream.synchronize()?; phase("bit_reverse N", &mut last); } + // === 2. iNTT body over all columns === + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + if debug_phases { stream.synchronize()?; phase("iNTT body", &mut last); } + + // === 3. Pointwise multiply by coset weights (includes 1/N) === + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + // === 4. Bit-reverse full LDE of every column === + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + if debug_phases { stream.synchronize()?; phase("pointwise + bit_reverse LDE", &mut last); } + // === 5. Forward NTT on full LDE of every column === + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + if debug_phases { stream.synchronize()?; phase("forward NTT body", &mut last); } + + // Single big D2H into the reusable pinned staging buffer — pinned, one + // call to the driver, saturates PCIe. + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + stream.synchronize()?; + if debug_phases { phase("D2H (one shot into pinned)", &mut last); } + + // Split pinned → per-column Vecs. The first write to each virgin + // Vec page-faults, which can dominate total time (~75 ms for 128 MB). + // Parallelise so the fault cost spreads across CPU cores. + use rayon::prelude::*; + let pinned_ptr = pinned.as_ptr() as usize; + let out: Vec> = (0..m) + .into_par_iter() + .map(|c| { + let mut v = Vec::::with_capacity(lde_size); + unsafe { v.set_len(lde_size) }; + let src = unsafe { + std::slice::from_raw_parts( + (pinned_ptr as *const u64).add(c * lde_size), + lde_size, + ) + }; + v.copy_from_slice(src); + v + }) + .collect(); + if debug_phases { phase("copy out (rayon pinned → Vecs)", &mut last); } + drop(staging); + Ok(out) +} + +/// Like `coset_lde_batch_base` but writes directly into caller-provided +/// output slices instead of allocating fresh `Vec`s. Each output slice +/// must already have length `n * blowup_factor`. Saves ~50–100 ms of pageable +/// allocator work + page faults at prover scale because the caller's Vecs +/// have been sized once and are reused across calls. +pub fn coset_lde_batch_base_into( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result<()> { + if columns.is_empty() { + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m, "outputs must match columns count"); + let n = columns[0].len(); + assert!(n.is_power_of_two(), "column length must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match column length"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + for c in columns.iter() { + assert_eq!(c.len(), n, "all columns must be the same size"); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), lde_size, "each output must be lde_size"); + } + if n == 0 { + return Ok(()); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + + for (c, col) in columns.iter().enumerate() { + pinned[c * n..c * n + n].copy_from_slice(col); + } + + let mut buf = stream.alloc_zeros::(m * lde_size)?; + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // iNTT bit-reverse + body, pointwise mul, forward bit-reverse + body. + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + stream.synchronize()?; + + // Parallel copy pinned → caller outputs. Caller's Vecs may still fault + // on first write; we spread that cost across rayon cores. + #[allow(unused_imports)] + use rayon::prelude::*; + let pinned_ptr = pinned.as_ptr() as usize; + outputs + .par_iter_mut() + .enumerate() + .for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts( + (pinned_ptr as *const u64).add(c * lde_size), + lde_size, + ) + }; + dst.copy_from_slice(src); + }); + drop(staging); + Ok(()) +} + +/// Batched ext3 polynomial → coset evaluation. +/// +/// Input: M ext3 columns of `n` coefficients each (interleaved, 3n u64). +/// Output: M ext3 columns of `n * blowup_factor` evaluations each at the +/// offset-coset. +/// +/// Skips the iFFT stage of [`coset_lde_batch_ext3_into`] (input is +/// coefficients, not evaluations). Weights encode the coset shift: +/// `weights[k] = offset^k` (NO 1/N because iFFT normalisation doesn't apply). +/// +/// Used by the stark prover to GPU-accelerate +/// `evaluate_polynomial_on_lde_domain` calls inside the +/// `number_of_parts > 2` branch of the composition-polynomial LDE. +pub fn evaluate_poly_coset_batch_ext3_into( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result<()> { + evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + false, + ) + .map(|_| ()) +} + +/// Same as [`evaluate_poly_coset_batch_ext3_into`] but retains the de- +/// interleaved LDE device buffer as a `GpuLdeExt3` handle. Lets R2 commit +/// and R4 DEEP composition read the composition-parts LDE without +/// re-H2D'ing. +pub fn evaluate_poly_coset_batch_ext3_into_keep( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result { + let opt = evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + true, + )?; + Ok(opt.expect("keep_device_buf=true must return Some")) +} + +fn evaluate_poly_coset_batch_ext3_into_inner( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + keep_device_buf: bool, +) -> Result> { + if coefs.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(None); + } + let m = coefs.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in coefs.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + if n == 0 { + return Ok(None); + } + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + coefs.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + // Apply coset scaling: x[k] *= weights[k] for k in 0..n (no iFFT first). + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + // Bit-reverse full lde_size slab, then forward DIT NTT. + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + stream.synchronize()?; + + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + drop(staging); + if keep_device_buf { + Ok(Some(GpuLdeExt3 { + buf: std::sync::Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + +/// Batched coset LDE for Goldilocks **cubic extension** columns. +/// +/// A degree-3 extension element is `(a, b, c)` in memory (three contiguous +/// u64s). The NTT butterfly multiplies `v = (a, b, c)` by a base-field +/// twiddle `t`: `t * v = (t*a, t*b, t*c)`. Addition is componentwise. So an +/// NTT over M ext3 columns is algebraically equivalent to **3M parallel +/// base-field NTTs** sharing the same twiddles and coset weights. We +/// exploit this to reuse the base-field kernels with no modification: +/// +/// 1. Host pack de-interleaves each ext3 column into 3 consecutive +/// base-field slabs inside the pinned staging buffer (slab 0 has all the +/// a-components, slab 1 all the b's, slab 2 all the c's — 3M base slabs +/// in total). +/// 2. Existing `bit_reverse_permute_batched` / `ntt_dit_*_batched` / +/// `pointwise_mul_batched` run over those 3M base slabs on device. +/// 3. D2H, then re-interleave 3 slabs per output ext3 column. +/// +/// Input/output layout: each slice is 3*n or 3*n*blowup u64s, packed as +/// `[a0, b0, c0, a1, b1, c1, ...]` — the natural `[FieldElement]` +/// memory representation. +pub fn coset_lde_batch_ext3_into( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result<()> { + if columns.is_empty() { + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m, "outputs must match columns count"); + assert!(n.is_power_of_two(), "n must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match n"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + for c in columns.iter() { + assert_eq!(c.len(), 3 * n, "each ext3 column must be 3*n u64s"); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size, "each output must be 3*lde_size u64s"); + } + if n == 0 { + return Ok(()); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + // 3 base slabs per ext3 column; slab index `c*3 + k` holds component `k`. + let mb = 3 * m; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + // Pack: for each ext3 column, write 3 base slabs into pinned. The slab + // for column c, component k lives at `pinned[(c*3 + k)*n .. (c*3+k)*n + n]`. + // We de-interleave from the interleaved `[a, b, c, a, b, c, ...]` input. + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: disjoint regions per c; staging lock held. + let slab_a = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 0) * n), + n, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), + n, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), + n, + ) + }; + for i in 0..n { + slab_a[i] = col[i * 3 + 0]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + // Allocate + zero-pad device buffer holding 3M slabs of `lde_size`. + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + // H2D: slab by slab into the first N slots of each `lde_size`-slab. + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + // === Butterflies: identical to the base-field batched path, but with + // grid.y = 3M instead of M. === + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + stream.synchronize()?; + + // Unpack: for each output column, re-interleave 3 slabs back into the + // ext3-per-element layout. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 0) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3 + 0] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + drop(staging); + Ok(()) +} + +/// Run the DIT butterfly body of a bit-reversed-input NTT over `m` batched +/// columns in one device buffer. Same fusion strategy as `run_ntt_body`: +/// first 8 levels shmem-fused (coalesced), subsequent levels one kernel each. +fn run_batched_ntt_body( + stream: &cudarc::driver::CudaStream, + x_dev: &mut cudarc::driver::CudaSlice, + tw_dev: &cudarc::driver::CudaSlice, + n: u64, + log_n: u64, + col_stride: u64, + m: u32, +) -> Result<()> { + let be = backend(); + let fused = core::cmp::min(log_n, 8); + if fused >= 8 { + let grid_x = (n / 256) as u32; + let cfg = LaunchConfig { + grid_dim: (grid_x, m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + let base_step = 0u64; + unsafe { + stream + .launch_builder(&be.ntt_dit_8_levels_batched) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&base_step) + .arg(&col_stride) + .launch(cfg)?; + } + } else { + let grid_x = ((n / 2) as u32).div_ceil(256).max(1); + let cfg = LaunchConfig { + grid_dim: (grid_x, m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + for level in 0..fused { + unsafe { + stream + .launch_builder(&be.ntt_dit_level_batched) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .arg(&col_stride) + .launch(cfg)?; + } + } + } + + let grid_x = ((n / 2) as u32).div_ceil(256).max(1); + let cfg = LaunchConfig { + grid_dim: (grid_x, m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + for level in fused..log_n { + unsafe { + stream + .launch_builder(&be.ntt_dit_level_batched) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .arg(&col_stride) + .launch(cfg)?; + } + } + Ok(()) +} + diff --git a/crypto/math-cuda/src/lib.rs b/crypto/math-cuda/src/lib.rs new file mode 100644 index 000000000..821f5bd3a --- /dev/null +++ b/crypto/math-cuda/src/lib.rs @@ -0,0 +1,152 @@ +//! GPU backend for the lambda-vm STARK prover. +//! +//! Primary entry point: [`lde::coset_lde_base`]. Everything else (`ntt`, +//! element-wise arith) is either internal to the LDE pipeline or used by the +//! parity test suite. + +pub mod device; +pub mod lde; +pub mod ntt; + +use cudarc::driver::{LaunchConfig, PushKernelArg}; + +use crate::device::{Backend, backend}; + +pub type Result = std::result::Result; + +/// Toolchain sanity: plain wrapping u64 vector add. Not a field op. +pub fn vector_add_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.vector_add_u64) +} + +/// Goldilocks field add on device, element-wise. Inputs may be non-canonical. +pub fn gl_add_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.gl_add) +} + +pub fn gl_sub_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.gl_sub) +} + +pub fn gl_mul_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.gl_mul) +} + +pub fn gl_neg_u64(a: &[u64]) -> Result> { + let n = a.len(); + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + + let a_dev = stream.clone_htod(a)?; + let mut c_dev = stream.alloc_zeros::(n)?; + + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.gl_neg) + .arg(&a_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Element-wise ext3 multiply. `a` and `b` are 3n u64s (interleaved +/// [a0,a1,a2,b0,b1,b2,...]). Test helper for the `ext3.cuh` header. +pub fn ext3_mul_u64(a: &[u64], b: &[u64]) -> Result> { + assert_eq!(a.len(), b.len()); + assert_eq!(a.len() % 3, 0); + let n = a.len() / 3; + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + let a_dev = stream.clone_htod(a)?; + let b_dev = stream.clone_htod(b)?; + let mut c_dev = stream.alloc_zeros::(3 * n)?; + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.ext3_mul) + .arg(&a_dev) + .arg(&b_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Element-wise ext3 add. +pub fn ext3_add_u64(a: &[u64], b: &[u64]) -> Result> { + assert_eq!(a.len(), b.len()); + assert_eq!(a.len() % 3, 0); + let n = a.len() / 3; + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + let a_dev = stream.clone_htod(a)?; + let b_dev = stream.clone_htod(b)?; + let mut c_dev = stream.alloc_zeros::(3 * n)?; + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.ext3_add) + .arg(&a_dev) + .arg(&b_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} + +fn launch_binary_u64(a: &[u64], b: &[u64], pick: F) -> Result> +where + F: for<'a> Fn(&'a Backend) -> &'a cudarc::driver::CudaFunction, +{ + assert_eq!(a.len(), b.len(), "length mismatch"); + let n = a.len(); + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + + let a_dev = stream.clone_htod(a)?; + let b_dev = stream.clone_htod(b)?; + let mut c_dev = stream.alloc_zeros::(n)?; + + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(pick(be)) + .arg(&a_dev) + .arg(&b_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} diff --git a/crypto/math-cuda/src/ntt.rs b/crypto/math-cuda/src/ntt.rs new file mode 100644 index 000000000..0ebb015ea --- /dev/null +++ b/crypto/math-cuda/src/ntt.rs @@ -0,0 +1,211 @@ +//! Forward and inverse NTT over Goldilocks base field. Matches the algebraic +//! contract of `math::polynomial::Polynomial::evaluate_fft` / +//! `interpolate_fft`: +//! input = n elements in natural order +//! output = n elements in natural order. +//! +//! Parity is checked by `tests/ntt.rs` against the CPU implementation. + +use cudarc::driver::{LaunchConfig, PushKernelArg}; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsFFTField, IsField}; + +use crate::Result; +use crate::device::backend; + +/// Host-side twiddle table: `[ω^0, ω^1, …, ω^{n/2-1}]` where ω is the +/// primitive n-th root of unity. Exposed for `device::Backend::cached_twiddles` +/// and for direct use in tests / benches. +pub fn twiddles_forward(log_n: u64) -> Vec { + let omega = *GoldilocksField::get_primitive_root_of_unity(log_n) + .expect("primitive root") + .value(); + powers_of(omega, 1usize << (log_n - 1)) +} + +/// Inverse twiddle table: `[ω^{-i}]` for i in [0, n/2). +pub fn twiddles_inverse(log_n: u64) -> Vec { + let omega = GoldilocksField::get_primitive_root_of_unity(log_n).expect("primitive root"); + let omega_inv = FieldElement::::inv(&omega).expect("inverse"); + powers_of(*omega_inv.value(), 1usize << (log_n - 1)) +} + +fn powers_of(base: u64, count: usize) -> Vec { + let mut out = Vec::with_capacity(count); + let mut w = 1u64; + for _ in 0..count { + out.push(w); + w = GoldilocksField::mul(&w, &base); + } + out +} + +/// Forward NTT on a slice of `n = 2^log_n` Goldilocks coefficients. Takes +/// natural-order input and returns natural-order evaluations. +pub fn forward(coeffs: &[u64]) -> Result> { + ntt_inplace(coeffs, /*forward=*/ true) +} + +/// Inverse NTT on a slice of `n = 2^log_n` Goldilocks evaluations. Takes +/// natural-order evaluations and returns natural-order coefficients. Includes +/// the 1/n scaling. +pub fn inverse(evals: &[u64]) -> Result> { + ntt_inplace(evals, /*forward=*/ false) +} + +fn ntt_inplace(input: &[u64], forward: bool) -> Result> { + let n = input.len(); + assert!(n.is_power_of_two(), "ntt length must be a power of two"); + if n <= 1 { + return Ok(input.to_vec()); + } + let log_n = n.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + + let mut x_dev = stream.clone_htod(input)?; + let tw_dev = if forward { + be.fwd_twiddles_for(log_n)? + } else { + be.inv_twiddles_for(log_n)? + }; + + let n_u64 = n as u64; + + // 1. Bit-reverse: natural → bit-reversed. + unsafe { + stream + .launch_builder(&be.bit_reverse_permute) + .arg(&mut x_dev) + .arg(&n_u64) + .arg(&log_n) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + + // 2. DIT butterfly levels. For log_n >= 8 we fuse 8 levels per kernel via + // the shmem kernel; for very small sizes (< 256 elements) we stick with + // the per-level kernel because the shmem block dimensions assume n ≥ 256. + run_ntt_body( + stream.as_ref(), + &mut x_dev, + tw_dev.as_ref(), + n_u64, + log_n, + )?; + + // 3. For iNTT, multiply by 1/n. + if !forward { + let n_fe = FieldElement::::from(n as u64); + let inv_n = *n_fe.inv().expect("n is non-zero").value(); + unsafe { + stream + .launch_builder(&be.scalar_mul) + .arg(&mut x_dev) + .arg(&inv_n) + .arg(&n_u64) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + } + + let out = stream.clone_dtoh(&x_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Run the butterfly body of a bit-reversed-input DIT NTT. Split out so the +/// LDE orchestrator can reuse it on the same device buffer. +pub(crate) fn run_ntt_body( + stream: &cudarc::driver::CudaStream, + x_dev: &mut cudarc::driver::CudaSlice, + tw_dev: &cudarc::driver::CudaSlice, + n: u64, + log_n: u64, +) -> Result<()> { + let be = backend(); + // Levels 0..min(log_n, 8): one shmem-fused launch. Loads are fully + // coalesced (base_step=0 → `row = tid`) and 8 butterfly rounds stay on + // chip. This is the big DRAM-bandwidth win. + let fused = core::cmp::min(log_n, 8); + if fused >= 8 { + let grid_x = (n / 256) as u32; + let cfg = LaunchConfig { + grid_dim: (grid_x, 1, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + let base_step = 0u64; + unsafe { + stream + .launch_builder(&be.ntt_dit_8_levels) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&base_step) + .launch(cfg)?; + } + } else { + // Sub-256-element NTT. Use per-level. + let half_cfg = LaunchConfig::for_num_elems((n / 2) as u32); + for level in 0..fused { + unsafe { + stream + .launch_builder(&be.ntt_dit_level) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .launch(half_cfg)?; + } + } + } + + // Levels 8..log_n: per-level kernels. Loads are fully coalesced in the + // per-level path; switching to fused-with-row-remap at base_step>0 tanks + // DRAM throughput enough to wipe out the launch savings. + let half_cfg = LaunchConfig::for_num_elems((n / 2) as u32); + for level in fused..log_n { + unsafe { + stream + .launch_builder(&be.ntt_dit_level) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .launch(half_cfg)?; + } + } + Ok(()) +} + +/// Pointwise multiply: `x[i] *= w[i]`. +pub fn pointwise_mul(x: &[u64], w: &[u64]) -> Result> { + assert_eq!(x.len(), w.len()); + let n = x.len(); + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + + let mut x_dev = stream.clone_htod(x)?; + let w_dev = stream.clone_htod(w)?; + + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.pointwise_mul) + .arg(&mut x_dev) + .arg(&w_dev) + .arg(&n_u64) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + + let out = stream.clone_dtoh(&x_dev)?; + stream.synchronize()?; + Ok(out) +} diff --git a/crypto/math-cuda/tests/bench_quick.rs b/crypto/math-cuda/tests/bench_quick.rs new file mode 100644 index 000000000..561331b74 --- /dev/null +++ b/crypto/math-cuda/tests/bench_quick.rs @@ -0,0 +1,356 @@ +//! Informal timing comparison for single-column and multi-column LDE. +//! Ignored by default; run with `cargo test ... -- --ignored --nocapture`. + +use std::time::Instant; + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsField; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::prelude::*; + +type Fp = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = *FieldElement::::from(n as u64).inv().unwrap().value(); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &g); + } + w +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_2_to_18_blowup_4() { + let log_n = 18; + let blowup = 4; + let n = 1usize << log_n; + let lde = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(1); + let input: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + let weights = coset_weights(n, 7); + + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new(lde.trailing_zeros() as u64).unwrap(); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + + const TRIALS: u32 = 10; + + let t0 = Instant::now(); + for _ in 0..TRIALS { + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + } + let gpu_ns = t0.elapsed().as_nanos() / TRIALS as u128; + + let t0 = Instant::now(); + for _ in 0..TRIALS { + let mut buf: Vec = input.iter().map(|&x| Fp::from_raw(x)).collect(); + Polynomial::coset_lde_full_expand::( + &mut buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + std::hint::black_box(&buf); + } + let cpu_ns = t0.elapsed().as_nanos() / TRIALS as u128; + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "single-column LDE 2^{log_n} blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x", + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_2_to_16_blowup_4() { + let log_n = 16; + let blowup = 4; + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(2); + let input: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + let weights = coset_weights(n, 7); + + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + + const TRIALS: u32 = 20; + + let t0 = Instant::now(); + for _ in 0..TRIALS { + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + } + let gpu_ns = t0.elapsed().as_nanos() / TRIALS as u128; + println!("single-column LDE 2^{log_n} blowup={blowup}: gpu={gpu_ns}ns"); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_multi_column_parallel() { + // Simulates the prover's Phase A: many columns processed via rayon. + // log_n = 16 keeps memory footprint manageable while still stressing streams. + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let lde = n * blowup; + let num_cols = 64; + + // Warm up. + let _ = math_cuda::lde::coset_lde_base( + &vec![0u64; n], + blowup, + &coset_weights(n, 7), + ) + .unwrap(); + + // Build input data. + let mut rng = ChaCha8Rng::seed_from_u64(11); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new(lde.trailing_zeros() as u64).unwrap(); + + // GPU: rayon parallel across columns, each column picks a stream. + let t0 = Instant::now(); + let _gpu_results: Vec> = columns + .par_iter() + .map(|col| math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap()) + .collect(); + let gpu_ns = t0.elapsed().as_nanos(); + + // CPU: same rayon parallel pattern as the prover's `expand_columns_to_lde`. + let mut cpu_bufs: Vec> = columns + .iter() + .map(|c| c.iter().map(|&x| Fp::from_raw(x)).collect()) + .collect(); + let t0 = Instant::now(); + cpu_bufs.par_iter_mut().for_each(|buf| { + Polynomial::coset_lde_full_expand::( + buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + }); + let cpu_ns = t0.elapsed().as_nanos(); + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "{num_cols}-column LDE 2^{log_n} blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x (cores={})", + rayon::current_num_threads(), + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_batched_prover_scale() { + // Realistic large-table shape from the 1M-fib prover: ~1M rows, blowup 4, + // a few dozen columns. This is what actually runs in expand_columns_to_lde. + let log_n = 20u32; // 1M rows + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 20; + + let mut rng = ChaCha8Rng::seed_from_u64(31); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new( + (n * blowup).trailing_zeros() as u64, + ) + .unwrap(); + + let warm_slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + for _ in 0..8 { + let _ = + math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); + } + + let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + let mut gpu_samples = Vec::with_capacity(10); + for _ in 0..10 { + let t0 = Instant::now(); + let _ = math_cuda::lde::coset_lde_batch_base(&slices, blowup, &weights).unwrap(); + gpu_samples.push(t0.elapsed().as_nanos()); + } + gpu_samples.sort(); + let gpu_ns = gpu_samples[gpu_samples.len() / 2]; // median + + let mut cpu_samples = Vec::with_capacity(10); + for _ in 0..10 { + let mut cpu_bufs: Vec> = columns + .iter() + .map(|c| c.iter().map(|&x| Fp::from_raw(x)).collect()) + .collect(); + let t0 = Instant::now(); + cpu_bufs.par_iter_mut().for_each(|buf| { + Polynomial::coset_lde_full_expand::( + buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + }); + cpu_samples.push(t0.elapsed().as_nanos()); + } + cpu_samples.sort(); + let cpu_ns = cpu_samples[cpu_samples.len() / 2]; // median + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "prover-scale batched {num_cols} cols, log_n={log_n}, blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x (median of 10)", + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_batched_vs_rayon_cpu() { + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 64; + + let mut rng = ChaCha8Rng::seed_from_u64(21); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + + // Warm up every stream slot so subsequent iterations don't pay the + // one-time pinned staging alloc cost. + let warm_slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + for _ in 0..64 { + let _ = + math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); + } + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new( + (n * blowup).trailing_zeros() as u64, + ) + .unwrap(); + + // GPU batched — first run may include lazy device init; do a few to stabilise. + let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + let mut gpu_ns = u128::MAX; + for _ in 0..5 { + let t0 = Instant::now(); + let _ = math_cuda::lde::coset_lde_batch_base(&slices, blowup, &weights).unwrap(); + gpu_ns = gpu_ns.min(t0.elapsed().as_nanos()); + } + + // CPU rayon (same pattern as prover). + let mut cpu_bufs: Vec> = columns + .iter() + .map(|c| c.iter().map(|&x| Fp::from_raw(x)).collect()) + .collect(); + let t0 = Instant::now(); + cpu_bufs.par_iter_mut().for_each(|buf| { + Polynomial::coset_lde_full_expand::( + buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + }); + let cpu_ns = t0.elapsed().as_nanos(); + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "batched {num_cols} cols, log_n={log_n}, blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x (cores={})", + rayon::current_num_threads(), + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_multi_column_serialized_gpu() { + use std::sync::Mutex; + + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 64; + + let _ = math_cuda::lde::coset_lde_base( + &vec![0u64; n], + blowup, + &coset_weights(n, 7), + ) + .unwrap(); + + let mut rng = ChaCha8Rng::seed_from_u64(13); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + + // Single global Mutex so only one thread at a time calls GPU. + let gpu_lock = Mutex::new(()); + let t0 = Instant::now(); + let _: Vec> = columns + .par_iter() + .map(|col| { + let _guard = gpu_lock.lock().unwrap(); + math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap() + }) + .collect(); + let gpu_ns = t0.elapsed().as_nanos(); + println!("GPU mutex-serialised from rayon: {gpu_ns}ns for {num_cols} cols"); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_multi_column_gpu_limited_threads() { + // Same as multi_column_parallel but forces rayon to use only 8 threads + // (matching the GPU stream pool rough capacity). Tests whether oversubscribed + // rayon + many streams is the bottleneck. + let gpu_pool = rayon::ThreadPoolBuilder::new() + .num_threads(8) + .build() + .unwrap(); + + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 64; + + let _ = math_cuda::lde::coset_lde_base( + &vec![0u64; n], + blowup, + &coset_weights(n, 7), + ) + .unwrap(); + + let mut rng = ChaCha8Rng::seed_from_u64(12); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + + let t0 = Instant::now(); + let _gpu_results: Vec> = gpu_pool.install(|| { + columns + .par_iter() + .map(|col| math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap()) + .collect() + }); + let gpu_ns = t0.elapsed().as_nanos(); + + let t0 = Instant::now(); + let _serial_gpu_results: Vec> = columns + .iter() + .map(|col| math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap()) + .collect(); + let gpu_serial_ns = t0.elapsed().as_nanos(); + + println!( + "GPU-only 8-thread: gpu-parallel={gpu_ns}ns gpu-serial={gpu_serial_ns}ns speedup={:.2}x", + gpu_serial_ns as f64 / gpu_ns as f64, + ); +} diff --git a/crypto/math-cuda/tests/evaluate_coset_ext3.rs b/crypto/math-cuda/tests/evaluate_coset_ext3.rs new file mode 100644 index 000000000..a79195291 --- /dev/null +++ b/crypto/math-cuda/tests/evaluate_coset_ext3.rs @@ -0,0 +1,143 @@ +//! Parity test for `evaluate_poly_coset_batch_ext3_into`. +//! +//! Reference: `math::polynomial::Polynomial::evaluate_offset_fft` on an ext3 +//! polynomial, then canonicalise. The GPU path should produce the same +//! evaluations on the offset-coset at `n * blowup` points. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn offset_weights(n: usize, offset: u64) -> Vec { + let mut w = Vec::with_capacity(n); + let mut cur = 1u64; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &offset); + } + w +} + +fn rand_ext3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn u64s_to_ext3(raw: &[u64]) -> Vec { + let mut out = Vec::with_capacity(raw.len() / 3); + for i in 0..raw.len() / 3 { + out.push(Fp3::new([ + Fp::from_raw(raw[i * 3 + 0]), + Fp::from_raw(raw[i * 3 + 1]), + Fp::from_raw(raw[i * 3 + 2]), + ])); + } + out +} + +fn canon_fp3(e: &Fp3) -> [u64; 3] { + [ + GoldilocksField::canonical(e.value()[0].value()), + GoldilocksField::canonical(e.value()[1].value()), + GoldilocksField::canonical(e.value()[2].value()), + ] +} + +fn assert_evaluate_coset(log_n: u64, blowup: usize, m: usize, offset: u64, seed: u64) { + let n = 1usize << log_n; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + + // M ext3 polynomials, each of degree < n. + let polys: Vec> = (0..m) + .map(|_| (0..n).map(|_| rand_ext3(&mut rng)).collect()) + .collect(); + + let weights = offset_weights(n, offset); + + // CPU reference: evaluate each polynomial at `offset`-coset of size lde_size. + let offset_fp = Fp::from_raw(offset); + let cpu: Vec> = polys + .iter() + .map(|coefs| { + let p = Polynomial::new(coefs); + Polynomial::evaluate_offset_fft::( + &p, + blowup, + Some(n), + &offset_fp, + ) + .unwrap() + }) + .collect(); + + // GPU: flatten each poly to 3n u64s, pre-allocate 3*lde_size u64 outputs. + let flat_inputs: Vec> = polys.iter().map(|p| ext3_to_u64s(p)).collect(); + let input_slices: Vec<&[u64]> = flat_inputs.iter().map(|v| v.as_slice()).collect(); + let mut flat_outputs: Vec> = (0..m).map(|_| vec![0u64; 3 * lde_size]).collect(); + { + let mut out_slices: Vec<&mut [u64]> = + flat_outputs.iter_mut().map(|v| v.as_mut_slice()).collect(); + math_cuda::lde::evaluate_poly_coset_batch_ext3_into( + &input_slices, + n, + blowup, + &weights, + &mut out_slices, + ) + .unwrap(); + } + + for c in 0..m { + let gpu: Vec = u64s_to_ext3(&flat_outputs[c]); + assert_eq!(gpu.len(), cpu[c].len(), "length mismatch"); + for i in 0..gpu.len() { + let g = canon_fp3(&gpu[i]); + let cc = canon_fp3(&cpu[c][i]); + assert_eq!(g, cc, "eval mismatch col={c} row={i} log_n={log_n} blowup={blowup}"); + } + } +} + +#[test] +fn ext3_evaluate_coset_small() { + for &m in &[1usize, 4] { + for log_n in 4..=10 { + for &blowup in &[2usize, 4] { + assert_evaluate_coset(log_n, blowup, m, 7, 100 + log_n * 10 + m as u64); + } + } + } +} + +#[test] +fn ext3_evaluate_coset_medium() { + for log_n in 11..=14 { + assert_evaluate_coset(log_n, 4, 2, 7, 200 + log_n); + } +} + +#[test] +fn ext3_evaluate_coset_large_one_column() { + assert_evaluate_coset(16, 4, 1, 7, 0xCAFE); +} diff --git a/crypto/math-cuda/tests/ext3.rs b/crypto/math-cuda/tests/ext3.rs new file mode 100644 index 000000000..c9aabbc27 --- /dev/null +++ b/crypto/math-cuda/tests/ext3.rs @@ -0,0 +1,87 @@ +//! Parity: GPU ext3 arithmetic must agree (canonically) with CPU +//! `Degree3GoldilocksExtensionField` on random ext3 inputs. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +const N: usize = 10_000; + +fn random_fp3s(seed: u64, count: usize) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..count) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() +} + +fn to_u64s(col: &[Fp3]) -> Vec { + let mut v = Vec::with_capacity(col.len() * 3); + for e in col { + v.push(*e.value()[0].value()); + v.push(*e.value()[1].value()); + v.push(*e.value()[2].value()); + } + v +} + +fn canon_triplet(e: &Fp3) -> [u64; 3] { + [ + GoldilocksField::canonical(e.value()[0].value()), + GoldilocksField::canonical(e.value()[1].value()), + GoldilocksField::canonical(e.value()[2].value()), + ] +} + +fn canon_triplet_raw(t: &[u64]) -> [u64; 3] { + [ + GoldilocksField::canonical(&t[0]), + GoldilocksField::canonical(&t[1]), + GoldilocksField::canonical(&t[2]), + ] +} + +#[test] +fn ext3_mul_matches_cpu() { + let a = random_fp3s(11, N); + let b = random_fp3s(22, N); + let a_raw = to_u64s(&a); + let b_raw = to_u64s(&b); + let gpu = math_cuda::ext3_mul_u64(&a_raw, &b_raw).unwrap(); + assert_eq!(gpu.len(), 3 * N); + for i in 0..N { + use math::field::traits::IsField; + let cpu = Degree3GoldilocksExtensionField::mul(a[i].value(), b[i].value()); + let cpu_fp3 = Fp3::new(cpu); + let g = canon_triplet_raw(&gpu[i * 3..(i + 1) * 3]); + let c = canon_triplet(&cpu_fp3); + assert_eq!(g, c, "ext3 mul mismatch at {i}"); + } +} + +#[test] +fn ext3_add_matches_cpu() { + let a = random_fp3s(33, N); + let b = random_fp3s(44, N); + let a_raw = to_u64s(&a); + let b_raw = to_u64s(&b); + let gpu = math_cuda::ext3_add_u64(&a_raw, &b_raw).unwrap(); + for i in 0..N { + let cpu = Degree3GoldilocksExtensionField::add(a[i].value(), b[i].value()); + let cpu_fp3 = Fp3::new(cpu); + let g = canon_triplet_raw(&gpu[i * 3..(i + 1) * 3]); + let c = canon_triplet(&cpu_fp3); + assert_eq!(g, c, "ext3 add mismatch at {i}"); + } +} diff --git a/crypto/math-cuda/tests/goldilocks.rs b/crypto/math-cuda/tests/goldilocks.rs new file mode 100644 index 000000000..317ffb0f8 --- /dev/null +++ b/crypto/math-cuda/tests/goldilocks.rs @@ -0,0 +1,127 @@ +//! GPU must produce bit-identical u64 outputs to `GoldilocksField` for every op. +//! Non-canonical inputs are expected (CPU operates on the full [0, 2^64) range), +//! so the test inputs include values above the prime. + +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +const N: usize = 10_000; + +fn sample_inputs(seed: u64) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..N).map(|_| rng.r#gen::()).collect() +} + +fn assert_raw_eq(op: &str, expected: &[u64], actual: &[u64]) { + assert_eq!(expected.len(), actual.len()); + for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() { + if e != a { + panic!( + "{op} mismatch at {i}: cpu={e:#018x} (canon {:#018x}), gpu={a:#018x} (canon {:#018x})", + GoldilocksField::canonical(e), + GoldilocksField::canonical(a), + ); + } + } +} + +#[test] +fn gpu_vector_add_u64_matches_wrapping() { + let a = sample_inputs(0xC0FFEE); + let b = sample_inputs(0xDEADBEEF); + let expected: Vec = a.iter().zip(&b).map(|(x, y)| x.wrapping_add(*y)).collect(); + let actual = math_cuda::vector_add_u64(&a, &b).expect("GPU vector_add_u64"); + assert_raw_eq("vector_add (wrapping)", &expected, &actual); +} + +#[test] +fn gpu_gl_add_matches_cpu() { + let a = sample_inputs(1); + let b = sample_inputs(2); + let expected: Vec = a + .iter() + .zip(&b) + .map(|(x, y)| GoldilocksField::add(x, y)) + .collect(); + let actual = math_cuda::gl_add_u64(&a, &b).expect("GPU gl_add"); + assert_raw_eq("gl_add", &expected, &actual); +} + +#[test] +fn gpu_gl_sub_matches_cpu() { + let a = sample_inputs(3); + let b = sample_inputs(4); + let expected: Vec = a + .iter() + .zip(&b) + .map(|(x, y)| GoldilocksField::sub(x, y)) + .collect(); + let actual = math_cuda::gl_sub_u64(&a, &b).expect("GPU gl_sub"); + assert_raw_eq("gl_sub", &expected, &actual); +} + +#[test] +fn gpu_gl_mul_matches_cpu() { + let a = sample_inputs(5); + let b = sample_inputs(6); + let expected: Vec = a + .iter() + .zip(&b) + .map(|(x, y)| GoldilocksField::mul(x, y)) + .collect(); + let actual = math_cuda::gl_mul_u64(&a, &b).expect("GPU gl_mul"); + assert_raw_eq("gl_mul", &expected, &actual); +} + +#[test] +fn gpu_gl_neg_matches_cpu() { + let a = sample_inputs(7); + let expected: Vec = a.iter().map(|x| GoldilocksField::neg(x)).collect(); + let actual = math_cuda::gl_neg_u64(&a).expect("GPU gl_neg"); + assert_raw_eq("gl_neg", &expected, &actual); +} + +/// Edge cases the random generator is unlikely to hit: 0, 1, p-1, p, p+1, 2p-1, +/// u64::MAX, EPSILON boundary values. Covers double-overflow / double-underflow. +#[test] +fn gpu_goldilocks_edge_cases() { + const P: u64 = 0xFFFF_FFFF_0000_0001; + const EPS: u64 = 0xFFFF_FFFF; + let edge: [u64; 11] = [ + 0, + 1, + P - 1, + P, + P + 1, + 2u64.wrapping_mul(P).wrapping_sub(1), + u64::MAX, + u64::MAX - EPS, + u64::MAX - 1, + EPS, + EPS - 1, + ]; + // All pairs via nested loops, materialised as flat a[], b[] of length edge^2. + let mut a = Vec::with_capacity(edge.len() * edge.len()); + let mut b = Vec::with_capacity(edge.len() * edge.len()); + for &x in &edge { + for &y in &edge { + a.push(x); + b.push(y); + } + } + + let cases: &[(&str, fn(&[u64], &[u64]) -> math_cuda::Result>, fn(&u64, &u64) -> u64)] = + &[ + ("gl_add", math_cuda::gl_add_u64, GoldilocksField::add), + ("gl_sub", math_cuda::gl_sub_u64, GoldilocksField::sub), + ("gl_mul", math_cuda::gl_mul_u64, GoldilocksField::mul), + ]; + + for (op, gpu_fn, cpu_fn) in cases { + let expected: Vec = a.iter().zip(&b).map(|(x, y)| cpu_fn(x, y)).collect(); + let actual = gpu_fn(&a, &b).expect("GPU op"); + assert_raw_eq(op, &expected, &actual); + } +} diff --git a/crypto/math-cuda/tests/lde.rs b/crypto/math-cuda/tests/lde.rs new file mode 100644 index 000000000..9648f833a --- /dev/null +++ b/crypto/math-cuda/tests/lde.rs @@ -0,0 +1,112 @@ +//! Phase-5 parity: GPU `coset_lde_base` must match the CPU +//! `Polynomial::coset_lde_full_expand` for a sweep of realistic sizes and +//! blowup factors. + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; + +/// Build the coset weights `[1/N, g/N, g²/N, …, g^{n-1}/N]` — this is the +/// layout `crypto/stark/src/prover.rs:248` uses, with `1/N` pre-folded into the +/// first coefficient so the iFFT step does not need a separate scaling pass. +fn coset_weights(n: usize, coset_offset: u64) -> Vec { + let inv_n_fe = FieldElement::::from(n as u64) + .inv() + .expect("n is non-zero"); + let mut w = Vec::with_capacity(n); + let mut cur = *inv_n_fe.value(); + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &coset_offset); + } + w +} + +fn cpu_lde(evals: &[u64], blowup_factor: usize, coset_offset: u64) -> Vec { + let n = evals.len(); + let log_n = n.trailing_zeros() as u64; + let log_lde = (n * blowup_factor).trailing_zeros() as u64; + + let inv_tw = LayerTwiddles::::new_inverse(log_n).expect("inv tw"); + let fwd_tw = LayerTwiddles::::new(log_lde).expect("fwd tw"); + let weights_raw = coset_weights(n, coset_offset); + let weights: Vec = weights_raw.iter().map(|&w| Fp::from_raw(w)).collect(); + + let mut buf: Vec = evals.iter().map(|&x| Fp::from_raw(x)).collect(); + Polynomial::coset_lde_full_expand::( + &mut buf, + blowup_factor, + &weights, + &inv_tw, + &fwd_tw, + ) + .expect("cpu lde"); + + buf.into_iter().map(|e| *e.value()).collect() +} + +fn canon(xs: &[u64]) -> Vec { + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() +} + +fn assert_lde_match(log_n: u64, blowup_factor: usize, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let evals: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + + // Use a fixed, public coset offset. For lambda-vm the coset offset is the + // generator of Goldilocks' multiplicative subgroup; any non-trivial element + // works for an isolated correctness check. + let coset_offset: u64 = 7; + let weights = coset_weights(n, coset_offset); + + let cpu = cpu_lde(&evals, blowup_factor, coset_offset); + let gpu = math_cuda::lde::coset_lde_base(&evals, blowup_factor, &weights).expect("gpu lde"); + + assert_eq!(cpu.len(), gpu.len(), "length mismatch (log_n={log_n}, blowup={blowup_factor})"); + let cpu_c = canon(&cpu); + let gpu_c = canon(&gpu); + for (i, (e, a)) in cpu_c.iter().zip(&gpu_c).enumerate() { + if e != a { + panic!( + "lde mismatch log_n={log_n} blowup={blowup_factor} i={i}: cpu {e:#018x}, gpu {a:#018x}", + ); + } + } +} + +#[test] +fn lde_small() { + for log_n in 4..=10 { + for &blow in &[2usize, 4, 8] { + assert_lde_match(log_n, blow, 1_000 + log_n + (blow as u64)); + } + } +} + +#[test] +fn lde_medium() { + for log_n in 11..=14 { + for &blow in &[2usize, 4] { + assert_lde_match(log_n, blow, 2_000 + log_n + (blow as u64)); + } + } +} + +#[test] +fn lde_large_2_to_18() { + // 2^18 × blowup 4 = 2^20 LDE — representative of Phase A trace columns. + assert_lde_match(18, 4, 0xCAFE); +} + +#[test] +fn lde_largest_2_to_20() { + // 2^20 LDE is the hot size; blowup 2 keeps total = 2^21 (within TWO_ADICITY). + assert_lde_match(20, 2, 0xF00D); +} diff --git a/crypto/math-cuda/tests/lde_batch.rs b/crypto/math-cuda/tests/lde_batch.rs new file mode 100644 index 000000000..67f975728 --- /dev/null +++ b/crypto/math-cuda/tests/lde_batch.rs @@ -0,0 +1,96 @@ +//! Batched coset LDE must agree with running the CPU single-column LDE on +//! each column independently. Sweeps a few realistic (n, blowup, m) tuples. + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = *FieldElement::::from(n as u64) + .inv() + .unwrap() + .value(); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &g); + } + w +} + +fn cpu_lde_one(col: &[u64], blowup: usize, weights_fp: &[Fp], inv_tw: &LayerTwiddles, fwd_tw: &LayerTwiddles) -> Vec { + let mut buf: Vec = col.iter().map(|&x| Fp::from_raw(x)).collect(); + Polynomial::coset_lde_full_expand::( + &mut buf, blowup, weights_fp, inv_tw, fwd_tw, + ) + .unwrap(); + buf.into_iter().map(|e| *e.value()).collect() +} + +fn canon(xs: &[u64]) -> Vec { + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() +} + +fn assert_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let columns: Vec> = (0..m) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + + let coset_offset: u64 = 7; + let weights = coset_weights(n, coset_offset); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + + let inv_tw = LayerTwiddles::::new_inverse(log_n).unwrap(); + let fwd_tw = + LayerTwiddles::::new((n * blowup).trailing_zeros() as u64).unwrap(); + + let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + let gpu_all = math_cuda::lde::coset_lde_batch_base(&slices, blowup, &weights).unwrap(); + assert_eq!(gpu_all.len(), m); + + for (c, col) in columns.iter().enumerate() { + let cpu = cpu_lde_one(col, blowup, &weights_fp, &inv_tw, &fwd_tw); + assert_eq!( + canon(&gpu_all[c]), + canon(&cpu), + "batch mismatch at col {c}, log_n={log_n}, blowup={blowup}" + ); + } +} + +#[test] +fn batch_small() { + for &m in &[1usize, 4, 16] { + for log_n in 4..=10 { + assert_batch(log_n, 4, m, 100 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn batch_medium() { + for &m in &[2usize, 32] { + for log_n in 11..=14 { + assert_batch(log_n, 4, m, 200 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn batch_large_one_column() { + assert_batch(18, 4, 1, 0xCAFE); +} + +#[test] +fn batch_large_32_columns() { + assert_batch(15, 4, 32, 0xBEEF); +} diff --git a/crypto/math-cuda/tests/lde_batch_ext3.rs b/crypto/math-cuda/tests/lde_batch_ext3.rs new file mode 100644 index 000000000..0a86197a5 --- /dev/null +++ b/crypto/math-cuda/tests/lde_batch_ext3.rs @@ -0,0 +1,161 @@ +//! Ext3 batched coset LDE must agree with the CPU `coset_lde_full_expand` +//! on each column independently when run over `FieldElement`. + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = *FieldElement::::from(n as u64) + .inv() + .unwrap() + .value(); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &g); + } + w +} + +fn rand_ext3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + // Each Fp3 is [u64; 3] in memory; we just flatten componentwise. + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn u64s_to_ext3(raw: &[u64]) -> Vec { + assert_eq!(raw.len() % 3, 0); + let mut out = Vec::with_capacity(raw.len() / 3); + for i in 0..raw.len() / 3 { + out.push(Fp3::new([ + Fp::from_raw(raw[i * 3 + 0]), + Fp::from_raw(raw[i * 3 + 1]), + Fp::from_raw(raw[i * 3 + 2]), + ])); + } + out +} + +fn cpu_lde_one_ext3( + col: &[Fp3], + blowup: usize, + weights_fp: &[Fp], + inv_tw: &LayerTwiddles, + fwd_tw: &LayerTwiddles, +) -> Vec { + let mut buf = col.to_vec(); + Polynomial::coset_lde_full_expand::( + &mut buf, blowup, weights_fp, inv_tw, fwd_tw, + ) + .unwrap(); + buf +} + +fn canon(xs: &[u64]) -> Vec { + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() +} + +fn assert_ext3_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { + let n = 1usize << log_n; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let columns: Vec> = (0..m) + .map(|_| (0..n).map(|_| rand_ext3(&mut rng)).collect()) + .collect(); + + let coset_offset: u64 = 7; + let weights = coset_weights(n, coset_offset); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n).unwrap(); + let fwd_tw = LayerTwiddles::::new(lde_size.trailing_zeros() as u64).unwrap(); + + // Flatten each ext3 column to 3n u64s for the GPU API. + let flat_inputs: Vec> = columns.iter().map(|c| ext3_to_u64s(c)).collect(); + let input_slices: Vec<&[u64]> = flat_inputs.iter().map(|v| v.as_slice()).collect(); + + // Pre-allocate outputs, each 3*lde_size u64s. + let mut flat_outputs: Vec> = + (0..m).map(|_| vec![0u64; 3 * lde_size]).collect(); + { + let mut out_slices: Vec<&mut [u64]> = + flat_outputs.iter_mut().map(|v| v.as_mut_slice()).collect(); + math_cuda::lde::coset_lde_batch_ext3_into( + &input_slices, + n, + blowup, + &weights, + &mut out_slices, + ) + .unwrap(); + } + + for (c, col) in columns.iter().enumerate() { + let cpu = cpu_lde_one_ext3(col, blowup, &weights_fp, &inv_tw, &fwd_tw); + let gpu: Vec = u64s_to_ext3(&flat_outputs[c]); + assert_eq!(gpu.len(), cpu.len(), "length mismatch"); + for i in 0..cpu.len() { + for k in 0..3 { + let cv = *cpu[i].value()[k].value(); + let gv = *gpu[i].value()[k].value(); + let cc = GoldilocksField::canonical(&cv); + let gc = GoldilocksField::canonical(&gv); + if cc != gc { + panic!( + "ext3 batch mismatch col={c} row={i} comp={k} log_n={log_n} blowup={blowup}: cpu={cv:#018x} (canon {cc:#018x}), gpu={gv:#018x} (canon {gc:#018x})", + ); + } + } + } + } + // Also sanity-check raw canonical equality per column. + for (c, col) in columns.iter().enumerate() { + let cpu_raw = ext3_to_u64s(&cpu_lde_one_ext3(col, blowup, &weights_fp, &inv_tw, &fwd_tw)); + assert_eq!(canon(&cpu_raw), canon(&flat_outputs[c])); + } +} + +#[test] +fn ext3_batch_small() { + for &m in &[1usize, 4, 16] { + for log_n in 4..=10 { + assert_ext3_batch(log_n, 4, m, 100 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn ext3_batch_medium() { + for &m in &[2usize, 8] { + for log_n in 11..=14 { + assert_ext3_batch(log_n, 4, m, 300 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn ext3_batch_large_one_column() { + assert_ext3_batch(16, 4, 1, 0xCAFE); +} diff --git a/crypto/math-cuda/tests/ntt.rs b/crypto/math-cuda/tests/ntt.rs new file mode 100644 index 000000000..d7cf3680a --- /dev/null +++ b/crypto/math-cuda/tests/ntt.rs @@ -0,0 +1,136 @@ +//! Phase-3 parity: GPU forward NTT must agree with `Polynomial::evaluate_fft` +//! as a field element, across a sweep of sizes from 2^4 to 2^20. +//! +//! Non-canonical u64s can differ between CPU and GPU while representing the +//! same element; we canonicalise both sides before comparing. + +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsPrimeField; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; + +fn cpu_fft(coeffs: &[u64]) -> Vec { + let elems: Vec = coeffs.iter().map(|&x| Fp::from_raw(x)).collect(); + let poly = Polynomial::new(&elems); + let evals = Polynomial::evaluate_fft::(&poly, 1, None).expect("cpu fft"); + evals.into_iter().map(|e| *e.value()).collect() +} + +fn canonicalize(xs: &[u64]) -> Vec { + xs.iter() + .map(|x| GoldilocksField::canonical(x)) + .collect() +} + +fn assert_ntt_match(log_n: u64, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let input: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + + let cpu = cpu_fft(&input); + let gpu = math_cuda::ntt::forward(&input).expect("gpu ntt"); + + assert_eq!(cpu.len(), gpu.len(), "length mismatch at log_n = {log_n}"); + let cpu_c = canonicalize(&cpu); + let gpu_c = canonicalize(&gpu); + for i in 0..n { + if cpu_c[i] != gpu_c[i] { + panic!( + "log_n={log_n} i={i}: cpu={:#018x} (canon {:#018x}), gpu={:#018x} (canon {:#018x})", + cpu[i], cpu_c[i], gpu[i], gpu_c[i], + ); + } + } +} + +#[test] +fn ntt_sizes_small() { + for log_n in 4..=10 { + assert_ntt_match(log_n, 100 + log_n); + } +} + +#[test] +fn ntt_sizes_medium() { + for log_n in 11..=16 { + assert_ntt_match(log_n, 200 + log_n); + } +} + +#[test] +fn ntt_size_2_to_20() { + // The hot LDE size. One seed is enough; any mismatch screams loudly. + assert_ntt_match(20, 0xDEAD); +} + +#[test] +fn ntt_trivial_sizes() { + // Power-of-two below the interesting range — should still pass. + assert_ntt_match(1, 1); + assert_ntt_match(2, 2); + assert_ntt_match(3, 3); +} + +fn assert_intt_match(log_n: u64, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let evals: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + + let elems: Vec = evals.iter().map(|&x| Fp::from_raw(x)).collect(); + let cpu_poly = + Polynomial::interpolate_fft::(&elems).expect("cpu intt"); + let cpu: Vec = cpu_poly.coefficients.into_iter().map(|e| *e.value()).collect(); + + let gpu = math_cuda::ntt::inverse(&evals).expect("gpu intt"); + + let cpu_c = canonicalize(&cpu); + let gpu_c = canonicalize(&gpu); + for i in 0..n { + if cpu_c[i] != gpu_c[i] { + panic!( + "iNTT log_n={log_n} i={i}: cpu canon {:#018x}, gpu canon {:#018x}", + cpu_c[i], gpu_c[i], + ); + } + } +} + +#[test] +fn intt_sizes_small() { + for log_n in 4..=10 { + assert_intt_match(log_n, 700 + log_n); + } +} + +#[test] +fn intt_sizes_medium() { + for log_n in 11..=16 { + assert_intt_match(log_n, 800 + log_n); + } +} + +#[test] +fn intt_size_2_to_20() { + assert_intt_match(20, 0xBEEF); +} + +#[test] +fn ntt_round_trip() { + // inverse(forward(x)) == x up to canonical form. + let log_n = 14; + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(42); + let x: Vec = (0..n).map(|_| rng.r#gen::() % 0xFFFF_FFFF_0000_0001).collect(); + + let evals = math_cuda::ntt::forward(&x).expect("forward"); + let back = math_cuda::ntt::inverse(&evals).expect("inverse"); + + let x_c = canonicalize(&x); + let back_c = canonicalize(&back); + assert_eq!(x_c, back_c, "round trip failed"); +} + diff --git a/crypto/stark/Cargo.toml b/crypto/stark/Cargo.toml index 53b205996..4d1f2cbca 100644 --- a/crypto/stark/Cargo.toml +++ b/crypto/stark/Cargo.toml @@ -22,6 +22,9 @@ itertools = "0.11.0" # Parallelization crates rayon = { version = "1.8.0", optional = true } +# GPU backend for trace LDE — only linked when `cuda` is enabled. +math-cuda = { path = "../math-cuda", optional = true } + # wasm wasm-bindgen = { version = "0.2", optional = true } serde-wasm-bindgen = { version = "0.5", optional = true } @@ -39,6 +42,7 @@ test_fiat_shamir = [] instruments = [] # This enables timing prints in prover and verifier debug-checks = [] # Enables validate_trace + bus balance report in prover parallel = ["dep:rayon", "crypto/parallel"] +cuda = ["dep:math-cuda"] wasm = ["dep:wasm-bindgen", "dep:serde-wasm-bindgen", "dep:web-sys"] From 79634ff2ef8e956de53326870acec43c30a81a87 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 6 May 2026 15:26:58 -0300 Subject: [PATCH 02/16] fmt --- crypto/math-cuda/build.rs | 9 +- crypto/math-cuda/src/device.rs | 6 +- crypto/math-cuda/src/lde.rs | 151 +++++++++--------- crypto/math-cuda/src/ntt.rs | 8 +- crypto/math-cuda/tests/bench_quick.rs | 68 ++++---- crypto/math-cuda/tests/evaluate_coset_ext3.rs | 14 +- crypto/math-cuda/tests/goldilocks.rs | 15 +- crypto/math-cuda/tests/lde.rs | 6 +- crypto/math-cuda/tests/lde_batch.rs | 8 +- crypto/math-cuda/tests/lde_batch_ext3.rs | 11 +- crypto/math-cuda/tests/ntt.rs | 18 ++- 11 files changed, 159 insertions(+), 155 deletions(-) diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs index a6defb5ab..d6b803bed 100644 --- a/crypto/math-cuda/build.rs +++ b/crypto/math-cuda/build.rs @@ -30,14 +30,7 @@ fn compile_ptx(src: &str, out_name: &str) { let arch = env::var("CUDARC_NVCC_ARCH").unwrap_or_else(|_| "compute_89".to_string()); let status = Command::new(nvcc_path()) - .args([ - "--ptx", - "-O3", - "-std=c++17", - "-arch", - &arch, - "-o", - ]) + .args(["--ptx", "-O3", "-std=c++17", "-arch", &arch, "-o"]) .arg(&out_path) .arg(&src_path) .status() diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index 2c70716a6..03f0b67ca 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -40,11 +40,7 @@ impl PinnedStaging { } } - pub fn ensure_capacity( - &mut self, - min_elems: usize, - ctx: &CudaContext, - ) -> Result<()> { + pub fn ensure_capacity(&mut self, min_elems: usize, ctx: &CudaContext) -> Result<()> { if self.capacity_elems >= min_elems { return Ok(()); } diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index fd25d1bca..8b9da522c 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -38,15 +38,14 @@ pub struct GpuLdeExt3 { pub lde_size: usize, } -pub fn coset_lde_base( - evals: &[u64], - blowup_factor: usize, - weights: &[u64], -) -> Result> { +pub fn coset_lde_base(evals: &[u64], blowup_factor: usize, weights: &[u64]) -> Result> { let n = evals.len(); assert!(n.is_power_of_two(), "evals length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match evals"); - assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + assert!( + blowup_factor.is_power_of_two(), + "blowup must be power of two" + ); if n == 0 { return Ok(Vec::new()); } @@ -130,7 +129,10 @@ pub fn coset_lde_batch_base( let n = columns[0].len(); assert!(n.is_power_of_two(), "column length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match column length"); - assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + assert!( + blowup_factor.is_power_of_two(), + "blowup must be power of two" + ); for c in columns.iter() { assert_eq!(c.len(), n, "all columns must be the same size"); } @@ -147,7 +149,11 @@ pub fn coset_lde_batch_base( let staging_slot = be.pinned_staging(); let debug_phases = std::env::var("MATH_CUDA_PHASE_TIMING").is_ok(); - let t_start = if debug_phases { Some(std::time::Instant::now()) } else { None }; + let t_start = if debug_phases { + Some(std::time::Instant::now()) + } else { + None + }; let phase = |label: &str, prev: &mut Option| { if let Some(p) = prev.as_ref() { let now = std::time::Instant::now(); @@ -165,7 +171,9 @@ pub fn coset_lde_batch_base( staging.ensure_capacity(m * lde_size, &be.ctx)?; // SAFETY: staging is locked, the slice alias ends before we unlock. let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - if debug_phases { phase("staging lock + grow", &mut last); } + if debug_phases { + phase("staging lock + grow", &mut last); + } // Pack columns into first m*n slots of the pinned buffer. Parallel: pinned // writes are DRAM-bandwidth bound, saturates at ~8 cores on modern @@ -176,32 +184,39 @@ pub fn coset_lde_batch_base( // SAFETY: each task writes to a disjoint `[c*n..c*n+n]` region of // `pinned`, and the outer `staging` lock guarantees no other call is // using the buffer concurrently. - let dst = unsafe { - std::slice::from_raw_parts_mut( - (pinned_base_ptr as *mut u64).add(c * n), - n, - ) - }; + let dst = + unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; dst.copy_from_slice(col); }); - if debug_phases { phase("host pack (pinned, rayon)", &mut last); } + if debug_phases { + phase("host pack (pinned, rayon)", &mut last); + } // Column layout: `buf[c * lde_size + r]`. Zeroed so the [n, lde_size) // tail of each column is already the zero-pad the CPU path does. let mut buf = stream.alloc_zeros::(m * lde_size)?; - if debug_phases { stream.synchronize()?; phase("alloc_zeros", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("alloc_zeros", &mut last); + } // One memcpy per column from the pinned buffer into the strided slots. // The pinned source hits PCIe line-rate. for c in 0..m { let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; } - if debug_phases { stream.synchronize()?; phase("H2D cols (pinned)", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("H2D cols (pinned)", &mut last); + } let inv_tw = be.inv_twiddles_for(log_n)?; let fwd_tw = be.fwd_twiddles_for(log_lde)?; let weights_dev = stream.clone_htod(weights)?; - if debug_phases { stream.synchronize()?; phase("twiddles + weights", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("twiddles + weights", &mut last); + } let n_u64 = n as u64; let lde_u64 = lde_size as u64; @@ -227,7 +242,10 @@ pub fn coset_lde_batch_base( } } - if debug_phases { stream.synchronize()?; phase("bit_reverse N", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("bit_reverse N", &mut last); + } // === 2. iNTT body over all columns === run_batched_ntt_body( stream.as_ref(), @@ -238,7 +256,10 @@ pub fn coset_lde_batch_base( col_stride_u64, m_u32, )?; - if debug_phases { stream.synchronize()?; phase("iNTT body", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("iNTT body", &mut last); + } // === 3. Pointwise multiply by coset weights (includes 1/N) === { @@ -278,7 +299,10 @@ pub fn coset_lde_batch_base( } } - if debug_phases { stream.synchronize()?; phase("pointwise + bit_reverse LDE", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("pointwise + bit_reverse LDE", &mut last); + } // === 5. Forward NTT on full LDE of every column === run_batched_ntt_body( stream.as_ref(), @@ -289,13 +313,18 @@ pub fn coset_lde_batch_base( col_stride_u64, m_u32, )?; - if debug_phases { stream.synchronize()?; phase("forward NTT body", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("forward NTT body", &mut last); + } // Single big D2H into the reusable pinned staging buffer — pinned, one // call to the driver, saturates PCIe. stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; stream.synchronize()?; - if debug_phases { phase("D2H (one shot into pinned)", &mut last); } + if debug_phases { + phase("D2H (one shot into pinned)", &mut last); + } // Split pinned → per-column Vecs. The first write to each virgin // Vec page-faults, which can dominate total time (~75 ms for 128 MB). @@ -308,16 +337,15 @@ pub fn coset_lde_batch_base( let mut v = Vec::::with_capacity(lde_size); unsafe { v.set_len(lde_size) }; let src = unsafe { - std::slice::from_raw_parts( - (pinned_ptr as *const u64).add(c * lde_size), - lde_size, - ) + std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) }; v.copy_from_slice(src); v }) .collect(); - if debug_phases { phase("copy out (rayon pinned → Vecs)", &mut last); } + if debug_phases { + phase("copy out (rayon pinned → Vecs)", &mut last); + } drop(staging); Ok(out) } @@ -341,7 +369,10 @@ pub fn coset_lde_batch_base_into( let n = columns[0].len(); assert!(n.is_power_of_two(), "column length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match column length"); - assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + assert!( + blowup_factor.is_power_of_two(), + "blowup must be power of two" + ); for c in columns.iter() { assert_eq!(c.len(), n, "all columns must be the same size"); } @@ -461,18 +492,12 @@ pub fn coset_lde_batch_base_into( #[allow(unused_imports)] use rayon::prelude::*; let pinned_ptr = pinned.as_ptr() as usize; - outputs - .par_iter_mut() - .enumerate() - .for_each(|(c, dst)| { - let src = unsafe { - std::slice::from_raw_parts( - (pinned_ptr as *const u64).add(c * lde_size), - lde_size, - ) - }; - dst.copy_from_slice(src); - }); + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) + }; + dst.copy_from_slice(src); + }); drop(staging); Ok(()) } @@ -497,15 +522,8 @@ pub fn evaluate_poly_coset_batch_ext3_into( weights: &[u64], outputs: &mut [&mut [u64]], ) -> Result<()> { - evaluate_poly_coset_batch_ext3_into_inner( - coefs, - n, - blowup_factor, - weights, - outputs, - false, - ) - .map(|_| ()) + evaluate_poly_coset_batch_ext3_into_inner(coefs, n, blowup_factor, weights, outputs, false) + .map(|_| ()) } /// Same as [`evaluate_poly_coset_batch_ext3_into`] but retains the de- @@ -519,14 +537,8 @@ pub fn evaluate_poly_coset_batch_ext3_into_keep( weights: &[u64], outputs: &mut [&mut [u64]], ) -> Result { - let opt = evaluate_poly_coset_batch_ext3_into_inner( - coefs, - n, - blowup_factor, - weights, - outputs, - true, - )?; + let opt = + evaluate_poly_coset_batch_ext3_into_inner(coefs, n, blowup_factor, weights, outputs, true)?; Ok(opt.expect("keep_device_buf=true must return Some")) } @@ -724,7 +736,10 @@ pub fn coset_lde_batch_ext3_into( assert_eq!(outputs.len(), m, "outputs must match columns count"); assert!(n.is_power_of_two(), "n must be a power of two"); assert_eq!(weights.len(), n, "weights length must match n"); - assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + assert!( + blowup_factor.is_power_of_two(), + "blowup must be power of two" + ); for c in columns.iter() { assert_eq!(c.len(), 3 * n, "each ext3 column must be 3*n u64s"); } @@ -757,22 +772,13 @@ pub fn coset_lde_batch_ext3_into( columns.par_iter().enumerate().for_each(|(c, col)| { // SAFETY: disjoint regions per c; staging lock held. let slab_a = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3 + 0) * n), - n, - ) + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 0) * n), n) }; let slab_b = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), - n, - ) + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) }; let slab_c = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), - n, - ) + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) }; for i in 0..n { slab_a[i] = col[i * 3 + 0]; @@ -981,4 +987,3 @@ fn run_batched_ntt_body( } Ok(()) } - diff --git a/crypto/math-cuda/src/ntt.rs b/crypto/math-cuda/src/ntt.rs index 0ebb015ea..31333c3ae 100644 --- a/crypto/math-cuda/src/ntt.rs +++ b/crypto/math-cuda/src/ntt.rs @@ -87,13 +87,7 @@ fn ntt_inplace(input: &[u64], forward: bool) -> Result> { // 2. DIT butterfly levels. For log_n >= 8 we fuse 8 levels per kernel via // the shmem kernel; for very small sizes (< 256 elements) we stick with // the per-level kernel because the shmem block dimensions assume n ≥ 256. - run_ntt_body( - stream.as_ref(), - &mut x_dev, - tw_dev.as_ref(), - n_u64, - log_n, - )?; + run_ntt_body(stream.as_ref(), &mut x_dev, tw_dev.as_ref(), n_u64, log_n)?; // 3. For iNTT, multiply by 1/n. if !forward { diff --git a/crypto/math-cuda/tests/bench_quick.rs b/crypto/math-cuda/tests/bench_quick.rs index 561331b74..41e9444c3 100644 --- a/crypto/math-cuda/tests/bench_quick.rs +++ b/crypto/math-cuda/tests/bench_quick.rs @@ -15,7 +15,10 @@ use rayon::prelude::*; type Fp = FieldElement; fn coset_weights(n: usize, g: u64) -> Vec { - let inv_n = *FieldElement::::from(n as u64).inv().unwrap().value(); + let inv_n = *FieldElement::::from(n as u64) + .inv() + .unwrap() + .value(); let mut w = Vec::with_capacity(n); let mut cur = inv_n; for _ in 0..n { @@ -54,7 +57,11 @@ fn bench_lde_2_to_18_blowup_4() { for _ in 0..TRIALS { let mut buf: Vec = input.iter().map(|&x| Fp::from_raw(x)).collect(); Polynomial::coset_lde_full_expand::( - &mut buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + &mut buf, + blowup, + &weights_fp, + &inv_tw, + &fwd_tw, ) .unwrap(); std::hint::black_box(&buf); @@ -101,12 +108,7 @@ fn bench_lde_multi_column_parallel() { let num_cols = 64; // Warm up. - let _ = math_cuda::lde::coset_lde_base( - &vec![0u64; n], - blowup, - &coset_weights(n, 7), - ) - .unwrap(); + let _ = math_cuda::lde::coset_lde_base(&vec![0u64; n], blowup, &coset_weights(n, 7)).unwrap(); // Build input data. let mut rng = ChaCha8Rng::seed_from_u64(11); @@ -134,7 +136,11 @@ fn bench_lde_multi_column_parallel() { let t0 = Instant::now(); cpu_bufs.par_iter_mut().for_each(|buf| { Polynomial::coset_lde_full_expand::( - buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + buf, + blowup, + &weights_fp, + &inv_tw, + &fwd_tw, ) .unwrap(); }); @@ -164,15 +170,12 @@ fn bench_lde_batched_prover_scale() { let weights = coset_weights(n, 7); let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); - let fwd_tw = LayerTwiddles::::new( - (n * blowup).trailing_zeros() as u64, - ) - .unwrap(); + let fwd_tw = + LayerTwiddles::::new((n * blowup).trailing_zeros() as u64).unwrap(); let warm_slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); for _ in 0..8 { - let _ = - math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); + let _ = math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); } let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); @@ -194,7 +197,11 @@ fn bench_lde_batched_prover_scale() { let t0 = Instant::now(); cpu_bufs.par_iter_mut().for_each(|buf| { Polynomial::coset_lde_full_expand::( - buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + buf, + blowup, + &weights_fp, + &inv_tw, + &fwd_tw, ) .unwrap(); }); @@ -227,15 +234,12 @@ fn bench_lde_batched_vs_rayon_cpu() { // one-time pinned staging alloc cost. let warm_slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); for _ in 0..64 { - let _ = - math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); + let _ = math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); } let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); - let fwd_tw = LayerTwiddles::::new( - (n * blowup).trailing_zeros() as u64, - ) - .unwrap(); + let fwd_tw = + LayerTwiddles::::new((n * blowup).trailing_zeros() as u64).unwrap(); // GPU batched — first run may include lazy device init; do a few to stabilise. let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); @@ -254,7 +258,11 @@ fn bench_lde_batched_vs_rayon_cpu() { let t0 = Instant::now(); cpu_bufs.par_iter_mut().for_each(|buf| { Polynomial::coset_lde_full_expand::( - buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + buf, + blowup, + &weights_fp, + &inv_tw, + &fwd_tw, ) .unwrap(); }); @@ -277,12 +285,7 @@ fn bench_lde_multi_column_serialized_gpu() { let n = 1usize << log_n; let num_cols = 64; - let _ = math_cuda::lde::coset_lde_base( - &vec![0u64; n], - blowup, - &coset_weights(n, 7), - ) - .unwrap(); + let _ = math_cuda::lde::coset_lde_base(&vec![0u64; n], blowup, &coset_weights(n, 7)).unwrap(); let mut rng = ChaCha8Rng::seed_from_u64(13); let columns: Vec> = (0..num_cols) @@ -320,12 +323,7 @@ fn bench_lde_multi_column_gpu_limited_threads() { let n = 1usize << log_n; let num_cols = 64; - let _ = math_cuda::lde::coset_lde_base( - &vec![0u64; n], - blowup, - &coset_weights(n, 7), - ) - .unwrap(); + let _ = math_cuda::lde::coset_lde_base(&vec![0u64; n], blowup, &coset_weights(n, 7)).unwrap(); let mut rng = ChaCha8Rng::seed_from_u64(12); let columns: Vec> = (0..num_cols) diff --git a/crypto/math-cuda/tests/evaluate_coset_ext3.rs b/crypto/math-cuda/tests/evaluate_coset_ext3.rs index a79195291..020c39241 100644 --- a/crypto/math-cuda/tests/evaluate_coset_ext3.rs +++ b/crypto/math-cuda/tests/evaluate_coset_ext3.rs @@ -81,13 +81,8 @@ fn assert_evaluate_coset(log_n: u64, blowup: usize, m: usize, offset: u64, seed: .iter() .map(|coefs| { let p = Polynomial::new(coefs); - Polynomial::evaluate_offset_fft::( - &p, - blowup, - Some(n), - &offset_fp, - ) - .unwrap() + Polynomial::evaluate_offset_fft::(&p, blowup, Some(n), &offset_fp) + .unwrap() }) .collect(); @@ -114,7 +109,10 @@ fn assert_evaluate_coset(log_n: u64, blowup: usize, m: usize, offset: u64, seed: for i in 0..gpu.len() { let g = canon_fp3(&gpu[i]); let cc = canon_fp3(&cpu[c][i]); - assert_eq!(g, cc, "eval mismatch col={c} row={i} log_n={log_n} blowup={blowup}"); + assert_eq!( + g, cc, + "eval mismatch col={c} row={i} log_n={log_n} blowup={blowup}" + ); } } } diff --git a/crypto/math-cuda/tests/goldilocks.rs b/crypto/math-cuda/tests/goldilocks.rs index 317ffb0f8..37a5d8533 100644 --- a/crypto/math-cuda/tests/goldilocks.rs +++ b/crypto/math-cuda/tests/goldilocks.rs @@ -112,12 +112,15 @@ fn gpu_goldilocks_edge_cases() { } } - let cases: &[(&str, fn(&[u64], &[u64]) -> math_cuda::Result>, fn(&u64, &u64) -> u64)] = - &[ - ("gl_add", math_cuda::gl_add_u64, GoldilocksField::add), - ("gl_sub", math_cuda::gl_sub_u64, GoldilocksField::sub), - ("gl_mul", math_cuda::gl_mul_u64, GoldilocksField::mul), - ]; + let cases: &[( + &str, + fn(&[u64], &[u64]) -> math_cuda::Result>, + fn(&u64, &u64) -> u64, + )] = &[ + ("gl_add", math_cuda::gl_add_u64, GoldilocksField::add), + ("gl_sub", math_cuda::gl_sub_u64, GoldilocksField::sub), + ("gl_mul", math_cuda::gl_mul_u64, GoldilocksField::mul), + ]; for (op, gpu_fn, cpu_fn) in cases { let expected: Vec = a.iter().zip(&b).map(|(x, y)| cpu_fn(x, y)).collect(); diff --git a/crypto/math-cuda/tests/lde.rs b/crypto/math-cuda/tests/lde.rs index 9648f833a..75ea80867 100644 --- a/crypto/math-cuda/tests/lde.rs +++ b/crypto/math-cuda/tests/lde.rs @@ -69,7 +69,11 @@ fn assert_lde_match(log_n: u64, blowup_factor: usize, seed: u64) { let cpu = cpu_lde(&evals, blowup_factor, coset_offset); let gpu = math_cuda::lde::coset_lde_base(&evals, blowup_factor, &weights).expect("gpu lde"); - assert_eq!(cpu.len(), gpu.len(), "length mismatch (log_n={log_n}, blowup={blowup_factor})"); + assert_eq!( + cpu.len(), + gpu.len(), + "length mismatch (log_n={log_n}, blowup={blowup_factor})" + ); let cpu_c = canon(&cpu); let gpu_c = canon(&gpu); for (i, (e, a)) in cpu_c.iter().zip(&gpu_c).enumerate() { diff --git a/crypto/math-cuda/tests/lde_batch.rs b/crypto/math-cuda/tests/lde_batch.rs index 67f975728..153e5d3e5 100644 --- a/crypto/math-cuda/tests/lde_batch.rs +++ b/crypto/math-cuda/tests/lde_batch.rs @@ -25,7 +25,13 @@ fn coset_weights(n: usize, g: u64) -> Vec { w } -fn cpu_lde_one(col: &[u64], blowup: usize, weights_fp: &[Fp], inv_tw: &LayerTwiddles, fwd_tw: &LayerTwiddles) -> Vec { +fn cpu_lde_one( + col: &[u64], + blowup: usize, + weights_fp: &[Fp], + inv_tw: &LayerTwiddles, + fwd_tw: &LayerTwiddles, +) -> Vec { let mut buf: Vec = col.iter().map(|&x| Fp::from_raw(x)).collect(); Polynomial::coset_lde_full_expand::( &mut buf, blowup, weights_fp, inv_tw, fwd_tw, diff --git a/crypto/math-cuda/tests/lde_batch_ext3.rs b/crypto/math-cuda/tests/lde_batch_ext3.rs index 0a86197a5..ff19f6c56 100644 --- a/crypto/math-cuda/tests/lde_batch_ext3.rs +++ b/crypto/math-cuda/tests/lde_batch_ext3.rs @@ -97,8 +97,7 @@ fn assert_ext3_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { let input_slices: Vec<&[u64]> = flat_inputs.iter().map(|v| v.as_slice()).collect(); // Pre-allocate outputs, each 3*lde_size u64s. - let mut flat_outputs: Vec> = - (0..m).map(|_| vec![0u64; 3 * lde_size]).collect(); + let mut flat_outputs: Vec> = (0..m).map(|_| vec![0u64; 3 * lde_size]).collect(); { let mut out_slices: Vec<&mut [u64]> = flat_outputs.iter_mut().map(|v| v.as_mut_slice()).collect(); @@ -132,7 +131,13 @@ fn assert_ext3_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { } // Also sanity-check raw canonical equality per column. for (c, col) in columns.iter().enumerate() { - let cpu_raw = ext3_to_u64s(&cpu_lde_one_ext3(col, blowup, &weights_fp, &inv_tw, &fwd_tw)); + let cpu_raw = ext3_to_u64s(&cpu_lde_one_ext3( + col, + blowup, + &weights_fp, + &inv_tw, + &fwd_tw, + )); assert_eq!(canon(&cpu_raw), canon(&flat_outputs[c])); } } diff --git a/crypto/math-cuda/tests/ntt.rs b/crypto/math-cuda/tests/ntt.rs index d7cf3680a..b6697b82d 100644 --- a/crypto/math-cuda/tests/ntt.rs +++ b/crypto/math-cuda/tests/ntt.rs @@ -21,9 +21,7 @@ fn cpu_fft(coeffs: &[u64]) -> Vec { } fn canonicalize(xs: &[u64]) -> Vec { - xs.iter() - .map(|x| GoldilocksField::canonical(x)) - .collect() + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() } fn assert_ntt_match(log_n: u64, seed: u64) { @@ -81,9 +79,12 @@ fn assert_intt_match(log_n: u64, seed: u64) { let evals: Vec = (0..n).map(|_| rng.r#gen::()).collect(); let elems: Vec = evals.iter().map(|&x| Fp::from_raw(x)).collect(); - let cpu_poly = - Polynomial::interpolate_fft::(&elems).expect("cpu intt"); - let cpu: Vec = cpu_poly.coefficients.into_iter().map(|e| *e.value()).collect(); + let cpu_poly = Polynomial::interpolate_fft::(&elems).expect("cpu intt"); + let cpu: Vec = cpu_poly + .coefficients + .into_iter() + .map(|e| *e.value()) + .collect(); let gpu = math_cuda::ntt::inverse(&evals).expect("gpu intt"); @@ -124,7 +125,9 @@ fn ntt_round_trip() { let log_n = 14; let n = 1usize << log_n; let mut rng = ChaCha8Rng::seed_from_u64(42); - let x: Vec = (0..n).map(|_| rng.r#gen::() % 0xFFFF_FFFF_0000_0001).collect(); + let x: Vec = (0..n) + .map(|_| rng.r#gen::() % 0xFFFF_FFFF_0000_0001) + .collect(); let evals = math_cuda::ntt::forward(&x).expect("forward"); let back = math_cuda::ntt::inverse(&evals).expect("inverse"); @@ -133,4 +136,3 @@ fn ntt_round_trip() { let back_c = canonicalize(&back); assert_eq!(x_c, back_c, "round trip failed"); } - From ac6fbb5972cc4cd4e228bd1adb2e659760c7b6e1 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 6 May 2026 15:55:58 -0300 Subject: [PATCH 03/16] fix clippy --- crypto/math-cuda/build.rs | 34 +++++++++++++++++-- crypto/math-cuda/src/lde.rs | 20 +++++++---- crypto/math-cuda/tests/evaluate_coset_ext3.rs | 2 +- crypto/math-cuda/tests/goldilocks.rs | 10 +++--- crypto/math-cuda/tests/lde.rs | 2 +- crypto/math-cuda/tests/lde_batch.rs | 2 +- crypto/math-cuda/tests/lde_batch_ext3.rs | 4 +-- crypto/math-cuda/tests/ntt.rs | 2 +- 8 files changed, 54 insertions(+), 22 deletions(-) diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs index d6b803bed..43edaa31c 100644 --- a/crypto/math-cuda/build.rs +++ b/crypto/math-cuda/build.rs @@ -1,4 +1,5 @@ use std::env; +use std::fs; use std::path::PathBuf; use std::process::Command; @@ -13,7 +14,7 @@ fn nvcc_path() -> PathBuf { cuda_home().join("bin").join("nvcc") } -fn compile_ptx(src: &str, out_name: &str) { +fn compile_ptx(src: &str, out_name: &str, have_nvcc: bool) { let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); let src_path = manifest_dir.join("kernels").join(src); @@ -24,6 +25,16 @@ fn compile_ptx(src: &str, out_name: &str) { println!("cargo:rerun-if-env-changed=CUDA_PATH"); println!("cargo:rerun-if-env-changed=CUDARC_NVCC_ARCH"); + // No nvcc on PATH → emit an empty PTX stub so the crate still compiles. + // include_str! in src/device.rs needs the file to exist at build time. + // Any runtime kernel call will then panic from cudarc when loading the + // empty module — which is the right failure mode (we can't run GPU code + // without nvcc on the build host anyway). + if !have_nvcc { + fs::write(&out_path, "").expect("failed to write empty PTX stub"); + return; + } + // Emit PTX for a virtual architecture; the CUDA driver JIT-compiles it for the // actual GPU at load time, so one PTX works across Ada/Hopper/Blackwell. Override // with CUDARC_NVCC_ARCH to pin a specific compute capability. @@ -45,6 +56,23 @@ fn main() { // Headers are not compiled; emit rerun-if-changed so edits trigger rebuilds. println!("cargo:rerun-if-changed=kernels/goldilocks.cuh"); println!("cargo:rerun-if-changed=kernels/ext3.cuh"); - compile_ptx("arith.cu", "arith.ptx"); - compile_ptx("ntt.cu", "ntt.ptx"); + + // Probe for nvcc once. Workspace consumers (clippy, fmt, CPU-only test + // runners) build math-cuda incidentally without using its kernels; allow + // those to succeed by stubbing out PTX when nvcc is unavailable. + let have_nvcc = Command::new(nvcc_path()) + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false); + if !have_nvcc { + println!( + "cargo:warning=math-cuda: nvcc not found at {} — emitting empty PTX stubs. \ + Runtime GPU calls will panic. Install CUDA and rebuild for a working backend.", + nvcc_path().display() + ); + } + + compile_ptx("arith.cu", "arith.ptx", have_nvcc); + compile_ptx("ntt.cu", "ntt.ptx", have_nvcc); } diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 8b9da522c..22d62ca3e 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -329,13 +329,19 @@ pub fn coset_lde_batch_base( // Split pinned → per-column Vecs. The first write to each virgin // Vec page-faults, which can dominate total time (~75 ms for 128 MB). // Parallelise so the fault cost spreads across CPU cores. - use rayon::prelude::*; let pinned_ptr = pinned.as_ptr() as usize; let out: Vec> = (0..m) .into_par_iter() .map(|c| { - let mut v = Vec::::with_capacity(lde_size); - unsafe { v.set_len(lde_size) }; + // set_len skips the O(N) zero-init that vec![0; n] would do + // (saves ~75 ms per 128 MB at prover scale). copy_from_slice + // below writes every slot before any reader sees the Vec. + #[allow(clippy::uninit_vec)] + let mut v = { + let mut v = Vec::::with_capacity(lde_size); + unsafe { v.set_len(lde_size) }; + v + }; let src = unsafe { std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) }; @@ -772,7 +778,7 @@ pub fn coset_lde_batch_ext3_into( columns.par_iter().enumerate().for_each(|(c, col)| { // SAFETY: disjoint regions per c; staging lock held. let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 0) * n), n) + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) }; let slab_b = unsafe { std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) @@ -781,7 +787,7 @@ pub fn coset_lde_batch_ext3_into( std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) }; for i in 0..n { - slab_a[i] = col[i * 3 + 0]; + slab_a[i] = col[i * 3]; slab_b[i] = col[i * 3 + 1]; slab_c[i] = col[i * 3 + 2]; } @@ -885,7 +891,7 @@ pub fn coset_lde_batch_ext3_into( outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { let slab_a = unsafe { std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 0) * lde_size), + (pinned_const as *const u64).add((c * 3) * lde_size), lde_size, ) }; @@ -902,7 +908,7 @@ pub fn coset_lde_batch_ext3_into( ) }; for i in 0..lde_size { - dst[i * 3 + 0] = slab_a[i]; + dst[i * 3] = slab_a[i]; dst[i * 3 + 1] = slab_b[i]; dst[i * 3 + 2] = slab_c[i]; } diff --git a/crypto/math-cuda/tests/evaluate_coset_ext3.rs b/crypto/math-cuda/tests/evaluate_coset_ext3.rs index 020c39241..334b2601d 100644 --- a/crypto/math-cuda/tests/evaluate_coset_ext3.rs +++ b/crypto/math-cuda/tests/evaluate_coset_ext3.rs @@ -47,7 +47,7 @@ fn u64s_to_ext3(raw: &[u64]) -> Vec { let mut out = Vec::with_capacity(raw.len() / 3); for i in 0..raw.len() / 3 { out.push(Fp3::new([ - Fp::from_raw(raw[i * 3 + 0]), + Fp::from_raw(raw[i * 3]), Fp::from_raw(raw[i * 3 + 1]), Fp::from_raw(raw[i * 3 + 2]), ])); diff --git a/crypto/math-cuda/tests/goldilocks.rs b/crypto/math-cuda/tests/goldilocks.rs index 37a5d8533..33b7bf5e7 100644 --- a/crypto/math-cuda/tests/goldilocks.rs +++ b/crypto/math-cuda/tests/goldilocks.rs @@ -78,7 +78,7 @@ fn gpu_gl_mul_matches_cpu() { #[test] fn gpu_gl_neg_matches_cpu() { let a = sample_inputs(7); - let expected: Vec = a.iter().map(|x| GoldilocksField::neg(x)).collect(); + let expected: Vec = a.iter().map(GoldilocksField::neg).collect(); let actual = math_cuda::gl_neg_u64(&a).expect("GPU gl_neg"); assert_raw_eq("gl_neg", &expected, &actual); } @@ -112,11 +112,9 @@ fn gpu_goldilocks_edge_cases() { } } - let cases: &[( - &str, - fn(&[u64], &[u64]) -> math_cuda::Result>, - fn(&u64, &u64) -> u64, - )] = &[ + type GpuOp = fn(&[u64], &[u64]) -> math_cuda::Result>; + type CpuOp = fn(&u64, &u64) -> u64; + let cases: &[(&str, GpuOp, CpuOp)] = &[ ("gl_add", math_cuda::gl_add_u64, GoldilocksField::add), ("gl_sub", math_cuda::gl_sub_u64, GoldilocksField::sub), ("gl_mul", math_cuda::gl_mul_u64, GoldilocksField::mul), diff --git a/crypto/math-cuda/tests/lde.rs b/crypto/math-cuda/tests/lde.rs index 75ea80867..facd2d861 100644 --- a/crypto/math-cuda/tests/lde.rs +++ b/crypto/math-cuda/tests/lde.rs @@ -52,7 +52,7 @@ fn cpu_lde(evals: &[u64], blowup_factor: usize, coset_offset: u64) -> Vec { } fn canon(xs: &[u64]) -> Vec { - xs.iter().map(|x| GoldilocksField::canonical(x)).collect() + xs.iter().map(GoldilocksField::canonical).collect() } fn assert_lde_match(log_n: u64, blowup_factor: usize, seed: u64) { diff --git a/crypto/math-cuda/tests/lde_batch.rs b/crypto/math-cuda/tests/lde_batch.rs index 153e5d3e5..b7120ca28 100644 --- a/crypto/math-cuda/tests/lde_batch.rs +++ b/crypto/math-cuda/tests/lde_batch.rs @@ -41,7 +41,7 @@ fn cpu_lde_one( } fn canon(xs: &[u64]) -> Vec { - xs.iter().map(|x| GoldilocksField::canonical(x)).collect() + xs.iter().map(GoldilocksField::canonical).collect() } fn assert_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { diff --git a/crypto/math-cuda/tests/lde_batch_ext3.rs b/crypto/math-cuda/tests/lde_batch_ext3.rs index ff19f6c56..c1156edaa 100644 --- a/crypto/math-cuda/tests/lde_batch_ext3.rs +++ b/crypto/math-cuda/tests/lde_batch_ext3.rs @@ -51,7 +51,7 @@ fn u64s_to_ext3(raw: &[u64]) -> Vec { let mut out = Vec::with_capacity(raw.len() / 3); for i in 0..raw.len() / 3 { out.push(Fp3::new([ - Fp::from_raw(raw[i * 3 + 0]), + Fp::from_raw(raw[i * 3]), Fp::from_raw(raw[i * 3 + 1]), Fp::from_raw(raw[i * 3 + 2]), ])); @@ -75,7 +75,7 @@ fn cpu_lde_one_ext3( } fn canon(xs: &[u64]) -> Vec { - xs.iter().map(|x| GoldilocksField::canonical(x)).collect() + xs.iter().map(GoldilocksField::canonical).collect() } fn assert_ext3_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { diff --git a/crypto/math-cuda/tests/ntt.rs b/crypto/math-cuda/tests/ntt.rs index b6697b82d..f3689cf94 100644 --- a/crypto/math-cuda/tests/ntt.rs +++ b/crypto/math-cuda/tests/ntt.rs @@ -21,7 +21,7 @@ fn cpu_fft(coeffs: &[u64]) -> Vec { } fn canonicalize(xs: &[u64]) -> Vec { - xs.iter().map(|x| GoldilocksField::canonical(x)).collect() + xs.iter().map(GoldilocksField::canonical).collect() } fn assert_ntt_match(log_n: u64, seed: u64) { From 2ceb3b0ef435f63702c1c4891f9116bac7e32a3f Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 6 May 2026 17:15:10 -0300 Subject: [PATCH 04/16] gpu 2nd part --- crypto/math-cuda/build.rs | 1 + crypto/math-cuda/kernels/keccak.cu | 347 +++++++ crypto/math-cuda/src/device.rs | 27 + crypto/math-cuda/src/lde.rs | 1189 +++++++++++++++++++++++ crypto/math-cuda/src/lib.rs | 1 + crypto/math-cuda/src/merkle.rs | 415 ++++++++ crypto/math-cuda/tests/keccak_leaves.rs | 140 +++ crypto/math-cuda/tests/merkle_tree.rs | 92 ++ 8 files changed, 2212 insertions(+) create mode 100644 crypto/math-cuda/kernels/keccak.cu create mode 100644 crypto/math-cuda/src/merkle.rs create mode 100644 crypto/math-cuda/tests/keccak_leaves.rs create mode 100644 crypto/math-cuda/tests/merkle_tree.rs diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs index 43edaa31c..7c417fb9c 100644 --- a/crypto/math-cuda/build.rs +++ b/crypto/math-cuda/build.rs @@ -75,4 +75,5 @@ fn main() { compile_ptx("arith.cu", "arith.ptx", have_nvcc); compile_ptx("ntt.cu", "ntt.ptx", have_nvcc); + compile_ptx("keccak.cu", "keccak.ptx", have_nvcc); } diff --git a/crypto/math-cuda/kernels/keccak.cu b/crypto/math-cuda/kernels/keccak.cu new file mode 100644 index 000000000..68ddce3b4 --- /dev/null +++ b/crypto/math-cuda/kernels/keccak.cu @@ -0,0 +1,347 @@ +// CUDA Keccak-256 (original Keccak, NOT SHA3-256 — uses 0x01 padding delimiter). +// +// Used by the lambda-vm prover's Merkle commit: +// leaf = Keccak-256(concat(col_0[br_idx].to_be_bytes(), col_1[br_idx].to_be_bytes(), …)) +// where `br_idx = bit_reverse(row_idx, log_num_rows)` and each element is +// written in BIG-ENDIAN canonical form (per `FieldElement::write_bytes_be`). +// +// Keccak state is 5x5 lanes of u64, interpreted little-endian. Rate = 136 B +// (17 lanes) for 256-bit output, capacity = 64 B (8 lanes). +// +// Since every input byte is u64-aligned (each field element is 8 or 24 bytes), +// we can absorb lane-by-lane instead of byte-by-byte. Canonicalise + byte-swap +// each u64 on read to turn a BE-serialised element into its LE-interpreted +// lane value. + +#include +#include "goldilocks.cuh" + +__device__ __constant__ uint64_t KECCAK_RC[24] = { + 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, + 0x8000000080008000ULL, 0x000000000000808bULL, 0x0000000080000001ULL, + 0x8000000080008081ULL, 0x8000000000008009ULL, 0x000000000000008aULL, + 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, + 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, + 0x8000000000008003ULL, 0x8000000000008002ULL, 0x8000000000000080ULL, + 0x000000000000800aULL, 0x800000008000000aULL, 0x8000000080008081ULL, + 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL, +}; + +// Rotation offsets indexed by lane position x + 5*y. Standard Keccak rho. +__device__ __constant__ uint32_t KECCAK_RHO_OFFSETS[25] = { + 0, 1, 62, 28, 27, // y=0: x=0..4 + 36, 44, 6, 55, 20, // y=1 + 3, 10, 43, 25, 39, // y=2 + 41, 45, 15, 21, 8, // y=3 + 18, 2, 61, 56, 14, // y=4 +}; + +__device__ __forceinline__ uint64_t rotl64(uint64_t x, uint32_t n) { + return (n == 0) ? x : ((x << n) | (x >> (64 - n))); +} + +__device__ __forceinline__ uint64_t bswap64(uint64_t x) { + // Reverse byte order: turns a BE-serialised u64 into its LE-read lane. + x = ((x & 0x00ff00ff00ff00ffULL) << 8) | ((x & 0xff00ff00ff00ff00ULL) >> 8); + x = ((x & 0x0000ffff0000ffffULL) << 16) | ((x & 0xffff0000ffff0000ULL) >> 16); + return (x << 32) | (x >> 32); +} + +__device__ __forceinline__ void keccak_f1600(uint64_t st[25]) { + uint64_t C[5], D[5], B[25]; + #pragma unroll + for (int r = 0; r < 24; ++r) { + // Theta + #pragma unroll + for (int x = 0; x < 5; ++x) { + C[x] = st[x] ^ st[x + 5] ^ st[x + 10] ^ st[x + 15] ^ st[x + 20]; + } + #pragma unroll + for (int x = 0; x < 5; ++x) { + D[x] = C[(x + 4) % 5] ^ rotl64(C[(x + 1) % 5], 1); + } + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + st[x + 5 * y] ^= D[x]; + } + } + + // Rho + Pi: B[pi(x,y)] = rotl(st[x,y], rho(x,y)) + // pi: (x', y') = (y, (2x + 3y) mod 5) + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + int nx = y; + int ny = (2 * x + 3 * y) % 5; + B[nx + 5 * ny] = rotl64(st[x + 5 * y], KECCAK_RHO_OFFSETS[x + 5 * y]); + } + } + + // Chi + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + st[x + 5 * y] = + B[x + 5 * y] ^ ((~B[((x + 1) % 5) + 5 * y]) & B[((x + 2) % 5) + 5 * y]); + } + } + + // Iota + st[0] ^= KECCAK_RC[r]; + } +} + +// --------------------------------------------------------------------------- +// Helper: absorb one 8-byte lane (already in lane form — i.e. LE interpretation +// of the BE-serialised u64) into the sponge at `rate_pos` (in bytes). Permutes +// when a full 136-byte block has been absorbed. +// --------------------------------------------------------------------------- +__device__ __forceinline__ void absorb_lane(uint64_t st[25], + uint32_t &rate_pos, + uint64_t lane) { + st[rate_pos / 8] ^= lane; + rate_pos += 8; + if (rate_pos == 136) { + keccak_f1600(st); + rate_pos = 0; + } +} + +// --------------------------------------------------------------------------- +// After all data lanes absorbed, apply Keccak (pre-SHA-3) padding: a single +// 0x01 byte at the current position, then bit 0x80 on the last rate byte +// (byte 135 = last byte of lane 16). Then permute and squeeze 32 bytes from +// the first four lanes in LE order. +// --------------------------------------------------------------------------- +__device__ __forceinline__ void finalize_keccak256(uint64_t st[25], + uint32_t rate_pos, + uint8_t *out32) { + // 0x01 at rate_pos + st[rate_pos / 8] ^= ((uint64_t)0x01) << ((rate_pos & 7) * 8); + // 0x80 at byte 135 (last byte of lane 16) + st[16] ^= ((uint64_t)0x80) << 56; + keccak_f1600(st); + + // Squeeze 32 bytes: 4 lanes, each LE-serialised. + #pragma unroll + for (int i = 0; i < 4; ++i) { + uint64_t lane = st[i]; + #pragma unroll + for (int b = 0; b < 8; ++b) { + out32[i * 8 + b] = (uint8_t)((lane >> (b * 8)) & 0xff); + } + } +} + +// --------------------------------------------------------------------------- +// Goldilocks BASE-FIELD leaf hashing. +// +// For output row `row_idx` (natural order), the leaf hashes the canonical BE +// byte representation of `columns[c][bit_reverse(row_idx, log_num_rows)]` for +// `c` in `[0, num_cols)`, concatenated in column order. Writes 32 bytes to +// `hashed_leaves_out[row_idx * 32 ..]`. +// +// `columns_base_ptr` points to a `num_cols * col_stride * u64` buffer; column +// `c` is the contiguous slab `[c*col_stride .. c*col_stride + num_rows]`. The +// remaining `col_stride - num_rows` entries (if any) are ignored. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak256_leaves_base_batched( + const uint64_t *columns_base_ptr, + uint64_t col_stride, + uint64_t num_cols, + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *hashed_leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_rows) return; + + // Bit-reverse the row index so we read columns at `br` but write the + // hashed leaf at `tid` — matching the CPU `commit_columns_bit_reversed`. + uint64_t br = __brevll(tid) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + for (uint64_t c = 0; c < num_cols; ++c) { + uint64_t v = columns_base_ptr[c * col_stride + br]; + // Canonicalise to match `canonical_u64().to_be_bytes()` on host. + uint64_t canon = goldilocks::canonical(v); + // The on-disk leaf bytes are canon.to_be_bytes(); Keccak reads those + // as a LE lane, which equals bswap64(canon). + uint64_t lane = bswap64(canon); + absorb_lane(st, rate_pos, lane); + } + + finalize_keccak256(st, rate_pos, hashed_leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// Goldilocks EXT3 leaf hashing (3 base-field components per ext3 element). +// +// Components live in three separate base-field slabs (our de-interleaved +// layout). Column `c` component `k` is at `columns_base_ptr[(c*3 + k)*col_stride +// + br]`. Per-element BE bytes are `[comp0, comp1, comp2]` each 8 BE bytes +// (matches `FieldElement::::write_bytes_be`). +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak256_leaves_ext3_batched( + const uint64_t *columns_base_ptr, + uint64_t col_stride, + uint64_t num_cols, // number of ext3 columns (NOT slabs) + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *hashed_leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_rows) return; + uint64_t br = __brevll(tid) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + for (uint64_t c = 0; c < num_cols; ++c) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = columns_base_ptr[(c * 3 + (uint64_t)k) * col_stride + br]; + uint64_t canon = goldilocks::canonical(v); + uint64_t lane = bswap64(canon); + absorb_lane(st, rate_pos, lane); + } + } + + finalize_keccak256(st, rate_pos, hashed_leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// R2 composition-polynomial leaf hashing. +// +// Each leaf hashes `2 * num_parts` ext3 values taken from bit-reversed rows +// `br_0 = reverse_index(2*leaf_idx)` and `br_1 = reverse_index(2*leaf_idx+1)` +// across all `num_parts` parts, in (br_0 row: part 0..K-1) then (br_1 row: +// part 0..K-1) order. Each ext3 value is 3 base components × 8 BE bytes. +// +// Columns arrive in the de-interleaved 3-slab layout: part `p` component +// `k` is at `parts_base_ptr[(p*3 + k) * col_stride + row]`. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_comp_poly_leaves_ext3( + const uint64_t *parts_base_ptr, + uint64_t col_stride, + uint64_t num_parts, + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t num_leaves = num_rows >> 1; + if (tid >= num_leaves) return; + + uint64_t br_0 = __brevll(2 * tid) >> (64 - log_num_rows); + uint64_t br_1 = __brevll(2 * tid + 1) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + // First row (br_0): part 0..K-1 × 3 components each. + for (uint64_t p = 0; p < num_parts; ++p) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = parts_base_ptr[(p * 3 + (uint64_t)k) * col_stride + br_0]; + uint64_t canon = goldilocks::canonical(v); + absorb_lane(st, rate_pos, bswap64(canon)); + } + } + // Second row (br_1). + for (uint64_t p = 0; p < num_parts; ++p) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = parts_base_ptr[(p * 3 + (uint64_t)k) * col_stride + br_1]; + uint64_t canon = goldilocks::canonical(v); + absorb_lane(st, rate_pos, bswap64(canon)); + } + } + + finalize_keccak256(st, rate_pos, leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// FRI layer leaf hashing. +// +// Each leaf hashes 2 consecutive ext3 values: Keccak256 over +// evals[2j].to_bytes_be() ++ evals[2j+1].to_bytes_be() +// = 48 BE bytes = 6 u64 BE lanes. No bit reversal; no column slab layout — +// the input is a single interleaved ext3 eval vector `[a0,a1,a2,b0,b1,b2,...]`. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_fri_leaves_ext3( + const uint64_t *evals_interleaved, // 3 * num_evals u64s (ext3 interleaved) + uint64_t num_leaves, // = num_evals / 2 + uint8_t *leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_leaves) return; + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + uint32_t rate_pos = 0; + + const uint64_t *left = evals_interleaved + 2 * tid * 3; // 3 u64s + const uint64_t *right = left + 3; + #pragma unroll + for (int i = 0; i < 3; ++i) { + uint64_t canon = goldilocks::canonical(left[i]); + absorb_lane(st, rate_pos, bswap64(canon)); + } + #pragma unroll + for (int i = 0; i < 3; ++i) { + uint64_t canon = goldilocks::canonical(right[i]); + absorb_lane(st, rate_pos, bswap64(canon)); + } + + finalize_keccak256(st, rate_pos, leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// Merkle inner-tree pair hash: one level of the inner Merkle tree. +// +// `nodes` is the full Merkle node buffer (length `2*leaves_len - 1`, each +// element 32 bytes). `parent_begin` is the node-index offset of the first +// parent slot in this level; children live at `parent_begin + n_pairs`. +// The layout mirrors `crypto/crypto/src/merkle_tree/merkle.rs`: +// +// children: nodes[parent_begin + n_pairs .. parent_begin + 3 * n_pairs] +// parents: nodes[parent_begin .. parent_begin + n_pairs] +// +// Each thread hashes one child pair → one parent. Keccak-256 of the +// concatenation of two 32-byte siblings; identical to +// `FieldElementVectorBackend::hash_new_parent` on host. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_merkle_level( + uint8_t *nodes, + uint64_t parent_begin, // node index (counted in 32-byte nodes) + uint64_t n_pairs) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_pairs) return; + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + const uint64_t *left = reinterpret_cast( + nodes + (parent_begin + n_pairs + 2 * tid) * 32); + #pragma unroll + for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, left[i]); + + const uint64_t *right = reinterpret_cast( + nodes + (parent_begin + n_pairs + 2 * tid + 1) * 32); + #pragma unroll + for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, right[i]); + + finalize_keccak256(st, rate_pos, nodes + (parent_begin + tid) * 32); +} diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index 03f0b67ca..8e956eef3 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -91,6 +91,7 @@ impl Drop for PinnedStaging { const ARITH_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/arith.ptx")); const NTT_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/ntt.ptx")); +const KECCAK_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/keccak.ptx")); /// Number of CUDA streams in the pool. Larger pools let many rayon-parallel /// callers overlap on the GPU without serializing on stream ownership. The /// default stream is deliberately excluded because it synchronises with all @@ -106,6 +107,11 @@ pub struct Backend { /// buffers 32×-inflated memory use and multiplied the one-time pinning /// cost for every first use of a new table size). pinned_staging: Mutex, + /// Separate pinned staging for Merkle leaf hashes. Sized `num_rows * 32` + /// bytes; lives alongside the LDE staging so the GPU→host D2H for + /// hashed leaves runs at full PCIe line-rate instead of the pageable + /// ~1.3 GB/s path that would otherwise eat ~100 ms per main-trace commit. + pinned_hashes: Mutex, util_stream: Arc, next: AtomicUsize, @@ -130,6 +136,13 @@ pub struct Backend { pub pointwise_mul_batched: CudaFunction, pub scalar_mul_batched: CudaFunction, + // keccak.ptx + pub keccak256_leaves_base_batched: CudaFunction, + pub keccak256_leaves_ext3_batched: CudaFunction, + pub keccak_comp_poly_leaves_ext3: CudaFunction, + pub keccak_fri_leaves_ext3: CudaFunction, + pub keccak_merkle_level: CudaFunction, + // Twiddle caches keyed by log_n. fwd_twiddles: Mutex>>>>, inv_twiddles: Mutex>>>>, @@ -146,12 +159,14 @@ impl Backend { let arith = ctx.load_module(Ptx::from_src(ARITH_PTX))?; let ntt = ctx.load_module(Ptx::from_src(NTT_PTX))?; + let keccak = ctx.load_module(Ptx::from_src(KECCAK_PTX))?; let mut streams = Vec::with_capacity(STREAM_POOL_SIZE); for _ in 0..STREAM_POOL_SIZE { streams.push(ctx.new_stream()?); } let pinned_staging = Mutex::new(PinnedStaging::empty()); + let pinned_hashes = Mutex::new(PinnedStaging::empty()); // Separate "utility" stream for twiddle uploads and other bookkeeping; // not part of the pool that callers rotate through. let util_stream = ctx.new_stream()?; @@ -178,11 +193,17 @@ impl Backend { ntt_dit_8_levels_batched: ntt.load_function("ntt_dit_8_levels_batched")?, pointwise_mul_batched: ntt.load_function("pointwise_mul_batched")?, scalar_mul_batched: ntt.load_function("scalar_mul_batched")?, + keccak256_leaves_base_batched: keccak.load_function("keccak256_leaves_base_batched")?, + keccak256_leaves_ext3_batched: keccak.load_function("keccak256_leaves_ext3_batched")?, + keccak_comp_poly_leaves_ext3: keccak.load_function("keccak_comp_poly_leaves_ext3")?, + keccak_fri_leaves_ext3: keccak.load_function("keccak_fri_leaves_ext3")?, + keccak_merkle_level: keccak.load_function("keccak_merkle_level")?, fwd_twiddles: Mutex::new(vec![None; max_log]), inv_twiddles: Mutex::new(vec![None; max_log]), ctx, streams, pinned_staging, + pinned_hashes, util_stream, next: AtomicUsize::new(0), }) @@ -201,6 +222,12 @@ impl Backend { &self.pinned_staging } + /// Separate pinned staging for Merkle leaf hash output. Sized in u64 + /// units; caller should reserve `(num_rows * 32 + 7) / 8` u64s. + pub fn pinned_hashes(&self) -> &Mutex { + &self.pinned_hashes + } + pub fn fwd_twiddles_for(&self, log_n: u64) -> Result>> { self.cached_twiddles(log_n, true) } diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 22d62ca3e..ccf5abb1d 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -16,6 +16,7 @@ use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; use crate::Result; use crate::device::backend; +use crate::merkle::{launch_keccak_base, launch_keccak_ext3}; use crate::ntt::run_ntt_body; /// Handle to a base-field LDE kept live on device after R1 commit. @@ -508,6 +509,963 @@ pub fn coset_lde_batch_base_into( Ok(()) } +/// Variant of `coset_lde_batch_base_into` that also emits the Keccak-256 +/// Merkle leaf hashes from the LDE output — all on GPU, no second H2D of +/// the LDE data. Leaves are computed reading columns at bit-reversed rows +/// (matching `commit_columns_bit_reversed` on the CPU side). +/// +/// `hashed_leaves_out` must be `lde_size * 32` bytes (one 32-byte digest +/// per output row, in natural row order). +pub fn coset_lde_batch_base_into_with_leaf_hash( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + hashed_leaves_out: &mut [u8], +) -> Result<()> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + let n = columns[0].len(); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), lde_size); + } + assert_eq!(hashed_leaves_out.len(), lde_size * 32); + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + + use rayon::prelude::*; + let pinned_base_ptr = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: disjoint regions per c, outer staging lock held. + let dst = + unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; + dst.copy_from_slice(col); + }); + + let mut buf = stream.alloc_zeros::(m * lde_size)?; + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // iNTT + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + // pointwise coset scale + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + // forward NTT on full LDE slab + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + + // Keccak-256 leaf hashing directly on the device LDE buffer. + let mut hashes_dev = stream.alloc_zeros::(lde_size * 32)?; + launch_keccak_base( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut hashes_dev, + )?; + + // D2H the LDE into the pinned LDE staging and the hashes into a + // dedicated pinned hash staging, in parallel on the same stream. Both + // at pinned PCIe line-rate — pageable D2H of the 128 MB hash buffer + // would otherwise cost ~100 ms per main-trace commit. + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + let hashes_u64_len = (lde_size * 32).div_ceil(8); + let hashes_staging_slot = be.pinned_hashes(); + let mut hashes_staging = hashes_staging_slot.lock().unwrap(); + hashes_staging.ensure_capacity(hashes_u64_len, &be.ctx)?; + let hashes_pinned = unsafe { hashes_staging.as_mut_slice(hashes_u64_len) }; + // `memcpy_dtoh` needs a byte slice. Reinterpret the u64 pinned buffer + // as bytes — same allocation, just typed differently. + let hashes_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(hashes_pinned.as_mut_ptr() as *mut u8, lde_size * 32) + }; + stream.memcpy_dtoh(&hashes_dev, hashes_pinned_bytes)?; + stream.synchronize()?; + + // Copy pinned → caller outputs in parallel with the hash memcpy. + let pinned_ptr = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) + }; + dst.copy_from_slice(src); + }); + // Rayon-parallel memcpy of 128 MB from pinned → caller. Single-threaded + // `copy_from_slice` faults virgin pageable pages one at a time; the + // mm_struct rwsem serialises them into ~100 ms at 1M-fib scale. Chunk + // the slice so ~N cores pre-fault+write in parallel. + const CHUNK: usize = 64 * 1024; // 64 KiB ≈ 16 pages per chunk + let pinned_hash_ptr = hashes_pinned_bytes.as_ptr() as usize; + hashed_leaves_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_hash_ptr as *const u8).add(i * CHUNK), dst.len()) + }; + dst.copy_from_slice(src); + }); + drop(hashes_staging); + drop(staging); + Ok(()) +} + +/// Like `coset_lde_batch_base_into_with_leaf_hash`, but also builds the full +/// Merkle tree on device and returns the `2*lde_size - 1` node buffer back +/// to the caller in `merkle_nodes_out` (byte length `(2*lde_size - 1) * 32`). +/// +/// The leaf hashes are never exposed to the caller — they stay on device and +/// feed straight into the pair-hash tree kernel, avoiding the +/// pinned→pageable→pinned round-trip that the separate-step GPU tree build +/// would pay. +pub fn coset_lde_batch_base_into_with_merkle_tree( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + coset_lde_batch_base_into_with_merkle_tree_inner( + columns, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + false, + ) + .map(|_| ()) +} + +/// Fused LDE + leaf-hash + Merkle tree build. If `keep_device_buf` is true, +/// returns an `Arc>` wrapping the LDE device buffer so callers +/// (R2–R4 GPU paths) can reuse the LDE without a re-H2D. +pub fn coset_lde_batch_base_into_with_merkle_tree_keep( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result { + let opt = coset_lde_batch_base_into_with_merkle_tree_inner( + columns, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + true, + )?; + let handle = opt.expect("keep_device_buf=true must return Some"); + Ok(handle) +} + +fn coset_lde_batch_base_into_with_merkle_tree_inner( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], + keep_device_buf: bool, +) -> Result> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(None); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + let n = columns[0].len(); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), lde_size); + } + let total_nodes = 2 * lde_size - 1; + assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + + use rayon::prelude::*; + let pinned_base_ptr = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + let dst = + unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; + dst.copy_from_slice(col); + }); + + let mut buf = stream.alloc_zeros::(m * lde_size)?; + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // iNTT + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + // forward NTT at LDE size + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + + // Allocate the full node buffer; leaves occupy the tail slab, inner + // nodes are written by the pair-hash level kernel below. `alloc` (not + // `alloc_zeros`) is safe because every byte is written before it is + // read: leaf kernel fills the tail, tree kernel fills the head. + // + // The leaf kernel writes to `nodes_dev` starting at byte offset + // `(lde_size - 1) * 32`; we pass the base pointer as-is because the + // kernel indexes linearly from `hashed_leaves_out[row_idx * 32]`, so we + // build an offset device slice and feed that to the launch. + let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; + let leaves_offset_bytes = (lde_size - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); + let log_num_rows_leaves = lde_size.trailing_zeros() as u64; + let num_cols_u64 = m as u64; + let grid = (lde_size as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak256_leaves_base_batched) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_cols_u64) + .arg(&lde_u64) + .arg(&log_num_rows_leaves) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree levels — mirror the CPU `build(nodes, leaves_len)` scan. + { + let mut level_begin: u64 = (lde_size - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = (n_pairs as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + // D2H the LDE and the tree nodes via pinned staging. + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + + let tree_u64_len = (total_nodes * 32).div_ceil(8); + let tree_staging_slot = be.pinned_hashes(); + let mut tree_staging = tree_staging_slot.lock().unwrap(); + tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; + let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; + let tree_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, total_nodes * 32) + }; + stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; + stream.synchronize()?; + + // Parallel memcpy pinned → caller. + let pinned_ptr = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) + }; + dst.copy_from_slice(src); + }); + const CHUNK: usize = 64 * 1024; + let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; + merkle_nodes_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) + }; + dst.copy_from_slice(src); + }); + drop(tree_staging); + drop(staging); + + if keep_device_buf { + Ok(Some(GpuLdeBase { + buf: Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + +/// Ext3 variant of `coset_lde_batch_base_into_with_leaf_hash`: run an LDE +/// over ext3 columns AND emit Keccak-256 Merkle leaves, all in one on-device +/// pipeline. +pub fn coset_lde_batch_ext3_into_with_leaf_hash( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + hashed_leaves_out: &mut [u8], +) -> Result<()> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in columns.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + assert_eq!(hashed_leaves_out.len(), lde_size * 32); + if n == 0 { + return Ok(()); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + // Keccak-256 on the de-interleaved device buffer (3M base slabs). + let mut hashes_dev = stream.alloc_zeros::(lde_size * 32)?; + launch_keccak_ext3( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut hashes_dev, + )?; + + // D2H LDE (mb * lde_size u64) and hashes (lde_size * 32 bytes). + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + let hashes_u64_len = (lde_size * 32).div_ceil(8); + let hashes_staging_slot = be.pinned_hashes(); + let mut hashes_staging = hashes_staging_slot.lock().unwrap(); + hashes_staging.ensure_capacity(hashes_u64_len, &be.ctx)?; + let hashes_pinned = unsafe { hashes_staging.as_mut_slice(hashes_u64_len) }; + let hashes_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(hashes_pinned.as_mut_ptr() as *mut u8, lde_size * 32) + }; + stream.memcpy_dtoh(&hashes_dev, hashes_pinned_bytes)?; + stream.synchronize()?; + + // Re-interleave pinned → caller ext3 outputs, parallel. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + + // Parallel memcpy of pinned hashes → caller. + const CHUNK: usize = 64 * 1024; + let hash_src_ptr = hashes_pinned_bytes.as_ptr() as usize; + hashed_leaves_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts((hash_src_ptr as *const u8).add(i * CHUNK), dst.len()) + }; + dst.copy_from_slice(src); + }); + drop(hashes_staging); + drop(staging); + Ok(()) +} + +/// Ext3 variant of the fused `coset_lde_batch_base_into_with_merkle_tree`. +/// LDE + leaf hashing + inner-tree build, all on device; D2Hs only the LDE +/// evaluations and the full `2*lde_size - 1` node buffer. +pub fn coset_lde_batch_ext3_into_with_merkle_tree( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + false, + ) + .map(|_| ()) +} + +/// Ext3 variant of [`coset_lde_batch_base_into_with_merkle_tree_keep`] — +/// returns an `Arc>` handle to the de-interleaved LDE device +/// buffer. +pub fn coset_lde_batch_ext3_into_with_merkle_tree_keep( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result { + let opt = coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + true, + )?; + Ok(opt.expect("keep_device_buf=true must return Some")) +} + +fn coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], + keep_device_buf: bool, +) -> Result> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(None); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in columns.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + let total_nodes = 2 * lde_size - 1; + assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + if n == 0 { + return Ok(None); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + // Allocate full tree buffer; leaf kernel writes to the tail slab. + let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; + let leaves_offset_bytes = (lde_size - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); + let log_num_rows_leaves = lde_size.trailing_zeros() as u64; + let num_cols_u64 = m as u64; + let grid = (lde_size as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak256_leaves_ext3_batched) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_cols_u64) + .arg(&lde_u64) + .arg(&log_num_rows_leaves) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree levels. + { + let mut level_begin: u64 = (lde_size - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = (n_pairs as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + // D2H LDE (mb * lde_size u64) and tree nodes. + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + let tree_u64_len = (total_nodes * 32).div_ceil(8); + let tree_staging_slot = be.pinned_hashes(); + let mut tree_staging = tree_staging_slot.lock().unwrap(); + tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; + let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; + let tree_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, total_nodes * 32) + }; + stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; + stream.synchronize()?; + + // Re-interleave pinned → caller ext3 outputs. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + + const CHUNK: usize = 64 * 1024; + let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; + merkle_nodes_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) + }; + dst.copy_from_slice(src); + }); + drop(tree_staging); + drop(staging); + + if keep_device_buf { + Ok(Some(GpuLdeExt3 { + buf: Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + /// Batched ext3 polynomial → coset evaluation. /// /// Input: M ext3 columns of `n` coefficients each (interleaved, 3n u64). @@ -708,6 +1666,237 @@ fn evaluate_poly_coset_batch_ext3_into_inner( } } +/// Fused variant of [`evaluate_poly_coset_batch_ext3_into`]: in addition to +/// the LDE output, builds the R2 composition-polynomial Merkle tree on device +/// (row-pair Keccak leaves at bit-reversed indices + pair-hash inner tree). +/// +/// `merkle_nodes_out` must have byte length `(2 * lde_size - 1) * 32`. +/// Requires `lde_size >= 2`. +pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + if coefs.is_empty() { + return Ok(()); + } + let m = coefs.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in coefs.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + assert!(lde_size >= 2); + let total_nodes = 2 * lde_size - 1; + assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + if n == 0 { + return Ok(()); + } + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + coefs.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + // Build the row-pair Merkle tree on device. + // + // Row-pair commit: each leaf hashes 2 rows (bit-reversed indices) → + // num_leaves = lde_size / 2. Tree size: 2*num_leaves - 1 = lde_size - 1. + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let log_num_rows = log_lde; + let num_parts_u64 = m as u64; + let grid = (num_leaves as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_comp_poly_leaves_ext3) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_parts_u64) + .arg(&lde_u64) + .arg(&log_num_rows) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + { + let mut level_begin: u64 = (num_leaves - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = (n_pairs as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + // D2H LDE and tree. + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + let tree_u64_len = (tight_total_nodes * 32).div_ceil(8); + let tree_staging_slot = be.pinned_hashes(); + let mut tree_staging = tree_staging_slot.lock().unwrap(); + tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; + let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; + let tree_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, tight_total_nodes * 32) + }; + stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; + stream.synchronize()?; + + // Re-interleave pinned → caller ext3 outputs. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + + // Copy pinned tree → caller nodes_out. `merkle_nodes_out.len() == + // total_nodes * 32` is oversized relative to our tight tree; we write + // only the first `tight_total_nodes * 32` bytes and the caller trims. + // Expose the tight byte count via the slice length so the caller can + // construct the MerkleTree with the right node count. + assert!(merkle_nodes_out.len() >= tight_total_nodes * 32); + const CHUNK: usize = 64 * 1024; + let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; + merkle_nodes_out[..tight_total_nodes * 32] + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) + }; + dst.copy_from_slice(src); + }); + drop(tree_staging); + drop(staging); + Ok(()) +} /// Batched coset LDE for Goldilocks **cubic extension** columns. /// /// A degree-3 extension element is `(a, b, c)` in memory (three contiguous diff --git a/crypto/math-cuda/src/lib.rs b/crypto/math-cuda/src/lib.rs index 821f5bd3a..d5d0c9fbc 100644 --- a/crypto/math-cuda/src/lib.rs +++ b/crypto/math-cuda/src/lib.rs @@ -6,6 +6,7 @@ pub mod device; pub mod lde; +pub mod merkle; pub mod ntt; use cudarc::driver::{LaunchConfig, PushKernelArg}; diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs new file mode 100644 index 000000000..0f80206a5 --- /dev/null +++ b/crypto/math-cuda/src/merkle.rs @@ -0,0 +1,415 @@ +//! GPU Keccak-256 leaf hashing for Merkle commits. +//! +//! Matches `FieldElementVectorBackend::hash_data` in +//! `crypto/crypto/src/merkle_tree/backends/field_element_vector.rs`, combined +//! with the `reverse_index` row read pattern used in +//! `commit_columns_bit_reversed` at `crypto/stark/src/prover.rs:368`. +//! +//! Caller supplies base-field column slabs already laid out as +//! `[col * col_stride + row]` (the same layout `coset_lde_batch_base_into` +//! writes to the pinned staging buffer). The kernel bit-reverses `row_idx`, +//! reads each column's canonical u64 at that row, byte-swaps it into a +//! Keccak lane, absorbs lane-by-lane, and squeezes 32 bytes per leaf. +//! +//! For ext3 columns the layout is `[col*3*col_stride + k*col_stride + row]` +//! — three base slabs per ext3 column — and the kernel reads three u64s per +//! column in component order 0,1,2 to match `FieldElement::::write_bytes_be`. + +use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; + +/// Run GPU Keccak-256 leaf hashing on a base-field column buffer. +/// +/// `columns` must hold `num_cols * col_stride` u64s with column `c`'s data +/// at `[c*col_stride .. c*col_stride + num_rows]`. Returns `num_rows * 32` +/// hash bytes in natural (non-bit-reversed) row order. +pub fn keccak_leaves_base( + columns: &[u64], + col_stride: usize, + num_cols: usize, + num_rows: usize, +) -> Result> { + assert!(num_rows.is_power_of_two()); + assert!(columns.len() >= num_cols * col_stride); + let be = backend(); + let stream = be.next_stream(); + let cols_dev = stream.clone_htod(&columns[..num_cols * col_stride])?; + let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; + launch_keccak_base( + stream.as_ref(), + &cols_dev, + col_stride as u64, + num_cols as u64, + num_rows as u64, + &mut out_dev, + )?; + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Ext3 variant — columns interleaved as three base slabs per ext3 column. +/// `columns.len() >= num_cols * 3 * col_stride`. +pub fn keccak_leaves_ext3( + columns: &[u64], + col_stride: usize, + num_cols: usize, + num_rows: usize, +) -> Result> { + assert!(num_rows.is_power_of_two()); + assert!(columns.len() >= num_cols * 3 * col_stride); + let be = backend(); + let stream = be.next_stream(); + let cols_dev = stream.clone_htod(&columns[..num_cols * 3 * col_stride])?; + let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; + launch_keccak_ext3( + stream.as_ref(), + &cols_dev, + col_stride as u64, + num_cols as u64, + num_rows as u64, + &mut out_dev, + )?; + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Block size for Keccak kernels. Per-thread register footprint is ~60 regs +/// (25-lane state + auxiliaries); the default 256 threads/block pushes the +/// block register file past the hardware limit on sm_120 (Blackwell). 128 +/// keeps us inside the budget with some head-room. +const KECCAK_BLOCK_DIM: u32 = 128; + +fn keccak_launch_cfg(num_rows: u64) -> LaunchConfig { + let grid = (num_rows as u32).div_ceil(KECCAK_BLOCK_DIM); + LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (KECCAK_BLOCK_DIM, 1, 1), + shared_mem_bytes: 0, + } +} + +pub(crate) fn launch_keccak_base( + stream: &CudaStream, + cols_dev: &CudaSlice, + col_stride: u64, + num_cols: u64, + num_rows: u64, + out_dev: &mut CudaSlice, +) -> Result<()> { + let be = backend(); + let log_num_rows = num_rows.trailing_zeros() as u64; + let cfg = keccak_launch_cfg(num_rows); + unsafe { + stream + .launch_builder(&be.keccak256_leaves_base_batched) + .arg(cols_dev) + .arg(&col_stride) + .arg(&num_cols) + .arg(&num_rows) + .arg(&log_num_rows) + .arg(out_dev) + .launch(cfg)?; + } + Ok(()) +} + +/// Given `hashed_leaves` of length `leaves_len * 32`, build the full Merkle +/// tree on device and return the complete node buffer `(2*leaves_len - 1) * +/// 32` bytes in the standard layout: +/// +/// `nodes[0..leaves_len - 1]` are inner nodes (root at index 0), and +/// `nodes[leaves_len - 1..]` are the leaves themselves. +/// +/// Matches the CPU `crypto/crypto/src/merkle_tree/merkle.rs` construction so +/// the resulting `nodes` Vec plugs straight into `MerkleTree { root, nodes }` +/// for downstream proof generation. +/// +/// `leaves_len` must be a power of two and ≥ 2. +pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { + assert!(hashed_leaves.len().is_multiple_of(32)); + let leaves_len = hashed_leaves.len() / 32; + assert!(leaves_len >= 2, "tree needs at least two leaves"); + assert!( + leaves_len.is_power_of_two(), + "leaves_len must be a power of two" + ); + + let total_nodes = 2 * leaves_len - 1; + let be = backend(); + let stream = be.next_stream(); + + // Allocate the full node buffer without zero-fill — we overwrite the + // leaf half via H2D immediately, and every inner node is written by the + // pair-hash kernel below. + // SAFETY: every byte is written before it is read: leaves are filled by + // the H2D below; inner nodes are filled by the level loop that follows. + let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; + let leaves_offset_bytes = (leaves_len - 1) * 32; + // SAFETY: target slice `nodes_dev[leaves_offset_bytes..]` has exactly + // `leaves_len * 32 == hashed_leaves.len()` bytes capacity. + { + let mut slice = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + hashed_leaves.len()); + stream.memcpy_htod(hashed_leaves, &mut slice)?; + } + + // Build level by level. The CPU `build(nodes, leaves_len)` starts with + // level_begin_index = leaves_len - 1 + // level_end_index = 2 * level_begin_index + // and each iteration computes: + // new_level_begin_index = level_begin_index / 2 + // new_level_length = level_begin_index - new_level_begin_index + // The parents occupy [new_level_begin_index, level_begin_index); the + // children occupy [level_begin_index, level_end_index + 1). + let mut level_begin: u64 = (leaves_len - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + + let cfg = keccak_launch_cfg(n_pairs); + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + + let out = stream.clone_dtoh(&nodes_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Row-pair Keccak leaf + Merkle tree build for R2 composition-polynomial +/// commit. `parts_interleaved` is `num_parts` slices, each holding an ext3 +/// LDE column interleaved as `[a0,a1,a2, b0,b1,b2, ...]` of length `3*lde_size`. +/// +/// Returns `(2*(lde_size/2) - 1) * 32` bytes of tree nodes in the standard +/// layout (root at byte offset 0, leaves in the tail). +pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Result> { + assert!(!parts_interleaved.is_empty()); + let m = parts_interleaved.len(); + let ext3_elems = parts_interleaved[0].len() / 3; + assert_eq!( + parts_interleaved[0].len(), + 3 * ext3_elems, + "ext3 buffer length must be 3 * lde_size" + ); + for p in parts_interleaved.iter() { + assert_eq!(p.len(), 3 * ext3_elems); + } + let lde_size = ext3_elems; + assert!(lde_size.is_power_of_two() && lde_size >= 2); + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + // Stage: de-interleave each part into 3 base slabs in pinned memory. + let mb = 3 * m; + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + parts_interleaved + .par_iter() + .enumerate() + .for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + // H2D the de-interleaved parts. + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + stream.memcpy_htod(&pinned[..mb * lde_size], &mut buf)?; + + // Leaves into tail of a tight node buffer. + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let col_stride_u64 = lde_size as u64; + let num_parts_u64 = m as u64; + let num_rows_u64 = lde_size as u64; + let log_num_rows = lde_size.trailing_zeros() as u64; + let grid = (num_leaves as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_comp_poly_leaves_ext3) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_parts_u64) + .arg(&num_rows_u64) + .arg(&log_num_rows) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree. + { + let mut level_begin: u64 = (num_leaves - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = (n_pairs as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + let out = stream.clone_dtoh(&nodes_dev)?; + stream.synchronize()?; + drop(staging); + Ok(out) +} + +/// Build a FRI-layer Merkle tree on device from an interleaved ext3 eval +/// vector. Each leaf hashes two consecutive ext3 values; `num_leaves = +/// evals.len() / 6` (since each ext3 is 3 u64s). +/// +/// Returns the `(2*num_leaves - 1) * 32`-byte node buffer in standard layout. +pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { + assert!( + evals.len().is_multiple_of(6), + "evals must hold whole pair-leaves" + ); + let num_evals = evals.len() / 3; + let num_leaves = num_evals / 2; + assert!(num_leaves.is_power_of_two() && num_leaves >= 1); + let tight_total_nodes = 2 * num_leaves - 1; + if tight_total_nodes == 0 { + return Ok(Vec::new()); + } + + let be = backend(); + let stream = be.next_stream(); + + let evals_dev = stream.clone_htod(evals)?; + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + + // Leaf kernel: num_leaves threads, one leaf each. + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let num_leaves_u64 = num_leaves as u64; + let grid = (num_leaves as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_fri_leaves_ext3) + .arg(&evals_dev) + .arg(&num_leaves_u64) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree levels — identical to the R2 version. + { + let mut level_begin: u64 = (num_leaves - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = (n_pairs as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + let out = stream.clone_dtoh(&nodes_dev)?; + stream.synchronize()?; + Ok(out) +} + +pub(crate) fn launch_keccak_ext3( + stream: &CudaStream, + cols_dev: &CudaSlice, + col_stride: u64, + num_cols: u64, + num_rows: u64, + out_dev: &mut CudaSlice, +) -> Result<()> { + let be = backend(); + let log_num_rows = num_rows.trailing_zeros() as u64; + let cfg = keccak_launch_cfg(num_rows); + unsafe { + stream + .launch_builder(&be.keccak256_leaves_ext3_batched) + .arg(cols_dev) + .arg(&col_stride) + .arg(&num_cols) + .arg(&num_rows) + .arg(&log_num_rows) + .arg(out_dev) + .launch(cfg)?; + } + Ok(()) +} diff --git a/crypto/math-cuda/tests/keccak_leaves.rs b/crypto/math-cuda/tests/keccak_leaves.rs new file mode 100644 index 000000000..1e451b386 --- /dev/null +++ b/crypto/math-cuda/tests/keccak_leaves.rs @@ -0,0 +1,140 @@ +//! Parity: GPU Keccak-256 leaf hashes must match CPU +//! `FieldElementVectorBackend::::hash_data` applied to +//! bit-reversed rows (same pattern as `commit_columns_bit_reversed` in the +//! stark prover). + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::traits::ByteConversion; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sha3::{Digest, Keccak256}; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn reverse_index(i: u64, n: u64) -> u64 { + let log_n = n.trailing_zeros(); + i.reverse_bits() >> (64 - log_n) +} + +fn cpu_leaves_base(columns: &[Vec]) -> Vec<[u8; 32]> { + let num_rows = columns[0].len(); + let num_cols = columns.len(); + let byte_len = 8; + (0..num_rows) + .map(|row_idx| { + let br = reverse_index(row_idx as u64, num_rows as u64) as usize; + let mut buf = vec![0u8; num_cols * byte_len]; + for c in 0..num_cols { + columns[c][br].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]); + } + let mut h = Keccak256::new(); + h.update(&buf); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out + }) + .collect() +} + +fn cpu_leaves_ext3(columns: &[Vec]) -> Vec<[u8; 32]> { + let num_rows = columns[0].len(); + let num_cols = columns.len(); + let byte_len = 24; + (0..num_rows) + .map(|row_idx| { + let br = reverse_index(row_idx as u64, num_rows as u64) as usize; + let mut buf = vec![0u8; num_cols * byte_len]; + for c in 0..num_cols { + columns[c][br].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]); + } + let mut h = Keccak256::new(); + h.update(&buf); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out + }) + .collect() +} + +#[test] +fn keccak_leaves_base_matches_cpu() { + for log_n in [4u32, 6, 8, 10, 12] { + for num_cols in [1usize, 5, 17, 41] { + let n = 1 << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(100 + log_n as u64 + num_cols as u64); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| Fp::from_raw(rng.r#gen::())).collect()) + .collect(); + + let cpu = cpu_leaves_base(&columns); + + // Flatten columns into a contiguous base slab layout matching + // `coset_lde_batch_base_into`'s pinned staging format: + // `[col * stride + row]`. Use stride = num_rows for compactness. + let mut flat = vec![0u64; num_cols * n]; + for (c, col) in columns.iter().enumerate() { + for (r, e) in col.iter().enumerate() { + flat[c * n + r] = *e.value(); + } + } + let gpu = math_cuda::merkle::keccak_leaves_base(&flat, n, num_cols, n).unwrap(); + assert_eq!(gpu.len(), n * 32); + for i in 0..n { + assert_eq!( + &gpu[i * 32..(i + 1) * 32], + &cpu[i][..], + "base leaf mismatch at row {i} (log_n={log_n}, cols={num_cols})" + ); + } + } + } +} + +#[test] +fn keccak_leaves_ext3_matches_cpu() { + for log_n in [4u32, 6, 8, 10] { + for num_cols in [1usize, 3, 11, 20] { + let n = 1 << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(200 + log_n as u64 + num_cols as u64); + let columns: Vec> = (0..num_cols) + .map(|_| { + (0..n) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() + }) + .collect(); + + let cpu = cpu_leaves_ext3(&columns); + + // GPU expects 3 base slabs per ext3 column in the order + // [col*3+0 (comp a), col*3+1 (comp b), col*3+2 (comp c)], each a + // contiguous slab of n u64s (length = num_cols * 3 * n). + let mut flat = vec![0u64; num_cols * 3 * n]; + for (c, col) in columns.iter().enumerate() { + for (r, e) in col.iter().enumerate() { + flat[(c * 3) * n + r] = *e.value()[0].value(); + flat[(c * 3 + 1) * n + r] = *e.value()[1].value(); + flat[(c * 3 + 2) * n + r] = *e.value()[2].value(); + } + } + let gpu = math_cuda::merkle::keccak_leaves_ext3(&flat, n, num_cols, n).unwrap(); + assert_eq!(gpu.len(), n * 32); + for i in 0..n { + assert_eq!( + &gpu[i * 32..(i + 1) * 32], + &cpu[i][..], + "ext3 leaf mismatch at row {i} (log_n={log_n}, cols={num_cols})" + ); + } + } + } +} diff --git a/crypto/math-cuda/tests/merkle_tree.rs b/crypto/math-cuda/tests/merkle_tree.rs new file mode 100644 index 000000000..34d44c767 --- /dev/null +++ b/crypto/math-cuda/tests/merkle_tree.rs @@ -0,0 +1,92 @@ +//! Parity: GPU Merkle inner-tree construction must match the CPU +//! `crypto/crypto/src/merkle_tree/merkle.rs` `build_from_hashed_leaves` +//! (Keccak-256 pair hash at each level). + +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sha3::{Digest, Keccak256}; + +fn cpu_hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] { + let mut h = Keccak256::new(); + h.update(left); + h.update(right); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out +} + +/// CPU reference: same algorithm as `build_from_hashed_leaves`. +fn cpu_merkle_nodes(leaves: &[[u8; 32]]) -> Vec<[u8; 32]> { + let leaves_len = leaves.len(); + assert!(leaves_len.is_power_of_two() && leaves_len >= 2); + let total = 2 * leaves_len - 1; + + let mut nodes: Vec<[u8; 32]> = vec![[0u8; 32]; total]; + for (i, leaf) in leaves.iter().enumerate() { + nodes[leaves_len - 1 + i] = *leaf; + } + + let mut level_begin = leaves_len - 1; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + for j in 0..n_pairs { + let left = nodes[level_begin + 2 * j]; + let right = nodes[level_begin + 2 * j + 1]; + nodes[new_begin + j] = cpu_hash_pair(&left, &right); + } + level_begin = new_begin; + } + nodes +} + +fn run_parity(log_n: u32, seed: u64) { + let leaves_len = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let leaves: Vec<[u8; 32]> = (0..leaves_len) + .map(|_| { + let mut arr = [0u8; 32]; + rng.fill(&mut arr[..]); + arr + }) + .collect(); + + // Flat byte layout for the GPU entry point. + let mut flat = Vec::with_capacity(leaves_len * 32); + for l in &leaves { + flat.extend_from_slice(l); + } + + let gpu_nodes_bytes = math_cuda::merkle::build_merkle_tree_on_device(&flat).unwrap(); + assert_eq!(gpu_nodes_bytes.len(), (2 * leaves_len - 1) * 32); + + let cpu_nodes = cpu_merkle_nodes(&leaves); + + for i in 0..cpu_nodes.len() { + let g = &gpu_nodes_bytes[i * 32..(i + 1) * 32]; + let c = &cpu_nodes[i]; + assert_eq!( + g, c, + "node {i} mismatch at log_n={log_n} (cpu={c:?}, gpu={g:?})" + ); + } +} + +#[test] +fn merkle_tree_small() { + for log_n in 1u32..=6 { + run_parity(log_n, 100 + log_n as u64); + } +} + +#[test] +fn merkle_tree_medium() { + for log_n in [10u32, 12, 14] { + run_parity(log_n, 500 + log_n as u64); + } +} + +#[test] +fn merkle_tree_large() { + run_parity(18, 9999); +} From affceb10c0aa2eeeb1434b98369887f7b7318aa4 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 6 May 2026 18:49:01 -0300 Subject: [PATCH 05/16] feat(cuda): Round 1 GPU LDE+commit dispatch + device-resident handles --- crypto/crypto/src/merkle_tree/merkle.rs | 24 + crypto/stark/src/gpu_lde.rs | 958 ++++++++++++++++++++++++ crypto/stark/src/lib.rs | 2 + crypto/stark/src/prover.rs | 279 +++++-- crypto/stark/src/trace.rs | 38 + prover/Cargo.toml | 1 + prover/tests/bench_single.rs | 12 + 7 files changed, 1258 insertions(+), 56 deletions(-) create mode 100644 crypto/stark/src/gpu_lde.rs create mode 100644 prover/tests/bench_single.rs diff --git a/crypto/crypto/src/merkle_tree/merkle.rs b/crypto/crypto/src/merkle_tree/merkle.rs index 55fa49a83..789adf1b6 100644 --- a/crypto/crypto/src/merkle_tree/merkle.rs +++ b/crypto/crypto/src/merkle_tree/merkle.rs @@ -54,6 +54,30 @@ where Self::build_from_hashed_leaves(hashed_leaves) } + /// Build a `MerkleTree` from an already-filled node vector whose layout + /// matches [`build_from_hashed_leaves`] output: + /// + /// - `nodes.len() == 2 * leaves_len - 1` where `leaves_len` is a power of two + /// - `nodes[0]` is the root + /// - `nodes[leaves_len - 1 .. 2*leaves_len - 1]` are the leaves + /// + /// Useful when the tree was constructed elsewhere (e.g. on a GPU) and + /// the caller just wants to hand the finished layout to the stark prover. + /// Performs no hashing. + pub fn from_precomputed_nodes(nodes: Vec) -> Option { + if nodes.is_empty() { + return None; + } + // Validate (cheap) that (nodes.len() + 1) is a power of two: there + // must be `leaves_len - 1 + leaves_len = 2*leaves_len - 1` entries. + let total = nodes.len(); + if !(total + 1).is_power_of_two() { + return None; + } + let root = nodes[ROOT].clone(); + Some(MerkleTree { root, nodes }) + } + /// Create a Merkle tree from pre-hashed leaf nodes. /// /// This skips the `hash_leaves` step, useful when leaves have already been diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs new file mode 100644 index 000000000..05e281549 --- /dev/null +++ b/crypto/stark/src/gpu_lde.rs @@ -0,0 +1,958 @@ +//! GPU dispatch layer for the per-column coset LDE. Lives in the stark crate +//! (not `math`) to avoid a dependency cycle between `math` and `math-cuda`. +//! +//! Handles only Goldilocks base-field columns above a size threshold; falls +//! back to CPU for extension-field columns and small columns where kernel +//! launch overhead dominates. Produces the same natural-order, non-canonical +//! LDE evaluations as the CPU path. + +use core::any::type_name; + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsSubFieldOf}; + +use crate::domain::Domain; + +/// Break-even LDE size. Below this, the CPU `coset_lde_full_expand` completes +/// in a few hundred microseconds and the GPU's ~37 kernel launches plus +/// H2D/D2H round-trip is a net loss. The check is on **lde size**, not trace +/// length, because that's what determines the FFT workload. +/// +/// 2^19 is a conservative default calibrated against a 46-core machine where +/// rayon-parallel CPU LDE is already fast. Override via env var for tuning +/// on smaller machines; see `/workspace/lambda_vm/crypto/math-cuda/tests/bench_quick.rs`. +const DEFAULT_GPU_LDE_THRESHOLD: usize = 1 << 19; + +fn gpu_lde_threshold() -> usize { + static CACHED: std::sync::OnceLock = std::sync::OnceLock::new(); + *CACHED.get_or_init(|| { + std::env::var("LAMBDA_VM_GPU_LDE_THRESHOLD") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_GPU_LDE_THRESHOLD) + }) +} + +/// Atomically counted by `try_expand_column` every time it actually routes a +/// column to the GPU. Used by benchmarks to confirm the GPU path fired. +static GPU_LDE_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); + +pub fn gpu_lde_calls() -> u64 { + GPU_LDE_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +pub fn reset_gpu_lde_calls() { + GPU_LDE_CALLS.store(0, std::sync::atomic::Ordering::Relaxed); +} + +pub(crate) static GPU_EXTEND_HALVES_CALLS: std::sync::atomic::AtomicU64 = + std::sync::atomic::AtomicU64::new(0); +pub fn gpu_extend_halves_calls() -> u64 { + GPU_EXTEND_HALVES_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +/// Try to GPU-batch all columns in one pass. +/// +/// Only engaged for Goldilocks-base tables whose LDE size is above the +/// threshold. The prover's `expand_columns_to_lde` hands us every column of +/// one table at once; those columns all share twiddles and coset weights so +/// they can be processed in a single batched pipeline on one stream. +/// +/// Returns `true` if the batch was handled on GPU (and `columns` now contains +/// the LDE evaluations). Returns `false` to let the caller run the per-column +/// CPU fallback. +#[inline] +pub(crate) fn try_expand_columns_batched( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> bool +where + F: IsField, + E: IsField, +{ + if columns.is_empty() { + return true; // nothing to do — same as CPU path + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return false; + } + if type_name::() != type_name::() { + return false; + } + // All columns within one call must be the same size (invariant of the + // caller), but double-check before unsafe extraction. + if columns.iter().any(|c| c.len() != n) { + return false; + } + + // Ext3 fast path: decompose each ext3 column into its 3 base components + // and dispatch to the base-field batched NTT with 3×M logical columns. + // Butterflies with a base-field twiddle act componentwise on ext3, so + // this is exactly equivalent to running the NTT in the extension field. + if type_name::() == type_name::() { + return try_expand_columns_batched_ext3::(columns, blowup_factor, weights); + } + + if type_name::() != type_name::() { + return false; + } + + // Extract raw u64 slices. SAFETY: type_name above confirms + // `E == GoldilocksField`, so `FieldElement` wraps u64 one-to-one. + let raw_columns: Vec> = columns + .iter() + .map(|col| { + col.iter() + .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) + .collect() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + // Pre-size caller Vecs to lde_size so the GPU path can write directly + // into the same backing allocation the caller already holds. This skips + // the intermediate `Vec>` allocation (which would page-fault + // per column) and is the main reason `coset_lde_batch_base_into` exists. + for col in columns.iter_mut() { + // SAFETY: set_len is valid here because capacity is already >= + // lde_size (the caller sized columns via `extract_columns_main(lde_size)`) + // and we're about to overwrite every slot via the GPU copy below. + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + + // Borrow each caller Vec as a raw `&mut [u64]` slice; safe because each + // FieldElement aliases a single u64 when E == GoldilocksField. + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len(); + // SAFETY: see above — single-u64 layout, caller still owns. + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); + math_cuda::lde::coset_lde_batch_base_into( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + ) + .expect("GPU batched coset LDE failed"); + true +} + +/// GPU path for `Prover::extend_half_to_lde`. +/// +/// Inside `decompose_and_extend_d2` (R2 quotient decomposition) the prover +/// does `rayon::join` of two calls: `iFFT(N on g²-coset) → FFT(2N on g-coset)` +/// over ext3 halves H0 and H1. They share the same domain/offset and sizes, +/// so we batch them into a single GPU call with M=2 ext3 columns. +/// +/// Weights = `[1/N, g^(-1)/N, g^(-2)/N, …, g^(-(N-1))/N]`. This bakes the +/// `(g²)^(-k)` input-coset-undo from `interpolate_offset_fft` together with +/// the `g^k` forward-coset-shift from `evaluate_polynomial_on_lde_domain` — +/// net is `g^(-k)` — plus the `1/N` iFFT normalisation. +/// +/// Returns `None` when the GPU path doesn't apply (too small, or CPU path +/// should be used); in that case the caller runs its existing rayon::join. +pub(crate) fn try_extend_two_halves_gpu( + h0: &[FieldElement], + h1: &[FieldElement], + squared_offset: &FieldElement, + domain: &Domain, +) -> Option<(Vec>, Vec>)> +where + F: math::field::traits::IsFFTField + IsField, + E: IsField, + F: IsSubFieldOf, +{ + if h0.len() != h1.len() { + return None; + } + let n = h0.len(); + let blowup = 2; // extend_half_to_lde extends N → 2N always + let lde_size = n * blowup; + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + GPU_EXTEND_HALVES_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + // squared_offset should be `g²`. We recover `g` as `domain.coset_offset` + // and use it to build the `g^(-k) / N` weights. + let _ = squared_offset; // unused (we derive weights from domain) + + // Flatten ext3 slices to raw 3*n u64 buffers. + let to_u64 = |col: &[FieldElement]| -> Vec { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }; + let h0_raw = to_u64(h0); + let h1_raw = to_u64(h1); + + // weights[k] = g^(-k) / N as a u64. + let inv_n = FieldElement::::from(n as u64).inv().expect("N nonzero"); + let g = &domain.coset_offset; + let g_inv = g.inv().expect("g nonzero"); + let mut weights_u64 = Vec::with_capacity(n); + let mut w = inv_n.clone(); + for _ in 0..n { + // F == GoldilocksField by type_name check above, so value is u64. + let v: u64 = unsafe { *(w.value() as *const _ as *const u64) }; + weights_u64.push(v); + w = w * &g_inv; + } + + // Pre-allocate outputs. + let mut lde_h0 = vec![FieldElement::::zero(); lde_size]; + let mut lde_h1 = vec![FieldElement::::zero(); lde_size]; + + GPU_LDE_CALLS.fetch_add(6, std::sync::atomic::Ordering::Relaxed); // 2 ext3 cols × 3 components + { + let inputs: [&[u64]; 2] = [&h0_raw, &h1_raw]; + // View each output Vec> as &mut [u64] of length 3*lde_size. + let out0_ptr = lde_h0.as_mut_ptr() as *mut u64; + let out1_ptr = lde_h1.as_mut_ptr() as *mut u64; + // SAFETY: ext3 FieldElement is [u64; 3] in memory, and the Vec has len + // = lde_size so the backing is 3*lde_size u64s. + let out0_slice = unsafe { core::slice::from_raw_parts_mut(out0_ptr, 3 * lde_size) }; + let out1_slice = unsafe { core::slice::from_raw_parts_mut(out1_ptr, 3 * lde_size) }; + let mut outputs: [&mut [u64]; 2] = [out0_slice, out1_slice]; + math_cuda::lde::coset_lde_batch_ext3_into(&inputs, n, blowup, &weights_u64, &mut outputs) + .expect("GPU extend_half_to_lde failed"); + } + + Some((lde_h0, lde_h1)) +} + +/// Combined GPU LDE + Merkle leaf hash for the base-field main trace. +/// +/// Keeps LDE output on device, runs Keccak-256 on the device buffer directly, +/// D2Hs both LDE columns (for Round 2-4 reuse) and hashed leaves (for tree +/// construction). Avoids the second H2D that a separate GPU Merkle commit +/// path would require. +/// +/// On success: resizes each `columns[c]` to `lde_size` with the LDE output, +/// and returns `Vec` — the Keccak-256 hashed leaves in natural +/// row order, ready to pass to `BatchedMerkleTree::build_from_hashed_leaves`. +#[allow(dead_code)] +pub(crate) fn try_expand_and_leaf_hash_batched( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option> +where + F: IsField, + E: IsField, +{ + if columns.is_empty() { + return Some(Vec::new()); + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if columns.iter().any(|c| c.len() != n) { + return None; + } + + let raw_columns: Vec> = columns + .iter() + .map(|col| { + col.iter() + .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) + .collect() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len(); + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + // Allocate as Vec<[u8; 32]> directly so we both skip the zero-fill pass + // AND avoid re-chunking afterwards. Fresh pages still fault on first + // write (inside the GPU-side memcpy), but only once each. + let mut leaves: Vec<[u8; 32]> = Vec::with_capacity(lde_size); + // SAFETY: we fill every byte via memcpy_dtoh below. + unsafe { leaves.set_len(lde_size) }; + let hashed_bytes_ptr = leaves.as_mut_ptr() as *mut u8; + let hashed_bytes: &mut [u8] = + unsafe { std::slice::from_raw_parts_mut(hashed_bytes_ptr, lde_size * 32) }; + + GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + math_cuda::lde::coset_lde_batch_base_into_with_leaf_hash( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + hashed_bytes, + ) + .expect("GPU LDE+leaf-hash failed"); + + Some(leaves) +} + +pub(crate) static GPU_LEAF_HASH_CALLS: std::sync::atomic::AtomicU64 = + std::sync::atomic::AtomicU64::new(0); +pub fn gpu_leaf_hash_calls() -> u64 { + GPU_LEAF_HASH_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +/// Fused variant: LDE + leaf-hash + Merkle tree build, all on device. Skips +/// the pinned→pageable→pinned leaf dance of the separate-step pipeline. +/// Returns the filled `MerkleTree` alongside populating `columns` with +/// the LDE-expanded evaluations. +#[allow(dead_code)] +pub(crate) fn try_expand_leaf_and_tree_batched( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option> +where + F: IsField, + E: IsField, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + if columns.is_empty() { + return None; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if columns.iter().any(|c| c.len() != n) { + return None; + } + // Tree layout needs `2*lde_size - 1` nodes; must be a power-of-two leaf + // count. LDE size is always pow2 here (checked above). + if lde_size < 2 { + return None; + } + + let raw_columns: Vec> = columns + .iter() + .map(|col| { + col.iter() + .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) + .collect() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len(); + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + let total_nodes = 2 * lde_size - 1; + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + // SAFETY: every byte is written by the D2H below. + unsafe { nodes.set_len(total_nodes) }; + let nodes_bytes: &mut [u8] = + unsafe { core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) }; + + GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + .expect("GPU LDE+leaf-hash+tree failed"); + + crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) +} + +/// Same as [`try_expand_leaf_and_tree_batched`] but ALSO retains the LDE +/// device buffer so R2–R4 GPU paths can reuse the LDE without a re-H2D. +/// Returns `(tree, gpu_handle)` on success, `None` if the GPU path doesn't +/// apply (same gates as the non-`_keep` variant). +pub(crate) fn try_expand_leaf_and_tree_batched_keep( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option<( + crypto::merkle_tree::merkle::MerkleTree, + math_cuda::lde::GpuLdeBase, +)> +where + F: IsField, + E: IsField, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + if columns.is_empty() { + return None; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if columns.iter().any(|c| c.len() != n) { + return None; + } + if lde_size < 2 { + return None; + } + + let raw_columns: Vec> = columns + .iter() + .map(|col| { + col.iter() + .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) + .collect() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len(); + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + let total_nodes = 2 * lde_size - 1; + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + unsafe { nodes.set_len(total_nodes) }; + let nodes_bytes: &mut [u8] = + unsafe { core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) }; + + GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let handle = math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree_keep( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + .expect("GPU LDE+leaf-hash+tree+keep failed"); + + let tree = crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes)?; + Some((tree, handle)) +} + +/// Ext3 variant of [`try_expand_leaf_and_tree_batched`]. Same fused flow +/// (LDE → leaf-hash → tree build) but over ext3 columns via the three-slab +/// decomposition; `B::Node = [u8; 32]` by construction for +/// `BatchKeccak256Backend`. +#[allow(dead_code)] +pub(crate) fn try_expand_leaf_and_tree_batched_ext3( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option> +where + F: IsField, + E: IsField, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + if columns.is_empty() { + return None; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if lde_size < 2 { + return None; + } + + // SAFETY: `E == Degree3Goldilocks`; each `FieldElement` is + // memory-equivalent to `[u64; 3]`. Copy out a Vec view per column. + let raw_columns: Vec> = columns + .iter() + .map(|col| { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len() * 3; + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + let total_nodes = 2 * lde_size - 1; + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + unsafe { nodes.set_len(total_nodes) }; + let nodes_bytes: &mut [u8] = + unsafe { core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) }; + + GPU_LDE_CALLS.fetch_add( + (columns.len() * 3) as u64, + std::sync::atomic::Ordering::Relaxed, + ); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + .expect("GPU ext3 LDE+leaf-hash+tree failed"); + + crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) +} + +/// Same as [`try_expand_leaf_and_tree_batched_ext3`] but also returns the +/// ext3 LDE device buffer (de-interleaved 3-slab layout) so downstream GPU +/// rounds can reuse it. +pub(crate) fn try_expand_leaf_and_tree_batched_ext3_keep( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option<( + crypto::merkle_tree::merkle::MerkleTree, + math_cuda::lde::GpuLdeExt3, +)> +where + F: IsField, + E: IsField, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + if columns.is_empty() { + return None; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if lde_size < 2 { + return None; + } + + let raw_columns: Vec> = columns + .iter() + .map(|col| { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len() * 3; + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + let total_nodes = 2 * lde_size - 1; + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + unsafe { nodes.set_len(total_nodes) }; + let nodes_bytes: &mut [u8] = + unsafe { core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) }; + + GPU_LDE_CALLS.fetch_add( + (columns.len() * 3) as u64, + std::sync::atomic::Ordering::Relaxed, + ); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let handle = math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree_keep( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + .expect("GPU ext3 LDE+leaf-hash+tree+keep failed"); + + let tree = crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes)?; + Some((tree, handle)) +} + +/// Ext3 variant of [`try_expand_and_leaf_hash_batched`] for the aux trace. +/// Decomposes each ext3 column into three base slabs, runs the LDE + Keccak +/// ext3 kernel in one on-device pipeline, re-interleaves LDE output back to +/// ext3 layout, and returns hashed leaves. +#[allow(dead_code)] +pub(crate) fn try_expand_and_leaf_hash_batched_ext3( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option> +where + F: IsField, + E: IsField, +{ + if columns.is_empty() { + return Some(Vec::new()); + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if columns.iter().any(|c| c.len() != n) { + return None; + } + + let raw_columns: Vec> = columns + .iter() + .map(|col| { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len() * 3; + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + let mut leaves: Vec<[u8; 32]> = Vec::with_capacity(lde_size); + unsafe { leaves.set_len(lde_size) }; + let hashed_bytes: &mut [u8] = + unsafe { std::slice::from_raw_parts_mut(leaves.as_mut_ptr() as *mut u8, lde_size * 32) }; + + GPU_LDE_CALLS.fetch_add( + (columns.len() * 3) as u64, + std::sync::atomic::Ordering::Relaxed, + ); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + math_cuda::lde::coset_lde_batch_ext3_into_with_leaf_hash( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + hashed_bytes, + ) + .expect("GPU ext3 LDE+leaf-hash failed"); + + Some(leaves) +} + +/// Ext3 specialisation of [`try_expand_columns_batched`]. `E` is known to be +/// `Degree3GoldilocksExtensionField` by type_name match at the caller. +fn try_expand_columns_batched_ext3( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> bool +where + F: IsField, + E: IsField, +{ + if columns.is_empty() { + return true; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + + // SAFETY: caller confirmed `E == Degree3GoldilocksExtensionField` via + // type_name. That means `FieldElement` wraps `[FieldElement; 3]`, + // which is memory-equivalent to `[u64; 3]`. A `&[FieldElement]` of + // length `n` is therefore a contiguous `3 * n * 8` byte buffer. + let raw_columns: Vec> = columns + .iter() + .map(|col| { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + // Copy rather than borrow: the caller still owns `col` and will + // reuse its backing storage after we resize + rewrite below. + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }) + .collect(); + // F is `type_name::() == GoldilocksField` by caller precondition; + // `F::BaseType == u64`, so we can read each `w.value()` as a `*const u64`. + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + // Pre-size each ext3 column to lde_size so its backing Vec has the right + // length for the output re-interleave. Capacity must already be >= + // lde_size (caller's `extract_columns_main(lde_size)` ensures this). + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + // SAFETY: overwritten fully by the GPU path below. + unsafe { col.set_len(lde_size) }; + } + + // View each column's backing memory as a `&mut [u64]` of length + // `3*lde_size`. Safe because ext3 elements are `[u64; 3]` layouts. + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len() * 3; + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + // Account each ext3 column as 3 logical GPU LDE "calls" (base-field + // components) so the counter matches the base-field batched path. + GPU_LDE_CALLS.fetch_add( + (columns.len() * 3) as u64, + std::sync::atomic::Ordering::Relaxed, + ); + math_cuda::lde::coset_lde_batch_ext3_into( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + ) + .expect("GPU batched ext3 coset LDE failed"); + true +} + +// ============================================================================ +// GPU barycentric OOD evaluation +// ============================================================================ +// +// Infrastructure for future use: these wrappers drive +// `math_cuda::barycentric::barycentric_{base,ext3}` and apply the trailing ext3 +// scalar on host. See the CPU reference in +// `crypto/math/src/polynomial/mod.rs::interpolate_coset_eval_*_with_g_n_inv`. +// +// NOT currently wired into the prover — a benchmark on fib_iterative_{1M, 4M} +// showed the CPU path (rayon over ~50 columns) already finishes in <1 ms wall +// because the GPU is busy with LDE and Merkle on parallel streams, so moving +// R3 OOD to the GPU just serialises work without freeing CPU wall time. +// Kept here and covered by parity tests in `crypto/math-cuda/tests/barycentric.rs` +// because it remains a net win for single-table or very-large-trace workloads. +// +// The GPU kernel returns the unscaled sum +// S = Σ_i point_i · eval_i · inv_denom_i +// per column; the final barycentric value is +// f(z) = scalar · (z^N − g^N) · S +// with `scalar = n_inv · g_n_inv` kept in the base field. + +// ============================================================================ +// GPU Merkle inner-tree construction +// ============================================================================ +// +// After the GPU keccak leaf-hash kernels produce a flat `[u8; 32]` leaf vec, +// the inner tree construction on CPU via `build_from_hashed_leaves` is a +// rayon-parallel pair-hash scan that still takes ~50-100 ms per table on a +// 46-core host. Delegating it to `math_cuda::merkle::build_merkle_tree_on_device` +// pushes it below 10 ms — the leaf buffer is already on host (it came out of +// `try_expand_and_leaf_hash_batched`), we H2D it once, the GPU does ~log₂(N) +// small kernel launches, and we D2H the full `2*leaves_len - 1` node array. + +static GPU_MERKLE_TREE_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); +pub fn gpu_merkle_tree_calls() -> u64 { + GPU_MERKLE_TREE_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +/// Build a Merkle tree from already-hashed leaves using the GPU pair-hash +/// kernel. Returns the filled `MerkleTree` in the same layout as the CPU +/// `build_from_hashed_leaves` would produce — plug straight in anywhere the +/// prover expected that. +/// +/// Returns `None` if the GPU path is disabled by threshold (`leaves_len < +/// GPU_MERKLE_TREE_THRESHOLD`), falling back to the caller's CPU path. +/// +/// Currently unwired in the prover: benchmarking showed the savings from +/// the GPU pair-hash are eaten by the H2D of leaves + D2H of the tree +/// because the leaves are in pageable memory (they're the caller's Vec from +/// `try_expand_and_leaf_hash_batched`). A proper fusion would keep the +/// leaf buffer on device and run the tree kernel immediately on the GPU +/// copy — left as future work. +#[allow(dead_code)] +pub(crate) fn try_build_merkle_tree_gpu( + hashed_leaves: &[B::Node], +) -> Option> +where + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + let leaves_len = hashed_leaves.len(); + if leaves_len < gpu_merkle_tree_threshold() || !leaves_len.is_power_of_two() || leaves_len < 2 { + return None; + } + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // Flatten host-side leaves into a contiguous byte buffer for the GPU + // kernel. SAFETY: `[u8; 32]` is POD and the slice is contiguous. + let leaves_bytes: &[u8] = unsafe { + core::slice::from_raw_parts(hashed_leaves.as_ptr() as *const u8, leaves_len * 32) + }; + let nodes_bytes = math_cuda::merkle::build_merkle_tree_on_device(leaves_bytes) + .expect("GPU merkle tree build failed"); + + let total_nodes = 2 * leaves_len - 1; + debug_assert_eq!(nodes_bytes.len(), total_nodes * 32); + + // Re-chunk into `Vec<[u8; 32]>` without re-allocating. We'd need an + // explicit copy because Vec and Vec<[u8; 32]> have different + // layouts in the allocator metadata (align differs on some platforms). + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + for i in 0..total_nodes { + let mut n = [0u8; 32]; + n.copy_from_slice(&nodes_bytes[i * 32..(i + 1) * 32]); + nodes.push(n); + } + + crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) +} + +/// Below this (tree size), stay on CPU — rayon pair-hash is already well +/// under a millisecond for small N and would lose to any PCIe round-trip. +const DEFAULT_GPU_MERKLE_TREE_THRESHOLD: usize = 1 << 15; + +fn gpu_merkle_tree_threshold() -> usize { + static CACHED: std::sync::OnceLock = std::sync::OnceLock::new(); + *CACHED.get_or_init(|| { + std::env::var("LAMBDA_VM_GPU_MERKLE_TREE_THRESHOLD") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_GPU_MERKLE_TREE_THRESHOLD) + }) +} diff --git a/crypto/stark/src/lib.rs b/crypto/stark/src/lib.rs index 09ca16ed4..24c149afe 100644 --- a/crypto/stark/src/lib.rs +++ b/crypto/stark/src/lib.rs @@ -8,6 +8,8 @@ pub mod domain; pub mod examples; pub mod frame; pub mod fri; +#[cfg(feature = "cuda")] +pub mod gpu_lde; pub mod grinding; #[cfg(feature = "instruments")] pub mod instruments; diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index a5386017a..e71cda72f 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -193,6 +193,30 @@ where struct Lde { main: Vec>>, aux: Vec>>, + /// Device-side main LDE buffer, populated only when the R1 GPU fused + /// pipeline ran for this table. Kept so R2/R3/R4 GPU paths can read + /// the LDE without re-H2D. + #[cfg(feature = "cuda")] + gpu_main: Option, + #[cfg(feature = "cuda")] + gpu_aux: Option, +} + +/// Result of `commit_main_trace` / `commit_preprocessed_trace`. Wraps the +/// commitment Merkle data plus the owned LDE columns, and — when the R1 +/// fused GPU pipeline ran — the retained device LDE handle. +pub struct MainTraceCommitResult +where + FieldElement: AsBytes, +{ + tree: BatchedMerkleTree, + root: Commitment, + precomputed_tree: Option>, + precomputed_root: Option, + num_precomputed_cols: usize, + columns: Vec>>, + #[cfg(feature = "cuda")] + gpu_main: Option, } impl Round1Commitments @@ -210,7 +234,18 @@ where blowup_factor: usize, has_aux_trace: bool, ) -> Round1 { - let lde_trace = LDETraceTable::from_columns(lde.main, lde.aux, step_size, blowup_factor); + #[allow(unused_mut)] + let mut lde_trace = + LDETraceTable::from_columns(lde.main, lde.aux, step_size, blowup_factor); + #[cfg(feature = "cuda")] + { + if let Some(h) = lde.gpu_main { + lde_trace.set_gpu_main(h); + } + if let Some(h) = lde.gpu_aux { + lde_trace.set_gpu_aux(h); + } + } let main = Round1CommitmentData:: { lde_trace_merkle_tree: Arc::clone(&self.main_merkle_tree), @@ -517,6 +552,19 @@ pub trait IsStarkProver< return; } + // GPU batched fast path: all columns at once in one pipeline on one + // stream. Falls through to per-column rayon when the table is too + // small, the element type isn't Goldilocks, or the `cuda` feature is + // off. + #[cfg(feature = "cuda")] + if crate::gpu_lde::try_expand_columns_batched::( + columns, + domain.blowup_factor, + &twiddles.coset_weights, + ) { + return; + } + #[cfg(feature = "parallel")] let iter = columns.par_iter_mut(); #[cfg(not(feature = "parallel"))] @@ -534,29 +582,52 @@ pub trait IsStarkProver< } /// Compute main LDE, commit, and return the Merkle tree/root along with the - /// owned LDE columns (consumed later in Phase D). + /// owned LDE columns (consumed later in Phase D). When the fused GPU + /// pipeline runs, the device LDE buffer is also kept alive and returned so + /// downstream rounds can read it without a re-H2D. #[allow(clippy::type_complexity)] fn commit_main_trace( trace: &TraceTable, domain: &Domain, twiddles: &LdeTwiddles, - ) -> Result< - ( - BatchedMerkleTree, - Commitment, - Option>, - Option, - usize, - Vec>>, - ), - ProvingError, - > + ) -> Result, ProvingError> where FieldElement: AsBytes, FieldElement: AsBytes, { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; let mut columns = trace.extract_columns_main(lde_size); + + #[cfg(feature = "cuda")] + { + #[cfg(feature = "instruments")] + let t_sub = Instant::now(); + if let Some((tree, handle)) = + crate::gpu_lde::try_expand_leaf_and_tree_batched_keep::< + Field, + Field, + BatchedMerkleTreeBackend, + >(&mut columns, domain.blowup_factor, &twiddles.coset_weights) + { + #[cfg(feature = "instruments")] + let main_lde_dur = t_sub.elapsed(); + #[cfg(feature = "instruments")] + let zero = std::time::Duration::from_secs(0); + let root = tree.root; + #[cfg(feature = "instruments")] + crate::instruments::accum_r1_main(main_lde_dur, zero); + return Ok(MainTraceCommitResult { + tree, + root, + precomputed_tree: None, + precomputed_root: None, + num_precomputed_cols: 0, + columns, + gpu_main: Some(handle), + }); + } + } + #[cfg(feature = "instruments")] let t_sub = Instant::now(); Self::expand_columns_to_lde::(&mut columns, domain, twiddles); @@ -570,7 +641,16 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] crate::instruments::accum_r1_main(main_lde_dur, t_sub.elapsed()); - Ok((tree, root, None, None, 0, columns)) + Ok(MainTraceCommitResult { + tree, + root, + precomputed_tree: None, + precomputed_root: None, + num_precomputed_cols: 0, + columns, + #[cfg(feature = "cuda")] + gpu_main: None, + }) } /// Commit preprocessed trace: precomputed and multiplicity columns get separate trees. @@ -581,17 +661,7 @@ pub trait IsStarkProver< precomputed_commitment: Commitment, num_precomputed_cols: usize, twiddles: &LdeTwiddles, - ) -> Result< - ( - BatchedMerkleTree, - Commitment, - Option>, - Option, - usize, - Vec>>, - ), - ProvingError, - > + ) -> Result, ProvingError> where FieldElement: AsBytes, FieldElement: AsBytes, @@ -621,14 +691,16 @@ pub trait IsStarkProver< "Prover's precomputed commitment doesn't match hardcoded AIR commitment" ); - Ok(( - mult_tree, - mult_root, - Some(precomputed_tree), - Some(precomputed_root), + Ok(MainTraceCommitResult { + tree: mult_tree, + root: mult_root, + precomputed_tree: Some(precomputed_tree), + precomputed_root: Some(precomputed_root), num_precomputed_cols, columns, - )) + #[cfg(feature = "cuda")] + gpu_main: None, + }) } /// Recompute Round1 from the trace, reusing the Merkle trees stored in commitments. @@ -841,6 +913,18 @@ pub trait IsStarkProver< // The squared coset offset is g² (= coset_offset²). let coset_offset_squared = &domain.coset_offset * &domain.coset_offset; + // GPU fast path: batch both halves into one ext3 LDE call. Requires + // `cuda` feature and a qualifying size; falls through to CPU when not. + #[cfg(feature = "cuda")] + if let Some((lde_h0, lde_h1)) = crate::gpu_lde::try_extend_two_halves_gpu( + &h0_evals, + &h1_evals, + &coset_offset_squared, + domain, + ) { + return vec![lde_h0, lde_h1]; + } + #[cfg(feature = "parallel")] let (lde_h0, lde_h1) = rayon::join( || Self::extend_half_to_lde(&h0_evals, &coset_offset_squared, domain), @@ -1626,6 +1710,9 @@ pub trait IsStarkProver< let mut main_commits: Vec> = Vec::with_capacity(num_airs); let mut main_ldes: Vec>>> = Vec::with_capacity(num_airs); + #[cfg(feature = "cuda")] + let mut main_gpu_handles: Vec> = + Vec::with_capacity(num_airs); for chunk_start in (0..num_airs).step_by(k) { let chunk_end = (chunk_start + k).min(num_airs); @@ -1658,19 +1745,21 @@ pub trait IsStarkProver< // Sequential: append roots to shared transcript (Fiat-Shamir ordering) for result in chunk_results { - let (tree, root, pre_tree, pre_root, n_pre, cached_main) = result?; - if let Some(ref pre_r) = pre_root { + let r = result?; + if let Some(ref pre_r) = r.precomputed_root { transcript.append_bytes(pre_r); } - transcript.append_bytes(&root); + transcript.append_bytes(&r.root); main_commits.push(MainCommitData { - main_tree: Arc::new(tree), - main_root: root, - precomputed_tree: pre_tree.map(Arc::new), - precomputed_root: pre_root, - num_precomputed_cols: n_pre, + main_tree: Arc::new(r.tree), + main_root: r.root, + precomputed_tree: r.precomputed_tree.map(Arc::new), + precomputed_root: r.precomputed_root, + num_precomputed_cols: r.num_precomputed_cols, }); - main_ldes.push(cached_main); + main_ldes.push(r.columns); + #[cfg(feature = "cuda")] + main_gpu_handles.push(r.gpu_main); } } @@ -1747,13 +1836,24 @@ pub trait IsStarkProver< }) .collect(); - // Parallel aux commit in chunks of K - #[allow(clippy::type_complexity)] - let mut aux_results: Vec<( - Option>>, + // Parallel aux commit in chunks of K. Fourth field is an optional + // GPU ext3 LDE handle retained when the R1 fused pipeline fires. + #[cfg(feature = "cuda")] + type AuxResult = ( + Option>>, Option, - Vec>>, - )> = Vec::with_capacity(num_airs); + Vec>>, + Option, + ); + #[cfg(not(feature = "cuda"))] + type AuxResult = ( + Option>>, + Option, + Vec>>, + (), + ); + #[allow(clippy::type_complexity)] + let mut aux_results: Vec> = Vec::with_capacity(num_airs); for chunk_start in (0..num_airs).step_by(k) { let chunk_end = (chunk_start + k).min(num_airs); @@ -1773,6 +1873,40 @@ pub trait IsStarkProver< if air.has_aux_trace() { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; let mut columns = trace.extract_columns_aux(lde_size); + + // GPU combined path: ext3 LDE + Keccak-256 leaf + // hashing + Merkle tree build in one on-device + // pipeline. The fused `_keep` variant also returns + // the device LDE handle for downstream GPU rounds. + #[cfg(feature = "cuda")] + { + #[cfg(feature = "instruments")] + let t_sub = Instant::now(); + if let Some((tree, handle)) = + crate::gpu_lde::try_expand_leaf_and_tree_batched_ext3_keep::< + Field, + FieldExtension, + BatchedMerkleTreeBackend, + >( + &mut columns, domain.blowup_factor, &twiddles.coset_weights + ) + { + #[cfg(feature = "instruments")] + let aux_lde_dur = t_sub.elapsed(); + #[cfg(feature = "instruments")] + let zero = std::time::Duration::from_secs(0); + let root = tree.root; + #[cfg(feature = "instruments")] + crate::instruments::accum_r1_aux(aux_lde_dur, zero); + return Ok(( + Some(Arc::new(tree)), + Some(root), + columns, + Some(handle), + )); + } + } + #[cfg(feature = "instruments")] let t_sub = Instant::now(); Self::expand_columns_to_lde::( @@ -1789,20 +1923,28 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] crate::instruments::accum_r1_aux(aux_lde_dur, t_sub.elapsed()); - Ok((Some(Arc::new(tree)), Some(root), columns)) + #[cfg(feature = "cuda")] + let aux_gpu: Option = None; + #[cfg(not(feature = "cuda"))] + let aux_gpu: () = (); + Ok((Some(Arc::new(tree)), Some(root), columns, aux_gpu)) } else { - Ok((None, None, Vec::new())) + #[cfg(feature = "cuda")] + let aux_gpu: Option = None; + #[cfg(not(feature = "cuda"))] + let aux_gpu: () = (); + Ok((None, None, Vec::new(), aux_gpu)) } }) .collect(); // Sequential: append aux roots to forked transcripts for (j, result) in chunk_aux.into_iter().enumerate() { - let (aux_tree, aux_root, cached_aux) = result?; + let (aux_tree, aux_root, cached_aux, aux_gpu) = result?; if let Some(ref root) = aux_root { table_transcripts[chunk_start + j].append_bytes(root); } - aux_results.push((aux_tree, aux_root, cached_aux)); + aux_results.push((aux_tree, aux_root, cached_aux, aux_gpu)); } } @@ -1811,12 +1953,25 @@ pub trait IsStarkProver< let mut commitments: Vec> = Vec::with_capacity(num_airs); let mut cached_ldes: Vec> = Vec::with_capacity(num_airs); - for (((main_commit, main_lde), (aux_tree, aux_root, cached_aux)), bus_public_inputs) in - main_commits - .into_iter() - .zip(main_ldes) - .zip(aux_results) - .zip(bus_inputs_vec) + // Zip in the optional GPU handles so the Lde constructor always + // has a value for its gpu_main/gpu_aux. Under `cfg(not(cuda))` the + // handles are `()` (see AuxResult type alias) — we just discard them. + #[cfg(feature = "cuda")] + let main_gpu_iter: Box>> = + Box::new(main_gpu_handles.into_iter()); + #[cfg(not(feature = "cuda"))] + let main_gpu_iter: Box> = + Box::new(std::iter::repeat_with(|| ()).take(num_airs)); + + for ( + (((main_commit, main_lde), main_gpu_h), (aux_tree, aux_root, cached_aux, aux_gpu_h)), + bus_public_inputs, + ) in main_commits + .into_iter() + .zip(main_ldes) + .zip(main_gpu_iter) + .zip(aux_results) + .zip(bus_inputs_vec) { commitments.push(Round1Commitments { main_merkle_tree: main_commit.main_tree, @@ -1829,10 +1984,22 @@ pub trait IsStarkProver< rap_challenges: lookup_challenges.clone(), bus_public_inputs, }); + #[cfg(feature = "cuda")] cached_ldes.push(Lde { main: main_lde, aux: cached_aux, + gpu_main: main_gpu_h, + gpu_aux: aux_gpu_h, }); + #[cfg(not(feature = "cuda"))] + { + #[allow(clippy::let_unit_value)] + let _ = (main_gpu_h, aux_gpu_h); + cached_ldes.push(Lde { + main: main_lde, + aux: cached_aux, + }); + } } #[cfg(feature = "instruments")] diff --git a/crypto/stark/src/trace.rs b/crypto/stark/src/trace.rs index ef6ee7833..dd1f8979c 100644 --- a/crypto/stark/src/trace.rs +++ b/crypto/stark/src/trace.rs @@ -193,6 +193,16 @@ where pub(crate) aux_columns: Vec>>, pub(crate) lde_step_size: usize, pub(crate) blowup_factor: usize, + /// If the main trace was LDE'd on the GPU via the fused pipeline, + /// the device buffer is retained here so downstream GPU rounds can + /// read the LDE without a re-H2D. `None` when the GPU LDE didn't + /// run (small tables, cuda feature off, fallback path). + #[cfg(feature = "cuda")] + pub(crate) gpu_main: Option, + /// Same as `gpu_main` but for the aux trace (ext3 de-interleaved + /// layout on device). + #[cfg(feature = "cuda")] + pub(crate) gpu_aux: Option, } impl LDETraceTable @@ -215,9 +225,37 @@ where aux_columns, lde_step_size, blowup_factor, + #[cfg(feature = "cuda")] + gpu_main: None, + #[cfg(feature = "cuda")] + gpu_aux: None, } } + /// Attach an already-populated device LDE handle for the main columns. + /// Only set when the GPU fused pipeline produced the LDE — callers that + /// ran the CPU path should leave this alone. + #[cfg(feature = "cuda")] + pub fn set_gpu_main(&mut self, h: math_cuda::lde::GpuLdeBase) { + self.gpu_main = Some(h); + } + + /// Attach an already-populated device LDE handle for the aux columns. + #[cfg(feature = "cuda")] + pub fn set_gpu_aux(&mut self, h: math_cuda::lde::GpuLdeExt3) { + self.gpu_aux = Some(h); + } + + #[cfg(feature = "cuda")] + pub fn gpu_main(&self) -> Option<&math_cuda::lde::GpuLdeBase> { + self.gpu_main.as_ref() + } + + #[cfg(feature = "cuda")] + pub fn gpu_aux(&self) -> Option<&math_cuda::lde::GpuLdeExt3> { + self.gpu_aux.as_ref() + } + /// Consume self and return the owned column vectors. #[allow(clippy::type_complexity)] pub fn into_columns(self) -> (Vec>>, Vec>>) { diff --git a/prover/Cargo.toml b/prover/Cargo.toml index dac711002..bf55a251d 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -6,6 +6,7 @@ edition = "2024" [features] default = ["parallel"] parallel = ["stark/parallel", "math/parallel", "crypto/parallel", "dep:rayon"] +cuda = ["stark/cuda"] debug-checks = ["stark/debug-checks"] instruments = ["stark/instruments"] diff --git a/prover/tests/bench_single.rs b/prover/tests/bench_single.rs new file mode 100644 index 000000000..947f0fddf --- /dev/null +++ b/prover/tests/bench_single.rs @@ -0,0 +1,12 @@ +//! Single-prove bench for profiling with nsys / ncu. +use lambda_vm_prover::test_utils::asm_elf_bytes; + +#[test] +#[ignore = "bench; run with --ignored --nocapture"] +fn prove_fib_1m_once() { + let elf = asm_elf_bytes("fib_iterative_1M"); + // Warm-up pays one-time costs (PTX load, pool warm-up). + let _ = lambda_vm_prover::prove(&elf).expect("warm-up"); + // The profiled run: + let _ = lambda_vm_prover::prove(&elf).expect("prove"); +} From 01172f216ee559a72cc4751a69ff432cbc92c2ac Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Tue, 19 May 2026 15:45:36 -0300 Subject: [PATCH 06/16] merge main --- .github/scripts/publish_bench_vs.sh | 16 +- .github/workflows/bench-vs-nightly.yml | 17 +- Cargo.lock | 3 + Makefile | 41 +- README.md | 107 +- bench_vs/run_ethrex.sh | 183 ++ bin/cli/README.md | 77 +- bin/cli/src/main.rs | 6 +- crypto/crypto/src/merkle_tree/merkle.rs | 7 + crypto/math-cuda/Cargo.toml | 5 +- crypto/math-cuda/build.rs | 57 +- crypto/math-cuda/kernels/arith.cu | 16 +- crypto/math-cuda/kernels/keccak.cu | 24 +- crypto/math-cuda/kernels/ntt.cu | 12 +- crypto/math-cuda/src/device.rs | 43 +- crypto/math-cuda/src/lde.rs | 1984 ++++++----------- crypto/math-cuda/src/lib.rs | 37 +- crypto/math-cuda/src/merkle.rs | 238 +- crypto/math-cuda/src/ntt.rs | 21 +- crypto/math-cuda/tests/bench_quick.rs | 9 +- crypto/math-cuda/tests/ext3_edge.rs | 176 ++ crypto/math-cuda/tests/ext3_sub.rs | 109 + crypto/math-cuda/tests/keccak_leaves.rs | 164 +- crypto/math-cuda/tests/lde.rs | 6 +- crypto/math-cuda/tests/lde_batch_into.rs | 87 + crypto/math-cuda/tests/merkle_tree.rs | 57 +- crypto/math-cuda/tests/ntt.rs | 4 +- crypto/math-cuda/tests/ntt_known.rs | 125 ++ crypto/stark/src/prover.rs | 170 +- docs/cryptography/lookup.md | 168 +- docs/cryptography/proof_system.md | 27 +- docs/general_flow.md | 34 +- executor/README.md | 53 + executor/programs/asm/test_keccak.s | 38 + executor/programs/asm/test_keccak_multi.s | 48 + .../rust/ef_io_demo/.cargo/config.toml | 5 + executor/programs/rust/ef_io_demo/Cargo.toml | 9 + executor/programs/rust/ef_io_demo/src/main.rs | 22 + executor/src/vm/instruction/execution.rs | 212 +- executor/src/vm/memory.rs | 157 +- executor/tests/asm.rs | 46 + executor/tests/rust.rs | 14 + prover/Cargo.toml | 1 + prover/README.md | 54 + prover/src/constraints/cpu.rs | 2 +- prover/src/lib.rs | 36 +- prover/src/tables/cpu.rs | 18 +- prover/src/tables/keccak.rs | 567 +++++ prover/src/tables/keccak_rc.rs | 190 ++ prover/src/tables/keccak_rnd.rs | 986 ++++++++ prover/src/tables/mod.rs | 3 + prover/src/tables/trace_builder.rs | 649 +++++- prover/src/tables/types.rs | 8 + prover/src/test_utils.rs | 63 + prover/src/tests/cpu_tests.rs | 2 +- prover/src/tests/prove_elfs_tests.rs | 285 ++- spec/README.md | 26 +- syscalls/README.md | 52 + syscalls/src/ef_io.rs | 84 + syscalls/src/lib.rs | 1 + syscalls/src/syscalls.rs | 26 +- 61 files changed, 5742 insertions(+), 1945 deletions(-) create mode 100755 bench_vs/run_ethrex.sh create mode 100644 crypto/math-cuda/tests/ext3_edge.rs create mode 100644 crypto/math-cuda/tests/ext3_sub.rs create mode 100644 crypto/math-cuda/tests/lde_batch_into.rs create mode 100644 crypto/math-cuda/tests/ntt_known.rs create mode 100644 executor/README.md create mode 100644 executor/programs/asm/test_keccak.s create mode 100644 executor/programs/asm/test_keccak_multi.s create mode 100644 executor/programs/rust/ef_io_demo/.cargo/config.toml create mode 100644 executor/programs/rust/ef_io_demo/Cargo.toml create mode 100644 executor/programs/rust/ef_io_demo/src/main.rs create mode 100644 prover/src/tables/keccak.rs create mode 100644 prover/src/tables/keccak_rc.rs create mode 100644 prover/src/tables/keccak_rnd.rs create mode 100644 syscalls/README.md create mode 100644 syscalls/src/ef_io.rs diff --git a/.github/scripts/publish_bench_vs.sh b/.github/scripts/publish_bench_vs.sh index 4408c17c0..de79ce64b 100644 --- a/.github/scripts/publish_bench_vs.sh +++ b/.github/scripts/publish_bench_vs.sh @@ -79,6 +79,20 @@ if [ -n "$LAMBDA_PROJECTED_H" ] || [ -n "$SP1_PROJECTED_H" ]; then PROJ_SECTION=',{"type":"divider"},{"type":"header","text":{"type":"plain_text","text":"Linear Projection"}},{"type":"section","text":{"type":"mrkdwn","text":"'"$PROJ_MRKDWN"'"}}' fi +ETHREX_METRICS_FILE="bench_vs_artifacts/ethrex_metrics.txt" +ETHREX_SECTION="" +if [ -f "$ETHREX_METRICS_FILE" ]; then + ETHREX_TIME=$(grep '^ethrex_empty_block_time_s=' "$ETHREX_METRICS_FILE" | cut -d= -f2-) + ETHREX_CYCLES=$(grep '^ethrex_empty_block_cycles=' "$ETHREX_METRICS_FILE" | cut -d= -f2-) + if [ -n "$ETHREX_TIME" ]; then + ETHREX_MRKDWN="*Empty block:* ${ETHREX_TIME}s" + if [ -n "$ETHREX_CYCLES" ] && [ "$ETHREX_CYCLES" != "n/a" ]; then + ETHREX_MRKDWN="${ETHREX_MRKDWN} (${ETHREX_CYCLES} cycles)" + fi + ETHREX_SECTION=',{"type":"divider"},{"type":"header","text":{"type":"plain_text","text":"Lambda VM - Ethrex Empty"}},{"type":"section","text":{"type":"mrkdwn","text":"'"$ETHREX_MRKDWN"'"}}' + fi +fi + curl -X POST "$WEBHOOK_URL" \ -H 'Content-Type: application/json; charset=utf-8' \ - --data '{"blocks":[{"type":"header","text":{"type":"plain_text","text":"Lambda VM vs SP1 v6 - Nightly Benchmark"}},{"type":"divider"},{"type":"section","text":{"type":"mrkdwn","text":"'"$RESULTS_MRKDWN"'"}}'"$PROJ_SECTION"']}' + --data '{"blocks":[{"type":"header","text":{"type":"plain_text","text":"Lambda VM vs SP1 v6 - Nightly Benchmark"}},{"type":"context","elements":[{"type":"mrkdwn","text":"*Program:* Fibonacci · *Device:* CPU"}]},{"type":"divider"},{"type":"section","text":{"type":"mrkdwn","text":"'"$RESULTS_MRKDWN"'"}}'"$PROJ_SECTION$ETHREX_SECTION"']}' diff --git a/.github/workflows/bench-vs-nightly.yml b/.github/workflows/bench-vs-nightly.yml index 2118632f8..c1fdd7c86 100644 --- a/.github/workflows/bench-vs-nightly.yml +++ b/.github/workflows/bench-vs-nightly.yml @@ -43,7 +43,22 @@ jobs: - name: Run nightly benchmark run: | bash ./bench_vs/run.sh \ - --steps 1000000 2000000 4000000 8000000 \ + --steps 1000000 2000000 4000000 8000000 16000000 \ + --report-dir bench_vs_artifacts \ + --no-color + + - name: Restore cached ethrex.elf (TEMPORARY — until /opt/lambda-vm-sysroot is provisioned on the bench runner) + continue-on-error: true + run: | + mkdir -p executor/program_artifacts/rust + cp /home/app/cached_artifacts/ethrex.elf executor/program_artifacts/rust/ethrex.elf + ls -la executor/program_artifacts/rust/ethrex.elf + sha256sum executor/program_artifacts/rust/ethrex.elf + + - name: Run ethrex empty block benchmark + continue-on-error: true + run: | + bash ./bench_vs/run_ethrex.sh \ --report-dir bench_vs_artifacts \ --no-color diff --git a/Cargo.lock b/Cargo.lock index 7b6ed3c62..0f01bf090 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1955,6 +1955,7 @@ dependencies = [ "rayon", "serde", "stark", + "tiny-keccak", ] [[package]] @@ -2128,12 +2129,14 @@ dependencies = [ name = "math-cuda" version = "0.1.0" dependencies = [ + "crypto", "cudarc", "math", "rand 0.8.5", "rand_chacha 0.3.1", "rayon", "sha3", + "stark", ] [[package]] diff --git a/Makefile b/Makefile index c02bffc49..aadd3d961 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY: deps deps-linux deps-macos prepare-test-data compile-programs-asm compile-programs-rust compile-bench \ compile-programs clean-asm clean-rust clean-bench clean-shared clean test test-asm test-no-compile \ test-asm-no-compile test-rust test-rust-no-compile test-executor flamegraph-prover \ -test-fast test-prover test-prover-all build check clippy fmt lint +test-fast test-prover test-prover-all test-math-cuda bench-math-cuda build check clippy fmt lint UNAME := $(shell uname) @@ -46,9 +46,15 @@ BENCH_ARTIFACTS := $(addprefix $(BENCH_ARTIFACTS_DIR)/, $(addsuffix .elf, $(BENC ETHREX_FILE := executor/tests/ethrex_hoodi.bin ETHREX_URL := https://lambda.alignedlayer.com/ethrex_hoodi.bin -SYSROOT_DIR := /opt/lambda-vm-sysroot +# Override with: make ... SYSROOT_DIR=$HOME/.lambda-vm-sysroot +# to install the sysroot in a user-writable location and avoid sudo. +SYSROOT_DIR ?= /opt/lambda-vm-sysroot SYSROOT_TARBALL := /tmp/lambda-vm-sysroot-rv64im.tar.gz SYSROOT_URL := https://lambda.alignedlayer.com/lambda-vm-sysroot-rv64im.tar.gz +# CFLAGS for ckzg / ethrex guest programs: overrides the hardcoded `/opt/lambda-vm-sysroot` +# in their .cargo/config.toml so cargo picks up our $(SYSROOT_DIR) instead. +# $(abspath ...) because the build rule cd's into the program dir before invoking cargo. +SYSROOT_CFLAGS := --target=riscv64 -march=rv64im -mabi=lp64 --sysroot=$(abspath $(SYSROOT_DIR)) # Custom RV64IM target spec location RV64_TARGET_SPEC=$(CURDIR)/executor/programs/riscv64im-lambda-vm-elf.json @@ -64,15 +70,26 @@ prepare-test-data: fi prepare-sysroot: - @if [ ! -d "$(SYSROOT_DIR)" ]; then \ + @if [ -d "$(SYSROOT_DIR)/include" ] && [ -d "$(SYSROOT_DIR)/lib" ]; then \ + echo "Sysroot already exists at $(SYSROOT_DIR)"; \ + else \ echo "Downloading lambda-vm-sysroot-rv64im.tar.gz..."; \ curl -L "$(SYSROOT_URL)" -o "$(SYSROOT_TARBALL)"; \ echo "Extracting sysroot to $(SYSROOT_DIR)..."; \ - sudo mkdir -p /opt && sudo tar -xzf "$(SYSROOT_TARBALL)" -C /opt; \ + if mkdir -p "$(SYSROOT_DIR)" 2>/dev/null && [ -w "$(SYSROOT_DIR)" ]; then \ + tar -xzf "$(SYSROOT_TARBALL)" -C "$(SYSROOT_DIR)" --strip-components=1 \ + || { rm -rf "$(SYSROOT_DIR)" "$(SYSROOT_TARBALL)"; exit 1; }; \ + else \ + echo "$(SYSROOT_DIR) is not writable; using sudo."; \ + echo "Tip: re-run with SYSROOT_DIR=\$$HOME/.lambda-vm-sysroot to avoid sudo."; \ + sudo mkdir -p "$(SYSROOT_DIR)" \ + && sudo tar -xzf "$(SYSROOT_TARBALL)" -C "$(SYSROOT_DIR)" --strip-components=1 \ + || { sudo rm -rf "$(SYSROOT_DIR)"; rm -f "$(SYSROOT_TARBALL)"; exit 1; }; \ + fi; \ rm "$(SYSROOT_TARBALL)"; \ - else \ - echo "Sysroot already exists at $(SYSROOT_DIR)"; \ fi +# Note: the tarball rm above only runs on success — each error handler +# cleans up the tarball itself before `exit 1`. compile-programs-asm: @mkdir -p $(ASM_ARTIFACTS_DIR) @@ -83,7 +100,7 @@ compile-programs-asm: compile-programs-rust: prepare-sysroot $(RUST_ARTIFACTS) -compile-bench: $(BENCH_ARTIFACTS) +compile-bench: prepare-sysroot $(BENCH_ARTIFACTS) compile-programs: compile-programs-asm compile-programs-rust compile-bench @@ -93,6 +110,7 @@ $(RUST_ARTIFACTS_DIR)/%.elf: $(RUST_PROGRAMS_DIR)/%/Cargo.toml @mkdir -p $(RUST_ARTIFACTS_DIR) cd $(RUST_PROGRAMS_DIR)/$* && \ CARGO_TARGET_DIR=$(abspath $(SHARED_TARGET_DIR)) \ + CFLAGS_riscv64im_lambda_vm_elf="$(SYSROOT_CFLAGS)" \ rustup run nightly-2026-02-01 cargo build --release \ --target $(RV64_TARGET_SPEC) \ -Z build-std=core,alloc,std,compiler_builtins,panic_abort \ @@ -105,6 +123,7 @@ $(BENCH_ARTIFACTS_DIR)/%.elf: $(BENCH_PROGRAMS_DIR)/%/Cargo.toml @mkdir -p $(BENCH_ARTIFACTS_DIR) cd $(BENCH_PROGRAMS_DIR)/$* && \ CARGO_TARGET_DIR=$(abspath $(SHARED_TARGET_DIR)) \ + CFLAGS_riscv64im_lambda_vm_elf="$(SYSROOT_CFLAGS)" \ rustup run nightly-2026-02-01 cargo build --release \ --target $(RV64_TARGET_SPEC) \ -Z build-std=core,alloc,std,compiler_builtins,panic_abort \ @@ -166,6 +185,14 @@ test-prover-all: test-prover-debug: cargo test -p lambda-vm-prover --features debug-checks -- --nocapture +# math-cuda parity tests (requires NVIDIA GPU + nvcc) +test-math-cuda: + cargo test -p math-cuda --release + +# math-cuda quick microbench (median of 10 runs) +bench-math-cuda: + cargo test -p math-cuda --release --test bench_quick -- --ignored --nocapture + # Build all build: cargo build --workspace diff --git a/README.md b/README.md index df751528d..e8f00d229 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ Right now, this is a project under development and experimentation and must not ### Dependencies - Rust nightly with `rust-src` component +- Clang with RISC-V target support and LLD linker (used by `make compile-programs-asm`) + - **macOS**: `brew install llvm` (the Homebrew LLVM includes `clang` and `lld` with RISC-V support) + - **Linux**: `apt install clang lld` (or equivalent for your distribution) ### Dev dependencies @@ -26,11 +29,10 @@ Install Rust using [rustup](https://rustup.rs/): curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh ``` -Then install the nightly toolchain with the `rust-src` component (required for building `std` for the custom RISC-V target): +Add the `rust-src` component to the pinned nightly toolchain used to build guest programs (required for building `std` for the custom RISC-V target — `make compile-programs-rust` will auto-fetch the toolchain itself): ```sh -rustup toolchain install nightly -rustup component add rust-src --toolchain nightly +rustup component add rust-src --toolchain nightly-2026-02-01 ``` #### Compile sysroot @@ -39,12 +41,21 @@ Some of the tests require linking with C libraries. ##### Download pre-installed C libraries +The easiest way is to let `make` do it: + +```sh +make prepare-sysroot # installs to /opt (uses sudo) +SYSROOT_DIR=$HOME/.lambda-vm-sysroot make prepare-sysroot # user-writable, no sudo +``` + +Or do it manually: + ```sh wget https://lambda.alignedlayer.com/lambda-vm-sysroot-rv64im.tar.gz sudo mkdir -p /opt && sudo tar -xzf lambda-vm-sysroot-rv64im.tar.gz -C /opt ``` -##### Compile it directly +##### Compile them directly ```sh sudo apt-get install -y autoconf automake autotools-dev curl python3 \ @@ -65,20 +76,62 @@ sudo mkdir -p /opt && sudo tar -xzf lambda-vm-sysroot-rv64im.tar.gz -C /opt cp -r /opt/riscv64-newlib/riscv64-unknown-elf/lib /opt/lambda-vm-sysroot/ ``` -#### Install the dependencies +Then, you can check that the executor works by running: ```sh -make deps +make test-executor ``` -**Note:** At the moment, `make deps` only works on macOS. +### Using the CLI -Then, you can check that the executor works by running: +The `cli` binary lets you execute, prove, and verify RISC-V ELF programs. Build it once with: ```sh -make test-executor +cargo build --release -p cli +``` + +The binary will be available at `target/release/cli`. + +To get a sample program to work with, compile the bundled assembly tests: + +```sh +make compile-programs-asm +``` + +This emits ELF files under `executor/program_artifacts/asm/`. With those in place, you can run the three core commands: + +#### Execute + +Run a program without generating a proof. Useful for sanity checks and debugging: + +```sh +cargo run -p cli --release -- execute executor/program_artifacts/asm/add.elf +``` + +#### Prove + +Generate a STARK proof of the execution: + +```sh +cargo run -p cli --release -- prove executor/program_artifacts/asm/add.elf -o /tmp/proof.bin +``` + +#### Verify + +Verify a proof against the ELF it was generated from. The command exits `0` on success and `1` on failure: + +```sh +cargo run -p cli --release -- verify /tmp/proof.bin executor/program_artifacts/asm/add.elf ``` +For the full CLI reference — including private inputs, blowup factor tuning, timing, and flamegraph profiling — see [`bin/cli/README.md`](./bin/cli/README.md). + +### Writing a guest program + +Guest programs are written in Rust (or RISC-V assembly) and cross-compiled to the custom RV64IM target. The guest SDK [`lambda-vm-syscalls`](./syscalls/README.md) provides the syscalls a program uses to read private input, commit public output, halt, and call precompiles like Keccak. The [`executor`](./executor/README.md) crate is what loads your compiled ELF and emits the per-instruction logs the prover consumes. + +To add a new Rust guest, drop a project under `executor/programs/rust//` and run `make compile-programs-rust`. See [`executor/programs/rust/`](./executor/programs/rust/) for examples (`fibonacci`, `keccak`, `hashmap`, …). + ## Design choices - The Instruction Set Architecture is RISCV64IM @@ -98,7 +151,18 @@ Following [ethrex](https://github.com/lambdaclass/ethrex): ## Documentation -Full documentation can be found in [docs](./docs/). It is currently a work in progress, we expect that as more features and components become ready, they will be included in the docs. +High-level documentation lives in [`docs/`](./docs/): + +- [Overview of VM flow](./docs/general_flow.md) — the pipeline from source code to proof +- [Proof system overview](./docs/cryptography/proof_system.md) — design goals and primitives +- [Lookup arguments](./docs/cryptography/lookup.md) — how tables are linked via LogUp +- [Recommended reading](./docs/other_resources.md) — papers and tutorials + +### Specification + +A formal specification of the VM is written in [Typst](https://typst.app/) under [`spec/`](./spec/) and rendered as a browsable wiki (HTML) or PDF using [`shiroa`](https://myriad-dreamin.github.io/shiroa/). With both tools installed, run `shiroa serve` from `spec/` to host the wiki locally. + +See [`spec/README.md`](./spec/README.md) for full setup instructions. ## Testing @@ -114,6 +178,7 @@ Full documentation can be found in [docs](./docs/). It is currently a work in pr | `make test-asm` | Compile and run ASM tests | | `make test-rust` | Compile and run Rust tests | | `make test-executor` | Compile all programs and run executor tests | +| `make test-math-cuda` | math-cuda parity tests (requires NVIDIA GPU + nvcc) | | `make build` | Build all workspace crates | | `make check` | Check all crates (faster than build, no codegen) | | `make clippy` | Run clippy on all crates | @@ -128,8 +193,8 @@ To run all tests across the project use ### ASM Tests -In order to add a new asm test you should add the `.s` file under `programs/asm` -Then add the corresponding test under `tests/asm.rs` +In order to add a new asm test you should add the `.s` file under `executor/programs/asm` +Then add the corresponding test under `executor/tests/asm.rs` To run them you can use @@ -139,9 +204,9 @@ This will compile them and run the tests ### Rust Tests -In order to add a new rust test you should add the cargo project under `programs/rust` as a new directory. +In order to add a new rust test you should add the cargo project under `executor/programs/rust` as a new directory. The folder should have the same name as the `Cargo.toml` program name. -Then add the corresponding test under `tests/rust.rs` +Then add the corresponding test under `executor/tests/rust.rs` You can run it with @@ -151,8 +216,20 @@ You can run it with You can create a flamegraph for proof generation using the following target: +```sh +make flamegraph-prover ``` - make flamegraph-prover + +This profiles the synthetic `fibonacci_multi_column` STARK example in `crypto/stark` (i.e. the STARK engine itself, not a real guest ELF). To profile the VM prover end-to-end on a real ELF, use the dedicated bench in the `prover` crate: + +```sh +samply record cargo bench --bench profile_vm_prover --features parallel +``` + +For a quick GPU microbench (requires an NVIDIA GPU + `nvcc`): + +```sh +make bench-math-cuda ``` ## Debug Checks diff --git a/bench_vs/run_ethrex.sh b/bench_vs/run_ethrex.sh new file mode 100755 index 000000000..e42ee5356 --- /dev/null +++ b/bench_vs/run_ethrex.sh @@ -0,0 +1,183 @@ +#!/usr/bin/env bash +# Benchmark: Lambda VM proving an empty ethrex block. +# +# Usage: ./bench_vs/run_ethrex.sh [--report-dir DIR] [--no-color] +# +# Prerequisites: +# - Lambda VM CLI build dependencies available +# - Sysroot present at /opt/lambda-vm-sysroot (run `make prepare-sysroot` first) +# - Rust stable + nightly-2026-02-01 installed + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" +TMP_DIR="/tmp/bench_ethrex" +REPORT_DIR="" +NO_COLOR=false + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BOLD='\033[1m' +NC='\033[0m' + +# --- Parse args ------------------------------------------------------------- +while [[ $# -gt 0 ]]; do + case $1 in + --report-dir) + if [[ $# -lt 2 ]]; then echo "--report-dir requires an argument"; exit 1; fi + REPORT_DIR=$2 + shift 2 + ;; + --no-color) + NO_COLOR=true + shift + ;; + -h|--help) + echo "Usage: $0 [--report-dir DIR] [--no-color]" + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +if $NO_COLOR; then + RED='' + GREEN='' + YELLOW='' + BOLD='' + NC='' +fi + +mkdir -p "$TMP_DIR" +rm -rf "${TMP_DIR:?}"/* + +if [ -n "$REPORT_DIR" ]; then + mkdir -p "$REPORT_DIR/raw" +fi + +extract_proving_time() { + sed -nE '/Proving time: [0-9.]+s/ { + s/.*Proving time: ([0-9.]+)s.*/\1/ + p + q + }' +} + +extract_cycles() { + sed -nE '/Cycles: [0-9]+/ { + s/.*Cycles: ([0-9]+).*/\1/ + p + q + }' +} + +# --- Pre-build -------------------------------------------------------------- + +CLI="$ROOT_DIR/target/release/cli" +ETHREX_ELF="$ROOT_DIR/executor/program_artifacts/rust/ethrex.elf" +ETHREX_INPUT="$ROOT_DIR/executor/tests/ethrex_empty_block.bin" +echo -e "${BOLD}=== Ethrex Empty Block Benchmark: Lambda VM ===${NC}" +echo "" + +echo -e "${GREEN}[Lambda VM] Building CLI...${NC}" +cargo build --release -p cli --manifest-path "$ROOT_DIR/Cargo.toml" 2>&1 | tail -5 + +if [ -f "$ETHREX_ELF" ]; then + echo -e "${YELLOW}[Lambda VM] Using pre-existing ethrex.elf at $ETHREX_ELF${NC}" +else + echo -e "${GREEN}[Lambda VM] Building ethrex guest ELF...${NC}" + make -C "$ROOT_DIR" executor/program_artifacts/rust/ethrex.elf 2>&1 | tail -5 +fi + +if [ ! -f "$ETHREX_ELF" ]; then + echo -e "${RED}[Lambda VM] Build failed — ethrex.elf not found at $ETHREX_ELF${NC}" + exit 1 +fi + +if [ ! -f "$ETHREX_INPUT" ]; then + echo -e "${RED}Input file not found: $ETHREX_INPUT${NC}" + exit 1 +fi + +# --- Run benchmark --------------------------------------------------- +echo "" +echo -e "${BOLD}--- Proving empty ethrex block ---${NC}" + +proof_file="$TMP_DIR/ethrex_empty_block.proof" +stderr_file="$TMP_DIR/ethrex_empty_block.stderr" + +echo -e " ${GREEN}[Lambda VM] Proving...${NC}" +if ! lambda_output=$("$CLI" prove "$ETHREX_ELF" \ + -o "$proof_file" \ + --private-input "$ETHREX_INPUT" \ + --time --cycles 2>"$stderr_file"); then + echo -e " ${RED}[Lambda VM] FAILED:${NC}" + cat "$stderr_file" + exit 1 +fi +rm -f "$proof_file" + +lambda_time=$(printf "%s\n" "$lambda_output" | extract_proving_time) +lambda_cycles=$(printf "%s\n" "$lambda_output" | extract_cycles) + +if [ -z "$lambda_time" ]; then + echo -e " ${RED}[Lambda VM] FAILED: could not parse proving time${NC}" + printf "%s\n" "$lambda_output" + exit 1 +fi +if [ -z "$lambda_cycles" ]; then + lambda_cycles="n/a" +fi + +if [ "$lambda_cycles" != "n/a" ]; then + echo -e " Lambda VM: ${BOLD}${lambda_time}s${NC} (${lambda_cycles} cycles)" +else + echo -e " Lambda VM: ${BOLD}${lambda_time}s${NC}" +fi + +if [ -n "$REPORT_DIR" ]; then + printf "%s\n" "$lambda_output" > "$REPORT_DIR/raw/ethrex_empty_block.stdout" + cp "$stderr_file" "$REPORT_DIR/raw/ethrex_empty_block.stderr" +fi + +# --- Summary table ---------------------------------------------------------- + +echo "" +echo -e "${BOLD}=== Summary ===${NC}" +echo -e "Program: ethrex empty block" +echo "" + +printf " %-22s %14s %14s\n" "Program" "Lambda (s)" "Lambda cycles" +printf " %-22s %14s %14s\n" "----------------------" "----------" "-------------" +printf " %-22s %13ss %14s\n" "ethrex empty block" "$lambda_time" "$lambda_cycles" + +echo "" +echo -e "Timing window covers single-shot end-to-end proving; excludes verification." +echo "Raw data in $TMP_DIR/" + +# --- Machine-readable report ------------------------------------------------ + +if [ -n "$REPORT_DIR" ]; then + { + echo "program=ethrex_empty_block" + echo "input_file=$ETHREX_INPUT" + echo "timing_window=single_shot_end_to_end_prove_no_verify" + echo "ethrex_empty_block_time_s=$lambda_time" + echo "ethrex_empty_block_cycles=$lambda_cycles" + } > "$REPORT_DIR/ethrex_metrics.txt" + + { + echo "# Ethrex Empty Block — Lambda VM" + echo + echo "Timing window: \`single-shot end-to-end prove\` (excludes verification)." + echo + echo "| Program | Lambda VM (s) | Lambda cycles |" + echo "|---------|--------------:|--------------:|" + printf "| ethrex empty block | %s | %s |\n" "$lambda_time" "$lambda_cycles" + } > "$REPORT_DIR/ethrex_summary.md" +fi diff --git a/bin/cli/README.md b/bin/cli/README.md index 9ce2c2674..da82620c0 100644 --- a/bin/cli/README.md +++ b/bin/cli/README.md @@ -4,46 +4,88 @@ A command-line interface for executing, proving, and verifying RISC-V ELF progra ## Installation -```bash -cargo build -p cli --release +```sh +cargo build --release -p cli ``` The binary will be available at `target/release/cli`. +## Producing an ELF + +The CLI consumes RISC-V ELF binaries. The repo ships ready-to-use guest programs that you can compile with: + +```sh +# RISC-V assembly tests → executor/program_artifacts/asm/*.elf +make compile-programs-asm + +# Rust guest programs → executor/program_artifacts/rust/*.elf (needs the sysroot + nightly toolchain) +make compile-programs-rust + +# Benchmark programs → executor/program_artifacts/bench/*.elf (needs the sysroot + nightly toolchain) +make compile-bench +``` + +See the root [`README.md`](../../README.md) for the toolchain setup. + ## Commands ### Execute Run a RISC-V ELF program without generating a proof. Useful for testing and debugging. -```bash -cargo run -p cli --release -- execute +```sh +cargo run -p cli --release -- execute [--private-input ] [--flamegraph ] ``` -See [Guest Program Flamegraphs](#guest-program-flamegraphs) for profiling execution. +| Flag | Description | +|---|---| +| `--private-input ` | Pass private input bytes to the guest (read via `get_private_input()`). | +| `--flamegraph ` | Generate folded-stack flamegraph output. See [Guest Program Flamegraphs](#guest-program-flamegraphs). | ### Prove Generate a STARK proof for a RISC-V ELF program execution. -```bash -cargo run -p cli --release -- prove -o proof.bin +```sh +cargo run -p cli --release -- prove -o proof.bin [flags] ``` +| Flag | Description | +|---|---| +| `-o, --output ` | Output path for the serialized proof bundle. Required. | +| `--private-input ` | Pass private input bytes to the guest. | +| `--blowup ` | FRI blowup factor (power of 2). Higher = fewer queries, smaller proof, slower proving. [default: 2] | +| `--time` | Print total proving time. | +| `--cycles` | Run one extra pre-pass outside the timer and print the dynamic instruction count. | +| `--elements` | Build traces and print main-trace and aux-trace field element counts. | + ### Verify -Verify a proof generated by the `prove` command. +Verify a proof generated by `prove`. -```bash -cargo run -p cli --release -- verify proof.bin +```sh +cargo run -p cli --release -- verify [flags] ``` -Returns exit code 0 on successful verification, 1 on failure. +| Flag | Description | +|---|---| +| `--blowup ` | FRI blowup factor used during proving. Must match. [default: 2] | +| `--time` | Print verification time. | + +Returns exit code `0` on successful verification, `1` on failure. + +### Count Elements + +Build traces and print main-trace and aux-trace field element counts **without** running the proof step. Useful for sizing. + +```sh +cargo run -p cli --release -- count-elements [--private-input ] +``` ## Examples -```bash -# Compile test programs (if not already done) +```sh +# Compile the bundled assembly tests make compile-programs-asm # Execute a simple program @@ -52,6 +94,9 @@ cargo run -p cli --release -- execute executor/program_artifacts/asm/add.elf # Generate and verify a proof cargo run -p cli --release -- prove executor/program_artifacts/asm/add.elf -o /tmp/proof.bin cargo run -p cli --release -- verify /tmp/proof.bin executor/program_artifacts/asm/add.elf + +# Prove with private input and print metrics +cargo run -p cli --release -- prove program.elf -o /tmp/proof.bin --private-input input.bin --time --cycles ``` ## Guest Program Flamegraphs @@ -60,7 +105,7 @@ Generate flamegraphs showing where the guest RISC-V program spends its execution ### Generate Folded Stacks -```bash +```sh cargo run -p cli --release -- execute --flamegraph folded.txt ``` @@ -68,7 +113,7 @@ cargo run -p cli --release -- execute --flamegraph folded.txt Requires [inferno](https://github.com/jonhoo/inferno) or [flamegraph.pl](https://github.com/brendangregg/FlameGraph): -```bash +```sh # Install inferno (one-time) cargo install inferno @@ -78,7 +123,7 @@ cat folded.txt | inferno-flamegraph > flamegraph.svg ### Example -```bash +```sh # Generate flamegraph for quicksort benchmark cargo run -p cli --release -- execute executor/program_artifacts/bench/quicksort.elf --flamegraph /tmp/quicksort.txt cat /tmp/quicksort.txt | inferno-flamegraph --title "quicksort" > quicksort_flamegraph.svg diff --git a/bin/cli/src/main.rs b/bin/cli/src/main.rs index aa633ef9a..f166e751d 100644 --- a/bin/cli/src/main.rs +++ b/bin/cli/src/main.rs @@ -125,7 +125,7 @@ enum Commands { private_input: Option, /// Blowup factor (power of 2). Higher = fewer queries, smaller proof, slower proving. - #[arg(long)] + #[arg(long, default_value = "2")] blowup: Option, /// Print proving time @@ -153,10 +153,10 @@ enum Commands { elf: PathBuf, /// Blowup factor used during proving (must match) - #[arg(long)] + #[arg(long, default_value = "2")] blowup: Option, - /// Print timing breakdown + /// Print verification time #[arg(long)] time: bool, }, diff --git a/crypto/crypto/src/merkle_tree/merkle.rs b/crypto/crypto/src/merkle_tree/merkle.rs index 789adf1b6..a9bdd8d46 100644 --- a/crypto/crypto/src/merkle_tree/merkle.rs +++ b/crypto/crypto/src/merkle_tree/merkle.rs @@ -105,6 +105,13 @@ where }) } + /// Read-only access to the full node buffer in standard layout: + /// `nodes[0..leaves_len - 1]` are inner nodes (root at index 0) and + /// `nodes[leaves_len - 1..]` are the leaves. + pub fn nodes(&self) -> &[B::Node] { + &self.nodes + } + /// Returns a Merkle proof for the element/s at position pos /// For example, give me an inclusion proof for the 3rd element in the /// Merkle tree diff --git a/crypto/math-cuda/Cargo.toml b/crypto/math-cuda/Cargo.toml index 8c22d1110..e700ec73e 100644 --- a/crypto/math-cuda/Cargo.toml +++ b/crypto/math-cuda/Cargo.toml @@ -9,14 +9,17 @@ cudarc = { version = "0.19", default-features = false, features = [ "driver", "nvrtc", "std", - "cuda-12080", + "cuda-version-from-build-system", + "fallback-latest", "dynamic-loading", ] } math = { path = "../math" } rayon = "1.7" [dev-dependencies] +crypto = { path = "../crypto" } rand = { version = "0.8.5", features = ["std"] } rand_chacha = "0.3.1" rayon = "1.7" sha3 = "0.10.8" +stark = { path = "../stark" } diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs index 7c417fb9c..bc84f2653 100644 --- a/crypto/math-cuda/build.rs +++ b/crypto/math-cuda/build.rs @@ -14,6 +14,38 @@ fn nvcc_path() -> PathBuf { cuda_home().join("bin").join("nvcc") } +/// Query `nvidia-smi` for the local GPU's compute capability (e.g. "12.0" +/// for Blackwell). Returns a `compute_XX` target on success, falling back +/// to `compute_89` (Ada) when no GPU is visible or the query fails. +fn detect_arch() -> String { + const FALLBACK: &str = "compute_89"; + let output = match Command::new("nvidia-smi") + .args(["--query-gpu=compute_cap", "--format=csv,noheader"]) + .output() + { + Ok(o) if o.status.success() => o, + _ => return FALLBACK.to_string(), + }; + let line = match std::str::from_utf8(&output.stdout) { + Ok(s) => s, + Err(_) => return FALLBACK.to_string(), + }; + // First line, first comma-separated value (covers multi-GPU hosts). + let cap = match line.lines().next() { + Some(l) => l.split(',').next().unwrap_or("").trim(), + None => return FALLBACK.to_string(), + }; + let (major, minor) = match cap.split_once('.') { + Some((m, n)) => (m.trim(), n.trim()), + None => return FALLBACK.to_string(), + }; + if major.chars().all(|c| c.is_ascii_digit()) && minor.chars().all(|c| c.is_ascii_digit()) { + format!("compute_{major}{minor}") + } else { + FALLBACK.to_string() + } +} + fn compile_ptx(src: &str, out_name: &str, have_nvcc: bool) { let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); @@ -25,20 +57,22 @@ fn compile_ptx(src: &str, out_name: &str, have_nvcc: bool) { println!("cargo:rerun-if-env-changed=CUDA_PATH"); println!("cargo:rerun-if-env-changed=CUDARC_NVCC_ARCH"); - // No nvcc on PATH → emit an empty PTX stub so the crate still compiles. - // include_str! in src/device.rs needs the file to exist at build time. - // Any runtime kernel call will then panic from cudarc when loading the - // empty module — which is the right failure mode (we can't run GPU code - // without nvcc on the build host anyway). + // When nvcc is missing from PATH, emit an empty PTX stub so the crate + // still compiles. include_str! in src/device.rs needs the file to exist + // at build time. Any runtime kernel call panics in cudarc when loading + // the empty module. We can't run GPU code without nvcc on the build + // host anyway. if !have_nvcc { fs::write(&out_path, "").expect("failed to write empty PTX stub"); return; } // Emit PTX for a virtual architecture; the CUDA driver JIT-compiles it for the - // actual GPU at load time, so one PTX works across Ada/Hopper/Blackwell. Override - // with CUDARC_NVCC_ARCH to pin a specific compute capability. - let arch = env::var("CUDARC_NVCC_ARCH").unwrap_or_else(|_| "compute_89".to_string()); + // actual GPU at load time. Override with CUDARC_NVCC_ARCH to pin a specific + // compute capability. If unset, try `nvidia-smi` to match the host GPU + // (avoids JIT failures like nvcc-13.0 PTX rejected on Blackwell drivers); + // fall back to compute_89 (Ada) when detection fails. + let arch = env::var("CUDARC_NVCC_ARCH").unwrap_or_else(|_| detect_arch()); let status = Command::new(nvcc_path()) .args(["--ptx", "-O3", "-std=c++17", "-arch", &arch, "-o"]) @@ -53,13 +87,14 @@ fn compile_ptx(src: &str, out_name: &str, have_nvcc: bool) { } fn main() { - // Headers are not compiled; emit rerun-if-changed so edits trigger rebuilds. + // Headers aren't compiled, so emit rerun-if-changed to rebuild on + // header edits. println!("cargo:rerun-if-changed=kernels/goldilocks.cuh"); println!("cargo:rerun-if-changed=kernels/ext3.cuh"); // Probe for nvcc once. Workspace consumers (clippy, fmt, CPU-only test - // runners) build math-cuda incidentally without using its kernels; allow - // those to succeed by stubbing out PTX when nvcc is unavailable. + // runners) build math-cuda incidentally without using its kernels. Stub + // out PTX when nvcc is unavailable so those builds succeed. let have_nvcc = Command::new(nvcc_path()) .arg("--version") .output() diff --git a/crypto/math-cuda/kernels/arith.cu b/crypto/math-cuda/kernels/arith.cu index 4bee9b8bb..b1a6bb8ab 100644 --- a/crypto/math-cuda/kernels/arith.cu +++ b/crypto/math-cuda/kernels/arith.cu @@ -1,4 +1,4 @@ -// Element-wise Goldilocks kernels used by the Phase-2 parity tests. These mirror +// Element-wise Goldilocks kernels used by the parity tests. These mirror // the CPU reference in `crypto/math/src/field/goldilocks.rs` so raw u64 outputs // are bit-identical to the CPU path. @@ -81,3 +81,17 @@ extern "C" __global__ void ext3_add_kernel(const uint64_t *a_int, c_int[tid*3 + 1] = r.b; c_int[tid*3 + 2] = r.c; } + +extern "C" __global__ void ext3_sub_kernel(const uint64_t *a_int, + const uint64_t *b_int, + uint64_t *c_int, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + ext3::Fe3 a = ext3::make(a_int[tid*3 + 0], a_int[tid*3 + 1], a_int[tid*3 + 2]); + ext3::Fe3 b = ext3::make(b_int[tid*3 + 0], b_int[tid*3 + 1], b_int[tid*3 + 2]); + ext3::Fe3 r = ext3::sub(a, b); + c_int[tid*3 + 0] = r.a; + c_int[tid*3 + 1] = r.b; + c_int[tid*3 + 2] = r.c; +} diff --git a/crypto/math-cuda/kernels/keccak.cu b/crypto/math-cuda/kernels/keccak.cu index 68ddce3b4..c22bc4d05 100644 --- a/crypto/math-cuda/kernels/keccak.cu +++ b/crypto/math-cuda/kernels/keccak.cu @@ -1,7 +1,7 @@ -// CUDA Keccak-256 (original Keccak, NOT SHA3-256 — uses 0x01 padding delimiter). +// Original Keccak-256 (0x01 padding) // // Used by the lambda-vm prover's Merkle commit: -// leaf = Keccak-256(concat(col_0[br_idx].to_be_bytes(), col_1[br_idx].to_be_bytes(), …)) +// leaf = Keccak-256(concat(col_0[br_idx].to_be_bytes(), col_1[br_idx].to_be_bytes(), ...)) // where `br_idx = bit_reverse(row_idx, log_num_rows)` and each element is // written in BIG-ENDIAN canonical form (per `FieldElement::write_bytes_be`). // @@ -49,7 +49,7 @@ __device__ __forceinline__ uint64_t bswap64(uint64_t x) { __device__ __forceinline__ void keccak_f1600(uint64_t st[25]) { uint64_t C[5], D[5], B[25]; - #pragma unroll + // No outer unroll: fully unrolling the 24 rounds slowed the kernel ~7.5% on RTX 5090. for (int r = 0; r < 24; ++r) { // Theta #pragma unroll @@ -96,9 +96,9 @@ __device__ __forceinline__ void keccak_f1600(uint64_t st[25]) { } // --------------------------------------------------------------------------- -// Helper: absorb one 8-byte lane (already in lane form — i.e. LE interpretation -// of the BE-serialised u64) into the sponge at `rate_pos` (in bytes). Permutes -// when a full 136-byte block has been absorbed. +// Helper: absorb one 8-byte lane (already byte-swapped from BE serialisation +// into Keccak's LE lane form) into the sponge at `rate_pos` (in bytes). +// Permutes when a full 136-byte block has been absorbed. // --------------------------------------------------------------------------- __device__ __forceinline__ void absorb_lane(uint64_t st[25], uint32_t &rate_pos, @@ -172,7 +172,7 @@ extern "C" __global__ void keccak256_leaves_base_batched( uint64_t v = columns_base_ptr[c * col_stride + br]; // Canonicalise to match `canonical_u64().to_be_bytes()` on host. uint64_t canon = goldilocks::canonical(v); - // The on-disk leaf bytes are canon.to_be_bytes(); Keccak reads those + // The on-disk leaf bytes are canon.to_be_bytes(). Keccak reads those // as a LE lane, which equals bswap64(canon). uint64_t lane = bswap64(canon); absorb_lane(st, rate_pos, lane); @@ -275,8 +275,8 @@ extern "C" __global__ void keccak_comp_poly_leaves_ext3( // // Each leaf hashes 2 consecutive ext3 values: Keccak256 over // evals[2j].to_bytes_be() ++ evals[2j+1].to_bytes_be() -// = 48 BE bytes = 6 u64 BE lanes. No bit reversal; no column slab layout — -// the input is a single interleaved ext3 eval vector `[a0,a1,a2,b0,b1,b2,...]`. +// = 48 BE bytes = 6 u64 BE lanes. No bit reversal, no column slab layout. +// The input is a single interleaved ext3 eval vector `[a0,a1,a2,b0,b1,b2,...]`. // --------------------------------------------------------------------------- extern "C" __global__ void keccak_fri_leaves_ext3( const uint64_t *evals_interleaved, // 3 * num_evals u64s (ext3 interleaved) @@ -311,14 +311,14 @@ extern "C" __global__ void keccak_fri_leaves_ext3( // // `nodes` is the full Merkle node buffer (length `2*leaves_len - 1`, each // element 32 bytes). `parent_begin` is the node-index offset of the first -// parent slot in this level; children live at `parent_begin + n_pairs`. +// parent slot in this level. Children live at `parent_begin + n_pairs`. // The layout mirrors `crypto/crypto/src/merkle_tree/merkle.rs`: // // children: nodes[parent_begin + n_pairs .. parent_begin + 3 * n_pairs] // parents: nodes[parent_begin .. parent_begin + n_pairs] // // Each thread hashes one child pair → one parent. Keccak-256 of the -// concatenation of two 32-byte siblings; identical to +// concatenation of two 32-byte siblings, identical to // `FieldElementVectorBackend::hash_new_parent` on host. // --------------------------------------------------------------------------- extern "C" __global__ void keccak_merkle_level( @@ -333,6 +333,8 @@ extern "C" __global__ void keccak_merkle_level( for (int i = 0; i < 25; ++i) st[i] = 0; uint32_t rate_pos = 0; + // `nodes` comes from cuMemAlloc (256-byte aligned); each 32-byte node + // sits at a 32-byte-aligned offset, so the u64 cast is safe. const uint64_t *left = reinterpret_cast( nodes + (parent_begin + n_pairs + 2 * tid) * 32); #pragma unroll diff --git a/crypto/math-cuda/kernels/ntt.cu b/crypto/math-cuda/kernels/ntt.cu index 2a5c8c786..cf5e1df2c 100644 --- a/crypto/math-cuda/kernels/ntt.cu +++ b/crypto/math-cuda/kernels/ntt.cu @@ -1,5 +1,6 @@ -// Radix-2 DIT NTT over Goldilocks. One kernel per butterfly level; the caller -// runs `bit_reverse_permute` once before the first level. +// Radix-2 DIT NTT over Goldilocks: per-level, fused 8-level (shmem), and +// batched (multi-column) variants. The caller runs `bit_reverse_permute` +// once before the first butterfly level. // // Input layout: bit-reversed-order coefficients (after `bit_reverse_permute`). // Output layout: natural-order evaluations — matches the CPU `evaluate_fft` contract. @@ -179,8 +180,9 @@ extern "C" __global__ void scalar_mul_batched(uint64_t *data, /// One DIT butterfly level. Thread `tid` (of n/2 total) owns exactly one /// butterfly pair (i0, i1 = i0 + half). Twiddle picked from the shared full -/// `tw` table at stride `n / block_size`. Kept for log_n < 8 where shmem -/// fusion is overkill. +/// `tw` table at stride `n / block_size`. Used for levels 0..7 when n < 256 +/// (shmem fusion needs at least 256 elements), and for levels >= 8 of any +/// size (above the shmem-fusion window). extern "C" __global__ void ntt_dit_level(uint64_t *x, const uint64_t *tw, uint64_t n, @@ -216,7 +218,7 @@ extern "C" __global__ void ntt_dit_level(uint64_t *x, /// without writing to global memory between them — cuts DRAM traffic by up /// to 8× vs the per-level kernel. /// -/// `base_step` selects which 8-level window this launch handles (0, 8, 16…). +/// `base_step` selects which 8-level window this launch handles (0, 8, 16, ...). /// For levels 0–7 the implicit DIT element layout already places all pair /// mates inside the same 256-block; for higher base_step we remap the loaded /// row so pair mates land in consecutive shared-memory slots. diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index 8e956eef3..d6d5fc403 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -15,8 +15,9 @@ use math::field::traits::IsFFTField; use crate::Result; use crate::ntt::{twiddles_forward, twiddles_inverse}; -/// Reusable pinned host staging buffer. One per stream; the stream's LDE call -/// holds its buffer's lock across the D2H + memcpy-to-user-Vecs window. +/// Reusable pinned host staging buffer. Shared across all streams via a +/// `Mutex` (see `Backend::pinned_staging`); the LDE call holds the lock +/// across the D2H + memcpy-to-user-Vecs window. /// /// Allocated with `cuMemHostAlloc(flags=0)` — portable, non-write-combined, /// so both DMA writes from device and CPU reads into user Vecs run at full @@ -108,9 +109,8 @@ pub struct Backend { /// cost for every first use of a new table size). pinned_staging: Mutex, /// Separate pinned staging for Merkle leaf hashes. Sized `num_rows * 32` - /// bytes; lives alongside the LDE staging so the GPU→host D2H for - /// hashed leaves runs at full PCIe line-rate instead of the pageable - /// ~1.3 GB/s path that would otherwise eat ~100 ms per main-trace commit. + /// bytes. It lives alongside the LDE staging so the GPU→host D2H for + /// hashed leaves runs at full PCIe line-rate. pinned_hashes: Mutex, util_stream: Arc, next: AtomicUsize, @@ -123,6 +123,7 @@ pub struct Backend { pub gl_neg: CudaFunction, pub ext3_mul: CudaFunction, pub ext3_add: CudaFunction, + pub ext3_sub: CudaFunction, // ntt.ptx pub bit_reverse_permute: CudaFunction, @@ -171,8 +172,9 @@ impl Backend { // not part of the pool that callers rotate through. let util_stream = ctx.new_stream()?; - // Goldilocks TWO_ADICITY is 32, so log_n ≤ 32 covers every LDE size - // the prover can produce. Overshoot by one for safety. + // Cache is indexed by log_n. Valid range is [0, TWO_ADICITY] since + // Goldilocks has roots of unity for orders 2^0..=2^TWO_ADICITY only. + // Length = TWO_ADICITY + 1 to allow indexing at log_n = TWO_ADICITY. let max_log = GoldilocksField::TWO_ADICITY as usize + 1; Ok(Self { @@ -183,6 +185,7 @@ impl Backend { gl_neg: arith.load_function("gl_neg_kernel")?, ext3_mul: arith.load_function("ext3_mul_kernel")?, ext3_add: arith.load_function("ext3_add_kernel")?, + ext3_sub: arith.load_function("ext3_sub_kernel")?, bit_reverse_permute: ntt.load_function("bit_reverse_permute")?, ntt_dit_level: ntt.load_function("ntt_dit_level")?, ntt_dit_8_levels: ntt.load_function("ntt_dit_8_levels")?, @@ -223,7 +226,7 @@ impl Backend { } /// Separate pinned staging for Merkle leaf hash output. Sized in u64 - /// units; caller should reserve `(num_rows * 32 + 7) / 8` u64s. + /// units. Caller should reserve `(num_rows * 32 + 7) / 8` u64s. pub fn pinned_hashes(&self) -> &Mutex { &self.pinned_hashes } @@ -243,6 +246,14 @@ impl Backend { } else { &self.inv_twiddles }; + // Cache is sized TWO_ADICITY + 1 in `Backend::init`. Callers derive + // log_n from `trailing_zeros` of valid Goldilocks domain sizes so it + // must stay in range; assert in debug to catch regressions. + debug_assert!( + log_n <= GoldilocksField::TWO_ADICITY, + "log_n {log_n} exceeds Goldilocks TWO_ADICITY ({})", + GoldilocksField::TWO_ADICITY, + ); { let guard = cache.lock().unwrap(); if let Some(t) = &guard[idx] { @@ -268,7 +279,19 @@ impl Backend { } } -pub fn backend() -> &'static Backend { +/// Returns the process-wide CUDA backend, initialising it on first call. +/// +/// Returns `Err` when CUDA initialisation fails (no driver, no GPU, PTX load +/// failure). Initialisation is retried on the next call until one succeeds — +/// only a successful `Backend` is cached. The race window where two threads +/// init concurrently is harmless: at most one extra `Backend::init()` runs +/// and the loser is dropped. +pub fn backend() -> Result<&'static Backend> { static BACKEND: OnceLock = OnceLock::new(); - BACKEND.get_or_init(|| Backend::init().expect("failed to initialise CUDA backend")) + if let Some(b) = BACKEND.get() { + return Ok(b); + } + let b = Backend::init()?; + let _ = BACKEND.set(b); + Ok(BACKEND.get().expect("backend just initialised")) } diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index ccf5abb1d..02f109938 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -4,7 +4,7 @@ //! Input : N evaluations (natural order) of a poly on the standard subgroup, //! plus coset weights (size N). The weights include the `1/N` iFFT //! normalisation, matching the `LdeTwiddles::coset_weights` format at -//! `crypto/stark/src/prover.rs:248` — i.e. `weights[i] = g^i / N`. +//! `crypto/stark/src/prover.rs` — i.e. `weights[i] = g^i / N`. //! Output : N*blowup_factor evaluations (natural order) on the coset. //! //! On-device steps, picks a stream from the shared pool so rayon-parallel @@ -12,13 +12,217 @@ use std::sync::Arc; -use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; +use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; +use rayon::prelude::*; use crate::Result; -use crate::device::backend; -use crate::merkle::{launch_keccak_base, launch_keccak_ext3}; +use crate::device::{Backend, backend}; +use crate::merkle::{keccak_launch_cfg, launch_keccak_base, launch_keccak_ext3}; use crate::ntt::run_ntt_body; +/// Goldilocks `TWO_ADICITY = 32` puts the theoretical domain ceiling at +/// `2^32`, where a downstream `as u32` cast would silently truncate to zero +/// and the corresponding kernel launch would do nothing. Assert at each +/// public entry point before any cast that depends on it. +#[inline] +fn assert_u32_domain(n: usize, what: &str) { + assert!( + n <= u32::MAX as usize, + "{what}: {n} exceeds u32 range — kernel grid would silently truncate", + ); +} + +/// Output shape requested from the fused LDE + Keccak entry points. +#[derive(Copy, Clone, PartialEq, Eq)] +enum KeccakCommit { + /// Only the `lde_size` keccak-256 leaves; no inner-tree build. Caller + /// receives `lde_size * 32` bytes. + LeavesOnly, + /// Full Merkle tree: leaves at the tail + inner nodes built on-device. + /// Caller receives `(2*lde_size - 1) * 32` bytes. + FullTree, +} + +impl KeccakCommit { + fn total_nodes_bytes(self, lde_size: usize) -> usize { + match self { + KeccakCommit::LeavesOnly => lde_size * 32, + KeccakCommit::FullTree => (2 * lde_size - 1) * 32, + } + } + + fn leaves_offset_bytes(self, lde_size: usize) -> usize { + match self { + KeccakCommit::LeavesOnly => 0, + KeccakCommit::FullTree => (lde_size - 1) * 32, + } + } +} + +/// De-interleave `columns` (each `3*n` u64s, ext3-per-element layout +/// `[a, b, c, a, b, c, ...]`) into `pinned` as `3*m` base-field slabs. +/// Component `k` of column `c` lands at `pinned[(c*3 + k)*n .. (c*3 + k)*n + n]`. +/// +/// Caller invariants: `pinned.len() >= 3 * columns.len() * n` and each +/// `columns[c].len() >= 3 * n`. The caller must hold the pinned-staging lock. +pub(crate) fn pack_ext3_to_pinned_slabs(columns: &[&[u64]], pinned: &mut [u64], n: usize) { + let m = columns.len(); + debug_assert!(pinned.len() >= 3 * m * n); + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: each task writes to disjoint `[(c*3 + k)*n .. ..+n]` regions + // of `pinned`. The outer `&mut [u64]` borrow guarantees no aliasing. + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); +} + +/// Re-interleave the `3*m` base-field slabs in `pinned` (layout matches +/// `pack_ext3_to_pinned_slabs`) into `outputs`, writing each as +/// `3*lde_size` interleaved u64s. +fn unpack_pinned_slabs_to_ext3(pinned: &[u64], outputs: &mut [&mut [u64]], lde_size: usize) { + let m = outputs.len(); + debug_assert!(pinned.len() >= 3 * m * lde_size); + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + // SAFETY: each task reads from disjoint `[(c*3 + k)*lde_size .. ..+lde_size]` + // regions of `pinned`. Caller borrows `pinned` for the duration of the call. + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); +} + +/// Run `bit_reverse_permute_batched` over `m` columns of length `n` each +/// (column stride `col_stride`). 256 threads per block, grid sized to cover +/// `n` per column. +fn launch_bit_reverse_batched( + stream: &CudaStream, + be: &Backend, + buf: &mut CudaSlice, + n: u64, + log_n: u64, + col_stride: u64, + m: u32, +) -> Result<()> { + let cfg = LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(buf) + .arg(&n) + .arg(&log_n) + .arg(&col_stride) + .launch(cfg)?; + } + Ok(()) +} + +/// D2H `dst.len()` bytes from `dev_bytes` into the caller's pageable `dst` +/// via the pinned-hashes staging buffer. Synchronises the stream first (so +/// any other D2H queued on the same stream also drains), then does a rayon +/// chunked memcpy pinned → caller to spread page-fault cost across cores. +fn d2h_bytes_via_pinned_hashes( + stream: &Arc, + be: &Backend, + dev_bytes: &CudaSlice, + dst: &mut [u8], +) -> Result<()> { + let n_bytes = dst.len(); + let u64_len = n_bytes.div_ceil(8); + let staging_slot = be.pinned_hashes(); + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(u64_len, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(u64_len) }; + // Reinterpret the u64 pinned buffer as bytes — same allocation, just + // typed differently. SAFETY: u64 has stricter alignment than u8 and the + // byte length fits in the `u64_len` capacity (rounded up to u64). + let pinned_bytes: &mut [u8] = + unsafe { std::slice::from_raw_parts_mut(pinned.as_mut_ptr() as *mut u8, n_bytes) }; + stream.memcpy_dtoh(dev_bytes, pinned_bytes)?; + stream.synchronize()?; + + // Single-threaded `copy_from_slice` faults virgin pageable pages one at + // a time; the mm_struct rwsem serialises them at prover scale. Chunk so + // ~N cores pre-fault+write in parallel. + const CHUNK: usize = 64 * 1024; + let src_ptr = pinned_bytes.as_ptr() as usize; + dst.par_chunks_mut(CHUNK).enumerate().for_each(|(i, d)| { + // SAFETY: each task reads `[i*CHUNK .. i*CHUNK + d.len()]` of + // `pinned_bytes`, which is disjoint per `i` and lives until `staging` + // is dropped below. + let src = + unsafe { std::slice::from_raw_parts((src_ptr as *const u8).add(i * CHUNK), d.len()) }; + d.copy_from_slice(src); + }); + drop(staging); + Ok(()) +} + +/// Run `pointwise_mul_batched`: `buf[c*col_stride + i] *= weights[i]` for +/// `m` columns, `n` elements each. +fn launch_pointwise_mul_batched( + stream: &CudaStream, + be: &Backend, + buf: &mut CudaSlice, + weights: &CudaSlice, + n: u64, + col_stride: u64, + m: u32, +) -> Result<()> { + let cfg = LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(buf) + .arg(weights) + .arg(&n) + .arg(&col_stride) + .launch(cfg)?; + } + Ok(()) +} + /// Handle to a base-field LDE kept live on device after R1 commit. /// Layout: `m` columns, each `lde_size` u64s, column `c` at byte offset /// `c * lde_size * 8` within `buf`. Freed when `buf` Arc drops. @@ -41,20 +245,23 @@ pub struct GpuLdeExt3 { pub fn coset_lde_base(evals: &[u64], blowup_factor: usize, weights: &[u64]) -> Result> { let n = evals.len(); + // Empty input must short-circuit before the power-of-two assert + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(Vec::new()); + } assert!(n.is_power_of_two(), "evals length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match evals"); assert!( blowup_factor.is_power_of_two(), "blowup must be power of two" ); - if n == 0 { - return Ok(Vec::new()); - } let lde_size = n * blowup_factor; + assert_u32_domain(lde_size, "coset_lde_base lde_size"); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); // Device buffer of lde_size, zero-padded tail, first N filled by copy. @@ -128,6 +335,11 @@ pub fn coset_lde_batch_base( } let m = columns.len(); let n = columns[0].len(); + // Empty columns must short-circuit before the power-of-two assert + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(vec![Vec::new(); m]); + } assert!(n.is_power_of_two(), "column length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match column length"); assert!( @@ -137,33 +349,15 @@ pub fn coset_lde_batch_base( for c in columns.iter() { assert_eq!(c.len(), n, "all columns must be the same size"); } - - if n == 0 { - return Ok(vec![Vec::new(); m]); - } let lde_size = n * blowup_factor; + assert_u32_domain(lde_size, "coset_lde_batch_base lde_size"); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); - let debug_phases = std::env::var("MATH_CUDA_PHASE_TIMING").is_ok(); - let t_start = if debug_phases { - Some(std::time::Instant::now()) - } else { - None - }; - let phase = |label: &str, prev: &mut Option| { - if let Some(p) = prev.as_ref() { - let now = std::time::Instant::now(); - eprintln!(" [{:>6.2} ms] {}", (now - *p).as_secs_f64() * 1e3, label); - *prev = Some(now); - } - }; - let mut last = t_start; - // Pinned staging. Lock and grow to max(m*n for upload, m*lde_size for // download). Holding the guard across the whole call serialises concurrent // batched calls that happened to hash to the same stream slot, but that's @@ -172,14 +366,10 @@ pub fn coset_lde_batch_base( staging.ensure_capacity(m * lde_size, &be.ctx)?; // SAFETY: staging is locked, the slice alias ends before we unlock. let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - if debug_phases { - phase("staging lock + grow", &mut last); - } // Pack columns into first m*n slots of the pinned buffer. Parallel: pinned - // writes are DRAM-bandwidth bound, saturates at ~8 cores on modern - // hardware, so rayon shaves 20+ ms at prover scale. - use rayon::prelude::*; + // writes are DRAM-bandwidth bound, so rayon spreads the cost across CPU + // cores. let pinned_base_ptr = pinned.as_mut_ptr() as usize; columns.par_iter().enumerate().for_each(|(c, col)| { // SAFETY: each task writes to a disjoint `[c*n..c*n+n]` region of @@ -189,35 +379,20 @@ pub fn coset_lde_batch_base( unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; dst.copy_from_slice(col); }); - if debug_phases { - phase("host pack (pinned, rayon)", &mut last); - } // Column layout: `buf[c * lde_size + r]`. Zeroed so the [n, lde_size) // tail of each column is already the zero-pad the CPU path does. let mut buf = stream.alloc_zeros::(m * lde_size)?; - if debug_phases { - stream.synchronize()?; - phase("alloc_zeros", &mut last); - } // One memcpy per column from the pinned buffer into the strided slots. // The pinned source hits PCIe line-rate. for c in 0..m { let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; } - if debug_phases { - stream.synchronize()?; - phase("H2D cols (pinned)", &mut last); - } let inv_tw = be.inv_twiddles_for(log_n)?; let fwd_tw = be.fwd_twiddles_for(log_lde)?; let weights_dev = stream.clone_htod(weights)?; - if debug_phases { - stream.synchronize()?; - phase("twiddles + weights", &mut last); - } let n_u64 = n as u64; let lde_u64 = lde_size as u64; @@ -225,28 +400,16 @@ pub fn coset_lde_batch_base( let m_u32 = m as u32; // === 1. Bit-reverse first N of every column === - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; - if debug_phases { - stream.synchronize()?; - phase("bit_reverse N", &mut last); - } // === 2. iNTT body over all columns === run_batched_ntt_body( stream.as_ref(), @@ -257,53 +420,29 @@ pub fn coset_lde_batch_base( col_stride_u64, m_u32, )?; - if debug_phases { - stream.synchronize()?; - phase("iNTT body", &mut last); - } // === 3. Pointwise multiply by coset weights (includes 1/N) === - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + m_u32, + )?; // === 4. Bit-reverse full LDE of every column === - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; - if debug_phases { - stream.synchronize()?; - phase("pointwise + bit_reverse LDE", &mut last); - } // === 5. Forward NTT on full LDE of every column === run_batched_ntt_body( stream.as_ref(), @@ -314,29 +453,22 @@ pub fn coset_lde_batch_base( col_stride_u64, m_u32, )?; - if debug_phases { - stream.synchronize()?; - phase("forward NTT body", &mut last); - } // Single big D2H into the reusable pinned staging buffer — pinned, one // call to the driver, saturates PCIe. stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; stream.synchronize()?; - if debug_phases { - phase("D2H (one shot into pinned)", &mut last); - } // Split pinned → per-column Vecs. The first write to each virgin - // Vec page-faults, which can dominate total time (~75 ms for 128 MB). - // Parallelise so the fault cost spreads across CPU cores. + // Vec page-faults, which can dominate total time. Parallelise so the + // fault cost spreads across CPU cores. let pinned_ptr = pinned.as_ptr() as usize; let out: Vec> = (0..m) .into_par_iter() .map(|c| { - // set_len skips the O(N) zero-init that vec![0; n] would do - // (saves ~75 ms per 128 MB at prover scale). copy_from_slice - // below writes every slot before any reader sees the Vec. + // set_len skips the O(N) zero-init that vec![0; n] would do. + // copy_from_slice below writes every slot before any reader + // sees the Vec. #[allow(clippy::uninit_vec)] let mut v = { let mut v = Vec::::with_capacity(lde_size); @@ -350,18 +482,15 @@ pub fn coset_lde_batch_base( v }) .collect(); - if debug_phases { - phase("copy out (rayon pinned → Vecs)", &mut last); - } drop(staging); Ok(out) } /// Like `coset_lde_batch_base` but writes directly into caller-provided /// output slices instead of allocating fresh `Vec`s. Each output slice -/// must already have length `n * blowup_factor`. Saves ~50–100 ms of pageable -/// allocator work + page faults at prover scale because the caller's Vecs -/// have been sized once and are reused across calls. +/// must already have length `n * blowup_factor`. Avoids pageable allocator +/// work and page faults at prover scale because the caller's Vecs have been +/// sized once and are reused across calls. pub fn coset_lde_batch_base_into( columns: &[&[u64]], blowup_factor: usize, @@ -374,6 +503,11 @@ pub fn coset_lde_batch_base_into( let m = columns.len(); assert_eq!(outputs.len(), m, "outputs must match columns count"); let n = columns[0].len(); + // Empty columns must short-circuit before the power-of-two assert + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(()); + } assert!(n.is_power_of_two(), "column length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match column length"); assert!( @@ -387,13 +521,11 @@ pub fn coset_lde_batch_base_into( for o in outputs.iter() { assert_eq!(o.len(), lde_size, "each output must be lde_size"); } - if n == 0 { - return Ok(()); - } + assert_u32_domain(lde_size, "coset_lde_batch_base_into lde_size"); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); @@ -421,23 +553,15 @@ pub fn coset_lde_batch_base_into( let m_u32 = m as u32; // iNTT bit-reverse + body, pointwise mul, forward bit-reverse + body. - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -447,40 +571,24 @@ pub fn coset_lde_batch_base_into( col_stride_u64, m_u32, )?; - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + m_u32, + )?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -496,8 +604,6 @@ pub fn coset_lde_batch_base_into( // Parallel copy pinned → caller outputs. Caller's Vecs may still fault // on first write; we spread that cost across rayon cores. - #[allow(unused_imports)] - use rayon::prelude::*; let pinned_ptr = pinned.as_ptr() as usize; outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { let src = unsafe { @@ -509,13 +615,12 @@ pub fn coset_lde_batch_base_into( Ok(()) } -/// Variant of `coset_lde_batch_base_into` that also emits the Keccak-256 -/// Merkle leaf hashes from the LDE output — all on GPU, no second H2D of -/// the LDE data. Leaves are computed reading columns at bit-reversed rows -/// (matching `commit_columns_bit_reversed` on the CPU side). -/// -/// `hashed_leaves_out` must be `lde_size * 32` bytes (one 32-byte digest -/// per output row, in natural row order). +/// Fused LDE + Keccak-256 leaf hashing. Caller receives the `lde_size * 32` +/// bytes of leaf hashes in `hashed_leaves_out` (one 32-byte digest per output +/// row, in natural row order; leaves are computed reading columns at +/// bit-reversed rows, matching `commit_columns_bit_reversed` on the CPU +/// side). Thin wrapper over `coset_lde_batch_base_into_with_merkle_tree_inner` +/// with `LeavesOnly` — no inner-tree build, no device handle. pub fn coset_lde_batch_base_into_with_leaf_hash( columns: &[&[u64]], blowup_factor: usize, @@ -523,172 +628,16 @@ pub fn coset_lde_batch_base_into_with_leaf_hash( outputs: &mut [&mut [u64]], hashed_leaves_out: &mut [u8], ) -> Result<()> { - if columns.is_empty() { - assert_eq!(outputs.len(), 0); - return Ok(()); - } - let m = columns.len(); - assert_eq!(outputs.len(), m); - let n = columns[0].len(); - assert!(n.is_power_of_two()); - assert_eq!(weights.len(), n); - assert!(blowup_factor.is_power_of_two()); - let lde_size = n * blowup_factor; - for o in outputs.iter() { - assert_eq!(o.len(), lde_size); - } - assert_eq!(hashed_leaves_out.len(), lde_size * 32); - let log_n = n.trailing_zeros() as u64; - let log_lde = lde_size.trailing_zeros() as u64; - - let be = backend(); - let stream = be.next_stream(); - let staging_slot = be.pinned_staging(); - - let mut staging = staging_slot.lock().unwrap(); - staging.ensure_capacity(m * lde_size, &be.ctx)?; - let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - - use rayon::prelude::*; - let pinned_base_ptr = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - // SAFETY: disjoint regions per c, outer staging lock held. - let dst = - unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; - dst.copy_from_slice(col); - }); - - let mut buf = stream.alloc_zeros::(m * lde_size)?; - for c in 0..m { - let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); - stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; - } - - let inv_tw = be.inv_twiddles_for(log_n)?; - let fwd_tw = be.fwd_twiddles_for(log_lde)?; - let weights_dev = stream.clone_htod(weights)?; - - let n_u64 = n as u64; - let lde_u64 = lde_size as u64; - let col_stride_u64 = lde_size as u64; - let m_u32 = m as u32; - - // iNTT - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } - run_batched_ntt_body( - stream.as_ref(), - &mut buf, - inv_tw.as_ref(), - n_u64, - log_n, - col_stride_u64, - m_u32, - )?; - // pointwise coset scale - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } - // forward NTT on full LDE slab - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((lde_size as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } - run_batched_ntt_body( - stream.as_ref(), - &mut buf, - fwd_tw.as_ref(), - lde_u64, - log_lde, - col_stride_u64, - m_u32, - )?; - - // Keccak-256 leaf hashing directly on the device LDE buffer. - let mut hashes_dev = stream.alloc_zeros::(lde_size * 32)?; - launch_keccak_base( - stream.as_ref(), - &buf, - col_stride_u64, - m as u64, - lde_u64, - &mut hashes_dev, - )?; - - // D2H the LDE into the pinned LDE staging and the hashes into a - // dedicated pinned hash staging, in parallel on the same stream. Both - // at pinned PCIe line-rate — pageable D2H of the 128 MB hash buffer - // would otherwise cost ~100 ms per main-trace commit. - stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; - let hashes_u64_len = (lde_size * 32).div_ceil(8); - let hashes_staging_slot = be.pinned_hashes(); - let mut hashes_staging = hashes_staging_slot.lock().unwrap(); - hashes_staging.ensure_capacity(hashes_u64_len, &be.ctx)?; - let hashes_pinned = unsafe { hashes_staging.as_mut_slice(hashes_u64_len) }; - // `memcpy_dtoh` needs a byte slice. Reinterpret the u64 pinned buffer - // as bytes — same allocation, just typed differently. - let hashes_pinned_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(hashes_pinned.as_mut_ptr() as *mut u8, lde_size * 32) - }; - stream.memcpy_dtoh(&hashes_dev, hashes_pinned_bytes)?; - stream.synchronize()?; - - // Copy pinned → caller outputs in parallel with the hash memcpy. - let pinned_ptr = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) - }; - dst.copy_from_slice(src); - }); - // Rayon-parallel memcpy of 128 MB from pinned → caller. Single-threaded - // `copy_from_slice` faults virgin pageable pages one at a time; the - // mm_struct rwsem serialises them into ~100 ms at 1M-fib scale. Chunk - // the slice so ~N cores pre-fault+write in parallel. - const CHUNK: usize = 64 * 1024; // 64 KiB ≈ 16 pages per chunk - let pinned_hash_ptr = hashes_pinned_bytes.as_ptr() as usize; - hashed_leaves_out - .par_chunks_mut(CHUNK) - .enumerate() - .for_each(|(i, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_hash_ptr as *const u8).add(i * CHUNK), dst.len()) - }; - dst.copy_from_slice(src); - }); - drop(hashes_staging); - drop(staging); - Ok(()) + coset_lde_batch_base_into_with_merkle_tree_inner( + columns, + blowup_factor, + weights, + outputs, + hashed_leaves_out, + KeccakCommit::LeavesOnly, + false, + ) + .map(|_| ()) } /// Like `coset_lde_batch_base_into_with_leaf_hash`, but also builds the full @@ -712,6 +661,7 @@ pub fn coset_lde_batch_base_into_with_merkle_tree( weights, outputs, merkle_nodes_out, + KeccakCommit::FullTree, false, ) .map(|_| ()) @@ -733,6 +683,7 @@ pub fn coset_lde_batch_base_into_with_merkle_tree_keep( weights, outputs, merkle_nodes_out, + KeccakCommit::FullTree, true, )?; let handle = opt.expect("keep_device_buf=true must return Some"); @@ -744,7 +695,8 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( blowup_factor: usize, weights: &[u64], outputs: &mut [&mut [u64]], - merkle_nodes_out: &mut [u8], + nodes_out: &mut [u8], + commit: KeccakCommit, keep_device_buf: bool, ) -> Result> { if columns.is_empty() { @@ -754,19 +706,27 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( let m = columns.len(); assert_eq!(outputs.len(), m); let n = columns[0].len(); + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(None); + } assert!(n.is_power_of_two()); assert_eq!(weights.len(), n); assert!(blowup_factor.is_power_of_two()); let lde_size = n * blowup_factor; + assert_u32_domain( + lde_size, + "coset_lde_batch_base_into_with_merkle_tree lde_size", + ); for o in outputs.iter() { assert_eq!(o.len(), lde_size); } - let total_nodes = 2 * lde_size - 1; - assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + let nodes_dev_bytes = commit.total_nodes_bytes(lde_size); + assert_eq!(nodes_out.len(), nodes_dev_bytes); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); @@ -774,7 +734,6 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( staging.ensure_capacity(m * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - use rayon::prelude::*; let pinned_base_ptr = pinned.as_mut_ptr() as usize; columns.par_iter().enumerate().for_each(|(c, col)| { let dst = @@ -798,19 +757,15 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( let m_u32 = m as u32; // iNTT - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -820,33 +775,25 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( col_stride_u64, m_u32, )?; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + m_u32, + )?; // forward NTT at LDE size - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((lde_size as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -857,80 +804,36 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( m_u32, )?; - // Allocate the full node buffer; leaves occupy the tail slab, inner - // nodes are written by the pair-hash level kernel below. `alloc` (not - // `alloc_zeros`) is safe because every byte is written before it is - // read: leaf kernel fills the tail, tree kernel fills the head. - // - // The leaf kernel writes to `nodes_dev` starting at byte offset - // `(lde_size - 1) * 32`; we pass the base pointer as-is because the - // kernel indexes linearly from `hashed_leaves_out[row_idx * 32]`, so we - // build an offset device slice and feed that to the launch. - let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; - let leaves_offset_bytes = (lde_size - 1) * 32; + // Allocate the device output buffer. In `LeavesOnly` mode this is just + // `lde_size * 32` bytes (the leaves themselves); in `FullTree` mode it's + // `(2*lde_size - 1) * 32` bytes (leaves in the tail + inner nodes filled + // below). `alloc` (not `alloc_zeros`) is safe because every byte is + // written before any reader sees it: the keccak kernel fills the + // leaves slab, the inner-tree pass (when present) fills the head. + let mut nodes_dev = unsafe { stream.alloc::(nodes_dev_bytes) }?; + let leaves_offset_bytes = commit.leaves_offset_bytes(lde_size); { let mut leaves_view = nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); - let log_num_rows_leaves = lde_size.trailing_zeros() as u64; - let num_cols_u64 = m as u64; - let grid = (lde_size as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak256_leaves_base_batched) - .arg(&buf) - .arg(&col_stride_u64) - .arg(&num_cols_u64) - .arg(&lde_u64) - .arg(&log_num_rows_leaves) - .arg(&mut leaves_view) - .launch(cfg)?; - } + launch_keccak_base( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut leaves_view, + )?; } - // Inner tree levels — mirror the CPU `build(nodes, leaves_len)` scan. - { - let mut level_begin: u64 = (lde_size - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - let grid = (n_pairs as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) - .launch(cfg)?; - } - level_begin = new_begin; - } + if commit == KeccakCommit::FullTree { + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; } - // D2H the LDE and the tree nodes via pinned staging. + // D2H the LDE and the tree/leaves nodes via pinned staging. stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?; - let tree_u64_len = (total_nodes * 32).div_ceil(8); - let tree_staging_slot = be.pinned_hashes(); - let mut tree_staging = tree_staging_slot.lock().unwrap(); - tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; - let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; - let tree_pinned_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, total_nodes * 32) - }; - stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; - stream.synchronize()?; - - // Parallel memcpy pinned → caller. + // Pinned LDE → caller outputs (post-sync host memcpy). let pinned_ptr = pinned.as_ptr() as usize; outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { let src = unsafe { @@ -938,18 +841,6 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( }; dst.copy_from_slice(src); }); - const CHUNK: usize = 64 * 1024; - let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; - merkle_nodes_out - .par_chunks_mut(CHUNK) - .enumerate() - .for_each(|(i, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) - }; - dst.copy_from_slice(src); - }); - drop(tree_staging); drop(staging); if keep_device_buf { @@ -964,9 +855,9 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( } } -/// Ext3 variant of `coset_lde_batch_base_into_with_leaf_hash`: run an LDE -/// over ext3 columns AND emit Keccak-256 Merkle leaves, all in one on-device -/// pipeline. +/// Ext3 variant of `coset_lde_batch_base_into_with_leaf_hash`: fused +/// LDE + Keccak-256 leaf hashing over ext3 columns. Thin wrapper over +/// `coset_lde_batch_ext3_into_with_merkle_tree_inner` with `LeavesOnly`. pub fn coset_lde_batch_ext3_into_with_leaf_hash( columns: &[&[u64]], n: usize, @@ -975,9 +866,85 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( outputs: &mut [&mut [u64]], hashed_leaves_out: &mut [u8], ) -> Result<()> { + coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + hashed_leaves_out, + KeccakCommit::LeavesOnly, + false, + ) + .map(|_| ()) +} + +/// Ext3 variant of the fused `coset_lde_batch_base_into_with_merkle_tree`. +/// LDE + leaf hashing + inner-tree build, all on device; D2Hs only the LDE +/// evaluations and the full `2*lde_size - 1` node buffer. +pub fn coset_lde_batch_ext3_into_with_merkle_tree( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + KeccakCommit::FullTree, + false, + ) + .map(|_| ()) +} + +/// Ext3 variant of [`coset_lde_batch_base_into_with_merkle_tree_keep`] — +/// returns an `Arc>` handle to the de-interleaved LDE device +/// buffer. +pub fn coset_lde_batch_ext3_into_with_merkle_tree_keep( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result { + let opt = coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + KeccakCommit::FullTree, + true, + )?; + Ok(opt.expect("keep_device_buf=true must return Some")) +} + +#[allow(clippy::too_many_arguments)] +fn coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + nodes_out: &mut [u8], + commit: KeccakCommit, + keep_device_buf: bool, +) -> Result> { if columns.is_empty() { assert_eq!(outputs.len(), 0); - return Ok(()); + return Ok(None); + } + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(None); } let m = columns.len(); assert_eq!(outputs.len(), m); @@ -988,18 +955,20 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( assert_eq!(c.len(), 3 * n); } let lde_size = n * blowup_factor; + assert_u32_domain( + lde_size, + "coset_lde_batch_ext3_into_with_merkle_tree lde_size", + ); for o in outputs.iter() { assert_eq!(o.len(), 3 * lde_size); } - assert_eq!(hashed_leaves_out.len(), lde_size * 32); - if n == 0 { - return Ok(()); - } + let nodes_dev_bytes = commit.total_nodes_bytes(lde_size); + assert_eq!(nodes_out.len(), nodes_dev_bytes); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; let mb = 3 * m; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); @@ -1007,24 +976,7 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - use rayon::prelude::*; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) - }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); + pack_ext3_to_pinned_slabs(columns, pinned, n); let mut buf = stream.alloc_zeros::(mb * lde_size)?; for s in 0..mb { @@ -1041,303 +993,42 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( let col_stride_u64 = lde_size as u64; let mb_u32 = mb as u32; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } - run_batched_ntt_body( + launch_bit_reverse_batched( stream.as_ref(), + be, &mut buf, - inv_tw.as_ref(), n_u64, log_n, col_stride_u64, mb_u32, )?; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } run_batched_ntt_body( stream.as_ref(), &mut buf, - fwd_tw.as_ref(), - lde_u64, - log_lde, + inv_tw.as_ref(), + n_u64, + log_n, col_stride_u64, mb_u32, )?; - - // Keccak-256 on the de-interleaved device buffer (3M base slabs). - let mut hashes_dev = stream.alloc_zeros::(lde_size * 32)?; - launch_keccak_ext3( + launch_pointwise_mul_batched( stream.as_ref(), - &buf, + be, + &mut buf, + &weights_dev, + n_u64, col_stride_u64, - m as u64, - lde_u64, - &mut hashes_dev, - )?; - - // D2H LDE (mb * lde_size u64) and hashes (lde_size * 32 bytes). - stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; - let hashes_u64_len = (lde_size * 32).div_ceil(8); - let hashes_staging_slot = be.pinned_hashes(); - let mut hashes_staging = hashes_staging_slot.lock().unwrap(); - hashes_staging.ensure_capacity(hashes_u64_len, &be.ctx)?; - let hashes_pinned = unsafe { hashes_staging.as_mut_slice(hashes_u64_len) }; - let hashes_pinned_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(hashes_pinned.as_mut_ptr() as *mut u8, lde_size * 32) - }; - stream.memcpy_dtoh(&hashes_dev, hashes_pinned_bytes)?; - stream.synchronize()?; - - // Re-interleave pinned → caller ext3 outputs, parallel. - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; - } - }); - - // Parallel memcpy of pinned hashes → caller. - const CHUNK: usize = 64 * 1024; - let hash_src_ptr = hashes_pinned_bytes.as_ptr() as usize; - hashed_leaves_out - .par_chunks_mut(CHUNK) - .enumerate() - .for_each(|(i, dst)| { - let src = unsafe { - std::slice::from_raw_parts((hash_src_ptr as *const u8).add(i * CHUNK), dst.len()) - }; - dst.copy_from_slice(src); - }); - drop(hashes_staging); - drop(staging); - Ok(()) -} - -/// Ext3 variant of the fused `coset_lde_batch_base_into_with_merkle_tree`. -/// LDE + leaf hashing + inner-tree build, all on device; D2Hs only the LDE -/// evaluations and the full `2*lde_size - 1` node buffer. -pub fn coset_lde_batch_ext3_into_with_merkle_tree( - columns: &[&[u64]], - n: usize, - blowup_factor: usize, - weights: &[u64], - outputs: &mut [&mut [u64]], - merkle_nodes_out: &mut [u8], -) -> Result<()> { - coset_lde_batch_ext3_into_with_merkle_tree_inner( - columns, - n, - blowup_factor, - weights, - outputs, - merkle_nodes_out, - false, - ) - .map(|_| ()) -} - -/// Ext3 variant of [`coset_lde_batch_base_into_with_merkle_tree_keep`] — -/// returns an `Arc>` handle to the de-interleaved LDE device -/// buffer. -pub fn coset_lde_batch_ext3_into_with_merkle_tree_keep( - columns: &[&[u64]], - n: usize, - blowup_factor: usize, - weights: &[u64], - outputs: &mut [&mut [u64]], - merkle_nodes_out: &mut [u8], -) -> Result { - let opt = coset_lde_batch_ext3_into_with_merkle_tree_inner( - columns, - n, - blowup_factor, - weights, - outputs, - merkle_nodes_out, - true, + mb_u32, )?; - Ok(opt.expect("keep_device_buf=true must return Some")) -} - -fn coset_lde_batch_ext3_into_with_merkle_tree_inner( - columns: &[&[u64]], - n: usize, - blowup_factor: usize, - weights: &[u64], - outputs: &mut [&mut [u64]], - merkle_nodes_out: &mut [u8], - keep_device_buf: bool, -) -> Result> { - if columns.is_empty() { - assert_eq!(outputs.len(), 0); - return Ok(None); - } - let m = columns.len(); - assert_eq!(outputs.len(), m); - assert!(n.is_power_of_two()); - assert_eq!(weights.len(), n); - assert!(blowup_factor.is_power_of_two()); - for c in columns.iter() { - assert_eq!(c.len(), 3 * n); - } - let lde_size = n * blowup_factor; - for o in outputs.iter() { - assert_eq!(o.len(), 3 * lde_size); - } - let total_nodes = 2 * lde_size - 1; - assert_eq!(merkle_nodes_out.len(), total_nodes * 32); - if n == 0 { - return Ok(None); - } - let log_n = n.trailing_zeros() as u64; - let log_lde = lde_size.trailing_zeros() as u64; - - let mb = 3 * m; - let be = backend(); - let stream = be.next_stream(); - let staging_slot = be.pinned_staging(); - - let mut staging = staging_slot.lock().unwrap(); - staging.ensure_capacity(mb * lde_size, &be.ctx)?; - let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - - use rayon::prelude::*; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) - }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); - - let mut buf = stream.alloc_zeros::(mb * lde_size)?; - for s in 0..mb { - let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); - stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; - } - - let inv_tw = be.inv_twiddles_for(log_n)?; - let fwd_tw = be.fwd_twiddles_for(log_lde)?; - let weights_dev = stream.clone_htod(weights)?; - - let n_u64 = n as u64; - let lde_u64 = lde_size as u64; - let col_stride_u64 = lde_size as u64; - let mb_u32 = mb as u32; - - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } - run_batched_ntt_body( + launch_bit_reverse_batched( stream.as_ref(), + be, &mut buf, - inv_tw.as_ref(), - n_u64, - log_n, + lde_u64, + log_lde, col_stride_u64, mb_u32, )?; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1348,110 +1039,33 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( mb_u32, )?; - // Allocate full tree buffer; leaf kernel writes to the tail slab. - let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; - let leaves_offset_bytes = (lde_size - 1) * 32; + // Allocate device output buffer (LeavesOnly → lde_size*32; FullTree → + // (2*lde_size - 1)*32). Leaf kernel writes to the leaves slab; the + // inner-tree pass (when present) fills the head. + let mut nodes_dev = unsafe { stream.alloc::(nodes_dev_bytes) }?; + let leaves_offset_bytes = commit.leaves_offset_bytes(lde_size); { let mut leaves_view = nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); - let log_num_rows_leaves = lde_size.trailing_zeros() as u64; - let num_cols_u64 = m as u64; - let grid = (lde_size as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak256_leaves_ext3_batched) - .arg(&buf) - .arg(&col_stride_u64) - .arg(&num_cols_u64) - .arg(&lde_u64) - .arg(&log_num_rows_leaves) - .arg(&mut leaves_view) - .launch(cfg)?; - } + launch_keccak_ext3( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut leaves_view, + )?; } - // Inner tree levels. - { - let mut level_begin: u64 = (lde_size - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - let grid = (n_pairs as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) - .launch(cfg)?; - } - level_begin = new_begin; - } + if commit == KeccakCommit::FullTree { + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; } - // D2H LDE (mb * lde_size u64) and tree nodes. + // D2H LDE (mb * lde_size u64) and tree/leaves nodes. stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; - let tree_u64_len = (total_nodes * 32).div_ceil(8); - let tree_staging_slot = be.pinned_hashes(); - let mut tree_staging = tree_staging_slot.lock().unwrap(); - tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; - let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; - let tree_pinned_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, total_nodes * 32) - }; - stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; - stream.synchronize()?; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?; - // Re-interleave pinned → caller ext3 outputs. - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; - } - }); - - const CHUNK: usize = 64 * 1024; - let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; - merkle_nodes_out - .par_chunks_mut(CHUNK) - .enumerate() - .for_each(|(i, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) - }; - dst.copy_from_slice(src); - }); - drop(tree_staging); + unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); drop(staging); if keep_device_buf { @@ -1475,10 +1089,6 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( /// Skips the iFFT stage of [`coset_lde_batch_ext3_into`] (input is /// coefficients, not evaluations). Weights encode the coset shift: /// `weights[k] = offset^k` (NO 1/N because iFFT normalisation doesn't apply). -/// -/// Used by the stark prover to GPU-accelerate -/// `evaluate_polynomial_on_lde_domain` calls inside the -/// `number_of_parts > 2` branch of the composition-polynomial LDE. pub fn evaluate_poly_coset_batch_ext3_into( coefs: &[&[u64]], n: usize, @@ -1486,205 +1096,60 @@ pub fn evaluate_poly_coset_batch_ext3_into( weights: &[u64], outputs: &mut [&mut [u64]], ) -> Result<()> { - evaluate_poly_coset_batch_ext3_into_inner(coefs, n, blowup_factor, weights, outputs, false) - .map(|_| ()) + evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + None, + false, + ) + .map(|_| ()) } /// Same as [`evaluate_poly_coset_batch_ext3_into`] but retains the de- -/// interleaved LDE device buffer as a `GpuLdeExt3` handle. Lets R2 commit -/// and R4 DEEP composition read the composition-parts LDE without -/// re-H2D'ing. +/// interleaved LDE device buffer as a `GpuLdeExt3` handle so callers can +/// reuse the LDE without a re-H2D. pub fn evaluate_poly_coset_batch_ext3_into_keep( coefs: &[&[u64]], n: usize, - blowup_factor: usize, - weights: &[u64], - outputs: &mut [&mut [u64]], -) -> Result { - let opt = - evaluate_poly_coset_batch_ext3_into_inner(coefs, n, blowup_factor, weights, outputs, true)?; - Ok(opt.expect("keep_device_buf=true must return Some")) -} - -fn evaluate_poly_coset_batch_ext3_into_inner( - coefs: &[&[u64]], - n: usize, - blowup_factor: usize, - weights: &[u64], - outputs: &mut [&mut [u64]], - keep_device_buf: bool, -) -> Result> { - if coefs.is_empty() { - assert_eq!(outputs.len(), 0); - return Ok(None); - } - let m = coefs.len(); - assert_eq!(outputs.len(), m); - assert!(n.is_power_of_two()); - assert_eq!(weights.len(), n); - assert!(blowup_factor.is_power_of_two()); - for c in coefs.iter() { - assert_eq!(c.len(), 3 * n); - } - let lde_size = n * blowup_factor; - for o in outputs.iter() { - assert_eq!(o.len(), 3 * lde_size); - } - if n == 0 { - return Ok(None); - } - let log_lde = lde_size.trailing_zeros() as u64; - - let mb = 3 * m; - let be = backend(); - let stream = be.next_stream(); - let staging_slot = be.pinned_staging(); - - let mut staging = staging_slot.lock().unwrap(); - staging.ensure_capacity(mb * lde_size, &be.ctx)?; - let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - - use rayon::prelude::*; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - coefs.par_iter().enumerate().for_each(|(c, col)| { - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) - }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); - - let mut buf = stream.alloc_zeros::(mb * lde_size)?; - for s in 0..mb { - let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); - stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; - } - - let fwd_tw = be.fwd_twiddles_for(log_lde)?; - let weights_dev = stream.clone_htod(weights)?; - - let n_u64 = n as u64; - let lde_u64 = lde_size as u64; - let col_stride_u64 = lde_size as u64; - let mb_u32 = mb as u32; - - // Apply coset scaling: x[k] *= weights[k] for k in 0..n (no iFFT first). - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } - - // Bit-reverse full lde_size slab, then forward DIT NTT. - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } - run_batched_ntt_body( - stream.as_ref(), - &mut buf, - fwd_tw.as_ref(), - lde_u64, - log_lde, - col_stride_u64, - mb_u32, - )?; - - stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; - stream.synchronize()?; - - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; - } - }); - drop(staging); - if keep_device_buf { - Ok(Some(GpuLdeExt3 { - buf: std::sync::Arc::new(buf), - m, - lde_size, - })) - } else { - drop(buf); - Ok(None) - } + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result { + let opt = evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + None, + true, + )?; + Ok(opt.expect("keep_device_buf=true must return Some")) } -/// Fused variant of [`evaluate_poly_coset_batch_ext3_into`]: in addition to -/// the LDE output, builds the R2 composition-polynomial Merkle tree on device -/// (row-pair Keccak leaves at bit-reversed indices + pair-hash inner tree). -/// -/// `merkle_nodes_out` must have byte length `(2 * lde_size - 1) * 32`. -/// Requires `lde_size >= 2`. -pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( +fn evaluate_poly_coset_batch_ext3_into_inner( coefs: &[&[u64]], n: usize, blowup_factor: usize, weights: &[u64], outputs: &mut [&mut [u64]], - merkle_nodes_out: &mut [u8], -) -> Result<()> { + merkle_nodes_out: Option<&mut [u8]>, + keep_device_buf: bool, +) -> Result> { if coefs.is_empty() { - return Ok(()); + assert_eq!(outputs.len(), 0); + return Ok(None); } let m = coefs.len(); assert_eq!(outputs.len(), m); + // Empty domain must short-circuit before the power-of-two assert + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(None); + } assert!(n.is_power_of_two()); assert_eq!(weights.len(), n); assert!(blowup_factor.is_power_of_two()); @@ -1695,16 +1160,14 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( for o in outputs.iter() { assert_eq!(o.len(), 3 * lde_size); } - assert!(lde_size >= 2); - let total_nodes = 2 * lde_size - 1; - assert_eq!(merkle_nodes_out.len(), total_nodes * 32); - if n == 0 { - return Ok(()); + assert_u32_domain(lde_size, "evaluate_poly_coset_batch_ext3_into lde_size"); + if merkle_nodes_out.is_some() { + assert!(lde_size >= 2); } let log_lde = lde_size.trailing_zeros() as u64; let mb = 3 * m; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); @@ -1712,24 +1175,7 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - use rayon::prelude::*; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - coefs.par_iter().enumerate().for_each(|(c, col)| { - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) - }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); + pack_ext3_to_pinned_slabs(coefs, pinned, n); let mut buf = stream.alloc_zeros::(mb * lde_size)?; for s in 0..mb { @@ -1745,32 +1191,27 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( let col_stride_u64 = lde_size as u64; let mb_u32 = mb as u32; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + // Apply coset scaling: x[k] *= weights[k] for k in 0..n (no iFFT first). + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + mb_u32, + )?; + + // Bit-reverse full lde_size slab, then forward DIT NTT. + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1781,121 +1222,79 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( mb_u32, )?; - // Build the row-pair Merkle tree on device. - // - // Row-pair commit: each leaf hashes 2 rows (bit-reversed indices) → - // num_leaves = lde_size / 2. Tree size: 2*num_leaves - 1 = lde_size - 1. - let num_leaves = lde_size / 2; - let tight_total_nodes = 2 * num_leaves - 1; - let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; - let leaves_offset_bytes = (num_leaves - 1) * 32; - { - let mut leaves_view = - nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); - let log_num_rows = log_lde; - let num_parts_u64 = m as u64; - let grid = (num_leaves as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak_comp_poly_leaves_ext3) - .arg(&buf) - .arg(&col_stride_u64) - .arg(&num_parts_u64) - .arg(&lde_u64) - .arg(&log_num_rows) - .arg(&mut leaves_view) - .launch(cfg)?; - } - } - { - let mut level_begin: u64 = (num_leaves - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - let grid = (n_pairs as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; + // Optional R2-style row-pair Merkle tree build on the LDE buffer. + if let Some(nodes_out) = merkle_nodes_out { + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + assert_eq!(nodes_out.len(), tight_total_nodes * 32); + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let log_num_rows = log_lde; + let num_parts_u64 = m as u64; + let cfg = keccak_launch_cfg(num_leaves as u64); unsafe { stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) + .launch_builder(&be.keccak_comp_poly_leaves_ext3) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_parts_u64) + .arg(&lde_u64) + .arg(&log_num_rows) + .arg(&mut leaves_view) .launch(cfg)?; } - level_begin = new_begin; } - } + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; - // D2H LDE and tree. - stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; - let tree_u64_len = (tight_total_nodes * 32).div_ceil(8); - let tree_staging_slot = be.pinned_hashes(); - let mut tree_staging = tree_staging_slot.lock().unwrap(); - tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; - let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; - let tree_pinned_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, tight_total_nodes * 32) - }; - stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; - stream.synchronize()?; - - // Re-interleave pinned → caller ext3 outputs. - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; - } - }); + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?; + } else { + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + stream.synchronize()?; + } - // Copy pinned tree → caller nodes_out. `merkle_nodes_out.len() == - // total_nodes * 32` is oversized relative to our tight tree; we write - // only the first `tight_total_nodes * 32` bytes and the caller trims. - // Expose the tight byte count via the slice length so the caller can - // construct the MerkleTree with the right node count. - assert!(merkle_nodes_out.len() >= tight_total_nodes * 32); - const CHUNK: usize = 64 * 1024; - let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; - merkle_nodes_out[..tight_total_nodes * 32] - .par_chunks_mut(CHUNK) - .enumerate() - .for_each(|(i, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) - }; - dst.copy_from_slice(src); - }); - drop(tree_staging); + unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); drop(staging); - Ok(()) + if keep_device_buf { + Ok(Some(GpuLdeExt3 { + buf: std::sync::Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + +/// Fused variant of [`evaluate_poly_coset_batch_ext3_into`]: in addition to +/// the LDE output, builds the R2 composition-polynomial Merkle tree on device +/// (row-pair Keccak leaves at bit-reversed indices + pair-hash inner tree). +/// +/// Row-pair commit: each leaf hashes 2 bit-reversed rows, so the tree has +/// `lde_size / 2` leaves and `merkle_nodes_out` must have byte length +/// `(lde_size - 1) * 32`. Requires `lde_size >= 2`. +pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + Some(merkle_nodes_out), + false, + ) + .map(|_| ()) } /// Batched coset LDE for Goldilocks **cubic extension** columns. /// @@ -1929,6 +1328,11 @@ pub fn coset_lde_batch_ext3_into( } let m = columns.len(); assert_eq!(outputs.len(), m, "outputs must match columns count"); + // Empty domain must short-circuit before the power-of-two assert + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(()); + } assert!(n.is_power_of_two(), "n must be a power of two"); assert_eq!(weights.len(), n, "weights length must match n"); assert!( @@ -1942,16 +1346,14 @@ pub fn coset_lde_batch_ext3_into( for o in outputs.iter() { assert_eq!(o.len(), 3 * lde_size, "each output must be 3*lde_size u64s"); } - if n == 0 { - return Ok(()); - } + assert_u32_domain(lde_size, "coset_lde_batch_ext3_into lde_size"); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; // 3 base slabs per ext3 column; slab index `c*3 + k` holds component `k`. let mb = 3 * m; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); @@ -1959,28 +1361,7 @@ pub fn coset_lde_batch_ext3_into( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - // Pack: for each ext3 column, write 3 base slabs into pinned. The slab - // for column c, component k lives at `pinned[(c*3 + k)*n .. (c*3+k)*n + n]`. - // We de-interleave from the interleaved `[a, b, c, a, b, c, ...]` input. - use rayon::prelude::*; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - // SAFETY: disjoint regions per c; staging lock held. - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) - }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); + pack_ext3_to_pinned_slabs(columns, pinned, n); // Allocate + zero-pad device buffer holding 3M slabs of `lde_size`. let mut buf = stream.alloc_zeros::(mb * lde_size)?; @@ -2001,23 +1382,15 @@ pub fn coset_lde_batch_ext3_into( // === Butterflies: identical to the base-field batched path, but with // grid.y = 3M instead of M. === - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -2027,40 +1400,24 @@ pub fn coset_lde_batch_ext3_into( col_stride_u64, mb_u32, )?; - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + mb_u32, + )?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -2076,32 +1433,7 @@ pub fn coset_lde_batch_ext3_into( // Unpack: for each output column, re-interleave 3 slabs back into the // ext3-per-element layout. - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; - } - }); + unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); drop(staging); Ok(()) } @@ -2118,7 +1450,7 @@ fn run_batched_ntt_body( col_stride: u64, m: u32, ) -> Result<()> { - let be = backend(); + let be = backend()?; let fused = core::cmp::min(log_n, 8); if fused >= 8 { let grid_x = (n / 256) as u32; diff --git a/crypto/math-cuda/src/lib.rs b/crypto/math-cuda/src/lib.rs index d5d0c9fbc..d1b4e1210 100644 --- a/crypto/math-cuda/src/lib.rs +++ b/crypto/math-cuda/src/lib.rs @@ -38,7 +38,7 @@ pub fn gl_neg_u64(a: &[u64]) -> Result> { if n == 0 { return Ok(Vec::new()); } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let a_dev = stream.clone_htod(a)?; @@ -69,7 +69,7 @@ pub fn ext3_mul_u64(a: &[u64], b: &[u64]) -> Result> { if n == 0 { return Ok(Vec::new()); } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let a_dev = stream.clone_htod(a)?; let b_dev = stream.clone_htod(b)?; @@ -90,6 +90,35 @@ pub fn ext3_mul_u64(a: &[u64], b: &[u64]) -> Result> { Ok(out) } +/// Element-wise ext3 subtract. Test helper for `ext3::sub` in `ext3.cuh`. +pub fn ext3_sub_u64(a: &[u64], b: &[u64]) -> Result> { + assert_eq!(a.len(), b.len()); + assert_eq!(a.len() % 3, 0); + let n = a.len() / 3; + if n == 0 { + return Ok(Vec::new()); + } + let be = backend()?; + let stream = be.next_stream(); + let a_dev = stream.clone_htod(a)?; + let b_dev = stream.clone_htod(b)?; + let mut c_dev = stream.alloc_zeros::(3 * n)?; + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.ext3_sub) + .arg(&a_dev) + .arg(&b_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} + /// Element-wise ext3 add. pub fn ext3_add_u64(a: &[u64], b: &[u64]) -> Result> { assert_eq!(a.len(), b.len()); @@ -98,7 +127,7 @@ pub fn ext3_add_u64(a: &[u64], b: &[u64]) -> Result> { if n == 0 { return Ok(Vec::new()); } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let a_dev = stream.clone_htod(a)?; let b_dev = stream.clone_htod(b)?; @@ -128,7 +157,7 @@ where if n == 0 { return Ok(Vec::new()); } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let a_dev = stream.clone_htod(a)?; diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs index 0f80206a5..6faf12b51 100644 --- a/crypto/math-cuda/src/merkle.rs +++ b/crypto/math-cuda/src/merkle.rs @@ -3,7 +3,7 @@ //! Matches `FieldElementVectorBackend::hash_data` in //! `crypto/crypto/src/merkle_tree/backends/field_element_vector.rs`, combined //! with the `reverse_index` row read pattern used in -//! `commit_columns_bit_reversed` at `crypto/stark/src/prover.rs:368`. +//! `commit_columns_bit_reversed` at `crypto/stark/src/prover.rs`. //! //! Caller supplies base-field column slabs already laid out as //! `[col * col_stride + row]` (the same layout `coset_lde_batch_base_into` @@ -11,14 +11,16 @@ //! reads each column's canonical u64 at that row, byte-swaps it into a //! Keccak lane, absorbs lane-by-lane, and squeezes 32 bytes per leaf. //! -//! For ext3 columns the layout is `[col*3*col_stride + k*col_stride + row]` -//! — three base slabs per ext3 column — and the kernel reads three u64s per -//! column in component order 0,1,2 to match `FieldElement::::write_bytes_be`. +//! For ext3 columns the layout is `[col*3*col_stride + k*col_stride + row]`, +//! three base-field components per ext3 column, indexed by `k ∈ {0,1,2}`, +//! and the kernel reads three u64s per column in component order 0,1,2 +//! to match `FieldElement::::write_bytes_be`. -use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; +use cudarc::driver::{CudaSlice, CudaStream, CudaViewMut, LaunchConfig, PushKernelArg}; use crate::Result; -use crate::device::backend; +use crate::device::{Backend, backend}; +use crate::lde::pack_ext3_to_pinned_slabs; /// Run GPU Keccak-256 leaf hashing on a base-field column buffer. /// @@ -32,10 +34,17 @@ pub fn keccak_leaves_base( num_rows: usize, ) -> Result> { assert!(num_rows.is_power_of_two()); - assert!(columns.len() >= num_cols * col_stride); - let be = backend(); + assert!( + col_stride >= num_rows, + "col_stride must be >= num_rows to keep per-column reads in-bounds" + ); + let total = num_cols + .checked_mul(col_stride) + .expect("num_cols * col_stride overflows usize"); + assert!(columns.len() >= total); + let be = backend()?; let stream = be.next_stream(); - let cols_dev = stream.clone_htod(&columns[..num_cols * col_stride])?; + let cols_dev = stream.clone_htod(&columns[..total])?; let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; launch_keccak_base( stream.as_ref(), @@ -43,14 +52,13 @@ pub fn keccak_leaves_base( col_stride as u64, num_cols as u64, num_rows as u64, - &mut out_dev, + &mut out_dev.as_view_mut(), )?; let out = stream.clone_dtoh(&out_dev)?; - stream.synchronize()?; Ok(out) } -/// Ext3 variant — columns interleaved as three base slabs per ext3 column. +/// Ext3 variant. Columns interleaved as three base slabs per ext3 column. /// `columns.len() >= num_cols * 3 * col_stride`. pub fn keccak_leaves_ext3( columns: &[u64], @@ -59,10 +67,18 @@ pub fn keccak_leaves_ext3( num_rows: usize, ) -> Result> { assert!(num_rows.is_power_of_two()); - assert!(columns.len() >= num_cols * 3 * col_stride); - let be = backend(); + assert!( + col_stride >= num_rows, + "col_stride must be >= num_rows to keep per-column reads in-bounds" + ); + let total = num_cols + .checked_mul(3) + .and_then(|v| v.checked_mul(col_stride)) + .expect("num_cols * 3 * col_stride overflows usize"); + assert!(columns.len() >= total); + let be = backend()?; let stream = be.next_stream(); - let cols_dev = stream.clone_htod(&columns[..num_cols * 3 * col_stride])?; + let cols_dev = stream.clone_htod(&columns[..total])?; let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; launch_keccak_ext3( stream.as_ref(), @@ -70,20 +86,23 @@ pub fn keccak_leaves_ext3( col_stride as u64, num_cols as u64, num_rows as u64, - &mut out_dev, + &mut out_dev.as_view_mut(), )?; let out = stream.clone_dtoh(&out_dev)?; - stream.synchronize()?; Ok(out) } /// Block size for Keccak kernels. Per-thread register footprint is ~60 regs -/// (25-lane state + auxiliaries); the default 256 threads/block pushes the +/// (25-lane state + auxiliaries). The default 256 threads/block pushes the /// block register file past the hardware limit on sm_120 (Blackwell). 128 /// keeps us inside the budget with some head-room. const KECCAK_BLOCK_DIM: u32 = 128; -fn keccak_launch_cfg(num_rows: u64) -> LaunchConfig { +pub(crate) fn keccak_launch_cfg(num_rows: u64) -> LaunchConfig { + debug_assert!( + num_rows <= u32::MAX as u64, + "keccak_launch_cfg: num_rows ({num_rows}) exceeds u32 grid range", + ); let grid = (num_rows as u32).div_ceil(KECCAK_BLOCK_DIM); LaunchConfig { grid_dim: (grid, 1, 1), @@ -92,15 +111,47 @@ fn keccak_launch_cfg(num_rows: u64) -> LaunchConfig { } } +/// Walk the inner Merkle tree on device. `nodes_dev` already has the +/// `leaves_len` hashed leaves written into the tail; this loops +/// `log2(leaves_len)` times invoking `keccak_merkle_level` to fill in the +/// inner nodes from the bottom up. Mirrors the CPU `build(nodes, leaves_len)` +/// scan in `crypto/crypto/src/merkle_tree/merkle.rs`. +pub(crate) fn build_inner_tree_levels( + stream: &CudaStream, + be: &Backend, + nodes_dev: &mut CudaSlice, + leaves_len: usize, +) -> Result<()> { + let mut level_begin: u64 = (leaves_len - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let cfg = keccak_launch_cfg(n_pairs); + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut *nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + Ok(()) +} + pub(crate) fn launch_keccak_base( stream: &CudaStream, cols_dev: &CudaSlice, col_stride: u64, num_cols: u64, num_rows: u64, - out_dev: &mut CudaSlice, + out_dev: &mut CudaViewMut<'_, u8>, ) -> Result<()> { - let be = backend(); + // The kernel computes `__brevll(tid) >> (64 - log_num_rows)`, which is UB + // for `log_num_rows == 0` (single-row trees are degenerate anyway). + debug_assert!(num_rows >= 2, "keccak leaf kernel: num_rows must be >= 2"); + let be = backend()?; let log_num_rows = num_rows.trailing_zeros() as u64; let cfg = keccak_launch_cfg(num_rows); unsafe { @@ -128,7 +179,7 @@ pub(crate) fn launch_keccak_base( /// the resulting `nodes` Vec plugs straight into `MerkleTree { root, nodes }` /// for downstream proof generation. /// -/// `leaves_len` must be a power of two and ≥ 2. +/// `leaves_len` must be a power of two and >= 2. pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { assert!(hashed_leaves.len().is_multiple_of(32)); let leaves_len = hashed_leaves.len() / 32; @@ -139,10 +190,10 @@ pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { ); let total_nodes = 2 * leaves_len - 1; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); - // Allocate the full node buffer without zero-fill — we overwrite the + // Allocate the full node buffer without zero-fill. We overwrite the // leaf half via H2D immediately, and every inner node is written by the // pair-hash kernel below. // SAFETY: every byte is written before it is read: leaves are filled by @@ -157,33 +208,9 @@ pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { stream.memcpy_htod(hashed_leaves, &mut slice)?; } - // Build level by level. The CPU `build(nodes, leaves_len)` starts with - // level_begin_index = leaves_len - 1 - // level_end_index = 2 * level_begin_index - // and each iteration computes: - // new_level_begin_index = level_begin_index / 2 - // new_level_length = level_begin_index - new_level_begin_index - // The parents occupy [new_level_begin_index, level_begin_index); the - // children occupy [level_begin_index, level_end_index + 1). - let mut level_begin: u64 = (leaves_len - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - - let cfg = keccak_launch_cfg(n_pairs); - unsafe { - stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) - .launch(cfg)?; - } - level_begin = new_begin; - } + build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, leaves_len)?; let out = stream.clone_dtoh(&nodes_dev)?; - stream.synchronize()?; Ok(out) } @@ -210,7 +237,7 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res let num_leaves = lde_size / 2; let tight_total_nodes = 2 * num_leaves - 1; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); @@ -220,36 +247,7 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - use rayon::prelude::*; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - parts_interleaved - .par_iter() - .enumerate() - .for_each(|(c, col)| { - let slab_a = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); + pack_ext3_to_pinned_slabs(parts_interleaved, pinned, lde_size); // H2D the de-interleaved parts. let mut buf = stream.alloc_zeros::(mb * lde_size)?; @@ -265,12 +263,7 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res let num_parts_u64 = m as u64; let num_rows_u64 = lde_size as u64; let log_num_rows = lde_size.trailing_zeros() as u64; - let grid = (num_leaves as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; + let cfg = keccak_launch_cfg(num_leaves as u64); unsafe { stream .launch_builder(&be.keccak_comp_poly_leaves_ext3) @@ -284,38 +277,15 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res } } - // Inner tree. - { - let mut level_begin: u64 = (num_leaves - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - let grid = (n_pairs as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) - .launch(cfg)?; - } - level_begin = new_begin; - } - } + build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; let out = stream.clone_dtoh(&nodes_dev)?; - stream.synchronize()?; drop(staging); Ok(out) } /// Build a FRI-layer Merkle tree on device from an interleaved ext3 eval -/// vector. Each leaf hashes two consecutive ext3 values; `num_leaves = +/// vector. Each leaf hashes two consecutive ext3 values. `num_leaves = /// evals.len() / 6` (since each ext3 is 3 u64s). /// /// Returns the `(2*num_leaves - 1) * 32`-byte node buffer in standard layout. @@ -326,13 +296,10 @@ pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { ); let num_evals = evals.len() / 3; let num_leaves = num_evals / 2; - assert!(num_leaves.is_power_of_two() && num_leaves >= 1); + assert!(num_leaves.is_power_of_two() && num_leaves >= 2); let tight_total_nodes = 2 * num_leaves - 1; - if tight_total_nodes == 0 { - return Ok(Vec::new()); - } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let evals_dev = stream.clone_htod(evals)?; @@ -344,12 +311,7 @@ pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { let mut leaves_view = nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); let num_leaves_u64 = num_leaves as u64; - let grid = (num_leaves as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; + let cfg = keccak_launch_cfg(num_leaves as u64); unsafe { stream .launch_builder(&be.keccak_fri_leaves_ext3) @@ -360,32 +322,9 @@ pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { } } - // Inner tree levels — identical to the R2 version. - { - let mut level_begin: u64 = (num_leaves - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - let grid = (n_pairs as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) - .launch(cfg)?; - } - level_begin = new_begin; - } - } + build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; let out = stream.clone_dtoh(&nodes_dev)?; - stream.synchronize()?; Ok(out) } @@ -395,9 +334,12 @@ pub(crate) fn launch_keccak_ext3( col_stride: u64, num_cols: u64, num_rows: u64, - out_dev: &mut CudaSlice, + out_dev: &mut CudaViewMut<'_, u8>, ) -> Result<()> { - let be = backend(); + // The kernel computes `__brevll(tid) >> (64 - log_num_rows)`, which is UB + // for `log_num_rows == 0` (single-row trees are degenerate anyway). + debug_assert!(num_rows >= 2, "keccak leaf kernel: num_rows must be >= 2"); + let be = backend()?; let log_num_rows = num_rows.trailing_zeros() as u64; let cfg = keccak_launch_cfg(num_rows); unsafe { diff --git a/crypto/math-cuda/src/ntt.rs b/crypto/math-cuda/src/ntt.rs index 31333c3ae..f5accc3b1 100644 --- a/crypto/math-cuda/src/ntt.rs +++ b/crypto/math-cuda/src/ntt.rs @@ -14,10 +14,13 @@ use math::field::traits::{IsFFTField, IsField}; use crate::Result; use crate::device::backend; -/// Host-side twiddle table: `[ω^0, ω^1, …, ω^{n/2-1}]` where ω is the +/// Host-side twiddle table: `[ω^0, ω^1, ..., ω^{n/2-1}]` where ω is the /// primitive n-th root of unity. Exposed for `device::Backend::cached_twiddles` /// and for direct use in tests / benches. pub fn twiddles_forward(log_n: u64) -> Vec { + // Smallest meaningful NTT is size 2 (log_n = 1); size-1 has nothing to + // twiddle. The shift `1 << (log_n - 1)` underflows for log_n = 0. + assert!(log_n >= 1, "twiddles_forward: log_n must be >= 1"); let omega = *GoldilocksField::get_primitive_root_of_unity(log_n) .expect("primitive root") .value(); @@ -26,6 +29,7 @@ pub fn twiddles_forward(log_n: u64) -> Vec { /// Inverse twiddle table: `[ω^{-i}]` for i in [0, n/2). pub fn twiddles_inverse(log_n: u64) -> Vec { + assert!(log_n >= 1, "twiddles_inverse: log_n must be >= 1"); let omega = GoldilocksField::get_primitive_root_of_unity(log_n).expect("primitive root"); let omega_inv = FieldElement::::inv(&omega).expect("inverse"); powers_of(*omega_inv.value(), 1usize << (log_n - 1)) @@ -56,13 +60,20 @@ pub fn inverse(evals: &[u64]) -> Result> { fn ntt_inplace(input: &[u64], forward: bool) -> Result> { let n = input.len(); - assert!(n.is_power_of_two(), "ntt length must be a power of two"); + // Empty / size-1 has no work to do. `is_power_of_two()` returns false for + // 0, so this branch must come before the assert to avoid panicking on + // empty input. if n <= 1 { return Ok(input.to_vec()); } + assert!(n.is_power_of_two(), "ntt length must be a power of two"); + assert!( + n <= u32::MAX as usize, + "ntt length {n} exceeds u32 range — kernel grid would silently truncate", + ); let log_n = n.trailing_zeros() as u64; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let mut x_dev = stream.clone_htod(input)?; @@ -117,7 +128,7 @@ pub(crate) fn run_ntt_body( n: u64, log_n: u64, ) -> Result<()> { - let be = backend(); + let be = backend()?; // Levels 0..min(log_n, 8): one shmem-fused launch. Loads are fully // coalesced (base_step=0 → `row = tid`) and 8 butterfly rounds stay on // chip. This is the big DRAM-bandwidth win. @@ -183,7 +194,7 @@ pub fn pointwise_mul(x: &[u64], w: &[u64]) -> Result> { if n == 0 { return Ok(Vec::new()); } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let mut x_dev = stream.clone_htod(x)?; diff --git a/crypto/math-cuda/tests/bench_quick.rs b/crypto/math-cuda/tests/bench_quick.rs index 41e9444c3..9e3883085 100644 --- a/crypto/math-cuda/tests/bench_quick.rs +++ b/crypto/math-cuda/tests/bench_quick.rs @@ -99,8 +99,9 @@ fn bench_lde_2_to_16_blowup_4() { #[test] #[ignore = "informal perf probe; run with --ignored"] fn bench_lde_multi_column_parallel() { - // Simulates the prover's Phase A: many columns processed via rayon. - // log_n = 16 keeps memory footprint manageable while still stressing streams. + // Simulates a multi-column workload processed via rayon: many columns + // dispatched concurrently to stress the stream pool. log_n = 16 keeps + // memory footprint manageable. let log_n = 16u32; let blowup = 4usize; let n = 1usize << log_n; @@ -156,8 +157,8 @@ fn bench_lde_multi_column_parallel() { #[test] #[ignore = "informal perf probe; run with --ignored"] fn bench_lde_batched_prover_scale() { - // Realistic large-table shape from the 1M-fib prover: ~1M rows, blowup 4, - // a few dozen columns. This is what actually runs in expand_columns_to_lde. + // Realistic large-table shape: ~1M rows, blowup 4, a few dozen columns. + // Exercises batched LDE at prover-scale sizes. let log_n = 20u32; // 1M rows let blowup = 4usize; let n = 1usize << log_n; diff --git a/crypto/math-cuda/tests/ext3_edge.rs b/crypto/math-cuda/tests/ext3_edge.rs new file mode 100644 index 000000000..f298fe884 --- /dev/null +++ b/crypto/math-cuda/tests/ext3_edge.rs @@ -0,0 +1,176 @@ +//! Adversarial edge-case parity tests for `ext3::mul` on the GPU. +//! +//! The CUDA `dot3` in kernels/ext3.cuh manually tracks overflow when +//! summing three u128 products in split u64 hi/lo registers. The CPU +//! reference (`crypto/math/src/field/goldilocks.rs::dot_product_3` via +//! `Degree3GoldilocksExtensionField::mul`) uses native u128 and so reaches +//! the same answer via a totally different code path. These tests pick +//! inputs that maximally stress the overflow-count tracking, the +//! non-canonical input handling, and the identity/zero cases that random +//! tests are unlikely to cover. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +const P: u64 = 0xFFFF_FFFF_0000_0001; // Goldilocks prime +const EPSILON: u64 = 0xFFFF_FFFF; // 2^32 - 1 + +fn canon(x: u64) -> u64 { + GoldilocksField::canonical(&x) +} + +fn canon3(t: [u64; 3]) -> [u64; 3] { + [canon(t[0]), canon(t[1]), canon(t[2])] +} + +/// Run `pairs` of (a, b) through the GPU `ext3_mul_u64` and compare +/// canonical output to the CPU `Degree3GoldilocksExtensionField::mul`. +/// `label_fn(i)` produces a per-case label printed on failure. +fn assert_ext3_mul_pairs(pairs: &[([u64; 3], [u64; 3])], label_fn: impl Fn(usize) -> String) { + let mut a_raw = Vec::with_capacity(pairs.len() * 3); + let mut b_raw = Vec::with_capacity(pairs.len() * 3); + for (a, b) in pairs { + a_raw.extend_from_slice(a); + b_raw.extend_from_slice(b); + } + let gpu = math_cuda::ext3_mul_u64(&a_raw, &b_raw).expect("GPU ext3 mul launch"); + assert_eq!(gpu.len(), 3 * pairs.len()); + for (i, (a, b)) in pairs.iter().enumerate() { + // CPU reference. Build via from_raw so non-canonical inputs (like + // u64::MAX or p, p+1) are passed in untouched — matching what the + // GPU sees. + let ae = [Fp::from_raw(a[0]), Fp::from_raw(a[1]), Fp::from_raw(a[2])]; + let be = [Fp::from_raw(b[0]), Fp::from_raw(b[1]), Fp::from_raw(b[2])]; + let cpu = Degree3GoldilocksExtensionField::mul(&ae, &be); + let cpu_fp3 = Fp3::new(cpu); + let g = canon3([gpu[3 * i], gpu[3 * i + 1], gpu[3 * i + 2]]); + let c = canon3([ + *cpu_fp3.value()[0].value(), + *cpu_fp3.value()[1].value(), + *cpu_fp3.value()[2].value(), + ]); + assert_eq!( + g, + c, + "ext3 mul mismatch [{}]: a={:?} b={:?} gpu={:?} cpu={:?}", + label_fn(i), + a, + b, + g, + c, + ); + } +} + +#[test] +fn ext3_mul_max_canonical_inputs() { + // (p-1, p-1, p-1) * (p-1, p-1, p-1) — every base limb is (p-1), every + // dot3 product a_i*b_i = (p-1)^2 ~ 2^128, so summing three of them + // forces the overflow path twice on each component. + let m1 = [P - 1, P - 1, P - 1]; + let pairs = vec![(m1, m1)]; + assert_ext3_mul_pairs(&pairs, |_| "(p-1)^3 squared".into()); +} + +#[test] +fn ext3_mul_zero_cases() { + // (0,0,0) * (p-1,p-1,p-1) must be zero; covers the "all-zero a" path + // where every dot3 product is zero and no overflow occurs. + let z = [0u64, 0, 0]; + let m = [P - 1, P - 1, P - 1]; + let pairs = vec![(z, m), (m, z), (z, z)]; + assert_ext3_mul_pairs(&pairs, |i| format!("zero case {i}")); +} + +#[test] +fn ext3_mul_identity() { + // (1, 0, 0) * (a, b, c) == (a, b, c). One-component is multiplicative + // identity in Fp[w]/(w^3 - 2). Use varied b to also exercise small + // non-zero dot3 products. + let id = [1u64, 0, 0]; + let cases: Vec<([u64; 3], [u64; 3])> = vec![ + (id, [0, 0, 0]), + (id, [1, 0, 0]), + (id, [0, 1, 0]), + (id, [0, 0, 1]), + (id, [P - 1, 1, 2]), + (id, [123, 456, 789]), + (id, [P - 1, P - 1, P - 1]), + // Reverse order: (a, b, c) * (1, 0, 0). + ([0, 0, 0], id), + ([1, 0, 0], id), + ([0, 1, 0], id), + ([0, 0, 1], id), + ([P - 1, 1, 2], id), + ([123, 456, 789], id), + ]; + assert_ext3_mul_pairs(&cases, |i| format!("identity case {i}")); +} + +#[test] +fn ext3_mul_non_canonical_zero_p() { + // (p, p, p) is a non-canonical representation of (0, 0, 0). The CPU + // canonicalises before the dot3 in some code paths, the GPU does not. + // Either way the product must canonicalise to zero. + let p = [P, P, P]; + let some = [123u64, 456, 789]; + let pairs = vec![(p, p), (p, some), (some, p)]; + assert_ext3_mul_pairs(&pairs, |i| format!("non-canonical-p case {i}")); +} + +#[test] +fn ext3_mul_u64_max_all_overflow_paths() { + // (u64::MAX, u64::MAX, u64::MAX) for both operands. u64::MAX = p + (2^32 - 2), + // i.e., a non-canonical representation of (2^32 - 2) mod p. Every dot3 + // product is ~2^128 - small, so summing three of them is the hardest + // possible exercise of `over1`, `over2`, and the EPSILON^2 correction + // path in `dot3`. + let m = [u64::MAX, u64::MAX, u64::MAX]; + assert_ext3_mul_pairs(&[(m, m)], |_| "u64::MAX^3 squared".into()); +} + +#[test] +fn ext3_mul_base_edge_pairs_embedded() { + // Embed every base-field edge value as the `a` component of an ext3 + // element (so b = c = 0) and run all NxN pairs through GPU mul. This + // reduces to base-field multiplication on the a-component but + // exercises all the dot3 zero/non-zero combinations. + let edges: Vec = vec![0, 1, P - 1, P, P + 1, u64::MAX, EPSILON]; + let mut pairs = Vec::with_capacity(edges.len() * edges.len()); + for x in &edges { + for y in &edges { + pairs.push(([*x, 0, 0], [*y, 0, 0])); + } + } + assert_ext3_mul_pairs(&pairs, |i| { + let xi = i / edges.len(); + let yi = i % edges.len(); + format!("base-edge a={:#x} b={:#x}", edges[xi], edges[yi]) + }); +} + +#[test] +fn ext3_mul_base_edges_in_b_and_c_slots() { + // Same edge values, but placed in the b and c slots so the cross terms + // (which involve the `b1_2 = 2*y.b`, `b2_2 = 2*y.c` doubling) are also + // exercised at edge inputs. Non-canonical doubling of P-1 etc. is a + // path that random tests rarely hit. + let edges: Vec = vec![0, 1, P - 1, P, P + 1, u64::MAX, EPSILON]; + let mut pairs = Vec::with_capacity(edges.len() * edges.len()); + for x in &edges { + for y in &edges { + // Put edge values in b/c slots, a non-trivial. + pairs.push(([1, *x, *y], [1, *x, *y])); + } + } + assert_ext3_mul_pairs(&pairs, |i| { + let xi = i / edges.len(); + let yi = i % edges.len(); + format!("ext3-bc-edge b={:#x} c={:#x}", edges[xi], edges[yi]) + }); +} diff --git a/crypto/math-cuda/tests/ext3_sub.rs b/crypto/math-cuda/tests/ext3_sub.rs new file mode 100644 index 000000000..1e94682a0 --- /dev/null +++ b/crypto/math-cuda/tests/ext3_sub.rs @@ -0,0 +1,109 @@ +//! Parity test for `ext3::sub` in kernels/ext3.cuh. This device +//! function is part of the public ext3 header but is not invoked by any +//! kernel in the PR — every other ext3 caller uses `mul`, `add`, or +//! `mul_base`. The PR's review test infrastructure adds an +//! `ext3_sub_kernel` so we can call it directly here for parity vs the +//! CPU `Degree3GoldilocksExtensionField::sub`. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +const N: usize = 10_000; +const P: u64 = 0xFFFF_FFFF_0000_0001; + +fn random_fp3s(seed: u64, count: usize) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..count) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() +} + +fn to_u64s(col: &[Fp3]) -> Vec { + let mut v = Vec::with_capacity(col.len() * 3); + for e in col { + v.push(*e.value()[0].value()); + v.push(*e.value()[1].value()); + v.push(*e.value()[2].value()); + } + v +} + +fn canon(x: u64) -> u64 { + GoldilocksField::canonical(&x) +} + +#[test] +fn ext3_sub_matches_cpu_random() { + let a = random_fp3s(101, N); + let b = random_fp3s(202, N); + let a_raw = to_u64s(&a); + let b_raw = to_u64s(&b); + let gpu = math_cuda::ext3_sub_u64(&a_raw, &b_raw).expect("GPU ext3 sub launch"); + assert_eq!(gpu.len(), 3 * N); + for i in 0..N { + let cpu = Degree3GoldilocksExtensionField::sub(a[i].value(), b[i].value()); + let cpu_fp3 = Fp3::new(cpu); + let g = [ + canon(gpu[3 * i]), + canon(gpu[3 * i + 1]), + canon(gpu[3 * i + 2]), + ]; + let c = [ + canon(*cpu_fp3.value()[0].value()), + canon(*cpu_fp3.value()[1].value()), + canon(*cpu_fp3.value()[2].value()), + ]; + assert_eq!(g, c, "ext3 sub mismatch at {i}"); + } +} + +#[test] +fn ext3_sub_edge_cases() { + // Underflow cases: a < b on each component, plus non-canonical p + // representations. + let cases: Vec<([u64; 3], [u64; 3])> = vec![ + ([0, 0, 0], [P - 1, P - 1, P - 1]), + ([1, 2, 3], [P - 1, P - 1, P - 1]), + ([P - 1, P - 1, P - 1], [0, 0, 0]), + ([P, P, P], [P, P, P]), // (0,0,0) - (0,0,0) + ([u64::MAX, u64::MAX, u64::MAX], [0, 0, 0]), + ([0, 0, 0], [u64::MAX, u64::MAX, u64::MAX]), + ]; + let mut a_raw = Vec::new(); + let mut b_raw = Vec::new(); + for (a, b) in &cases { + a_raw.extend_from_slice(a); + b_raw.extend_from_slice(b); + } + let gpu = math_cuda::ext3_sub_u64(&a_raw, &b_raw).expect("GPU ext3 sub launch"); + for (i, (a, b)) in cases.iter().enumerate() { + let ae = [Fp::from_raw(a[0]), Fp::from_raw(a[1]), Fp::from_raw(a[2])]; + let be = [Fp::from_raw(b[0]), Fp::from_raw(b[1]), Fp::from_raw(b[2])]; + let cpu = Degree3GoldilocksExtensionField::sub(&ae, &be); + let cpu_fp3 = Fp3::new(cpu); + let g = [ + canon(gpu[3 * i]), + canon(gpu[3 * i + 1]), + canon(gpu[3 * i + 2]), + ]; + let c = [ + canon(*cpu_fp3.value()[0].value()), + canon(*cpu_fp3.value()[1].value()), + canon(*cpu_fp3.value()[2].value()), + ]; + assert_eq!(g, c, "ext3 sub edge mismatch at {i}: a={a:?} b={b:?}"); + } +} diff --git a/crypto/math-cuda/tests/keccak_leaves.rs b/crypto/math-cuda/tests/keccak_leaves.rs index 1e451b386..d614e233d 100644 --- a/crypto/math-cuda/tests/keccak_leaves.rs +++ b/crypto/math-cuda/tests/keccak_leaves.rs @@ -1,64 +1,22 @@ -//! Parity: GPU Keccak-256 leaf hashes must match CPU -//! `FieldElementVectorBackend::::hash_data` applied to -//! bit-reversed rows (same pattern as `commit_columns_bit_reversed` in the -//! stark prover). +//! Parity: GPU Keccak-256 leaf hashes must match the CPU prover's leaf +//! hashing helpers. `stark::prover::keccak_leaves_bit_reversed` for +//! per-row commits, `keccak_leaves_row_pair_bit_reversed` for the R2 +//! composition commit, and `FriLayerMerkleTreeBackend::hash_data` for the +//! FRI commit. These are the same helpers the prover itself calls so any +//! change to the CPU leaf-hash contract surfaces here. +use crypto::merkle_tree::traits::IsMerkleTreeBackend; use math::field::element::FieldElement; use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; use math::field::goldilocks::GoldilocksField; -use math::traits::ByteConversion; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; -use sha3::{Digest, Keccak256}; +use stark::config::FriLayerMerkleTreeBackend; +use stark::prover::{keccak_leaves_bit_reversed, keccak_leaves_row_pair_bit_reversed}; type Fp = FieldElement; type Fp3 = FieldElement; -fn reverse_index(i: u64, n: u64) -> u64 { - let log_n = n.trailing_zeros(); - i.reverse_bits() >> (64 - log_n) -} - -fn cpu_leaves_base(columns: &[Vec]) -> Vec<[u8; 32]> { - let num_rows = columns[0].len(); - let num_cols = columns.len(); - let byte_len = 8; - (0..num_rows) - .map(|row_idx| { - let br = reverse_index(row_idx as u64, num_rows as u64) as usize; - let mut buf = vec![0u8; num_cols * byte_len]; - for c in 0..num_cols { - columns[c][br].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]); - } - let mut h = Keccak256::new(); - h.update(&buf); - let mut out = [0u8; 32]; - out.copy_from_slice(&h.finalize()); - out - }) - .collect() -} - -fn cpu_leaves_ext3(columns: &[Vec]) -> Vec<[u8; 32]> { - let num_rows = columns[0].len(); - let num_cols = columns.len(); - let byte_len = 24; - (0..num_rows) - .map(|row_idx| { - let br = reverse_index(row_idx as u64, num_rows as u64) as usize; - let mut buf = vec![0u8; num_cols * byte_len]; - for c in 0..num_cols { - columns[c][br].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]); - } - let mut h = Keccak256::new(); - h.update(&buf); - let mut out = [0u8; 32]; - out.copy_from_slice(&h.finalize()); - out - }) - .collect() -} - #[test] fn keccak_leaves_base_matches_cpu() { for log_n in [4u32, 6, 8, 10, 12] { @@ -69,7 +27,7 @@ fn keccak_leaves_base_matches_cpu() { .map(|_| (0..n).map(|_| Fp::from_raw(rng.r#gen::())).collect()) .collect(); - let cpu = cpu_leaves_base(&columns); + let cpu = keccak_leaves_bit_reversed(&columns); // Flatten columns into a contiguous base slab layout matching // `coset_lde_batch_base_into`'s pinned staging format: @@ -113,7 +71,7 @@ fn keccak_leaves_ext3_matches_cpu() { }) .collect(); - let cpu = cpu_leaves_ext3(&columns); + let cpu = keccak_leaves_bit_reversed(&columns); // GPU expects 3 base slabs per ext3 column in the order // [col*3+0 (comp a), col*3+1 (comp b), col*3+2 (comp c)], each a @@ -138,3 +96,103 @@ fn keccak_leaves_ext3_matches_cpu() { } } } + +#[test] +fn keccak_comp_poly_leaves_matches_cpu() { + // Built tree's leaves live at byte offset `(num_leaves - 1) * 32` and + // span `num_leaves * 32` bytes. Compare those to the CPU reference. + for log_lde in [2u32, 4, 6, 8, 10, 12] { + for num_parts in [1usize, 2, 5, 17] { + let lde_size = 1usize << log_lde; + let mut rng = ChaCha8Rng::seed_from_u64(300 + log_lde as u64 + num_parts as u64); + let parts: Vec> = (0..num_parts) + .map(|_| { + (0..lde_size) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() + }) + .collect(); + let cpu = keccak_leaves_row_pair_bit_reversed(&parts); + + // Each part is passed as `[a0,a1,a2, b0,b1,b2, ...]` of length `3 * lde_size`. + let parts_interleaved: Vec> = parts + .iter() + .map(|p| { + let mut v = vec![0u64; 3 * lde_size]; + for (i, e) in p.iter().enumerate() { + v[i * 3] = *e.value()[0].value(); + v[i * 3 + 1] = *e.value()[1].value(); + v[i * 3 + 2] = *e.value()[2].value(); + } + v + }) + .collect(); + let parts_slices: Vec<&[u64]> = + parts_interleaved.iter().map(|v| v.as_slice()).collect(); + + let nodes = + math_cuda::merkle::build_comp_poly_tree_from_evals_ext3(&parts_slices).unwrap(); + let num_leaves = lde_size / 2; + let leaves_offset = (num_leaves - 1) * 32; + for i in 0..num_leaves { + assert_eq!( + &nodes[leaves_offset + i * 32..leaves_offset + (i + 1) * 32], + &cpu[i][..], + "comp-poly leaf mismatch at i={i} (log_lde={log_lde}, parts={num_parts})" + ); + } + } + } +} + +#[test] +fn keccak_fri_leaves_matches_cpu() { + for log_lde in [2u32, 4, 6, 8, 10, 12] { + let lde_size = 1usize << log_lde; + let mut rng = ChaCha8Rng::seed_from_u64(400 + log_lde as u64); + let evals: Vec = (0..lde_size) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect(); + + // CPU reference: consecutive ext3 pairs hashed via the prover's + // FRI-layer Merkle backend. + let cpu: Vec<[u8; 32]> = evals + .chunks_exact(2) + .map(|c| { + FriLayerMerkleTreeBackend::::hash_data(&[ + c[0], c[1], + ]) + }) + .collect(); + + let mut evals_interleaved = vec![0u64; 3 * lde_size]; + for (i, e) in evals.iter().enumerate() { + evals_interleaved[i * 3] = *e.value()[0].value(); + evals_interleaved[i * 3 + 1] = *e.value()[1].value(); + evals_interleaved[i * 3 + 2] = *e.value()[2].value(); + } + let nodes = + math_cuda::merkle::build_fri_layer_tree_from_evals_ext3(&evals_interleaved).unwrap(); + let num_leaves = lde_size / 2; + let leaves_offset = (num_leaves - 1) * 32; + for i in 0..num_leaves { + assert_eq!( + &nodes[leaves_offset + i * 32..leaves_offset + (i + 1) * 32], + &cpu[i][..], + "fri leaf mismatch at i={i} (log_lde={log_lde})" + ); + } + } +} diff --git a/crypto/math-cuda/tests/lde.rs b/crypto/math-cuda/tests/lde.rs index facd2d861..110997e6e 100644 --- a/crypto/math-cuda/tests/lde.rs +++ b/crypto/math-cuda/tests/lde.rs @@ -1,4 +1,4 @@ -//! Phase-5 parity: GPU `coset_lde_base` must match the CPU +//! Parity: GPU `coset_lde_base` must match the CPU //! `Polynomial::coset_lde_full_expand` for a sweep of realistic sizes and //! blowup factors. @@ -12,8 +12,8 @@ use rand_chacha::ChaCha8Rng; type Fp = FieldElement; -/// Build the coset weights `[1/N, g/N, g²/N, …, g^{n-1}/N]` — this is the -/// layout `crypto/stark/src/prover.rs:248` uses, with `1/N` pre-folded into the +/// Build the coset weights `[1/N, g/N, g²/N, ..., g^{n-1}/N]` — this is the +/// layout `crypto/stark/src/prover.rs` uses, with `1/N` pre-folded into the /// first coefficient so the iFFT step does not need a separate scaling pass. fn coset_weights(n: usize, coset_offset: u64) -> Vec { let inv_n_fe = FieldElement::::from(n as u64) diff --git a/crypto/math-cuda/tests/lde_batch_into.rs b/crypto/math-cuda/tests/lde_batch_into.rs new file mode 100644 index 000000000..c3d25adbf --- /dev/null +++ b/crypto/math-cuda/tests/lde_batch_into.rs @@ -0,0 +1,87 @@ +//! Direct parity test for `coset_lde_batch_base_into` (lde.rs), the +//! caller-allocated-buffer variant of `coset_lde_batch_base`. The two should +//! produce bit-identical canonical output for the same inputs; the only +//! difference between them is who owns the output Vec. +//! +//! This is otherwise covered indirectly through `coset_lde_batch_base_into_with_leaf_hash` +//! and similar, but the base `_into` variant ships as public API with no +//! direct test in the original PR. + +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = *Fp::from(n as u64).inv().unwrap().value(); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &g); + } + w +} + +fn canon(xs: &[u64]) -> Vec { + xs.iter().map(GoldilocksField::canonical).collect() +} + +fn run_pair(log_n: u64, blowup: usize, m: usize, seed: u64) { + let n = 1usize << log_n; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let columns: Vec> = (0..m) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + + let weights = coset_weights(n, 7); + + let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + + // Reference: the existing batch-allocates-Vec API. Random-input tests + // already cross-validate this against the CPU single-column LDE in + // `lde_batch.rs`, so any divergence from it pinpoints `_into`. + let ref_out = math_cuda::lde::coset_lde_batch_base(&slices, blowup, &weights).unwrap(); + assert_eq!(ref_out.len(), m); + + // Caller-allocated buffers. + let mut owned: Vec> = (0..m).map(|_| vec![0u64; lde_size]).collect(); + { + let mut outs: Vec<&mut [u64]> = owned.iter_mut().map(|v| v.as_mut_slice()).collect(); + math_cuda::lde::coset_lde_batch_base_into(&slices, blowup, &weights, &mut outs) + .expect("into variant"); + } + + for c in 0..m { + assert_eq!( + canon(&owned[c]), + canon(&ref_out[c]), + "_into vs _batch_base diverge at column {c}, log_n={log_n}, blowup={blowup}, m={m}" + ); + } +} + +#[test] +fn into_matches_batch_base_small() { + run_pair(8, 4, 4, 1); + run_pair(10, 4, 1, 2); + run_pair(10, 4, 16, 3); +} + +#[test] +fn into_matches_batch_base_medium() { + run_pair(14, 4, 8, 4); + run_pair(15, 4, 32, 5); +} + +#[test] +fn into_matches_batch_base_uneven_blowup() { + // Non-default blowup factors (still power of two) — confirms the + // _into variant respects blowup_factor identically. + run_pair(8, 2, 4, 6); + run_pair(8, 8, 4, 7); +} diff --git a/crypto/math-cuda/tests/merkle_tree.rs b/crypto/math-cuda/tests/merkle_tree.rs index 34d44c767..76fdeb919 100644 --- a/crypto/math-cuda/tests/merkle_tree.rs +++ b/crypto/math-cuda/tests/merkle_tree.rs @@ -1,44 +1,17 @@ //! Parity: GPU Merkle inner-tree construction must match the CPU //! `crypto/crypto/src/merkle_tree/merkle.rs` `build_from_hashed_leaves` -//! (Keccak-256 pair hash at each level). +//! (Keccak-256 pair hash at each level). Uses the prover's +//! `FieldElementVectorBackend<_, Keccak256, 32>` directly so any change to +//! the CPU tree builder is automatically exercised here. +use crypto::merkle_tree::backends::field_element_vector::FieldElementVectorBackend; +use crypto::merkle_tree::merkle::MerkleTree; +use math::field::goldilocks::GoldilocksField; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; -use sha3::{Digest, Keccak256}; +use sha3::Keccak256; -fn cpu_hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] { - let mut h = Keccak256::new(); - h.update(left); - h.update(right); - let mut out = [0u8; 32]; - out.copy_from_slice(&h.finalize()); - out -} - -/// CPU reference: same algorithm as `build_from_hashed_leaves`. -fn cpu_merkle_nodes(leaves: &[[u8; 32]]) -> Vec<[u8; 32]> { - let leaves_len = leaves.len(); - assert!(leaves_len.is_power_of_two() && leaves_len >= 2); - let total = 2 * leaves_len - 1; - - let mut nodes: Vec<[u8; 32]> = vec![[0u8; 32]; total]; - for (i, leaf) in leaves.iter().enumerate() { - nodes[leaves_len - 1 + i] = *leaf; - } - - let mut level_begin = leaves_len - 1; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - for j in 0..n_pairs { - let left = nodes[level_begin + 2 * j]; - let right = nodes[level_begin + 2 * j + 1]; - nodes[new_begin + j] = cpu_hash_pair(&left, &right); - } - level_begin = new_begin; - } - nodes -} +type CpuTree = MerkleTree>; fn run_parity(log_n: u32, seed: u64) { let leaves_len = 1usize << log_n; @@ -60,11 +33,12 @@ fn run_parity(log_n: u32, seed: u64) { let gpu_nodes_bytes = math_cuda::merkle::build_merkle_tree_on_device(&flat).unwrap(); assert_eq!(gpu_nodes_bytes.len(), (2 * leaves_len - 1) * 32); - let cpu_nodes = cpu_merkle_nodes(&leaves); + // CPU reference: the prover's MerkleTree builder over the same backend. + let cpu_tree = CpuTree::build_from_hashed_leaves(leaves).unwrap(); + let cpu_nodes = cpu_tree.nodes(); - for i in 0..cpu_nodes.len() { + for (i, c) in cpu_nodes.iter().enumerate() { let g = &gpu_nodes_bytes[i * 32..(i + 1) * 32]; - let c = &cpu_nodes[i]; assert_eq!( g, c, "node {i} mismatch at log_n={log_n} (cpu={c:?}, gpu={g:?})" @@ -79,13 +53,6 @@ fn merkle_tree_small() { } } -#[test] -fn merkle_tree_medium() { - for log_n in [10u32, 12, 14] { - run_parity(log_n, 500 + log_n as u64); - } -} - #[test] fn merkle_tree_large() { run_parity(18, 9999); diff --git a/crypto/math-cuda/tests/ntt.rs b/crypto/math-cuda/tests/ntt.rs index f3689cf94..c02892204 100644 --- a/crypto/math-cuda/tests/ntt.rs +++ b/crypto/math-cuda/tests/ntt.rs @@ -1,4 +1,4 @@ -//! Phase-3 parity: GPU forward NTT must agree with `Polynomial::evaluate_fft` +//! Parity: GPU forward NTT must agree with `Polynomial::evaluate_fft` //! as a field element, across a sweep of sizes from 2^4 to 2^20. //! //! Non-canonical u64s can differ between CPU and GPU while representing the @@ -61,13 +61,11 @@ fn ntt_sizes_medium() { #[test] fn ntt_size_2_to_20() { - // The hot LDE size. One seed is enough; any mismatch screams loudly. assert_ntt_match(20, 0xDEAD); } #[test] fn ntt_trivial_sizes() { - // Power-of-two below the interesting range — should still pass. assert_ntt_match(1, 1); assert_ntt_match(2, 2); assert_ntt_match(3, 3); diff --git a/crypto/math-cuda/tests/ntt_known.rs b/crypto/math-cuda/tests/ntt_known.rs new file mode 100644 index 000000000..f0a7b9f5c --- /dev/null +++ b/crypto/math-cuda/tests/ntt_known.rs @@ -0,0 +1,125 @@ +//! Known-answer NTT test. Random-input tests catch most bugs but can mask +//! systematic errors (sign flips, off-by-one twiddle indices, wrong-direction +//! butterflies) that would cancel under noise. This test picks a polynomial +//! with a known closed-form evaluation at every root of unity and compares +//! the GPU forward NTT to that reference, computed independently from any +//! FFT code path. + +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsFFTField, IsPrimeField}; + +type Fp = FieldElement; + +fn canon(x: u64) -> u64 { + GoldilocksField::canonical(&x) +} + +/// p(x) = 1 + x. NTT at size N is the vector [p(omega^0), p(omega^1), ..., +/// p(omega^(N-1))] for `omega` a primitive N-th root of unity. We compute +/// the reference by direct exponentiation of `omega` (no FFT involved) +/// so a bug in either the CPU or GPU NTT can't hide here. +#[test] +fn ntt_known_polynomial_x_plus_one_size_256() { + let log_n: u64 = 8; + let n: usize = 1 << log_n; + + // GoldilocksField uses bit-reversed coefficient layout for forward NTT: + // input coeffs at index `i` become `p(omega^bitrev(i))`. Confirm by + // matching against the existing forward() API which random-input tests + // already validate against `Polynomial::evaluate_fft`. The known-poly + // value lets us catch systematic errors in that pipeline that random + // inputs miss. + + // Input coefficients: [1, 1, 0, 0, ..., 0]. (Natural order, lowest + // degree first. `Polynomial::new` and `math_cuda::ntt::forward` both + // expect this convention.) + let mut input = vec![0u64; n]; + input[0] = 1; + input[1] = 1; + + let gpu = math_cuda::ntt::forward(&input).expect("gpu ntt"); + assert_eq!(gpu.len(), n); + + // Reference: omega = primitive N-th root of unity in Goldilocks. + // p(omega^i) = 1 + omega^i. + let omega = GoldilocksField::get_primitive_root_of_unity(log_n).expect("root of unity"); + let one = Fp::from_raw(1); + + let mut expected = Vec::with_capacity(n); + let mut omega_i = one; // omega^0 + for _ in 0..n { + let val = &one + &omega_i; + expected.push(*val.value()); + omega_i = &omega_i * ω + } + + for (i, (&g_raw, &e_raw)) in gpu.iter().zip(expected.iter()).enumerate() { + let g = canon(g_raw); + let e = canon(e_raw); + if g != e { + panic!( + "p(omega^{i}) mismatch: gpu canon {:#018x}, expected canon {:#018x} (omega^{i} computed independently of any FFT)", + g, e, + ); + } + } +} + +/// Same idea, smaller: p(x) = 1 + x at size 2^4 = 16. A failure at this +/// size with passes at larger sizes (or vice-versa) would point at a +/// boundary bug between the recursive base case and the 8-level +/// shared-memory fused step in the GPU NTT. +#[test] +fn ntt_known_polynomial_x_plus_one_size_16() { + let log_n: u64 = 4; + let n: usize = 1 << log_n; + + let mut input = vec![0u64; n]; + input[0] = 1; + input[1] = 1; + + let gpu = math_cuda::ntt::forward(&input).expect("gpu ntt"); + + let omega = GoldilocksField::get_primitive_root_of_unity(log_n).expect("root"); + let one = Fp::from_raw(1); + let mut omega_i = one; + for (i, &g) in gpu.iter().enumerate() { + let exp = &one + &omega_i; + assert_eq!( + canon(g), + canon(*exp.value()), + "p(omega^{i}) mismatch at size 16" + ); + omega_i = &omega_i * ω + } +} + +/// p(x) = x^k for k = N/2. p(omega^i) = omega^(k*i). With k = N/2, +/// omega^(k*i) = (-1)^i since omega^(N/2) = -1 in any field with a +/// primitive N-th root of unity. So evaluations alternate +1, -1, +1, -1. +/// This is a strong test of twiddle-index direction and sign. +#[test] +fn ntt_known_polynomial_x_half_alternating() { + let log_n: u64 = 8; + let n: usize = 1 << log_n; + let k = n / 2; + + let mut input = vec![0u64; n]; + input[k] = 1; // p(x) = x^(N/2) + + let gpu = math_cuda::ntt::forward(&input).expect("gpu ntt"); + + // Expected: omega^(k*i) for i = 0..N, which is (omega^k)^i = (-1)^i. + // canonical(+1) = 1; canonical(-1) = p - 1. + let p_minus_one = 0xFFFF_FFFF_0000_0001u64 - 1; + for (i, &g) in gpu.iter().enumerate() { + let exp = if i % 2 == 0 { 1u64 } else { p_minus_one }; + assert_eq!( + canon(g), + exp, + "x^(N/2) NTT alternation mismatch at i={i}: got {:#018x}", + canon(g) + ); + } +} diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index e71cda72f..c4458763f 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -10,7 +10,7 @@ use math::fft::errors::FFTError; use log::info; use math::field::traits::{IsField, IsSubFieldOf}; -use math::traits::AsBytes; +use math::traits::{AsBytes, ByteConversion}; use math::{ field::{element::FieldElement, traits::IsFFTField}, polynomial::Polynomial, @@ -409,6 +409,102 @@ where } } +/// Compute Keccak-256 leaf hashes for `commit_columns_bit_reversed`: one +/// leaf per row, where each row is read at `reverse_index(row_idx)` and the +/// columns are concatenated as big-endian bytes before hashing. +/// +/// Returns `Vec` with the same length as `columns[0]`. Exposed +/// (instead of being a closure inside `commit_columns_bit_reversed`) so +/// parity tests in dependent crates can compare against the same code path +/// the prover uses. +pub fn keccak_leaves_bit_reversed(columns: &[Vec>]) -> Vec +where + E: IsField, + FieldElement: AsBytes + Sync + Send + ByteConversion, +{ + if columns.is_empty() || columns[0].is_empty() { + return Vec::new(); + } + + let num_rows = columns[0].len(); + let num_cols = columns.len(); + let byte_len = as ByteConversion>::BYTE_LEN; + + debug_assert!( + num_rows.is_power_of_two(), + "num_rows must be a power of two for reverse_index" + ); + + #[cfg(feature = "parallel")] + let iter = (0..num_rows).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let iter = 0..num_rows; + + iter.map(|row_idx| { + let br_idx = reverse_index(row_idx, num_rows as u64); + let total_bytes = num_cols * byte_len; + let mut buf = vec![0u8; total_bytes]; + for col_idx in 0..num_cols { + columns[col_idx][br_idx] + .write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]); + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) + }) + .collect() +} + +/// Compute Keccak-256 leaf hashes for `commit_composition_polynomial`: one +/// leaf per row-pair, where leaf `i` hashes the BE concatenation of +/// `parts[..][br_0] ++ parts[..][br_1]` with +/// `br_k = reverse_index(2*i + k, num_rows)`. +/// +/// Returns `Vec` of length `parts[0].len() / 2`. +pub fn keccak_leaves_row_pair_bit_reversed(parts: &[Vec>]) -> Vec +where + E: IsField, + FieldElement: AsBytes + Sync + Send + ByteConversion, +{ + let num_parts = parts.len(); + if num_parts == 0 { + return Vec::new(); + } + let num_rows = parts[0].len(); + if num_rows == 0 { + return Vec::new(); + } + + let num_leaves = num_rows / 2; + debug_assert!( + num_rows.is_power_of_two(), + "num_rows must be a power of two for reverse_index" + ); + + let byte_len = as ByteConversion>::BYTE_LEN; + + #[cfg(feature = "parallel")] + let iter = (0..num_leaves).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let iter = 0..num_leaves; + + iter.map(|leaf_idx| { + let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); + let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); + let total_bytes = 2 * num_parts * byte_len; + let mut buf = vec![0u8; total_bytes]; + let mut offset = 0; + for part in parts.iter() { + part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + for part in parts.iter() { + part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) + }) + .collect() +} + /// The functionality of a STARK prover providing methods to run the STARK Prove protocol /// https://lambdaclass.github.io/lambdaworks/starks/protocol.html /// The default implementation is complete and is compatible with Stone prover @@ -435,41 +531,10 @@ pub trait IsStarkProver< FieldElement: AsBytes + Sync + Send + math::traits::ByteConversion, E: IsField, { - use math::traits::ByteConversion; - if columns.is_empty() || columns[0].is_empty() { return None; } - - let num_rows = columns[0].len(); - let num_cols = columns.len(); - let byte_len = as ByteConversion>::BYTE_LEN; - - debug_assert!( - num_rows.is_power_of_two(), - "num_rows must be a power of two for reverse_index" - ); - - #[cfg(feature = "parallel")] - let iter = (0..num_rows).into_par_iter(); - #[cfg(not(feature = "parallel"))] - let iter = 0..num_rows; - - // One allocation per row (was one per field element): write all columns - // into a single buffer, then hash once. - let hashed_leaves: Vec = iter - .map(|row_idx| { - let br_idx = reverse_index(row_idx, num_rows as u64); - let total_bytes = num_cols * byte_len; - let mut buf = vec![0u8; total_bytes]; - for col_idx in 0..num_cols { - columns[col_idx][br_idx] - .write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]); - } - BatchedMerkleTreeBackend::::hash_bytes(&buf) - }) - .collect(); - + let hashed_leaves = keccak_leaves_bit_reversed(columns); let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; Some((tree, root)) @@ -795,8 +860,6 @@ pub trait IsStarkProver< FieldElement: AsBytes + Sync + Send, FieldElement: AsBytes + Sync + Send + math::traits::ByteConversion, { - use math::traits::ByteConversion; - let num_parts = lde_composition_poly_parts_evaluations.len(); if num_parts == 0 { return None; @@ -805,41 +868,8 @@ pub trait IsStarkProver< if num_rows == 0 { return None; } - - let num_leaves = num_rows / 2; - debug_assert!( - num_rows.is_power_of_two(), - "num_rows must be a power of two for reverse_index" - ); - - let byte_len = as ByteConversion>::BYTE_LEN; - - // One allocation per leaf (was one per field element): write all parts - // into a single buffer. Each leaf = row_pair[2*i] ++ row_pair[2*i+1] after bit-reverse. - #[cfg(feature = "parallel")] - let iter = (0..num_leaves).into_par_iter(); - #[cfg(not(feature = "parallel"))] - let iter = 0..num_leaves; - - let hashed_leaves: Vec = iter - .map(|leaf_idx| { - let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); - let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); - let total_bytes = 2 * num_parts * byte_len; - let mut buf = vec![0u8; total_bytes]; - let mut offset = 0; - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - BatchedMerkleTreeBackend::::hash_bytes(&buf) - }) - .collect(); - + let hashed_leaves = + keccak_leaves_row_pair_bit_reversed(lde_composition_poly_parts_evaluations); let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; Some((tree, root)) diff --git a/docs/cryptography/lookup.md b/docs/cryptography/lookup.md index 99c2bf803..efde8ecdc 100644 --- a/docs/cryptography/lookup.md +++ b/docs/cryptography/lookup.md @@ -1,111 +1,141 @@ -# Lookup Arguments +# Lookup arguments -Lookup arguments are a cryptographic technique that allows a prover to demonstrate that values in one table appear in another table, without revealing the actual values. They are essential for building efficient virtual machines where different components (CPU, memory, ALU) need to verify consistency with each other. +Lambda VM uses **LogUp** lookup arguments to connect its trace tables. Each table generates and consumes "tokens" on one or more **named buses**; the system is sound when, across every bus, the total sender contribution equals the total receiver contribution. -## Why Lookup Arguments? +Lookups are how the prover proves cross-table relations without duplicating constraints. For example, the CPU table proves it dispatched a bitwise AND instruction by sending a token `(AndByte, x, y, x & y)` on the `AndByte` bus; the BITWISE table proves it has a row matching that token by sending a receiver token on the same bus. If both sides match, the bus balances. If a sender has no matching receiver (or vice versa), the bus does not balance and verification fails. -In a virtual machine proof, different execution tables need to communicate: +The implementation lives in [`crypto/stark/src/lookup.rs`](../../crypto/stark/src/lookup.rs). -- The **CPU table** performs operations and accesses memory -- The **Memory table** stores and retrieves values -- The **ALU table** computes arithmetic operations +## The `BusInteraction` struct -Without lookups, verifying that "the CPU read value X from address Y" would require expensive polynomial constraints. Lookup arguments provide an efficient way to prove these cross-table relationships. +A single lookup contribution is a `BusInteraction`: -## The LogUp Protocol +```rust +pub struct BusInteraction { + pub bus_id: u64, + pub multiplicity: Multiplicity, + pub values: Vec, + pub is_sender: bool, +} +``` -We use the **LogUp** (Logarithmic Derivative Lookup) protocol, which is based on a key mathematical insight: two multisets are equal if and only if their logarithmic derivatives are equal. +| Field | Role | +|---|---| +| `bus_id` | Names the bus. Senders and receivers must use the same `bus_id` for their tokens to match. Different buses use different IDs so that, e.g., an `And` token doesn't accidentally cancel an `Xor` token. | +| `multiplicity` | How many times this row contributes (see below). | +| `values` | The token payload — the data being looked up. | +| `is_sender` | Carries the sign. Senders add to the bus sum, receivers subtract. The balance check is `Σ sender − Σ receiver = 0`. | -### Fingerprints +Build them with `BusInteraction::sender(bus_id, mul, values)` or `::receiver(bus_id, mul, values)`. -Each row in a table is compressed into a single field element called a **fingerprint**: +## Named buses (`BusId`) -``` -fingerprint = 1 / (z - (v₀ + v₁·α + v₂·α² + ...)) -``` +Bus IDs are declared in [`prover/src/tables/types.rs`](../../prover/src/tables/types.rs) as a `#[repr(u64)]` enum: -Where: -- `z` and `α` are random challenges sampled via Fiat-Shamir -- `v₀, v₁, v₂, ...` are the column values in that row +```rust +#[repr(u64)] +pub enum BusId { + IsByte = 0, + IsHalfword, + IsB20, + AndByte, + OrByte, + XorByte, + Msb8, + Msb16, + Zero, + Hwsl, + Lt, + Mul, + Dvrm, + Shift, + Memw, + Load, + Memory, + // ... +} +``` -The linear combination `v₀ + v₁·α + v₂·α² + ...` compresses multiple columns into one value, and `z` shifts it to enable the logarithmic derivative form. +Each value is a `u64` discriminant, auto-incremented from 0. `BusInteraction::new` takes `impl Into` so you pass `BusId::AndByte` directly. -### Running Sum +## Multiplicity -For each table interaction, we build an auxiliary column that accumulates fingerprints: +How many copies of the token a row contributes. Most rows are `One`, but tables that deduplicate or have flag-gated participation use richer forms: -``` -s[i+1] = s[i] + multiplicity[i] / (z - linear_combination[i]) +```rust +pub enum Multiplicity { + One, // 1 + Column(usize), // col[i] + Sum(usize, usize), // col[a] + col[b] + Negated(usize), // 1 - col[i] (col must be a bit) + Diff(usize, usize), // col[a] - col[b] + Sum3(usize, usize, usize), // col[a] + col[b] + col[c] + Linear(Vec), // arbitrary signed combination +} ``` -Where `multiplicity` indicates how many times this row participates in the lookup: -- **Positive** for rows being "looked up" (proving side) -- **Negative** for rows doing the "looking" (assuming side) +`Linear` is the escape hatch — it supports signed coefficients and large unsigned coefficients (e.g. `2^{-32} mod p`), and is how interactions like `μ − read2 − read4 − read8` are expressed. -### Bus Balancing +## Bus values (the token payload) -The key property: if all lookups are valid, the sum of all fingerprints across all tables equals zero. This is because every "send" (negative multiplicity) has a matching "receive" (positive multiplicity). +Each entry in `BusInteraction.values` is a `BusValue`: -``` -Σ (sends) + Σ (receives) = 0 +```rust +pub enum BusValue { + Packed { start_column: usize, packing: Packing }, + Linear(Vec), +} ``` -This is verified by checking that the final values of all running-sum columns sum to zero. +- `Packed` reads consecutive trace columns and combines them via a `Packing` formula (powers of 2). For example, `Packing::Word2L` at `start_column = 4` reads columns 4 and 5 and computes `c₄ + 2¹⁶·c₅`, producing one bus element representing a 32-bit word. +- `Linear` is an arbitrary signed linear combination over columns and constants — used when the value is a flag, a constant tag, or a derived expression that doesn't fit a `Packing`. -## Multi-Table Challenge Sharing +The `Packing` enum supports primitive shapes (`Direct`, `Word2L`, `Word4L`) and compound shapes (`DWordHL`, `DWordBL`, `QuadHL`, …) that produce multiple bus elements. A 64-bit double-word stored as 4 half-words is one `BusValue::Packed { packing: DWordHL, .. }` that yields two bus elements. -For the bus to balance correctly, **all tables must use the same random challenges** `(z, α)`. This is critical for security and correctness. +## Two-stage value combination -### Protocol Flow +A token's contribution to the bus is computed in **two stages**: -1. **Commit all main traces**: Each table commits its main execution trace to the transcript -2. **Sample shared challenges**: After ALL main traces are committed, sample `z` and `α` once -3. **Build auxiliary traces**: Each table builds its running-sum columns using the shared challenges -4. **Commit auxiliary traces**: Each table commits its auxiliary trace -5. **Continue STARK protocol**: Proceed with composition polynomial, FRI, etc. +1. **Limb packing.** Within each `BusValue`, columns are combined using powers of 2 according to the chosen `Packing`. This is how multi-limb values are formed from their column-level representation (e.g. assembling a 32-bit word from four 8-bit byte columns). -### Why Share Challenges? +2. **Bus fingerprint.** All bus elements from the interaction — starting with the `bus_id`, then the elements produced by each `BusValue` — are folded together using powers of a single challenge α: -If tables used different challenges: -- Table A computes fingerprints with `(z₁, α₁)` -- Table B computes fingerprints with `(z₂, α₂)` -- The fingerprints don't match even for identical values -- The bus cannot balance, and valid proofs become impossible + ``` + fingerprint = z − (bus_id + α·v₁ + α²·v₂ + … + α^(k−1)·v_{k−1}) + ``` -By sharing challenges, fingerprints for the same values are identical across tables, enabling the bus to balance. + The interaction's contribution at this row is `± multiplicity / fingerprint`, with the sign coming from `is_sender`. -## Implementation +The `bus_id` is the first bus element. This is what makes tokens on different buses non-interfering: two interactions on `BusId::Mul` and `BusId::Lt` have different fingerprints even when all the data values match. -### Challenge Constants +## Challenges -```rust -// Index of the `z` challenge - evaluation point for fingerprints -pub const LOGUP_CHALLENGE_Z: usize = 0; +LogUp uses two challenges sampled from the transcript after the main trace is committed: -// Index of the `α` challenge - base for linear combination -pub const LOGUP_CHALLENGE_ALPHA: usize = 1; +- `z` — read as `challenges[0]`; the subtractor in the fingerprint denominator. +- `α` — read as `challenges[LOGUP_CHALLENGE_ALPHA]` where `LOGUP_CHALLENGE_ALPHA = 1`; the base for the powers-of-α combination of bus elements. -// Total number of LogUp challenges -pub const LOGUP_NUM_CHALLENGES: usize = 2; -``` +The total challenge count is `LOGUP_NUM_CHALLENGES = 2`. There is no separate `LOGUP_CHALLENGE_Z` constant — `z` is just the first challenge by convention. -### Table Interactions +## Bus balance -Each AIR defines its lookup interactions via `TableInteraction`: +For every bus to be sound, across all tables and all rows: -```rust -pub struct TableInteraction { - pub flag_columns: Vec, // Columns indicating participation (multiplicity) - pub value_columns: Vec, // Columns containing the looked-up values -} +``` +Σ (over senders) multiplicity / fingerprint +− +Σ (over receivers) multiplicity / fingerprint += 0 ``` -### Auxiliary Trace +In code this becomes: every table contributes a per-bus running sum (the "table contribution") to its auxiliary trace; the verifier checks that the sum of table contributions equals the expected bus balance. For most buses the expected balance is zero. The `Commit` bus is an exception: its expected balance is recomputed by the verifier from the public output bytes (see `compute_commit_bus_offset` in [`prover/src/lib.rs`](../../prover/src/lib.rs)) so that tampering with the proof's public output is caught. -The auxiliary trace contains one running-sum column per interaction, plus optionally a grand-sum column that aggregates all interactions. +When `--features debug-checks` is on, [`crypto/stark/src/bus_debug.rs`](../../crypto/stark/src/bus_debug.rs) prints per-bus sender vs. receiver sums to help diagnose imbalances during development. -## Security Considerations +## Implementation pointers -1. **Challenge derivation**: Challenges must be sampled via Fiat-Shamir after all main traces are committed to prevent manipulation -2. **Shared challenges**: All tables in a multi-table proof MUST use identical challenges -3. **Field size**: The field must be large enough that random challenges don't accidentally cause fingerprint collisions +- Interaction shape, packing, balance: [`crypto/stark/src/lookup.rs`](../../crypto/stark/src/lookup.rs) +- Named bus IDs used by the VM: [`prover/src/tables/types.rs`](../../prover/src/tables/types.rs) +- Per-table interactions: [`prover/src/tables/`](../../prover/src/tables/) (one file per table) +- Verifier-side bus offset for the COMMIT bus: [`prover/src/lib.rs`](../../prover/src/lib.rs) (`compute_commit_bus_offset`) +- Debug-checks bus diagnostics: [`crypto/stark/src/bus_debug.rs`](../../crypto/stark/src/bus_debug.rs) diff --git a/docs/cryptography/proof_system.md b/docs/cryptography/proof_system.md index 8683c1747..89b427426 100644 --- a/docs/cryptography/proof_system.md +++ b/docs/cryptography/proof_system.md @@ -8,17 +8,20 @@ The proof system is the component responsible for generating the certificate of 5. Have short proofs. This section will cover the basic cryptographic primitives needed for the proof system and a description of the whole proof system and arguments used. Core concepts are: -1. [Finite field](./finite_field.md) -2. [Polynomials](./polynomials.md) -3. [Extension field](./extension_field.md) -4. [Hash function](./hash_function.md) -5. [Fast-Fourier transform](./fast_fourier_transform.md) -6. [Reed-Solomon codes](./reed_solomon_codes.md) -7. [Constraint](./constraint.md) -8. [Algebraic intermediate representation](./air.md) -9. [Interactive oracle proof](./iop.md) -10. [Fast Reed-Solomon Interactive Oracle Proof of Proximity (FRI)](./fri.md) -11. [Provable security and conjectured security](./security.md) -12. [Lookup argument](./lookup.md) + +> **Note:** the chapters below are a work in progress. + +1. Finite field +2. Polynomials +3. Extension field +4. Hash function +5. Fast-Fourier transform +6. Reed-Solomon codes +7. Constraint +8. Algebraic intermediate representation +9. Interactive oracle proof +10. Fast Reed-Solomon Interactive Oracle Proof of Proximity (FRI) +11. Provable security and conjectured security +12. Lookup argument The flow of the proof system is described in the following section. \ No newline at end of file diff --git a/docs/general_flow.md b/docs/general_flow.md index dae345683..deee5e4fe 100644 --- a/docs/general_flow.md +++ b/docs/general_flow.md @@ -1,14 +1,20 @@ -# Description - -The different components that form the pipeline for proving the correctness of the execution of a given program on an input stream are: -1. The source code of the program, written in high-level language -2. The program binary, ready for the virtual machine -3. The execution record of the binary over the VM architecture for a given input -4. The witness of the computation, generated from the execution record. Typically, this will consist of several tables, called trace tables. -5. The proof of validity of the witness for some language and VM architecture - -The steps are as follows: -1. The *compiler* transforms the program into the binary. -2. The *executor* takes a binary, an input stream and an VM architecture and produces the execution record. -3. The *witness generator* transforms the execution record into a witness compatible with the chosen arithmetisation and constraint system. -4. The *proof system* takes the witness and the constraint system, and produces a (set of) proof(s) that the former satisfies the latter. \ No newline at end of file +# Overview of VM flow + +The Lambda VM proves correct execution of a RISC-V (RV64IM) program against an input stream. The pipeline has five artifacts and four transformations. + +## Artifacts + +1. **Source code** — high-level Rust (using [`syscalls/`](../syscalls/) for guest-host I/O) or RISC-V assembly. +2. **ELF binary** — the program in the VM's ISA, ready to load. +3. **Execution record** — per-instruction logs emitted by running the ELF on the VM. +4. **Witness** — a set of trace tables (CPU, decode, MEMW, LOAD, bitwise, branch, LT, shift, MUL, DVRM, page, register, halt, commit, keccak, …) derived from the execution record. Each table is an AIR (Algebraic Intermediate Representation); tables are linked by LogUp lookup arguments. +5. **Proof** — a multi-table STARK proof (transparent, hash-based, post-quantum secure) that the witness satisfies all AIR constraints and lookup arguments. Low-degree of the witness polynomials is verified via FRI. + +## Transformations + +1. **Compiler** — `rustc` cross-compiles to the custom RISC-V target spec ([`executor/programs/riscv64im-lambda-vm-elf.json`](../executor/programs/riscv64im-lambda-vm-elf.json)) and produces the ELF. The `lambda-vm-syscalls` crate exposes guest-side syscalls (`commit`, `get_private_input`, `print_string`, `keccak_permute`, `sys_halt`). +2. **Executor** ([`executor/`](../executor/)) — loads the ELF, runs the program against the VM's memory and register state, handles syscalls and precompiles (e.g. Keccak), and emits the per-instruction logs. +3. **Witness generator** ([`prover/src/tables/`](../prover/src/tables/)) — turns the logs into trace tables, populates AIR columns, and computes the LogUp auxiliary columns that connect tables. +4. **Proof system** ([`crypto/stark/`](../crypto/stark/)) — commits to each table's trace via Merkle trees, samples challenges via Fiat-Shamir, and runs FRI for the low-degree test. Produces a `MultiProof`; the verifier replays the transcript and checks all AIR and lookup constraints. + +For a deeper dive into each component see the [proof system overview](./cryptography/proof_system.md). diff --git a/executor/README.md b/executor/README.md new file mode 100644 index 000000000..df9911038 --- /dev/null +++ b/executor/README.md @@ -0,0 +1,53 @@ +# Lambda VM Executor + +RISC-V (RV64IM) emulator for the Lambda VM. Loads ELF binaries, runs them against an in-memory VM state, and emits the per-instruction execution logs that the [prover](../prover) turns into a STARK trace. + +Published as `executor`. Used directly by the CLI and the prover; you can also drive it from Rust. + +## Usage + +```rust +use executor::elf::Elf; +use executor::vm::execution::Executor; + +let elf_bytes = std::fs::read("program.elf")?; +let program = Elf::load(&elf_bytes)?; +let executor = Executor::new(&program, /* private input */ vec![])?; +let result = executor.run()?; + +println!("Executed {} instructions", result.logs.len()); +``` + +For chunked execution (useful when you don't want to hold all logs in memory), drive the executor via `executor.resume()` in a loop until it yields `None`, then call `executor.finish()`. See [`bin/cli/src/main.rs`](../bin/cli/src/main.rs) for an example. + +## Example programs + +The repo ships ready-to-use guest programs in three flavours, all compiled by Makefile targets at the repo root: + +- [`programs/asm/`](./programs/asm/) — raw RISC-V assembly. Built with `make compile-programs-asm` into `program_artifacts/asm/`. +- [`programs/rust/`](./programs/rust/) — Rust guest projects (`fibonacci`, `keccak`, `hashmap`, …). Built with `make compile-programs-rust` into `program_artifacts/rust/`. Requires the pinned nightly toolchain and sysroot — see the root [`README.md`](../README.md). +- [`programs/bench/`](./programs/bench/) — benchmark programs. Built with `make compile-bench`. + +The custom RISC-V target spec used for Rust guests lives at [`programs/riscv64im-lambda-vm-elf.json`](./programs/riscv64im-lambda-vm-elf.json). + +## Tests + +```sh +# Compile all programs and run executor tests +make test-executor + +# Just the asm tests +make test-asm + +# Just the Rust tests +make test-rust +``` + +To add a new test: + +- **ASM**: add a `.s` file under [`programs/asm/`](./programs/asm/) and a matching entry in [`tests/asm.rs`](./tests/asm.rs). +- **Rust**: add a cargo project under [`programs/rust//`](./programs/rust/) (the directory and the `Cargo.toml` package name must match) and a matching entry in [`tests/rust.rs`](./tests/rust.rs). + +## Flamegraphs + +The executor includes a flamegraph generator (`executor::flamegraph::FlamegraphGenerator`) that produces folded-stack output by instruction count. Drive it via the CLI: `cli execute --flamegraph stacks.txt`. See [`bin/cli/README.md`](../bin/cli/README.md) for details. diff --git a/executor/programs/asm/test_keccak.s b/executor/programs/asm/test_keccak.s new file mode 100644 index 000000000..31cd93be6 --- /dev/null +++ b/executor/programs/asm/test_keccak.s @@ -0,0 +1,38 @@ + .attribute 5, "rv64i2p1_m2p0_zmmul1p0" +.Lfunc_end0: + .globl main +main: + # Allocate 200 bytes on the stack for the Keccak state (25 × u64) + addi sp, sp, -200 + + # Zero out the state (200 bytes = 25 doublewords) + mv t0, sp + li t1, 25 +.Lzero_loop: + sd zero, 0(t0) + addi t0, t0, 8 + addi t1, t1, -1 + bnez t1, .Lzero_loop + + # Call keccak-f[1600] permutation + # a0 = pointer to 200-byte state + # a7 = syscall number (0xFFFFFFFFFFFFFFFE = u64::MAX - 1) + mv a0, sp + li a7, -2 + ecall + + # Commit the post-permutation state so the test can verify the KAT. + # Commit syscall: a0=fd(1), a1=buf_addr, a2=count, a7=64 + li a0, 1 + mv a1, sp + li a2, 200 + li a7, 64 + ecall + + # Restore stack and halt + addi sp, sp, 200 + li a0, 0 + li a7, 93 + ecall +.Lfunc_end1: + .size main, .Lfunc_end1-main diff --git a/executor/programs/asm/test_keccak_multi.s b/executor/programs/asm/test_keccak_multi.s new file mode 100644 index 000000000..fcd192de7 --- /dev/null +++ b/executor/programs/asm/test_keccak_multi.s @@ -0,0 +1,48 @@ + .attribute 5, "rv64i2p1_m2p0_zmmul1p0" +.Lfunc_end0: + .globl main +main: + # Allocate 200 bytes on the stack for the Keccak state (25 × u64). + addi sp, sp, -200 + + # Initialize a non-zero, deterministic state: lane[i] = i + 1. + # Used by the host test as the initial state for tiny-keccak::keccakf + # cross-checking. + mv t0, sp + li t1, 1 + li t2, 26 +.Linit_loop: + sd t1, 0(t0) + addi t0, t0, 8 + addi t1, t1, 1 + bne t1, t2, .Linit_loop + + # First keccak-f[1600] call. + mv a0, sp + li a7, -2 + ecall + + # Second keccak-f[1600] call on the result. + mv a0, sp + li a7, -2 + ecall + + # Third keccak-f[1600] call on the result. + mv a0, sp + li a7, -2 + ecall + + # Commit the final 200-byte state. + li a0, 1 + mv a1, sp + li a2, 200 + li a7, 64 + ecall + + # Restore stack and halt. + addi sp, sp, 200 + li a0, 0 + li a7, 93 + ecall +.Lfunc_end1: + .size main, .Lfunc_end1-main diff --git a/executor/programs/rust/ef_io_demo/.cargo/config.toml b/executor/programs/rust/ef_io_demo/.cargo/config.toml new file mode 100644 index 000000000..ca99a3f45 --- /dev/null +++ b/executor/programs/rust/ef_io_demo/.cargo/config.toml @@ -0,0 +1,5 @@ +[target.riscv64im-lambda-vm-elf] +rustflags = [ + "--cfg", "getrandom_backend=\"custom\"", + "-C", "passes=lower-atomic" +] diff --git a/executor/programs/rust/ef_io_demo/Cargo.toml b/executor/programs/rust/ef_io_demo/Cargo.toml new file mode 100644 index 000000000..f1c6f812a --- /dev/null +++ b/executor/programs/rust/ef_io_demo/Cargo.toml @@ -0,0 +1,9 @@ +[workspace] + +[package] +name = "ef_io_demo" +version = "0.1.0" +edition = "2024" + +[dependencies] +lambda-vm-syscalls = { path = "../../../../syscalls" } diff --git a/executor/programs/rust/ef_io_demo/src/main.rs b/executor/programs/rust/ef_io_demo/src/main.rs new file mode 100644 index 000000000..ef0690398 --- /dev/null +++ b/executor/programs/rust/ef_io_demo/src/main.rs @@ -0,0 +1,22 @@ +// Demo guest exercising the EF zkVM IO interface (`read_input` / `write_output`). +// +// Reads the private input via the EF zero-copy `read_input` shim, then emits it +// back as the public output in TWO `write_output` calls (split in halves) to +// exercise the multi-call concatenation requirement of the EF spec. +use lambda_vm_syscalls as syscalls; + +pub fn main() { + let mut buf_ptr: *const u8 = core::ptr::null(); + let mut buf_size: usize = 0; + unsafe { + syscalls::ef_io::read_input(&mut buf_ptr, &mut buf_size); + } + + if buf_size > 0 { + let half = buf_size / 2; + unsafe { + syscalls::ef_io::write_output(buf_ptr, half); + syscalls::ef_io::write_output(buf_ptr.add(half), buf_size - half); + } + } +} diff --git a/executor/src/vm/instruction/execution.rs b/executor/src/vm/instruction/execution.rs index a5222557a..219414745 100644 --- a/executor/src/vm/instruction/execution.rs +++ b/executor/src/vm/instruction/execution.rs @@ -8,12 +8,20 @@ use crate::vm::{ const REGULAR_PC_UPDATE: u64 = 4; pub enum SyscallNumbers { + // Placeholder discriminant. The actual syscall value is KECCAK_SYSCALL_NUMBER. + KeccakPermute = 0, Print = 1, Panic = 2, Commit = 64, Halt = 93, } +/// Syscall number for KeccakPermute (u64::MAX - 1 = 0xFFFF_FFFF_FFFF_FFFE). +/// +/// Cannot be an enum discriminant because it exceeds isize::MAX. +pub const KECCAK_SYSCALL_NUMBER: u64 = u64::MAX - 1; +const KECCAK_STATE_BYTES: u64 = 25 * 8; + impl TryFrom for SyscallNumbers { type Error = (); fn try_from(value: u64) -> Result { @@ -22,6 +30,7 @@ impl TryFrom for SyscallNumbers { 2 => Ok(SyscallNumbers::Panic), 64 => Ok(SyscallNumbers::Commit), 93 => Ok(SyscallNumbers::Halt), + v if v == KECCAK_SYSCALL_NUMBER => Ok(SyscallNumbers::KeccakPermute), _ => Err(()), } } @@ -295,7 +304,7 @@ impl Instruction { // It is not the correct implementation of ecall/ebreak let pointer = registers.read(10)?; let len = registers.read(11)?; - let bytes = memory.load_bytes(pointer, len); + let bytes = memory.load_bytes(pointer, len)?; let value = str::from_utf8(&bytes).map_err(|_| ExecutionError::IncorrectMessage)?; println!("PRINT VM: {}", value); @@ -304,7 +313,7 @@ impl Instruction { // panic let pointer = registers.read(10)?; let len = registers.read(11)?; - let bytes = memory.load_bytes(pointer, len); + let bytes = memory.load_bytes(pointer, len)?; let value = str::from_utf8(&bytes).map_err(|_| ExecutionError::IncorrectMessage)?; return Err(ExecutionError::Panic(value.to_owned())); @@ -324,6 +333,32 @@ impl Instruction { src2_val = buf_addr; dst_val = count; } + SyscallNumbers::KeccakPermute => { + // keccak-f[1600] permutation on 200 bytes (25 × u64) at address in x10 + let state_addr = registers.read(10)?; + if !state_addr.is_multiple_of(8) { + return Err(ExecutionError::UnalignedKeccakStateAddress(state_addr)); + } + state_addr + .checked_add(KECCAK_STATE_BYTES - 1) + .ok_or(ExecutionError::KeccakStateAddressOverflow(state_addr))?; + + let mut state = [0u64; 25]; + for (i, lane) in state.iter_mut().enumerate() { + let lane_addr = state_addr + .checked_add((i as u64) * 8) + .ok_or(ExecutionError::KeccakStateAddressOverflow(state_addr))?; + *lane = memory.load_doubleword(lane_addr)?; + } + keccak_f1600(&mut state); + for (i, &lane) in state.iter().enumerate() { + let lane_addr = state_addr + .checked_add((i as u64) * 8) + .ok_or(ExecutionError::KeccakStateAddressOverflow(state_addr))?; + memory.store_doubleword(lane_addr, lane)?; + } + src2_val = state_addr; + } SyscallNumbers::Halt => { // halt return Ok(Log { @@ -496,4 +531,177 @@ pub enum ExecutionError { InvalidWSuffixOperation(ArithOp), #[error("Invalid commit fd: expected 1 (stdout), got {0}")] InvalidCommitFd(u64), + #[error("Unaligned Keccak state address: {0:#018x}")] + UnalignedKeccakStateAddress(u64), + #[error("Keccak state address range overflows: {0:#018x}")] + KeccakStateAddressOverflow(u64), +} + +// ============================================================================= +// Keccak-f[1600] permutation +// ============================================================================= + +/// Round constants for Keccak-f[1600] (24 rounds). +pub const KECCAK_RC: [u64; 24] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808A, + 0x8000000080008000, + 0x000000000000808B, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008A, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000A, + 0x000000008000808B, + 0x800000000000008B, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800A, + 0x800000008000000A, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, +]; + +/// Rotation offsets R[x][y] for the rho step of Keccak-f[1600]. +pub const KECCAK_RHO: [[u32; 5]; 5] = [ + [0, 36, 3, 41, 18], + [1, 44, 10, 45, 2], + [62, 6, 43, 15, 61], + [28, 55, 25, 21, 56], + [27, 20, 39, 8, 14], +]; + +/// Apply the Keccak-f[1600] permutation (24 rounds) to a 25-word state. +/// +/// The state is indexed as `state[x + 5*y]` where `x, y ∈ {0..4}`. +pub fn keccak_f1600(state: &mut [u64; 25]) { + for &rc in &KECCAK_RC { + // θ (theta) + let mut c = [0u64; 5]; + for x in 0..5 { + c[x] = state[x] ^ state[x + 5] ^ state[x + 10] ^ state[x + 15] ^ state[x + 20]; + } + let mut d = [0u64; 5]; + for x in 0..5 { + d[x] = c[(x + 4) % 5] ^ c[(x + 1) % 5].rotate_left(1); + } + for x in 0..5 { + for y in 0..5 { + state[x + 5 * y] ^= d[x]; + } + } + + // ρ (rho) and π (pi) + let mut b = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + b[y + 5 * ((2 * x + 3 * y) % 5)] = state[x + 5 * y].rotate_left(KECCAK_RHO[x][y]); + } + } + + // χ (chi) + for x in 0..5 { + for y in 0..5 { + state[x + 5 * y] = + b[x + 5 * y] ^ (!b[(x + 1) % 5 + 5 * y] & b[(x + 2) % 5 + 5 * y]); + } + } + + // ι (iota) + state[0] ^= rc; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_keccak_f1600_zero_input() { + let mut state = [0u64; 25]; + keccak_f1600(&mut state); + + let expected: [u64; 25] = [ + 0xF1258F7940E1DDE7, + 0x84D5CCF933C0478A, + 0xD598261EA65AA9EE, + 0xBD1547306F80494D, + 0x8B284E056253D057, + 0xFF97A42D7F8E6FD4, + 0x90FEE5A0A44647C4, + 0x8C5BDA0CD6192E76, + 0xAD30A6F71B19059C, + 0x30935AB7D08FFC64, + 0xEB5AA93F2317D635, + 0xA9A6E6260D712103, + 0x81A57C16DBCF555F, + 0x43B831CD0347C826, + 0x01F22F1A11A5569F, + 0x05E5635A21D9AE61, + 0x64BEFEF28CC970F2, + 0x613670957BC46611, + 0xB87C5A554FD00ECB, + 0x8C3EE88A1CCF32C8, + 0x940C7922AE3A2614, + 0x1841F924A2C509E4, + 0x16F53526E70465C2, + 0x75F644E97F30A13B, + 0xEAF1FF7B5CECA249, + ]; + + assert_eq!(state, expected, "keccak-f[1600] on zero input mismatch"); + } + + #[test] + fn test_keccak_f1600_nonzero_input() { + let mut state = [0u64; 25]; + state[0] = 1; + let original = state; + keccak_f1600(&mut state); + assert_ne!(state, original); + assert!(state.iter().any(|&x| x != 0)); + } + + #[test] + fn test_keccak_syscall_rejects_unaligned_state_addr() { + let mut pc = 0; + let mut registers = Registers::default(); + let mut memory = Memory::default(); + + registers.write(17, KECCAK_SYSCALL_NUMBER).unwrap(); + registers.write(10, 0x1001).unwrap(); + + let err = Instruction::EcallEbreak + .run(&mut pc, &mut registers, &mut memory) + .unwrap_err(); + assert!(matches!( + err, + ExecutionError::UnalignedKeccakStateAddress(0x1001) + )); + } + + #[test] + fn test_keccak_syscall_rejects_overflowing_state_range() { + let mut pc = 0; + let mut registers = Registers::default(); + let mut memory = Memory::default(); + + registers.write(17, KECCAK_SYSCALL_NUMBER).unwrap(); + registers.write(10, u64::MAX - 191).unwrap(); + + let err = Instruction::EcallEbreak + .run(&mut pc, &mut registers, &mut memory) + .unwrap_err(); + assert!(matches!( + err, + ExecutionError::KeccakStateAddressOverflow(addr) if addr == u64::MAX - 191 + )); + } } diff --git a/executor/src/vm/memory.rs b/executor/src/vm/memory.rs index b1f047ee1..b78c98d44 100644 --- a/executor/src/vm/memory.rs +++ b/executor/src/vm/memory.rs @@ -38,9 +38,10 @@ impl BuildHasher for U64BuildHasher { pub type U64HashMap = HashMap; -// TODO: Correctly define this -const MAX_PUBLIC_OUTPUT_COMMIT_SIZE: u64 = 1024; -const PUBLIC_OUTPUT_START_INDEX: u64 = 0; +/// Total cap on public output bytes across all `commit_public_output` calls. +/// The COMMIT AIR concatenates calls via the running `x254` index, so this +/// is enforced as a running-total budget rather than a per-call limit. +pub const MAX_PUBLIC_OUTPUT_TOTAL_SIZE: u64 = 1024 * 1024; /// Maximum size of the private input memory region (in bytes). pub const MAX_PRIVATE_INPUT_SIZE: u64 = 6700000; /// Fixed high address where private input is mapped. Guest programs can read @@ -50,19 +51,30 @@ pub const MAX_PRIVATE_INPUT_SIZE: u64 = 6700000; pub const PRIVATE_INPUT_START_INDEX: u64 = 0xFF000000; #[derive(Default, Debug)] -pub struct Memory(U64HashMap<[u8; 4]>); +pub struct Memory { + cells: U64HashMap<[u8; 4]>, + /// Bytes committed to public output via `commit_public_output`. The + /// COMMIT AIR doesn't write to a fixed memory region (it streams bytes + /// onto the Commit bus by `index`), so this buffer is purely the + /// executor's view used by `read_return_value` and CLI display. + public_output: Vec, +} impl Memory { pub fn load_byte(&self, address: u64) -> u8 { let aligned_address = address - address % 4; - let value = self.0.get(&aligned_address).cloned().unwrap_or_default(); + let value = self + .cells + .get(&aligned_address) + .cloned() + .unwrap_or_default(); value[(address % 4) as usize] } pub fn store_byte(&mut self, address: u64, value: u8) { let aligned_address = address - address % 4; let entry = self - .0 + .cells .entry(aligned_address) .or_insert_with(|| [0, 0, 0, 0]); entry[(address % 4) as usize] = value; @@ -72,7 +84,7 @@ impl Memory { if !address.is_multiple_of(4) { return Err(MemoryError::UnalignedAccess); } - let bytes = self.0.get(&address).cloned().unwrap_or_default(); + let bytes = self.cells.get(&address).cloned().unwrap_or_default(); Ok(u32::from_le_bytes(bytes)) } @@ -81,7 +93,7 @@ impl Memory { return Err(MemoryError::UnalignedAccess); } let bytes = value.to_le_bytes(); - self.0.insert(address, bytes); + self.cells.insert(address, bytes); Ok(()) } @@ -90,8 +102,8 @@ impl Memory { if !address.is_multiple_of(8) { return Err(MemoryError::UnalignedAccess); } - let low_bytes = self.0.get(&address).cloned().unwrap_or_default(); - let high_bytes = self.0.get(&(address + 4)).cloned().unwrap_or_default(); + let low_bytes = self.cells.get(&address).cloned().unwrap_or_default(); + let high_bytes = self.cells.get(&(address + 4)).cloned().unwrap_or_default(); let low = u32::from_le_bytes(low_bytes) as u64; let high = u32::from_le_bytes(high_bytes) as u64; Ok(low | (high << 32)) @@ -104,8 +116,8 @@ impl Memory { } let low = (value & 0xFFFFFFFF) as u32; let high = (value >> 32) as u32; - self.0.insert(address, low.to_le_bytes()); - self.0.insert(address + 4, high.to_le_bytes()); + self.cells.insert(address, low.to_le_bytes()); + self.cells.insert(address + 4, high.to_le_bytes()); Ok(()) } @@ -117,7 +129,11 @@ impl Memory { ); } let aligned_address = address - address % 4; - let bytes = self.0.get(&aligned_address).cloned().unwrap_or_default(); + let bytes = self + .cells + .get(&aligned_address) + .cloned() + .unwrap_or_default(); let value = &bytes[(address % 4) as usize..(address % 4) as usize + 2]; Ok(u16::from_le_bytes( value.try_into().map_err(|_| MemoryError::LoadHalf)?, @@ -130,7 +146,7 @@ impl Memory { } let aligned_address = address - address % 4; let entry = self - .0 + .cells .entry(aligned_address) .or_insert_with(|| [0, 0, 0, 0]); let bytes = value.to_le_bytes(); @@ -139,19 +155,25 @@ impl Memory { Ok(()) } + /// Append `length` bytes from guest memory starting at `address` to the + /// public output. The COMMIT AIR concatenates calls via the running + /// `x254` index, and the trace builder accumulates `commit_ops` into + /// `VmProof.public_output`; this method maintains the executor's view + /// of the same byte stream so `read_return_value` matches. pub fn commit_public_output(&mut self, address: u64, length: u64) -> Result<(), MemoryError> { - if length > MAX_PUBLIC_OUTPUT_COMMIT_SIZE { + let new_total = (self.public_output.len() as u64) + .checked_add(length) + .ok_or(MemoryError::CommitSizeExceeded)?; + if new_total > MAX_PUBLIC_OUTPUT_TOTAL_SIZE { return Err(MemoryError::CommitSizeExceeded); } - self.store_word(PUBLIC_OUTPUT_START_INDEX, length as u32)?; - let inputs = self.load_bytes(address, length); - self.set_bytes_aligned(PUBLIC_OUTPUT_START_INDEX + 4, &inputs)?; + let bytes = self.load_bytes(address, length)?; + self.public_output.extend_from_slice(&bytes); Ok(()) } pub fn read_return_value(&self) -> Result, MemoryError> { - let size = self.load_word(PUBLIC_OUTPUT_START_INDEX)?; - Ok(self.load_bytes(PUBLIC_OUTPUT_START_INDEX + 4, size as u64)) + Ok(self.public_output.clone()) } /// Pre-loads private input bytes at `PRIVATE_INPUT_START_INDEX` as a @@ -164,23 +186,29 @@ impl Memory { if inputs.len() as u64 > MAX_PRIVATE_INPUT_SIZE { return Err(MemoryError::PrivateInputSizeExceeded); } - self.store_word(PRIVATE_INPUT_START_INDEX, inputs.len() as u32)?; + let len_u32 = + u32::try_from(inputs.len()).map_err(|_| MemoryError::PrivateInputSizeExceeded)?; + self.store_word(PRIVATE_INPUT_START_INDEX, len_u32)?; self.set_bytes_aligned(PRIVATE_INPUT_START_INDEX + 4, &inputs)?; Ok(()) } - pub fn load_bytes(&self, mut addr: u64, len: u64) -> Vec { - let mut result = Vec::with_capacity(len as usize); - let end = addr + len; + pub fn load_bytes(&self, mut addr: u64, len: u64) -> Result, MemoryError> { + let end = addr.checked_add(len).ok_or(MemoryError::AddressOverflow)?; + let len_usize = usize::try_from(len).map_err(|_| MemoryError::AllocationFailed)?; + let mut result = Vec::new(); + result + .try_reserve_exact(len_usize) + .map_err(|_| MemoryError::AllocationFailed)?; while addr < end { let aligned = addr - (addr % 4); - let bytes = self.0.get(&aligned).cloned().unwrap_or_default(); + let bytes = self.cells.get(&aligned).cloned().unwrap_or_default(); let offset = (addr % 4) as usize; let take = std::cmp::min(4 - offset, (end - addr) as usize); result.extend_from_slice(&bytes[offset..offset + take]); addr += take as u64; } - result + Ok(result) } /// Helper method to store a given input at an aligned address. It may also overwrite existing bytes with zero if inputs is not divisible by 4 @@ -192,7 +220,7 @@ impl Memory { for chunk in inputs.chunks(4) { let mut bytes = [0u8; 4]; bytes[..chunk.len()].copy_from_slice(chunk); - self.0.insert(addr, bytes); + self.cells.insert(addr, bytes); addr += 4; } Ok(()) @@ -209,6 +237,10 @@ pub enum MemoryError { CommitSizeExceeded, #[error("Private input size exceeded")] PrivateInputSizeExceeded, + #[error("Address range exceeds u64::MAX")] + AddressOverflow, + #[error("Failed to allocate memory for load_bytes")] + AllocationFailed, } #[cfg(test)] @@ -234,7 +266,7 @@ mod tests { } #[test] - fn test_commit_public_output_overwrites() { + fn test_commit_public_output_appends() { let mut memory = Memory::default(); memory.store_byte(0x100, b'a'); memory.store_byte(0x101, b'b'); @@ -248,19 +280,78 @@ mod tests { .commit_public_output(0x104, 2) .expect("second commit should succeed"); - // Overwrite semantics: second commit replaces first + // Append semantics: calls concatenate (EF zkVM IO interface). assert_eq!( memory .read_return_value() .expect("public output should be readable"), - b"cd".to_vec() + b"abcd".to_vec() ); } #[test] - fn test_commit_public_output_size_exceeded() { + fn test_commit_public_output_empty_is_ok() { + let mut memory = Memory::default(); + memory + .commit_public_output(0, 0) + .expect("zero-length commit should succeed"); + assert!( + memory + .read_return_value() + .expect("public output should be readable") + .is_empty() + ); + } + + #[test] + fn test_commit_public_output_address_overflow() { + let mut memory = Memory::default(); + let err = memory + .commit_public_output(u64::MAX, 2) + .expect_err("address overflow must error, not panic"); + assert!(matches!(err, super::MemoryError::AddressOverflow)); + } + + #[test] + fn test_load_bytes_huge_len_returns_alloc_error() { + let memory = Memory::default(); + // A multi-petabyte allocation request from a guest must fail cleanly, + // not abort the host process via OOM. `addr=0` and `len=1<<50` keep + // `checked_add` happy so the path reaches the allocation. + let huge = 1u64 << 50; + let err = memory + .load_bytes(0, huge) + .expect_err("huge alloc must error, not abort"); + assert!(matches!(err, super::MemoryError::AllocationFailed)); + } + + #[test] + fn test_load_bytes_overflow_errors() { + let memory = Memory::default(); + let err = memory + .load_bytes(u64::MAX, 2) + .expect_err("address overflow must error, not panic"); + assert!(matches!(err, super::MemoryError::AddressOverflow)); + } + + #[test] + fn test_commit_public_output_total_cap() { let mut memory = Memory::default(); - let err = memory.commit_public_output(0x100, 1025); - assert!(err.is_err()); + // Seed enough source bytes for two 512 KB writes. + let chunk = vec![0xAB; 512 * 1024]; + memory + .set_bytes_aligned(0x1_0000, &chunk) + .expect("seed should succeed"); + + memory + .commit_public_output(0x1_0000, 512 * 1024) + .expect("first 512 KB commit should succeed"); + memory + .commit_public_output(0x1_0000, 512 * 1024) + .expect("second 512 KB commit should succeed (total = 1 MB)"); + + // One more byte exceeds the 1 MB total cap. + let err = memory.commit_public_output(0x1_0000, 1).unwrap_err(); + assert!(matches!(err, super::MemoryError::CommitSizeExceeded)); } } diff --git a/executor/tests/asm.rs b/executor/tests/asm.rs index cbc1adec5..86722b82c 100644 --- a/executor/tests/asm.rs +++ b/executor/tests/asm.rs @@ -801,3 +801,49 @@ fn test_sub_64bit() { fn test_sub_underflow() { run_program("./program_artifacts/asm/sub_underflow.elf"); } + +// ==================== Keccak Precompile ==================== + +#[test] +fn test_keccak() { + // Runs keccak-f[1600] on a zeroed state and commits the 200-byte result. + // Expected output is the FIPS-202 zero-input KAT. + let elf_data = std::fs::read("./program_artifacts/asm/test_keccak.elf").unwrap(); + let program = Elf::load(&elf_data).unwrap(); + let executor = Executor::new(&program, vec![]).expect("Failed to create executor"); + let result = executor.run().expect("Failed to run program"); + + let expected_state: [u64; 25] = [ + 0xF1258F7940E1DDE7, + 0x84D5CCF933C0478A, + 0xD598261EA65AA9EE, + 0xBD1547306F80494D, + 0x8B284E056253D057, + 0xFF97A42D7F8E6FD4, + 0x90FEE5A0A44647C4, + 0x8C5BDA0CD6192E76, + 0xAD30A6F71B19059C, + 0x30935AB7D08FFC64, + 0xEB5AA93F2317D635, + 0xA9A6E6260D712103, + 0x81A57C16DBCF555F, + 0x43B831CD0347C826, + 0x01F22F1A11A5569F, + 0x05E5635A21D9AE61, + 0x64BEFEF28CC970F2, + 0x613670957BC46611, + 0xB87C5A554FD00ECB, + 0x8C3EE88A1CCF32C8, + 0x940C7922AE3A2614, + 0x1841F924A2C509E4, + 0x16F53526E70465C2, + 0x75F644E97F30A13B, + 0xEAF1FF7B5CECA249, + ]; + let mut expected_bytes = Vec::with_capacity(200); + for lane in expected_state { + expected_bytes.extend_from_slice(&lane.to_le_bytes()); + } + assert_eq!(result.return_values.memory_values, expected_bytes); + assert_eq!(result.return_values.register_values.0, 0); +} diff --git a/executor/tests/rust.rs b/executor/tests/rust.rs index fab183571..b15530a63 100644 --- a/executor/tests/rust.rs +++ b/executor/tests/rust.rs @@ -160,6 +160,20 @@ fn test_commit() { ); } +#[test] +fn test_ef_io_demo_concatenates_writes() { + // Demo guest reads its private input via EF `read_input`, then emits it + // back as the public output via TWO `write_output` calls (split in halves). + // The COMMIT AIR concatenates the two calls; the executor's + // `commit_public_output` appends in the same order. + let input: Vec = b"hello world!".to_vec(); + run_program_and_check_public_output( + "./program_artifacts/rust/ef_io_demo.elf", + input.clone(), + input, + ); +} + #[test] fn test_commit_sum() { run_program_and_check_public_output( diff --git a/prover/Cargo.toml b/prover/Cargo.toml index bf55a251d..cd5d2534b 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -21,6 +21,7 @@ rayon = { version = "1.8.0", optional = true } [dev-dependencies] env_logger = "*" criterion = { version = "0.5", default-features = false } +tiny-keccak = { version = "2.0", features = ["keccak"] } [[bench]] name = "vm_prover_benchmark" diff --git a/prover/README.md b/prover/README.md index e69de29bb..b523c1c3b 100644 --- a/prover/README.md +++ b/prover/README.md @@ -0,0 +1,54 @@ +# Lambda VM Prover + +STARK prover for the Lambda VM. Proves correct execution of RISC-V ELF binaries by generating a multi-table STARK proof (CPU, decode, bitwise, branch, LT, shift, MUL, DVRM, MEMW, LOAD, page, register, halt, commit, keccak) and provides the matching native verifier. + +Published as `lambda-vm-prover`. Consumed by [`bin/cli`](../bin/cli) and by the benchmarks; you can also use it directly from Rust. + +## Usage + +```rust +use lambda_vm_prover as prover; + +let elf_bytes = std::fs::read("program.elf").unwrap(); + +let proof = prover::prove(&elf_bytes).unwrap(); +assert!(prover::verify(&proof, &elf_bytes).unwrap()); +``` + +With private inputs: + +```rust +let private_inputs = std::fs::read("input.bin").unwrap(); +let proof = prover::prove_with_inputs(&elf_bytes, &private_inputs).unwrap(); +``` + +## Public API + +| Function | Description | +|---|---| +| `prove(elf)` | Generate a proof with default options (blowup = 2). | +| `prove_with_inputs(elf, private)` | Same, with private input bytes. | +| `prove_with_options(elf, opts, max_rows)` | Custom proof options and max-rows config. | +| `prove_with_options_and_inputs(...)` | Most general entry point. | +| `verify(proof, elf)` | Verify a proof with default options. | +| `verify_with_options(proof, elf, opts)` | Verify with caller-chosen options (the verifier enforces its own security parameters, not the prover's). | +| `prove_and_verify(elf)` | Prove + verify in one call (convenience). | +| `count_elements(elf, private)` | Build traces and return `(main, aux)` field-element counts without running the proof step. | + +The proof bundle type is `VmProof`, containing the multi-table STARK proof, the public output bytes committed by the guest, and the metadata the verifier needs to reconstruct the AIR configuration (table chunk counts, runtime page ranges, number of private-input pages). + +## Features + +| Feature | Description | +|---|---| +| `parallel` (default) | Rayon parallelism across tables and FFTs. | +| `debug-checks` | Runs `validate_trace` and prints a per-bus LogUp balance report after proving. Forwarded to `crypto/stark`. | +| `instruments` | Per-phase timing and heap-usage report (execute, trace build, AIR construction, prove). | + +To run the test suite with debug output: + +```sh +cargo test --release -p lambda-vm-prover --features debug-checks -- --nocapture +``` + +See the root [`README.md`](../README.md) for the end-to-end workflow (compiling guest programs, the CLI wrapper, benchmarks). diff --git a/prover/src/constraints/cpu.rs b/prover/src/constraints/cpu.rs index 64d6b7e3e..546f2f2a4 100644 --- a/prover/src/constraints/cpu.rs +++ b/prover/src/constraints/cpu.rs @@ -1033,7 +1033,7 @@ pub fn create_jalr_constraints(constraint_idx_start: usize) -> (Vec, pub halt: VmAir, pub commit: VmAir, + pub keccak: VmAir, + pub keccak_rnd: VmAir, + pub keccak_rc: VmAir, pub register: VmAir, pub pages: Vec, pub memw_registers: Vec, @@ -213,6 +217,9 @@ impl VmAirs { (&self.decode, &mut traces.decode, &()), (&self.halt, &mut traces.halt, &()), (&self.commit, &mut traces.commit, &()), + (&self.keccak, &mut traces.keccak, &()), + (&self.keccak_rnd, &mut traces.keccak_rnd, &()), + (&self.keccak_rc, &mut traces.keccak_rc, &()), (&self.register, &mut traces.register, &()), ]; @@ -268,6 +275,9 @@ impl VmAirs { &self.decode, &self.halt, &self.commit, + &self.keccak, + &self.keccak_rnd, + &self.keccak_rc, &self.register, ]; @@ -363,6 +373,12 @@ impl VmAirs { .collect(); let halt = create_halt_air(proof_options); let commit = create_commit_air(proof_options); + let keccak = create_keccak_air(proof_options); + let keccak_rnd = create_keccak_rnd_air(proof_options); + let keccak_rc = create_keccak_rc_air(proof_options).with_preprocessed( + tables::keccak_rc::preprocessed_commitment(proof_options), + tables::keccak_rc::NUM_PRECOMPUTED_COLS, + ); let register = create_register_air(proof_options).with_preprocessed( register::preprocessed_commitment(proof_options, elf.entry_point), register::NUM_PREPROCESSED_COLS, @@ -406,6 +422,9 @@ impl VmAirs { branches, halt, commit, + keccak, + keccak_rnd, + keccak_rc, register, pages, memw_registers, @@ -632,6 +651,11 @@ pub fn prove_with_options_and_inputs( .filter(|c| c.is_private_input) .count(); + debug_assert_eq!( + traces.public_output_bytes, result.return_values.memory_values, + "public output diverged between executor view and trace reconstruction" + ); + Ok(VmProof { proof, runtime_page_ranges, @@ -690,11 +714,11 @@ pub fn verify_with_options( ); // Cross-check: table_counts must match the number of sub-proofs. - // Fixed tables (bitwise, decode, halt, commit, register) = 5, plus page tables. - let expected_proof_count = vm_proof.table_counts.total() + 5 + page_configs.len(); + // Fixed tables (bitwise, decode, halt, commit, keccak, keccak_rnd, keccak_rc, register) = 8, plus page tables. + let expected_proof_count = vm_proof.table_counts.total() + 8 + page_configs.len(); if expected_proof_count != vm_proof.proof.proofs.len() { return Err(Error::InvalidTableCounts(format!( - "table_counts total ({}) + 5 fixed + {} pages = {}, but proof contains {} sub-proofs", + "table_counts total ({}) + 8 fixed + {} pages = {}, but proof contains {} sub-proofs", vm_proof.table_counts.total(), page_configs.len(), expected_proof_count, diff --git a/prover/src/tables/cpu.rs b/prover/src/tables/cpu.rs index 70ae8c501..57f207d4d 100644 --- a/prover/src/tables/cpu.rs +++ b/prover/src/tables/cpu.rs @@ -306,6 +306,12 @@ pub struct CpuOperation { /// For Commit ECALLs: byte count from x12 pub commit_count: u64, + + /// Whether this ECALL is a KeccakPermute syscall + pub ecall_keccak: bool, + + /// For KeccakPermute ECALLs: state address from x10 + pub keccak_state_addr: u64, } impl CpuOperation { @@ -641,6 +647,9 @@ impl CpuOperation { } else { (0, 0) }; + let ecall_keccak = decode.op_ecall + && log.src1_val == executor::vm::instruction::execution::KECCAK_SYSCALL_NUMBER; + let keccak_state_addr = if ecall_keccak { log.src2_val } else { 0 }; // CM50: (1 - read_register2) * rv2[i] = 0. When read_register2=0, rv2 must be 0. // For example, ECALL has read_register2=0 (rs2 defaults to 0). The commit buf_addr is // carried separately in commit_buf_addr and does not go through rv2. @@ -663,6 +672,8 @@ impl CpuOperation { ecall_commit, commit_buf_addr, commit_count, + ecall_keccak, + keccak_state_addr, }; // Compute runtime-specific values based on instruction type @@ -2035,12 +2046,9 @@ pub fn bus_interactions() -> Vec { } } - // ECALL interaction (single shared bus for HALT and COMMIT) + // ECALL interaction (shared bus for HALT, COMMIT, and KECCAK) // ------------------------------------------------------------------------- - // Sends to both HALT and COMMIT tables. Each receiver pattern-matches on - // the syscall number in the payload. - // multiplicity = ECALL - // rv1 = value of a7 register (syscall number). + // multiplicity = ECALL (all ECALLs, each receiver matches on syscall number) interactions.push(BusInteraction::sender( BusId::Ecall, Multiplicity::Column(cols::ECALL), diff --git a/prover/src/tables/keccak.rs b/prover/src/tables/keccak.rs new file mode 100644 index 000000000..87e8dc122 --- /dev/null +++ b/prover/src/tables/keccak.rs @@ -0,0 +1,567 @@ +//! KECCAK core chip — handles ECALL, memory I/O, and delegation to the round chip. +//! +//! One row per keccak permutation call. Reads/writes 25 u64 lanes from/to memory, +//! sends input state to the round chip via the Keccak bus, and receives the output +//! state after 24 rounds. +//! +//! ## Column layout (~511 columns) +//! +//! | Group | Size | Description | +//! |----------------|------|------------------------------------------------| +//! | timestamp | 2 | DWordWL | +//! | addr | 8 | State address as DWordBL (8 bytes) | +//! | input_state | 200 | Input state bytes [5][5][8] | +//! | output_state | 200 | Output state bytes [5][5][8] | +//! | state_ptr | 100 | Per-lane DWordHL addresses [25][4] | +//! | mu | 1 | Multiplicity flag | + +use executor::vm::instruction::execution::KECCAK_SYSCALL_NUMBER; +use math::field::element::FieldElement; +use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; +use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; +use stark::table::TableView; +use stark::trace::TraceTable; + +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; +use crate::constraints::templates::{AddConstraint, AddOperand, INV_SHIFT_32}; + +// ========================================================================= +// Column indices +// ========================================================================= + +pub mod cols { + pub const TIMESTAMP_0: usize = 0; + pub const TIMESTAMP_1: usize = 1; + + // addr[8] — state address as 8 bytes (DWordBL) + pub const ADDR: usize = 2; + + // input_state[5][5][8] = 200 bytes + pub const INPUT_STATE: usize = ADDR + 8; // 10 + + // output_state[5][5][8] = 200 bytes + pub const OUTPUT_STATE: usize = INPUT_STATE + 200; // 210 + + // state_ptr[25][4] = 100 halfwords (DWordHL per lane) + pub const STATE_PTR: usize = OUTPUT_STATE + 200; // 410 + + pub const MU: usize = STATE_PTR + 100; // 510 + + pub const NUM_COLUMNS: usize = MU + 1; // 511 + + // ------------------------------------------------------------------------- + // Index helpers + // ------------------------------------------------------------------------- + + #[inline] + pub const fn addr(byte: usize) -> usize { + ADDR + byte + } + + /// Index into input_state[x][y][byte] + #[inline] + pub const fn input_state(x: usize, y: usize, byte: usize) -> usize { + INPUT_STATE + (x + 5 * y) * 8 + byte + } + + /// Index into output_state[x][y][byte] + #[inline] + pub const fn output_state(x: usize, y: usize, byte: usize) -> usize { + OUTPUT_STATE + (x + 5 * y) * 8 + byte + } + + /// Index into state_ptr[lane_idx][halfword] (DWordHL = 4 halfwords) + #[inline] + pub const fn state_ptr(lane_idx: usize, hw: usize) -> usize { + STATE_PTR + lane_idx * 4 + hw + } +} + +// ========================================================================= +// Operation struct +// ========================================================================= + +#[derive(Debug, Clone)] +pub struct KeccakOperation { + pub timestamp: u64, + pub state_addr: u64, + pub input: [u64; 25], + pub output: [u64; 25], +} + +// ========================================================================= +// Trace generation +// ========================================================================= + +fn byte_of(val: u64, b: usize) -> u8 { + ((val >> (b * 8)) & 0xFF) as u8 +} + +pub fn generate_keccak_trace( + ops: &[KeccakOperation], +) -> TraceTable { + let n = ops.len(); + let num_rows = n.next_power_of_two().max(4); + let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; + + for (row_idx, op) in ops.iter().enumerate() { + let base = row_idx * cols::NUM_COLUMNS; + + // Timestamp + data[base + cols::TIMESTAMP_0] = FE::from(op.timestamp & 0xFFFF_FFFF); + data[base + cols::TIMESTAMP_1] = FE::from(op.timestamp >> 32); + + // Address as 8 bytes + for b in 0..8 { + data[base + cols::addr(b)] = FE::from(byte_of(op.state_addr, b) as u64); + } + + // Input state as bytes + for x in 0..5 { + for y in 0..5 { + let lane = op.input[x + 5 * y]; + for b in 0..8 { + data[base + cols::input_state(x, y, b)] = FE::from(byte_of(lane, b) as u64); + } + } + } + + // Output state as bytes + for x in 0..5 { + for y in 0..5 { + let lane = op.output[x + 5 * y]; + for b in 0..8 { + data[base + cols::output_state(x, y, b)] = FE::from(byte_of(lane, b) as u64); + } + } + } + + // State pointers: state_ptr[lane] = addr + 8 * lane_idx + for lane_idx in 0..25 { + let ptr = op + .state_addr + .checked_add(lane_idx as u64 * 8) + .expect("keccak state address range must be validated by the executor"); + data[base + cols::state_ptr(lane_idx, 0)] = FE::from(ptr & 0xFFFF); + data[base + cols::state_ptr(lane_idx, 1)] = FE::from((ptr >> 16) & 0xFFFF); + data[base + cols::state_ptr(lane_idx, 2)] = FE::from((ptr >> 32) & 0xFFFF); + data[base + cols::state_ptr(lane_idx, 3)] = FE::from((ptr >> 48) & 0xFFFF); + } + + // mu = 1 (real row) + data[base + cols::MU] = FE::one(); + } + + // Padding rows: state_ptr[lane][0] = 8 * lane_idx (per spec keccak.toml pad). + // Halfwords 1..3 stay zero since 8*24 = 192 fits in the low halfword. + // mu = 0 gates all bus interactions and the ADD constraint, so these values + // only need to satisfy the pad requirement, not reconstruct a real address. + for row_idx in n..num_rows { + let base = row_idx * cols::NUM_COLUMNS; + for lane_idx in 0..25 { + data[base + cols::state_ptr(lane_idx, 0)] = FE::from((lane_idx as u64) * 8); + } + } + + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +// ========================================================================= +// Bus interactions +// ========================================================================= + +pub fn bus_interactions() -> Vec { + let syscall_lo = KECCAK_SYSCALL_NUMBER & 0xFFFF_FFFF; + let syscall_hi = KECCAK_SYSCALL_NUMBER >> 32; + let mut interactions = Vec::with_capacity(160); + + // 1. ECALL receiver (shared bus, per spec keccak:c:output) + // Payload: [ts_lo, ts_hi, syscall_lo32, syscall_hi32] in DWordWL [lo, hi] + // ordering, matching the CPU ECALL sender shared with HALT/COMMIT. + interactions.push(BusInteraction::receiver( + BusId::Ecall, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::constant(syscall_lo), + BusValue::constant(syscall_hi), + ], + )); + + // 2. MEMW read_addr: read register x10 to bind addr (per spec keccak:c:read_addr) + // Format: [old[8], is_register=1, base_addr=[20,0], value[8], ts, ts_hi, write2=1, write4=0, write8=0] + // For register read: old = value = addr as WL + 6 zeros + { + // addr as DWordWL from DWordBL bytes: lo32 = sum(addr[0..4] * 256^i), hi32 = sum(addr[4..8] * 256^i) + let addr_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::addr(0), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::addr(1), + }, + LinearTerm::Column { + coefficient: 65536, + column: cols::addr(2), + }, + LinearTerm::Column { + coefficient: 16777216, + column: cols::addr(3), + }, + ]); + let addr_hi = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::addr(4), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::addr(5), + }, + LinearTerm::Column { + coefficient: 65536, + column: cols::addr(6), + }, + LinearTerm::Column { + coefficient: 16777216, + column: cols::addr(7), + }, + ]); + let mut values = Vec::with_capacity(24); + // old[0..7] = addr as WL + 6 zeros + values.push(addr_lo.clone()); + values.push(addr_hi.clone()); + for _ in 2..8 { + values.push(BusValue::constant(0)); + } + // is_register = 1 + values.push(BusValue::constant(1)); + // base_address = 2*10 = 20 (register x10) + values.push(BusValue::constant(20)); + values.push(BusValue::constant(0)); + // value[0..7] = same as old (read) + values.push(addr_lo); + values.push(addr_hi); + for _ in 2..8 { + values.push(BusValue::constant(0)); + } + // timestamp + values.push(BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }); + values.push(BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }); + // write2=1, write4=0, write8=0 (register access) + values.push(BusValue::constant(1)); + values.push(BusValue::constant(0)); + values.push(BusValue::constant(0)); + interactions.push(BusInteraction::sender( + BusId::Memw, + Multiplicity::Column(cols::MU), + values, + )); + } + + // 2. Keccak bus: send (timestamp, 0, input_state[200]) + // Per spec keccak.toml: input = ["timestamp", 0, "input_state"] where + // input_state is [[[Byte, 8], 5], 5] — 200 Byte elements, each its own + // bus element (no packing). + { + let mut values = vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::constant(0), // round = 0 + ]; + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::input_state(x, y, b), + packing: Packing::Direct, + }); + } + } + } + interactions.push(BusInteraction::sender( + BusId::Keccak, + Multiplicity::Column(cols::MU), + values, + )); + } + + // 3. Keccak bus: receive (timestamp, 24, output_state[200]) + { + let mut values = vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::constant(24), // round = 24 + ]; + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::output_state(x, y, b), + packing: Packing::Direct, + }); + } + } + } + interactions.push(BusInteraction::receiver( + BusId::Keccak, + Multiplicity::Column(cols::MU), + values, + )); + } + + // 4. IS_HALF range checks on state_ptr (100 interactions) + for lane_idx in 0..25 { + for hw in 0..4 { + interactions.push(BusInteraction::sender( + BusId::IsHalfword, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: cols::state_ptr(lane_idx, hw), + packing: Packing::Direct, + }], + )); + } + } + + // 5. Alignment: addr[0] & 7 = 0, which enforces addr % 8 == 0. + interactions.push(BusInteraction::sender( + BusId::AndByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::addr(0), + packing: Packing::Direct, + }, + BusValue::constant(7), + BusValue::constant(0), + ], + )); + + // 6. Range-check every addr byte. The addr columns are reconstructed as a + // linear combination (addr_lo = b0 + 256*b1 + 65536*b2 + 2^24*b3, etc.) + // for the MEMW lookup and the no-overflow / alignment constraints. Without + // an explicit byte range check on each cell, an attacker can keep the + // field-element value of that linear combination correct while encoding + // arbitrary non-byte values in the individual cells (e.g. addr[0]=0, + // addr[1]=V_lo * 256^{-1} mod p), bypassing the alignment check. + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::IsByte, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: cols::addr(b), + packing: Packing::Direct, + }], + )); + } + + // 7. MEMW interactions: 25 combined read+write per lane (per spec) + // Format: [old[8], is_register, addr_lo32, addr_hi32, value[8], ts[2], w2, w4, w8] = 24 + // old = input_state (read), value = output_state (write) + for lane_idx in 0..25 { + let x = lane_idx % 5; + let y = lane_idx / 5; + + // Address as DWordWL: lo32 = h0 + 2^16*h1, hi32 = h2 + 2^16*h3 + let addr_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::state_ptr(lane_idx, 0), + }, + LinearTerm::Column { + coefficient: 65536, + column: cols::state_ptr(lane_idx, 1), + }, + ]); + let addr_hi = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::state_ptr(lane_idx, 2), + }, + LinearTerm::Column { + coefficient: 65536, + column: cols::state_ptr(lane_idx, 3), + }, + ]); + + let mut values = Vec::with_capacity(24); + // old[0..8] = input_state bytes (the value being read) + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::input_state(x, y, b), + packing: Packing::Direct, + }); + } + // is_register = 0 + values.push(BusValue::constant(0)); + // address as DWordWL + values.push(addr_lo); + values.push(addr_hi); + // value[0..8] = output_state bytes (the value being written) + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::output_state(x, y, b), + packing: Packing::Direct, + }); + } + // timestamp + values.push(BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }); + values.push(BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }); + // write2=0, write4=0, write8=1 + values.push(BusValue::constant(0)); + values.push(BusValue::constant(0)); + values.push(BusValue::constant(1)); + + interactions.push(BusInteraction::sender( + BusId::Memw, + Multiplicity::Column(cols::MU), + values, + )); + } + + interactions +} + +// ========================================================================= +// Constraints +// ========================================================================= + +struct KeccakAddressNoOverflowConstraint { + constraint_idx: usize, +} + +impl KeccakAddressNoOverflowConstraint { + fn new(constraint_idx: usize) -> Self { + Self { constraint_idx } + } + + fn compute(&self, step: &TableView) -> FieldElement + where + F: IsSubFieldOf, + E: IsField, + { + let addr_lo = step.get_main_evaluation_element(0, cols::addr(0)).clone() + + step.get_main_evaluation_element(0, cols::addr(1)) * FieldElement::::from(256) + + step.get_main_evaluation_element(0, cols::addr(2)) * FieldElement::::from(65536) + + step.get_main_evaluation_element(0, cols::addr(3)) + * FieldElement::::from(16777216); + let addr_hi = step.get_main_evaluation_element(0, cols::addr(4)).clone() + + step.get_main_evaluation_element(0, cols::addr(5)) * FieldElement::::from(256) + + step.get_main_evaluation_element(0, cols::addr(6)) * FieldElement::::from(65536) + + step.get_main_evaluation_element(0, cols::addr(7)) + * FieldElement::::from(16777216); + + let ptr_lo = step + .get_main_evaluation_element(0, cols::state_ptr(24, 0)) + .clone() + + step.get_main_evaluation_element(0, cols::state_ptr(24, 1)) + * FieldElement::::from(65536); + let ptr_hi = step + .get_main_evaluation_element(0, cols::state_ptr(24, 2)) + .clone() + + step.get_main_evaluation_element(0, cols::state_ptr(24, 3)) + * FieldElement::::from(65536); + + let inv_2_32 = FieldElement::::from(INV_SHIFT_32); + let carry_0 = (addr_lo + FieldElement::::from(192) - ptr_lo) * inv_2_32.clone(); + let carry_1 = (addr_hi + carry_0 - ptr_hi) * inv_2_32; + step.get_main_evaluation_element(0, cols::MU).clone() * carry_1 + } +} + +impl TransitionConstraint + for KeccakAddressNoOverflowConstraint +{ + fn degree(&self) -> usize { + 2 + } + + fn constraint_idx(&self) -> usize { + self.constraint_idx + } + + fn evaluate(&self, step: &TableView) -> FieldElement + where + F: IsSubFieldOf, + E: IsField, + { + self.compute(step) + } +} + +/// Create constraints for the KECCAK core chip. +/// +/// Per spec (keccak:c:state_ptr): ADD template for each lane: +/// state_ptr[lane] = addr + 8 * lane_idx +/// +/// 25 lane pointers × 2 constraints per ADD + 1 top-lane no-overflow +/// constraint = 51 constraints total. +/// Conditional on mu (only real rows). +pub fn create_constraints( + constraint_idx_start: usize, +) -> ( + Vec>>, + usize, +) { + let mut constraints: Vec< + Box>, + > = Vec::with_capacity(51); + let mut idx = constraint_idx_start; + + // state_ptr[lane] = addr + 8*lane_idx + // addr is DWordBL (8 bytes), state_ptr is DWordHL (4 halfwords) + // ADD: lhs = addr (DWordBL→DWordWL), rhs = 8*lane_idx (constant), sum = state_ptr (DWordHL→DWordWL) + for lane_idx in 0..25 { + let offset = (lane_idx * 8) as i64; + let (c0, c1) = AddConstraint::new_pair( + vec![cols::MU], // conditional on mu + AddOperand::from_dword_bl(cols::ADDR), + AddOperand::constant(offset), + AddOperand::from_dword_hl(cols::state_ptr(lane_idx, 0)), + idx, + ); + constraints.push(c0.boxed()); + constraints.push(c1.boxed()); + idx += 2; + } + + constraints.push(KeccakAddressNoOverflowConstraint::new(idx).boxed()); + idx += 1; + + (constraints, idx) +} diff --git a/prover/src/tables/keccak_rc.rs b/prover/src/tables/keccak_rc.rs new file mode 100644 index 000000000..c2e14d643 --- /dev/null +++ b/prover/src/tables/keccak_rc.rs @@ -0,0 +1,190 @@ +//! KECCAK_RC: Precomputed round constant lookup table for Keccak-f[1600]. +//! +//! 24 rows (one per round), padded to 32. Each row maps a round index to its +//! 8-byte round constant. The round chip looks up `(round) → rc[8]` via the +//! `KeccakRc` bus. +//! +//! Follows the BITWISE preprocessed-table pattern: precomputed columns are +//! committed once and cached via `OnceLock`. + +use std::sync::OnceLock; + +use math::fft::cpu::bit_reversing::in_place_bit_reverse_permute; +use math::field::element::FieldElement; +use math::polynomial::Polynomial; +use stark::config::{BatchedMerkleTree, Commitment}; +use stark::lookup::{BusInteraction, BusValue, Multiplicity, Packing}; +use stark::proof::options::ProofOptions; +use stark::prover::evaluate_polynomial_on_lde_domain; +use stark::trace::{TraceTable, columns2rows}; + +use executor::vm::instruction::execution::KECCAK_RC; + +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; + +// ========================================================================= +// Column indices +// ========================================================================= + +pub mod cols { + /// Round index (0..23) + pub const ROUND: usize = 0; + /// RC bytes [0..7] — 8 bytes of the round constant (little-endian) + pub const RC: usize = 1; + pub const RC_END: usize = RC + 8; // = 9 + /// Multiplicity (how many times this row is looked up) + pub const MU: usize = 9; + + pub const NUM_COLUMNS: usize = 10; +} + +/// Number of precomputed columns (everything except MU). +pub const NUM_PRECOMPUTED_COLS: usize = 9; + +/// Number of real rows (one per keccak round). +pub const NUM_REAL_ROWS: usize = 24; + +/// Number of rows in the trace (padded to next power of 2). +pub const NUM_ROWS: usize = 32; + +/// Whether this table is preprocessed. +pub const fn is_preprocessed() -> bool { + true +} + +/// Generate one precomputed row: [round, rc_byte0, ..., rc_byte7]. +pub const fn generate_row(round: usize) -> [u64; NUM_PRECOMPUTED_COLS] { + let rc_val = if round < 24 { KECCAK_RC[round] } else { 0 }; + [ + round as u64, + rc_val & 0xFF, + (rc_val >> 8) & 0xFF, + (rc_val >> 16) & 0xFF, + (rc_val >> 24) & 0xFF, + (rc_val >> 32) & 0xFF, + (rc_val >> 40) & 0xFF, + (rc_val >> 48) & 0xFF, + (rc_val >> 56) & 0xFF, + ] +} + +// ========================================================================= +// Preprocessed commitment +// ========================================================================= + +static KECCAK_RC_COMMITMENT: OnceLock = OnceLock::new(); + +fn compute_preprocessed_commitment(options: &ProofOptions) -> Commitment { + // Generate precomputed columns + let mut columns: Vec> = (0..NUM_PRECOMPUTED_COLS) + .map(|_| Vec::with_capacity(NUM_ROWS)) + .collect(); + for idx in 0..NUM_ROWS { + let row = generate_row(idx); + for (col_idx, &value) in row.iter().enumerate() { + columns[col_idx].push(FE::from(value)); + } + } + + // Interpolate each column to a polynomial + let polys: Vec> = columns + .iter() + .map(|col| { + Polynomial::interpolate_fft::(col) + .expect("FFT interpolation failed for keccak_rc column") + }) + .collect(); + + // Evaluate on LDE domain + let blowup_factor = options.blowup_factor as usize; + let coset_offset = FE::from(options.coset_offset); + let mut lde_columns: Vec> = polys + .iter() + .map(|poly| { + evaluate_polynomial_on_lde_domain(poly, blowup_factor, NUM_ROWS, &coset_offset) + .expect("LDE evaluation failed for keccak_rc polynomial") + }) + .collect(); + + // Bit-reverse permute + for col in lde_columns.iter_mut() { + in_place_bit_reverse_permute(col); + } + + // Build Merkle tree + let lde_rows = columns2rows(lde_columns); + let tree = BatchedMerkleTree::::build(&lde_rows) + .expect("Failed to build Merkle tree for keccak_rc LDE"); + + tree.root +} + +#[inline] +pub fn preprocessed_commitment(options: &ProofOptions) -> Commitment { + *KECCAK_RC_COMMITMENT.get_or_init(|| compute_preprocessed_commitment(options)) +} + +// ========================================================================= +// Trace generation +// ========================================================================= + +/// Generate the KECCAK_RC trace table. +/// +/// All precomputed columns are filled; MU is initialized to zero and must be +/// updated via `update_multiplicities` after all round-chip lookups are known. +pub fn generate_keccak_rc_trace() -> TraceTable { + let mut data = vec![FE::zero(); NUM_ROWS * cols::NUM_COLUMNS]; + + for idx in 0..NUM_ROWS { + let base = idx * cols::NUM_COLUMNS; + let row = generate_row(idx); + for (col_idx, &value) in row.iter().enumerate() { + data[base + col_idx] = FE::from(value); + } + // MU = 0 (will be updated later) + } + + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +/// Increment MU for each round lookup. +/// +/// Called after the round chip's trace is generated. Each keccak permutation +/// call produces 24 round lookups (one per round), so each round row's MU +/// equals the number of keccak operations. +pub fn update_multiplicities( + trace: &mut TraceTable, + num_keccak_ops: usize, +) { + let mu = FieldElement::from(num_keccak_ops as u64); + for round in 0..NUM_REAL_ROWS { + let base = round * cols::NUM_COLUMNS; + trace.main_table.data[base + cols::MU] = mu; + } +} + +// ========================================================================= +// Bus interactions +// ========================================================================= + +/// Single receiver on the KeccakRc bus. +/// +/// Format: [round(Direct), rc[0](Direct), ..., rc[7](Direct)] +pub fn bus_interactions() -> Vec { + let mut values = vec![BusValue::Packed { + start_column: cols::ROUND, + packing: Packing::Direct, + }]; + for i in 0..8 { + values.push(BusValue::Packed { + start_column: cols::RC + i, + packing: Packing::Direct, + }); + } + + vec![BusInteraction::receiver( + BusId::KeccakRc, + Multiplicity::Column(cols::MU), + values, + )] +} diff --git a/prover/src/tables/keccak_rnd.rs b/prover/src/tables/keccak_rnd.rs new file mode 100644 index 000000000..277281583 --- /dev/null +++ b/prover/src/tables/keccak_rnd.rs @@ -0,0 +1,986 @@ +//! KECCAK_RND: Round chip for Keccak-f[1600] permutation. +//! +//! One row per round (24 rows per keccak call). All bitwise operations are +//! delegated to BITWISE lookup tables (XOR_BYTE, AND_BYTE, HWSL, IS_BYTE). +//! +//! ## Column layout (1,480 columns) +//! +//! | Group | Size | Description | +//! |----------------|------|---------------------------------------------------| +//! | timestamp | 2 | DWordWL | +//! | round | 1 | Round index (0..23) | +//! | start | 200 | Input state bytes [5][5][8] | +//! | Cxz | 160 | Column parity chain [5][4][8] | +//! | Cxz_left | 40 | Left component of rotated C [5][8] | +//! | Cxz_right | 20 | Carry bits of HWSL(C[x],1) [5][4] | +//! | Dxz | 40 | D values [5][8] | +//! | theta | 200 | State after θ [5][5][8] | +//! | rot_left | 200 | Left half of ρ rotation [5][5][8] | +//! | rot_right | 200 | Right half of ρ rotation [5][5][8] | +//! | chi_ands | 200 | AND results for χ [5][5][8] | +//! | chi | 200 | State after χ [5][5][8] | +//! | rc | 8 | Round constant bytes | +//! | iota | 8 | χ[0][0] ⊕ rc | +//! | mu | 1 | Multiplicity (1 for real, 0 for padding) | +//! +//! Note: spec [[variables.constant]] `rnc` and `rbc` are inlined as compile-time +//! constants derived from `KECCAK_RHO[x][y]`, not materialized as columns. +//! `Cxz_right` is typed `[Bit, 4]` per spec d75944ee — HWSL with shift=1 +//! produces a single-bit carry, range-checked via IS_BIT polynomial constraints. + +use executor::vm::instruction::execution::{KECCAK_RC, KECCAK_RHO}; +use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; +use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; +use stark::trace::TraceTable; + +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; + +// ========================================================================= +// Column indices +// ========================================================================= + +pub mod cols { + pub const TIMESTAMP_0: usize = 0; + pub const TIMESTAMP_1: usize = 1; + pub const ROUND: usize = 2; + + // start[5][5][8] = 200 bytes — input state for this round + pub const START: usize = 3; + + // Cxz[5][4][8] = 160 bytes — partial XOR chain for column parities + pub const CXZ: usize = START + 200; // 203 + + // Cxz_left[5][8] = 40 bytes — left shift component of rotated C + pub const CXZ_LEFT: usize = CXZ + 160; // 363 + + // Cxz_right[5][4] = 20 bits — carry bit of HWSL(C[x] halfword[hw], 1). + // For shift=1, HWSL emits a single-bit carry; one column per halfword. + pub const CXZ_RIGHT: usize = CXZ_LEFT + 40; // 403 + + // Dxz[5][8] = 40 bytes + pub const DXZ: usize = CXZ_RIGHT + 20; // 423 + + // theta[5][5][8] = 200 bytes — state after θ + pub const THETA: usize = DXZ + 40; // 463 + + // rot_left[5][5][8] = 200 bytes + pub const ROT_LEFT: usize = THETA + 200; // 663 + + // rot_right[5][5][8] = 200 bytes + pub const ROT_RIGHT: usize = ROT_LEFT + 200; // 863 + + // chi_ands[5][5][8] = 200 bytes + // (pi is a spec [[variables.virtual]] — inlined as rot_left + rot_right at + // compile-resolved offsets, not materialized as columns.) + pub const CHI_ANDS: usize = ROT_RIGHT + 200; // 1063 + + // chi[5][5][8] = 200 bytes — state after χ + pub const CHI: usize = CHI_ANDS + 200; // 1263 + + // rc[8] — round constant bytes + pub const RC: usize = CHI + 200; // 1463 + + // iota[8] — χ[0][0] ⊕ rc + pub const IOTA: usize = RC + 8; // 1471 + + // mu — multiplicity flag. + // rnc and rbc (spec [[variables.constant]]) are inlined as compile-time + // constants from KECCAK_RHO, not allocated as columns. + pub const MU: usize = IOTA + 8; // 1479 + + pub const NUM_COLUMNS: usize = MU + 1; // 1480 + + // ------------------------------------------------------------------------- + // Index helpers + // ------------------------------------------------------------------------- + + /// Index into start[x][y][byte] (200 bytes, row-major: y varies fastest) + #[inline] + pub const fn start(x: usize, y: usize, byte: usize) -> usize { + START + (x + 5 * y) * 8 + byte + } + + /// Index into Cxz[x][stage][byte] (160 bytes) + #[inline] + pub const fn cxz(x: usize, stage: usize, byte: usize) -> usize { + CXZ + (x * 4 + stage) * 8 + byte + } + + /// Index into Cxz_left[x][byte] + #[inline] + pub const fn cxz_left(x: usize, byte: usize) -> usize { + CXZ_LEFT + x * 8 + byte + } + + /// Index into Cxz_right[x][hw] — single-bit carry for halfword `hw` of x. + #[inline] + pub const fn cxz_right_bit(x: usize, hw: usize) -> usize { + CXZ_RIGHT + x * 4 + hw + } + + /// For byte `b` of the rotated_Cxz output, return Some(hw) if a Cxz_right + /// bit contributes (even b), else None (odd b → only Cxz_left contributes). + /// Spec d75944ee/9143370f: rotated_Cxz[z] = Cxz_left[z] + (1 - z%2) * + /// Cxz_right[(z/2 - 1) mod 4]. + #[inline] + pub const fn cxz_right_bit_for_byte(b: usize) -> Option { + if b.is_multiple_of(2) { + Some((b / 2 + 3) % 4) + } else { + None + } + } + + /// Index into Dxz[x][byte] + #[inline] + pub const fn dxz(x: usize, byte: usize) -> usize { + DXZ + x * 8 + byte + } + + /// Index into theta[x][y][byte] + #[inline] + pub const fn theta(x: usize, y: usize, byte: usize) -> usize { + THETA + (x + 5 * y) * 8 + byte + } + + /// Index into rot_left[x][y][byte] + #[inline] + pub const fn rot_left(x: usize, y: usize, byte: usize) -> usize { + ROT_LEFT + (x + 5 * y) * 8 + byte + } + + /// Index into rot_right[x][y][byte] + #[inline] + pub const fn rot_right(x: usize, y: usize, byte: usize) -> usize { + ROT_RIGHT + (x + 5 * y) * 8 + byte + } + + /// Resolve pi[x][y][z] (spec virtual) to the (rot_left_col, rot_right_col) + /// pair whose sum equals pi[x][y][z]. rbc is compile-time constant. + #[inline] + pub fn pi_src_cols(x: usize, y: usize, z: usize) -> (usize, usize) { + use executor::vm::instruction::execution::KECCAK_RHO; + let sx = (x + 3 * y) % 5; + let sy = x; + let rho_offset = KECCAK_RHO[sx][sy] as usize; + let rbc_val = rho_offset / 16; + let (l_byte, r_byte) = match rbc_val { + 0 => (z, (z + 6) % 8), + 1 => ((z + 6) % 8, (z + 4) % 8), + 2 => ((z + 4) % 8, (z + 2) % 8), + 3 => ((z + 2) % 8, z), + _ => unreachable!(), + }; + (rot_left(sx, sy, l_byte), rot_right(sx, sy, r_byte)) + } + + /// Index into chi_ands[x][y][byte] + #[inline] + pub const fn chi_ands(x: usize, y: usize, byte: usize) -> usize { + CHI_ANDS + (x + 5 * y) * 8 + byte + } + + /// Index into chi[x][y][byte] + #[inline] + pub const fn chi(x: usize, y: usize, byte: usize) -> usize { + CHI + (x + 5 * y) * 8 + byte + } + + /// Index into rc[byte] + #[inline] + pub const fn rc(byte: usize) -> usize { + RC + byte + } + + /// Index into iota[byte] + #[inline] + pub const fn iota(byte: usize) -> usize { + IOTA + byte + } +} + +// ========================================================================= +// Operation struct +// ========================================================================= + +/// One keccak permutation call's worth of data (produces 24 rows). +#[derive(Debug, Clone)] +pub struct KeccakRoundOperation { + pub timestamp: u64, + pub input: [u64; 25], + pub output: [u64; 25], +} + +// ========================================================================= +// Trace generation +// ========================================================================= + +/// Extract byte `b` (0..8) from a u64 value. +#[inline] +fn byte_of(val: u64, b: usize) -> u8 { + ((val >> (b * 8)) & 0xFF) as u8 +} + +/// Compute halfword shift left: (value << shift) mod 2^16 and value >> (16 - shift). +#[inline] +fn hwsl(halfword: u16, shift: u8) -> (u16, u16) { + if shift == 0 { + (halfword, 0) + } else { + ( + halfword << shift, // u16 naturally wraps at 16 bits + halfword >> (16 - shift), + ) + } +} + +#[allow(clippy::needless_range_loop)] +/// Generate the KECCAK_RND trace table. +/// +/// Each `KeccakRoundOperation` produces 24 rows (one per round). The trace +/// computes all intermediate values (θ, ρ, π, χ, ι) at byte granularity. +pub fn generate_keccak_rnd_trace( + ops: &[KeccakRoundOperation], +) -> TraceTable { + let n_rows = (ops.len() * 24).next_power_of_two().max(4); + let mut data = vec![FE::zero(); n_rows * cols::NUM_COLUMNS]; + + for (op_idx, op) in ops.iter().enumerate() { + // Execute round-by-round, tracking the state + let mut state = op.input; + + for round in 0..24 { + let row_idx = op_idx * 24 + round; + let base = row_idx * cols::NUM_COLUMNS; + + // Timestamp & round + data[base + cols::TIMESTAMP_0] = FE::from(op.timestamp & 0xFFFF_FFFF); + data[base + cols::TIMESTAMP_1] = FE::from(op.timestamp >> 32); + data[base + cols::ROUND] = FE::from(round as u64); + + // start = current state as bytes + for x in 0..5 { + for y in 0..5 { + let lane = state[x + 5 * y]; + for b in 0..8 { + data[base + cols::start(x, y, b)] = FE::from(byte_of(lane, b) as u64); + } + } + } + + // === θ (theta) === + // Column parities: C[x] = XOR of all 5 lanes in column x + // Computed as a chain: Cxz[x][0] = start[x,0] XOR start[x,1] + // Cxz[x][k] = Cxz[x][k-1] XOR start[x,k+1] + let mut c_bytes = [[0u8; 8]; 5]; // C[x][byte] = final parity + let mut cxz = [[[0u8; 8]; 4]; 5]; // Cxz[x][stage][byte] + for x in 0..5 { + // Stage 0: XOR(start[x,0], start[x,1]) + for b in 0..8 { + let v0 = byte_of(state[x], b); + let v1 = byte_of(state[x + 5], b); + cxz[x][0][b] = v0 ^ v1; + data[base + cols::cxz(x, 0, b)] = FE::from(cxz[x][0][b] as u64); + } + // Stages 1..3: XOR(Cxz[x][k-1], start[x, k+1]) + for stage in 1..4 { + let y = stage + 1; + for b in 0..8 { + let prev = cxz[x][stage - 1][b]; + let sv = byte_of(state[x + 5 * y], b); + cxz[x][stage][b] = prev ^ sv; + data[base + cols::cxz(x, stage, b)] = FE::from(cxz[x][stage][b] as u64); + } + } + c_bytes[x] = cxz[x][3]; + } + + // Rotate C left by 1 bit using HWSL decomposition. + // HWSL shifts each halfword (u16) independently. For shift=1, the + // carry is a single bit (top bit of the halfword); we store it in + // one column per halfword (Cxz_right[x][hw], spec d75944ee). + // rotated_Cxz[z] = Cxz_left[z] + (1 - z%2) * Cxz_right[(z/2 - 1) mod 4] + let mut cxz_left_bytes = [[0u8; 8]; 5]; + let mut cxz_right_bits = [[0u8; 4]; 5]; + let mut rotated_c = [[0u8; 8]; 5]; + for x in 0..5 { + for hw in 0..4 { + let lo = c_bytes[x][hw * 2] as u16; + let hi = c_bytes[x][hw * 2 + 1] as u16; + let halfword = lo | (hi << 8); + let (shifted, carry) = hwsl(halfword, 1); + cxz_left_bytes[x][hw * 2] = (shifted & 0xFF) as u8; + cxz_left_bytes[x][hw * 2 + 1] = (shifted >> 8) as u8; + // For shift=1, carry ∈ {0, 1}. + cxz_right_bits[x][hw] = carry as u8; + data[base + cols::cxz_left(x, hw * 2)] = + FE::from(cxz_left_bytes[x][hw * 2] as u64); + data[base + cols::cxz_left(x, hw * 2 + 1)] = + FE::from(cxz_left_bytes[x][hw * 2 + 1] as u64); + data[base + cols::cxz_right_bit(x, hw)] = + FE::from(cxz_right_bits[x][hw] as u64); + } + // Reconstruct: left[b] + (1 - b%2) * right[(b/2 + 3) mod 4] + for b in 0..8 { + let right_contribution = match cols::cxz_right_bit_for_byte(b) { + Some(hw) => cxz_right_bits[x][hw], + None => 0, + }; + rotated_c[x][b] = cxz_left_bytes[x][b].wrapping_add(right_contribution); + } + } + + // D[x] = C[(x-1)%5] XOR rotated_C[(x+1)%5] + let mut d_bytes = [[0u8; 8]; 5]; + for x in 0..5 { + for b in 0..8 { + let val = c_bytes[(x + 4) % 5][b] ^ rotated_c[(x + 1) % 5][b]; + d_bytes[x][b] = val; + data[base + cols::dxz(x, b)] = FE::from(val as u64); + } + } + + // theta[x][y] = start[x][y] XOR D[x] + let mut theta_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let lane = state[x + 5 * y]; + let mut d_lane = 0u64; + for b in 0..8 { + d_lane |= (d_bytes[x][b] as u64) << (b * 8); + } + theta_lanes[x + 5 * y] = lane ^ d_lane; + for b in 0..8 { + data[base + cols::theta(x, y, b)] = + FE::from(byte_of(theta_lanes[x + 5 * y], b) as u64); + } + } + } + + // === ρ (rho) === + // For each lane, rotate theta[x][y] by KECCAK_RHO[x][y] bits. + // Decompose rotation as: rnc (nibble, 0..15) + 16*rbc[0] + 32*rbc[1]. + // rnc and rbc are inlined as compile-time constants per spec + // [[variables.constant]]; only HWSL outputs are stored in the trace. + for x in 0..5 { + for y in 0..5 { + let rho_offset = KECCAK_RHO[x][y] as usize; + let rnc_val = (rho_offset % 16) as u8; + let theta_lane = theta_lanes[x + 5 * y]; + for hw in 0..4 { + let halfword = ((theta_lane >> (hw * 16)) & 0xFFFF) as u16; + let (shifted, carry) = hwsl(halfword, rnc_val); + data[base + cols::rot_left(x, y, hw * 2)] = + FE::from((shifted & 0xFF) as u64); + data[base + cols::rot_left(x, y, hw * 2 + 1)] = + FE::from((shifted >> 8) as u64); + data[base + cols::rot_right(x, y, hw * 2)] = + FE::from((carry & 0xFF) as u64); + data[base + cols::rot_right(x, y, hw * 2 + 1)] = + FE::from((carry >> 8) as u64); + } + } + } + + // === π (pi) === + // pi[x][y] = rho[(x+3y)%5][x] where rho is the rotated theta. + // pi is a spec [[variables.virtual]] — not stored as trace columns. + // It's reconstructed inline in chi bus interactions as + // pi[x][y][z] = rot_left[sx,sy,l_byte] + rot_right[sx,sy,r_byte] + // with (sx, sy) = ((x+3y)%5, x) and (l_byte, r_byte) resolved from + // the compile-time rbc constant. pi_lanes is still computed here + // for the chi step below. + let mut pi_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let rotated = theta_lanes[x + 5 * y].rotate_left(KECCAK_RHO[x][y]); + let dst_x = y; + let dst_y = (2 * x + 3 * y) % 5; + pi_lanes[dst_x + 5 * dst_y] = rotated; + } + } + + // === χ (chi) === + let mut chi_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let not_next = !pi_lanes[(x + 1) % 5 + 5 * y]; + let next2 = pi_lanes[(x + 2) % 5 + 5 * y]; + let and_val = not_next & next2; + chi_lanes[x + 5 * y] = pi_lanes[x + 5 * y] ^ and_val; + for b in 0..8 { + data[base + cols::chi_ands(x, y, b)] = FE::from(byte_of(and_val, b) as u64); + data[base + cols::chi(x, y, b)] = + FE::from(byte_of(chi_lanes[x + 5 * y], b) as u64); + } + } + } + + // === ι (iota) === + let rc_val = KECCAK_RC[round]; + for b in 0..8 { + data[base + cols::rc(b)] = FE::from(byte_of(rc_val, b) as u64); + let iota_byte = byte_of(chi_lanes[0], b) ^ byte_of(rc_val, b); + data[base + cols::iota(b)] = FE::from(iota_byte as u64); + } + + // Update state for next round + chi_lanes[0] ^= rc_val; + state = chi_lanes; + + // mu = 1 (real row) + data[base + cols::MU] = FE::one(); + } + } + + // Padding rows have mu=0 and all zeros (default) + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +// ========================================================================= +// Bus interactions (1,371 total) +// ========================================================================= + +#[allow(clippy::needless_range_loop)] +pub fn bus_interactions() -> Vec { + let mut interactions = Vec::with_capacity(1371); + + // --- IO group (3) --- + + // 1. KECCAK bus: receive (timestamp, round, start[200]) + // Per spec keccak_round.toml: input = ["timestamp", "round", "start"] where + // start is [[[Byte, 8], 5], 5] — 200 Byte elements, each its own bus element. + { + let mut values = vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::ROUND, + packing: Packing::Direct, + }, + ]; + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::start(x, y, b), + packing: Packing::Direct, + }); + } + } + } + interactions.push(BusInteraction::receiver( + BusId::Keccak, + Multiplicity::Column(cols::MU), + values, + )); + } + + // 2. KECCAK bus: send (timestamp, round+1, out[200]) + // out[0][0] = iota, out[x][y] = chi for (x,y) != (0,0) + { + let mut values = vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::ROUND, + }, + LinearTerm::Constant(1), + ]), + ]; + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + let col = if x == 0 && y == 0 { + cols::IOTA + b + } else { + cols::chi(x, y, b) + }; + values.push(BusValue::Packed { + start_column: col, + packing: Packing::Direct, + }); + } + } + } + interactions.push(BusInteraction::sender( + BusId::Keccak, + Multiplicity::Column(cols::MU), + values, + )); + } + + // 3. KECCAK_RC: lookup (round) → rc[8] + { + let mut values = vec![BusValue::Packed { + start_column: cols::ROUND, + packing: Packing::Direct, + }]; + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::rc(b), + packing: Packing::Direct, + }); + } + interactions.push(BusInteraction::sender( + BusId::KeccakRc, + Multiplicity::Column(cols::MU), + values, + )); + } + + // --- Theta: Cxz chain XOR_BYTE (160) --- + // Stage 0: XOR(start[x,0,z], start[x,1,z]) → Cxz[x,0,z] + for x in 0..5 { + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::start(x, 0, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::start(x, 1, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::cxz(x, 0, b), + packing: Packing::Direct, + }, + ], + )); + } + } + // Stages 1..3: XOR(Cxz[x,stage-1,z], start[x,stage+1,z]) → Cxz[x,stage,z] + for x in 0..5 { + for stage in 1..4usize { + let y = stage + 1; + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::cxz(x, stage - 1, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::start(x, y, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::cxz(x, stage, b), + packing: Packing::Direct, + }, + ], + )); + } + } + } + + // --- Theta: HWSL for rotated C (20) --- + // HWSL(C[x] halfword[hw], 1) → (Cxz_left, Cxz_right) + // Cxz_right is a single carry bit zero-extended to a halfword (spec d75944ee). + for x in 0..5 { + for hw in 0..4 { + interactions.push(BusInteraction::sender( + BusId::Hwsl, + Multiplicity::Column(cols::MU), + vec![ + // Input halfword: Cxz[x][3][hw*2] + 256 * Cxz[x][3][hw*2+1] + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::cxz(x, 3, hw * 2), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::cxz(x, 3, hw * 2 + 1), + }, + ]), + // Shift amount = 1 + BusValue::constant(1), + // Output: shifted + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::cxz_left(x, hw * 2), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::cxz_left(x, hw * 2 + 1), + }, + ]), + // Output: carry (single bit cast to Half — high byte = 0). + BusValue::Packed { + start_column: cols::cxz_right_bit(x, hw), + packing: Packing::Direct, + }, + ], + )); + } + } + + // --- Theta: IS_BYTE range checks on Cxz_left (40) --- + // Cxz_right uses IS_BIT polynomial constraints (see create_constraints). + for x in 0..5 { + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::IsByte, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: cols::cxz_left(x, b), + packing: Packing::Direct, + }], + )); + } + } + + // --- Theta: Dxz XOR_BYTE (40) --- + // D[x][b] = C[(x-1)%5][b] XOR rotated_C[(x+1)%5][b] + // rotated_C[x'][b] = Cxz_left[x'][b] + (1 - b%2) * Cxz_right[x'][(b/2 - 1)%4] + // (spec d75944ee/9143370f). For odd b only Cxz_left contributes. + for x in 0..5 { + for b in 0..8 { + let mut rotated_c_terms = vec![LinearTerm::Column { + coefficient: 1, + column: cols::cxz_left((x + 1) % 5, b), + }]; + if let Some(hw) = cols::cxz_right_bit_for_byte(b) { + rotated_c_terms.push(LinearTerm::Column { + coefficient: 1, + column: cols::cxz_right_bit((x + 1) % 5, hw), + }); + } + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::cxz((x + 4) % 5, 3, b), + packing: Packing::Direct, + }, + BusValue::linear(rotated_c_terms), + BusValue::Packed { + start_column: cols::dxz(x, b), + packing: Packing::Direct, + }, + ], + )); + } + } + + // --- Theta final: XOR_BYTE (200) --- + // theta[x][y][b] = start[x][y][b] XOR D[x][b] + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::start(x, y, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::dxz(x, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::theta(x, y, b), + packing: Packing::Direct, + }, + ], + )); + } + } + } + + // --- Rho: HWSL (100) --- + // HWSL(theta[x][y] halfword[hw], rnc[x][y]) → (rot_left, rot_right) + // rnc is inlined as a constant: KECCAK_RHO[x][y] % 16. + for x in 0..5 { + for y in 0..5 { + let rnc_val = (KECCAK_RHO[x][y] % 16) as u64; + for hw in 0..4 { + interactions.push(BusInteraction::sender( + BusId::Hwsl, + Multiplicity::Column(cols::MU), + vec![ + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::theta(x, y, hw * 2), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::theta(x, y, hw * 2 + 1), + }, + ]), + BusValue::constant(rnc_val), + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::rot_left(x, y, hw * 2), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::rot_left(x, y, hw * 2 + 1), + }, + ]), + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::rot_right(x, y, hw * 2), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::rot_right(x, y, hw * 2 + 1), + }, + ]), + ], + )); + } + } + } + + // --- Rho: IS_BYTE range checks on rot_left + rot_right (400) --- + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::IsByte, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: cols::rot_left(x, y, b), + packing: Packing::Direct, + }], + )); + interactions.push(BusInteraction::sender( + BusId::IsByte, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: cols::rot_right(x, y, b), + packing: Packing::Direct, + }], + )); + } + } + } + + // --- Chi: AND_BYTE (200) --- + // chi_ands[x][y][b] = (255 - pi[(x+1)%5][y][b]) AND pi[(x+2)%5][y][b] + // pi is virtual: pi[x][y][z] = rot_left[sx,sy,l_byte] + rot_right[sx,sy,r_byte] + // with src lane (sx,sy) = ((x+3y)%5, x) and byte offsets from KECCAK_RHO. + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + let (p1_l, p1_r) = cols::pi_src_cols((x + 1) % 5, y, b); + let (p2_l, p2_r) = cols::pi_src_cols((x + 2) % 5, y, b); + interactions.push(BusInteraction::sender( + BusId::AndByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::linear(vec![ + LinearTerm::Constant(255), + LinearTerm::Column { + coefficient: -1, + column: p1_l, + }, + LinearTerm::Column { + coefficient: -1, + column: p1_r, + }, + ]), + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: p2_l, + }, + LinearTerm::Column { + coefficient: 1, + column: p2_r, + }, + ]), + BusValue::Packed { + start_column: cols::chi_ands(x, y, b), + packing: Packing::Direct, + }, + ], + )); + } + } + } + + // --- Chi: XOR_BYTE (200) --- + // chi[x][y][b] = pi[x][y][b] XOR chi_ands[x][y][b] (pi virtual). + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + let (p_l, p_r) = cols::pi_src_cols(x, y, b); + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: p_l, + }, + LinearTerm::Column { + coefficient: 1, + column: p_r, + }, + ]), + BusValue::Packed { + start_column: cols::chi_ands(x, y, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::chi(x, y, b), + packing: Packing::Direct, + }, + ], + )); + } + } + } + + // --- Iota: XOR_BYTE (8) --- + // iota[b] = chi[0][0][b] XOR rc[b] + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::chi(0, 0, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::rc(b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::iota(b), + packing: Packing::Direct, + }, + ], + )); + } + + interactions +} + +// ========================================================================= +// Constraints +// ========================================================================= + +/// KECCAK_RND polynomial constraints: 20 IS_BIT(μ; Cxz_right) constraints. +/// +/// Per spec d75944ee, `Cxz_right` is typed `[Bit, 4], 5` and range-checked via +/// IS_BIT polynomial constraints (kind="template", cond="μ"), not lookups: +/// μ * Cxz_right[x][hw] * (1 - Cxz_right[x][hw]) = 0 +/// +/// - pi is a spec [[variables.virtual]] inlined in chi bus interactions. +/// - rnc/rbc are spec [[variables.constant]] inlined as compile-time constants. +/// +/// All other checks (XOR, AND, HWSL, IS_BYTE, IS_HALF, KECCAK, KECCAK_RC) are +/// enforced via bus interactions against the BITWISE/KECCAK_RC chips. +pub fn create_constraints( + constraint_idx_start: usize, +) -> ( + Vec>>, + usize, +) { + use crate::constraints::templates::IsBitConstraint; + + let mut constraints: Vec< + Box>, + > = Vec::with_capacity(20); + let mut idx = constraint_idx_start; + for x in 0..5 { + for hw in 0..4 { + constraints + .push(IsBitConstraint::new(cols::MU, cols::cxz_right_bit(x, hw), idx).boxed()); + idx += 1; + } + } + (constraints, idx) +} + +#[cfg(test)] +mod tests { + use super::*; + use executor::vm::instruction::execution::keccak_f1600; + + /// pi is a spec virtual variable. Verify the inlined expression + /// (rot_left[sx,sy,l_byte] + rot_right[sx,sy,r_byte]) matches the byte of + /// rho(theta) for a non-trivial state. Uses mu=0 padding rows as a trivial + /// sanity check (all zeros), then a non-zero-input round as the real test. + #[test] + fn test_pi_virtual_matches_rotate() { + // Use a non-zero input so theta_lanes are non-trivial. + let input = [0x0102030405060708u64; 25]; + let mut output = input; + keccak_f1600(&mut output); + let op = KeccakRoundOperation { + timestamp: 42, + input, + output, + }; + let trace = generate_keccak_rnd_trace(&[op]); + let base = 0; + + // Recompute theta for round 0 in u64 to compare against virtual pi. + let mut c = [0u64; 5]; + for x in 0..5 { + c[x] = input[x] ^ input[x + 5] ^ input[x + 10] ^ input[x + 15] ^ input[x + 20]; + } + let mut d = [0u64; 5]; + for x in 0..5 { + d[x] = c[(x + 4) % 5] ^ c[(x + 1) % 5].rotate_left(1); + } + let mut theta_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + theta_lanes[x + 5 * y] = input[x + 5 * y] ^ d[x]; + } + } + + for x in 0..5 { + for y in 0..5 { + let sx = (x + 3 * y) % 5; + let sy = x; + let rotated = theta_lanes[sx + 5 * sy].rotate_left(KECCAK_RHO[sx][sy]); + for z in 0..8 { + let (l_col, r_col) = cols::pi_src_cols(x, y, z); + let virtual_pi = + &trace.main_table.data[base + l_col] + &trace.main_table.data[base + r_col]; + let expected = FE::from((rotated >> (z * 8)) & 0xFF); + assert_eq!( + virtual_pi, expected, + "virtual pi mismatch at ({x},{y},{z}): sx={sx}, sy={sy}" + ); + } + } + } + } +} diff --git a/prover/src/tables/mod.rs b/prover/src/tables/mod.rs index 19d14411d..4a6032ef2 100644 --- a/prover/src/tables/mod.rs +++ b/prover/src/tables/mod.rs @@ -28,6 +28,9 @@ pub mod cpu; pub mod decode; pub mod dvrm; pub mod halt; +pub mod keccak; +pub mod keccak_rc; +pub mod keccak_rnd; pub mod load; pub mod lt; pub mod memw; diff --git a/prover/src/tables/trace_builder.rs b/prover/src/tables/trace_builder.rs index d2743a1e5..bd5a3f2d2 100644 --- a/prover/src/tables/trace_builder.rs +++ b/prover/src/tables/trace_builder.rs @@ -10,10 +10,10 @@ //! ```text //! PHASE 0: ELF → DECODE, MEMORY_INIT (preprocessed tables) //! PHASE 1: Logs → CPU ops -//! PHASE 2: CPU ops → MEMW, MEMW_A, MEMW_R, LOAD, LT, Bitwise (with state tracking for MEMW/LOAD) +//! PHASE 2: CPU ops → MEMW, MEMW_A, MEMW_R, LOAD, LT, Bitwise, KECCAK (with state tracking for MEMW/LOAD/ECALL) //! PHASE 3: MEMW/MEMW_A → LT ops (timestamp ordering); MEMW_R uses IS_HALFWORD instead -//! PHASE 4: LT, MEMW_A, MEMW_R → Bitwise lookups -//! PHASE 5: Generate all traces +//! PHASE 4: LT, MEMW_A, MEMW_R, KECCAK → Bitwise lookups +//! PHASE 5: Generate all traces (including KECCAK core, KECCAK_RND, KECCAK_RC) //! ``` //! //! ## Usage @@ -40,6 +40,9 @@ use super::cpu::{self, CpuOperation}; use super::decode; use super::dvrm::{self, DvrmOperation}; use super::halt; +use super::keccak::{self, KeccakOperation}; +use super::keccak_rc; +use super::keccak_rnd::{self, KeccakRoundOperation}; use super::load::{self, LoadOperation}; use super::lt::{self, LtOperation}; use super::memw::{self, MemwOperation}; @@ -335,7 +338,7 @@ fn collect_cpu_ops( /// /// MEMW and LOAD collection requires sequential processing with state tracking. /// -/// Returns: (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops) +/// Returns: (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops) #[allow(clippy::type_complexity)] fn collect_ops_from_cpu( cpu_ops: &[CpuOperation], @@ -348,6 +351,7 @@ fn collect_ops_from_cpu( Vec, Vec, Vec, + Vec, ) { let mut memw_ops = Vec::with_capacity(cpu_ops.len() * 3); let mut load_ops = Vec::with_capacity(cpu_ops.len() / 8 + 1); @@ -355,6 +359,7 @@ fn collect_ops_from_cpu( let mut shift_ops = Vec::with_capacity(cpu_ops.len() / 10 + 1); let mut bitwise_ops = Vec::with_capacity(cpu_ops.len() * 4); let mut commit_ops = Vec::new(); + let mut keccak_ops = Vec::new(); let mut current_commit_index = 0u32; let mut commit_ecall_count = 0u32; @@ -397,6 +402,38 @@ fn collect_ops_from_cpu( commit_ecall_count += 1; } + // Collect KeccakPermute ECALL operations + if op.ecall_keccak { + let state_addr = op.keccak_state_addr; + let mut input = [0u64; 25]; + for (i, lane) in input.iter_mut().enumerate() { + let addr = state_addr + .checked_add(i as u64 * 8) + .expect("keccak state address range must be validated by the executor"); + let mut val = 0u64; + for b in 0..8 { + let byte_addr = addr + .checked_add(b as u64) + .expect("keccak state address range must be validated by the executor"); + let (byte_val, _ts) = memory_state.read_byte(byte_addr); + val |= (byte_val as u64) << (b * 8); + } + *lane = val; + } + let mut output = input; + executor::vm::instruction::execution::keccak_f1600(&mut output); + // collect_keccak_memw_ops handles memory_state + register_state updates + let keccak_memw_ops = + collect_keccak_memw_ops(op, &input, &output, memory_state, register_state); + memw_ops.extend(keccak_memw_ops); + keccak_ops.push(KeccakOperation { + timestamp: op.timestamp, + state_addr, + input, + output, + }); + } + // --- LT, SHIFT, and Bitwise (no state tracking needed) --- // Collect LT operations from SLT/BLT instructions @@ -440,6 +477,7 @@ fn collect_ops_from_cpu( shift_ops, bitwise_ops, commit_ops, + keccak_ops, ) } @@ -781,6 +819,73 @@ fn collect_halt_ops(register_state: &mut RegisterState) -> Vec { // ============================================================================= /// Collects LT operations from MEMW for timestamp ordering. +/// Collect MEMW operations for a KeccakPermute ECALL. +/// +/// Generates 25 read operations (input lanes at timestamp) and 25 write +/// operations (output lanes at timestamp+1). Each operation is 8 bytes wide. +fn collect_keccak_memw_ops( + op: &CpuOperation, + input: &[u64; 25], + output: &[u64; 25], + memory_state: &mut MemoryState, + register_state: &mut RegisterState, +) -> Vec { + let ts = op.timestamp; + let state_addr = op.keccak_state_addr; + let mut memw_ops = Vec::with_capacity(26); // 1 register read + 25 lane ops + + // Per spec (keccak:c:read_addr): read register x10 to get state_addr + { + let reg_value = pack_register_value(state_addr); + let reg_addr = 2 * 10u64; // x10 → address 20 + let (_old_val, old_ts) = register_state.read(10); + let old_timestamps = [old_ts, old_ts, 0, 0, 0, 0, 0, 0]; + let memw_op = MemwOperation::new(true, reg_addr, reg_value, ts, 2, true) + .with_old(reg_value, old_timestamps); + memw_ops.push(memw_op); + register_state.write(10, state_addr, ts); + } + + // Per spec (keccak:c:load_store_state): single combined read+write MEMW per lane. + // input = [0, state_ptr, output_state, timestamp, 0, 0, 1], output = input_state + // The MEMW table sees: old=input_state, value=output_state, is_read=true. + for (lane_idx, (&in_lane, &out_lane)) in input.iter().zip(output.iter()).enumerate() { + let lane_addr = state_addr + .checked_add(lane_idx as u64 * 8) + .expect("keccak state address range must be validated by the executor"); + + let mut old_bytes = [0u64; 8]; + let mut old_timestamps = [0u64; 8]; + for b in 0..8 { + old_bytes[b] = (in_lane >> (b * 8)) & 0xFF; + let byte_addr = lane_addr + .checked_add(b as u64) + .expect("keccak state address range must be validated by the executor"); + let (_old_val, old_ts) = memory_state.read_byte(byte_addr); + old_timestamps[b] = old_ts; + } + + let mut value_bytes = [0u64; 8]; + for (b, byte) in value_bytes.iter_mut().enumerate() { + *byte = (out_lane >> (b * 8)) & 0xFF; + } + + let memw_op = MemwOperation::new(false, lane_addr, value_bytes, ts, 8, true) + .with_old(old_bytes, old_timestamps); + memw_ops.push(memw_op); + + // Update memory state + for (b, &val) in value_bytes.iter().enumerate() { + let byte_addr = lane_addr + .checked_add(b as u64) + .expect("keccak state address range must be validated by the executor"); + memory_state.write_byte(byte_addr, val as u8, ts); + } + } + + memw_ops +} + /// /// From spec memw.md: /// - MEMW-C4 through MEMW-C7: old_timestamp[i] < timestamp (based on width) @@ -1544,6 +1649,264 @@ fn collect_bitwise_from_commit(commit_ops: &[CommitOperation]) -> Vec Vec { + use executor::vm::instruction::execution::{KECCAK_RC, KECCAK_RHO}; + + let mut ops = Vec::new(); + + for kop in keccak_ops { + let state_addr = kop.state_addr; + + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::AndByte, + (state_addr & 0xFF) as u8, + 7, + )); + + // Range-check addr bytes (paired with the IS_BYTE sends in + // keccak::bus_interactions): without this the field-element value of + // the addr_lo / addr_hi linear combinations is unconstrained per byte. + for b in 0..8 { + let byte = ((state_addr >> (b * 8)) & 0xFF) as u8; + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + byte, + )); + } + + // IS_HALF for state_ptr halfwords (100 per call) + for lane_idx in 0..25 { + let ptr = state_addr + .checked_add(lane_idx as u64 * 8) + .expect("keccak state address range must be validated by the executor"); + for shift in [0, 16, 32, 48] { + let half = ((ptr >> shift) & 0xFFFF) as u16; + ops.push(BitwiseOperation::halfword( + BitwiseOperationType::IsHalf, + (half & 0xFF) as u8, + ((half >> 8) & 0xFF) as u8, + )); + } + } + + // Replay keccak round computation to extract bitwise lookups + let mut state = kop.input; + for round in 0..24 { + // --- theta: Cxz chain XOR_BYTE (160) --- + let mut cxz = [[[0u8; 8]; 4]; 5]; + for x in 0..5 { + for b in 0..8 { + let v0 = ((state[x] >> (b * 8)) & 0xFF) as u8; + let v1 = ((state[x + 5] >> (b * 8)) & 0xFF) as u8; + cxz[x][0][b] = v0 ^ v1; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + v0, + v1, + )); + } + for stage in 1..4usize { + let y = stage + 1; + for b in 0..8 { + let prev = cxz[x][stage - 1][b]; + let sv = ((state[x + 5 * y] >> (b * 8)) & 0xFF) as u8; + cxz[x][stage][b] = prev ^ sv; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + prev, + sv, + )); + } + } + } + + // theta: HWSL for rotated C (20) + IS_BYTE on Cxz_left (40). + // Cxz_right is range-checked via IS_BIT polynomial constraints + // on the keccak_rnd chip, not via lookups (spec d75944ee). + let mut rotated_c = [[0u8; 8]; 5]; + for x in 0..5 { + let c = cxz[x][3]; + for hw in 0..4 { + let halfword = (c[hw * 2] as u16) | ((c[hw * 2 + 1] as u16) << 8); + let shifted = halfword << 1; // u16 wraps + ops.push(BitwiseOperation::new( + BitwiseOperationType::Hwsl, + (halfword & 0xFF) as u8, + ((halfword >> 8) & 0xFF) as u8, + 1, + )); + // IS_BYTE for cxz_left bytes + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + (shifted & 0xFF) as u8, + )); + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + ((shifted >> 8) & 0xFF) as u8, + )); + } + // Reconstruct rotated_c using the bit-typed Cxz_right. + let mut left_bytes = [0u8; 8]; + let mut right_bits = [0u8; 4]; + for hw in 0..4 { + let halfword = (c[hw * 2] as u16) | ((c[hw * 2 + 1] as u16) << 8); + let shifted = halfword << 1; + left_bytes[hw * 2] = (shifted & 0xFF) as u8; + left_bytes[hw * 2 + 1] = ((shifted >> 8) & 0xFF) as u8; + right_bits[hw] = (halfword >> 15) as u8; + } + for b in 0usize..8 { + let right_contribution = if b.is_multiple_of(2) { + right_bits[(b / 2 + 3) % 4] + } else { + 0 + }; + rotated_c[x][b] = left_bytes[b].wrapping_add(right_contribution); + } + } + + // theta: Dxz XOR_BYTE (40) + let mut d_bytes = [[0u8; 8]; 5]; + for x in 0..5 { + for b in 0..8 { + let a = cxz[(x + 4) % 5][3][b]; + let rb = rotated_c[(x + 1) % 5][b]; + d_bytes[x][b] = a ^ rb; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + a, + rb, + )); + } + } + + // theta final: XOR_BYTE (200) + let mut theta_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let lane = state[x + 5 * y]; + let mut d_lane = 0u64; + for b in 0..8 { + d_lane |= (d_bytes[x][b] as u64) << (b * 8); + } + theta_lanes[x + 5 * y] = lane ^ d_lane; + for b in 0..8 { + let s = ((lane >> (b * 8)) & 0xFF) as u8; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + s, + d_bytes[x][b], + )); + } + } + } + + // rho: HWSL (100) + IS_BYTE (400) + for x in 0..5 { + for y in 0..5 { + let rho_offset = KECCAK_RHO[x][y] as usize; + let rnc_val = (rho_offset % 16) as u8; + let theta_lane = theta_lanes[x + 5 * y]; + for hw in 0..4 { + let halfword = ((theta_lane >> (hw * 16)) & 0xFFFF) as u16; + let (shifted, carry) = if rnc_val == 0 { + (halfword, 0u16) + } else { + (halfword << rnc_val, halfword >> (16 - rnc_val)) + }; + ops.push(BitwiseOperation::new( + BitwiseOperationType::Hwsl, + (halfword & 0xFF) as u8, + ((halfword >> 8) & 0xFF) as u8, + rnc_val, + )); + // IS_BYTE for rot_left + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + (shifted & 0xFF) as u8, + )); + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + ((shifted >> 8) & 0xFF) as u8, + )); + // IS_BYTE for rot_right + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + (carry & 0xFF) as u8, + )); + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + ((carry >> 8) & 0xFF) as u8, + )); + } + } + } + + // pi: compute pi_lanes + let mut pi_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let rotated = theta_lanes[x + 5 * y].rotate_left(KECCAK_RHO[x][y]); + let dst_x = y; + let dst_y = (2 * x + 3 * y) % 5; + pi_lanes[dst_x + 5 * dst_y] = rotated; + } + } + + // chi: AND_BYTE (200) + XOR_BYTE (200) + let mut chi_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let not_next = !pi_lanes[(x + 1) % 5 + 5 * y]; + let next2 = pi_lanes[(x + 2) % 5 + 5 * y]; + let and_val = not_next & next2; + chi_lanes[x + 5 * y] = pi_lanes[x + 5 * y] ^ and_val; + for b in 0..8 { + let not_byte = ((not_next >> (b * 8)) & 0xFF) as u8; + let n2_byte = ((next2 >> (b * 8)) & 0xFF) as u8; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::AndByte, + not_byte, + n2_byte, + )); + let pi_byte = ((pi_lanes[x + 5 * y] >> (b * 8)) & 0xFF) as u8; + let and_byte = ((and_val >> (b * 8)) & 0xFF) as u8; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + pi_byte, + and_byte, + )); + } + } + } + + // iota: XOR_BYTE (8) + let rc_val = KECCAK_RC[round]; + for b in 0..8 { + let chi_byte = ((chi_lanes[0] >> (b * 8)) & 0xFF) as u8; + let rc_byte = ((rc_val >> (b * 8)) & 0xFF) as u8; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + chi_byte, + rc_byte, + )); + } + + // Update state + chi_lanes[0] ^= rc_val; + state = chi_lanes; + } + } + + ops +} + /// every address accessed during execution (ELF init + runtime stores/loads). /// ELF pages get their init data from the binary; all others are zero-init. fn generate_page_tables( @@ -1664,6 +2027,15 @@ pub struct Traces { /// COMMIT table for write syscall (byte-by-byte commit with recursive bus) pub commit: TraceTable, + /// KECCAK core table (one row per keccak permutation call) + pub keccak: TraceTable, + + /// KECCAK_RND round table (24 rows per keccak call) + pub keccak_rnd: TraceTable, + + /// KECCAK_RC precomputed round constant table (32 rows) + pub keccak_rc: TraceTable, + /// MEMW_R register-only fast-path traces (split into chunks of max_rows::MEMW_R) pub memw_registers: Vec>, } @@ -1683,6 +2055,7 @@ struct CollectedOps { mul_ops: Vec<(MulOperation, bool)>, dvrm_ops: Vec<(DvrmOperation, bool)>, commit_ops: Vec, + keccak_ops: Vec, } /// Chunk raw ops and generate one trace table per chunk. @@ -1711,6 +2084,7 @@ fn collect_all_ops( shift_ops: Vec, bitwise_ops: Vec, commit_ops: Vec, + keccak_ops: Vec, register_state: &mut RegisterState, ) -> CollectedOps { // HALT finalization: 33 register MEMW operations at timestamp u64::MAX. @@ -1800,6 +2174,7 @@ fn collect_all_ops( mul_ops, dvrm_ops, commit_ops, + keccak_ops, } } @@ -1832,6 +2207,7 @@ fn build_traces( mul_ops, dvrm_ops, commit_ops, + keccak_ops, } = ops; // ===================================================================== @@ -1863,6 +2239,8 @@ fn build_traces( .collect(); // COMMIT table sends IsByte and IsHalfword lookups bitwise_ops.extend(collect_bitwise_from_commit(&commit_ops)); + // KECCAK_RND sends XOR/AND/IS_BYTE/HWSL; KECCAK core sends IS_HALF + bitwise_ops.extend(collect_bitwise_from_keccak(&keccak_ops)); // CPU padding rows send IS_BYTE with all-zero values. // Add corresponding ops so the bitwise table multiplicities balance. @@ -1921,6 +2299,21 @@ fn build_traces( // Generate remaining traces in parallel (page, register, halt, commit). // chunk_and_generate already handled cpu, lt, memw, load, mul, dvrm, branch above. let commit_trace = commit::generate_commit_trace(&commit_ops); + + // Generate keccak traces (core table + per-round table + preprocessed RC) + let keccak_rnd_ops: Vec = keccak_ops + .iter() + .map(|op| KeccakRoundOperation { + timestamp: op.timestamp, + input: op.input, + output: op.output, + }) + .collect(); + let keccak_trace = keccak::generate_keccak_trace(&keccak_ops); + let keccak_rnd_trace = keccak_rnd::generate_keccak_rnd_trace(&keccak_rnd_ops); + let mut keccak_rc_trace = keccak_rc::generate_keccak_rc_trace(); + keccak_rc::update_multiplicities(&mut keccak_rc_trace, keccak_ops.len()); + let (pages, page_configs, register_trace, halt_trace); #[cfg(feature = "parallel")] { @@ -1977,6 +2370,9 @@ fn build_traces( branches, halt: halt_trace, commit: commit_trace, + keccak: keccak_trace, + keccak_rnd: keccak_rnd_trace, + keccak_rc: keccak_rc_trace, memw_registers, }) } @@ -1999,6 +2395,10 @@ impl Traces { use super::decode::cols::NUM_COLUMNS as DECODE_COLS; use super::dvrm::cols::NUM_COLUMNS as DVRM_COLS; use super::halt::cols::NUM_COLUMNS as HALT_COLS; + use super::keccak::cols::NUM_COLUMNS as KECCAK_COLS; + use super::keccak_rc::NUM_PRECOMPUTED_COLS as KECCAK_RC_PRECOMPUTED; + use super::keccak_rc::cols::NUM_COLUMNS as KECCAK_RC_COLS; + use super::keccak_rnd::cols::NUM_COLUMNS as KECCAK_RND_COLS; use super::load::cols::NUM_COLUMNS as LOAD_COLS; use super::lt::cols::NUM_COLUMNS as LT_COLS; use super::memw::cols::NUM_COLUMNS as MEMW_COLS; @@ -2027,6 +2427,9 @@ impl Traces { branches, halt, commit, + keccak, + keccak_rnd, + keccak_rc, memw_registers, page_configs: _, public_output_bytes: _, @@ -2071,6 +2474,9 @@ impl Traces { for t in memw_registers { total += (t.num_rows() * MEMW_R_COLS) as u64; } + total += (keccak.num_rows() * KECCAK_COLS) as u64; + total += (keccak_rnd.num_rows() * KECCAK_RND_COLS) as u64; + total += (keccak_rc.num_rows() * (KECCAK_RC_COLS - KECCAK_RC_PRECOMPUTED)) as u64; total } @@ -2103,6 +2509,9 @@ impl Traces { // page::bus_interactions count is constant regardless of page_base. let n_page = aux_cols(super::page::bus_interactions(0).len()); let n_memw_r = aux_cols(super::memw_register::bus_interactions().len()); + let n_keccak = aux_cols(super::keccak::bus_interactions().len()); + let n_keccak_rnd = aux_cols(super::keccak_rnd::bus_interactions().len()); + let n_keccak_rc = aux_cols(super::keccak_rc::bus_interactions().len()); let Traces { cpus, @@ -2120,6 +2529,9 @@ impl Traces { branches, halt, commit, + keccak, + keccak_rnd, + keccak_rc, memw_registers, page_configs: _, public_output_bytes: _, @@ -2164,6 +2576,9 @@ impl Traces { for t in memw_registers { total += (t.num_rows() * n_memw_r) as u64; } + total += (keccak.num_rows() * n_keccak) as u64; + total += (keccak_rnd.num_rows() * n_keccak_rnd) as u64; + total += (keccak_rc.num_rows() * n_keccak_rc) as u64; total } @@ -2322,7 +2737,7 @@ impl Traces { let mut memory_state = MemoryState::from_elf(elf); memory_state.add_private_input(private_input); let mut register_state = RegisterState::new(elf.entry_point); - let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops) = + let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops) = collect_ops_from_cpu(&cpu_ops, &mut memory_state, &mut register_state); let ops = collect_all_ops( @@ -2333,6 +2748,7 @@ impl Traces { shift_ops, bitwise_ops, commit_ops, + keccak_ops, &mut register_state, ); @@ -2368,7 +2784,7 @@ impl Traces { let mut memory_state = MemoryState::new(); let entry_point = cpu_ops.first().map_or(0, |op| op.decode.pc); let mut register_state = RegisterState::new(entry_point); - let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops) = + let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops) = collect_ops_from_cpu(&cpu_ops, &mut memory_state, &mut register_state); let ops = collect_all_ops( @@ -2379,6 +2795,7 @@ impl Traces { shift_ops, bitwise_ops, commit_ops, + keccak_ops, &mut register_state, ); @@ -2446,6 +2863,226 @@ impl Traces { ) -> Result { Self::from_logs_trimmed(logs, instructions, max_rows) } + + /// Like [`from_elf_and_logs`] but trims the bitwise table (TEST ONLY). + /// + /// Produces PAGE and REGISTER tables (requires ELF) while keeping the + /// bitwise table small. Same unsoundness caveats as [`from_logs_trimmed`]. + #[cfg(test)] + pub fn from_elf_and_logs_minimal( + elf: &Elf, + logs: &[Log], + max_rows: &super::MaxRowsConfig, + private_input: &[u8], + ) -> Result { + let mut traces = Self::from_elf_and_logs(elf, logs, max_rows, private_input)?; + traces.bitwise = bitwise::trim_zero_rows(traces.bitwise); + Ok(traces) + } +} + +#[cfg(test)] +mod keccak_tests { + use super::*; + use crate::tables::keccak::cols as core_cols; + use crate::tables::keccak_rnd::cols as rnd_cols; + use crate::tables::types::FE; + use executor::vm::instruction::execution::keccak_f1600; + + fn make_keccak_ops() -> (KeccakOperation, KeccakRoundOperation) { + let input = [0u64; 25]; + let mut output = input; + keccak_f1600(&mut output); + let kop = KeccakOperation { + timestamp: 42, + state_addr: 0x1000, + input, + output, + }; + let rop = KeccakRoundOperation { + timestamp: 42, + input, + output, + }; + (kop, rop) + } + + #[test] + fn test_keccak_bitwise_ops_count() { + let (kop, _) = make_keccak_ops(); + let ops = collect_bitwise_from_keccak(&[kop]); + + let xor = ops + .iter() + .filter(|o| o.lookup_type == BitwiseOperationType::XorByte) + .count(); + let and = ops + .iter() + .filter(|o| o.lookup_type == BitwiseOperationType::AndByte) + .count(); + let is_byte = ops + .iter() + .filter(|o| o.lookup_type == BitwiseOperationType::IsByte) + .count(); + let hwsl = ops + .iter() + .filter(|o| o.lookup_type == BitwiseOperationType::Hwsl) + .count(); + let is_half = ops + .iter() + .filter(|o| o.lookup_type == BitwiseOperationType::IsHalf) + .count(); + + assert_eq!(xor, 24 * 608, "XorByte count"); + assert_eq!(and, 24 * 200 + 1, "AndByte count"); + // Cxz_right Byte→Bit (spec d75944ee): drops 40 IS_BYTE per round. + // +8 per call to range-check the addr bytes used in alignment / no-overflow. + assert_eq!(is_byte, 24 * 440 + 8, "IsByte count"); + assert_eq!(hwsl, 24 * 120, "Hwsl count"); + assert_eq!(is_half, 100, "IsHalf count"); + assert_eq!(ops.len(), 109 + 24 * 1368, "Total bitwise ops"); + } + + #[test] + fn test_keccak_round_trace_matches_f1600() { + let (_, rop) = make_keccak_ops(); + let rnd_trace = keccak_rnd::generate_keccak_rnd_trace(&[rop]); + + let mut ref_state = [0u64; 25]; + for round in 0..24 { + let rc = executor::vm::instruction::execution::KECCAK_RC[round]; + let mut c = [0u64; 5]; + for x in 0..5 { + c[x] = ref_state[x] + ^ ref_state[x + 5] + ^ ref_state[x + 10] + ^ ref_state[x + 15] + ^ ref_state[x + 20]; + } + let mut d = [0u64; 5]; + for x in 0..5 { + d[x] = c[(x + 4) % 5] ^ c[(x + 1) % 5].rotate_left(1); + } + for i in 0..25 { + ref_state[i] ^= d[i % 5]; + } + let mut b = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + b[y + 5 * ((2 * x + 3 * y) % 5)] = ref_state[x + 5 * y] + .rotate_left(executor::vm::instruction::execution::KECCAK_RHO[x][y]); + } + } + for x in 0..5 { + for y in 0..5 { + ref_state[x + 5 * y] = + b[x + 5 * y] ^ (!b[(x + 1) % 5 + 5 * y] & b[(x + 2) % 5 + 5 * y]); + } + } + ref_state[0] ^= rc; + + let base = round * rnd_cols::NUM_COLUMNS; + for (lane, &lane_val) in ref_state.iter().enumerate() { + let x = lane % 5; + let y = lane / 5; + for byte_idx in 0..8 { + let expected = FE::from((lane_val >> (byte_idx * 8)) & 0xFF); + let col = if x == 0 && y == 0 { + rnd_cols::iota(byte_idx) + } else { + rnd_cols::chi(x, y, byte_idx) + }; + let trace_val = &rnd_trace.main_table.data[base + col]; + assert_eq!( + &expected, trace_val, + "Round {round} lane ({x},{y}) byte {byte_idx}" + ); + } + } + } + } + + #[test] + fn test_keccak_core_round_state_consistency() { + let (kop, rop) = make_keccak_ops(); + let core_trace = keccak::generate_keccak_trace(&[kop]); + let rnd_trace = keccak_rnd::generate_keccak_rnd_trace(&[rop]); + + // Round 0 start == core input_state + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + let core_val = &core_trace.main_table.data[core_cols::input_state(x, y, b)]; + let rnd_val = &rnd_trace.main_table.data[rnd_cols::start(x, y, b)]; + assert_eq!(core_val, rnd_val, "Round 0 start mismatch at ({x},{y},{b})"); + } + } + } + + // Round 23 out == core output_state + let rnd_base_23 = 23 * rnd_cols::NUM_COLUMNS; + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + let core_val = &core_trace.main_table.data[core_cols::output_state(x, y, b)]; + let rnd_val = if x == 0 && y == 0 { + &rnd_trace.main_table.data[rnd_base_23 + rnd_cols::iota(b)] + } else { + &rnd_trace.main_table.data[rnd_base_23 + rnd_cols::chi(x, y, b)] + }; + assert_eq!(core_val, rnd_val, "Round 23 out mismatch at ({x},{y},{b})"); + } + } + } + } + + #[test] + fn test_keccak_bus_interaction_counts() { + assert_eq!( + keccak::bus_interactions().len(), + 138, + "KECCAK core: 1 ECALL + 1 MEMW read_addr + 25 MEMW lanes + 100 IS_HALF + 1 AND_BYTE alignment + 8 IS_BYTE addr + 1 Keccak send + 1 Keccak recv" + ); + assert_eq!( + keccak_rnd::bus_interactions().len(), + 1371, + "KECCAK_RND: 3 IO + 460 theta + 500 rho + 400 chi + 8 iota \ + (Cxz_right Byte→Bit drops 40 IS_BYTE per spec d75944ee)" + ); + assert_eq!( + keccak_rc::bus_interactions().len(), + 1, + "KECCAK_RC: 1 receiver" + ); + } + + #[test] + fn test_keccak_column_counts() { + assert_eq!(core_cols::NUM_COLUMNS, 511, "KECCAK core columns"); + assert_eq!( + rnd_cols::NUM_COLUMNS, + 1480, + "KECCAK_RND columns (rnc/rbc inlined; pi virtual; Cxz_right Bit-typed)" + ); + assert_eq!(keccak_rc::cols::NUM_COLUMNS, 10, "KECCAK_RC columns"); + } + + #[test] + fn test_keccak_constraint_counts() { + let (core_constraints, _) = keccak::create_constraints(0); + assert_eq!( + core_constraints.len(), + 51, + "KECCAK core: 25 ADD pairs + no-overflow" + ); + + let (rnd_constraints, _) = keccak_rnd::create_constraints(0); + assert_eq!( + rnd_constraints.len(), + 20, + "KECCAK_RND: 20 IS_BIT(μ; Cxz_right_bit) per spec d75944ee" + ); + } } #[cfg(test)] diff --git a/prover/src/tables/types.rs b/prover/src/tables/types.rs index a1dcd043a..70aa6813d 100644 --- a/prover/src/tables/types.rs +++ b/prover/src/tables/types.rs @@ -110,6 +110,10 @@ pub enum BusId { /// COMMIT output bus: verifier computes the receiver contribution externally /// from `VmProof.public_output` using the shared LogUp challenges Commit, + /// Keccak core ↔ round chip: (timestamp, round, state[200 bytes]) + Keccak, + /// Keccak round ↔ RC lookup: (round, rc[8 bytes]) + KeccakRc, } impl BusId { @@ -138,6 +142,8 @@ impl BusId { BusId::Dvrm => "Dvrm", BusId::CommitNextByte => "CommitNextByte", BusId::Commit => "Commit", + BusId::Keccak => "Keccak", + BusId::KeccakRc => "KeccakRc", } } } @@ -169,6 +175,8 @@ impl TryFrom for BusId { 19 => Ok(BusId::Ecall), 20 => Ok(BusId::CommitNextByte), 21 => Ok(BusId::Commit), + 22 => Ok(BusId::Keccak), + 23 => Ok(BusId::KeccakRc), other => Err(other), } } diff --git a/prover/src/test_utils.rs b/prover/src/test_utils.rs index b47554857..1dcb768b2 100644 --- a/prover/src/test_utils.rs +++ b/prover/src/test_utils.rs @@ -43,6 +43,13 @@ use crate::tables::dvrm::{ bus_interactions as dvrm_bus_interactions, cols as dvrm_cols, dvrm_constraints, }; use crate::tables::halt::{bus_interactions as halt_bus_interactions, cols as halt_cols}; +use crate::tables::keccak::{bus_interactions as keccak_bus_interactions, cols as keccak_cols}; +use crate::tables::keccak_rc::{ + bus_interactions as keccak_rc_bus_interactions, cols as keccak_rc_cols, +}; +use crate::tables::keccak_rnd::{ + bus_interactions as keccak_rnd_bus_interactions, cols as keccak_rnd_cols, +}; use crate::tables::load::{ bus_interactions as load_bus_interactions, cols as load_cols, constraints as load_constraints, }; @@ -791,3 +798,59 @@ pub fn create_register_air(proof_options: &ProofOptions) -> VmAir { ) .with_name("REGISTER") } + +/// Create KECCAK core AIR with ADD constraints and bus interactions. +pub fn create_keccak_air(proof_options: &ProofOptions) -> VmAir { + let (constraints, _) = crate::tables::keccak::create_constraints(0); + let transition_constraints: Vec>> = constraints; + + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: keccak_bus_interactions(), + }; + + AirWithBuses::new( + keccak_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("KECCAK") +} + +/// Create KECCAK_RND AIR with pi constraints and bus interactions. +pub fn create_keccak_rnd_air(proof_options: &ProofOptions) -> VmAir { + let (constraints, _) = crate::tables::keccak_rnd::create_constraints(0); + let transition_constraints: Vec>> = constraints; + + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: keccak_rnd_bus_interactions(), + }; + + AirWithBuses::new( + keccak_rnd_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("KECCAK_RND") +} + +/// Create KECCAK_RC AIR with bus interactions (preprocessed table). +pub fn create_keccak_rc_air(proof_options: &ProofOptions) -> VmAir { + let transition_constraints: Vec>> = vec![]; + + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: keccak_rc_bus_interactions(), + }; + + AirWithBuses::new( + keccak_rc_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("KECCAK_RC") +} diff --git a/prover/src/tests/cpu_tests.rs b/prover/src/tests/cpu_tests.rs index 6b3239f43..9004d24c0 100644 --- a/prover/src/tests/cpu_tests.rs +++ b/prover/src/tests/cpu_tests.rs @@ -328,7 +328,7 @@ fn test_bus_interactions_count() { // - 1 DVRM (division/remainder) // - 1 SHIFT (shift operations) // - 1 BRANCH (branch/jump target calculation) - // - 1 ECALL (single shared bus for HALT and COMMIT, mult = ECALL) + // - 1 ECALL (shared bus for HALT, COMMIT, and KECCAK, mult = ECALL) // - 1 IS_BYTE for (RS1, RS2) paired // - 1 IS_BYTE for (RD, 0) // - 12 IS_BYTE (ARG1/ARG2/RES byte pairs: 4 pairs × 3 arrays) diff --git a/prover/src/tests/prove_elfs_tests.rs b/prover/src/tests/prove_elfs_tests.rs index 7e0fbc181..736fcd78e 100644 --- a/prover/src/tests/prove_elfs_tests.rs +++ b/prover/src/tests/prove_elfs_tests.rs @@ -22,10 +22,13 @@ use stark::prover::{IsStarkProver, Prover}; use stark::traits::AIR; use stark::verifier::{IsStarkVerifier, Verifier}; +use crate::VmProof; +use crate::tables::MaxRowsConfig; use crate::tables::trace_builder::Traces; use crate::tables::types::{GoldilocksExtension, GoldilocksField}; use executor::elf::Elf; +use executor::vm::execution::Executor; // Import shared utilities use crate::VmAirs; @@ -83,6 +86,79 @@ fn prove_and_verify_vm_minimal(elf: &Elf, traces: &mut Traces) -> bool { ) } +/// Like [`crate::prove_with_options_and_inputs`] but with trimmed bitwise (TEST ONLY). +/// +/// ~100x faster than the production path. Same unsoundness caveats as +/// [`Traces::from_elf_and_logs_minimal`]. The full preprocessed bitwise +/// path is covered by `test_prove_elfs_all_instructions_64_full`. +fn prove_vm_minimal(elf_bytes: &[u8], private_inputs: &[u8], max_rows: &MaxRowsConfig) -> VmProof { + let proof_options = ProofOptions::default_test_options(); + let elf = Elf::load(elf_bytes).expect("ELF load"); + let executor = Executor::new(&elf, private_inputs.to_vec()).expect("executor"); + let result = executor.run().expect("execution"); + let mut traces = + Traces::from_elf_and_logs_minimal(&elf, &result.logs, max_rows, private_inputs).unwrap(); + let table_counts = traces.table_counts(); + let airs = VmAirs::new( + &elf, + &proof_options, + true, + &traces.page_configs, + &table_counts, + ); + let runtime_page_ranges = traces.runtime_page_ranges(); + let proof = Prover::multi_prove( + airs.air_trace_pairs(&mut traces), + &mut DefaultTranscript::::new(&[]), + ) + .expect("prove"); + let num_private_input_pages = traces + .page_configs + .iter() + .filter(|c| c.is_private_input) + .count(); + VmProof { + proof, + runtime_page_ranges, + table_counts, + public_output: traces.public_output_bytes.clone(), + num_private_input_pages, + } +} + +/// Like [`crate::verify_with_options`] but matches the minimal bitwise AIR. +/// +/// Must be used to verify proofs from [`prove_vm_minimal`]. +fn verify_vm_minimal(vm_proof: &VmProof, elf_bytes: &[u8]) -> bool { + let proof_options = ProofOptions::default_test_options(); + let elf = Elf::load(elf_bytes).expect("ELF load"); + let page_configs = Traces::page_configs_from_elf_and_runtime( + &elf, + &vm_proof.runtime_page_ranges, + vm_proof.num_private_input_pages, + ); + let airs = VmAirs::new( + &elf, + &proof_options, + true, + &page_configs, + &vm_proof.table_counts, + ); + let air_refs = airs.air_refs(); + let expected_bus_balance = crate::compute_expected_commit_bus_balance( + &air_refs, + &vm_proof.proof, + &vm_proof.public_output, + ) + .expect("fingerprint collision in test"); + Verifier::multi_verify( + &air_refs, + &vm_proof.proof, + &mut DefaultTranscript::::new(&[]), + &expected_bus_balance, + ) +} + // ============================================================================= // Integration tests // ============================================================================= @@ -157,8 +233,9 @@ fn test_cpu_only_no_bus() { fn test_prove_elfs_sub_fast() { let _ = env_logger::builder().is_test(true).try_init(); let (elf, logs, _instructions) = run_asm_elf("sub"); - // Use from_elf_and_logs to get PAGE and REGISTER tables for Memory bus - let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + // Use from_elf_and_logs_minimal to get PAGE and REGISTER tables for Memory bus + let mut traces = + Traces::from_elf_and_logs_minimal(&elf, &logs, &Default::default(), &[]).unwrap(); assert!( prove_and_verify_vm_minimal(&elf, &mut traces), @@ -597,7 +674,8 @@ fn test_prove_elfs_test_xor_8() { #[test] fn test_prove_elfs_test_lb_lh_8() { let (elf, logs, _instructions) = run_asm_elf("test_lb_lh_8"); - let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + let mut traces = + Traces::from_elf_and_logs_minimal(&elf, &logs, &Default::default(), &[]).unwrap(); assert!( prove_and_verify_vm_minimal(&elf, &mut traces), "test_lb_lh_8 failed" @@ -607,7 +685,8 @@ fn test_prove_elfs_test_lb_lh_8() { #[test] fn test_prove_elfs_test_sb_sh_8() { let (elf, logs, _instructions) = run_asm_elf("test_sb_sh_8"); - let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + let mut traces = + Traces::from_elf_and_logs_minimal(&elf, &logs, &Default::default(), &[]).unwrap(); assert!( !traces.memws.is_empty(), "test_sb_sh_8 should produce MEMW rows for byte/halfword memory accesses" @@ -624,7 +703,8 @@ fn test_prove_elfs_test_sb_sh_8() { #[test] fn test_prove_elfs_lw_sw() { let (elf, logs, _instructions) = run_asm_elf("lw_sw"); - let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + let mut traces = + Traces::from_elf_and_logs_minimal(&elf, &logs, &Default::default(), &[]).unwrap(); assert!( !traces.memw_aligneds.is_empty(), "lw_sw should produce MEMW_A rows for aligned word accesses" @@ -644,7 +724,8 @@ fn test_prove_elfs_lw_sw() { #[test] fn test_prove_elfs_test_memw_split_ts() { let (elf, logs, _instructions) = run_asm_elf("test_memw_split_ts"); - let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + let mut traces = + Traces::from_elf_and_logs_minimal(&elf, &logs, &Default::default(), &[]).unwrap(); assert!( !traces.memws.is_empty(), "test_memw_split_ts should produce MEMW rows (split old_timestamps from sb+sb+lh)" @@ -683,7 +764,8 @@ fn test_prove_elfs_all_branches_16() { #[test] fn test_prove_elfs_all_loadstore_32() { let (elf, logs, _instructions) = run_asm_elf("all_loadstore_32"); - let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + let mut traces = + Traces::from_elf_and_logs_minimal(&elf, &logs, &Default::default(), &[]).unwrap(); assert!( prove_and_verify_vm_minimal(&elf, &mut traces), "all_loadstore_32 failed" @@ -716,6 +798,100 @@ fn test_prove_elfs_all_instructions_64() { ); } +#[test] +fn test_prove_elfs_keccak() { + let _ = env_logger::builder().is_test(true).try_init(); + + let (elf, logs, _instructions) = run_asm_elf("test_keccak"); + // Must use from_elf_and_logs (not from_logs_minimal) because keccak accesses + // RAM (stack memory), which requires PAGE tables for Memory bus balance. + let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + + assert!( + prove_and_verify_vm_minimal(&elf, &mut traces), + "keccak prove/verify failed" + ); +} + +#[test] +fn test_prove_elfs_keccak_multi_call() { + let _ = env_logger::builder().is_test(true).try_init(); + + let elf_bytes = crate::test_utils::asm_elf_bytes("test_keccak_multi"); + let elf = Elf::load(&elf_bytes).expect("Failed to load ELF"); + let executor = + executor::vm::execution::Executor::new(&elf, vec![]).expect("Failed to create executor"); + let result = executor.run().expect("Failed to run program"); + + // The guest initializes lane[i] = i + 1 and applies keccak-f[1600] three times. + // Cross-check the committed output against tiny-keccak's independent + // implementation of the permutation. + let mut expected_state: [u64; 25] = core::array::from_fn(|i| (i + 1) as u64); + for _ in 0..3 { + tiny_keccak::keccakf(&mut expected_state); + } + let mut expected_bytes = Vec::with_capacity(200); + for lane in expected_state { + expected_bytes.extend_from_slice(&lane.to_le_bytes()); + } + + assert_eq!( + result.return_values.memory_values, expected_bytes, + "committed state must match tiny-keccak after 3 keccak-f[1600] calls" + ); + + let mut traces = + Traces::from_elf_and_logs(&elf, &result.logs, &Default::default(), &[]).unwrap(); + assert_eq!( + traces.public_output_bytes, + result.return_values.memory_values + ); + + assert!( + prove_and_verify_vm_minimal(&elf, &mut traces), + "keccak multi-call prove/verify failed" + ); +} + +/// Verifier REJECTS a forged trace where an addr byte cell is set to a +/// non-byte field element. +/// +/// Without the IS_BYTE range checks on addr(0..7), an attacker could keep +/// `addr_lo = b0 + 256·b1 + 65536·b2 + 2^24·b3` equal to an unaligned target +/// address as a field element while setting addr(0)=0 (passing the AndByte +/// alignment check) and folding the carry into addr(1) as a non-byte +/// FE-element. This test asserts that mutating addr(1) to a non-byte value +/// unbalances the verifier's bus checks and the proof is rejected. +#[test] +fn test_prove_elfs_keccak_unaligned_state_addr() { + use crate::tables::keccak::cols as keccak_cols; + + let _ = env_logger::builder().is_test(true).try_init(); + + let elf_bytes = crate::test_utils::asm_elf_bytes("test_keccak_multi"); + let elf = Elf::load(&elf_bytes).expect("Failed to load ELF"); + let executor = + executor::vm::execution::Executor::new(&elf, vec![]).expect("Failed to create executor"); + let result = executor.run().expect("Failed to run program"); + let mut traces = + Traces::from_elf_and_logs(&elf, &result.logs, &Default::default(), &[]).unwrap(); + + // Tamper the first real keccak row: replace addr(1) (a byte cell) with a + // value outside [0, 256). The new IS_BYTE bus sender will emit this + // value with multiplicity MU=1; the IS_BYTE preprocessed table only + // contains 0..256, so the bus cannot balance. + traces.keccak.main_table.set( + 0, + keccak_cols::addr(1), + FieldElement::::from(257u64), + ); + + assert!( + !prove_and_verify_vm_minimal(&elf, &mut traces), + "Verifier must reject a keccak proof whose addr cells are not bytes" + ); +} + #[test] fn test_prove_elfs_test_commit_4() { let elf_bytes = crate::test_utils::asm_elf_bytes("test_commit_4"); @@ -732,7 +908,7 @@ fn test_prove_elfs_test_commit_4() { ); let mut traces = - Traces::from_elf_and_logs(&elf, &result.logs, &Default::default(), &[]).unwrap(); + Traces::from_elf_and_logs_minimal(&elf, &result.logs, &Default::default(), &[]).unwrap(); assert_eq!( traces.public_output_bytes, result.return_values.memory_values @@ -758,7 +934,7 @@ fn test_prove_elfs_test_commit_4_wrong_pages_rejected() { executor::vm::execution::Executor::new(&elf, vec![]).expect("Failed to create executor"); let result = executor.run().expect("Failed to run program"); let mut traces = - Traces::from_elf_and_logs(&elf, &result.logs, &Default::default(), &[]).unwrap(); + Traces::from_elf_and_logs_minimal(&elf, &result.logs, &Default::default(), &[]).unwrap(); // Prover uses correct page configs let table_counts = traces.table_counts(); @@ -1459,7 +1635,8 @@ fn test_debug_memory_tokens_sb_sh() { #[test] fn test_deep_stack_passes() { let (elf, logs, _instructions) = run_asm_elf("deep_stack"); - let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + let mut traces = + Traces::from_elf_and_logs_minimal(&elf, &logs, &Default::default(), &[]).unwrap(); assert!( prove_and_verify_vm_minimal(&elf, &mut traces), @@ -1481,7 +1658,7 @@ fn test_deep_stack_runtime_pages_roundtrip() { executor::vm::execution::Executor::new(&elf, vec![]).expect("Failed to create executor"); let result = executor.run().expect("Failed to run program"); let mut traces = - Traces::from_elf_and_logs(&elf, &result.logs, &Default::default(), &[]).unwrap(); + Traces::from_elf_and_logs_minimal(&elf, &result.logs, &Default::default(), &[]).unwrap(); let runtime_page_ranges = traces.runtime_page_ranges(); let table_counts = traces.table_counts(); @@ -1541,7 +1718,7 @@ fn test_deep_stack_missing_pages_rejected() { executor::vm::execution::Executor::new(&elf, vec![]).expect("Failed to create executor"); let result = executor.run().expect("Failed to run program"); let mut traces = - Traces::from_elf_and_logs(&elf, &result.logs, &Default::default(), &[]).unwrap(); + Traces::from_elf_and_logs_minimal(&elf, &result.logs, &Default::default(), &[]).unwrap(); // Prover uses correct page configs (auto-detected from MemoryState) let table_counts = traces.table_counts(); @@ -1592,7 +1769,8 @@ fn test_deep_stack_missing_pages_rejected() { #[test] fn test_heap_alloc_passes() { let (elf, logs, _instructions) = run_asm_elf("heap_alloc"); - let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + let mut traces = + Traces::from_elf_and_logs_minimal(&elf, &logs, &Default::default(), &[]).unwrap(); // Verify runtime_page_ranges includes the heap page let ranges = traces.runtime_page_ranges(); @@ -1621,7 +1799,7 @@ fn test_heap_alloc_runtime_pages_roundtrip() { executor::vm::execution::Executor::new(&elf, vec![]).expect("Failed to create executor"); let result = executor.run().expect("Failed to run program"); let mut traces = - Traces::from_elf_and_logs(&elf, &result.logs, &Default::default(), &[]).unwrap(); + Traces::from_elf_and_logs_minimal(&elf, &result.logs, &Default::default(), &[]).unwrap(); let runtime_page_ranges = traces.runtime_page_ranges(); let table_counts = traces.table_counts(); @@ -1796,7 +1974,7 @@ fn test_crafted_zero_count_proof_must_not_verify() { let airs = VmAirs::new(&elf, &proof_options, true, &[], &zero_counts); let verifier_air_refs = airs.air_refs(); - assert_eq!(verifier_air_refs.len(), 5); + assert_eq!(verifier_air_refs.len(), 8); let mut bitwise_trace = crate::tables::bitwise::generate_bitwise_trace(); @@ -1832,11 +2010,9 @@ fn test_crafted_zero_count_proof_must_not_verify() { #[test] fn test_small_max_rows_splits_tables() { let elf_bytes = crate::test_utils::asm_elf_bytes("all_instructions_64"); - let proof_options = ProofOptions::default_test_options(); let max_rows = crate::tables::MaxRowsConfig::small(); - let vm_proof = crate::prove_with_options(&elf_bytes, &proof_options, &max_rows) - .expect("Prover should succeed with small max_rows"); + let vm_proof = prove_vm_minimal(&elf_bytes, &[], &max_rows); // With 2^5 max rows and 64+ instructions, tables should have multiple chunks. assert!( @@ -1845,9 +2021,10 @@ fn test_small_max_rows_splits_tables() { vm_proof.table_counts.cpu ); - let verified = crate::verify_with_options(&vm_proof, &elf_bytes, &proof_options) - .expect("Verifier should not error"); - assert!(verified, "Proof with small max_rows should verify"); + assert!( + verify_vm_minimal(&vm_proof, &elf_bytes), + "Proof with small max_rows should verify" + ); } // ============================================================================= @@ -1912,8 +2089,11 @@ fn test_verify_rejects_inflated_table_counts() { #[test] fn test_prove_wsuffix_64bit() { let elf_bytes = crate::test_utils::asm_elf_bytes("test_wsuffix_64bit"); - let result = crate::prove_and_verify(&elf_bytes).expect("prove_and_verify failed"); - assert!(result, "W-suffix 64-bit register test should verify"); + let vm_proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); + assert!( + verify_vm_minimal(&vm_proof, &elf_bytes), + "W-suffix 64-bit register test should verify" + ); } /// Proves a minimal Rust std program that uses `init_allocator()` and @@ -1930,9 +2110,9 @@ fn test_prove_allocator_minimal_reproducer() { let elf_bytes = std::fs::read(workspace_root.join("executor/program_artifacts/rust/allocator.elf")) .expect("allocator.elf not found — run `make compile-programs-rust`"); - let proof = crate::prove(&elf_bytes).expect("prove should succeed"); + let proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); assert!( - crate::verify(&proof, &elf_bytes).expect("verify should not error"), + verify_vm_minimal(&proof, &elf_bytes), "allocator.elf should verify" ); assert_eq!(proof.public_output, b"Hello World"); @@ -1949,9 +2129,9 @@ fn test_pure_commit_rust() { let elf_bytes = std::fs::read(workspace_root.join("executor/program_artifacts/rust/pure_commit.elf")) .expect("pure_commit.elf not found — run `make compile-programs-rust`"); - let proof = crate::prove(&elf_bytes).expect("prove should succeed"); + let proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); assert!( - crate::verify(&proof, &elf_bytes).expect("verify should not error"), + verify_vm_minimal(&proof, &elf_bytes), "pure_commit.elf should verify" ); assert_eq!(proof.public_output, vec![0xAA, 0xBB, 0xCC, 0xDD]); @@ -1974,12 +2154,8 @@ fn test_prove_with_input_empty() { fn test_prove_private_input_xpage() { let elf_bytes = crate::test_utils::asm_elf_bytes("test_private_input_xpage"); let input: Vec = (0u8..16).collect(); - let proof = - crate::prove_with_inputs(&elf_bytes, &input).expect("prove_with_inputs should succeed"); - assert!( - crate::verify(&proof, &elf_bytes).expect("verify should not error"), - "proof should verify" - ); + let proof = prove_vm_minimal(&elf_bytes, &input, &Default::default()); + assert!(verify_vm_minimal(&proof, &elf_bytes), "proof should verify"); assert_eq!(proof.public_output, input[4..12].to_vec()); } @@ -1991,12 +2167,34 @@ fn test_prove_private_input_different_values() { 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, ]; - let proof = crate::prove_with_inputs(&elf_bytes, &input).expect("prove"); + let proof = prove_vm_minimal(&elf_bytes, &input, &Default::default()); + assert!(verify_vm_minimal(&proof, &elf_bytes), "proof should verify"); + assert_eq!(proof.public_output, input[4..12].to_vec()); +} + +/// End-to-end: EF zkVM IO interface — demo guest reads its private input via +/// `read_input` and emits it back through TWO `write_output` calls. The +/// COMMIT AIR's running `x254` index concatenates them; the resulting proof's +/// `public_output` must equal the original input. +#[test] +fn test_prove_ef_io_demo_concatenates() { + let workspace_root = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .parent() + .expect("workspace root") + .to_path_buf(); + let elf_bytes = + std::fs::read(workspace_root.join("executor/program_artifacts/rust/ef_io_demo.elf")) + .expect("ef_io_demo.elf not found — run `make compile-programs-rust`"); + let input: &[u8] = b"hello world!"; + let proof = crate::prove_with_inputs(&elf_bytes, input).expect("prove should succeed"); assert!( - crate::verify(&proof, &elf_bytes).expect("verify"), - "proof should verify" + crate::verify(&proof, &elf_bytes).expect("verify should not error"), + "ef_io_demo should verify" + ); + assert_eq!( + proof.public_output, input, + "two write_output calls must concatenate" ); - assert_eq!(proof.public_output, input[4..12].to_vec()); } /// End-to-end: Rust std program with private input. @@ -2010,9 +2208,9 @@ fn test_prove_commit_sum() { std::fs::read(workspace_root.join("executor/program_artifacts/rust/commit_sum.elf")) .expect("commit_sum.elf not found — run `make compile-programs-rust`"); let input = &[3u8, 5u8]; - let proof = crate::prove_with_inputs(&elf_bytes, input).expect("prove should succeed"); + let proof = prove_vm_minimal(&elf_bytes, input, &Default::default()); assert!( - crate::verify(&proof, &elf_bytes).expect("verify should not error"), + verify_vm_minimal(&proof, &elf_bytes), "commit_sum should verify" ); assert_eq!(proof.public_output, vec![8u8]); @@ -2128,7 +2326,7 @@ fn test_verify_rejects_private_input_with_tampered_public_output() { let vm_proof = crate::prove_with_inputs(&elf_bytes, &input).expect("prove should succeed"); assert!( - crate::verify(&vm_proof, &elf_bytes).expect("verify"), + crate::verify(&vm_proof, &elf_bytes).expect("verify should not error"), "Baseline must verify" ); @@ -2177,8 +2375,11 @@ fn test_proof_does_not_contain_private_input_field() { #[test] fn test_addiw_neg_immediate() { let elf_bytes = crate::test_utils::asm_elf_bytes("test_addiw_neg"); - let result = crate::prove_and_verify(&elf_bytes).expect("prove_and_verify failed"); - assert!(result, "addiw with negative immediate should verify"); + let proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); + assert!( + verify_vm_minimal(&proof, &elf_bytes), + "addiw with negative immediate should verify" + ); } /// Regression test: both main and aux field element counts must be nonzero for any real ELF. diff --git a/spec/README.md b/spec/README.md index 127e528c8..da844e801 100644 --- a/spec/README.md +++ b/spec/README.md @@ -1,11 +1,19 @@ -# LambdaVM specification -This repository contains specification for [`LambdaVM`](https://github.com/yetanotherco/lambda_vm). -The specification is written in [`Typst`](https://typst.app/) and can be rendered by [`shiroa`](https://myriad-dreamin.github.io/shiroa/) as either a file (pdf) or a wiki (html). +# Lambda VM Specification -## Installation & Development setup -1. [Install `Typst`](https://github.com/typst/typst?tab=readme-ov-file#installation). -2. [Install `shiroa`](https://myriad-dreamin.github.io/shiroa/guide/installation.html). -3. Clone this repository. -4. Open the repository in a terminal and execute `shiroa serve`. +Formal specification of the Lambda VM. Covers the per-chip AIR constraints (CPU, decode, bitwise, branch, LT, shift, MUL, DVRM, MEMW, LOAD, page, register, halt, commit, keccak), the memory argument, and the LogUp lookup framework that links the tables. -At this point, the wiki version is hosted locally and is actively updated as you modify the specification files. +The specification is written in [Typst](https://typst.app/) and rendered as either a PDF or a browsable HTML wiki using [shiroa](https://myriad-dreamin.github.io/shiroa/). + +## Rendering it locally + +1. [Install Typst](https://github.com/typst/typst?tab=readme-ov-file#installation). +2. [Install shiroa](https://myriad-dreamin.github.io/shiroa/guide/installation.html). +3. From this directory, run: + + ```sh + shiroa serve + ``` + + shiroa will host the HTML wiki locally and live-reload as you edit the `.typ` source files. + +To produce a PDF instead, see the shiroa documentation for the `build` command. diff --git a/syscalls/README.md b/syscalls/README.md new file mode 100644 index 000000000..fa5758741 --- /dev/null +++ b/syscalls/README.md @@ -0,0 +1,52 @@ +# Lambda VM Syscalls (Guest SDK) + +Guest-side library for programs that run inside the Lambda VM. Provides the syscalls and the default entry point that let a Rust program interact with the host: read private input, commit public output, halt, and invoke precompiles. + +Published as `lambda-vm-syscalls`. Intended to be used from RISC-V (RV64IM) guest binaries cross-compiled with the toolchain described in the root [`README.md`](../README.md). + +## What it provides + +| Function | Purpose | +|---|---| +| `commit(bytes: &[u8])` | Append bytes to the **public output** that the verifier checks. | +| `get_private_input() -> Vec` | Read the host-supplied private input bytes (memory-mapped at `0xFF000000`). | +| `sys_halt() -> !` | Terminate execution cleanly. Called automatically after `main` by the default entry point. | +| `keccak_permute(state: &mut [u64; 25])` | Keccak-f[1600] permutation precompile. | + +The crate also provides a default `_start` that initialises the allocator, calls `main`, and halts. + +> **Note:** the `print_string` syscall is temporarily unavailable — calling it in a guest will cause proof verification to fail. Tracked as a follow-up. + +## Example + +A minimal guest that reads private input and commits a (non-secret) summary of it: + +```rust +use lambda_vm_syscalls::syscalls; + +pub fn main() { + let input = syscalls::get_private_input(); + + // Anything passed to `commit` becomes part of the proof's public output. + // Don't echo private input here — commit a derived value instead. + let len = (input.len() as u32).to_le_bytes(); + syscalls::commit(&len); +} +``` + +See [`executor/programs/rust/`](../executor/programs/rust/) for more example guests (`fibonacci`, `keccak`, `hashmap`, …). + +## Building a guest + +Guests are compiled with a pinned nightly toolchain for the custom RISC-V target `riscv64im-lambda-vm-elf.json`. The simplest path is to drop a new project under `executor/programs/rust//` and run `make compile-programs-rust` from the repo root. + +See the root [`README.md`](../README.md) for the full toolchain setup (sysroot, nightly pin, target spec). + +## Unsupported `std` functions + +The following functions are stubbed and **panic at runtime** if called — Lambda VM does not provide stdin or command-line arguments: + +- `sys_read` (so `io::Read` for `Stdin` is not available) +- `sys_argc`, `sys_argv` + +To pass data into a guest, use `get_private_input()` instead. diff --git a/syscalls/src/ef_io.rs b/syscalls/src/ef_io.rs new file mode 100644 index 000000000..dabf7818d --- /dev/null +++ b/syscalls/src/ef_io.rs @@ -0,0 +1,84 @@ +//! EF zkVM IO interface: +//! +//! Two C-callable functions that match the EF standard so portable applications +//! compile unchanged across zkVMs: +//! +//! - `read_input`: returns a zero-copy pointer + size to the private input. +//! - `write_output`: appends bytes to the public output. Multiple calls +//! concatenate. +//! +//! On Lambda VM these map to: +//! - `read_input` → memory-mapped private input region at `0xFF000000` +//! (4-byte LE length prefix at base, data at `+4`). +//! - `write_output` → ECALL #64 (Commit). The trace builder maintains a +//! running commitment index in synthetic register `x254`, so multiple +//! ECALLs naturally concatenate at the proof level. + +#[cfg(target_arch = "riscv64")] +use core::arch::asm; + +#[cfg(target_arch = "riscv64")] +use crate::syscalls::{PRIVATE_INPUT_START, SyscallNumbers}; + +/// EF IO: return a zero-copy pointer and size for the private input. +/// +/// Per the spec this function is idempotent, callable multiple times, and +/// cannot fail. If `buf_size` is 0, the value of `buf_ptr` is unspecified. +/// Privacy of the input is the guest's responsibility; the VM does not +/// enforce it. +/// +/// # Safety +/// +/// `buf_ptr` and `buf_size` must be valid, writable pointers. +#[cfg(target_arch = "riscv64")] +#[unsafe(no_mangle)] +pub unsafe extern "C" fn read_input(buf_ptr: *mut *const u8, buf_size: *mut usize) { + unsafe { + let len_ptr = PRIVATE_INPUT_START as *const u32; + let len = core::ptr::read_volatile(len_ptr) as usize; + *buf_ptr = (PRIVATE_INPUT_START + 4) as *const u8; + *buf_size = len; + } +} + +/// EF IO: append `size` bytes from `output` to the public output. +/// +/// Multiple calls concatenate. Per the spec this function cannot fail; in +/// practice the executor enforces a total-output cap (see +/// `MAX_PUBLIC_OUTPUT_TOTAL_SIZE` in `executor::vm::memory`). Exceeding it +/// causes the executor to return an error and abort proving — not a graceful +/// failure mode at the C boundary, but consistent with "cannot fail" for +/// well-formed programs that stay under the limit. +/// +/// # Safety +/// +/// `output` must point to `size` readable bytes within guest memory. +#[cfg(target_arch = "riscv64")] +#[unsafe(no_mangle)] +pub unsafe extern "C" fn write_output(output: *const u8, size: usize) { + unsafe { + asm!( + "ecall", + in("a0") 1usize, // fd = 1 (stdout) — required by the COMMIT chip + in("a1") output, + in("a2") size, + in("a7") SyscallNumbers::Commit as usize, + ); + } +} + +/// Host-side stub — Lambda VM's IO interface is only implemented for the +/// `riscv64` guest target. Not exported with C linkage on host so the +/// generic name doesn't collide with C dependencies in test builds. +#[cfg(not(target_arch = "riscv64"))] +pub fn read_input(_buf_ptr: *mut *const u8, _buf_size: *mut usize) { + unimplemented!("read_input is only implemented for riscv64 targets"); +} + +/// Host-side stub — Lambda VM's IO interface is only implemented for the +/// `riscv64` guest target. Not exported with C linkage on host so the +/// generic name doesn't collide with C dependencies in test builds. +#[cfg(not(target_arch = "riscv64"))] +pub fn write_output(_output: *const u8, _size: usize) { + unimplemented!("write_output is only implemented for riscv64 targets"); +} diff --git a/syscalls/src/lib.rs b/syscalls/src/lib.rs index 378257d18..79a420181 100644 --- a/syscalls/src/lib.rs +++ b/syscalls/src/lib.rs @@ -1,4 +1,5 @@ pub mod allocator; +pub mod ef_io; pub mod entrypoint; pub mod random; pub mod syscalls; diff --git a/syscalls/src/syscalls.rs b/syscalls/src/syscalls.rs index ae0315ff5..6451828c6 100644 --- a/syscalls/src/syscalls.rs +++ b/syscalls/src/syscalls.rs @@ -6,16 +6,20 @@ use core::arch::asm; /// The host pre-loads the input; the guest reads directly (no ecall). /// Must match `executor::vm::memory::PRIVATE_INPUT_START_INDEX`. #[cfg(target_arch = "riscv64")] -const PRIVATE_INPUT_START: usize = 0xFF000000; +pub const PRIVATE_INPUT_START: usize = 0xFF000000; #[cfg(target_arch = "riscv64")] -enum SyscallNumbers { +pub enum SyscallNumbers { Print = 1, Panic = 2, Commit = 64, Halt = 93, } +/// Syscall number for KeccakPermute (u64::MAX - 1). +#[cfg(target_arch = "riscv64")] +const KECCAK_SYSCALL_NUMBER: usize = usize::MAX - 1; + #[cfg(target_arch = "riscv64")] /// This is a template for printing in the vm pub fn print_string(s: &str) { @@ -120,6 +124,24 @@ pub fn sys_halt() -> ! { unimplemented!("syscalls are only implemented for riscv64 targets"); } +#[cfg(target_arch = "riscv64")] +/// Apply the Keccak-f[1600] permutation to a 25-element u64 state in-place. +pub fn keccak_permute(state: &mut [u64; 25]) { + unsafe { + asm!( + "ecall", + in("a0") state.as_mut_ptr(), + in("a7") KECCAK_SYSCALL_NUMBER, + ) + } +} + +#[cfg(not(target_arch = "riscv64"))] +/// Apply the Keccak-f[1600] permutation to a 25-element u64 state in-place. +pub fn keccak_permute(_state: &mut [u64; 25]) { + unimplemented!("syscalls are only implemented for riscv64 targets"); +} + // ============================================================================= // Stub implementations for unsupported std functions // These functions are required by Rust's std zkvm module but are not supported From 01aa5e4cada40bb68b3449f98140f2aa03b1effc Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 20 May 2026 15:50:33 -0300 Subject: [PATCH 07/16] comments fix --- crypto/crypto/src/merkle_tree/merkle.rs | 28 +- crypto/stark/src/gpu_lde.rs | 1056 +++++++++++------------ crypto/stark/src/prover.rs | 147 ++-- prover/src/lib.rs | 5 - prover/src/tests/prove_elfs_tests.rs | 2 +- 5 files changed, 616 insertions(+), 622 deletions(-) diff --git a/crypto/crypto/src/merkle_tree/merkle.rs b/crypto/crypto/src/merkle_tree/merkle.rs index 9077d272a..5cc203550 100644 --- a/crypto/crypto/src/merkle_tree/merkle.rs +++ b/crypto/crypto/src/merkle_tree/merkle.rs @@ -123,16 +123,14 @@ where Self::build_from_hashed_leaves(hashed_leaves) } - /// Build a `MerkleTree` from an already-filled node vector whose layout - /// matches [`build_from_hashed_leaves`] output: + /// Useful for handing a GPU-built tree to the stark prover. + /// Performs no hashing, the caller is responsible for the layout's + /// cryptographic correctness. /// + /// Expected layout (matches [`build_from_hashed_leaves`]): /// - `nodes.len() == 2 * leaves_len - 1` where `leaves_len` is a power of two /// - `nodes[0]` is the root /// - `nodes[leaves_len - 1 .. 2*leaves_len - 1]` are the leaves - /// - /// Useful when the tree was constructed elsewhere (e.g. on a GPU) and - /// the caller just wants to hand the finished layout to the stark prover. - /// Performs no hashing. pub fn from_precomputed_nodes(nodes: Vec) -> Option { if nodes.is_empty() { return None; @@ -143,8 +141,24 @@ where if !(total + 1).is_power_of_two() { return None; } + // Debug-only integrity spot-check: the root must equal hash(left, right). + // Catches GPU correctness regressions in CI without paying for a full + // tree walk on every call. + #[cfg(debug_assertions)] + if total >= 3 { + let expected_root = B::hash_new_parent(&nodes[1], &nodes[2]); + debug_assert!( + nodes[ROOT] == expected_root, + "from_precomputed_nodes: root does not hash from children", + ); + } let root = nodes[ROOT].clone(); - Some(MerkleTree { root, nodes }) + Some(MerkleTree { + root, + nodes, + #[cfg(feature = "disk-spill")] + mmap_backing: None, + }) } /// Create a Merkle tree from pre-hashed leaf nodes. diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index 05e281549..6d9a0de0a 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -6,7 +6,7 @@ //! launch overhead dominates. Produces the same natural-order, non-canonical //! LDE evaluations as the CPU path. -use core::any::type_name; +use std::any::TypeId; use math::field::element::FieldElement; use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; @@ -47,109 +47,291 @@ pub fn reset_gpu_lde_calls() { GPU_LDE_CALLS.store(0, std::sync::atomic::Ordering::Relaxed); } +/// Reset all GPU call counters at once. Useful between bench warm-up and +/// profiled passes so the numbers reported aren't doubled by the warm-up. +pub fn reset_all_gpu_call_counters() { + use std::sync::atomic::Ordering::Relaxed; + GPU_LDE_CALLS.store(0, Relaxed); + GPU_EXTEND_HALVES_CALLS.store(0, Relaxed); + GPU_LEAF_HASH_CALLS.store(0, Relaxed); + GPU_MERKLE_TREE_CALLS.store(0, Relaxed); +} + pub(crate) static GPU_EXTEND_HALVES_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); pub fn gpu_extend_halves_calls() -> u64 { GPU_EXTEND_HALVES_CALLS.load(std::sync::atomic::Ordering::Relaxed) } -/// Try to GPU-batch all columns in one pass. -/// -/// Only engaged for Goldilocks-base tables whose LDE size is above the -/// threshold. The prover's `expand_columns_to_lde` hands us every column of -/// one table at once; those columns all share twiddles and coset weights so -/// they can be processed in a single batched pipeline on one stream. -/// -/// Returns `true` if the batch was handled on GPU (and `columns` now contains -/// the LDE evaluations). Returns `false` to let the caller run the per-column -/// CPU fallback. -#[inline] -pub(crate) fn try_expand_columns_batched( - columns: &mut [Vec>], - blowup_factor: usize, - weights: &[FieldElement], -) -> bool +// ============================================================================ +// Shared dispatch helpers +// ============================================================================ +// +// Every `try_expand_*` variant runs the same prologue: empty-check, threshold +// check, two TypeId checks, equal-length check, and a column-to-u64 cast under +// the type-confirmed precondition. Centralising that here keeps each variant +// short and means a future change to (say) the threshold logic is one edit. + +/// Outcome of validating an incoming `columns` slice against the GPU dispatch +/// preconditions. +enum LayoutDispatch { + /// `columns` is empty — caller returns its own "trivially done" value + /// (`true` for `bool` callers, `Some(Vec::new())` for `Option` callers). + Empty, + /// GPU path doesn't apply (below threshold, wrong types, ragged columns). + /// Caller returns its own "fall through to CPU" value (`false`/`None`). + Skip, + /// GPU path applies. `n` is the per-column input length; `lde_size = n * + /// blowup_factor` (saturating). + Run { n: usize, lde_size: usize }, +} + +/// Validate preconditions for the base-field batched GPU path: every column +/// must be Goldilocks base-field of equal length, the LDE size must clear the +/// threshold. +fn check_base_layout(columns: &[Vec>], blowup_factor: usize) -> LayoutDispatch where - F: IsField, - E: IsField, + F: IsField + 'static, + E: IsField + 'static, { if columns.is_empty() { - return true; // nothing to do — same as CPU path + return LayoutDispatch::Empty; } let n = columns[0].len(); let lde_size = n.saturating_mul(blowup_factor); if lde_size < gpu_lde_threshold() { - return false; + return LayoutDispatch::Skip; } - if type_name::() != type_name::() { - return false; + if TypeId::of::() != TypeId::of::() { + return LayoutDispatch::Skip; + } + if TypeId::of::() != TypeId::of::() { + return LayoutDispatch::Skip; } - // All columns within one call must be the same size (invariant of the - // caller), but double-check before unsafe extraction. if columns.iter().any(|c| c.len() != n) { - return false; + return LayoutDispatch::Skip; } + LayoutDispatch::Run { n, lde_size } +} - // Ext3 fast path: decompose each ext3 column into its 3 base components - // and dispatch to the base-field batched NTT with 3×M logical columns. - // Butterflies with a base-field twiddle act componentwise on ext3, so - // this is exactly equivalent to running the NTT in the extension field. - if type_name::() == type_name::() { - return try_expand_columns_batched_ext3::(columns, blowup_factor, weights); +/// Validate preconditions for the ext3 batched GPU path: every column must be +/// `Degree3GoldilocksExtensionField` of equal length, weights must be over +/// `GoldilocksField`, LDE size must clear the threshold. +fn check_ext3_layout(columns: &[Vec>], blowup_factor: usize) -> LayoutDispatch +where + F: IsField + 'static, + E: IsField + 'static, +{ + if columns.is_empty() { + return LayoutDispatch::Empty; } - - if type_name::() != type_name::() { - return false; + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return LayoutDispatch::Skip; + } + if TypeId::of::() != TypeId::of::() { + return LayoutDispatch::Skip; + } + if TypeId::of::() != TypeId::of::() { + return LayoutDispatch::Skip; + } + if columns.iter().any(|c| c.len() != n) { + return LayoutDispatch::Skip; } + LayoutDispatch::Run { n, lde_size } +} - // Extract raw u64 slices. SAFETY: type_name above confirms - // `E == GoldilocksField`, so `FieldElement` wraps u64 one-to-one. - let raw_columns: Vec> = columns +/// Materialise base-field columns as owned `Vec>` for the GPU input +/// slice list. +/// +/// SAFETY: caller must have established `E == GoldilocksField` (e.g. via +/// [`check_base_layout`]). Each `FieldElement` is then a `#[repr(transparent)]` +/// wrapper over `u64`. +unsafe fn columns_to_u64_base(columns: &[Vec>]) -> Vec> { + columns .iter() .map(|col| { col.iter() .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) .collect() }) - .collect(); - let weights_u64: Vec = weights + .collect() +} + +/// Materialise ext3 columns as owned `Vec>` (de-interleaved into raw +/// `[u64; 3]` lanes per element) for the GPU input slice list. +/// +/// SAFETY: caller must have established `E == Degree3GoldilocksExtensionField` +/// (e.g. via [`check_ext3_layout`]). Each `FieldElement` is then a +/// `#[repr(transparent)]` wrapper over `[u64; 3]`. +unsafe fn columns_to_u64_ext3(columns: &[Vec>]) -> Vec> { + columns + .iter() + .map(|col| { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }) + .collect() +} + +/// Materialise weights as a raw `Vec`. +/// +/// SAFETY: caller must have established `F == GoldilocksField`. +unsafe fn weights_to_u64(weights: &[FieldElement]) -> Vec { + weights .iter() .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) - .collect(); + .collect() +} - // Pre-size caller Vecs to lde_size so the GPU path can write directly - // into the same backing allocation the caller already holds. This skips - // the intermediate `Vec>` allocation (which would page-fault - // per column) and is the main reason `coset_lde_batch_base_into` exists. +/// Pre-size each column to `lde_size` and view it as a `&mut [u64]` of length +/// `lde_size` (base-field, single-u64 layout). Asserts capacity hard so a +/// caller regression can't quietly UB in release builds. +/// +/// SAFETY: caller must have established `E == GoldilocksField`. +unsafe fn presize_and_view_base( + columns: &mut [Vec>], + lde_size: usize, +) -> Vec<&mut [u64]> { for col in columns.iter_mut() { - // SAFETY: set_len is valid here because capacity is already >= - // lde_size (the caller sized columns via `extract_columns_main(lde_size)`) - // and we're about to overwrite every slot via the GPU copy below. - debug_assert!(col.capacity() >= lde_size); + assert!( + col.capacity() >= lde_size, + "col capacity {} < lde_size {}", + col.capacity(), + lde_size + ); + // SAFETY: assert above guarantees capacity; the GPU path overwrites + // every slot before any reader sees the new length. unsafe { col.set_len(lde_size) }; } - - // Borrow each caller Vec as a raw `&mut [u64]` slice; safe because each - // FieldElement aliases a single u64 when E == GoldilocksField. - let mut raw_outputs: Vec<&mut [u64]> = columns + columns .iter_mut() .map(|col| { let ptr = col.as_mut_ptr() as *mut u64; let len = col.len(); - // SAFETY: see above — single-u64 layout, caller still owns. + // SAFETY: single-u64 layout, caller still owns the backing alloc. + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect() +} + +/// Same as [`presize_and_view_base`] but for ext3 columns: each view is +/// `3 * lde_size` u64s (de-interleaved lanes). +/// +/// SAFETY: caller must have established `E == Degree3GoldilocksExtensionField`. +unsafe fn presize_and_view_ext3( + columns: &mut [Vec>], + lde_size: usize, +) -> Vec<&mut [u64]> { + for col in columns.iter_mut() { + assert!( + col.capacity() >= lde_size, + "col capacity {} < lde_size {}", + col.capacity(), + lde_size + ); + // SAFETY: assert above + GPU path overwrites every slot. + unsafe { col.set_len(lde_size) }; + } + columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len() * 3; + // SAFETY: ext3 `[u64; 3]` layout, caller still owns the backing. unsafe { core::slice::from_raw_parts_mut(ptr, len) } }) - .collect(); + .collect() +} +/// Truncate each column back to `n` (trace size) after a GPU error so the +/// CPU fallback (which reads `buffer.len()` as the trace size) runs cleanly. +/// Safe because `math_cuda` writes outputs only at the final host copy, post- +/// synchronize; any `Err` returns before that copy, leaving `columns[0..n]` +/// untouched. +fn restore_columns_on_err(columns: &mut [Vec>], n: usize) { + for col in columns.iter_mut() { + col.truncate(n); + } +} + +/// Allocate the `[u8; 32]` Merkle node buffer for a tree of `lde_size` leaves +/// and return both the node `Vec` (length-initialised, contents undefined) and +/// a `&mut [u8]` byte view of total length `total_nodes * 32`. Returns `None` +/// if the layout would be invalid (`lde_size < 2` or the byte length +/// overflows). The caller must overwrite every byte via the GPU D2H below. +fn alloc_merkle_nodes(lde_size: usize) -> Option<(Vec<[u8; 32]>, usize)> { + if lde_size < 2 { + return None; + } + let total_nodes = 2usize.saturating_mul(lde_size).checked_sub(1)?; + let _byte_len = total_nodes.checked_mul(32)?; + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + // SAFETY: every byte will be overwritten via the GPU D2H before the + // contents are read; the caller computes the byte-length view from the + // returned `nodes` Vec using `total_nodes.checked_mul(32)`. + unsafe { nodes.set_len(total_nodes) }; + Some((nodes, total_nodes)) +} + +/// Try to GPU-batch all columns in one pass. +/// +/// Only engaged for Goldilocks-base tables whose LDE size is above the +/// threshold. The prover's `expand_columns_to_lde` hands us every column of +/// one table at once; those columns all share twiddles and coset weights so +/// they can be processed in a single batched pipeline on one stream. +/// +/// Returns `true` if the batch was handled on GPU (and `columns` now contains +/// the LDE evaluations). Returns `false` to let the caller run the per-column +/// CPU fallback. +#[inline] +pub(crate) fn try_expand_columns_batched( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> bool +where + F: IsField + 'static, + E: IsField + 'static, +{ + // Ext3 fast path: decompose each ext3 column into its 3 base components + // and dispatch to the base-field batched NTT with 3×M logical columns. + // Butterflies with a base-field twiddle act componentwise on ext3, so + // this is exactly equivalent to running the NTT in the extension field. + if TypeId::of::() == TypeId::of::() { + return try_expand_columns_batched_ext3::(columns, blowup_factor, weights); + } + + let (n, lde_size) = match check_base_layout::(columns, blowup_factor) { + LayoutDispatch::Empty => return true, // nothing to do — same as CPU path + LayoutDispatch::Skip => return false, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + + // SAFETY: layout-checked above (`E == GoldilocksField`, `F == GoldilocksField`). + let raw_columns = unsafe { columns_to_u64_base::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); - math_cuda::lde::coset_lde_batch_base_into( - &slices, - blowup_factor, - &weights_u64, - &mut raw_outputs, - ) - .expect("GPU batched coset LDE failed"); + GPU_LDE_CALLS.fetch_add(num_columns as u64, std::sync::atomic::Ordering::Relaxed); + let gpu_result = { + let mut raw_outputs = unsafe { presize_and_view_base::(columns, lde_size) }; + math_cuda::lde::coset_lde_batch_base_into( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + ) + }; + if gpu_result.is_err() { + // Restore columns to trace length for the CPU fallback. `math_cuda` + // only writes outputs at the very end (post-synchronize host copy); + // on any Err the caller's `columns[0..n]` is untouched trace data. + restore_columns_on_err(columns, n); + return false; + } true } @@ -170,12 +352,11 @@ where pub(crate) fn try_extend_two_halves_gpu( h0: &[FieldElement], h1: &[FieldElement], - squared_offset: &FieldElement, domain: &Domain, ) -> Option<(Vec>, Vec>)> where - F: math::field::traits::IsFFTField + IsField, - E: IsField, + F: math::field::traits::IsFFTField + IsField + 'static, + E: IsField + 'static, F: IsSubFieldOf, { if h0.len() != h1.len() { @@ -187,16 +368,16 @@ where if lde_size < gpu_lde_threshold() { return None; } - if type_name::() != type_name::() { + if TypeId::of::() != TypeId::of::() { return None; } - if type_name::() != type_name::() { + if TypeId::of::() != TypeId::of::() { return None; } GPU_EXTEND_HALVES_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - // squared_offset should be `g²`. We recover `g` as `domain.coset_offset` - // and use it to build the `g^(-k) / N` weights. - let _ = squared_offset; // unused (we derive weights from domain) + // Weights are built from `g = domain.coset_offset` directly: the + // CPU caller previously passed `g²` redundantly. See the + // `g^(-k) / N` weight loop below. // Flatten ext3 slices to raw 3*n u64 buffers. let to_u64 = |col: &[FieldElement]| -> Vec { @@ -214,7 +395,7 @@ where let mut weights_u64 = Vec::with_capacity(n); let mut w = inv_n.clone(); for _ in 0..n { - // F == GoldilocksField by type_name check above, so value is u64. + // F == GoldilocksField by TypeId check above, so value is u64. let v: u64 = unsafe { *(w.value() as *const _ as *const u64) }; weights_u64.push(v); w = w * &g_inv; @@ -232,11 +413,17 @@ where let out1_ptr = lde_h1.as_mut_ptr() as *mut u64; // SAFETY: ext3 FieldElement is [u64; 3] in memory, and the Vec has len // = lde_size so the backing is 3*lde_size u64s. - let out0_slice = unsafe { core::slice::from_raw_parts_mut(out0_ptr, 3 * lde_size) }; - let out1_slice = unsafe { core::slice::from_raw_parts_mut(out1_ptr, 3 * lde_size) }; + let ext3_len = lde_size + .checked_mul(3) + .expect("ext3 output length overflow"); + let out0_slice = unsafe { core::slice::from_raw_parts_mut(out0_ptr, ext3_len) }; + let out1_slice = unsafe { core::slice::from_raw_parts_mut(out1_ptr, ext3_len) }; let mut outputs: [&mut [u64]; 2] = [out0_slice, out1_slice]; - math_cuda::lde::coset_lde_batch_ext3_into(&inputs, n, blowup, &weights_u64, &mut outputs) - .expect("GPU extend_half_to_lde failed"); + if math_cuda::lde::coset_lde_batch_ext3_into(&inputs, n, blowup, &weights_u64, &mut outputs) + .is_err() + { + return None; + } } Some((lde_h0, lde_h1)) @@ -259,75 +446,47 @@ pub(crate) fn try_expand_and_leaf_hash_batched( weights: &[FieldElement], ) -> Option> where - F: IsField, - E: IsField, + F: IsField + 'static, + E: IsField + 'static, { - if columns.is_empty() { - return Some(Vec::new()); - } - let n = columns[0].len(); - let lde_size = n.saturating_mul(blowup_factor); - if lde_size < gpu_lde_threshold() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if columns.iter().any(|c| c.len() != n) { - return None; - } - - let raw_columns: Vec> = columns - .iter() - .map(|col| { - col.iter() - .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) - .collect() - }) - .collect(); - let weights_u64: Vec = weights - .iter() - .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) - .collect(); - - for col in columns.iter_mut() { - debug_assert!(col.capacity() >= lde_size); - unsafe { col.set_len(lde_size) }; - } - let mut raw_outputs: Vec<&mut [u64]> = columns - .iter_mut() - .map(|col| { - let ptr = col.as_mut_ptr() as *mut u64; - let len = col.len(); - unsafe { core::slice::from_raw_parts_mut(ptr, len) } - }) - .collect(); + let (n, lde_size) = match check_base_layout::(columns, blowup_factor) { + LayoutDispatch::Empty => return Some(Vec::new()), + LayoutDispatch::Skip => return None, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + // SAFETY: layout-checked above. + let raw_columns = unsafe { columns_to_u64_base::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - // Allocate as Vec<[u8; 32]> directly so we both skip the zero-fill pass - // AND avoid re-chunking afterwards. Fresh pages still fault on first - // write (inside the GPU-side memcpy), but only once each. + // Allocate the leaf-hash buffer directly as `Vec<[u8; 32]>` to skip a + // re-chunk pass; fresh pages fault on first write but only once each. let mut leaves: Vec<[u8; 32]> = Vec::with_capacity(lde_size); - // SAFETY: we fill every byte via memcpy_dtoh below. + // SAFETY: every byte will be overwritten by the GPU D2H below. unsafe { leaves.set_len(lde_size) }; - let hashed_bytes_ptr = leaves.as_mut_ptr() as *mut u8; - let hashed_bytes: &mut [u8] = - unsafe { std::slice::from_raw_parts_mut(hashed_bytes_ptr, lde_size * 32) }; + let leaf_byte_len = lde_size.checked_mul(32).expect("leaf byte length overflow"); - GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LDE_CALLS.fetch_add(num_columns as u64, std::sync::atomic::Ordering::Relaxed); GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - math_cuda::lde::coset_lde_batch_base_into_with_leaf_hash( - &slices, - blowup_factor, - &weights_u64, - &mut raw_outputs, - hashed_bytes, - ) - .expect("GPU LDE+leaf-hash failed"); + let gpu_result = { + let mut raw_outputs = unsafe { presize_and_view_base::(columns, lde_size) }; + let hashed_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(leaves.as_mut_ptr() as *mut u8, leaf_byte_len) + }; + math_cuda::lde::coset_lde_batch_base_into_with_leaf_hash( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + hashed_bytes, + ) + }; + if gpu_result.is_err() { + restore_columns_on_err(columns, n); + return None; + } Some(leaves) } @@ -349,80 +508,46 @@ pub(crate) fn try_expand_leaf_and_tree_batched( weights: &[FieldElement], ) -> Option> where - F: IsField, - E: IsField, + F: IsField + 'static, + E: IsField + 'static, B: crypto::merkle_tree::traits::IsMerkleTreeBackend, { - if columns.is_empty() { - return None; - } - let n = columns[0].len(); - let lde_size = n.saturating_mul(blowup_factor); - if lde_size < gpu_lde_threshold() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if columns.iter().any(|c| c.len() != n) { - return None; - } - // Tree layout needs `2*lde_size - 1` nodes; must be a power-of-two leaf - // count. LDE size is always pow2 here (checked above). - if lde_size < 2 { - return None; - } - - let raw_columns: Vec> = columns - .iter() - .map(|col| { - col.iter() - .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) - .collect() - }) - .collect(); - let weights_u64: Vec = weights - .iter() - .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) - .collect(); - - for col in columns.iter_mut() { - debug_assert!(col.capacity() >= lde_size); - unsafe { col.set_len(lde_size) }; - } - let mut raw_outputs: Vec<&mut [u64]> = columns - .iter_mut() - .map(|col| { - let ptr = col.as_mut_ptr() as *mut u64; - let len = col.len(); - unsafe { core::slice::from_raw_parts_mut(ptr, len) } - }) - .collect(); - + let (n, lde_size) = match check_base_layout::(columns, blowup_factor) { + LayoutDispatch::Empty | LayoutDispatch::Skip => return None, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + let (mut nodes, total_nodes) = alloc_merkle_nodes(lde_size)?; + let node_byte_len = total_nodes + .checked_mul(32) + .expect("node byte length overflow"); + + // SAFETY: layout-checked above. + let raw_columns = unsafe { columns_to_u64_base::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - let total_nodes = 2 * lde_size - 1; - let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); - // SAFETY: every byte is written by the D2H below. - unsafe { nodes.set_len(total_nodes) }; - let nodes_bytes: &mut [u8] = - unsafe { core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) }; - - GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LDE_CALLS.fetch_add(num_columns as u64, std::sync::atomic::Ordering::Relaxed); GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree( - &slices, - blowup_factor, - &weights_u64, - &mut raw_outputs, - nodes_bytes, - ) - .expect("GPU LDE+leaf-hash+tree failed"); + let gpu_result = { + let mut raw_outputs = unsafe { presize_and_view_base::(columns, lde_size) }; + let nodes_bytes: &mut [u8] = unsafe { + core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) + }; + math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + }; + if gpu_result.is_err() { + restore_columns_on_err(columns, n); + return None; + } crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) } @@ -440,77 +565,49 @@ pub(crate) fn try_expand_leaf_and_tree_batched_keep( math_cuda::lde::GpuLdeBase, )> where - F: IsField, - E: IsField, + F: IsField + 'static, + E: IsField + 'static, B: crypto::merkle_tree::traits::IsMerkleTreeBackend, { - if columns.is_empty() { - return None; - } - let n = columns[0].len(); - let lde_size = n.saturating_mul(blowup_factor); - if lde_size < gpu_lde_threshold() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if columns.iter().any(|c| c.len() != n) { - return None; - } - if lde_size < 2 { - return None; - } - - let raw_columns: Vec> = columns - .iter() - .map(|col| { - col.iter() - .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) - .collect() - }) - .collect(); - let weights_u64: Vec = weights - .iter() - .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) - .collect(); - - for col in columns.iter_mut() { - debug_assert!(col.capacity() >= lde_size); - unsafe { col.set_len(lde_size) }; - } - let mut raw_outputs: Vec<&mut [u64]> = columns - .iter_mut() - .map(|col| { - let ptr = col.as_mut_ptr() as *mut u64; - let len = col.len(); - unsafe { core::slice::from_raw_parts_mut(ptr, len) } - }) - .collect(); - + let (n, lde_size) = match check_base_layout::(columns, blowup_factor) { + LayoutDispatch::Empty | LayoutDispatch::Skip => return None, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + let (mut nodes, total_nodes) = alloc_merkle_nodes(lde_size)?; + let node_byte_len = total_nodes + .checked_mul(32) + .expect("node byte length overflow"); + + // SAFETY: layout-checked above. + let raw_columns = unsafe { columns_to_u64_base::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - let total_nodes = 2 * lde_size - 1; - let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); - unsafe { nodes.set_len(total_nodes) }; - let nodes_bytes: &mut [u8] = - unsafe { core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) }; - - GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LDE_CALLS.fetch_add(num_columns as u64, std::sync::atomic::Ordering::Relaxed); GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let handle = math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree_keep( - &slices, - blowup_factor, - &weights_u64, - &mut raw_outputs, - nodes_bytes, - ) - .expect("GPU LDE+leaf-hash+tree+keep failed"); + let handle_result = { + let mut raw_outputs = unsafe { presize_and_view_base::(columns, lde_size) }; + let nodes_bytes: &mut [u8] = unsafe { + core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) + }; + math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree_keep( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + }; + let handle = match handle_result { + Ok(h) => h, + Err(_) => { + restore_columns_on_err(columns, n); + return None; + } + }; let tree = crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes)?; Some((tree, handle)) @@ -527,80 +624,50 @@ pub(crate) fn try_expand_leaf_and_tree_batched_ext3( weights: &[FieldElement], ) -> Option> where - F: IsField, - E: IsField, + F: IsField + 'static, + E: IsField + 'static, B: crypto::merkle_tree::traits::IsMerkleTreeBackend, { - if columns.is_empty() { - return None; - } - let n = columns[0].len(); - let lde_size = n.saturating_mul(blowup_factor); - if lde_size < gpu_lde_threshold() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if lde_size < 2 { - return None; - } - - // SAFETY: `E == Degree3Goldilocks`; each `FieldElement` is - // memory-equivalent to `[u64; 3]`. Copy out a Vec view per column. - let raw_columns: Vec> = columns - .iter() - .map(|col| { - let len = col.len() * 3; - let ptr = col.as_ptr() as *const u64; - unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() - }) - .collect(); - let weights_u64: Vec = weights - .iter() - .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) - .collect(); - - for col in columns.iter_mut() { - debug_assert!(col.capacity() >= lde_size); - unsafe { col.set_len(lde_size) }; - } - let mut raw_outputs: Vec<&mut [u64]> = columns - .iter_mut() - .map(|col| { - let ptr = col.as_mut_ptr() as *mut u64; - let len = col.len() * 3; - unsafe { core::slice::from_raw_parts_mut(ptr, len) } - }) - .collect(); - + let (n, lde_size) = match check_ext3_layout::(columns, blowup_factor) { + LayoutDispatch::Empty | LayoutDispatch::Skip => return None, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + let (mut nodes, total_nodes) = alloc_merkle_nodes(lde_size)?; + let node_byte_len = total_nodes + .checked_mul(32) + .expect("node byte length overflow"); + + // SAFETY: layout-checked above. + let raw_columns = unsafe { columns_to_u64_ext3::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - let total_nodes = 2 * lde_size - 1; - let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); - unsafe { nodes.set_len(total_nodes) }; - let nodes_bytes: &mut [u8] = - unsafe { core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) }; - GPU_LDE_CALLS.fetch_add( - (columns.len() * 3) as u64, + (num_columns * 3) as u64, std::sync::atomic::Ordering::Relaxed, ); GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree( - &slices, - n, - blowup_factor, - &weights_u64, - &mut raw_outputs, - nodes_bytes, - ) - .expect("GPU ext3 LDE+leaf-hash+tree failed"); + let gpu_result = { + let mut raw_outputs = unsafe { presize_and_view_ext3::(columns, lde_size) }; + let nodes_bytes: &mut [u8] = unsafe { + core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) + }; + math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + }; + if gpu_result.is_err() { + restore_columns_on_err(columns, n); + return None; + } crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) } @@ -617,78 +684,53 @@ pub(crate) fn try_expand_leaf_and_tree_batched_ext3_keep( math_cuda::lde::GpuLdeExt3, )> where - F: IsField, - E: IsField, + F: IsField + 'static, + E: IsField + 'static, B: crypto::merkle_tree::traits::IsMerkleTreeBackend, { - if columns.is_empty() { - return None; - } - let n = columns[0].len(); - let lde_size = n.saturating_mul(blowup_factor); - if lde_size < gpu_lde_threshold() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if lde_size < 2 { - return None; - } - - let raw_columns: Vec> = columns - .iter() - .map(|col| { - let len = col.len() * 3; - let ptr = col.as_ptr() as *const u64; - unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() - }) - .collect(); - let weights_u64: Vec = weights - .iter() - .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) - .collect(); - - for col in columns.iter_mut() { - debug_assert!(col.capacity() >= lde_size); - unsafe { col.set_len(lde_size) }; - } - let mut raw_outputs: Vec<&mut [u64]> = columns - .iter_mut() - .map(|col| { - let ptr = col.as_mut_ptr() as *mut u64; - let len = col.len() * 3; - unsafe { core::slice::from_raw_parts_mut(ptr, len) } - }) - .collect(); - + let (n, lde_size) = match check_ext3_layout::(columns, blowup_factor) { + LayoutDispatch::Empty | LayoutDispatch::Skip => return None, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + let (mut nodes, total_nodes) = alloc_merkle_nodes(lde_size)?; + let node_byte_len = total_nodes + .checked_mul(32) + .expect("node byte length overflow"); + + // SAFETY: layout-checked above. + let raw_columns = unsafe { columns_to_u64_ext3::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - let total_nodes = 2 * lde_size - 1; - let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); - unsafe { nodes.set_len(total_nodes) }; - let nodes_bytes: &mut [u8] = - unsafe { core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) }; - GPU_LDE_CALLS.fetch_add( - (columns.len() * 3) as u64, + (num_columns * 3) as u64, std::sync::atomic::Ordering::Relaxed, ); GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let handle = math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree_keep( - &slices, - n, - blowup_factor, - &weights_u64, - &mut raw_outputs, - nodes_bytes, - ) - .expect("GPU ext3 LDE+leaf-hash+tree+keep failed"); + let handle_result = { + let mut raw_outputs = unsafe { presize_and_view_ext3::(columns, lde_size) }; + let nodes_bytes: &mut [u8] = unsafe { + core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) + }; + math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree_keep( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + }; + let handle = match handle_result { + Ok(h) => h, + Err(_) => { + restore_columns_on_err(columns, n); + return None; + } + }; let tree = crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes)?; Some((tree, handle)) @@ -705,176 +747,99 @@ pub(crate) fn try_expand_and_leaf_hash_batched_ext3( weights: &[FieldElement], ) -> Option> where - F: IsField, - E: IsField, + F: IsField + 'static, + E: IsField + 'static, { - if columns.is_empty() { - return Some(Vec::new()); - } - let n = columns[0].len(); - let lde_size = n.saturating_mul(blowup_factor); - if lde_size < gpu_lde_threshold() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if type_name::() != type_name::() { - return None; - } - if columns.iter().any(|c| c.len() != n) { - return None; - } - - let raw_columns: Vec> = columns - .iter() - .map(|col| { - let len = col.len() * 3; - let ptr = col.as_ptr() as *const u64; - unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() - }) - .collect(); - let weights_u64: Vec = weights - .iter() - .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) - .collect(); - - for col in columns.iter_mut() { - debug_assert!(col.capacity() >= lde_size); - unsafe { col.set_len(lde_size) }; - } - let mut raw_outputs: Vec<&mut [u64]> = columns - .iter_mut() - .map(|col| { - let ptr = col.as_mut_ptr() as *mut u64; - let len = col.len() * 3; - unsafe { core::slice::from_raw_parts_mut(ptr, len) } - }) - .collect(); + let (n, lde_size) = match check_ext3_layout::(columns, blowup_factor) { + LayoutDispatch::Empty => return Some(Vec::new()), + LayoutDispatch::Skip => return None, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + // SAFETY: layout-checked above. + let raw_columns = unsafe { columns_to_u64_ext3::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); let mut leaves: Vec<[u8; 32]> = Vec::with_capacity(lde_size); + // SAFETY: every byte will be overwritten by the GPU D2H below. unsafe { leaves.set_len(lde_size) }; - let hashed_bytes: &mut [u8] = - unsafe { std::slice::from_raw_parts_mut(leaves.as_mut_ptr() as *mut u8, lde_size * 32) }; + let leaf_byte_len = lde_size.checked_mul(32).expect("leaf byte length overflow"); GPU_LDE_CALLS.fetch_add( - (columns.len() * 3) as u64, + (num_columns * 3) as u64, std::sync::atomic::Ordering::Relaxed, ); GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - math_cuda::lde::coset_lde_batch_ext3_into_with_leaf_hash( - &slices, - n, - blowup_factor, - &weights_u64, - &mut raw_outputs, - hashed_bytes, - ) - .expect("GPU ext3 LDE+leaf-hash failed"); + let gpu_result = { + let mut raw_outputs = unsafe { presize_and_view_ext3::(columns, lde_size) }; + let hashed_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(leaves.as_mut_ptr() as *mut u8, leaf_byte_len) + }; + math_cuda::lde::coset_lde_batch_ext3_into_with_leaf_hash( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + hashed_bytes, + ) + }; + if gpu_result.is_err() { + restore_columns_on_err(columns, n); + return None; + } Some(leaves) } /// Ext3 specialisation of [`try_expand_columns_batched`]. `E` is known to be -/// `Degree3GoldilocksExtensionField` by type_name match at the caller. +/// `Degree3GoldilocksExtensionField` by TypeId match at the caller. fn try_expand_columns_batched_ext3( columns: &mut [Vec>], blowup_factor: usize, weights: &[FieldElement], ) -> bool where - F: IsField, - E: IsField, + F: IsField + 'static, + E: IsField + 'static, { - if columns.is_empty() { - return true; - } - let n = columns[0].len(); - let lde_size = n.saturating_mul(blowup_factor); - - // SAFETY: caller confirmed `E == Degree3GoldilocksExtensionField` via - // type_name. That means `FieldElement` wraps `[FieldElement; 3]`, - // which is memory-equivalent to `[u64; 3]`. A `&[FieldElement]` of - // length `n` is therefore a contiguous `3 * n * 8` byte buffer. - let raw_columns: Vec> = columns - .iter() - .map(|col| { - let len = col.len() * 3; - let ptr = col.as_ptr() as *const u64; - // Copy rather than borrow: the caller still owns `col` and will - // reuse its backing storage after we resize + rewrite below. - unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() - }) - .collect(); - // F is `type_name::() == GoldilocksField` by caller precondition; - // `F::BaseType == u64`, so we can read each `w.value()` as a `*const u64`. - let weights_u64: Vec = weights - .iter() - .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) - .collect(); - - // Pre-size each ext3 column to lde_size so its backing Vec has the right - // length for the output re-interleave. Capacity must already be >= - // lde_size (caller's `extract_columns_main(lde_size)` ensures this). - for col in columns.iter_mut() { - debug_assert!(col.capacity() >= lde_size); - // SAFETY: overwritten fully by the GPU path below. - unsafe { col.set_len(lde_size) }; - } - - // View each column's backing memory as a `&mut [u64]` of length - // `3*lde_size`. Safe because ext3 elements are `[u64; 3]` layouts. - let mut raw_outputs: Vec<&mut [u64]> = columns - .iter_mut() - .map(|col| { - let ptr = col.as_mut_ptr() as *mut u64; - let len = col.len() * 3; - unsafe { core::slice::from_raw_parts_mut(ptr, len) } - }) - .collect(); + let (n, lde_size) = match check_ext3_layout::(columns, blowup_factor) { + LayoutDispatch::Empty => return true, + LayoutDispatch::Skip => return false, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + // SAFETY: layout-checked above. + let raw_columns = unsafe { columns_to_u64_ext3::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + // Account each ext3 column as 3 logical GPU LDE "calls" (base-field // components) so the counter matches the base-field batched path. GPU_LDE_CALLS.fetch_add( - (columns.len() * 3) as u64, + (num_columns * 3) as u64, std::sync::atomic::Ordering::Relaxed, ); - math_cuda::lde::coset_lde_batch_ext3_into( - &slices, - n, - blowup_factor, - &weights_u64, - &mut raw_outputs, - ) - .expect("GPU batched ext3 coset LDE failed"); + let gpu_result = { + let mut raw_outputs = unsafe { presize_and_view_ext3::(columns, lde_size) }; + math_cuda::lde::coset_lde_batch_ext3_into( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + ) + }; + if gpu_result.is_err() { + restore_columns_on_err(columns, n); + return false; + } true } -// ============================================================================ -// GPU barycentric OOD evaluation -// ============================================================================ -// -// Infrastructure for future use: these wrappers drive -// `math_cuda::barycentric::barycentric_{base,ext3}` and apply the trailing ext3 -// scalar on host. See the CPU reference in -// `crypto/math/src/polynomial/mod.rs::interpolate_coset_eval_*_with_g_n_inv`. -// -// NOT currently wired into the prover — a benchmark on fib_iterative_{1M, 4M} -// showed the CPU path (rayon over ~50 columns) already finishes in <1 ms wall -// because the GPU is busy with LDE and Merkle on parallel streams, so moving -// R3 OOD to the GPU just serialises work without freeing CPU wall time. -// Kept here and covered by parity tests in `crypto/math-cuda/tests/barycentric.rs` -// because it remains a net win for single-table or very-large-trace workloads. -// -// The GPU kernel returns the unscaled sum -// S = Σ_i point_i · eval_i · inv_denom_i -// per column; the final barycentric value is -// f(z) = scalar · (z^N − g^N) · S -// with `scalar = n_inv · g_n_inv` kept in the base field. - // ============================================================================ // GPU Merkle inner-tree construction // ============================================================================ @@ -917,22 +882,35 @@ where if leaves_len < gpu_merkle_tree_threshold() || !leaves_len.is_power_of_two() || leaves_len < 2 { return None; } + let leaves_byte_len = leaves_len + .checked_mul(32) + .expect("leaf byte length overflow"); + let total_nodes = 2usize + .checked_mul(leaves_len) + .and_then(|v| v.checked_sub(1)) + .expect("merkle node count overflow"); + let node_byte_len = total_nodes + .checked_mul(32) + .expect("node byte length overflow"); + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); // Flatten host-side leaves into a contiguous byte buffer for the GPU // kernel. SAFETY: `[u8; 32]` is POD and the slice is contiguous. let leaves_bytes: &[u8] = unsafe { - core::slice::from_raw_parts(hashed_leaves.as_ptr() as *const u8, leaves_len * 32) + core::slice::from_raw_parts(hashed_leaves.as_ptr() as *const u8, leaves_byte_len) + }; + let nodes_bytes = match math_cuda::merkle::build_merkle_tree_on_device(leaves_bytes) { + Ok(b) => b, + Err(_) => return None, }; - let nodes_bytes = math_cuda::merkle::build_merkle_tree_on_device(leaves_bytes) - .expect("GPU merkle tree build failed"); - let total_nodes = 2 * leaves_len - 1; - debug_assert_eq!(nodes_bytes.len(), total_nodes * 32); + debug_assert_eq!(nodes_bytes.len(), node_byte_len); - // Re-chunk into `Vec<[u8; 32]>` without re-allocating. We'd need an - // explicit copy because Vec and Vec<[u8; 32]> have different - // layouts in the allocator metadata (align differs on some platforms). + // Re-chunk the flat byte buffer into `Vec<[u8; 32]>`. Alignment is + // identical (`[u8; 32]` has align 1), but `Vec` and `Vec<[u8; 32]>` + // track different element counts, so a fresh allocation + per-row copy + // is the simplest correct conversion. let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); for i in 0..total_nodes { let mut n = [0u8; 32]; diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 911bc3163..64cf2ea36 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -89,8 +89,8 @@ pub struct Prover< } impl< - Field: IsSubFieldOf + IsFFTField + Send + Sync, - FieldExtension: Send + Sync + IsField, + Field: IsSubFieldOf + IsFFTField + Send + Sync + 'static, + FieldExtension: Send + Sync + IsField + 'static, PI, > IsStarkProver for Prover where @@ -212,18 +212,22 @@ struct Lde { /// Result of `commit_main_trace` / `commit_preprocessed_trace`. Wraps the /// commitment Merkle data plus the owned LDE columns, and — when the R1 /// fused GPU pipeline ran — the retained device LDE handle. +/// +/// Fields are `pub(crate)` because constructing one is the prover module's +/// job; external implementers of `IsStarkProver` consume it via the trait's +/// default impls and don't need to construct it themselves. pub struct MainTraceCommitResult where FieldElement: AsBytes, { - tree: BatchedMerkleTree, - root: Commitment, - precomputed_tree: Option>, - precomputed_root: Option, - num_precomputed_cols: usize, - columns: Vec>>, + pub(crate) tree: BatchedMerkleTree, + pub(crate) root: Commitment, + pub(crate) precomputed_tree: Option>, + pub(crate) precomputed_root: Option, + pub(crate) num_precomputed_cols: usize, + pub(crate) columns: Vec>>, #[cfg(feature = "cuda")] - gpu_main: Option, + pub(crate) gpu_main: Option, } impl Round1Commitments @@ -517,8 +521,8 @@ where /// The default implementation is complete and is compatible with Stone prover /// https://github.com/starkware-libs/stone-prover pub trait IsStarkProver< - Field: IsSubFieldOf + IsFFTField + Send + Sync, - FieldExtension: Send + Sync + IsField, + Field: IsSubFieldOf + IsFFTField + Send + Sync + 'static, + FieldExtension: Send + Sync + IsField + 'static, PI, > where FieldElement: math::traits::ByteConversion, @@ -617,7 +621,7 @@ pub trait IsStarkProver< twiddles: &LdeTwiddles, ) where Field: IsSubFieldOf, - E: IsSubFieldOf + IsField + Send + Sync, + E: IsSubFieldOf + IsField + Send + Sync + 'static, FieldElement: Send + Sync, { if columns.is_empty() { @@ -833,7 +837,14 @@ pub trait IsStarkProver< }; Ok(commitment.build_round1( - Lde { main, aux }, + Lde { + main, + aux, + #[cfg(feature = "cuda")] + gpu_main: None, + #[cfg(feature = "cuda")] + gpu_aux: None, + }, air.step_size(), domain.blowup_factor, air.has_aux_trace(), @@ -982,12 +993,9 @@ pub trait IsStarkProver< // GPU fast path: batch both halves into one ext3 LDE call. Requires // `cuda` feature and a qualifying size; falls through to CPU when not. #[cfg(feature = "cuda")] - if let Some((lde_h0, lde_h1)) = crate::gpu_lde::try_extend_two_halves_gpu( - &h0_evals, - &h1_evals, - &coset_offset_squared, - domain, - ) { + if let Some((lde_h0, lde_h1)) = + crate::gpu_lde::try_extend_two_halves_gpu(&h0_evals, &h1_evals, domain) + { return vec![lde_h0, lde_h1]; } @@ -1947,24 +1955,20 @@ pub trait IsStarkProver< }) .collect(); - // Parallel aux commit in chunks of K. Fourth field is an optional - // GPU ext3 LDE handle retained when the R1 fused pipeline fires. - #[cfg(feature = "cuda")] + // Parallel aux commit in chunks of K. The optional ext3 GPU LDE handle + // (retained when the R1 fused pipeline fires) is carried in a side + // vector under `cfg(cuda)` so AuxResult stays a clean 3-tuple in both + // cfg variants. type AuxResult = ( Option>>, Option, Vec>>, - Option, - ); - #[cfg(not(feature = "cuda"))] - type AuxResult = ( - Option>>, - Option, - Vec>>, - (), ); #[allow(clippy::type_complexity)] let mut aux_results: Vec> = Vec::with_capacity(num_airs); + #[cfg(feature = "cuda")] + let mut aux_gpu_handles: Vec> = + Vec::with_capacity(num_airs); for chunk_start in (0..num_airs).step_by(k) { let chunk_end = (chunk_start + k).min(num_airs); @@ -1975,6 +1979,10 @@ pub trait IsStarkProver< #[cfg(not(feature = "parallel"))] let iter = chunk_range; + // Per-iter the closure produces `(AuxResult, Option)` + // under cuda, or `AuxResult` alone under non-cuda. Splitting them + // at the sequential collection step keeps the two-vec layout. + #[allow(clippy::type_complexity)] let chunk_aux: Vec> = iter .map(|idx| { let (air, trace, _) = &air_trace_pairs[idx]; @@ -2010,9 +2018,7 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] crate::instruments::accum_r1_aux(aux_lde_dur, zero); return Ok(( - Some(Arc::new(tree)), - Some(root), - columns, + (Some(Arc::new(tree)), Some(root), columns), Some(handle), )); } @@ -2047,27 +2053,35 @@ pub trait IsStarkProver< } #[cfg(feature = "cuda")] - let aux_gpu: Option = None; + return Ok(( + (Some(Arc::new(tree)), Some(root), columns), + None::, + )); #[cfg(not(feature = "cuda"))] - let aux_gpu: () = (); - Ok((Some(Arc::new(tree)), Some(root), columns, aux_gpu)) + Ok((Some(Arc::new(tree)), Some(root), columns)) } else { #[cfg(feature = "cuda")] - let aux_gpu: Option = None; + return Ok(((None, None, Vec::new()), None::)); #[cfg(not(feature = "cuda"))] - let aux_gpu: () = (); - Ok((None, None, Vec::new(), aux_gpu)) + Ok((None, None, Vec::new())) } }) .collect(); - // Sequential: append aux roots to forked transcripts + // Sequential: append aux roots to forked transcripts and split + // the optional GPU handle into its own side vector under cuda. for (j, result) in chunk_aux.into_iter().enumerate() { - let (aux_tree, aux_root, cached_aux, aux_gpu) = result?; + #[cfg(feature = "cuda")] + let (aux_triple, aux_gpu_h) = result?; + #[cfg(not(feature = "cuda"))] + let aux_triple = result?; + let (aux_tree, aux_root, cached_aux) = aux_triple; if let Some(ref root) = aux_root { table_transcripts[chunk_start + j].append_bytes(root); } - aux_results.push((aux_tree, aux_root, cached_aux, aux_gpu)); + aux_results.push((aux_tree, aux_root, cached_aux)); + #[cfg(feature = "cuda")] + aux_gpu_handles.push(aux_gpu_h); } } @@ -2076,25 +2090,18 @@ pub trait IsStarkProver< let mut commitments: Vec> = Vec::with_capacity(num_airs); let mut cached_ldes: Vec> = Vec::with_capacity(num_airs); - // Zip in the optional GPU handles so the Lde constructor always - // has a value for its gpu_main/gpu_aux. Under `cfg(not(cuda))` the - // handles are `()` (see AuxResult type alias) — we just discard them. + #[cfg(feature = "cuda")] - let main_gpu_iter: Box>> = - Box::new(main_gpu_handles.into_iter()); - #[cfg(not(feature = "cuda"))] - let main_gpu_iter: Box> = - Box::new(std::iter::repeat_with(|| ()).take(num_airs)); - - for ( - (((main_commit, main_lde), main_gpu_h), (aux_tree, aux_root, cached_aux, aux_gpu_h)), - bus_public_inputs, - ) in main_commits - .into_iter() - .zip(main_ldes) - .zip(main_gpu_iter) - .zip(aux_results) - .zip(bus_inputs_vec) + let mut main_gpu_iter = main_gpu_handles.into_iter(); + #[cfg(feature = "cuda")] + let mut aux_gpu_iter = aux_gpu_handles.into_iter(); + + for (((main_commit, main_lde), (aux_tree, aux_root, cached_aux)), bus_public_inputs) in + main_commits + .into_iter() + .zip(main_ldes) + .zip(aux_results) + .zip(bus_inputs_vec) { commitments.push(Round1Commitments { main_merkle_tree: main_commit.main_tree, @@ -2111,18 +2118,18 @@ pub trait IsStarkProver< cached_ldes.push(Lde { main: main_lde, aux: cached_aux, - gpu_main: main_gpu_h, - gpu_aux: aux_gpu_h, + gpu_main: main_gpu_iter + .next() + .expect("main_gpu_handles length mismatch"), + gpu_aux: aux_gpu_iter + .next() + .expect("aux_gpu_handles length mismatch"), }); #[cfg(not(feature = "cuda"))] - { - #[allow(clippy::let_unit_value)] - let _ = (main_gpu_h, aux_gpu_h); - cached_ldes.push(Lde { - main: main_lde, - aux: cached_aux, - }); - } + cached_ldes.push(Lde { + main: main_lde, + aux: cached_aux, + }); } #[cfg(feature = "instruments")] diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 77cb8ad4b..dbe13d20b 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -677,11 +677,6 @@ pub fn prove_with_options_and_inputs( .filter(|c| c.is_private_input) .count(); - debug_assert_eq!( - traces.public_output_bytes, result.return_values.memory_values, - "public output diverged between executor view and trace reconstruction" - ); - Ok(VmProof { proof, runtime_page_ranges, diff --git a/prover/src/tests/prove_elfs_tests.rs b/prover/src/tests/prove_elfs_tests.rs index de0a686bf..65d9cdbec 100644 --- a/prover/src/tests/prove_elfs_tests.rs +++ b/prover/src/tests/prove_elfs_tests.rs @@ -107,7 +107,7 @@ fn prove_vm_minimal(elf_bytes: &[u8], private_inputs: &[u8], max_rows: &MaxRowsC &table_counts, ); let runtime_page_ranges = traces.runtime_page_ranges(); - let proof = Prover::multi_prove( + let proof = multi_prove_ram( airs.air_trace_pairs(&mut traces), &mut DefaultTranscript::::new(&[]), ) From ea5696f6be03f8dd9449646130b9a4767b220926 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti <56092489+ColoCarletti@users.noreply.github.com> Date: Thu, 21 May 2026 18:32:49 -0300 Subject: [PATCH 08/16] Update crypto/stark/src/gpu_lde.rs Co-authored-by: Gabriel Bosio <38794644+gabrielbosio@users.noreply.github.com> --- crypto/stark/src/gpu_lde.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index 6d9a0de0a..119097153 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -1,5 +1,4 @@ -//! GPU dispatch layer for the per-column coset LDE. Lives in the stark crate -//! (not `math`) to avoid a dependency cycle between `math` and `math-cuda`. +//! GPU dispatch layer for the per-column coset LDE. //! //! Handles only Goldilocks base-field columns above a size threshold; falls //! back to CPU for extension-field columns and small columns where kernel From a8cf26583c4cb462c9f9f9e5ff9656763f1905e5 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti <56092489+ColoCarletti@users.noreply.github.com> Date: Thu, 21 May 2026 18:33:33 -0300 Subject: [PATCH 09/16] Update crypto/stark/src/gpu_lde.rs Co-authored-by: Gabriel Bosio <38794644+gabrielbosio@users.noreply.github.com> --- crypto/stark/src/gpu_lde.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index 119097153..00a418303 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -1,6 +1,6 @@ //! GPU dispatch layer for the per-column coset LDE. //! -//! Handles only Goldilocks base-field columns above a size threshold; falls +//! Handles only Goldilocks base-field columns above a size threshold. Falls //! back to CPU for extension-field columns and small columns where kernel //! launch overhead dominates. Produces the same natural-order, non-canonical //! LDE evaluations as the CPU path. From fb8d31f94efaa65b72a6c70657c2944ab7f72895 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti <56092489+ColoCarletti@users.noreply.github.com> Date: Thu, 21 May 2026 18:33:47 -0300 Subject: [PATCH 10/16] Update crypto/stark/src/gpu_lde.rs Co-authored-by: Gabriel Bosio <38794644+gabrielbosio@users.noreply.github.com> --- crypto/stark/src/gpu_lde.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index 00a418303..aac750775 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -21,7 +21,7 @@ use crate::domain::Domain; /// /// 2^19 is a conservative default calibrated against a 46-core machine where /// rayon-parallel CPU LDE is already fast. Override via env var for tuning -/// on smaller machines; see `/workspace/lambda_vm/crypto/math-cuda/tests/bench_quick.rs`. +/// on smaller machines, see `crypto/math-cuda/tests/bench_quick.rs`. const DEFAULT_GPU_LDE_THRESHOLD: usize = 1 << 19; fn gpu_lde_threshold() -> usize { From a79f2b567fcafd1b205b27c9f66c18ffc3dd1106 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti <56092489+ColoCarletti@users.noreply.github.com> Date: Thu, 21 May 2026 18:34:08 -0300 Subject: [PATCH 11/16] Update crypto/stark/src/gpu_lde.rs Co-authored-by: Gabriel Bosio <38794644+gabrielbosio@users.noreply.github.com> --- crypto/stark/src/gpu_lde.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index aac750775..cc4c8a02d 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -71,17 +71,15 @@ pub fn gpu_extend_halves_calls() -> u64 { // the type-confirmed precondition. Centralising that here keeps each variant // short and means a future change to (say) the threshold logic is one edit. -/// Outcome of validating an incoming `columns` slice against the GPU dispatch -/// preconditions. +/// Outcome of validating an input slice against the GPU dispatch preconditions. enum LayoutDispatch { - /// `columns` is empty — caller returns its own "trivially done" value - /// (`true` for `bool` callers, `Some(Vec::new())` for `Option` callers). + /// Input slice is empty, no work to do. Empty, - /// GPU path doesn't apply (below threshold, wrong types, ragged columns). - /// Caller returns its own "fall through to CPU" value (`false`/`None`). + /// Preconditions not met: below threshold, wrong element types, or + /// columns of unequal length. Skip, - /// GPU path applies. `n` is the per-column input length; `lde_size = n * - /// blowup_factor` (saturating). + /// Preconditions met. `n` is the per-column input length: + /// `lde_size = n * blowup_factor` (saturating). Run { n: usize, lde_size: usize }, } From 761a2c0e7f00be19413060e0f3aa04569773ef03 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti <56092489+ColoCarletti@users.noreply.github.com> Date: Thu, 21 May 2026 18:34:20 -0300 Subject: [PATCH 12/16] Update crypto/stark/src/gpu_lde.rs Co-authored-by: Gabriel Bosio <38794644+gabrielbosio@users.noreply.github.com> --- crypto/stark/src/gpu_lde.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index cc4c8a02d..d56b3ed4b 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -15,7 +15,7 @@ use math::field::traits::{IsField, IsSubFieldOf}; use crate::domain::Domain; /// Break-even LDE size. Below this, the CPU `coset_lde_full_expand` completes -/// in a few hundred microseconds and the GPU's ~37 kernel launches plus +/// in a few hundred microseconds and the GPU's tens of kernel launches plus /// H2D/D2H round-trip is a net loss. The check is on **lde size**, not trace /// length, because that's what determines the FFT workload. /// From e066e9d384d734edfb9ad6be8402abb2b2b4cfda Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 21 May 2026 18:53:05 -0300 Subject: [PATCH 13/16] address reviews --- crypto/stark/src/gpu_lde.rs | 51 +++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index d56b3ed4b..71a3f1b65 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -6,6 +6,9 @@ //! LDE evaluations as the CPU path. use std::any::TypeId; +use std::slice::{from_raw_parts, from_raw_parts_mut}; +use std::sync::OnceLock; +use std::sync::atomic::{AtomicU64, Ordering}; use math::field::element::FieldElement; use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; @@ -25,7 +28,7 @@ use crate::domain::Domain; const DEFAULT_GPU_LDE_THRESHOLD: usize = 1 << 19; fn gpu_lde_threshold() -> usize { - static CACHED: std::sync::OnceLock = std::sync::OnceLock::new(); + static CACHED: OnceLock = OnceLock::new(); *CACHED.get_or_init(|| { std::env::var("LAMBDA_VM_GPU_LDE_THRESHOLD") .ok() @@ -36,40 +39,36 @@ fn gpu_lde_threshold() -> usize { /// Atomically counted by `try_expand_column` every time it actually routes a /// column to the GPU. Used by benchmarks to confirm the GPU path fired. -static GPU_LDE_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); +static GPU_LDE_CALLS: AtomicU64 = AtomicU64::new(0); pub fn gpu_lde_calls() -> u64 { - GPU_LDE_CALLS.load(std::sync::atomic::Ordering::Relaxed) + GPU_LDE_CALLS.load(Ordering::Relaxed) } pub fn reset_gpu_lde_calls() { - GPU_LDE_CALLS.store(0, std::sync::atomic::Ordering::Relaxed); + GPU_LDE_CALLS.store(0, Ordering::Relaxed); } /// Reset all GPU call counters at once. Useful between bench warm-up and /// profiled passes so the numbers reported aren't doubled by the warm-up. pub fn reset_all_gpu_call_counters() { - use std::sync::atomic::Ordering::Relaxed; - GPU_LDE_CALLS.store(0, Relaxed); - GPU_EXTEND_HALVES_CALLS.store(0, Relaxed); - GPU_LEAF_HASH_CALLS.store(0, Relaxed); - GPU_MERKLE_TREE_CALLS.store(0, Relaxed); + GPU_LDE_CALLS.store(0, Ordering::Relaxed); + GPU_EXTEND_HALVES_CALLS.store(0, Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.store(0, Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.store(0, Ordering::Relaxed); } -pub(crate) static GPU_EXTEND_HALVES_CALLS: std::sync::atomic::AtomicU64 = - std::sync::atomic::AtomicU64::new(0); +pub(crate) static GPU_EXTEND_HALVES_CALLS: AtomicU64 = AtomicU64::new(0); pub fn gpu_extend_halves_calls() -> u64 { - GPU_EXTEND_HALVES_CALLS.load(std::sync::atomic::Ordering::Relaxed) + GPU_EXTEND_HALVES_CALLS.load(Ordering::Relaxed) } // ============================================================================ // Shared dispatch helpers // ============================================================================ // -// Every `try_expand_*` variant runs the same prologue: empty-check, threshold -// check, two TypeId checks, equal-length check, and a column-to-u64 cast under -// the type-confirmed precondition. Centralising that here keeps each variant -// short and means a future change to (say) the threshold logic is one edit. +// Common prologue for the try_expand_* variants: empty-check, threshold, +// TypeId checks, equal-length check, column-to-u64 cast. /// Outcome of validating an input slice against the GPU dispatch preconditions. enum LayoutDispatch { @@ -139,8 +138,7 @@ where LayoutDispatch::Run { n, lde_size } } -/// Materialise base-field columns as owned `Vec>` for the GPU input -/// slice list. +/// Convert base-field columns to `Vec>` for the GPU input slice list. /// /// SAFETY: caller must have established `E == GoldilocksField` (e.g. via /// [`check_base_layout`]). Each `FieldElement` is then a `#[repr(transparent)]` @@ -156,8 +154,8 @@ unsafe fn columns_to_u64_base(columns: &[Vec>]) -> V .collect() } -/// Materialise ext3 columns as owned `Vec>` (de-interleaved into raw -/// `[u64; 3]` lanes per element) for the GPU input slice list. +/// Convert ext3 columns to `Vec>` (de-interleaved into raw `[u64; 3]` +/// lanes per element) for the GPU input slice list. /// /// SAFETY: caller must have established `E == Degree3GoldilocksExtensionField` /// (e.g. via [`check_ext3_layout`]). Each `FieldElement` is then a @@ -168,12 +166,12 @@ unsafe fn columns_to_u64_ext3(columns: &[Vec>]) -> V .map(|col| { let len = col.len() * 3; let ptr = col.as_ptr() as *const u64; - unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + unsafe { from_raw_parts(ptr, len) }.to_vec() }) .collect() } -/// Materialise weights as a raw `Vec`. +/// Convert weights to raw `Vec`. /// /// SAFETY: caller must have established `F == GoldilocksField`. unsafe fn weights_to_u64(weights: &[FieldElement]) -> Vec { @@ -184,8 +182,7 @@ unsafe fn weights_to_u64(weights: &[FieldElement]) -> Vec { } /// Pre-size each column to `lde_size` and view it as a `&mut [u64]` of length -/// `lde_size` (base-field, single-u64 layout). Asserts capacity hard so a -/// caller regression can't quietly UB in release builds. +/// `lde_size` (base-field, single-u64 layout). /// /// SAFETY: caller must have established `E == GoldilocksField`. unsafe fn presize_and_view_base( @@ -199,7 +196,7 @@ unsafe fn presize_and_view_base( col.capacity(), lde_size ); - // SAFETY: assert above guarantees capacity; the GPU path overwrites + // SAFETY: assert above guarantees capacity, the GPU path overwrites // every slot before any reader sees the new length. unsafe { col.set_len(lde_size) }; } @@ -209,7 +206,7 @@ unsafe fn presize_and_view_base( let ptr = col.as_mut_ptr() as *mut u64; let len = col.len(); // SAFETY: single-u64 layout, caller still owns the backing alloc. - unsafe { core::slice::from_raw_parts_mut(ptr, len) } + unsafe { from_raw_parts_mut(ptr, len) } }) .collect() } @@ -238,7 +235,7 @@ unsafe fn presize_and_view_ext3( let ptr = col.as_mut_ptr() as *mut u64; let len = col.len() * 3; // SAFETY: ext3 `[u64; 3]` layout, caller still owns the backing. - unsafe { core::slice::from_raw_parts_mut(ptr, len) } + unsafe { from_raw_parts_mut(ptr, len) } }) .collect() } From 7d3d0f01d5013cb4e616af3b1725b043d888a0d1 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 22 May 2026 12:44:19 -0300 Subject: [PATCH 14/16] fix review comments --- crypto/stark/src/gpu_lde.rs | 423 ++++------------------------------- crypto/stark/src/prover.rs | 137 ++++++------ prover/tests/bench_single.rs | 11 + 3 files changed, 125 insertions(+), 446 deletions(-) diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index 71a3f1b65..7c9983601 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -45,10 +45,6 @@ pub fn gpu_lde_calls() -> u64 { GPU_LDE_CALLS.load(Ordering::Relaxed) } -pub fn reset_gpu_lde_calls() { - GPU_LDE_CALLS.store(0, Ordering::Relaxed); -} - /// Reset all GPU call counters at once. Useful between bench warm-up and /// profiled passes so the numbers reported aren't doubled by the warm-up. pub fn reset_all_gpu_call_counters() { @@ -277,15 +273,15 @@ fn alloc_merkle_nodes(lde_size: usize) -> Option<(Vec<[u8; 32]>, usize)> { /// one table at once; those columns all share twiddles and coset weights so /// they can be processed in a single batched pipeline on one stream. /// -/// Returns `true` if the batch was handled on GPU (and `columns` now contains -/// the LDE evaluations). Returns `false` to let the caller run the per-column -/// CPU fallback. +/// Returns `Some(())` if the batch was handled on GPU (and `columns` now +/// contains the LDE evaluations). Returns `None` to let the caller run the +/// per-column CPU fallback. #[inline] pub(crate) fn try_expand_columns_batched( columns: &mut [Vec>], blowup_factor: usize, weights: &[FieldElement], -) -> bool +) -> Option<()> where F: IsField + 'static, E: IsField + 'static, @@ -299,8 +295,8 @@ where } let (n, lde_size) = match check_base_layout::(columns, blowup_factor) { - LayoutDispatch::Empty => return true, // nothing to do — same as CPU path - LayoutDispatch::Skip => return false, + LayoutDispatch::Empty => return Some(()), // nothing to do — same as CPU path + LayoutDispatch::Skip => return None, LayoutDispatch::Run { n, lde_size } => (n, lde_size), }; let num_columns = columns.len(); @@ -309,7 +305,7 @@ where let raw_columns = unsafe { columns_to_u64_base::(columns) }; let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - GPU_LDE_CALLS.fetch_add(num_columns as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LDE_CALLS.fetch_add(num_columns as u64, Ordering::Relaxed); let gpu_result = { let mut raw_outputs = unsafe { presize_and_view_base::(columns, lde_size) }; math_cuda::lde::coset_lde_batch_base_into( @@ -324,9 +320,9 @@ where // only writes outputs at the very end (post-synchronize host copy); // on any Err the caller's `columns[0..n]` is untouched trace data. restore_columns_on_err(columns, n); - return false; + return None; } - true + Some(()) } /// GPU path for `Prover::extend_half_to_lde`. @@ -368,7 +364,7 @@ where if TypeId::of::() != TypeId::of::() { return None; } - GPU_EXTEND_HALVES_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_EXTEND_HALVES_CALLS.fetch_add(1, Ordering::Relaxed); // Weights are built from `g = domain.coset_offset` directly: the // CPU caller previously passed `g²` redundantly. See the // `g^(-k) / N` weight loop below. @@ -377,7 +373,7 @@ where let to_u64 = |col: &[FieldElement]| -> Vec { let len = col.len() * 3; let ptr = col.as_ptr() as *const u64; - unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + unsafe { from_raw_parts(ptr, len) }.to_vec() }; let h0_raw = to_u64(h0); let h1_raw = to_u64(h1); @@ -399,7 +395,9 @@ where let mut lde_h0 = vec![FieldElement::::zero(); lde_size]; let mut lde_h1 = vec![FieldElement::::zero(); lde_size]; - GPU_LDE_CALLS.fetch_add(6, std::sync::atomic::Ordering::Relaxed); // 2 ext3 cols × 3 components + // Two ext3 columns (h0 + h1), each composed of 3 base-field components. + const NUM_COLS: usize = 2; + GPU_LDE_CALLS.fetch_add((NUM_COLS * 3) as u64, Ordering::Relaxed); { let inputs: [&[u64]; 2] = [&h0_raw, &h1_raw]; // View each output Vec> as &mut [u64] of length 3*lde_size. @@ -410,8 +408,8 @@ where let ext3_len = lde_size .checked_mul(3) .expect("ext3 output length overflow"); - let out0_slice = unsafe { core::slice::from_raw_parts_mut(out0_ptr, ext3_len) }; - let out1_slice = unsafe { core::slice::from_raw_parts_mut(out1_ptr, ext3_len) }; + let out0_slice = unsafe { from_raw_parts_mut(out0_ptr, ext3_len) }; + let out1_slice = unsafe { from_raw_parts_mut(out1_ptr, ext3_len) }; let mut outputs: [&mut [u64]; 2] = [out0_slice, out1_slice]; if math_cuda::lde::coset_lde_batch_ext3_into(&inputs, n, blowup, &weights_u64, &mut outputs) .is_err() @@ -423,133 +421,16 @@ where Some((lde_h0, lde_h1)) } -/// Combined GPU LDE + Merkle leaf hash for the base-field main trace. -/// -/// Keeps LDE output on device, runs Keccak-256 on the device buffer directly, -/// D2Hs both LDE columns (for Round 2-4 reuse) and hashed leaves (for tree -/// construction). Avoids the second H2D that a separate GPU Merkle commit -/// path would require. -/// -/// On success: resizes each `columns[c]` to `lde_size` with the LDE output, -/// and returns `Vec` — the Keccak-256 hashed leaves in natural -/// row order, ready to pass to `BatchedMerkleTree::build_from_hashed_leaves`. -#[allow(dead_code)] -pub(crate) fn try_expand_and_leaf_hash_batched( - columns: &mut [Vec>], - blowup_factor: usize, - weights: &[FieldElement], -) -> Option> -where - F: IsField + 'static, - E: IsField + 'static, -{ - let (n, lde_size) = match check_base_layout::(columns, blowup_factor) { - LayoutDispatch::Empty => return Some(Vec::new()), - LayoutDispatch::Skip => return None, - LayoutDispatch::Run { n, lde_size } => (n, lde_size), - }; - let num_columns = columns.len(); - - // SAFETY: layout-checked above. - let raw_columns = unsafe { columns_to_u64_base::(columns) }; - let weights_u64 = unsafe { weights_to_u64::(weights) }; - let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - - // Allocate the leaf-hash buffer directly as `Vec<[u8; 32]>` to skip a - // re-chunk pass; fresh pages fault on first write but only once each. - let mut leaves: Vec<[u8; 32]> = Vec::with_capacity(lde_size); - // SAFETY: every byte will be overwritten by the GPU D2H below. - unsafe { leaves.set_len(lde_size) }; - let leaf_byte_len = lde_size.checked_mul(32).expect("leaf byte length overflow"); - - GPU_LDE_CALLS.fetch_add(num_columns as u64, std::sync::atomic::Ordering::Relaxed); - GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let gpu_result = { - let mut raw_outputs = unsafe { presize_and_view_base::(columns, lde_size) }; - let hashed_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(leaves.as_mut_ptr() as *mut u8, leaf_byte_len) - }; - math_cuda::lde::coset_lde_batch_base_into_with_leaf_hash( - &slices, - blowup_factor, - &weights_u64, - &mut raw_outputs, - hashed_bytes, - ) - }; - if gpu_result.is_err() { - restore_columns_on_err(columns, n); - return None; - } - - Some(leaves) -} - -pub(crate) static GPU_LEAF_HASH_CALLS: std::sync::atomic::AtomicU64 = - std::sync::atomic::AtomicU64::new(0); +pub(crate) static GPU_LEAF_HASH_CALLS: AtomicU64 = AtomicU64::new(0); pub fn gpu_leaf_hash_calls() -> u64 { - GPU_LEAF_HASH_CALLS.load(std::sync::atomic::Ordering::Relaxed) -} - -/// Fused variant: LDE + leaf-hash + Merkle tree build, all on device. Skips -/// the pinned→pageable→pinned leaf dance of the separate-step pipeline. -/// Returns the filled `MerkleTree` alongside populating `columns` with -/// the LDE-expanded evaluations. -#[allow(dead_code)] -pub(crate) fn try_expand_leaf_and_tree_batched( - columns: &mut [Vec>], - blowup_factor: usize, - weights: &[FieldElement], -) -> Option> -where - F: IsField + 'static, - E: IsField + 'static, - B: crypto::merkle_tree::traits::IsMerkleTreeBackend, -{ - let (n, lde_size) = match check_base_layout::(columns, blowup_factor) { - LayoutDispatch::Empty | LayoutDispatch::Skip => return None, - LayoutDispatch::Run { n, lde_size } => (n, lde_size), - }; - let num_columns = columns.len(); - let (mut nodes, total_nodes) = alloc_merkle_nodes(lde_size)?; - let node_byte_len = total_nodes - .checked_mul(32) - .expect("node byte length overflow"); - - // SAFETY: layout-checked above. - let raw_columns = unsafe { columns_to_u64_base::(columns) }; - let weights_u64 = unsafe { weights_to_u64::(weights) }; - let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - - GPU_LDE_CALLS.fetch_add(num_columns as u64, std::sync::atomic::Ordering::Relaxed); - GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - - let gpu_result = { - let mut raw_outputs = unsafe { presize_and_view_base::(columns, lde_size) }; - let nodes_bytes: &mut [u8] = unsafe { - core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) - }; - math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree( - &slices, - blowup_factor, - &weights_u64, - &mut raw_outputs, - nodes_bytes, - ) - }; - if gpu_result.is_err() { - restore_columns_on_err(columns, n); - return None; - } - - crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) + GPU_LEAF_HASH_CALLS.load(Ordering::Relaxed) } -/// Same as [`try_expand_leaf_and_tree_batched`] but ALSO retains the LDE -/// device buffer so R2–R4 GPU paths can reuse the LDE without a re-H2D. -/// Returns `(tree, gpu_handle)` on success, `None` if the GPU path doesn't -/// apply (same gates as the non-`_keep` variant). +/// Fused base-field path: LDE + Keccak-256 leaf hash + Merkle tree build, +/// all on device, with the LDE buffer retained for R2–R4 GPU reuse. On +/// success: `columns[c]` is resized to `lde_size` with the LDE output, and +/// the returned `(tree, GpuLdeBase)` pair is the host-side tree plus a +/// device-resident handle to the LDE buffer. pub(crate) fn try_expand_leaf_and_tree_batched_keep( columns: &mut [Vec>], blowup_factor: usize, @@ -578,15 +459,14 @@ where let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - GPU_LDE_CALLS.fetch_add(num_columns as u64, std::sync::atomic::Ordering::Relaxed); - GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_LDE_CALLS.fetch_add(num_columns as u64, Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, Ordering::Relaxed); let handle_result = { let mut raw_outputs = unsafe { presize_and_view_base::(columns, lde_size) }; - let nodes_bytes: &mut [u8] = unsafe { - core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) - }; + let nodes_bytes: &mut [u8] = + unsafe { from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) }; math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree_keep( &slices, blowup_factor, @@ -607,68 +487,10 @@ where Some((tree, handle)) } -/// Ext3 variant of [`try_expand_leaf_and_tree_batched`]. Same fused flow -/// (LDE → leaf-hash → tree build) but over ext3 columns via the three-slab -/// decomposition; `B::Node = [u8; 32]` by construction for -/// `BatchKeccak256Backend`. -#[allow(dead_code)] -pub(crate) fn try_expand_leaf_and_tree_batched_ext3( - columns: &mut [Vec>], - blowup_factor: usize, - weights: &[FieldElement], -) -> Option> -where - F: IsField + 'static, - E: IsField + 'static, - B: crypto::merkle_tree::traits::IsMerkleTreeBackend, -{ - let (n, lde_size) = match check_ext3_layout::(columns, blowup_factor) { - LayoutDispatch::Empty | LayoutDispatch::Skip => return None, - LayoutDispatch::Run { n, lde_size } => (n, lde_size), - }; - let num_columns = columns.len(); - let (mut nodes, total_nodes) = alloc_merkle_nodes(lde_size)?; - let node_byte_len = total_nodes - .checked_mul(32) - .expect("node byte length overflow"); - - // SAFETY: layout-checked above. - let raw_columns = unsafe { columns_to_u64_ext3::(columns) }; - let weights_u64 = unsafe { weights_to_u64::(weights) }; - let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - - GPU_LDE_CALLS.fetch_add( - (num_columns * 3) as u64, - std::sync::atomic::Ordering::Relaxed, - ); - GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - - let gpu_result = { - let mut raw_outputs = unsafe { presize_and_view_ext3::(columns, lde_size) }; - let nodes_bytes: &mut [u8] = unsafe { - core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) - }; - math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree( - &slices, - n, - blowup_factor, - &weights_u64, - &mut raw_outputs, - nodes_bytes, - ) - }; - if gpu_result.is_err() { - restore_columns_on_err(columns, n); - return None; - } - - crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) -} - -/// Same as [`try_expand_leaf_and_tree_batched_ext3`] but also returns the -/// ext3 LDE device buffer (de-interleaved 3-slab layout) so downstream GPU -/// rounds can reuse it. +/// Fused ext3 path: LDE + Keccak-256 leaf hash + Merkle tree build over +/// ext3 columns via the three-slab decomposition, with the ext3 LDE device +/// buffer (de-interleaved 3-slab layout) retained for downstream GPU rounds. +/// `B::Node = [u8; 32]` by construction for `BatchKeccak256Backend`. pub(crate) fn try_expand_leaf_and_tree_batched_ext3_keep( columns: &mut [Vec>], blowup_factor: usize, @@ -697,18 +519,14 @@ where let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - GPU_LDE_CALLS.fetch_add( - (num_columns * 3) as u64, - std::sync::atomic::Ordering::Relaxed, - ); - GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_LDE_CALLS.fetch_add((num_columns * 3) as u64, Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, Ordering::Relaxed); let handle_result = { let mut raw_outputs = unsafe { presize_and_view_ext3::(columns, lde_size) }; - let nodes_bytes: &mut [u8] = unsafe { - core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) - }; + let nodes_bytes: &mut [u8] = + unsafe { from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) }; math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree_keep( &slices, n, @@ -730,78 +548,20 @@ where Some((tree, handle)) } -/// Ext3 variant of [`try_expand_and_leaf_hash_batched`] for the aux trace. -/// Decomposes each ext3 column into three base slabs, runs the LDE + Keccak -/// ext3 kernel in one on-device pipeline, re-interleaves LDE output back to -/// ext3 layout, and returns hashed leaves. -#[allow(dead_code)] -pub(crate) fn try_expand_and_leaf_hash_batched_ext3( - columns: &mut [Vec>], - blowup_factor: usize, - weights: &[FieldElement], -) -> Option> -where - F: IsField + 'static, - E: IsField + 'static, -{ - let (n, lde_size) = match check_ext3_layout::(columns, blowup_factor) { - LayoutDispatch::Empty => return Some(Vec::new()), - LayoutDispatch::Skip => return None, - LayoutDispatch::Run { n, lde_size } => (n, lde_size), - }; - let num_columns = columns.len(); - - // SAFETY: layout-checked above. - let raw_columns = unsafe { columns_to_u64_ext3::(columns) }; - let weights_u64 = unsafe { weights_to_u64::(weights) }; - let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - - let mut leaves: Vec<[u8; 32]> = Vec::with_capacity(lde_size); - // SAFETY: every byte will be overwritten by the GPU D2H below. - unsafe { leaves.set_len(lde_size) }; - let leaf_byte_len = lde_size.checked_mul(32).expect("leaf byte length overflow"); - - GPU_LDE_CALLS.fetch_add( - (num_columns * 3) as u64, - std::sync::atomic::Ordering::Relaxed, - ); - GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let gpu_result = { - let mut raw_outputs = unsafe { presize_and_view_ext3::(columns, lde_size) }; - let hashed_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(leaves.as_mut_ptr() as *mut u8, leaf_byte_len) - }; - math_cuda::lde::coset_lde_batch_ext3_into_with_leaf_hash( - &slices, - n, - blowup_factor, - &weights_u64, - &mut raw_outputs, - hashed_bytes, - ) - }; - if gpu_result.is_err() { - restore_columns_on_err(columns, n); - return None; - } - - Some(leaves) -} - /// Ext3 specialisation of [`try_expand_columns_batched`]. `E` is known to be /// `Degree3GoldilocksExtensionField` by TypeId match at the caller. fn try_expand_columns_batched_ext3( columns: &mut [Vec>], blowup_factor: usize, weights: &[FieldElement], -) -> bool +) -> Option<()> where F: IsField + 'static, E: IsField + 'static, { let (n, lde_size) = match check_ext3_layout::(columns, blowup_factor) { - LayoutDispatch::Empty => return true, - LayoutDispatch::Skip => return false, + LayoutDispatch::Empty => return Some(()), + LayoutDispatch::Skip => return None, LayoutDispatch::Run { n, lde_size } => (n, lde_size), }; let num_columns = columns.len(); @@ -813,10 +573,7 @@ where // Account each ext3 column as 3 logical GPU LDE "calls" (base-field // components) so the counter matches the base-field batched path. - GPU_LDE_CALLS.fetch_add( - (num_columns * 3) as u64, - std::sync::atomic::Ordering::Relaxed, - ); + GPU_LDE_CALLS.fetch_add((num_columns * 3) as u64, Ordering::Relaxed); let gpu_result = { let mut raw_outputs = unsafe { presize_and_view_ext3::(columns, lde_size) }; math_cuda::lde::coset_lde_batch_ext3_into( @@ -829,102 +586,12 @@ where }; if gpu_result.is_err() { restore_columns_on_err(columns, n); - return false; - } - true -} - -// ============================================================================ -// GPU Merkle inner-tree construction -// ============================================================================ -// -// After the GPU keccak leaf-hash kernels produce a flat `[u8; 32]` leaf vec, -// the inner tree construction on CPU via `build_from_hashed_leaves` is a -// rayon-parallel pair-hash scan that still takes ~50-100 ms per table on a -// 46-core host. Delegating it to `math_cuda::merkle::build_merkle_tree_on_device` -// pushes it below 10 ms — the leaf buffer is already on host (it came out of -// `try_expand_and_leaf_hash_batched`), we H2D it once, the GPU does ~log₂(N) -// small kernel launches, and we D2H the full `2*leaves_len - 1` node array. - -static GPU_MERKLE_TREE_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); -pub fn gpu_merkle_tree_calls() -> u64 { - GPU_MERKLE_TREE_CALLS.load(std::sync::atomic::Ordering::Relaxed) -} - -/// Build a Merkle tree from already-hashed leaves using the GPU pair-hash -/// kernel. Returns the filled `MerkleTree` in the same layout as the CPU -/// `build_from_hashed_leaves` would produce — plug straight in anywhere the -/// prover expected that. -/// -/// Returns `None` if the GPU path is disabled by threshold (`leaves_len < -/// GPU_MERKLE_TREE_THRESHOLD`), falling back to the caller's CPU path. -/// -/// Currently unwired in the prover: benchmarking showed the savings from -/// the GPU pair-hash are eaten by the H2D of leaves + D2H of the tree -/// because the leaves are in pageable memory (they're the caller's Vec from -/// `try_expand_and_leaf_hash_batched`). A proper fusion would keep the -/// leaf buffer on device and run the tree kernel immediately on the GPU -/// copy — left as future work. -#[allow(dead_code)] -pub(crate) fn try_build_merkle_tree_gpu( - hashed_leaves: &[B::Node], -) -> Option> -where - B: crypto::merkle_tree::traits::IsMerkleTreeBackend, -{ - let leaves_len = hashed_leaves.len(); - if leaves_len < gpu_merkle_tree_threshold() || !leaves_len.is_power_of_two() || leaves_len < 2 { return None; } - let leaves_byte_len = leaves_len - .checked_mul(32) - .expect("leaf byte length overflow"); - let total_nodes = 2usize - .checked_mul(leaves_len) - .and_then(|v| v.checked_sub(1)) - .expect("merkle node count overflow"); - let node_byte_len = total_nodes - .checked_mul(32) - .expect("node byte length overflow"); - - GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - - // Flatten host-side leaves into a contiguous byte buffer for the GPU - // kernel. SAFETY: `[u8; 32]` is POD and the slice is contiguous. - let leaves_bytes: &[u8] = unsafe { - core::slice::from_raw_parts(hashed_leaves.as_ptr() as *const u8, leaves_byte_len) - }; - let nodes_bytes = match math_cuda::merkle::build_merkle_tree_on_device(leaves_bytes) { - Ok(b) => b, - Err(_) => return None, - }; - - debug_assert_eq!(nodes_bytes.len(), node_byte_len); - - // Re-chunk the flat byte buffer into `Vec<[u8; 32]>`. Alignment is - // identical (`[u8; 32]` has align 1), but `Vec` and `Vec<[u8; 32]>` - // track different element counts, so a fresh allocation + per-row copy - // is the simplest correct conversion. - let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); - for i in 0..total_nodes { - let mut n = [0u8; 32]; - n.copy_from_slice(&nodes_bytes[i * 32..(i + 1) * 32]); - nodes.push(n); - } - - crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) + Some(()) } -/// Below this (tree size), stay on CPU — rayon pair-hash is already well -/// under a millisecond for small N and would lose to any PCIe round-trip. -const DEFAULT_GPU_MERKLE_TREE_THRESHOLD: usize = 1 << 15; - -fn gpu_merkle_tree_threshold() -> usize { - static CACHED: std::sync::OnceLock = std::sync::OnceLock::new(); - *CACHED.get_or_init(|| { - std::env::var("LAMBDA_VM_GPU_MERKLE_TREE_THRESHOLD") - .ok() - .and_then(|s| s.parse().ok()) - .unwrap_or(DEFAULT_GPU_MERKLE_TREE_THRESHOLD) - }) +static GPU_MERKLE_TREE_CALLS: AtomicU64 = AtomicU64::new(0); +pub fn gpu_merkle_tree_calls() -> u64 { + GPU_MERKLE_TREE_CALLS.load(Ordering::Relaxed) } diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 64cf2ea36..2fdf99d69 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -161,6 +161,11 @@ where precomputed_tree: Option>>, precomputed_root: Option, num_precomputed_cols: usize, + /// Device-side LDE buffer kept for downstream GPU rounds when the R1 fused + /// pipeline produced one. Carried alongside the commit data so the + /// .zip() chain in Phase D stays compiler-aligned by construction. + #[cfg(feature = "cuda")] + gpu_main: Option, } /// Round 1 commitment artifacts — Merkle trees, roots, challenges, and bus inputs. @@ -220,8 +225,13 @@ pub struct MainTraceCommitResult where FieldElement: AsBytes, { - pub(crate) tree: BatchedMerkleTree, - pub(crate) root: Commitment, + /// Primary commitment tree: the main trace for `commit_main_trace`, the + /// multiplicity columns for `commit_preprocessed_trace`. + pub(crate) commit_tree: BatchedMerkleTree, + /// Root of `commit_tree`. + pub(crate) commit_root: Commitment, + /// Secondary tree for preprocessed traces: the static precomputed columns. + /// `None` for non-preprocessed tables. pub(crate) precomputed_tree: Option>, pub(crate) precomputed_root: Option, pub(crate) num_precomputed_cols: usize, @@ -637,7 +647,9 @@ pub trait IsStarkProver< columns, domain.blowup_factor, &twiddles.coset_weights, - ) { + ) + .is_some() + { return; } @@ -688,14 +700,16 @@ pub trait IsStarkProver< { #[cfg(feature = "instruments")] let main_lde_dur = t_sub.elapsed(); - #[cfg(feature = "instruments")] - let zero = std::time::Duration::from_secs(0); let root = tree.root; + // Fused GPU path produces LDE + leaves + tree as one pipeline, + // so the wall-clock total lands in `main_lde_dur`. Bill the + // merkle bucket equal to LDE so the sum (lde + merkle) stays + // comparable to the non-GPU path's combined LDE+commit total. #[cfg(feature = "instruments")] - crate::instruments::accum_r1_main(main_lde_dur, zero); + crate::instruments::accum_r1_main(main_lde_dur, main_lde_dur); return Ok(MainTraceCommitResult { - tree, - root, + commit_tree: tree, + commit_root: root, precomputed_tree: None, precomputed_root: None, num_precomputed_cols: 0, @@ -730,8 +744,8 @@ pub trait IsStarkProver< } Ok(MainTraceCommitResult { - tree, - root, + commit_tree: tree, + commit_root: root, precomputed_tree: None, precomputed_root: None, num_precomputed_cols: 0, @@ -797,8 +811,8 @@ pub trait IsStarkProver< } Ok(MainTraceCommitResult { - tree: mult_tree, - root: mult_root, + commit_tree: mult_tree, + commit_root: mult_root, precomputed_tree: Some(precomputed_tree), precomputed_root: Some(precomputed_root), num_precomputed_cols, @@ -1804,9 +1818,6 @@ pub trait IsStarkProver< let mut main_commits: Vec> = Vec::with_capacity(num_airs); let mut main_ldes: Vec>>> = Vec::with_capacity(num_airs); - #[cfg(feature = "cuda")] - let mut main_gpu_handles: Vec> = - Vec::with_capacity(num_airs); for chunk_start in (0..num_airs).step_by(k) { let chunk_end = (chunk_start + k).min(num_airs); @@ -1851,17 +1862,17 @@ pub trait IsStarkProver< if let Some(ref pre_r) = r.precomputed_root { transcript.append_bytes(pre_r); } - transcript.append_bytes(&r.root); + transcript.append_bytes(&r.commit_root); main_commits.push(MainCommitData { - main_tree: Arc::new(r.tree), - main_root: r.root, + main_tree: Arc::new(r.commit_tree), + main_root: r.commit_root, precomputed_tree: r.precomputed_tree.map(Arc::new), precomputed_root: r.precomputed_root, num_precomputed_cols: r.num_precomputed_cols, + #[cfg(feature = "cuda")] + gpu_main: r.gpu_main, }); main_ldes.push(r.columns); - #[cfg(feature = "cuda")] - main_gpu_handles.push(r.gpu_main); } } @@ -1955,10 +1966,18 @@ pub trait IsStarkProver< }) .collect(); - // Parallel aux commit in chunks of K. The optional ext3 GPU LDE handle - // (retained when the R1 fused pipeline fires) is carried in a side - // vector under `cfg(cuda)` so AuxResult stays a clean 3-tuple in both - // cfg variants. + // Parallel aux commit in chunks of K. The closure returns a cfg-gated + // AuxResult — under cuda it carries the optional ext3 GPU LDE handle + // as a fourth element so the .zip() chain in Phase D stays + // compiler-aligned with no side vectors. + #[cfg(feature = "cuda")] + type AuxResult = ( + Option>>, + Option, + Vec>>, + Option, + ); + #[cfg(not(feature = "cuda"))] type AuxResult = ( Option>>, Option, @@ -1966,9 +1985,6 @@ pub trait IsStarkProver< ); #[allow(clippy::type_complexity)] let mut aux_results: Vec> = Vec::with_capacity(num_airs); - #[cfg(feature = "cuda")] - let mut aux_gpu_handles: Vec> = - Vec::with_capacity(num_airs); for chunk_start in (0..num_airs).step_by(k) { let chunk_end = (chunk_start + k).min(num_airs); @@ -1979,11 +1995,8 @@ pub trait IsStarkProver< #[cfg(not(feature = "parallel"))] let iter = chunk_range; - // Per-iter the closure produces `(AuxResult, Option)` - // under cuda, or `AuxResult` alone under non-cuda. Splitting them - // at the sequential collection step keeps the two-vec layout. #[allow(clippy::type_complexity)] - let chunk_aux: Vec> = iter + let chunk_aux: Vec, ProvingError>> = iter .map(|idx| { let (air, trace, _) = &air_trace_pairs[idx]; let domain = &domains[idx]; @@ -2012,13 +2025,16 @@ pub trait IsStarkProver< { #[cfg(feature = "instruments")] let aux_lde_dur = t_sub.elapsed(); - #[cfg(feature = "instruments")] - let zero = std::time::Duration::from_secs(0); let root = tree.root; + // Fused GPU path: bill merkle equal to LDE so + // the (lde + merkle) sum stays comparable to + // the non-GPU path's combined R1 total. #[cfg(feature = "instruments")] - crate::instruments::accum_r1_aux(aux_lde_dur, zero); + crate::instruments::accum_r1_aux(aux_lde_dur, aux_lde_dur); return Ok(( - (Some(Arc::new(tree)), Some(root), columns), + Some(Arc::new(tree)), + Some(root), + columns, Some(handle), )); } @@ -2053,35 +2069,26 @@ pub trait IsStarkProver< } #[cfg(feature = "cuda")] - return Ok(( - (Some(Arc::new(tree)), Some(root), columns), - None::, - )); + return Ok((Some(Arc::new(tree)), Some(root), columns, None)); #[cfg(not(feature = "cuda"))] Ok((Some(Arc::new(tree)), Some(root), columns)) } else { #[cfg(feature = "cuda")] - return Ok(((None, None, Vec::new()), None::)); + return Ok((None, None, Vec::new(), None)); #[cfg(not(feature = "cuda"))] Ok((None, None, Vec::new())) } }) .collect(); - // Sequential: append aux roots to forked transcripts and split - // the optional GPU handle into its own side vector under cuda. + // Sequential: append aux roots to forked transcripts. for (j, result) in chunk_aux.into_iter().enumerate() { - #[cfg(feature = "cuda")] - let (aux_triple, aux_gpu_h) = result?; - #[cfg(not(feature = "cuda"))] - let aux_triple = result?; - let (aux_tree, aux_root, cached_aux) = aux_triple; - if let Some(ref root) = aux_root { + let aux_full = result?; + // Tuple shape is cfg-gated; `.1` is the root in both variants. + if let Some(ref root) = aux_full.1 { table_transcripts[chunk_start + j].append_bytes(root); } - aux_results.push((aux_tree, aux_root, cached_aux)); - #[cfg(feature = "cuda")] - aux_gpu_handles.push(aux_gpu_h); + aux_results.push(aux_full); } } @@ -2091,18 +2098,16 @@ pub trait IsStarkProver< Vec::with_capacity(num_airs); let mut cached_ldes: Vec> = Vec::with_capacity(num_airs); - #[cfg(feature = "cuda")] - let mut main_gpu_iter = main_gpu_handles.into_iter(); - #[cfg(feature = "cuda")] - let mut aux_gpu_iter = aux_gpu_handles.into_iter(); - - for (((main_commit, main_lde), (aux_tree, aux_root, cached_aux)), bus_public_inputs) in - main_commits - .into_iter() - .zip(main_ldes) - .zip(aux_results) - .zip(bus_inputs_vec) + for (((main_commit, main_lde), aux_full), bus_public_inputs) in main_commits + .into_iter() + .zip(main_ldes) + .zip(aux_results) + .zip(bus_inputs_vec) { + #[cfg(feature = "cuda")] + let (aux_tree, aux_root, cached_aux, gpu_aux) = aux_full; + #[cfg(not(feature = "cuda"))] + let (aux_tree, aux_root, cached_aux) = aux_full; commitments.push(Round1Commitments { main_merkle_tree: main_commit.main_tree, main_merkle_root: main_commit.main_root, @@ -2118,12 +2123,8 @@ pub trait IsStarkProver< cached_ldes.push(Lde { main: main_lde, aux: cached_aux, - gpu_main: main_gpu_iter - .next() - .expect("main_gpu_handles length mismatch"), - gpu_aux: aux_gpu_iter - .next() - .expect("aux_gpu_handles length mismatch"), + gpu_main: main_commit.gpu_main, + gpu_aux, }); #[cfg(not(feature = "cuda"))] cached_ldes.push(Lde { diff --git a/prover/tests/bench_single.rs b/prover/tests/bench_single.rs index 947f0fddf..fac6ad901 100644 --- a/prover/tests/bench_single.rs +++ b/prover/tests/bench_single.rs @@ -7,6 +7,17 @@ fn prove_fib_1m_once() { let elf = asm_elf_bytes("fib_iterative_1M"); // Warm-up pays one-time costs (PTX load, pool warm-up). let _ = lambda_vm_prover::prove(&elf).expect("warm-up"); + // Reset GPU counters so the profiled-pass assert below reflects only the + // second run, not warm-up + profiled combined. + #[cfg(feature = "cuda")] + stark::gpu_lde::reset_all_gpu_call_counters(); // The profiled run: let _ = lambda_vm_prover::prove(&elf).expect("prove"); + // Catch silent regressions where the table sizes drop below the GPU LDE + // threshold and we'd be measuring CPU numbers without noticing. + #[cfg(feature = "cuda")] + assert!( + stark::gpu_lde::gpu_lde_calls() > 0, + "GPU LDE path did not fire — fib_iterative_1M may have dropped below the GPU threshold" + ); } From 71aba0d66d87e0090200be7dd9b3317bf76e7515 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 22 May 2026 16:53:49 -0300 Subject: [PATCH 15/16] address doc comment suggestions --- crypto/stark/src/gpu_lde.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index 7c9983601..ea4f04488 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -239,8 +239,7 @@ unsafe fn presize_and_view_ext3( /// Truncate each column back to `n` (trace size) after a GPU error so the /// CPU fallback (which reads `buffer.len()` as the trace size) runs cleanly. /// Safe because `math_cuda` writes outputs only at the final host copy, post- -/// synchronize; any `Err` returns before that copy, leaving `columns[0..n]` -/// untouched. +/// synchronize; any `Err` returns before that copy, leaving `columns[0..n]` untouched. fn restore_columns_on_err(columns: &mut [Vec>], n: usize) { for col in columns.iter_mut() { col.truncate(n); @@ -260,7 +259,7 @@ fn alloc_merkle_nodes(lde_size: usize) -> Option<(Vec<[u8; 32]>, usize)> { let _byte_len = total_nodes.checked_mul(32)?; let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); // SAFETY: every byte will be overwritten via the GPU D2H before the - // contents are read; the caller computes the byte-length view from the + // contents are read. The caller computes the byte-length view from the // returned `nodes` Vec using `total_nodes.checked_mul(32)`. unsafe { nodes.set_len(total_nodes) }; Some((nodes, total_nodes)) @@ -270,13 +269,12 @@ fn alloc_merkle_nodes(lde_size: usize) -> Option<(Vec<[u8; 32]>, usize)> { /// /// Only engaged for Goldilocks-base tables whose LDE size is above the /// threshold. The prover's `expand_columns_to_lde` hands us every column of -/// one table at once; those columns all share twiddles and coset weights so +/// one table at once. Those columns all share twiddles and coset weights so /// they can be processed in a single batched pipeline on one stream. /// /// Returns `Some(())` if the batch was handled on GPU (and `columns` now /// contains the LDE evaluations). Returns `None` to let the caller run the /// per-column CPU fallback. -#[inline] pub(crate) fn try_expand_columns_batched( columns: &mut [Vec>], blowup_factor: usize, @@ -286,8 +284,8 @@ where F: IsField + 'static, E: IsField + 'static, { - // Ext3 fast path: decompose each ext3 column into its 3 base components - // and dispatch to the base-field batched NTT with 3×M logical columns. + // Ext3 path: decompose each ext3 column into its 3 base components and + // dispatch to the base-field batched NTT with 3×M logical columns. // Butterflies with a base-field twiddle act componentwise on ext3, so // this is exactly equivalent to running the NTT in the extension field. if TypeId::of::() == TypeId::of::() { @@ -301,7 +299,8 @@ where }; let num_columns = columns.len(); - // SAFETY: layout-checked above (`E == GoldilocksField`, `F == GoldilocksField`). + // SAFETY: the `Run` arm of `check_base_layout::` (matched above) + // guarantees `E == GoldilocksField` and `F == GoldilocksField`. let raw_columns = unsafe { columns_to_u64_base::(columns) }; let weights_u64 = unsafe { weights_to_u64::(weights) }; let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); From 34cae4bbc3f61edc07380dd225514210216eb489 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 22 May 2026 18:39:47 -0300 Subject: [PATCH 16/16] fix --- crypto/math-cuda/src/device.rs | 69 +++++++++++++++------ crypto/math-cuda/src/lde.rs | 107 +++++++++++++++------------------ 2 files changed, 99 insertions(+), 77 deletions(-) diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index d6d5fc403..f2cc988c0 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -102,16 +102,26 @@ const STREAM_POOL_SIZE: usize = 32; pub struct Backend { pub ctx: Arc, streams: Vec>, - /// Single shared pinned staging buffer, grown to the biggest LDE size - /// seen. Concurrent batched LDE calls serialise on it; in exchange the - /// process keeps only ONE gigabyte-sized pinned allocation (per-stream - /// buffers 32×-inflated memory use and multiplied the one-time pinning - /// cost for every first use of a new table size). - pinned_staging: Mutex, - /// Separate pinned staging for Merkle leaf hashes. Sized `num_rows * 32` - /// bytes. It lives alongside the LDE staging so the GPU→host D2H for - /// hashed leaves runs at full PCIe line-rate. - pinned_hashes: Mutex, + /// Per-rayon-worker pinned staging buffers, grown lazily to the biggest + /// LDE size each worker sees. Indexed by `rayon::current_thread_index()` + /// (or 0 for non-rayon callers). + /// + /// Per-worker (not single-shared) because the LDE call holds the lock + /// across an internal rayon `par_chunks_mut`/`par_iter` window: with a + /// single shared mutex, rayon work-stealing can yield a lock-holder onto + /// another task waiting for the same lock — classic recursive-rayon + /// deadlock. Per-worker buffers eliminate cross-worker contention so + /// each `par_iter` worker hits a distinct mutex. + /// + /// Each entry starts empty (`PinnedStaging::empty()` is a zero-cost null + /// handle); only the slots actually used by the running workers ever + /// allocate pinned memory. Worst-case footprint is `N_workers × + /// max_LDE_size` of pinned host RAM. + pinned_staging: Vec>, + /// Per-worker pinned staging for Merkle leaf hashes. Same layout as + /// `pinned_staging`; sized `num_rows * 32` bytes per slot. Lives + /// alongside the LDE staging so the GPU→host D2H runs at PCIe line-rate. + pinned_hashes: Vec>, util_stream: Arc, next: AtomicUsize, @@ -166,8 +176,20 @@ impl Backend { for _ in 0..STREAM_POOL_SIZE { streams.push(ctx.new_stream()?); } - let pinned_staging = Mutex::new(PinnedStaging::empty()); - let pinned_hashes = Mutex::new(PinnedStaging::empty()); + // Size to the rayon worker count, plus one for non-rayon callers + // who land on slot 0 (`rayon::current_thread_index()` returns None + // outside a rayon context — we map that to 0). + // + // `current_num_threads()` returns the default-pool size if no custom + // pool is in use, which is the cpu count. Stable across the + // backend's lifetime since rayon's pool is fixed at first use. + let n_slots = rayon::current_num_threads().max(1); + let pinned_staging: Vec> = (0..n_slots) + .map(|_| Mutex::new(PinnedStaging::empty())) + .collect(); + let pinned_hashes: Vec> = (0..n_slots) + .map(|_| Mutex::new(PinnedStaging::empty())) + .collect(); // Separate "utility" stream for twiddle uploads and other bookkeeping; // not part of the pool that callers rotate through. let util_stream = ctx.new_stream()?; @@ -219,16 +241,29 @@ impl Backend { self.streams[idx].clone() } - /// Shared pinned staging buffer. Grows to the largest LDE the process - /// has seen so far. Concurrent callers serialise on the mutex. + /// Per-rayon-worker pinned staging buffer. Returns the slot for the + /// current worker (or slot 0 outside a rayon context). Grows lazily to + /// the largest LDE the worker has seen. See the field docs for the + /// rationale behind the per-worker split. pub fn pinned_staging(&self) -> &Mutex { - &self.pinned_staging + &self.pinned_staging[self.worker_slot(self.pinned_staging.len())] } - /// Separate pinned staging for Merkle leaf hash output. Sized in u64 + /// Per-worker pinned staging for Merkle leaf hash output. Sized in u64 /// units. Caller should reserve `(num_rows * 32 + 7) / 8` u64s. pub fn pinned_hashes(&self) -> &Mutex { - &self.pinned_hashes + &self.pinned_hashes[self.worker_slot(self.pinned_hashes.len())] + } + + /// Map `rayon::current_thread_index()` to a slot index, with a defensive + /// clamp in case the rayon pool grew past the Vec we sized at init. + fn worker_slot(&self, len: usize) -> usize { + let idx = rayon::current_thread_index().unwrap_or(0); + // Should be unreachable with rayon's fixed default pool, but if a + // larger custom pool sneaks in we still want safety — fall back to + // slot 0 (correctness preserved, just contention). + debug_assert!(idx < len, "rayon worker {idx} >= staging slots {len}"); + idx.min(len.saturating_sub(1)) } pub fn fwd_twiddles_for(&self, log_n: u64) -> Result>> { diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 02f109938..48b580994 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -13,7 +13,6 @@ use std::sync::Arc; use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; -use rayon::prelude::*; use crate::Result; use crate::device::{Backend, backend}; @@ -69,7 +68,10 @@ pub(crate) fn pack_ext3_to_pinned_slabs(columns: &[&[u64]], pinned: &mut [u64], let m = columns.len(); debug_assert!(pinned.len() >= 3 * m * n); let pinned_ptr_u = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { + // Sequential, not `par_iter`: this runs while the per-worker pinned + // staging mutex is held. Rayon inside a held mutex risks recursive + // stealing-during-wait deadlocks — see `Backend::pinned_staging` docs. + columns.iter().enumerate().for_each(|(c, col)| { // SAFETY: each task writes to disjoint `[(c*3 + k)*n .. ..+n]` regions // of `pinned`. The outer `&mut [u64]` borrow guarantees no aliasing. let slab_a = unsafe { @@ -96,7 +98,9 @@ fn unpack_pinned_slabs_to_ext3(pinned: &[u64], outputs: &mut [&mut [u64]], lde_s let m = outputs.len(); debug_assert!(pinned.len() >= 3 * m * lde_size); let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + // Sequential, not `par_iter_mut`: runs inside a held pinned-staging + // mutex; rayon-inside-mutex risks deadlock (see `Backend::pinned_staging`). + outputs.iter_mut().enumerate().for_each(|(c, dst)| { // SAFETY: each task reads from disjoint `[(c*3 + k)*lde_size .. ..+lde_size]` // regions of `pinned`. Caller borrows `pinned` for the duration of the call. let slab_a = unsafe { @@ -178,19 +182,14 @@ fn d2h_bytes_via_pinned_hashes( stream.memcpy_dtoh(dev_bytes, pinned_bytes)?; stream.synchronize()?; - // Single-threaded `copy_from_slice` faults virgin pageable pages one at - // a time; the mm_struct rwsem serialises them at prover scale. Chunk so - // ~N cores pre-fault+write in parallel. - const CHUNK: usize = 64 * 1024; - let src_ptr = pinned_bytes.as_ptr() as usize; - dst.par_chunks_mut(CHUNK).enumerate().for_each(|(i, d)| { - // SAFETY: each task reads `[i*CHUNK .. i*CHUNK + d.len()]` of - // `pinned_bytes`, which is disjoint per `i` and lives until `staging` - // is dropped below. - let src = - unsafe { std::slice::from_raw_parts((src_ptr as *const u8).add(i * CHUNK), d.len()) }; - d.copy_from_slice(src); - }); + // Sequential, not `par_chunks_mut`: this runs while the per-worker + // pinned_hashes mutex is held. Rayon inside a held mutex risks + // recursive stealing-during-wait deadlocks — see + // `Backend::pinned_staging` docs. Page-fault parallelism on virgin + // destination pages is recovered at the outer level: per-worker + // staging buffers let rayon's outer `par_iter` dispatch multiple LDE + // calls in parallel, each faulting its own destination pages. + dst.copy_from_slice(pinned_bytes); drop(staging); Ok(()) } @@ -367,18 +366,14 @@ pub fn coset_lde_batch_base( // SAFETY: staging is locked, the slice alias ends before we unlock. let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - // Pack columns into first m*n slots of the pinned buffer. Parallel: pinned - // writes are DRAM-bandwidth bound, so rayon spreads the cost across CPU - // cores. - let pinned_base_ptr = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - // SAFETY: each task writes to a disjoint `[c*n..c*n+n]` region of - // `pinned`, and the outer `staging` lock guarantees no other call is - // using the buffer concurrently. - let dst = - unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; - dst.copy_from_slice(col); - }); + // Pack columns into first m*n slots of the pinned buffer. Sequential + // (not `par_iter`) because this runs inside the held pinned-staging + // mutex — see `Backend::pinned_staging` docs. Pre-fault parallelism on + // the destination is recovered at the outer level via per-worker + // staging slots. + for (c, col) in columns.iter().enumerate() { + pinned[c * n..c * n + n].copy_from_slice(col); + } // Column layout: `buf[c * lde_size + r]`. Zeroed so the [n, lde_size) // tail of each column is already the zero-pad the CPU path does. @@ -459,12 +454,11 @@ pub fn coset_lde_batch_base( stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; stream.synchronize()?; - // Split pinned → per-column Vecs. The first write to each virgin - // Vec page-faults, which can dominate total time. Parallelise so the - // fault cost spreads across CPU cores. - let pinned_ptr = pinned.as_ptr() as usize; + // Split pinned → per-column Vecs. Sequential (not `into_par_iter`) + // because this runs inside the held pinned-staging mutex — see + // `Backend::pinned_staging` docs. Fault-cost parallelism is recovered + // at the outer level (per-worker staging slots). let out: Vec> = (0..m) - .into_par_iter() .map(|c| { // set_len skips the O(N) zero-init that vec![0; n] would do. // copy_from_slice below writes every slot before any reader @@ -475,10 +469,7 @@ pub fn coset_lde_batch_base( unsafe { v.set_len(lde_size) }; v }; - let src = unsafe { - std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) - }; - v.copy_from_slice(src); + v.copy_from_slice(&pinned[c * lde_size..c * lde_size + lde_size]); v }) .collect(); @@ -602,15 +593,14 @@ pub fn coset_lde_batch_base_into( stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; stream.synchronize()?; - // Parallel copy pinned → caller outputs. Caller's Vecs may still fault - // on first write; we spread that cost across rayon cores. - let pinned_ptr = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) - }; - dst.copy_from_slice(src); - }); + // Sequential copy pinned → caller outputs (not `par_iter_mut`): runs + // inside the held pinned-staging mutex; rayon-inside-mutex risks + // recursive stealing-during-wait deadlocks (see + // `Backend::pinned_staging`). Fault-cost parallelism is recovered at + // the outer level via per-worker staging slots. + for (c, dst) in outputs.iter_mut().enumerate() { + dst.copy_from_slice(&pinned[c * lde_size..c * lde_size + lde_size]); + } drop(staging); Ok(()) } @@ -734,12 +724,12 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( staging.ensure_capacity(m * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - let pinned_base_ptr = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - let dst = - unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; - dst.copy_from_slice(col); - }); + // Sequential pack (not `par_iter`): runs inside the held pinned-staging + // mutex (see `Backend::pinned_staging` docs). Per-worker staging slots + // give the outer parallelism back. + for (c, col) in columns.iter().enumerate() { + pinned[c * n..c * n + n].copy_from_slice(col); + } let mut buf = stream.alloc_zeros::(m * lde_size)?; for c in 0..m { @@ -833,14 +823,11 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?; - // Pinned LDE → caller outputs (post-sync host memcpy). - let pinned_ptr = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) - }; - dst.copy_from_slice(src); - }); + // Sequential pinned → caller outputs (not `par_iter_mut`): runs inside + // the held pinned-staging mutex (see `Backend::pinned_staging` docs). + for (c, dst) in outputs.iter_mut().enumerate() { + dst.copy_from_slice(&pinned[c * lde_size..c * lde_size + lde_size]); + } drop(staging); if keep_device_buf {