From d1a0abf4fff108ac80e07c3ca0b5bb00cf0eb970 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 6 May 2026 15:12:54 -0300 Subject: [PATCH 01/22] add first cuda files --- Cargo.lock | 32 + Cargo.toml | 1 + crypto/math-cuda/Cargo.toml | 22 + crypto/math-cuda/build.rs | 57 + crypto/math-cuda/kernels/arith.cu | 83 ++ crypto/math-cuda/kernels/ext3.cuh | 121 +++ crypto/math-cuda/kernels/goldilocks.cuh | 69 ++ crypto/math-cuda/kernels/ntt.cu | 285 +++++ crypto/math-cuda/src/device.rs | 251 +++++ crypto/math-cuda/src/lde.rs | 984 ++++++++++++++++++ crypto/math-cuda/src/lib.rs | 152 +++ crypto/math-cuda/src/ntt.rs | 211 ++++ crypto/math-cuda/tests/bench_quick.rs | 356 +++++++ crypto/math-cuda/tests/evaluate_coset_ext3.rs | 143 +++ crypto/math-cuda/tests/ext3.rs | 87 ++ crypto/math-cuda/tests/goldilocks.rs | 127 +++ crypto/math-cuda/tests/lde.rs | 112 ++ crypto/math-cuda/tests/lde_batch.rs | 96 ++ crypto/math-cuda/tests/lde_batch_ext3.rs | 161 +++ crypto/math-cuda/tests/ntt.rs | 136 +++ crypto/stark/Cargo.toml | 4 + 21 files changed, 3490 insertions(+) create mode 100644 crypto/math-cuda/Cargo.toml create mode 100644 crypto/math-cuda/build.rs create mode 100644 crypto/math-cuda/kernels/arith.cu create mode 100644 crypto/math-cuda/kernels/ext3.cuh create mode 100644 crypto/math-cuda/kernels/goldilocks.cuh create mode 100644 crypto/math-cuda/kernels/ntt.cu create mode 100644 crypto/math-cuda/src/device.rs create mode 100644 crypto/math-cuda/src/lde.rs create mode 100644 crypto/math-cuda/src/lib.rs create mode 100644 crypto/math-cuda/src/ntt.rs create mode 100644 crypto/math-cuda/tests/bench_quick.rs create mode 100644 crypto/math-cuda/tests/evaluate_coset_ext3.rs create mode 100644 crypto/math-cuda/tests/ext3.rs create mode 100644 crypto/math-cuda/tests/goldilocks.rs create mode 100644 crypto/math-cuda/tests/lde.rs create mode 100644 crypto/math-cuda/tests/lde_batch.rs create mode 100644 crypto/math-cuda/tests/lde_batch_ext3.rs create mode 100644 crypto/math-cuda/tests/ntt.rs diff --git a/Cargo.lock b/Cargo.lock index f6eea84d6..7b6ed3c62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -803,6 +803,15 @@ dependencies = [ "typenum", ] +[[package]] +name = "cudarc" +version = "0.19.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f071cd6a7b5d51607df76aa2d426aaabc7a74bc6bdb885b8afa63a880572ad9b" +dependencies = [ + "libloading", +] + [[package]] name = "darling" version = "0.21.3" @@ -1989,6 +1998,16 @@ version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" +[[package]] +name = "libloading" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "libm" version = "0.2.15" @@ -2105,6 +2124,18 @@ dependencies = [ "serde_json", ] +[[package]] +name = "math-cuda" +version = "0.1.0" +dependencies = [ + "cudarc", + "math", + "rand 0.8.5", + "rand_chacha 0.3.1", + "rayon", + "sha3", +] + [[package]] name = "memchr" version = "2.7.6" @@ -3172,6 +3203,7 @@ dependencies = [ "itertools 0.11.0", "log", "math", + "math-cuda", "rayon", "serde", "serde-wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index 4d10b7c44..e43dc7f0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "crypto/stark", "crypto/crypto", "crypto/math", + "crypto/math-cuda", "bin/cli", ] diff --git a/crypto/math-cuda/Cargo.toml b/crypto/math-cuda/Cargo.toml new file mode 100644 index 000000000..8c22d1110 --- /dev/null +++ b/crypto/math-cuda/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "math-cuda" +description = "CUDA-accelerated FFT/NTT for Goldilocks (base field) used by the lambda-vm STARK prover" +version = "0.1.0" +edition = "2024" + +[dependencies] +cudarc = { version = "0.19", default-features = false, features = [ + "driver", + "nvrtc", + "std", + "cuda-12080", + "dynamic-loading", +] } +math = { path = "../math" } +rayon = "1.7" + +[dev-dependencies] +rand = { version = "0.8.5", features = ["std"] } +rand_chacha = "0.3.1" +rayon = "1.7" +sha3 = "0.10.8" diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs new file mode 100644 index 000000000..a6defb5ab --- /dev/null +++ b/crypto/math-cuda/build.rs @@ -0,0 +1,57 @@ +use std::env; +use std::path::PathBuf; +use std::process::Command; + +fn cuda_home() -> PathBuf { + env::var_os("CUDA_HOME") + .or_else(|| env::var_os("CUDA_PATH")) + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from("/usr/local/cuda")) +} + +fn nvcc_path() -> PathBuf { + cuda_home().join("bin").join("nvcc") +} + +fn compile_ptx(src: &str, out_name: &str) { + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let src_path = manifest_dir.join("kernels").join(src); + let out_path = out_dir.join(out_name); + + println!("cargo:rerun-if-changed=kernels/{src}"); + println!("cargo:rerun-if-env-changed=CUDA_HOME"); + println!("cargo:rerun-if-env-changed=CUDA_PATH"); + println!("cargo:rerun-if-env-changed=CUDARC_NVCC_ARCH"); + + // Emit PTX for a virtual architecture; the CUDA driver JIT-compiles it for the + // actual GPU at load time, so one PTX works across Ada/Hopper/Blackwell. Override + // with CUDARC_NVCC_ARCH to pin a specific compute capability. + let arch = env::var("CUDARC_NVCC_ARCH").unwrap_or_else(|_| "compute_89".to_string()); + + let status = Command::new(nvcc_path()) + .args([ + "--ptx", + "-O3", + "-std=c++17", + "-arch", + &arch, + "-o", + ]) + .arg(&out_path) + .arg(&src_path) + .status() + .expect("failed to invoke nvcc — is CUDA installed and CUDA_HOME set?"); + + if !status.success() { + panic!("nvcc failed compiling {}", src_path.display()); + } +} + +fn main() { + // Headers are not compiled; emit rerun-if-changed so edits trigger rebuilds. + println!("cargo:rerun-if-changed=kernels/goldilocks.cuh"); + println!("cargo:rerun-if-changed=kernels/ext3.cuh"); + compile_ptx("arith.cu", "arith.ptx"); + compile_ptx("ntt.cu", "ntt.ptx"); +} diff --git a/crypto/math-cuda/kernels/arith.cu b/crypto/math-cuda/kernels/arith.cu new file mode 100644 index 000000000..4bee9b8bb --- /dev/null +++ b/crypto/math-cuda/kernels/arith.cu @@ -0,0 +1,83 @@ +// Element-wise Goldilocks kernels used by the Phase-2 parity tests. These mirror +// the CPU reference in `crypto/math/src/field/goldilocks.rs` so raw u64 outputs +// are bit-identical to the CPU path. + +#include "goldilocks.cuh" +#include "ext3.cuh" + +using goldilocks::add; +using goldilocks::sub; +using goldilocks::mul; +using goldilocks::neg; + +extern "C" __global__ void vector_add_u64(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = a[tid] + b[tid]; // plain wrapping u64 add — toolchain sanity only. +} + +extern "C" __global__ void gl_add_kernel(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = add(a[tid], b[tid]); +} + +extern "C" __global__ void gl_sub_kernel(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = sub(a[tid], b[tid]); +} + +extern "C" __global__ void gl_mul_kernel(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = mul(a[tid], b[tid]); +} + +extern "C" __global__ void gl_neg_kernel(const uint64_t *a, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = neg(a[tid]); +} + +// --------------------------------------------------------------------------- +// Ext3 (Goldilocks cubic extension) test kernels. +// Input/output arrays are interleaved [a_0, b_0, c_0, a_1, b_1, c_1, ...]. +// --------------------------------------------------------------------------- + +extern "C" __global__ void ext3_mul_kernel(const uint64_t *a_int, + const uint64_t *b_int, + uint64_t *c_int, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + ext3::Fe3 a = ext3::make(a_int[tid*3 + 0], a_int[tid*3 + 1], a_int[tid*3 + 2]); + ext3::Fe3 b = ext3::make(b_int[tid*3 + 0], b_int[tid*3 + 1], b_int[tid*3 + 2]); + ext3::Fe3 r = ext3::mul(a, b); + c_int[tid*3 + 0] = r.a; + c_int[tid*3 + 1] = r.b; + c_int[tid*3 + 2] = r.c; +} + +extern "C" __global__ void ext3_add_kernel(const uint64_t *a_int, + const uint64_t *b_int, + uint64_t *c_int, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + ext3::Fe3 a = ext3::make(a_int[tid*3 + 0], a_int[tid*3 + 1], a_int[tid*3 + 2]); + ext3::Fe3 b = ext3::make(b_int[tid*3 + 0], b_int[tid*3 + 1], b_int[tid*3 + 2]); + ext3::Fe3 r = ext3::add(a, b); + c_int[tid*3 + 0] = r.a; + c_int[tid*3 + 1] = r.b; + c_int[tid*3 + 2] = r.c; +} diff --git a/crypto/math-cuda/kernels/ext3.cuh b/crypto/math-cuda/kernels/ext3.cuh new file mode 100644 index 000000000..2f4040714 --- /dev/null +++ b/crypto/math-cuda/kernels/ext3.cuh @@ -0,0 +1,121 @@ +// Goldilocks cubic extension on device: Fp3 = Fp[w] / (w^3 - 2) +// where Fp is Goldilocks (2^64 - 2^32 + 1). +// +// Layout matches the CPU `Degree3GoldilocksExtensionField` (see +// `crypto/math/src/field/extensions_goldilocks.rs`): an element is a +// 3-tuple `(a, b, c)` representing `a + b*w + c*w^2`. +// +// The reducible `w^3 = 2` means cross-term products get a factor of 2: +// (a0 + a1*w + a2*w^2) * (b0 + b1*w + b2*w^2) +// = (a0*b0 + 2*(a1*b2 + a2*b1)) +// + (a0*b1 + a1*b0 + 2*a2*b2) * w +// + (a0*b2 + a1*b1 + a2*b0) * w^2 +// +// We use the same dot-product-of-three folding as the CPU (which saves +// reductions by summing u128 products before `reduce128`). CUDA has +// `__umul64hi` so we implement `dot_product_3` inline. + +#pragma once +#include "goldilocks.cuh" + +namespace ext3 { + +struct Fe3 { + uint64_t a, b, c; +}; + +__device__ __forceinline__ Fe3 make(uint64_t a, uint64_t b, uint64_t c) { + Fe3 r = {a, b, c}; + return r; +} + +__device__ __forceinline__ Fe3 zero() { return make(0, 0, 0); } +__device__ __forceinline__ Fe3 one() { return make(1, 0, 0); } + +__device__ __forceinline__ Fe3 add(const Fe3 &x, const Fe3 &y) { + return make(goldilocks::add(x.a, y.a), + goldilocks::add(x.b, y.b), + goldilocks::add(x.c, y.c)); +} + +__device__ __forceinline__ Fe3 sub(const Fe3 &x, const Fe3 &y) { + return make(goldilocks::sub(x.a, y.a), + goldilocks::sub(x.b, y.b), + goldilocks::sub(x.c, y.c)); +} + +__device__ __forceinline__ Fe3 neg(const Fe3 &x) { + return make(goldilocks::neg(x.a), + goldilocks::neg(x.b), + goldilocks::neg(x.c)); +} + +/// Mixed: base * ext3 → ext3 (componentwise). +__device__ __forceinline__ Fe3 mul_base(const Fe3 &x, uint64_t s) { + return make(goldilocks::mul(x.a, s), + goldilocks::mul(x.b, s), + goldilocks::mul(x.c, s)); +} + +/// Dot-product of three (a0*b0 + a1*b1 + a2*b2) mod p, with one reduce128 +/// on the sum of three u128 products. Matches CPU `dot_product_3`. +__device__ __forceinline__ uint64_t dot3(uint64_t a0, uint64_t b0, + uint64_t a1, uint64_t b1, + uint64_t a2, uint64_t b2) { + // Split the sum of three u128 products into hi/lo u128 halves, then + // reduce once. We track overflow-count (at most 2) and add EPSILON^2 + // per overflow, matching the CPU path. + // prod_i = a_i * b_i (u128) + uint64_t lo0 = a0 * b0, hi0 = __umul64hi(a0, b0); + uint64_t lo1 = a1 * b1, hi1 = __umul64hi(a1, b1); + uint64_t lo2 = a2 * b2, hi2 = __umul64hi(a2, b2); + + // sum01 = prod0 + prod1 (in u128 lanes) + uint64_t s01_lo = lo0 + lo1; + uint64_t carry01 = (s01_lo < lo0) ? 1ULL : 0ULL; + uint64_t s01_hi = hi0 + hi1 + carry01; + uint32_t over1 = (s01_hi < hi0 + carry01) ? 1u : 0u; // low-pass overflow + + // sum012 = sum01 + prod2 + uint64_t s012_lo = s01_lo + lo2; + uint64_t carry012 = (s012_lo < s01_lo) ? 1ULL : 0ULL; + uint64_t s012_hi = s01_hi + hi2 + carry012; + uint32_t over2 = (s012_hi < hi2 + carry012) ? 1u : 0u; + + uint64_t reduced = goldilocks::reduce128(s012_lo, s012_hi); + + uint32_t overflow_count = over1 + over2; + if (overflow_count > 0) { + // 2^128 mod p = EPSILON^2 (= (2^32 - 1)^2). + uint64_t eps = goldilocks::EPSILON; + uint64_t eps_sq = eps * eps; + reduced = goldilocks::add_no_canonicalize(reduced, eps_sq); + if (overflow_count > 1) { + reduced = goldilocks::add_no_canonicalize(reduced, eps_sq); + } + } + return reduced; +} + +/// Full ext3 × ext3 multiplication (matches CPU +/// `Degree3GoldilocksExtensionField::mul`). +__device__ __forceinline__ Fe3 mul(const Fe3 &x, const Fe3 &y) { + // c0 = x.a*y.a + x.b*(2*y.c) + x.c*(2*y.b) + // c1 = x.a*y.b + x.b*y.a + x.c*(2*y.c) + // c2 = x.a*y.c + x.b*y.b + x.c*y.a + uint64_t b1_2 = goldilocks::add(y.b, y.b); + uint64_t b2_2 = goldilocks::add(y.c, y.c); + + uint64_t c0 = dot3(x.a, y.a, x.b, b2_2, x.c, b1_2); + uint64_t c1 = dot3(x.a, y.b, x.b, y.a, x.c, b2_2); + uint64_t c2 = dot3(x.a, y.c, x.b, y.b, x.c, y.a); + return make(c0, c1, c2); +} + +__device__ __forceinline__ Fe3 canonical(const Fe3 &x) { + return make(goldilocks::canonical(x.a), + goldilocks::canonical(x.b), + goldilocks::canonical(x.c)); +} + +} // namespace ext3 diff --git a/crypto/math-cuda/kernels/goldilocks.cuh b/crypto/math-cuda/kernels/goldilocks.cuh new file mode 100644 index 000000000..5e296a390 --- /dev/null +++ b/crypto/math-cuda/kernels/goldilocks.cuh @@ -0,0 +1,69 @@ +// Goldilocks field on device. Ports `crypto/math/src/field/goldilocks.rs` one-to-one: +// - Representation: non-canonical u64 in [0, 2^64). Canonicalise only at boundaries. +// - Prime: 2^64 - 2^32 + 1. +// - Reduction: exploits 2^64 ≡ EPSILON (mod p) and 2^96 ≡ -1 (mod p). +// +// The arithmetic here must produce bit-identical u64 outputs to the CPU path so +// LDE parity tests can assert raw equality. + +#pragma once +#include + +namespace goldilocks { + +__device__ constexpr uint64_t PRIME = 0xFFFFFFFF00000001ULL; +__device__ constexpr uint64_t EPSILON = 0xFFFFFFFFULL; // 2^32 - 1 + +__device__ __forceinline__ uint64_t add_no_canonicalize(uint64_t x, uint64_t y) { + // Mirror of `add_no_canonicalize_trashing_input`: one add, one EPSILON bump on carry. + uint64_t sum = x + y; + return sum + (sum < x ? EPSILON : 0ULL); +} + +__device__ __forceinline__ uint64_t add(uint64_t a, uint64_t b) { + uint64_t sum = a + b; + uint64_t over1 = (sum < a) ? EPSILON : 0ULL; + uint64_t sum2 = sum + over1; + uint64_t over2 = (sum2 < sum) ? EPSILON : 0ULL; + return sum2 + over2; +} + +__device__ __forceinline__ uint64_t sub(uint64_t a, uint64_t b) { + uint64_t diff = a - b; + uint64_t under1 = (a < b) ? EPSILON : 0ULL; + uint64_t diff2 = diff - under1; + uint64_t under2 = (diff2 > diff) ? EPSILON : 0ULL; + return diff2 - under2; +} + +__device__ __forceinline__ uint64_t reduce128(uint64_t lo, uint64_t hi) { + uint64_t x_hi_hi = hi >> 32; + uint64_t x_hi_lo = hi & EPSILON; + + // 2^96 ≡ -1 (mod p): subtract x_hi_hi from lo, EPSILON-correct on borrow. + uint64_t t0 = lo - x_hi_hi; + if (lo < x_hi_hi) t0 -= EPSILON; + + // 2^64 ≡ EPSILON (mod p): x_hi_lo * EPSILON = (x_hi_lo << 32) - x_hi_lo. + uint64_t t1 = (x_hi_lo << 32) - x_hi_lo; + + return add_no_canonicalize(t0, t1); +} + +__device__ __forceinline__ uint64_t mul(uint64_t a, uint64_t b) { + uint64_t lo = a * b; + uint64_t hi = __umul64hi(a, b); + return reduce128(lo, hi); +} + +__device__ __forceinline__ uint64_t neg(uint64_t a) { + // `a` may be non-canonical. Canonicalise first, then p - a (or 0). + uint64_t canon = (a >= PRIME) ? (a - PRIME) : a; + return canon == 0 ? 0 : (PRIME - canon); +} + +__device__ __forceinline__ uint64_t canonical(uint64_t a) { + return (a >= PRIME) ? (a - PRIME) : a; +} + +} // namespace goldilocks diff --git a/crypto/math-cuda/kernels/ntt.cu b/crypto/math-cuda/kernels/ntt.cu new file mode 100644 index 000000000..2a5c8c786 --- /dev/null +++ b/crypto/math-cuda/kernels/ntt.cu @@ -0,0 +1,285 @@ +// Radix-2 DIT NTT over Goldilocks. One kernel per butterfly level; the caller +// runs `bit_reverse_permute` once before the first level. +// +// Input layout: bit-reversed-order coefficients (after `bit_reverse_permute`). +// Output layout: natural-order evaluations — matches the CPU `evaluate_fft` contract. +// +// Twiddle table: `tw[i] = ω^i` for i in [0, n/2). Stride-indexed per level. + +#include "goldilocks.cuh" + +using goldilocks::add; +using goldilocks::sub; +using goldilocks::mul; + +/// Reverse the low `log_n` bits of each index and swap x[i] ↔ x[rev(i)]. +/// One thread per index; guarded by `tid < rev` to avoid double-swap. +extern "C" __global__ void bit_reverse_permute(uint64_t *x, + uint64_t n, + uint64_t log_n) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + + // __brevll reverses all 64 bits; shift right so result lives in [0, n). + uint64_t rev = __brevll(tid) >> (64 - log_n); + if (tid < rev) { + uint64_t tmp = x[tid]; + x[tid] = x[rev]; + x[rev] = tmp; + } +} + +/// Pointwise multiply: x[i] *= w[i]. Used for coset scaling (w = g^i weights). +extern "C" __global__ void pointwise_mul(uint64_t *x, + const uint64_t *w, + uint64_t n) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) x[tid] = mul(x[tid], w[tid]); +} + +/// Broadcast scalar multiply: x[i] *= c. Used for the 1/n factor at the end of iNTT. +extern "C" __global__ void scalar_mul(uint64_t *x, + uint64_t c, + uint64_t n) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) x[tid] = mul(x[tid], c); +} + +// ============================================================================ +// BATCHED KERNELS +// +// One launch processes M columns at once. The device buffer holds M columns +// back-to-back; column `c` starts at `data + c * col_stride`. gridDim.y is +// the column index, so each block handles one (column, butterfly-window) pair. +// +// The same twiddle table is shared across all columns of a batch (they all +// NTT on the same domain). The coset weights are also shared. +// ============================================================================ + +extern "C" __global__ void bit_reverse_permute_batched(uint64_t *data, + uint64_t n, + uint64_t log_n, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + + uint64_t rev = __brevll(tid) >> (64 - log_n); + if (tid < rev) { + uint64_t tmp = x[tid]; + x[tid] = x[rev]; + x[rev] = tmp; + } +} + +extern "C" __global__ void ntt_dit_level_batched(uint64_t *data, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t level, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n_half = n >> 1; + if (tid >= n_half) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + + uint64_t half = 1ULL << level; + uint64_t block_size = half << 1; + uint64_t block_idx = tid >> level; + uint64_t k = tid & (half - 1); + + uint64_t i0 = block_idx * block_size + k; + uint64_t i1 = i0 + half; + + uint64_t tw_index = k << (log_n - level - 1); + uint64_t w = tw[tw_index]; + + uint64_t u = x[i0]; + uint64_t v = mul(w, x[i1]); + x[i0] = add(u, v); + x[i1] = sub(u, v); +} + +extern "C" __global__ void ntt_dit_8_levels_batched(uint64_t *data, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t base_step, + uint64_t col_stride) { + __shared__ uint64_t tile[256]; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + + uint32_t n_loc_steps = (uint32_t)min((uint64_t)8, log_n - base_step); + + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + + uint64_t group_size = 1ULL << base_step; + uint64_t n_groups = n >> base_step; + uint64_t low_bits = tid / n_groups; + uint64_t high_bits = tid & (n_groups - 1); + uint64_t row = high_bits * group_size + low_bits; + + tile[threadIdx.x] = x[row]; + __syncthreads(); + + uint32_t remaining_high_bits = (uint32_t)(log_n - base_step - 1); + uint32_t high_mask = (1u << remaining_high_bits) - 1u; + + for (uint32_t loc_step = 0; loc_step < n_loc_steps; ++loc_step) { + if (threadIdx.x < 128) { + uint32_t i = threadIdx.x; + uint32_t half = 1u << loc_step; + uint32_t grp = i >> loc_step; + uint32_t grp_pos = i & (half - 1); + uint32_t idx1 = (grp << (loc_step + 1)) + grp_pos; + uint32_t idx2 = idx1 + half; + + uint32_t gs = (uint32_t)base_step + loc_step; + uint32_t ggp = (blockIdx.x << 7) + i; + ggp = ((ggp & high_mask) << (uint32_t)base_step) + (ggp >> remaining_high_bits); + ggp = ggp & ((1u << gs) - 1u); + uint64_t factor = tw[(uint64_t)ggp * (n >> (gs + 1))]; + + uint64_t u = tile[idx1]; + uint64_t v = mul(tile[idx2], factor); + tile[idx1] = add(u, v); + tile[idx2] = sub(u, v); + } + __syncthreads(); + } + + x[row] = tile[threadIdx.x]; +} + + +/// Batched pointwise multiply: first n elements of each column multiplied by +/// the SHARED weight vector `w` (size n). Used for coset scaling — every +/// column of a table sees the same `g^i / N` weights. +extern "C" __global__ void pointwise_mul_batched(uint64_t *data, + const uint64_t *w, + uint64_t n, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + x[tid] = mul(x[tid], w[tid]); +} + +/// Batched broadcast scalar multiply — one scalar c applied to the first n +/// elements of every column. +extern "C" __global__ void scalar_mul_batched(uint64_t *data, + uint64_t c, + uint64_t n, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + x[tid] = mul(x[tid], c); +} + +/// One DIT butterfly level. Thread `tid` (of n/2 total) owns exactly one +/// butterfly pair (i0, i1 = i0 + half). Twiddle picked from the shared full +/// `tw` table at stride `n / block_size`. Kept for log_n < 8 where shmem +/// fusion is overkill. +extern "C" __global__ void ntt_dit_level(uint64_t *x, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t level) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n_half = n >> 1; + if (tid >= n_half) return; + + uint64_t half = 1ULL << level; // 2^ℓ + uint64_t block_size = half << 1; // 2^{ℓ+1} + uint64_t block_idx = tid >> level; // floor(tid / half) + uint64_t k = tid & (half - 1); // tid mod half + + uint64_t i0 = block_idx * block_size + k; + uint64_t i1 = i0 + half; + + // Stride = n / block_size = n >> (level + 1). + uint64_t tw_index = k << (log_n - level - 1); + uint64_t w = tw[tw_index]; + + uint64_t u = x[i0]; + uint64_t v = mul(w, x[i1]); + x[i0] = add(u, v); + x[i1] = sub(u, v); +} + +/// Up to 8 DIT butterfly levels fused in one kernel using shared memory. +/// +/// Ported from Zisk's `br_ntt_8_steps` (`pil2-stark/src/goldilocks/src/ntt_goldilocks.cu`), +/// simplified to single-column. Each block of 256 threads processes 256 +/// elements in on-chip shared memory, running up to 8 butterfly levels +/// without writing to global memory between them — cuts DRAM traffic by up +/// to 8× vs the per-level kernel. +/// +/// `base_step` selects which 8-level window this launch handles (0, 8, 16…). +/// For levels 0–7 the implicit DIT element layout already places all pair +/// mates inside the same 256-block; for higher base_step we remap the loaded +/// row so pair mates land in consecutive shared-memory slots. +/// +/// Expects bit-reversed input (the caller runs `bit_reverse_permute` once +/// before the first kernel launch). +/// +/// Assumes `n` is a multiple of 256, i.e. `log_n >= 8`. +extern "C" __global__ void ntt_dit_8_levels(uint64_t *x, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t base_step) { + __shared__ uint64_t tile[256]; + + uint32_t n_loc_steps = (uint32_t)min((uint64_t)8, log_n - base_step); + + // tid is the *unpermuted* flat index the block/thread would own. + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + + // Row remap: for base_step > 0, gather elements that pair at levels + // `base_step..base_step+7` so they land consecutively in the block. + uint64_t group_size = 1ULL << base_step; + uint64_t n_groups = n >> base_step; // = n / group_size + uint64_t low_bits = tid / n_groups; + uint64_t high_bits = tid & (n_groups - 1); + uint64_t row = high_bits * group_size + low_bits; + + // Load one element per thread. + tile[threadIdx.x] = x[row]; + __syncthreads(); + + // Each butterfly level uses half the threads (128 butterflies per block). + // The global butterfly index `ggp` is recovered from blockIdx + threadIdx + // and reshaped by the same row-remap to find the right twiddle. + uint32_t remaining_high_bits = (uint32_t)(log_n - base_step - 1); // log2(n_groups / 2) + uint32_t high_mask = (1u << remaining_high_bits) - 1u; + + for (uint32_t loc_step = 0; loc_step < n_loc_steps; ++loc_step) { + if (threadIdx.x < 128) { + uint32_t i = threadIdx.x; + uint32_t half = 1u << loc_step; + uint32_t grp = i >> loc_step; + uint32_t grp_pos = i & (half - 1); + uint32_t idx1 = (grp << (loc_step + 1)) + grp_pos; + uint32_t idx2 = idx1 + half; + + // Global step and butterfly position for twiddle lookup. + uint32_t gs = (uint32_t)base_step + loc_step; + uint32_t ggp = (blockIdx.x << 7) + i; // blockIdx * 128 + i + // Un-remap ggp to find its position in the natural ordering. + ggp = ((ggp & high_mask) << (uint32_t)base_step) + (ggp >> remaining_high_bits); + ggp = ggp & ((1u << gs) - 1u); + uint64_t factor = tw[(uint64_t)ggp * (n >> (gs + 1))]; + + uint64_t u = tile[idx1]; + uint64_t v = mul(tile[idx2], factor); + tile[idx1] = add(u, v); + tile[idx2] = sub(u, v); + } + __syncthreads(); + } + + // Store back to the remapped row. + x[row] = tile[threadIdx.x]; +} diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs new file mode 100644 index 000000000..2c70716a6 --- /dev/null +++ b/crypto/math-cuda/src/device.rs @@ -0,0 +1,251 @@ +//! CUDA device context, stream pool, kernel handles, and twiddle cache. +//! +//! One process-wide backend — lazy-initialised on first use. All kernels live +//! on a single CUDA context; a pool of streams lets rayon-parallel callers +//! overlap H2D / compute / D2H. + +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; + +use cudarc::driver::{CudaContext, CudaFunction, CudaSlice, CudaStream}; +use cudarc::nvrtc::Ptx; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsFFTField; + +use crate::Result; +use crate::ntt::{twiddles_forward, twiddles_inverse}; + +/// Reusable pinned host staging buffer. One per stream; the stream's LDE call +/// holds its buffer's lock across the D2H + memcpy-to-user-Vecs window. +/// +/// Allocated with `cuMemHostAlloc(flags=0)` — portable, non-write-combined, +/// so both DMA writes from device and CPU reads into user Vecs run at full +/// speed. Grows power-of-two; never shrinks. +pub struct PinnedStaging { + ptr: *mut u64, + capacity_elems: usize, +} + +// SAFETY: the raw pointer aliases host memory allocated via cuMemHostAlloc. +// We guard concurrent access with a Mutex; the pointer is valid for the +// lifetime of this struct and is freed on drop. +unsafe impl Send for PinnedStaging {} +unsafe impl Sync for PinnedStaging {} + +impl PinnedStaging { + const fn empty() -> Self { + Self { + ptr: std::ptr::null_mut(), + capacity_elems: 0, + } + } + + pub fn ensure_capacity( + &mut self, + min_elems: usize, + ctx: &CudaContext, + ) -> Result<()> { + if self.capacity_elems >= min_elems { + return Ok(()); + } + // cuMemHostAlloc requires the context to be current on this thread. + ctx.bind_to_thread()?; + // Free old (if any) before allocating the new one. + if !self.ptr.is_null() { + unsafe { + let _ = cudarc::driver::sys::cuMemFreeHost(self.ptr as *mut _); + } + self.ptr = std::ptr::null_mut(); + self.capacity_elems = 0; + } + let new_cap = min_elems.next_power_of_two().max(1 << 20); // at least 8 MB + let bytes = new_cap * std::mem::size_of::(); + let ptr = unsafe { + cudarc::driver::result::malloc_host(bytes, 0 /* flags: non-WC */)? + } as *mut u64; + self.ptr = ptr; + self.capacity_elems = new_cap; + Ok(()) + } + + /// View of the first `len` elements. Caller must hold this `PinnedStaging` + /// locked while using the slice; the slice aliases the internal pointer. + /// + /// # Safety + /// Caller must not outlive the `PinnedStaging` and must not race with + /// concurrent uses. + pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [u64] { + assert!(len <= self.capacity_elems); + if len == 0 { + return &mut []; + } + unsafe { std::slice::from_raw_parts_mut(self.ptr, len) } + } +} + +impl Drop for PinnedStaging { + fn drop(&mut self) { + if !self.ptr.is_null() { + unsafe { + let _ = cudarc::driver::sys::cuMemFreeHost(self.ptr as *mut _); + } + } + } +} + +const ARITH_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/arith.ptx")); +const NTT_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/ntt.ptx")); +/// Number of CUDA streams in the pool. Larger pools let many rayon-parallel +/// callers overlap on the GPU without serializing on stream ownership. The +/// default stream is deliberately excluded because it synchronises with all +/// other streams, defeating the point of the pool. +const STREAM_POOL_SIZE: usize = 32; + +pub struct Backend { + pub ctx: Arc, + streams: Vec>, + /// Single shared pinned staging buffer, grown to the biggest LDE size + /// seen. Concurrent batched LDE calls serialise on it; in exchange the + /// process keeps only ONE gigabyte-sized pinned allocation (per-stream + /// buffers 32×-inflated memory use and multiplied the one-time pinning + /// cost for every first use of a new table size). + pinned_staging: Mutex, + util_stream: Arc, + next: AtomicUsize, + + // arith.ptx + pub vector_add_u64: CudaFunction, + pub gl_add: CudaFunction, + pub gl_sub: CudaFunction, + pub gl_mul: CudaFunction, + pub gl_neg: CudaFunction, + pub ext3_mul: CudaFunction, + pub ext3_add: CudaFunction, + + // ntt.ptx + pub bit_reverse_permute: CudaFunction, + pub ntt_dit_level: CudaFunction, + pub ntt_dit_8_levels: CudaFunction, + pub pointwise_mul: CudaFunction, + pub scalar_mul: CudaFunction, + pub bit_reverse_permute_batched: CudaFunction, + pub ntt_dit_level_batched: CudaFunction, + pub ntt_dit_8_levels_batched: CudaFunction, + pub pointwise_mul_batched: CudaFunction, + pub scalar_mul_batched: CudaFunction, + + // Twiddle caches keyed by log_n. + fwd_twiddles: Mutex>>>>, + inv_twiddles: Mutex>>>>, +} + +impl Backend { + fn init() -> Result { + let ctx = CudaContext::new(0)?; + // cudarc's default per-slice CudaEvent tracking adds two driver calls + // per alloc and serialises under the context lock. We never share + // slices across streams (every call scopes its own buffers and syncs + // before returning), so the tracking is pure overhead. Disable it. + unsafe { ctx.disable_event_tracking() }; + + let arith = ctx.load_module(Ptx::from_src(ARITH_PTX))?; + let ntt = ctx.load_module(Ptx::from_src(NTT_PTX))?; + + let mut streams = Vec::with_capacity(STREAM_POOL_SIZE); + for _ in 0..STREAM_POOL_SIZE { + streams.push(ctx.new_stream()?); + } + let pinned_staging = Mutex::new(PinnedStaging::empty()); + // Separate "utility" stream for twiddle uploads and other bookkeeping; + // not part of the pool that callers rotate through. + let util_stream = ctx.new_stream()?; + + // Goldilocks TWO_ADICITY is 32, so log_n ≤ 32 covers every LDE size + // the prover can produce. Overshoot by one for safety. + let max_log = GoldilocksField::TWO_ADICITY as usize + 1; + + Ok(Self { + vector_add_u64: arith.load_function("vector_add_u64")?, + gl_add: arith.load_function("gl_add_kernel")?, + gl_sub: arith.load_function("gl_sub_kernel")?, + gl_mul: arith.load_function("gl_mul_kernel")?, + gl_neg: arith.load_function("gl_neg_kernel")?, + ext3_mul: arith.load_function("ext3_mul_kernel")?, + ext3_add: arith.load_function("ext3_add_kernel")?, + bit_reverse_permute: ntt.load_function("bit_reverse_permute")?, + ntt_dit_level: ntt.load_function("ntt_dit_level")?, + ntt_dit_8_levels: ntt.load_function("ntt_dit_8_levels")?, + pointwise_mul: ntt.load_function("pointwise_mul")?, + scalar_mul: ntt.load_function("scalar_mul")?, + bit_reverse_permute_batched: ntt.load_function("bit_reverse_permute_batched")?, + ntt_dit_level_batched: ntt.load_function("ntt_dit_level_batched")?, + ntt_dit_8_levels_batched: ntt.load_function("ntt_dit_8_levels_batched")?, + pointwise_mul_batched: ntt.load_function("pointwise_mul_batched")?, + scalar_mul_batched: ntt.load_function("scalar_mul_batched")?, + fwd_twiddles: Mutex::new(vec![None; max_log]), + inv_twiddles: Mutex::new(vec![None; max_log]), + ctx, + streams, + pinned_staging, + util_stream, + next: AtomicUsize::new(0), + }) + } + + /// Round-robin over the stream pool. Concurrent callers get different + /// streams so their kernel launches overlap on the GPU. + pub fn next_stream(&self) -> Arc { + let idx = self.next.fetch_add(1, Ordering::Relaxed) % self.streams.len(); + self.streams[idx].clone() + } + + /// Shared pinned staging buffer. Grows to the largest LDE the process + /// has seen so far. Concurrent callers serialise on the mutex. + pub fn pinned_staging(&self) -> &Mutex { + &self.pinned_staging + } + + pub fn fwd_twiddles_for(&self, log_n: u64) -> Result>> { + self.cached_twiddles(log_n, true) + } + + pub fn inv_twiddles_for(&self, log_n: u64) -> Result>> { + self.cached_twiddles(log_n, false) + } + + fn cached_twiddles(&self, log_n: u64, forward: bool) -> Result>> { + let idx = log_n as usize; + let cache = if forward { + &self.fwd_twiddles + } else { + &self.inv_twiddles + }; + { + let guard = cache.lock().unwrap(); + if let Some(t) = &guard[idx] { + return Ok(t.clone()); + } + } + // Compute on host, upload on the utility stream. Another thread may + // have populated the cache in the meantime; prefer that entry. + let host = if forward { + twiddles_forward(log_n) + } else { + twiddles_inverse(log_n) + }; + let dev = Arc::new(self.util_stream.clone_htod(&host)?); + self.util_stream.synchronize()?; + let mut guard = cache.lock().unwrap(); + if let Some(t) = &guard[idx] { + Ok(t.clone()) + } else { + guard[idx] = Some(dev.clone()); + Ok(dev) + } + } +} + +pub fn backend() -> &'static Backend { + static BACKEND: OnceLock = OnceLock::new(); + BACKEND.get_or_init(|| Backend::init().expect("failed to initialise CUDA backend")) +} diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs new file mode 100644 index 000000000..fd25d1bca --- /dev/null +++ b/crypto/math-cuda/src/lde.rs @@ -0,0 +1,984 @@ +//! Full coset LDE on device. Mirrors `Polynomial::coset_lde_full_expand` in +//! `crypto/math/src/fft/polynomial.rs` algebraically: +//! +//! Input : N evaluations (natural order) of a poly on the standard subgroup, +//! plus coset weights (size N). The weights include the `1/N` iFFT +//! normalisation, matching the `LdeTwiddles::coset_weights` format at +//! `crypto/stark/src/prover.rs:248` — i.e. `weights[i] = g^i / N`. +//! Output : N*blowup_factor evaluations (natural order) on the coset. +//! +//! On-device steps, picks a stream from the shared pool so rayon-parallel +//! callers overlap on the GPU. Twiddles are cached in the backend. + +use std::sync::Arc; + +use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; +use crate::ntt::run_ntt_body; + +/// Handle to a base-field LDE kept live on device after R1 commit. +/// Layout: `m` columns, each `lde_size` u64s, column `c` at byte offset +/// `c * lde_size * 8` within `buf`. Freed when `buf` Arc drops. +#[derive(Clone)] +pub struct GpuLdeBase { + pub buf: Arc>, + pub m: usize, + pub lde_size: usize, +} + +/// Handle to an ext3 LDE kept live on device, de-interleaved into 3 base +/// slabs per column. Column `c` component `k` at u64 offset +/// `(c*3 + k) * lde_size` within `buf`. +#[derive(Clone)] +pub struct GpuLdeExt3 { + pub buf: Arc>, + pub m: usize, + pub lde_size: usize, +} + +pub fn coset_lde_base( + evals: &[u64], + blowup_factor: usize, + weights: &[u64], +) -> Result> { + let n = evals.len(); + assert!(n.is_power_of_two(), "evals length must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match evals"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + if n == 0 { + return Ok(Vec::new()); + } + let lde_size = n * blowup_factor; + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + + // Device buffer of lde_size, zero-padded tail, first N filled by copy. + let mut buf = stream.alloc_zeros::(lde_size)?; + { + let mut head = buf.slice_mut(0..n); + stream.memcpy_htod(evals, &mut head)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + + // === 1. iNTT on first N: bit_reverse + 8-level-fused DIT body === + unsafe { + stream + .launch_builder(&be.bit_reverse_permute) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + // Note: `run_ntt_body` expects a standalone CudaSlice; we pass `buf` and + // the kernel walks the first `n_u64` elements via its own indexing. + run_ntt_body(stream.as_ref(), &mut buf, inv_tw.as_ref(), n_u64, log_n)?; + // Note: the CPU iFFT does not include 1/N — it's folded into `weights`. The + // next pointwise multiply applies both the coset shift and the 1/N factor. + + // === 2. Pointwise multiply first N by coset weights (includes 1/N) === + unsafe { + stream + .launch_builder(&be.pointwise_mul) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + + // === 3. Forward NTT on full buffer === + unsafe { + stream + .launch_builder(&be.bit_reverse_permute) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .launch(LaunchConfig::for_num_elems(lde_size as u32))?; + } + run_ntt_body(stream.as_ref(), &mut buf, fwd_tw.as_ref(), lde_u64, log_lde)?; + + let out = stream.clone_dtoh(&buf)?; + stream.synchronize()?; + Ok(out) +} + +/// Batched coset LDE: processes `m` columns (all the same domain) in a single +/// pipeline on one stream. One H2D per column, then per-level batched kernels +/// that launch with `grid.y = m` so a single launch does the butterflies for +/// every column at that level. +/// +/// Returns one `Vec` per input column, each of length `n * blowup_factor`. +pub fn coset_lde_batch_base( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], +) -> Result>> { + if columns.is_empty() { + return Ok(Vec::new()); + } + let m = columns.len(); + let n = columns[0].len(); + assert!(n.is_power_of_two(), "column length must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match column length"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + for c in columns.iter() { + assert_eq!(c.len(), n, "all columns must be the same size"); + } + + if n == 0 { + return Ok(vec![Vec::new(); m]); + } + let lde_size = n * blowup_factor; + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let debug_phases = std::env::var("MATH_CUDA_PHASE_TIMING").is_ok(); + let t_start = if debug_phases { Some(std::time::Instant::now()) } else { None }; + let phase = |label: &str, prev: &mut Option| { + if let Some(p) = prev.as_ref() { + let now = std::time::Instant::now(); + eprintln!(" [{:>6.2} ms] {}", (now - *p).as_secs_f64() * 1e3, label); + *prev = Some(now); + } + }; + let mut last = t_start; + + // Pinned staging. Lock and grow to max(m*n for upload, m*lde_size for + // download). Holding the guard across the whole call serialises concurrent + // batched calls that happened to hash to the same stream slot, but that's + // exactly what we want — one stream can only do one sequence at a time. + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + // SAFETY: staging is locked, the slice alias ends before we unlock. + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + if debug_phases { phase("staging lock + grow", &mut last); } + + // Pack columns into first m*n slots of the pinned buffer. Parallel: pinned + // writes are DRAM-bandwidth bound, saturates at ~8 cores on modern + // hardware, so rayon shaves 20+ ms at prover scale. + use rayon::prelude::*; + let pinned_base_ptr = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: each task writes to a disjoint `[c*n..c*n+n]` region of + // `pinned`, and the outer `staging` lock guarantees no other call is + // using the buffer concurrently. + let dst = unsafe { + std::slice::from_raw_parts_mut( + (pinned_base_ptr as *mut u64).add(c * n), + n, + ) + }; + dst.copy_from_slice(col); + }); + if debug_phases { phase("host pack (pinned, rayon)", &mut last); } + + // Column layout: `buf[c * lde_size + r]`. Zeroed so the [n, lde_size) + // tail of each column is already the zero-pad the CPU path does. + let mut buf = stream.alloc_zeros::(m * lde_size)?; + if debug_phases { stream.synchronize()?; phase("alloc_zeros", &mut last); } + // One memcpy per column from the pinned buffer into the strided slots. + // The pinned source hits PCIe line-rate. + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + if debug_phases { stream.synchronize()?; phase("H2D cols (pinned)", &mut last); } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + if debug_phases { stream.synchronize()?; phase("twiddles + weights", &mut last); } + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // === 1. Bit-reverse first N of every column === + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + if debug_phases { stream.synchronize()?; phase("bit_reverse N", &mut last); } + // === 2. iNTT body over all columns === + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + if debug_phases { stream.synchronize()?; phase("iNTT body", &mut last); } + + // === 3. Pointwise multiply by coset weights (includes 1/N) === + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + // === 4. Bit-reverse full LDE of every column === + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + if debug_phases { stream.synchronize()?; phase("pointwise + bit_reverse LDE", &mut last); } + // === 5. Forward NTT on full LDE of every column === + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + if debug_phases { stream.synchronize()?; phase("forward NTT body", &mut last); } + + // Single big D2H into the reusable pinned staging buffer — pinned, one + // call to the driver, saturates PCIe. + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + stream.synchronize()?; + if debug_phases { phase("D2H (one shot into pinned)", &mut last); } + + // Split pinned → per-column Vecs. The first write to each virgin + // Vec page-faults, which can dominate total time (~75 ms for 128 MB). + // Parallelise so the fault cost spreads across CPU cores. + use rayon::prelude::*; + let pinned_ptr = pinned.as_ptr() as usize; + let out: Vec> = (0..m) + .into_par_iter() + .map(|c| { + let mut v = Vec::::with_capacity(lde_size); + unsafe { v.set_len(lde_size) }; + let src = unsafe { + std::slice::from_raw_parts( + (pinned_ptr as *const u64).add(c * lde_size), + lde_size, + ) + }; + v.copy_from_slice(src); + v + }) + .collect(); + if debug_phases { phase("copy out (rayon pinned → Vecs)", &mut last); } + drop(staging); + Ok(out) +} + +/// Like `coset_lde_batch_base` but writes directly into caller-provided +/// output slices instead of allocating fresh `Vec`s. Each output slice +/// must already have length `n * blowup_factor`. Saves ~50–100 ms of pageable +/// allocator work + page faults at prover scale because the caller's Vecs +/// have been sized once and are reused across calls. +pub fn coset_lde_batch_base_into( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result<()> { + if columns.is_empty() { + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m, "outputs must match columns count"); + let n = columns[0].len(); + assert!(n.is_power_of_two(), "column length must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match column length"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + for c in columns.iter() { + assert_eq!(c.len(), n, "all columns must be the same size"); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), lde_size, "each output must be lde_size"); + } + if n == 0 { + return Ok(()); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + + for (c, col) in columns.iter().enumerate() { + pinned[c * n..c * n + n].copy_from_slice(col); + } + + let mut buf = stream.alloc_zeros::(m * lde_size)?; + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // iNTT bit-reverse + body, pointwise mul, forward bit-reverse + body. + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + stream.synchronize()?; + + // Parallel copy pinned → caller outputs. Caller's Vecs may still fault + // on first write; we spread that cost across rayon cores. + #[allow(unused_imports)] + use rayon::prelude::*; + let pinned_ptr = pinned.as_ptr() as usize; + outputs + .par_iter_mut() + .enumerate() + .for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts( + (pinned_ptr as *const u64).add(c * lde_size), + lde_size, + ) + }; + dst.copy_from_slice(src); + }); + drop(staging); + Ok(()) +} + +/// Batched ext3 polynomial → coset evaluation. +/// +/// Input: M ext3 columns of `n` coefficients each (interleaved, 3n u64). +/// Output: M ext3 columns of `n * blowup_factor` evaluations each at the +/// offset-coset. +/// +/// Skips the iFFT stage of [`coset_lde_batch_ext3_into`] (input is +/// coefficients, not evaluations). Weights encode the coset shift: +/// `weights[k] = offset^k` (NO 1/N because iFFT normalisation doesn't apply). +/// +/// Used by the stark prover to GPU-accelerate +/// `evaluate_polynomial_on_lde_domain` calls inside the +/// `number_of_parts > 2` branch of the composition-polynomial LDE. +pub fn evaluate_poly_coset_batch_ext3_into( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result<()> { + evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + false, + ) + .map(|_| ()) +} + +/// Same as [`evaluate_poly_coset_batch_ext3_into`] but retains the de- +/// interleaved LDE device buffer as a `GpuLdeExt3` handle. Lets R2 commit +/// and R4 DEEP composition read the composition-parts LDE without +/// re-H2D'ing. +pub fn evaluate_poly_coset_batch_ext3_into_keep( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result { + let opt = evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + true, + )?; + Ok(opt.expect("keep_device_buf=true must return Some")) +} + +fn evaluate_poly_coset_batch_ext3_into_inner( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + keep_device_buf: bool, +) -> Result> { + if coefs.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(None); + } + let m = coefs.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in coefs.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + if n == 0 { + return Ok(None); + } + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + coefs.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + // Apply coset scaling: x[k] *= weights[k] for k in 0..n (no iFFT first). + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + // Bit-reverse full lde_size slab, then forward DIT NTT. + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + stream.synchronize()?; + + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + drop(staging); + if keep_device_buf { + Ok(Some(GpuLdeExt3 { + buf: std::sync::Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + +/// Batched coset LDE for Goldilocks **cubic extension** columns. +/// +/// A degree-3 extension element is `(a, b, c)` in memory (three contiguous +/// u64s). The NTT butterfly multiplies `v = (a, b, c)` by a base-field +/// twiddle `t`: `t * v = (t*a, t*b, t*c)`. Addition is componentwise. So an +/// NTT over M ext3 columns is algebraically equivalent to **3M parallel +/// base-field NTTs** sharing the same twiddles and coset weights. We +/// exploit this to reuse the base-field kernels with no modification: +/// +/// 1. Host pack de-interleaves each ext3 column into 3 consecutive +/// base-field slabs inside the pinned staging buffer (slab 0 has all the +/// a-components, slab 1 all the b's, slab 2 all the c's — 3M base slabs +/// in total). +/// 2. Existing `bit_reverse_permute_batched` / `ntt_dit_*_batched` / +/// `pointwise_mul_batched` run over those 3M base slabs on device. +/// 3. D2H, then re-interleave 3 slabs per output ext3 column. +/// +/// Input/output layout: each slice is 3*n or 3*n*blowup u64s, packed as +/// `[a0, b0, c0, a1, b1, c1, ...]` — the natural `[FieldElement]` +/// memory representation. +pub fn coset_lde_batch_ext3_into( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result<()> { + if columns.is_empty() { + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m, "outputs must match columns count"); + assert!(n.is_power_of_two(), "n must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match n"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + for c in columns.iter() { + assert_eq!(c.len(), 3 * n, "each ext3 column must be 3*n u64s"); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size, "each output must be 3*lde_size u64s"); + } + if n == 0 { + return Ok(()); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + // 3 base slabs per ext3 column; slab index `c*3 + k` holds component `k`. + let mb = 3 * m; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + // Pack: for each ext3 column, write 3 base slabs into pinned. The slab + // for column c, component k lives at `pinned[(c*3 + k)*n .. (c*3+k)*n + n]`. + // We de-interleave from the interleaved `[a, b, c, a, b, c, ...]` input. + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: disjoint regions per c; staging lock held. + let slab_a = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 0) * n), + n, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), + n, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), + n, + ) + }; + for i in 0..n { + slab_a[i] = col[i * 3 + 0]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + // Allocate + zero-pad device buffer holding 3M slabs of `lde_size`. + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + // H2D: slab by slab into the first N slots of each `lde_size`-slab. + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + // === Butterflies: identical to the base-field batched path, but with + // grid.y = 3M instead of M. === + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + stream.synchronize()?; + + // Unpack: for each output column, re-interleave 3 slabs back into the + // ext3-per-element layout. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 0) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3 + 0] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + drop(staging); + Ok(()) +} + +/// Run the DIT butterfly body of a bit-reversed-input NTT over `m` batched +/// columns in one device buffer. Same fusion strategy as `run_ntt_body`: +/// first 8 levels shmem-fused (coalesced), subsequent levels one kernel each. +fn run_batched_ntt_body( + stream: &cudarc::driver::CudaStream, + x_dev: &mut cudarc::driver::CudaSlice, + tw_dev: &cudarc::driver::CudaSlice, + n: u64, + log_n: u64, + col_stride: u64, + m: u32, +) -> Result<()> { + let be = backend(); + let fused = core::cmp::min(log_n, 8); + if fused >= 8 { + let grid_x = (n / 256) as u32; + let cfg = LaunchConfig { + grid_dim: (grid_x, m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + let base_step = 0u64; + unsafe { + stream + .launch_builder(&be.ntt_dit_8_levels_batched) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&base_step) + .arg(&col_stride) + .launch(cfg)?; + } + } else { + let grid_x = ((n / 2) as u32).div_ceil(256).max(1); + let cfg = LaunchConfig { + grid_dim: (grid_x, m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + for level in 0..fused { + unsafe { + stream + .launch_builder(&be.ntt_dit_level_batched) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .arg(&col_stride) + .launch(cfg)?; + } + } + } + + let grid_x = ((n / 2) as u32).div_ceil(256).max(1); + let cfg = LaunchConfig { + grid_dim: (grid_x, m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + for level in fused..log_n { + unsafe { + stream + .launch_builder(&be.ntt_dit_level_batched) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .arg(&col_stride) + .launch(cfg)?; + } + } + Ok(()) +} + diff --git a/crypto/math-cuda/src/lib.rs b/crypto/math-cuda/src/lib.rs new file mode 100644 index 000000000..821f5bd3a --- /dev/null +++ b/crypto/math-cuda/src/lib.rs @@ -0,0 +1,152 @@ +//! GPU backend for the lambda-vm STARK prover. +//! +//! Primary entry point: [`lde::coset_lde_base`]. Everything else (`ntt`, +//! element-wise arith) is either internal to the LDE pipeline or used by the +//! parity test suite. + +pub mod device; +pub mod lde; +pub mod ntt; + +use cudarc::driver::{LaunchConfig, PushKernelArg}; + +use crate::device::{Backend, backend}; + +pub type Result = std::result::Result; + +/// Toolchain sanity: plain wrapping u64 vector add. Not a field op. +pub fn vector_add_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.vector_add_u64) +} + +/// Goldilocks field add on device, element-wise. Inputs may be non-canonical. +pub fn gl_add_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.gl_add) +} + +pub fn gl_sub_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.gl_sub) +} + +pub fn gl_mul_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.gl_mul) +} + +pub fn gl_neg_u64(a: &[u64]) -> Result> { + let n = a.len(); + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + + let a_dev = stream.clone_htod(a)?; + let mut c_dev = stream.alloc_zeros::(n)?; + + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.gl_neg) + .arg(&a_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Element-wise ext3 multiply. `a` and `b` are 3n u64s (interleaved +/// [a0,a1,a2,b0,b1,b2,...]). Test helper for the `ext3.cuh` header. +pub fn ext3_mul_u64(a: &[u64], b: &[u64]) -> Result> { + assert_eq!(a.len(), b.len()); + assert_eq!(a.len() % 3, 0); + let n = a.len() / 3; + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + let a_dev = stream.clone_htod(a)?; + let b_dev = stream.clone_htod(b)?; + let mut c_dev = stream.alloc_zeros::(3 * n)?; + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.ext3_mul) + .arg(&a_dev) + .arg(&b_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Element-wise ext3 add. +pub fn ext3_add_u64(a: &[u64], b: &[u64]) -> Result> { + assert_eq!(a.len(), b.len()); + assert_eq!(a.len() % 3, 0); + let n = a.len() / 3; + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + let a_dev = stream.clone_htod(a)?; + let b_dev = stream.clone_htod(b)?; + let mut c_dev = stream.alloc_zeros::(3 * n)?; + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.ext3_add) + .arg(&a_dev) + .arg(&b_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} + +fn launch_binary_u64(a: &[u64], b: &[u64], pick: F) -> Result> +where + F: for<'a> Fn(&'a Backend) -> &'a cudarc::driver::CudaFunction, +{ + assert_eq!(a.len(), b.len(), "length mismatch"); + let n = a.len(); + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + + let a_dev = stream.clone_htod(a)?; + let b_dev = stream.clone_htod(b)?; + let mut c_dev = stream.alloc_zeros::(n)?; + + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(pick(be)) + .arg(&a_dev) + .arg(&b_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} diff --git a/crypto/math-cuda/src/ntt.rs b/crypto/math-cuda/src/ntt.rs new file mode 100644 index 000000000..0ebb015ea --- /dev/null +++ b/crypto/math-cuda/src/ntt.rs @@ -0,0 +1,211 @@ +//! Forward and inverse NTT over Goldilocks base field. Matches the algebraic +//! contract of `math::polynomial::Polynomial::evaluate_fft` / +//! `interpolate_fft`: +//! input = n elements in natural order +//! output = n elements in natural order. +//! +//! Parity is checked by `tests/ntt.rs` against the CPU implementation. + +use cudarc::driver::{LaunchConfig, PushKernelArg}; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsFFTField, IsField}; + +use crate::Result; +use crate::device::backend; + +/// Host-side twiddle table: `[ω^0, ω^1, …, ω^{n/2-1}]` where ω is the +/// primitive n-th root of unity. Exposed for `device::Backend::cached_twiddles` +/// and for direct use in tests / benches. +pub fn twiddles_forward(log_n: u64) -> Vec { + let omega = *GoldilocksField::get_primitive_root_of_unity(log_n) + .expect("primitive root") + .value(); + powers_of(omega, 1usize << (log_n - 1)) +} + +/// Inverse twiddle table: `[ω^{-i}]` for i in [0, n/2). +pub fn twiddles_inverse(log_n: u64) -> Vec { + let omega = GoldilocksField::get_primitive_root_of_unity(log_n).expect("primitive root"); + let omega_inv = FieldElement::::inv(&omega).expect("inverse"); + powers_of(*omega_inv.value(), 1usize << (log_n - 1)) +} + +fn powers_of(base: u64, count: usize) -> Vec { + let mut out = Vec::with_capacity(count); + let mut w = 1u64; + for _ in 0..count { + out.push(w); + w = GoldilocksField::mul(&w, &base); + } + out +} + +/// Forward NTT on a slice of `n = 2^log_n` Goldilocks coefficients. Takes +/// natural-order input and returns natural-order evaluations. +pub fn forward(coeffs: &[u64]) -> Result> { + ntt_inplace(coeffs, /*forward=*/ true) +} + +/// Inverse NTT on a slice of `n = 2^log_n` Goldilocks evaluations. Takes +/// natural-order evaluations and returns natural-order coefficients. Includes +/// the 1/n scaling. +pub fn inverse(evals: &[u64]) -> Result> { + ntt_inplace(evals, /*forward=*/ false) +} + +fn ntt_inplace(input: &[u64], forward: bool) -> Result> { + let n = input.len(); + assert!(n.is_power_of_two(), "ntt length must be a power of two"); + if n <= 1 { + return Ok(input.to_vec()); + } + let log_n = n.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + + let mut x_dev = stream.clone_htod(input)?; + let tw_dev = if forward { + be.fwd_twiddles_for(log_n)? + } else { + be.inv_twiddles_for(log_n)? + }; + + let n_u64 = n as u64; + + // 1. Bit-reverse: natural → bit-reversed. + unsafe { + stream + .launch_builder(&be.bit_reverse_permute) + .arg(&mut x_dev) + .arg(&n_u64) + .arg(&log_n) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + + // 2. DIT butterfly levels. For log_n >= 8 we fuse 8 levels per kernel via + // the shmem kernel; for very small sizes (< 256 elements) we stick with + // the per-level kernel because the shmem block dimensions assume n ≥ 256. + run_ntt_body( + stream.as_ref(), + &mut x_dev, + tw_dev.as_ref(), + n_u64, + log_n, + )?; + + // 3. For iNTT, multiply by 1/n. + if !forward { + let n_fe = FieldElement::::from(n as u64); + let inv_n = *n_fe.inv().expect("n is non-zero").value(); + unsafe { + stream + .launch_builder(&be.scalar_mul) + .arg(&mut x_dev) + .arg(&inv_n) + .arg(&n_u64) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + } + + let out = stream.clone_dtoh(&x_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Run the butterfly body of a bit-reversed-input DIT NTT. Split out so the +/// LDE orchestrator can reuse it on the same device buffer. +pub(crate) fn run_ntt_body( + stream: &cudarc::driver::CudaStream, + x_dev: &mut cudarc::driver::CudaSlice, + tw_dev: &cudarc::driver::CudaSlice, + n: u64, + log_n: u64, +) -> Result<()> { + let be = backend(); + // Levels 0..min(log_n, 8): one shmem-fused launch. Loads are fully + // coalesced (base_step=0 → `row = tid`) and 8 butterfly rounds stay on + // chip. This is the big DRAM-bandwidth win. + let fused = core::cmp::min(log_n, 8); + if fused >= 8 { + let grid_x = (n / 256) as u32; + let cfg = LaunchConfig { + grid_dim: (grid_x, 1, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + let base_step = 0u64; + unsafe { + stream + .launch_builder(&be.ntt_dit_8_levels) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&base_step) + .launch(cfg)?; + } + } else { + // Sub-256-element NTT. Use per-level. + let half_cfg = LaunchConfig::for_num_elems((n / 2) as u32); + for level in 0..fused { + unsafe { + stream + .launch_builder(&be.ntt_dit_level) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .launch(half_cfg)?; + } + } + } + + // Levels 8..log_n: per-level kernels. Loads are fully coalesced in the + // per-level path; switching to fused-with-row-remap at base_step>0 tanks + // DRAM throughput enough to wipe out the launch savings. + let half_cfg = LaunchConfig::for_num_elems((n / 2) as u32); + for level in fused..log_n { + unsafe { + stream + .launch_builder(&be.ntt_dit_level) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .launch(half_cfg)?; + } + } + Ok(()) +} + +/// Pointwise multiply: `x[i] *= w[i]`. +pub fn pointwise_mul(x: &[u64], w: &[u64]) -> Result> { + assert_eq!(x.len(), w.len()); + let n = x.len(); + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + + let mut x_dev = stream.clone_htod(x)?; + let w_dev = stream.clone_htod(w)?; + + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.pointwise_mul) + .arg(&mut x_dev) + .arg(&w_dev) + .arg(&n_u64) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + + let out = stream.clone_dtoh(&x_dev)?; + stream.synchronize()?; + Ok(out) +} diff --git a/crypto/math-cuda/tests/bench_quick.rs b/crypto/math-cuda/tests/bench_quick.rs new file mode 100644 index 000000000..561331b74 --- /dev/null +++ b/crypto/math-cuda/tests/bench_quick.rs @@ -0,0 +1,356 @@ +//! Informal timing comparison for single-column and multi-column LDE. +//! Ignored by default; run with `cargo test ... -- --ignored --nocapture`. + +use std::time::Instant; + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsField; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::prelude::*; + +type Fp = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = *FieldElement::::from(n as u64).inv().unwrap().value(); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &g); + } + w +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_2_to_18_blowup_4() { + let log_n = 18; + let blowup = 4; + let n = 1usize << log_n; + let lde = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(1); + let input: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + let weights = coset_weights(n, 7); + + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new(lde.trailing_zeros() as u64).unwrap(); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + + const TRIALS: u32 = 10; + + let t0 = Instant::now(); + for _ in 0..TRIALS { + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + } + let gpu_ns = t0.elapsed().as_nanos() / TRIALS as u128; + + let t0 = Instant::now(); + for _ in 0..TRIALS { + let mut buf: Vec = input.iter().map(|&x| Fp::from_raw(x)).collect(); + Polynomial::coset_lde_full_expand::( + &mut buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + std::hint::black_box(&buf); + } + let cpu_ns = t0.elapsed().as_nanos() / TRIALS as u128; + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "single-column LDE 2^{log_n} blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x", + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_2_to_16_blowup_4() { + let log_n = 16; + let blowup = 4; + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(2); + let input: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + let weights = coset_weights(n, 7); + + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + + const TRIALS: u32 = 20; + + let t0 = Instant::now(); + for _ in 0..TRIALS { + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + } + let gpu_ns = t0.elapsed().as_nanos() / TRIALS as u128; + println!("single-column LDE 2^{log_n} blowup={blowup}: gpu={gpu_ns}ns"); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_multi_column_parallel() { + // Simulates the prover's Phase A: many columns processed via rayon. + // log_n = 16 keeps memory footprint manageable while still stressing streams. + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let lde = n * blowup; + let num_cols = 64; + + // Warm up. + let _ = math_cuda::lde::coset_lde_base( + &vec![0u64; n], + blowup, + &coset_weights(n, 7), + ) + .unwrap(); + + // Build input data. + let mut rng = ChaCha8Rng::seed_from_u64(11); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new(lde.trailing_zeros() as u64).unwrap(); + + // GPU: rayon parallel across columns, each column picks a stream. + let t0 = Instant::now(); + let _gpu_results: Vec> = columns + .par_iter() + .map(|col| math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap()) + .collect(); + let gpu_ns = t0.elapsed().as_nanos(); + + // CPU: same rayon parallel pattern as the prover's `expand_columns_to_lde`. + let mut cpu_bufs: Vec> = columns + .iter() + .map(|c| c.iter().map(|&x| Fp::from_raw(x)).collect()) + .collect(); + let t0 = Instant::now(); + cpu_bufs.par_iter_mut().for_each(|buf| { + Polynomial::coset_lde_full_expand::( + buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + }); + let cpu_ns = t0.elapsed().as_nanos(); + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "{num_cols}-column LDE 2^{log_n} blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x (cores={})", + rayon::current_num_threads(), + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_batched_prover_scale() { + // Realistic large-table shape from the 1M-fib prover: ~1M rows, blowup 4, + // a few dozen columns. This is what actually runs in expand_columns_to_lde. + let log_n = 20u32; // 1M rows + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 20; + + let mut rng = ChaCha8Rng::seed_from_u64(31); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new( + (n * blowup).trailing_zeros() as u64, + ) + .unwrap(); + + let warm_slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + for _ in 0..8 { + let _ = + math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); + } + + let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + let mut gpu_samples = Vec::with_capacity(10); + for _ in 0..10 { + let t0 = Instant::now(); + let _ = math_cuda::lde::coset_lde_batch_base(&slices, blowup, &weights).unwrap(); + gpu_samples.push(t0.elapsed().as_nanos()); + } + gpu_samples.sort(); + let gpu_ns = gpu_samples[gpu_samples.len() / 2]; // median + + let mut cpu_samples = Vec::with_capacity(10); + for _ in 0..10 { + let mut cpu_bufs: Vec> = columns + .iter() + .map(|c| c.iter().map(|&x| Fp::from_raw(x)).collect()) + .collect(); + let t0 = Instant::now(); + cpu_bufs.par_iter_mut().for_each(|buf| { + Polynomial::coset_lde_full_expand::( + buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + }); + cpu_samples.push(t0.elapsed().as_nanos()); + } + cpu_samples.sort(); + let cpu_ns = cpu_samples[cpu_samples.len() / 2]; // median + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "prover-scale batched {num_cols} cols, log_n={log_n}, blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x (median of 10)", + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_batched_vs_rayon_cpu() { + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 64; + + let mut rng = ChaCha8Rng::seed_from_u64(21); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + + // Warm up every stream slot so subsequent iterations don't pay the + // one-time pinned staging alloc cost. + let warm_slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + for _ in 0..64 { + let _ = + math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); + } + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new( + (n * blowup).trailing_zeros() as u64, + ) + .unwrap(); + + // GPU batched — first run may include lazy device init; do a few to stabilise. + let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + let mut gpu_ns = u128::MAX; + for _ in 0..5 { + let t0 = Instant::now(); + let _ = math_cuda::lde::coset_lde_batch_base(&slices, blowup, &weights).unwrap(); + gpu_ns = gpu_ns.min(t0.elapsed().as_nanos()); + } + + // CPU rayon (same pattern as prover). + let mut cpu_bufs: Vec> = columns + .iter() + .map(|c| c.iter().map(|&x| Fp::from_raw(x)).collect()) + .collect(); + let t0 = Instant::now(); + cpu_bufs.par_iter_mut().for_each(|buf| { + Polynomial::coset_lde_full_expand::( + buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + }); + let cpu_ns = t0.elapsed().as_nanos(); + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "batched {num_cols} cols, log_n={log_n}, blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x (cores={})", + rayon::current_num_threads(), + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_multi_column_serialized_gpu() { + use std::sync::Mutex; + + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 64; + + let _ = math_cuda::lde::coset_lde_base( + &vec![0u64; n], + blowup, + &coset_weights(n, 7), + ) + .unwrap(); + + let mut rng = ChaCha8Rng::seed_from_u64(13); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + + // Single global Mutex so only one thread at a time calls GPU. + let gpu_lock = Mutex::new(()); + let t0 = Instant::now(); + let _: Vec> = columns + .par_iter() + .map(|col| { + let _guard = gpu_lock.lock().unwrap(); + math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap() + }) + .collect(); + let gpu_ns = t0.elapsed().as_nanos(); + println!("GPU mutex-serialised from rayon: {gpu_ns}ns for {num_cols} cols"); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_multi_column_gpu_limited_threads() { + // Same as multi_column_parallel but forces rayon to use only 8 threads + // (matching the GPU stream pool rough capacity). Tests whether oversubscribed + // rayon + many streams is the bottleneck. + let gpu_pool = rayon::ThreadPoolBuilder::new() + .num_threads(8) + .build() + .unwrap(); + + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 64; + + let _ = math_cuda::lde::coset_lde_base( + &vec![0u64; n], + blowup, + &coset_weights(n, 7), + ) + .unwrap(); + + let mut rng = ChaCha8Rng::seed_from_u64(12); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + + let t0 = Instant::now(); + let _gpu_results: Vec> = gpu_pool.install(|| { + columns + .par_iter() + .map(|col| math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap()) + .collect() + }); + let gpu_ns = t0.elapsed().as_nanos(); + + let t0 = Instant::now(); + let _serial_gpu_results: Vec> = columns + .iter() + .map(|col| math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap()) + .collect(); + let gpu_serial_ns = t0.elapsed().as_nanos(); + + println!( + "GPU-only 8-thread: gpu-parallel={gpu_ns}ns gpu-serial={gpu_serial_ns}ns speedup={:.2}x", + gpu_serial_ns as f64 / gpu_ns as f64, + ); +} diff --git a/crypto/math-cuda/tests/evaluate_coset_ext3.rs b/crypto/math-cuda/tests/evaluate_coset_ext3.rs new file mode 100644 index 000000000..a79195291 --- /dev/null +++ b/crypto/math-cuda/tests/evaluate_coset_ext3.rs @@ -0,0 +1,143 @@ +//! Parity test for `evaluate_poly_coset_batch_ext3_into`. +//! +//! Reference: `math::polynomial::Polynomial::evaluate_offset_fft` on an ext3 +//! polynomial, then canonicalise. The GPU path should produce the same +//! evaluations on the offset-coset at `n * blowup` points. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn offset_weights(n: usize, offset: u64) -> Vec { + let mut w = Vec::with_capacity(n); + let mut cur = 1u64; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &offset); + } + w +} + +fn rand_ext3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn u64s_to_ext3(raw: &[u64]) -> Vec { + let mut out = Vec::with_capacity(raw.len() / 3); + for i in 0..raw.len() / 3 { + out.push(Fp3::new([ + Fp::from_raw(raw[i * 3 + 0]), + Fp::from_raw(raw[i * 3 + 1]), + Fp::from_raw(raw[i * 3 + 2]), + ])); + } + out +} + +fn canon_fp3(e: &Fp3) -> [u64; 3] { + [ + GoldilocksField::canonical(e.value()[0].value()), + GoldilocksField::canonical(e.value()[1].value()), + GoldilocksField::canonical(e.value()[2].value()), + ] +} + +fn assert_evaluate_coset(log_n: u64, blowup: usize, m: usize, offset: u64, seed: u64) { + let n = 1usize << log_n; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + + // M ext3 polynomials, each of degree < n. + let polys: Vec> = (0..m) + .map(|_| (0..n).map(|_| rand_ext3(&mut rng)).collect()) + .collect(); + + let weights = offset_weights(n, offset); + + // CPU reference: evaluate each polynomial at `offset`-coset of size lde_size. + let offset_fp = Fp::from_raw(offset); + let cpu: Vec> = polys + .iter() + .map(|coefs| { + let p = Polynomial::new(coefs); + Polynomial::evaluate_offset_fft::( + &p, + blowup, + Some(n), + &offset_fp, + ) + .unwrap() + }) + .collect(); + + // GPU: flatten each poly to 3n u64s, pre-allocate 3*lde_size u64 outputs. + let flat_inputs: Vec> = polys.iter().map(|p| ext3_to_u64s(p)).collect(); + let input_slices: Vec<&[u64]> = flat_inputs.iter().map(|v| v.as_slice()).collect(); + let mut flat_outputs: Vec> = (0..m).map(|_| vec![0u64; 3 * lde_size]).collect(); + { + let mut out_slices: Vec<&mut [u64]> = + flat_outputs.iter_mut().map(|v| v.as_mut_slice()).collect(); + math_cuda::lde::evaluate_poly_coset_batch_ext3_into( + &input_slices, + n, + blowup, + &weights, + &mut out_slices, + ) + .unwrap(); + } + + for c in 0..m { + let gpu: Vec = u64s_to_ext3(&flat_outputs[c]); + assert_eq!(gpu.len(), cpu[c].len(), "length mismatch"); + for i in 0..gpu.len() { + let g = canon_fp3(&gpu[i]); + let cc = canon_fp3(&cpu[c][i]); + assert_eq!(g, cc, "eval mismatch col={c} row={i} log_n={log_n} blowup={blowup}"); + } + } +} + +#[test] +fn ext3_evaluate_coset_small() { + for &m in &[1usize, 4] { + for log_n in 4..=10 { + for &blowup in &[2usize, 4] { + assert_evaluate_coset(log_n, blowup, m, 7, 100 + log_n * 10 + m as u64); + } + } + } +} + +#[test] +fn ext3_evaluate_coset_medium() { + for log_n in 11..=14 { + assert_evaluate_coset(log_n, 4, 2, 7, 200 + log_n); + } +} + +#[test] +fn ext3_evaluate_coset_large_one_column() { + assert_evaluate_coset(16, 4, 1, 7, 0xCAFE); +} diff --git a/crypto/math-cuda/tests/ext3.rs b/crypto/math-cuda/tests/ext3.rs new file mode 100644 index 000000000..c9aabbc27 --- /dev/null +++ b/crypto/math-cuda/tests/ext3.rs @@ -0,0 +1,87 @@ +//! Parity: GPU ext3 arithmetic must agree (canonically) with CPU +//! `Degree3GoldilocksExtensionField` on random ext3 inputs. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +const N: usize = 10_000; + +fn random_fp3s(seed: u64, count: usize) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..count) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() +} + +fn to_u64s(col: &[Fp3]) -> Vec { + let mut v = Vec::with_capacity(col.len() * 3); + for e in col { + v.push(*e.value()[0].value()); + v.push(*e.value()[1].value()); + v.push(*e.value()[2].value()); + } + v +} + +fn canon_triplet(e: &Fp3) -> [u64; 3] { + [ + GoldilocksField::canonical(e.value()[0].value()), + GoldilocksField::canonical(e.value()[1].value()), + GoldilocksField::canonical(e.value()[2].value()), + ] +} + +fn canon_triplet_raw(t: &[u64]) -> [u64; 3] { + [ + GoldilocksField::canonical(&t[0]), + GoldilocksField::canonical(&t[1]), + GoldilocksField::canonical(&t[2]), + ] +} + +#[test] +fn ext3_mul_matches_cpu() { + let a = random_fp3s(11, N); + let b = random_fp3s(22, N); + let a_raw = to_u64s(&a); + let b_raw = to_u64s(&b); + let gpu = math_cuda::ext3_mul_u64(&a_raw, &b_raw).unwrap(); + assert_eq!(gpu.len(), 3 * N); + for i in 0..N { + use math::field::traits::IsField; + let cpu = Degree3GoldilocksExtensionField::mul(a[i].value(), b[i].value()); + let cpu_fp3 = Fp3::new(cpu); + let g = canon_triplet_raw(&gpu[i * 3..(i + 1) * 3]); + let c = canon_triplet(&cpu_fp3); + assert_eq!(g, c, "ext3 mul mismatch at {i}"); + } +} + +#[test] +fn ext3_add_matches_cpu() { + let a = random_fp3s(33, N); + let b = random_fp3s(44, N); + let a_raw = to_u64s(&a); + let b_raw = to_u64s(&b); + let gpu = math_cuda::ext3_add_u64(&a_raw, &b_raw).unwrap(); + for i in 0..N { + let cpu = Degree3GoldilocksExtensionField::add(a[i].value(), b[i].value()); + let cpu_fp3 = Fp3::new(cpu); + let g = canon_triplet_raw(&gpu[i * 3..(i + 1) * 3]); + let c = canon_triplet(&cpu_fp3); + assert_eq!(g, c, "ext3 add mismatch at {i}"); + } +} diff --git a/crypto/math-cuda/tests/goldilocks.rs b/crypto/math-cuda/tests/goldilocks.rs new file mode 100644 index 000000000..317ffb0f8 --- /dev/null +++ b/crypto/math-cuda/tests/goldilocks.rs @@ -0,0 +1,127 @@ +//! GPU must produce bit-identical u64 outputs to `GoldilocksField` for every op. +//! Non-canonical inputs are expected (CPU operates on the full [0, 2^64) range), +//! so the test inputs include values above the prime. + +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +const N: usize = 10_000; + +fn sample_inputs(seed: u64) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..N).map(|_| rng.r#gen::()).collect() +} + +fn assert_raw_eq(op: &str, expected: &[u64], actual: &[u64]) { + assert_eq!(expected.len(), actual.len()); + for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() { + if e != a { + panic!( + "{op} mismatch at {i}: cpu={e:#018x} (canon {:#018x}), gpu={a:#018x} (canon {:#018x})", + GoldilocksField::canonical(e), + GoldilocksField::canonical(a), + ); + } + } +} + +#[test] +fn gpu_vector_add_u64_matches_wrapping() { + let a = sample_inputs(0xC0FFEE); + let b = sample_inputs(0xDEADBEEF); + let expected: Vec = a.iter().zip(&b).map(|(x, y)| x.wrapping_add(*y)).collect(); + let actual = math_cuda::vector_add_u64(&a, &b).expect("GPU vector_add_u64"); + assert_raw_eq("vector_add (wrapping)", &expected, &actual); +} + +#[test] +fn gpu_gl_add_matches_cpu() { + let a = sample_inputs(1); + let b = sample_inputs(2); + let expected: Vec = a + .iter() + .zip(&b) + .map(|(x, y)| GoldilocksField::add(x, y)) + .collect(); + let actual = math_cuda::gl_add_u64(&a, &b).expect("GPU gl_add"); + assert_raw_eq("gl_add", &expected, &actual); +} + +#[test] +fn gpu_gl_sub_matches_cpu() { + let a = sample_inputs(3); + let b = sample_inputs(4); + let expected: Vec = a + .iter() + .zip(&b) + .map(|(x, y)| GoldilocksField::sub(x, y)) + .collect(); + let actual = math_cuda::gl_sub_u64(&a, &b).expect("GPU gl_sub"); + assert_raw_eq("gl_sub", &expected, &actual); +} + +#[test] +fn gpu_gl_mul_matches_cpu() { + let a = sample_inputs(5); + let b = sample_inputs(6); + let expected: Vec = a + .iter() + .zip(&b) + .map(|(x, y)| GoldilocksField::mul(x, y)) + .collect(); + let actual = math_cuda::gl_mul_u64(&a, &b).expect("GPU gl_mul"); + assert_raw_eq("gl_mul", &expected, &actual); +} + +#[test] +fn gpu_gl_neg_matches_cpu() { + let a = sample_inputs(7); + let expected: Vec = a.iter().map(|x| GoldilocksField::neg(x)).collect(); + let actual = math_cuda::gl_neg_u64(&a).expect("GPU gl_neg"); + assert_raw_eq("gl_neg", &expected, &actual); +} + +/// Edge cases the random generator is unlikely to hit: 0, 1, p-1, p, p+1, 2p-1, +/// u64::MAX, EPSILON boundary values. Covers double-overflow / double-underflow. +#[test] +fn gpu_goldilocks_edge_cases() { + const P: u64 = 0xFFFF_FFFF_0000_0001; + const EPS: u64 = 0xFFFF_FFFF; + let edge: [u64; 11] = [ + 0, + 1, + P - 1, + P, + P + 1, + 2u64.wrapping_mul(P).wrapping_sub(1), + u64::MAX, + u64::MAX - EPS, + u64::MAX - 1, + EPS, + EPS - 1, + ]; + // All pairs via nested loops, materialised as flat a[], b[] of length edge^2. + let mut a = Vec::with_capacity(edge.len() * edge.len()); + let mut b = Vec::with_capacity(edge.len() * edge.len()); + for &x in &edge { + for &y in &edge { + a.push(x); + b.push(y); + } + } + + let cases: &[(&str, fn(&[u64], &[u64]) -> math_cuda::Result>, fn(&u64, &u64) -> u64)] = + &[ + ("gl_add", math_cuda::gl_add_u64, GoldilocksField::add), + ("gl_sub", math_cuda::gl_sub_u64, GoldilocksField::sub), + ("gl_mul", math_cuda::gl_mul_u64, GoldilocksField::mul), + ]; + + for (op, gpu_fn, cpu_fn) in cases { + let expected: Vec = a.iter().zip(&b).map(|(x, y)| cpu_fn(x, y)).collect(); + let actual = gpu_fn(&a, &b).expect("GPU op"); + assert_raw_eq(op, &expected, &actual); + } +} diff --git a/crypto/math-cuda/tests/lde.rs b/crypto/math-cuda/tests/lde.rs new file mode 100644 index 000000000..9648f833a --- /dev/null +++ b/crypto/math-cuda/tests/lde.rs @@ -0,0 +1,112 @@ +//! Phase-5 parity: GPU `coset_lde_base` must match the CPU +//! `Polynomial::coset_lde_full_expand` for a sweep of realistic sizes and +//! blowup factors. + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; + +/// Build the coset weights `[1/N, g/N, g²/N, …, g^{n-1}/N]` — this is the +/// layout `crypto/stark/src/prover.rs:248` uses, with `1/N` pre-folded into the +/// first coefficient so the iFFT step does not need a separate scaling pass. +fn coset_weights(n: usize, coset_offset: u64) -> Vec { + let inv_n_fe = FieldElement::::from(n as u64) + .inv() + .expect("n is non-zero"); + let mut w = Vec::with_capacity(n); + let mut cur = *inv_n_fe.value(); + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &coset_offset); + } + w +} + +fn cpu_lde(evals: &[u64], blowup_factor: usize, coset_offset: u64) -> Vec { + let n = evals.len(); + let log_n = n.trailing_zeros() as u64; + let log_lde = (n * blowup_factor).trailing_zeros() as u64; + + let inv_tw = LayerTwiddles::::new_inverse(log_n).expect("inv tw"); + let fwd_tw = LayerTwiddles::::new(log_lde).expect("fwd tw"); + let weights_raw = coset_weights(n, coset_offset); + let weights: Vec = weights_raw.iter().map(|&w| Fp::from_raw(w)).collect(); + + let mut buf: Vec = evals.iter().map(|&x| Fp::from_raw(x)).collect(); + Polynomial::coset_lde_full_expand::( + &mut buf, + blowup_factor, + &weights, + &inv_tw, + &fwd_tw, + ) + .expect("cpu lde"); + + buf.into_iter().map(|e| *e.value()).collect() +} + +fn canon(xs: &[u64]) -> Vec { + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() +} + +fn assert_lde_match(log_n: u64, blowup_factor: usize, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let evals: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + + // Use a fixed, public coset offset. For lambda-vm the coset offset is the + // generator of Goldilocks' multiplicative subgroup; any non-trivial element + // works for an isolated correctness check. + let coset_offset: u64 = 7; + let weights = coset_weights(n, coset_offset); + + let cpu = cpu_lde(&evals, blowup_factor, coset_offset); + let gpu = math_cuda::lde::coset_lde_base(&evals, blowup_factor, &weights).expect("gpu lde"); + + assert_eq!(cpu.len(), gpu.len(), "length mismatch (log_n={log_n}, blowup={blowup_factor})"); + let cpu_c = canon(&cpu); + let gpu_c = canon(&gpu); + for (i, (e, a)) in cpu_c.iter().zip(&gpu_c).enumerate() { + if e != a { + panic!( + "lde mismatch log_n={log_n} blowup={blowup_factor} i={i}: cpu {e:#018x}, gpu {a:#018x}", + ); + } + } +} + +#[test] +fn lde_small() { + for log_n in 4..=10 { + for &blow in &[2usize, 4, 8] { + assert_lde_match(log_n, blow, 1_000 + log_n + (blow as u64)); + } + } +} + +#[test] +fn lde_medium() { + for log_n in 11..=14 { + for &blow in &[2usize, 4] { + assert_lde_match(log_n, blow, 2_000 + log_n + (blow as u64)); + } + } +} + +#[test] +fn lde_large_2_to_18() { + // 2^18 × blowup 4 = 2^20 LDE — representative of Phase A trace columns. + assert_lde_match(18, 4, 0xCAFE); +} + +#[test] +fn lde_largest_2_to_20() { + // 2^20 LDE is the hot size; blowup 2 keeps total = 2^21 (within TWO_ADICITY). + assert_lde_match(20, 2, 0xF00D); +} diff --git a/crypto/math-cuda/tests/lde_batch.rs b/crypto/math-cuda/tests/lde_batch.rs new file mode 100644 index 000000000..67f975728 --- /dev/null +++ b/crypto/math-cuda/tests/lde_batch.rs @@ -0,0 +1,96 @@ +//! Batched coset LDE must agree with running the CPU single-column LDE on +//! each column independently. Sweeps a few realistic (n, blowup, m) tuples. + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = *FieldElement::::from(n as u64) + .inv() + .unwrap() + .value(); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &g); + } + w +} + +fn cpu_lde_one(col: &[u64], blowup: usize, weights_fp: &[Fp], inv_tw: &LayerTwiddles, fwd_tw: &LayerTwiddles) -> Vec { + let mut buf: Vec = col.iter().map(|&x| Fp::from_raw(x)).collect(); + Polynomial::coset_lde_full_expand::( + &mut buf, blowup, weights_fp, inv_tw, fwd_tw, + ) + .unwrap(); + buf.into_iter().map(|e| *e.value()).collect() +} + +fn canon(xs: &[u64]) -> Vec { + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() +} + +fn assert_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let columns: Vec> = (0..m) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + + let coset_offset: u64 = 7; + let weights = coset_weights(n, coset_offset); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + + let inv_tw = LayerTwiddles::::new_inverse(log_n).unwrap(); + let fwd_tw = + LayerTwiddles::::new((n * blowup).trailing_zeros() as u64).unwrap(); + + let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + let gpu_all = math_cuda::lde::coset_lde_batch_base(&slices, blowup, &weights).unwrap(); + assert_eq!(gpu_all.len(), m); + + for (c, col) in columns.iter().enumerate() { + let cpu = cpu_lde_one(col, blowup, &weights_fp, &inv_tw, &fwd_tw); + assert_eq!( + canon(&gpu_all[c]), + canon(&cpu), + "batch mismatch at col {c}, log_n={log_n}, blowup={blowup}" + ); + } +} + +#[test] +fn batch_small() { + for &m in &[1usize, 4, 16] { + for log_n in 4..=10 { + assert_batch(log_n, 4, m, 100 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn batch_medium() { + for &m in &[2usize, 32] { + for log_n in 11..=14 { + assert_batch(log_n, 4, m, 200 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn batch_large_one_column() { + assert_batch(18, 4, 1, 0xCAFE); +} + +#[test] +fn batch_large_32_columns() { + assert_batch(15, 4, 32, 0xBEEF); +} diff --git a/crypto/math-cuda/tests/lde_batch_ext3.rs b/crypto/math-cuda/tests/lde_batch_ext3.rs new file mode 100644 index 000000000..0a86197a5 --- /dev/null +++ b/crypto/math-cuda/tests/lde_batch_ext3.rs @@ -0,0 +1,161 @@ +//! Ext3 batched coset LDE must agree with the CPU `coset_lde_full_expand` +//! on each column independently when run over `FieldElement`. + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = *FieldElement::::from(n as u64) + .inv() + .unwrap() + .value(); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &g); + } + w +} + +fn rand_ext3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + // Each Fp3 is [u64; 3] in memory; we just flatten componentwise. + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn u64s_to_ext3(raw: &[u64]) -> Vec { + assert_eq!(raw.len() % 3, 0); + let mut out = Vec::with_capacity(raw.len() / 3); + for i in 0..raw.len() / 3 { + out.push(Fp3::new([ + Fp::from_raw(raw[i * 3 + 0]), + Fp::from_raw(raw[i * 3 + 1]), + Fp::from_raw(raw[i * 3 + 2]), + ])); + } + out +} + +fn cpu_lde_one_ext3( + col: &[Fp3], + blowup: usize, + weights_fp: &[Fp], + inv_tw: &LayerTwiddles, + fwd_tw: &LayerTwiddles, +) -> Vec { + let mut buf = col.to_vec(); + Polynomial::coset_lde_full_expand::( + &mut buf, blowup, weights_fp, inv_tw, fwd_tw, + ) + .unwrap(); + buf +} + +fn canon(xs: &[u64]) -> Vec { + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() +} + +fn assert_ext3_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { + let n = 1usize << log_n; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let columns: Vec> = (0..m) + .map(|_| (0..n).map(|_| rand_ext3(&mut rng)).collect()) + .collect(); + + let coset_offset: u64 = 7; + let weights = coset_weights(n, coset_offset); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n).unwrap(); + let fwd_tw = LayerTwiddles::::new(lde_size.trailing_zeros() as u64).unwrap(); + + // Flatten each ext3 column to 3n u64s for the GPU API. + let flat_inputs: Vec> = columns.iter().map(|c| ext3_to_u64s(c)).collect(); + let input_slices: Vec<&[u64]> = flat_inputs.iter().map(|v| v.as_slice()).collect(); + + // Pre-allocate outputs, each 3*lde_size u64s. + let mut flat_outputs: Vec> = + (0..m).map(|_| vec![0u64; 3 * lde_size]).collect(); + { + let mut out_slices: Vec<&mut [u64]> = + flat_outputs.iter_mut().map(|v| v.as_mut_slice()).collect(); + math_cuda::lde::coset_lde_batch_ext3_into( + &input_slices, + n, + blowup, + &weights, + &mut out_slices, + ) + .unwrap(); + } + + for (c, col) in columns.iter().enumerate() { + let cpu = cpu_lde_one_ext3(col, blowup, &weights_fp, &inv_tw, &fwd_tw); + let gpu: Vec = u64s_to_ext3(&flat_outputs[c]); + assert_eq!(gpu.len(), cpu.len(), "length mismatch"); + for i in 0..cpu.len() { + for k in 0..3 { + let cv = *cpu[i].value()[k].value(); + let gv = *gpu[i].value()[k].value(); + let cc = GoldilocksField::canonical(&cv); + let gc = GoldilocksField::canonical(&gv); + if cc != gc { + panic!( + "ext3 batch mismatch col={c} row={i} comp={k} log_n={log_n} blowup={blowup}: cpu={cv:#018x} (canon {cc:#018x}), gpu={gv:#018x} (canon {gc:#018x})", + ); + } + } + } + } + // Also sanity-check raw canonical equality per column. + for (c, col) in columns.iter().enumerate() { + let cpu_raw = ext3_to_u64s(&cpu_lde_one_ext3(col, blowup, &weights_fp, &inv_tw, &fwd_tw)); + assert_eq!(canon(&cpu_raw), canon(&flat_outputs[c])); + } +} + +#[test] +fn ext3_batch_small() { + for &m in &[1usize, 4, 16] { + for log_n in 4..=10 { + assert_ext3_batch(log_n, 4, m, 100 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn ext3_batch_medium() { + for &m in &[2usize, 8] { + for log_n in 11..=14 { + assert_ext3_batch(log_n, 4, m, 300 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn ext3_batch_large_one_column() { + assert_ext3_batch(16, 4, 1, 0xCAFE); +} diff --git a/crypto/math-cuda/tests/ntt.rs b/crypto/math-cuda/tests/ntt.rs new file mode 100644 index 000000000..d7cf3680a --- /dev/null +++ b/crypto/math-cuda/tests/ntt.rs @@ -0,0 +1,136 @@ +//! Phase-3 parity: GPU forward NTT must agree with `Polynomial::evaluate_fft` +//! as a field element, across a sweep of sizes from 2^4 to 2^20. +//! +//! Non-canonical u64s can differ between CPU and GPU while representing the +//! same element; we canonicalise both sides before comparing. + +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsPrimeField; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; + +fn cpu_fft(coeffs: &[u64]) -> Vec { + let elems: Vec = coeffs.iter().map(|&x| Fp::from_raw(x)).collect(); + let poly = Polynomial::new(&elems); + let evals = Polynomial::evaluate_fft::(&poly, 1, None).expect("cpu fft"); + evals.into_iter().map(|e| *e.value()).collect() +} + +fn canonicalize(xs: &[u64]) -> Vec { + xs.iter() + .map(|x| GoldilocksField::canonical(x)) + .collect() +} + +fn assert_ntt_match(log_n: u64, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let input: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + + let cpu = cpu_fft(&input); + let gpu = math_cuda::ntt::forward(&input).expect("gpu ntt"); + + assert_eq!(cpu.len(), gpu.len(), "length mismatch at log_n = {log_n}"); + let cpu_c = canonicalize(&cpu); + let gpu_c = canonicalize(&gpu); + for i in 0..n { + if cpu_c[i] != gpu_c[i] { + panic!( + "log_n={log_n} i={i}: cpu={:#018x} (canon {:#018x}), gpu={:#018x} (canon {:#018x})", + cpu[i], cpu_c[i], gpu[i], gpu_c[i], + ); + } + } +} + +#[test] +fn ntt_sizes_small() { + for log_n in 4..=10 { + assert_ntt_match(log_n, 100 + log_n); + } +} + +#[test] +fn ntt_sizes_medium() { + for log_n in 11..=16 { + assert_ntt_match(log_n, 200 + log_n); + } +} + +#[test] +fn ntt_size_2_to_20() { + // The hot LDE size. One seed is enough; any mismatch screams loudly. + assert_ntt_match(20, 0xDEAD); +} + +#[test] +fn ntt_trivial_sizes() { + // Power-of-two below the interesting range — should still pass. + assert_ntt_match(1, 1); + assert_ntt_match(2, 2); + assert_ntt_match(3, 3); +} + +fn assert_intt_match(log_n: u64, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let evals: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + + let elems: Vec = evals.iter().map(|&x| Fp::from_raw(x)).collect(); + let cpu_poly = + Polynomial::interpolate_fft::(&elems).expect("cpu intt"); + let cpu: Vec = cpu_poly.coefficients.into_iter().map(|e| *e.value()).collect(); + + let gpu = math_cuda::ntt::inverse(&evals).expect("gpu intt"); + + let cpu_c = canonicalize(&cpu); + let gpu_c = canonicalize(&gpu); + for i in 0..n { + if cpu_c[i] != gpu_c[i] { + panic!( + "iNTT log_n={log_n} i={i}: cpu canon {:#018x}, gpu canon {:#018x}", + cpu_c[i], gpu_c[i], + ); + } + } +} + +#[test] +fn intt_sizes_small() { + for log_n in 4..=10 { + assert_intt_match(log_n, 700 + log_n); + } +} + +#[test] +fn intt_sizes_medium() { + for log_n in 11..=16 { + assert_intt_match(log_n, 800 + log_n); + } +} + +#[test] +fn intt_size_2_to_20() { + assert_intt_match(20, 0xBEEF); +} + +#[test] +fn ntt_round_trip() { + // inverse(forward(x)) == x up to canonical form. + let log_n = 14; + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(42); + let x: Vec = (0..n).map(|_| rng.r#gen::() % 0xFFFF_FFFF_0000_0001).collect(); + + let evals = math_cuda::ntt::forward(&x).expect("forward"); + let back = math_cuda::ntt::inverse(&evals).expect("inverse"); + + let x_c = canonicalize(&x); + let back_c = canonicalize(&back); + assert_eq!(x_c, back_c, "round trip failed"); +} + diff --git a/crypto/stark/Cargo.toml b/crypto/stark/Cargo.toml index 53b205996..4d1f2cbca 100644 --- a/crypto/stark/Cargo.toml +++ b/crypto/stark/Cargo.toml @@ -22,6 +22,9 @@ itertools = "0.11.0" # Parallelization crates rayon = { version = "1.8.0", optional = true } +# GPU backend for trace LDE — only linked when `cuda` is enabled. +math-cuda = { path = "../math-cuda", optional = true } + # wasm wasm-bindgen = { version = "0.2", optional = true } serde-wasm-bindgen = { version = "0.5", optional = true } @@ -39,6 +42,7 @@ test_fiat_shamir = [] instruments = [] # This enables timing prints in prover and verifier debug-checks = [] # Enables validate_trace + bus balance report in prover parallel = ["dep:rayon", "crypto/parallel"] +cuda = ["dep:math-cuda"] wasm = ["dep:wasm-bindgen", "dep:serde-wasm-bindgen", "dep:web-sys"] From 79634ff2ef8e956de53326870acec43c30a81a87 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 6 May 2026 15:26:58 -0300 Subject: [PATCH 02/22] fmt --- crypto/math-cuda/build.rs | 9 +- crypto/math-cuda/src/device.rs | 6 +- crypto/math-cuda/src/lde.rs | 151 +++++++++--------- crypto/math-cuda/src/ntt.rs | 8 +- crypto/math-cuda/tests/bench_quick.rs | 68 ++++---- crypto/math-cuda/tests/evaluate_coset_ext3.rs | 14 +- crypto/math-cuda/tests/goldilocks.rs | 15 +- crypto/math-cuda/tests/lde.rs | 6 +- crypto/math-cuda/tests/lde_batch.rs | 8 +- crypto/math-cuda/tests/lde_batch_ext3.rs | 11 +- crypto/math-cuda/tests/ntt.rs | 18 ++- 11 files changed, 159 insertions(+), 155 deletions(-) diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs index a6defb5ab..d6b803bed 100644 --- a/crypto/math-cuda/build.rs +++ b/crypto/math-cuda/build.rs @@ -30,14 +30,7 @@ fn compile_ptx(src: &str, out_name: &str) { let arch = env::var("CUDARC_NVCC_ARCH").unwrap_or_else(|_| "compute_89".to_string()); let status = Command::new(nvcc_path()) - .args([ - "--ptx", - "-O3", - "-std=c++17", - "-arch", - &arch, - "-o", - ]) + .args(["--ptx", "-O3", "-std=c++17", "-arch", &arch, "-o"]) .arg(&out_path) .arg(&src_path) .status() diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index 2c70716a6..03f0b67ca 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -40,11 +40,7 @@ impl PinnedStaging { } } - pub fn ensure_capacity( - &mut self, - min_elems: usize, - ctx: &CudaContext, - ) -> Result<()> { + pub fn ensure_capacity(&mut self, min_elems: usize, ctx: &CudaContext) -> Result<()> { if self.capacity_elems >= min_elems { return Ok(()); } diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index fd25d1bca..8b9da522c 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -38,15 +38,14 @@ pub struct GpuLdeExt3 { pub lde_size: usize, } -pub fn coset_lde_base( - evals: &[u64], - blowup_factor: usize, - weights: &[u64], -) -> Result> { +pub fn coset_lde_base(evals: &[u64], blowup_factor: usize, weights: &[u64]) -> Result> { let n = evals.len(); assert!(n.is_power_of_two(), "evals length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match evals"); - assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + assert!( + blowup_factor.is_power_of_two(), + "blowup must be power of two" + ); if n == 0 { return Ok(Vec::new()); } @@ -130,7 +129,10 @@ pub fn coset_lde_batch_base( let n = columns[0].len(); assert!(n.is_power_of_two(), "column length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match column length"); - assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + assert!( + blowup_factor.is_power_of_two(), + "blowup must be power of two" + ); for c in columns.iter() { assert_eq!(c.len(), n, "all columns must be the same size"); } @@ -147,7 +149,11 @@ pub fn coset_lde_batch_base( let staging_slot = be.pinned_staging(); let debug_phases = std::env::var("MATH_CUDA_PHASE_TIMING").is_ok(); - let t_start = if debug_phases { Some(std::time::Instant::now()) } else { None }; + let t_start = if debug_phases { + Some(std::time::Instant::now()) + } else { + None + }; let phase = |label: &str, prev: &mut Option| { if let Some(p) = prev.as_ref() { let now = std::time::Instant::now(); @@ -165,7 +171,9 @@ pub fn coset_lde_batch_base( staging.ensure_capacity(m * lde_size, &be.ctx)?; // SAFETY: staging is locked, the slice alias ends before we unlock. let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - if debug_phases { phase("staging lock + grow", &mut last); } + if debug_phases { + phase("staging lock + grow", &mut last); + } // Pack columns into first m*n slots of the pinned buffer. Parallel: pinned // writes are DRAM-bandwidth bound, saturates at ~8 cores on modern @@ -176,32 +184,39 @@ pub fn coset_lde_batch_base( // SAFETY: each task writes to a disjoint `[c*n..c*n+n]` region of // `pinned`, and the outer `staging` lock guarantees no other call is // using the buffer concurrently. - let dst = unsafe { - std::slice::from_raw_parts_mut( - (pinned_base_ptr as *mut u64).add(c * n), - n, - ) - }; + let dst = + unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; dst.copy_from_slice(col); }); - if debug_phases { phase("host pack (pinned, rayon)", &mut last); } + if debug_phases { + phase("host pack (pinned, rayon)", &mut last); + } // Column layout: `buf[c * lde_size + r]`. Zeroed so the [n, lde_size) // tail of each column is already the zero-pad the CPU path does. let mut buf = stream.alloc_zeros::(m * lde_size)?; - if debug_phases { stream.synchronize()?; phase("alloc_zeros", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("alloc_zeros", &mut last); + } // One memcpy per column from the pinned buffer into the strided slots. // The pinned source hits PCIe line-rate. for c in 0..m { let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; } - if debug_phases { stream.synchronize()?; phase("H2D cols (pinned)", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("H2D cols (pinned)", &mut last); + } let inv_tw = be.inv_twiddles_for(log_n)?; let fwd_tw = be.fwd_twiddles_for(log_lde)?; let weights_dev = stream.clone_htod(weights)?; - if debug_phases { stream.synchronize()?; phase("twiddles + weights", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("twiddles + weights", &mut last); + } let n_u64 = n as u64; let lde_u64 = lde_size as u64; @@ -227,7 +242,10 @@ pub fn coset_lde_batch_base( } } - if debug_phases { stream.synchronize()?; phase("bit_reverse N", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("bit_reverse N", &mut last); + } // === 2. iNTT body over all columns === run_batched_ntt_body( stream.as_ref(), @@ -238,7 +256,10 @@ pub fn coset_lde_batch_base( col_stride_u64, m_u32, )?; - if debug_phases { stream.synchronize()?; phase("iNTT body", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("iNTT body", &mut last); + } // === 3. Pointwise multiply by coset weights (includes 1/N) === { @@ -278,7 +299,10 @@ pub fn coset_lde_batch_base( } } - if debug_phases { stream.synchronize()?; phase("pointwise + bit_reverse LDE", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("pointwise + bit_reverse LDE", &mut last); + } // === 5. Forward NTT on full LDE of every column === run_batched_ntt_body( stream.as_ref(), @@ -289,13 +313,18 @@ pub fn coset_lde_batch_base( col_stride_u64, m_u32, )?; - if debug_phases { stream.synchronize()?; phase("forward NTT body", &mut last); } + if debug_phases { + stream.synchronize()?; + phase("forward NTT body", &mut last); + } // Single big D2H into the reusable pinned staging buffer — pinned, one // call to the driver, saturates PCIe. stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; stream.synchronize()?; - if debug_phases { phase("D2H (one shot into pinned)", &mut last); } + if debug_phases { + phase("D2H (one shot into pinned)", &mut last); + } // Split pinned → per-column Vecs. The first write to each virgin // Vec page-faults, which can dominate total time (~75 ms for 128 MB). @@ -308,16 +337,15 @@ pub fn coset_lde_batch_base( let mut v = Vec::::with_capacity(lde_size); unsafe { v.set_len(lde_size) }; let src = unsafe { - std::slice::from_raw_parts( - (pinned_ptr as *const u64).add(c * lde_size), - lde_size, - ) + std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) }; v.copy_from_slice(src); v }) .collect(); - if debug_phases { phase("copy out (rayon pinned → Vecs)", &mut last); } + if debug_phases { + phase("copy out (rayon pinned → Vecs)", &mut last); + } drop(staging); Ok(out) } @@ -341,7 +369,10 @@ pub fn coset_lde_batch_base_into( let n = columns[0].len(); assert!(n.is_power_of_two(), "column length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match column length"); - assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + assert!( + blowup_factor.is_power_of_two(), + "blowup must be power of two" + ); for c in columns.iter() { assert_eq!(c.len(), n, "all columns must be the same size"); } @@ -461,18 +492,12 @@ pub fn coset_lde_batch_base_into( #[allow(unused_imports)] use rayon::prelude::*; let pinned_ptr = pinned.as_ptr() as usize; - outputs - .par_iter_mut() - .enumerate() - .for_each(|(c, dst)| { - let src = unsafe { - std::slice::from_raw_parts( - (pinned_ptr as *const u64).add(c * lde_size), - lde_size, - ) - }; - dst.copy_from_slice(src); - }); + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) + }; + dst.copy_from_slice(src); + }); drop(staging); Ok(()) } @@ -497,15 +522,8 @@ pub fn evaluate_poly_coset_batch_ext3_into( weights: &[u64], outputs: &mut [&mut [u64]], ) -> Result<()> { - evaluate_poly_coset_batch_ext3_into_inner( - coefs, - n, - blowup_factor, - weights, - outputs, - false, - ) - .map(|_| ()) + evaluate_poly_coset_batch_ext3_into_inner(coefs, n, blowup_factor, weights, outputs, false) + .map(|_| ()) } /// Same as [`evaluate_poly_coset_batch_ext3_into`] but retains the de- @@ -519,14 +537,8 @@ pub fn evaluate_poly_coset_batch_ext3_into_keep( weights: &[u64], outputs: &mut [&mut [u64]], ) -> Result { - let opt = evaluate_poly_coset_batch_ext3_into_inner( - coefs, - n, - blowup_factor, - weights, - outputs, - true, - )?; + let opt = + evaluate_poly_coset_batch_ext3_into_inner(coefs, n, blowup_factor, weights, outputs, true)?; Ok(opt.expect("keep_device_buf=true must return Some")) } @@ -724,7 +736,10 @@ pub fn coset_lde_batch_ext3_into( assert_eq!(outputs.len(), m, "outputs must match columns count"); assert!(n.is_power_of_two(), "n must be a power of two"); assert_eq!(weights.len(), n, "weights length must match n"); - assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + assert!( + blowup_factor.is_power_of_two(), + "blowup must be power of two" + ); for c in columns.iter() { assert_eq!(c.len(), 3 * n, "each ext3 column must be 3*n u64s"); } @@ -757,22 +772,13 @@ pub fn coset_lde_batch_ext3_into( columns.par_iter().enumerate().for_each(|(c, col)| { // SAFETY: disjoint regions per c; staging lock held. let slab_a = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3 + 0) * n), - n, - ) + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 0) * n), n) }; let slab_b = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), - n, - ) + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) }; let slab_c = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), - n, - ) + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) }; for i in 0..n { slab_a[i] = col[i * 3 + 0]; @@ -981,4 +987,3 @@ fn run_batched_ntt_body( } Ok(()) } - diff --git a/crypto/math-cuda/src/ntt.rs b/crypto/math-cuda/src/ntt.rs index 0ebb015ea..31333c3ae 100644 --- a/crypto/math-cuda/src/ntt.rs +++ b/crypto/math-cuda/src/ntt.rs @@ -87,13 +87,7 @@ fn ntt_inplace(input: &[u64], forward: bool) -> Result> { // 2. DIT butterfly levels. For log_n >= 8 we fuse 8 levels per kernel via // the shmem kernel; for very small sizes (< 256 elements) we stick with // the per-level kernel because the shmem block dimensions assume n ≥ 256. - run_ntt_body( - stream.as_ref(), - &mut x_dev, - tw_dev.as_ref(), - n_u64, - log_n, - )?; + run_ntt_body(stream.as_ref(), &mut x_dev, tw_dev.as_ref(), n_u64, log_n)?; // 3. For iNTT, multiply by 1/n. if !forward { diff --git a/crypto/math-cuda/tests/bench_quick.rs b/crypto/math-cuda/tests/bench_quick.rs index 561331b74..41e9444c3 100644 --- a/crypto/math-cuda/tests/bench_quick.rs +++ b/crypto/math-cuda/tests/bench_quick.rs @@ -15,7 +15,10 @@ use rayon::prelude::*; type Fp = FieldElement; fn coset_weights(n: usize, g: u64) -> Vec { - let inv_n = *FieldElement::::from(n as u64).inv().unwrap().value(); + let inv_n = *FieldElement::::from(n as u64) + .inv() + .unwrap() + .value(); let mut w = Vec::with_capacity(n); let mut cur = inv_n; for _ in 0..n { @@ -54,7 +57,11 @@ fn bench_lde_2_to_18_blowup_4() { for _ in 0..TRIALS { let mut buf: Vec = input.iter().map(|&x| Fp::from_raw(x)).collect(); Polynomial::coset_lde_full_expand::( - &mut buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + &mut buf, + blowup, + &weights_fp, + &inv_tw, + &fwd_tw, ) .unwrap(); std::hint::black_box(&buf); @@ -101,12 +108,7 @@ fn bench_lde_multi_column_parallel() { let num_cols = 64; // Warm up. - let _ = math_cuda::lde::coset_lde_base( - &vec![0u64; n], - blowup, - &coset_weights(n, 7), - ) - .unwrap(); + let _ = math_cuda::lde::coset_lde_base(&vec![0u64; n], blowup, &coset_weights(n, 7)).unwrap(); // Build input data. let mut rng = ChaCha8Rng::seed_from_u64(11); @@ -134,7 +136,11 @@ fn bench_lde_multi_column_parallel() { let t0 = Instant::now(); cpu_bufs.par_iter_mut().for_each(|buf| { Polynomial::coset_lde_full_expand::( - buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + buf, + blowup, + &weights_fp, + &inv_tw, + &fwd_tw, ) .unwrap(); }); @@ -164,15 +170,12 @@ fn bench_lde_batched_prover_scale() { let weights = coset_weights(n, 7); let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); - let fwd_tw = LayerTwiddles::::new( - (n * blowup).trailing_zeros() as u64, - ) - .unwrap(); + let fwd_tw = + LayerTwiddles::::new((n * blowup).trailing_zeros() as u64).unwrap(); let warm_slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); for _ in 0..8 { - let _ = - math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); + let _ = math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); } let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); @@ -194,7 +197,11 @@ fn bench_lde_batched_prover_scale() { let t0 = Instant::now(); cpu_bufs.par_iter_mut().for_each(|buf| { Polynomial::coset_lde_full_expand::( - buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + buf, + blowup, + &weights_fp, + &inv_tw, + &fwd_tw, ) .unwrap(); }); @@ -227,15 +234,12 @@ fn bench_lde_batched_vs_rayon_cpu() { // one-time pinned staging alloc cost. let warm_slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); for _ in 0..64 { - let _ = - math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); + let _ = math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); } let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); - let fwd_tw = LayerTwiddles::::new( - (n * blowup).trailing_zeros() as u64, - ) - .unwrap(); + let fwd_tw = + LayerTwiddles::::new((n * blowup).trailing_zeros() as u64).unwrap(); // GPU batched — first run may include lazy device init; do a few to stabilise. let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); @@ -254,7 +258,11 @@ fn bench_lde_batched_vs_rayon_cpu() { let t0 = Instant::now(); cpu_bufs.par_iter_mut().for_each(|buf| { Polynomial::coset_lde_full_expand::( - buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + buf, + blowup, + &weights_fp, + &inv_tw, + &fwd_tw, ) .unwrap(); }); @@ -277,12 +285,7 @@ fn bench_lde_multi_column_serialized_gpu() { let n = 1usize << log_n; let num_cols = 64; - let _ = math_cuda::lde::coset_lde_base( - &vec![0u64; n], - blowup, - &coset_weights(n, 7), - ) - .unwrap(); + let _ = math_cuda::lde::coset_lde_base(&vec![0u64; n], blowup, &coset_weights(n, 7)).unwrap(); let mut rng = ChaCha8Rng::seed_from_u64(13); let columns: Vec> = (0..num_cols) @@ -320,12 +323,7 @@ fn bench_lde_multi_column_gpu_limited_threads() { let n = 1usize << log_n; let num_cols = 64; - let _ = math_cuda::lde::coset_lde_base( - &vec![0u64; n], - blowup, - &coset_weights(n, 7), - ) - .unwrap(); + let _ = math_cuda::lde::coset_lde_base(&vec![0u64; n], blowup, &coset_weights(n, 7)).unwrap(); let mut rng = ChaCha8Rng::seed_from_u64(12); let columns: Vec> = (0..num_cols) diff --git a/crypto/math-cuda/tests/evaluate_coset_ext3.rs b/crypto/math-cuda/tests/evaluate_coset_ext3.rs index a79195291..020c39241 100644 --- a/crypto/math-cuda/tests/evaluate_coset_ext3.rs +++ b/crypto/math-cuda/tests/evaluate_coset_ext3.rs @@ -81,13 +81,8 @@ fn assert_evaluate_coset(log_n: u64, blowup: usize, m: usize, offset: u64, seed: .iter() .map(|coefs| { let p = Polynomial::new(coefs); - Polynomial::evaluate_offset_fft::( - &p, - blowup, - Some(n), - &offset_fp, - ) - .unwrap() + Polynomial::evaluate_offset_fft::(&p, blowup, Some(n), &offset_fp) + .unwrap() }) .collect(); @@ -114,7 +109,10 @@ fn assert_evaluate_coset(log_n: u64, blowup: usize, m: usize, offset: u64, seed: for i in 0..gpu.len() { let g = canon_fp3(&gpu[i]); let cc = canon_fp3(&cpu[c][i]); - assert_eq!(g, cc, "eval mismatch col={c} row={i} log_n={log_n} blowup={blowup}"); + assert_eq!( + g, cc, + "eval mismatch col={c} row={i} log_n={log_n} blowup={blowup}" + ); } } } diff --git a/crypto/math-cuda/tests/goldilocks.rs b/crypto/math-cuda/tests/goldilocks.rs index 317ffb0f8..37a5d8533 100644 --- a/crypto/math-cuda/tests/goldilocks.rs +++ b/crypto/math-cuda/tests/goldilocks.rs @@ -112,12 +112,15 @@ fn gpu_goldilocks_edge_cases() { } } - let cases: &[(&str, fn(&[u64], &[u64]) -> math_cuda::Result>, fn(&u64, &u64) -> u64)] = - &[ - ("gl_add", math_cuda::gl_add_u64, GoldilocksField::add), - ("gl_sub", math_cuda::gl_sub_u64, GoldilocksField::sub), - ("gl_mul", math_cuda::gl_mul_u64, GoldilocksField::mul), - ]; + let cases: &[( + &str, + fn(&[u64], &[u64]) -> math_cuda::Result>, + fn(&u64, &u64) -> u64, + )] = &[ + ("gl_add", math_cuda::gl_add_u64, GoldilocksField::add), + ("gl_sub", math_cuda::gl_sub_u64, GoldilocksField::sub), + ("gl_mul", math_cuda::gl_mul_u64, GoldilocksField::mul), + ]; for (op, gpu_fn, cpu_fn) in cases { let expected: Vec = a.iter().zip(&b).map(|(x, y)| cpu_fn(x, y)).collect(); diff --git a/crypto/math-cuda/tests/lde.rs b/crypto/math-cuda/tests/lde.rs index 9648f833a..75ea80867 100644 --- a/crypto/math-cuda/tests/lde.rs +++ b/crypto/math-cuda/tests/lde.rs @@ -69,7 +69,11 @@ fn assert_lde_match(log_n: u64, blowup_factor: usize, seed: u64) { let cpu = cpu_lde(&evals, blowup_factor, coset_offset); let gpu = math_cuda::lde::coset_lde_base(&evals, blowup_factor, &weights).expect("gpu lde"); - assert_eq!(cpu.len(), gpu.len(), "length mismatch (log_n={log_n}, blowup={blowup_factor})"); + assert_eq!( + cpu.len(), + gpu.len(), + "length mismatch (log_n={log_n}, blowup={blowup_factor})" + ); let cpu_c = canon(&cpu); let gpu_c = canon(&gpu); for (i, (e, a)) in cpu_c.iter().zip(&gpu_c).enumerate() { diff --git a/crypto/math-cuda/tests/lde_batch.rs b/crypto/math-cuda/tests/lde_batch.rs index 67f975728..153e5d3e5 100644 --- a/crypto/math-cuda/tests/lde_batch.rs +++ b/crypto/math-cuda/tests/lde_batch.rs @@ -25,7 +25,13 @@ fn coset_weights(n: usize, g: u64) -> Vec { w } -fn cpu_lde_one(col: &[u64], blowup: usize, weights_fp: &[Fp], inv_tw: &LayerTwiddles, fwd_tw: &LayerTwiddles) -> Vec { +fn cpu_lde_one( + col: &[u64], + blowup: usize, + weights_fp: &[Fp], + inv_tw: &LayerTwiddles, + fwd_tw: &LayerTwiddles, +) -> Vec { let mut buf: Vec = col.iter().map(|&x| Fp::from_raw(x)).collect(); Polynomial::coset_lde_full_expand::( &mut buf, blowup, weights_fp, inv_tw, fwd_tw, diff --git a/crypto/math-cuda/tests/lde_batch_ext3.rs b/crypto/math-cuda/tests/lde_batch_ext3.rs index 0a86197a5..ff19f6c56 100644 --- a/crypto/math-cuda/tests/lde_batch_ext3.rs +++ b/crypto/math-cuda/tests/lde_batch_ext3.rs @@ -97,8 +97,7 @@ fn assert_ext3_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { let input_slices: Vec<&[u64]> = flat_inputs.iter().map(|v| v.as_slice()).collect(); // Pre-allocate outputs, each 3*lde_size u64s. - let mut flat_outputs: Vec> = - (0..m).map(|_| vec![0u64; 3 * lde_size]).collect(); + let mut flat_outputs: Vec> = (0..m).map(|_| vec![0u64; 3 * lde_size]).collect(); { let mut out_slices: Vec<&mut [u64]> = flat_outputs.iter_mut().map(|v| v.as_mut_slice()).collect(); @@ -132,7 +131,13 @@ fn assert_ext3_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { } // Also sanity-check raw canonical equality per column. for (c, col) in columns.iter().enumerate() { - let cpu_raw = ext3_to_u64s(&cpu_lde_one_ext3(col, blowup, &weights_fp, &inv_tw, &fwd_tw)); + let cpu_raw = ext3_to_u64s(&cpu_lde_one_ext3( + col, + blowup, + &weights_fp, + &inv_tw, + &fwd_tw, + )); assert_eq!(canon(&cpu_raw), canon(&flat_outputs[c])); } } diff --git a/crypto/math-cuda/tests/ntt.rs b/crypto/math-cuda/tests/ntt.rs index d7cf3680a..b6697b82d 100644 --- a/crypto/math-cuda/tests/ntt.rs +++ b/crypto/math-cuda/tests/ntt.rs @@ -21,9 +21,7 @@ fn cpu_fft(coeffs: &[u64]) -> Vec { } fn canonicalize(xs: &[u64]) -> Vec { - xs.iter() - .map(|x| GoldilocksField::canonical(x)) - .collect() + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() } fn assert_ntt_match(log_n: u64, seed: u64) { @@ -81,9 +79,12 @@ fn assert_intt_match(log_n: u64, seed: u64) { let evals: Vec = (0..n).map(|_| rng.r#gen::()).collect(); let elems: Vec = evals.iter().map(|&x| Fp::from_raw(x)).collect(); - let cpu_poly = - Polynomial::interpolate_fft::(&elems).expect("cpu intt"); - let cpu: Vec = cpu_poly.coefficients.into_iter().map(|e| *e.value()).collect(); + let cpu_poly = Polynomial::interpolate_fft::(&elems).expect("cpu intt"); + let cpu: Vec = cpu_poly + .coefficients + .into_iter() + .map(|e| *e.value()) + .collect(); let gpu = math_cuda::ntt::inverse(&evals).expect("gpu intt"); @@ -124,7 +125,9 @@ fn ntt_round_trip() { let log_n = 14; let n = 1usize << log_n; let mut rng = ChaCha8Rng::seed_from_u64(42); - let x: Vec = (0..n).map(|_| rng.r#gen::() % 0xFFFF_FFFF_0000_0001).collect(); + let x: Vec = (0..n) + .map(|_| rng.r#gen::() % 0xFFFF_FFFF_0000_0001) + .collect(); let evals = math_cuda::ntt::forward(&x).expect("forward"); let back = math_cuda::ntt::inverse(&evals).expect("inverse"); @@ -133,4 +136,3 @@ fn ntt_round_trip() { let back_c = canonicalize(&back); assert_eq!(x_c, back_c, "round trip failed"); } - From ac6fbb5972cc4cd4e228bd1adb2e659760c7b6e1 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 6 May 2026 15:55:58 -0300 Subject: [PATCH 03/22] fix clippy --- crypto/math-cuda/build.rs | 34 +++++++++++++++++-- crypto/math-cuda/src/lde.rs | 20 +++++++---- crypto/math-cuda/tests/evaluate_coset_ext3.rs | 2 +- crypto/math-cuda/tests/goldilocks.rs | 10 +++--- crypto/math-cuda/tests/lde.rs | 2 +- crypto/math-cuda/tests/lde_batch.rs | 2 +- crypto/math-cuda/tests/lde_batch_ext3.rs | 4 +-- crypto/math-cuda/tests/ntt.rs | 2 +- 8 files changed, 54 insertions(+), 22 deletions(-) diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs index d6b803bed..43edaa31c 100644 --- a/crypto/math-cuda/build.rs +++ b/crypto/math-cuda/build.rs @@ -1,4 +1,5 @@ use std::env; +use std::fs; use std::path::PathBuf; use std::process::Command; @@ -13,7 +14,7 @@ fn nvcc_path() -> PathBuf { cuda_home().join("bin").join("nvcc") } -fn compile_ptx(src: &str, out_name: &str) { +fn compile_ptx(src: &str, out_name: &str, have_nvcc: bool) { let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); let src_path = manifest_dir.join("kernels").join(src); @@ -24,6 +25,16 @@ fn compile_ptx(src: &str, out_name: &str) { println!("cargo:rerun-if-env-changed=CUDA_PATH"); println!("cargo:rerun-if-env-changed=CUDARC_NVCC_ARCH"); + // No nvcc on PATH → emit an empty PTX stub so the crate still compiles. + // include_str! in src/device.rs needs the file to exist at build time. + // Any runtime kernel call will then panic from cudarc when loading the + // empty module — which is the right failure mode (we can't run GPU code + // without nvcc on the build host anyway). + if !have_nvcc { + fs::write(&out_path, "").expect("failed to write empty PTX stub"); + return; + } + // Emit PTX for a virtual architecture; the CUDA driver JIT-compiles it for the // actual GPU at load time, so one PTX works across Ada/Hopper/Blackwell. Override // with CUDARC_NVCC_ARCH to pin a specific compute capability. @@ -45,6 +56,23 @@ fn main() { // Headers are not compiled; emit rerun-if-changed so edits trigger rebuilds. println!("cargo:rerun-if-changed=kernels/goldilocks.cuh"); println!("cargo:rerun-if-changed=kernels/ext3.cuh"); - compile_ptx("arith.cu", "arith.ptx"); - compile_ptx("ntt.cu", "ntt.ptx"); + + // Probe for nvcc once. Workspace consumers (clippy, fmt, CPU-only test + // runners) build math-cuda incidentally without using its kernels; allow + // those to succeed by stubbing out PTX when nvcc is unavailable. + let have_nvcc = Command::new(nvcc_path()) + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false); + if !have_nvcc { + println!( + "cargo:warning=math-cuda: nvcc not found at {} — emitting empty PTX stubs. \ + Runtime GPU calls will panic. Install CUDA and rebuild for a working backend.", + nvcc_path().display() + ); + } + + compile_ptx("arith.cu", "arith.ptx", have_nvcc); + compile_ptx("ntt.cu", "ntt.ptx", have_nvcc); } diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 8b9da522c..22d62ca3e 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -329,13 +329,19 @@ pub fn coset_lde_batch_base( // Split pinned → per-column Vecs. The first write to each virgin // Vec page-faults, which can dominate total time (~75 ms for 128 MB). // Parallelise so the fault cost spreads across CPU cores. - use rayon::prelude::*; let pinned_ptr = pinned.as_ptr() as usize; let out: Vec> = (0..m) .into_par_iter() .map(|c| { - let mut v = Vec::::with_capacity(lde_size); - unsafe { v.set_len(lde_size) }; + // set_len skips the O(N) zero-init that vec![0; n] would do + // (saves ~75 ms per 128 MB at prover scale). copy_from_slice + // below writes every slot before any reader sees the Vec. + #[allow(clippy::uninit_vec)] + let mut v = { + let mut v = Vec::::with_capacity(lde_size); + unsafe { v.set_len(lde_size) }; + v + }; let src = unsafe { std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) }; @@ -772,7 +778,7 @@ pub fn coset_lde_batch_ext3_into( columns.par_iter().enumerate().for_each(|(c, col)| { // SAFETY: disjoint regions per c; staging lock held. let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 0) * n), n) + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) }; let slab_b = unsafe { std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) @@ -781,7 +787,7 @@ pub fn coset_lde_batch_ext3_into( std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) }; for i in 0..n { - slab_a[i] = col[i * 3 + 0]; + slab_a[i] = col[i * 3]; slab_b[i] = col[i * 3 + 1]; slab_c[i] = col[i * 3 + 2]; } @@ -885,7 +891,7 @@ pub fn coset_lde_batch_ext3_into( outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { let slab_a = unsafe { std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 0) * lde_size), + (pinned_const as *const u64).add((c * 3) * lde_size), lde_size, ) }; @@ -902,7 +908,7 @@ pub fn coset_lde_batch_ext3_into( ) }; for i in 0..lde_size { - dst[i * 3 + 0] = slab_a[i]; + dst[i * 3] = slab_a[i]; dst[i * 3 + 1] = slab_b[i]; dst[i * 3 + 2] = slab_c[i]; } diff --git a/crypto/math-cuda/tests/evaluate_coset_ext3.rs b/crypto/math-cuda/tests/evaluate_coset_ext3.rs index 020c39241..334b2601d 100644 --- a/crypto/math-cuda/tests/evaluate_coset_ext3.rs +++ b/crypto/math-cuda/tests/evaluate_coset_ext3.rs @@ -47,7 +47,7 @@ fn u64s_to_ext3(raw: &[u64]) -> Vec { let mut out = Vec::with_capacity(raw.len() / 3); for i in 0..raw.len() / 3 { out.push(Fp3::new([ - Fp::from_raw(raw[i * 3 + 0]), + Fp::from_raw(raw[i * 3]), Fp::from_raw(raw[i * 3 + 1]), Fp::from_raw(raw[i * 3 + 2]), ])); diff --git a/crypto/math-cuda/tests/goldilocks.rs b/crypto/math-cuda/tests/goldilocks.rs index 37a5d8533..33b7bf5e7 100644 --- a/crypto/math-cuda/tests/goldilocks.rs +++ b/crypto/math-cuda/tests/goldilocks.rs @@ -78,7 +78,7 @@ fn gpu_gl_mul_matches_cpu() { #[test] fn gpu_gl_neg_matches_cpu() { let a = sample_inputs(7); - let expected: Vec = a.iter().map(|x| GoldilocksField::neg(x)).collect(); + let expected: Vec = a.iter().map(GoldilocksField::neg).collect(); let actual = math_cuda::gl_neg_u64(&a).expect("GPU gl_neg"); assert_raw_eq("gl_neg", &expected, &actual); } @@ -112,11 +112,9 @@ fn gpu_goldilocks_edge_cases() { } } - let cases: &[( - &str, - fn(&[u64], &[u64]) -> math_cuda::Result>, - fn(&u64, &u64) -> u64, - )] = &[ + type GpuOp = fn(&[u64], &[u64]) -> math_cuda::Result>; + type CpuOp = fn(&u64, &u64) -> u64; + let cases: &[(&str, GpuOp, CpuOp)] = &[ ("gl_add", math_cuda::gl_add_u64, GoldilocksField::add), ("gl_sub", math_cuda::gl_sub_u64, GoldilocksField::sub), ("gl_mul", math_cuda::gl_mul_u64, GoldilocksField::mul), diff --git a/crypto/math-cuda/tests/lde.rs b/crypto/math-cuda/tests/lde.rs index 75ea80867..facd2d861 100644 --- a/crypto/math-cuda/tests/lde.rs +++ b/crypto/math-cuda/tests/lde.rs @@ -52,7 +52,7 @@ fn cpu_lde(evals: &[u64], blowup_factor: usize, coset_offset: u64) -> Vec { } fn canon(xs: &[u64]) -> Vec { - xs.iter().map(|x| GoldilocksField::canonical(x)).collect() + xs.iter().map(GoldilocksField::canonical).collect() } fn assert_lde_match(log_n: u64, blowup_factor: usize, seed: u64) { diff --git a/crypto/math-cuda/tests/lde_batch.rs b/crypto/math-cuda/tests/lde_batch.rs index 153e5d3e5..b7120ca28 100644 --- a/crypto/math-cuda/tests/lde_batch.rs +++ b/crypto/math-cuda/tests/lde_batch.rs @@ -41,7 +41,7 @@ fn cpu_lde_one( } fn canon(xs: &[u64]) -> Vec { - xs.iter().map(|x| GoldilocksField::canonical(x)).collect() + xs.iter().map(GoldilocksField::canonical).collect() } fn assert_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { diff --git a/crypto/math-cuda/tests/lde_batch_ext3.rs b/crypto/math-cuda/tests/lde_batch_ext3.rs index ff19f6c56..c1156edaa 100644 --- a/crypto/math-cuda/tests/lde_batch_ext3.rs +++ b/crypto/math-cuda/tests/lde_batch_ext3.rs @@ -51,7 +51,7 @@ fn u64s_to_ext3(raw: &[u64]) -> Vec { let mut out = Vec::with_capacity(raw.len() / 3); for i in 0..raw.len() / 3 { out.push(Fp3::new([ - Fp::from_raw(raw[i * 3 + 0]), + Fp::from_raw(raw[i * 3]), Fp::from_raw(raw[i * 3 + 1]), Fp::from_raw(raw[i * 3 + 2]), ])); @@ -75,7 +75,7 @@ fn cpu_lde_one_ext3( } fn canon(xs: &[u64]) -> Vec { - xs.iter().map(|x| GoldilocksField::canonical(x)).collect() + xs.iter().map(GoldilocksField::canonical).collect() } fn assert_ext3_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { diff --git a/crypto/math-cuda/tests/ntt.rs b/crypto/math-cuda/tests/ntt.rs index b6697b82d..f3689cf94 100644 --- a/crypto/math-cuda/tests/ntt.rs +++ b/crypto/math-cuda/tests/ntt.rs @@ -21,7 +21,7 @@ fn cpu_fft(coeffs: &[u64]) -> Vec { } fn canonicalize(xs: &[u64]) -> Vec { - xs.iter().map(|x| GoldilocksField::canonical(x)).collect() + xs.iter().map(GoldilocksField::canonical).collect() } fn assert_ntt_match(log_n: u64, seed: u64) { From 2ceb3b0ef435f63702c1c4891f9116bac7e32a3f Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 6 May 2026 17:15:10 -0300 Subject: [PATCH 04/22] gpu 2nd part --- crypto/math-cuda/build.rs | 1 + crypto/math-cuda/kernels/keccak.cu | 347 +++++++ crypto/math-cuda/src/device.rs | 27 + crypto/math-cuda/src/lde.rs | 1189 +++++++++++++++++++++++ crypto/math-cuda/src/lib.rs | 1 + crypto/math-cuda/src/merkle.rs | 415 ++++++++ crypto/math-cuda/tests/keccak_leaves.rs | 140 +++ crypto/math-cuda/tests/merkle_tree.rs | 92 ++ 8 files changed, 2212 insertions(+) create mode 100644 crypto/math-cuda/kernels/keccak.cu create mode 100644 crypto/math-cuda/src/merkle.rs create mode 100644 crypto/math-cuda/tests/keccak_leaves.rs create mode 100644 crypto/math-cuda/tests/merkle_tree.rs diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs index 43edaa31c..7c417fb9c 100644 --- a/crypto/math-cuda/build.rs +++ b/crypto/math-cuda/build.rs @@ -75,4 +75,5 @@ fn main() { compile_ptx("arith.cu", "arith.ptx", have_nvcc); compile_ptx("ntt.cu", "ntt.ptx", have_nvcc); + compile_ptx("keccak.cu", "keccak.ptx", have_nvcc); } diff --git a/crypto/math-cuda/kernels/keccak.cu b/crypto/math-cuda/kernels/keccak.cu new file mode 100644 index 000000000..68ddce3b4 --- /dev/null +++ b/crypto/math-cuda/kernels/keccak.cu @@ -0,0 +1,347 @@ +// CUDA Keccak-256 (original Keccak, NOT SHA3-256 — uses 0x01 padding delimiter). +// +// Used by the lambda-vm prover's Merkle commit: +// leaf = Keccak-256(concat(col_0[br_idx].to_be_bytes(), col_1[br_idx].to_be_bytes(), …)) +// where `br_idx = bit_reverse(row_idx, log_num_rows)` and each element is +// written in BIG-ENDIAN canonical form (per `FieldElement::write_bytes_be`). +// +// Keccak state is 5x5 lanes of u64, interpreted little-endian. Rate = 136 B +// (17 lanes) for 256-bit output, capacity = 64 B (8 lanes). +// +// Since every input byte is u64-aligned (each field element is 8 or 24 bytes), +// we can absorb lane-by-lane instead of byte-by-byte. Canonicalise + byte-swap +// each u64 on read to turn a BE-serialised element into its LE-interpreted +// lane value. + +#include +#include "goldilocks.cuh" + +__device__ __constant__ uint64_t KECCAK_RC[24] = { + 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, + 0x8000000080008000ULL, 0x000000000000808bULL, 0x0000000080000001ULL, + 0x8000000080008081ULL, 0x8000000000008009ULL, 0x000000000000008aULL, + 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, + 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, + 0x8000000000008003ULL, 0x8000000000008002ULL, 0x8000000000000080ULL, + 0x000000000000800aULL, 0x800000008000000aULL, 0x8000000080008081ULL, + 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL, +}; + +// Rotation offsets indexed by lane position x + 5*y. Standard Keccak rho. +__device__ __constant__ uint32_t KECCAK_RHO_OFFSETS[25] = { + 0, 1, 62, 28, 27, // y=0: x=0..4 + 36, 44, 6, 55, 20, // y=1 + 3, 10, 43, 25, 39, // y=2 + 41, 45, 15, 21, 8, // y=3 + 18, 2, 61, 56, 14, // y=4 +}; + +__device__ __forceinline__ uint64_t rotl64(uint64_t x, uint32_t n) { + return (n == 0) ? x : ((x << n) | (x >> (64 - n))); +} + +__device__ __forceinline__ uint64_t bswap64(uint64_t x) { + // Reverse byte order: turns a BE-serialised u64 into its LE-read lane. + x = ((x & 0x00ff00ff00ff00ffULL) << 8) | ((x & 0xff00ff00ff00ff00ULL) >> 8); + x = ((x & 0x0000ffff0000ffffULL) << 16) | ((x & 0xffff0000ffff0000ULL) >> 16); + return (x << 32) | (x >> 32); +} + +__device__ __forceinline__ void keccak_f1600(uint64_t st[25]) { + uint64_t C[5], D[5], B[25]; + #pragma unroll + for (int r = 0; r < 24; ++r) { + // Theta + #pragma unroll + for (int x = 0; x < 5; ++x) { + C[x] = st[x] ^ st[x + 5] ^ st[x + 10] ^ st[x + 15] ^ st[x + 20]; + } + #pragma unroll + for (int x = 0; x < 5; ++x) { + D[x] = C[(x + 4) % 5] ^ rotl64(C[(x + 1) % 5], 1); + } + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + st[x + 5 * y] ^= D[x]; + } + } + + // Rho + Pi: B[pi(x,y)] = rotl(st[x,y], rho(x,y)) + // pi: (x', y') = (y, (2x + 3y) mod 5) + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + int nx = y; + int ny = (2 * x + 3 * y) % 5; + B[nx + 5 * ny] = rotl64(st[x + 5 * y], KECCAK_RHO_OFFSETS[x + 5 * y]); + } + } + + // Chi + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + st[x + 5 * y] = + B[x + 5 * y] ^ ((~B[((x + 1) % 5) + 5 * y]) & B[((x + 2) % 5) + 5 * y]); + } + } + + // Iota + st[0] ^= KECCAK_RC[r]; + } +} + +// --------------------------------------------------------------------------- +// Helper: absorb one 8-byte lane (already in lane form — i.e. LE interpretation +// of the BE-serialised u64) into the sponge at `rate_pos` (in bytes). Permutes +// when a full 136-byte block has been absorbed. +// --------------------------------------------------------------------------- +__device__ __forceinline__ void absorb_lane(uint64_t st[25], + uint32_t &rate_pos, + uint64_t lane) { + st[rate_pos / 8] ^= lane; + rate_pos += 8; + if (rate_pos == 136) { + keccak_f1600(st); + rate_pos = 0; + } +} + +// --------------------------------------------------------------------------- +// After all data lanes absorbed, apply Keccak (pre-SHA-3) padding: a single +// 0x01 byte at the current position, then bit 0x80 on the last rate byte +// (byte 135 = last byte of lane 16). Then permute and squeeze 32 bytes from +// the first four lanes in LE order. +// --------------------------------------------------------------------------- +__device__ __forceinline__ void finalize_keccak256(uint64_t st[25], + uint32_t rate_pos, + uint8_t *out32) { + // 0x01 at rate_pos + st[rate_pos / 8] ^= ((uint64_t)0x01) << ((rate_pos & 7) * 8); + // 0x80 at byte 135 (last byte of lane 16) + st[16] ^= ((uint64_t)0x80) << 56; + keccak_f1600(st); + + // Squeeze 32 bytes: 4 lanes, each LE-serialised. + #pragma unroll + for (int i = 0; i < 4; ++i) { + uint64_t lane = st[i]; + #pragma unroll + for (int b = 0; b < 8; ++b) { + out32[i * 8 + b] = (uint8_t)((lane >> (b * 8)) & 0xff); + } + } +} + +// --------------------------------------------------------------------------- +// Goldilocks BASE-FIELD leaf hashing. +// +// For output row `row_idx` (natural order), the leaf hashes the canonical BE +// byte representation of `columns[c][bit_reverse(row_idx, log_num_rows)]` for +// `c` in `[0, num_cols)`, concatenated in column order. Writes 32 bytes to +// `hashed_leaves_out[row_idx * 32 ..]`. +// +// `columns_base_ptr` points to a `num_cols * col_stride * u64` buffer; column +// `c` is the contiguous slab `[c*col_stride .. c*col_stride + num_rows]`. The +// remaining `col_stride - num_rows` entries (if any) are ignored. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak256_leaves_base_batched( + const uint64_t *columns_base_ptr, + uint64_t col_stride, + uint64_t num_cols, + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *hashed_leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_rows) return; + + // Bit-reverse the row index so we read columns at `br` but write the + // hashed leaf at `tid` — matching the CPU `commit_columns_bit_reversed`. + uint64_t br = __brevll(tid) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + for (uint64_t c = 0; c < num_cols; ++c) { + uint64_t v = columns_base_ptr[c * col_stride + br]; + // Canonicalise to match `canonical_u64().to_be_bytes()` on host. + uint64_t canon = goldilocks::canonical(v); + // The on-disk leaf bytes are canon.to_be_bytes(); Keccak reads those + // as a LE lane, which equals bswap64(canon). + uint64_t lane = bswap64(canon); + absorb_lane(st, rate_pos, lane); + } + + finalize_keccak256(st, rate_pos, hashed_leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// Goldilocks EXT3 leaf hashing (3 base-field components per ext3 element). +// +// Components live in three separate base-field slabs (our de-interleaved +// layout). Column `c` component `k` is at `columns_base_ptr[(c*3 + k)*col_stride +// + br]`. Per-element BE bytes are `[comp0, comp1, comp2]` each 8 BE bytes +// (matches `FieldElement::::write_bytes_be`). +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak256_leaves_ext3_batched( + const uint64_t *columns_base_ptr, + uint64_t col_stride, + uint64_t num_cols, // number of ext3 columns (NOT slabs) + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *hashed_leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_rows) return; + uint64_t br = __brevll(tid) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + for (uint64_t c = 0; c < num_cols; ++c) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = columns_base_ptr[(c * 3 + (uint64_t)k) * col_stride + br]; + uint64_t canon = goldilocks::canonical(v); + uint64_t lane = bswap64(canon); + absorb_lane(st, rate_pos, lane); + } + } + + finalize_keccak256(st, rate_pos, hashed_leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// R2 composition-polynomial leaf hashing. +// +// Each leaf hashes `2 * num_parts` ext3 values taken from bit-reversed rows +// `br_0 = reverse_index(2*leaf_idx)` and `br_1 = reverse_index(2*leaf_idx+1)` +// across all `num_parts` parts, in (br_0 row: part 0..K-1) then (br_1 row: +// part 0..K-1) order. Each ext3 value is 3 base components × 8 BE bytes. +// +// Columns arrive in the de-interleaved 3-slab layout: part `p` component +// `k` is at `parts_base_ptr[(p*3 + k) * col_stride + row]`. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_comp_poly_leaves_ext3( + const uint64_t *parts_base_ptr, + uint64_t col_stride, + uint64_t num_parts, + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t num_leaves = num_rows >> 1; + if (tid >= num_leaves) return; + + uint64_t br_0 = __brevll(2 * tid) >> (64 - log_num_rows); + uint64_t br_1 = __brevll(2 * tid + 1) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + // First row (br_0): part 0..K-1 × 3 components each. + for (uint64_t p = 0; p < num_parts; ++p) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = parts_base_ptr[(p * 3 + (uint64_t)k) * col_stride + br_0]; + uint64_t canon = goldilocks::canonical(v); + absorb_lane(st, rate_pos, bswap64(canon)); + } + } + // Second row (br_1). + for (uint64_t p = 0; p < num_parts; ++p) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = parts_base_ptr[(p * 3 + (uint64_t)k) * col_stride + br_1]; + uint64_t canon = goldilocks::canonical(v); + absorb_lane(st, rate_pos, bswap64(canon)); + } + } + + finalize_keccak256(st, rate_pos, leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// FRI layer leaf hashing. +// +// Each leaf hashes 2 consecutive ext3 values: Keccak256 over +// evals[2j].to_bytes_be() ++ evals[2j+1].to_bytes_be() +// = 48 BE bytes = 6 u64 BE lanes. No bit reversal; no column slab layout — +// the input is a single interleaved ext3 eval vector `[a0,a1,a2,b0,b1,b2,...]`. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_fri_leaves_ext3( + const uint64_t *evals_interleaved, // 3 * num_evals u64s (ext3 interleaved) + uint64_t num_leaves, // = num_evals / 2 + uint8_t *leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_leaves) return; + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + uint32_t rate_pos = 0; + + const uint64_t *left = evals_interleaved + 2 * tid * 3; // 3 u64s + const uint64_t *right = left + 3; + #pragma unroll + for (int i = 0; i < 3; ++i) { + uint64_t canon = goldilocks::canonical(left[i]); + absorb_lane(st, rate_pos, bswap64(canon)); + } + #pragma unroll + for (int i = 0; i < 3; ++i) { + uint64_t canon = goldilocks::canonical(right[i]); + absorb_lane(st, rate_pos, bswap64(canon)); + } + + finalize_keccak256(st, rate_pos, leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// Merkle inner-tree pair hash: one level of the inner Merkle tree. +// +// `nodes` is the full Merkle node buffer (length `2*leaves_len - 1`, each +// element 32 bytes). `parent_begin` is the node-index offset of the first +// parent slot in this level; children live at `parent_begin + n_pairs`. +// The layout mirrors `crypto/crypto/src/merkle_tree/merkle.rs`: +// +// children: nodes[parent_begin + n_pairs .. parent_begin + 3 * n_pairs] +// parents: nodes[parent_begin .. parent_begin + n_pairs] +// +// Each thread hashes one child pair → one parent. Keccak-256 of the +// concatenation of two 32-byte siblings; identical to +// `FieldElementVectorBackend::hash_new_parent` on host. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_merkle_level( + uint8_t *nodes, + uint64_t parent_begin, // node index (counted in 32-byte nodes) + uint64_t n_pairs) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_pairs) return; + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + const uint64_t *left = reinterpret_cast( + nodes + (parent_begin + n_pairs + 2 * tid) * 32); + #pragma unroll + for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, left[i]); + + const uint64_t *right = reinterpret_cast( + nodes + (parent_begin + n_pairs + 2 * tid + 1) * 32); + #pragma unroll + for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, right[i]); + + finalize_keccak256(st, rate_pos, nodes + (parent_begin + tid) * 32); +} diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index 03f0b67ca..8e956eef3 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -91,6 +91,7 @@ impl Drop for PinnedStaging { const ARITH_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/arith.ptx")); const NTT_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/ntt.ptx")); +const KECCAK_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/keccak.ptx")); /// Number of CUDA streams in the pool. Larger pools let many rayon-parallel /// callers overlap on the GPU without serializing on stream ownership. The /// default stream is deliberately excluded because it synchronises with all @@ -106,6 +107,11 @@ pub struct Backend { /// buffers 32×-inflated memory use and multiplied the one-time pinning /// cost for every first use of a new table size). pinned_staging: Mutex, + /// Separate pinned staging for Merkle leaf hashes. Sized `num_rows * 32` + /// bytes; lives alongside the LDE staging so the GPU→host D2H for + /// hashed leaves runs at full PCIe line-rate instead of the pageable + /// ~1.3 GB/s path that would otherwise eat ~100 ms per main-trace commit. + pinned_hashes: Mutex, util_stream: Arc, next: AtomicUsize, @@ -130,6 +136,13 @@ pub struct Backend { pub pointwise_mul_batched: CudaFunction, pub scalar_mul_batched: CudaFunction, + // keccak.ptx + pub keccak256_leaves_base_batched: CudaFunction, + pub keccak256_leaves_ext3_batched: CudaFunction, + pub keccak_comp_poly_leaves_ext3: CudaFunction, + pub keccak_fri_leaves_ext3: CudaFunction, + pub keccak_merkle_level: CudaFunction, + // Twiddle caches keyed by log_n. fwd_twiddles: Mutex>>>>, inv_twiddles: Mutex>>>>, @@ -146,12 +159,14 @@ impl Backend { let arith = ctx.load_module(Ptx::from_src(ARITH_PTX))?; let ntt = ctx.load_module(Ptx::from_src(NTT_PTX))?; + let keccak = ctx.load_module(Ptx::from_src(KECCAK_PTX))?; let mut streams = Vec::with_capacity(STREAM_POOL_SIZE); for _ in 0..STREAM_POOL_SIZE { streams.push(ctx.new_stream()?); } let pinned_staging = Mutex::new(PinnedStaging::empty()); + let pinned_hashes = Mutex::new(PinnedStaging::empty()); // Separate "utility" stream for twiddle uploads and other bookkeeping; // not part of the pool that callers rotate through. let util_stream = ctx.new_stream()?; @@ -178,11 +193,17 @@ impl Backend { ntt_dit_8_levels_batched: ntt.load_function("ntt_dit_8_levels_batched")?, pointwise_mul_batched: ntt.load_function("pointwise_mul_batched")?, scalar_mul_batched: ntt.load_function("scalar_mul_batched")?, + keccak256_leaves_base_batched: keccak.load_function("keccak256_leaves_base_batched")?, + keccak256_leaves_ext3_batched: keccak.load_function("keccak256_leaves_ext3_batched")?, + keccak_comp_poly_leaves_ext3: keccak.load_function("keccak_comp_poly_leaves_ext3")?, + keccak_fri_leaves_ext3: keccak.load_function("keccak_fri_leaves_ext3")?, + keccak_merkle_level: keccak.load_function("keccak_merkle_level")?, fwd_twiddles: Mutex::new(vec![None; max_log]), inv_twiddles: Mutex::new(vec![None; max_log]), ctx, streams, pinned_staging, + pinned_hashes, util_stream, next: AtomicUsize::new(0), }) @@ -201,6 +222,12 @@ impl Backend { &self.pinned_staging } + /// Separate pinned staging for Merkle leaf hash output. Sized in u64 + /// units; caller should reserve `(num_rows * 32 + 7) / 8` u64s. + pub fn pinned_hashes(&self) -> &Mutex { + &self.pinned_hashes + } + pub fn fwd_twiddles_for(&self, log_n: u64) -> Result>> { self.cached_twiddles(log_n, true) } diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 22d62ca3e..ccf5abb1d 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -16,6 +16,7 @@ use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; use crate::Result; use crate::device::backend; +use crate::merkle::{launch_keccak_base, launch_keccak_ext3}; use crate::ntt::run_ntt_body; /// Handle to a base-field LDE kept live on device after R1 commit. @@ -508,6 +509,963 @@ pub fn coset_lde_batch_base_into( Ok(()) } +/// Variant of `coset_lde_batch_base_into` that also emits the Keccak-256 +/// Merkle leaf hashes from the LDE output — all on GPU, no second H2D of +/// the LDE data. Leaves are computed reading columns at bit-reversed rows +/// (matching `commit_columns_bit_reversed` on the CPU side). +/// +/// `hashed_leaves_out` must be `lde_size * 32` bytes (one 32-byte digest +/// per output row, in natural row order). +pub fn coset_lde_batch_base_into_with_leaf_hash( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + hashed_leaves_out: &mut [u8], +) -> Result<()> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + let n = columns[0].len(); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), lde_size); + } + assert_eq!(hashed_leaves_out.len(), lde_size * 32); + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + + use rayon::prelude::*; + let pinned_base_ptr = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: disjoint regions per c, outer staging lock held. + let dst = + unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; + dst.copy_from_slice(col); + }); + + let mut buf = stream.alloc_zeros::(m * lde_size)?; + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // iNTT + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + // pointwise coset scale + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + // forward NTT on full LDE slab + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + + // Keccak-256 leaf hashing directly on the device LDE buffer. + let mut hashes_dev = stream.alloc_zeros::(lde_size * 32)?; + launch_keccak_base( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut hashes_dev, + )?; + + // D2H the LDE into the pinned LDE staging and the hashes into a + // dedicated pinned hash staging, in parallel on the same stream. Both + // at pinned PCIe line-rate — pageable D2H of the 128 MB hash buffer + // would otherwise cost ~100 ms per main-trace commit. + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + let hashes_u64_len = (lde_size * 32).div_ceil(8); + let hashes_staging_slot = be.pinned_hashes(); + let mut hashes_staging = hashes_staging_slot.lock().unwrap(); + hashes_staging.ensure_capacity(hashes_u64_len, &be.ctx)?; + let hashes_pinned = unsafe { hashes_staging.as_mut_slice(hashes_u64_len) }; + // `memcpy_dtoh` needs a byte slice. Reinterpret the u64 pinned buffer + // as bytes — same allocation, just typed differently. + let hashes_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(hashes_pinned.as_mut_ptr() as *mut u8, lde_size * 32) + }; + stream.memcpy_dtoh(&hashes_dev, hashes_pinned_bytes)?; + stream.synchronize()?; + + // Copy pinned → caller outputs in parallel with the hash memcpy. + let pinned_ptr = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) + }; + dst.copy_from_slice(src); + }); + // Rayon-parallel memcpy of 128 MB from pinned → caller. Single-threaded + // `copy_from_slice` faults virgin pageable pages one at a time; the + // mm_struct rwsem serialises them into ~100 ms at 1M-fib scale. Chunk + // the slice so ~N cores pre-fault+write in parallel. + const CHUNK: usize = 64 * 1024; // 64 KiB ≈ 16 pages per chunk + let pinned_hash_ptr = hashes_pinned_bytes.as_ptr() as usize; + hashed_leaves_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_hash_ptr as *const u8).add(i * CHUNK), dst.len()) + }; + dst.copy_from_slice(src); + }); + drop(hashes_staging); + drop(staging); + Ok(()) +} + +/// Like `coset_lde_batch_base_into_with_leaf_hash`, but also builds the full +/// Merkle tree on device and returns the `2*lde_size - 1` node buffer back +/// to the caller in `merkle_nodes_out` (byte length `(2*lde_size - 1) * 32`). +/// +/// The leaf hashes are never exposed to the caller — they stay on device and +/// feed straight into the pair-hash tree kernel, avoiding the +/// pinned→pageable→pinned round-trip that the separate-step GPU tree build +/// would pay. +pub fn coset_lde_batch_base_into_with_merkle_tree( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + coset_lde_batch_base_into_with_merkle_tree_inner( + columns, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + false, + ) + .map(|_| ()) +} + +/// Fused LDE + leaf-hash + Merkle tree build. If `keep_device_buf` is true, +/// returns an `Arc>` wrapping the LDE device buffer so callers +/// (R2–R4 GPU paths) can reuse the LDE without a re-H2D. +pub fn coset_lde_batch_base_into_with_merkle_tree_keep( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result { + let opt = coset_lde_batch_base_into_with_merkle_tree_inner( + columns, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + true, + )?; + let handle = opt.expect("keep_device_buf=true must return Some"); + Ok(handle) +} + +fn coset_lde_batch_base_into_with_merkle_tree_inner( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], + keep_device_buf: bool, +) -> Result> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(None); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + let n = columns[0].len(); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), lde_size); + } + let total_nodes = 2 * lde_size - 1; + assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + + use rayon::prelude::*; + let pinned_base_ptr = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + let dst = + unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; + dst.copy_from_slice(col); + }); + + let mut buf = stream.alloc_zeros::(m * lde_size)?; + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // iNTT + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + // forward NTT at LDE size + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + + // Allocate the full node buffer; leaves occupy the tail slab, inner + // nodes are written by the pair-hash level kernel below. `alloc` (not + // `alloc_zeros`) is safe because every byte is written before it is + // read: leaf kernel fills the tail, tree kernel fills the head. + // + // The leaf kernel writes to `nodes_dev` starting at byte offset + // `(lde_size - 1) * 32`; we pass the base pointer as-is because the + // kernel indexes linearly from `hashed_leaves_out[row_idx * 32]`, so we + // build an offset device slice and feed that to the launch. + let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; + let leaves_offset_bytes = (lde_size - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); + let log_num_rows_leaves = lde_size.trailing_zeros() as u64; + let num_cols_u64 = m as u64; + let grid = (lde_size as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak256_leaves_base_batched) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_cols_u64) + .arg(&lde_u64) + .arg(&log_num_rows_leaves) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree levels — mirror the CPU `build(nodes, leaves_len)` scan. + { + let mut level_begin: u64 = (lde_size - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = (n_pairs as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + // D2H the LDE and the tree nodes via pinned staging. + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + + let tree_u64_len = (total_nodes * 32).div_ceil(8); + let tree_staging_slot = be.pinned_hashes(); + let mut tree_staging = tree_staging_slot.lock().unwrap(); + tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; + let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; + let tree_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, total_nodes * 32) + }; + stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; + stream.synchronize()?; + + // Parallel memcpy pinned → caller. + let pinned_ptr = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) + }; + dst.copy_from_slice(src); + }); + const CHUNK: usize = 64 * 1024; + let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; + merkle_nodes_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) + }; + dst.copy_from_slice(src); + }); + drop(tree_staging); + drop(staging); + + if keep_device_buf { + Ok(Some(GpuLdeBase { + buf: Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + +/// Ext3 variant of `coset_lde_batch_base_into_with_leaf_hash`: run an LDE +/// over ext3 columns AND emit Keccak-256 Merkle leaves, all in one on-device +/// pipeline. +pub fn coset_lde_batch_ext3_into_with_leaf_hash( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + hashed_leaves_out: &mut [u8], +) -> Result<()> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in columns.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + assert_eq!(hashed_leaves_out.len(), lde_size * 32); + if n == 0 { + return Ok(()); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + // Keccak-256 on the de-interleaved device buffer (3M base slabs). + let mut hashes_dev = stream.alloc_zeros::(lde_size * 32)?; + launch_keccak_ext3( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut hashes_dev, + )?; + + // D2H LDE (mb * lde_size u64) and hashes (lde_size * 32 bytes). + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + let hashes_u64_len = (lde_size * 32).div_ceil(8); + let hashes_staging_slot = be.pinned_hashes(); + let mut hashes_staging = hashes_staging_slot.lock().unwrap(); + hashes_staging.ensure_capacity(hashes_u64_len, &be.ctx)?; + let hashes_pinned = unsafe { hashes_staging.as_mut_slice(hashes_u64_len) }; + let hashes_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(hashes_pinned.as_mut_ptr() as *mut u8, lde_size * 32) + }; + stream.memcpy_dtoh(&hashes_dev, hashes_pinned_bytes)?; + stream.synchronize()?; + + // Re-interleave pinned → caller ext3 outputs, parallel. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + + // Parallel memcpy of pinned hashes → caller. + const CHUNK: usize = 64 * 1024; + let hash_src_ptr = hashes_pinned_bytes.as_ptr() as usize; + hashed_leaves_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts((hash_src_ptr as *const u8).add(i * CHUNK), dst.len()) + }; + dst.copy_from_slice(src); + }); + drop(hashes_staging); + drop(staging); + Ok(()) +} + +/// Ext3 variant of the fused `coset_lde_batch_base_into_with_merkle_tree`. +/// LDE + leaf hashing + inner-tree build, all on device; D2Hs only the LDE +/// evaluations and the full `2*lde_size - 1` node buffer. +pub fn coset_lde_batch_ext3_into_with_merkle_tree( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + false, + ) + .map(|_| ()) +} + +/// Ext3 variant of [`coset_lde_batch_base_into_with_merkle_tree_keep`] — +/// returns an `Arc>` handle to the de-interleaved LDE device +/// buffer. +pub fn coset_lde_batch_ext3_into_with_merkle_tree_keep( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result { + let opt = coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + true, + )?; + Ok(opt.expect("keep_device_buf=true must return Some")) +} + +fn coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], + keep_device_buf: bool, +) -> Result> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(None); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in columns.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + let total_nodes = 2 * lde_size - 1; + assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + if n == 0 { + return Ok(None); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + // Allocate full tree buffer; leaf kernel writes to the tail slab. + let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; + let leaves_offset_bytes = (lde_size - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); + let log_num_rows_leaves = lde_size.trailing_zeros() as u64; + let num_cols_u64 = m as u64; + let grid = (lde_size as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak256_leaves_ext3_batched) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_cols_u64) + .arg(&lde_u64) + .arg(&log_num_rows_leaves) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree levels. + { + let mut level_begin: u64 = (lde_size - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = (n_pairs as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + // D2H LDE (mb * lde_size u64) and tree nodes. + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + let tree_u64_len = (total_nodes * 32).div_ceil(8); + let tree_staging_slot = be.pinned_hashes(); + let mut tree_staging = tree_staging_slot.lock().unwrap(); + tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; + let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; + let tree_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, total_nodes * 32) + }; + stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; + stream.synchronize()?; + + // Re-interleave pinned → caller ext3 outputs. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + + const CHUNK: usize = 64 * 1024; + let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; + merkle_nodes_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) + }; + dst.copy_from_slice(src); + }); + drop(tree_staging); + drop(staging); + + if keep_device_buf { + Ok(Some(GpuLdeExt3 { + buf: Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + /// Batched ext3 polynomial → coset evaluation. /// /// Input: M ext3 columns of `n` coefficients each (interleaved, 3n u64). @@ -708,6 +1666,237 @@ fn evaluate_poly_coset_batch_ext3_into_inner( } } +/// Fused variant of [`evaluate_poly_coset_batch_ext3_into`]: in addition to +/// the LDE output, builds the R2 composition-polynomial Merkle tree on device +/// (row-pair Keccak leaves at bit-reversed indices + pair-hash inner tree). +/// +/// `merkle_nodes_out` must have byte length `(2 * lde_size - 1) * 32`. +/// Requires `lde_size >= 2`. +pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + if coefs.is_empty() { + return Ok(()); + } + let m = coefs.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in coefs.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + assert!(lde_size >= 2); + let total_nodes = 2 * lde_size - 1; + assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + if n == 0 { + return Ok(()); + } + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + coefs.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + // Build the row-pair Merkle tree on device. + // + // Row-pair commit: each leaf hashes 2 rows (bit-reversed indices) → + // num_leaves = lde_size / 2. Tree size: 2*num_leaves - 1 = lde_size - 1. + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let log_num_rows = log_lde; + let num_parts_u64 = m as u64; + let grid = (num_leaves as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_comp_poly_leaves_ext3) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_parts_u64) + .arg(&lde_u64) + .arg(&log_num_rows) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + { + let mut level_begin: u64 = (num_leaves - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = (n_pairs as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + // D2H LDE and tree. + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + let tree_u64_len = (tight_total_nodes * 32).div_ceil(8); + let tree_staging_slot = be.pinned_hashes(); + let mut tree_staging = tree_staging_slot.lock().unwrap(); + tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; + let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; + let tree_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, tight_total_nodes * 32) + }; + stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; + stream.synchronize()?; + + // Re-interleave pinned → caller ext3 outputs. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + + // Copy pinned tree → caller nodes_out. `merkle_nodes_out.len() == + // total_nodes * 32` is oversized relative to our tight tree; we write + // only the first `tight_total_nodes * 32` bytes and the caller trims. + // Expose the tight byte count via the slice length so the caller can + // construct the MerkleTree with the right node count. + assert!(merkle_nodes_out.len() >= tight_total_nodes * 32); + const CHUNK: usize = 64 * 1024; + let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; + merkle_nodes_out[..tight_total_nodes * 32] + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) + }; + dst.copy_from_slice(src); + }); + drop(tree_staging); + drop(staging); + Ok(()) +} /// Batched coset LDE for Goldilocks **cubic extension** columns. /// /// A degree-3 extension element is `(a, b, c)` in memory (three contiguous diff --git a/crypto/math-cuda/src/lib.rs b/crypto/math-cuda/src/lib.rs index 821f5bd3a..d5d0c9fbc 100644 --- a/crypto/math-cuda/src/lib.rs +++ b/crypto/math-cuda/src/lib.rs @@ -6,6 +6,7 @@ pub mod device; pub mod lde; +pub mod merkle; pub mod ntt; use cudarc::driver::{LaunchConfig, PushKernelArg}; diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs new file mode 100644 index 000000000..0f80206a5 --- /dev/null +++ b/crypto/math-cuda/src/merkle.rs @@ -0,0 +1,415 @@ +//! GPU Keccak-256 leaf hashing for Merkle commits. +//! +//! Matches `FieldElementVectorBackend::hash_data` in +//! `crypto/crypto/src/merkle_tree/backends/field_element_vector.rs`, combined +//! with the `reverse_index` row read pattern used in +//! `commit_columns_bit_reversed` at `crypto/stark/src/prover.rs:368`. +//! +//! Caller supplies base-field column slabs already laid out as +//! `[col * col_stride + row]` (the same layout `coset_lde_batch_base_into` +//! writes to the pinned staging buffer). The kernel bit-reverses `row_idx`, +//! reads each column's canonical u64 at that row, byte-swaps it into a +//! Keccak lane, absorbs lane-by-lane, and squeezes 32 bytes per leaf. +//! +//! For ext3 columns the layout is `[col*3*col_stride + k*col_stride + row]` +//! — three base slabs per ext3 column — and the kernel reads three u64s per +//! column in component order 0,1,2 to match `FieldElement::::write_bytes_be`. + +use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; + +/// Run GPU Keccak-256 leaf hashing on a base-field column buffer. +/// +/// `columns` must hold `num_cols * col_stride` u64s with column `c`'s data +/// at `[c*col_stride .. c*col_stride + num_rows]`. Returns `num_rows * 32` +/// hash bytes in natural (non-bit-reversed) row order. +pub fn keccak_leaves_base( + columns: &[u64], + col_stride: usize, + num_cols: usize, + num_rows: usize, +) -> Result> { + assert!(num_rows.is_power_of_two()); + assert!(columns.len() >= num_cols * col_stride); + let be = backend(); + let stream = be.next_stream(); + let cols_dev = stream.clone_htod(&columns[..num_cols * col_stride])?; + let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; + launch_keccak_base( + stream.as_ref(), + &cols_dev, + col_stride as u64, + num_cols as u64, + num_rows as u64, + &mut out_dev, + )?; + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Ext3 variant — columns interleaved as three base slabs per ext3 column. +/// `columns.len() >= num_cols * 3 * col_stride`. +pub fn keccak_leaves_ext3( + columns: &[u64], + col_stride: usize, + num_cols: usize, + num_rows: usize, +) -> Result> { + assert!(num_rows.is_power_of_two()); + assert!(columns.len() >= num_cols * 3 * col_stride); + let be = backend(); + let stream = be.next_stream(); + let cols_dev = stream.clone_htod(&columns[..num_cols * 3 * col_stride])?; + let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; + launch_keccak_ext3( + stream.as_ref(), + &cols_dev, + col_stride as u64, + num_cols as u64, + num_rows as u64, + &mut out_dev, + )?; + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Block size for Keccak kernels. Per-thread register footprint is ~60 regs +/// (25-lane state + auxiliaries); the default 256 threads/block pushes the +/// block register file past the hardware limit on sm_120 (Blackwell). 128 +/// keeps us inside the budget with some head-room. +const KECCAK_BLOCK_DIM: u32 = 128; + +fn keccak_launch_cfg(num_rows: u64) -> LaunchConfig { + let grid = (num_rows as u32).div_ceil(KECCAK_BLOCK_DIM); + LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (KECCAK_BLOCK_DIM, 1, 1), + shared_mem_bytes: 0, + } +} + +pub(crate) fn launch_keccak_base( + stream: &CudaStream, + cols_dev: &CudaSlice, + col_stride: u64, + num_cols: u64, + num_rows: u64, + out_dev: &mut CudaSlice, +) -> Result<()> { + let be = backend(); + let log_num_rows = num_rows.trailing_zeros() as u64; + let cfg = keccak_launch_cfg(num_rows); + unsafe { + stream + .launch_builder(&be.keccak256_leaves_base_batched) + .arg(cols_dev) + .arg(&col_stride) + .arg(&num_cols) + .arg(&num_rows) + .arg(&log_num_rows) + .arg(out_dev) + .launch(cfg)?; + } + Ok(()) +} + +/// Given `hashed_leaves` of length `leaves_len * 32`, build the full Merkle +/// tree on device and return the complete node buffer `(2*leaves_len - 1) * +/// 32` bytes in the standard layout: +/// +/// `nodes[0..leaves_len - 1]` are inner nodes (root at index 0), and +/// `nodes[leaves_len - 1..]` are the leaves themselves. +/// +/// Matches the CPU `crypto/crypto/src/merkle_tree/merkle.rs` construction so +/// the resulting `nodes` Vec plugs straight into `MerkleTree { root, nodes }` +/// for downstream proof generation. +/// +/// `leaves_len` must be a power of two and ≥ 2. +pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { + assert!(hashed_leaves.len().is_multiple_of(32)); + let leaves_len = hashed_leaves.len() / 32; + assert!(leaves_len >= 2, "tree needs at least two leaves"); + assert!( + leaves_len.is_power_of_two(), + "leaves_len must be a power of two" + ); + + let total_nodes = 2 * leaves_len - 1; + let be = backend(); + let stream = be.next_stream(); + + // Allocate the full node buffer without zero-fill — we overwrite the + // leaf half via H2D immediately, and every inner node is written by the + // pair-hash kernel below. + // SAFETY: every byte is written before it is read: leaves are filled by + // the H2D below; inner nodes are filled by the level loop that follows. + let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; + let leaves_offset_bytes = (leaves_len - 1) * 32; + // SAFETY: target slice `nodes_dev[leaves_offset_bytes..]` has exactly + // `leaves_len * 32 == hashed_leaves.len()` bytes capacity. + { + let mut slice = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + hashed_leaves.len()); + stream.memcpy_htod(hashed_leaves, &mut slice)?; + } + + // Build level by level. The CPU `build(nodes, leaves_len)` starts with + // level_begin_index = leaves_len - 1 + // level_end_index = 2 * level_begin_index + // and each iteration computes: + // new_level_begin_index = level_begin_index / 2 + // new_level_length = level_begin_index - new_level_begin_index + // The parents occupy [new_level_begin_index, level_begin_index); the + // children occupy [level_begin_index, level_end_index + 1). + let mut level_begin: u64 = (leaves_len - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + + let cfg = keccak_launch_cfg(n_pairs); + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + + let out = stream.clone_dtoh(&nodes_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Row-pair Keccak leaf + Merkle tree build for R2 composition-polynomial +/// commit. `parts_interleaved` is `num_parts` slices, each holding an ext3 +/// LDE column interleaved as `[a0,a1,a2, b0,b1,b2, ...]` of length `3*lde_size`. +/// +/// Returns `(2*(lde_size/2) - 1) * 32` bytes of tree nodes in the standard +/// layout (root at byte offset 0, leaves in the tail). +pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Result> { + assert!(!parts_interleaved.is_empty()); + let m = parts_interleaved.len(); + let ext3_elems = parts_interleaved[0].len() / 3; + assert_eq!( + parts_interleaved[0].len(), + 3 * ext3_elems, + "ext3 buffer length must be 3 * lde_size" + ); + for p in parts_interleaved.iter() { + assert_eq!(p.len(), 3 * ext3_elems); + } + let lde_size = ext3_elems; + assert!(lde_size.is_power_of_two() && lde_size >= 2); + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + // Stage: de-interleave each part into 3 base slabs in pinned memory. + let mb = 3 * m; + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + parts_interleaved + .par_iter() + .enumerate() + .for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + // H2D the de-interleaved parts. + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + stream.memcpy_htod(&pinned[..mb * lde_size], &mut buf)?; + + // Leaves into tail of a tight node buffer. + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let col_stride_u64 = lde_size as u64; + let num_parts_u64 = m as u64; + let num_rows_u64 = lde_size as u64; + let log_num_rows = lde_size.trailing_zeros() as u64; + let grid = (num_leaves as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_comp_poly_leaves_ext3) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_parts_u64) + .arg(&num_rows_u64) + .arg(&log_num_rows) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree. + { + let mut level_begin: u64 = (num_leaves - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = (n_pairs as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + let out = stream.clone_dtoh(&nodes_dev)?; + stream.synchronize()?; + drop(staging); + Ok(out) +} + +/// Build a FRI-layer Merkle tree on device from an interleaved ext3 eval +/// vector. Each leaf hashes two consecutive ext3 values; `num_leaves = +/// evals.len() / 6` (since each ext3 is 3 u64s). +/// +/// Returns the `(2*num_leaves - 1) * 32`-byte node buffer in standard layout. +pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { + assert!( + evals.len().is_multiple_of(6), + "evals must hold whole pair-leaves" + ); + let num_evals = evals.len() / 3; + let num_leaves = num_evals / 2; + assert!(num_leaves.is_power_of_two() && num_leaves >= 1); + let tight_total_nodes = 2 * num_leaves - 1; + if tight_total_nodes == 0 { + return Ok(Vec::new()); + } + + let be = backend(); + let stream = be.next_stream(); + + let evals_dev = stream.clone_htod(evals)?; + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + + // Leaf kernel: num_leaves threads, one leaf each. + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let num_leaves_u64 = num_leaves as u64; + let grid = (num_leaves as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_fri_leaves_ext3) + .arg(&evals_dev) + .arg(&num_leaves_u64) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree levels — identical to the R2 version. + { + let mut level_begin: u64 = (num_leaves - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = (n_pairs as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + let out = stream.clone_dtoh(&nodes_dev)?; + stream.synchronize()?; + Ok(out) +} + +pub(crate) fn launch_keccak_ext3( + stream: &CudaStream, + cols_dev: &CudaSlice, + col_stride: u64, + num_cols: u64, + num_rows: u64, + out_dev: &mut CudaSlice, +) -> Result<()> { + let be = backend(); + let log_num_rows = num_rows.trailing_zeros() as u64; + let cfg = keccak_launch_cfg(num_rows); + unsafe { + stream + .launch_builder(&be.keccak256_leaves_ext3_batched) + .arg(cols_dev) + .arg(&col_stride) + .arg(&num_cols) + .arg(&num_rows) + .arg(&log_num_rows) + .arg(out_dev) + .launch(cfg)?; + } + Ok(()) +} diff --git a/crypto/math-cuda/tests/keccak_leaves.rs b/crypto/math-cuda/tests/keccak_leaves.rs new file mode 100644 index 000000000..1e451b386 --- /dev/null +++ b/crypto/math-cuda/tests/keccak_leaves.rs @@ -0,0 +1,140 @@ +//! Parity: GPU Keccak-256 leaf hashes must match CPU +//! `FieldElementVectorBackend::::hash_data` applied to +//! bit-reversed rows (same pattern as `commit_columns_bit_reversed` in the +//! stark prover). + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::traits::ByteConversion; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sha3::{Digest, Keccak256}; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn reverse_index(i: u64, n: u64) -> u64 { + let log_n = n.trailing_zeros(); + i.reverse_bits() >> (64 - log_n) +} + +fn cpu_leaves_base(columns: &[Vec]) -> Vec<[u8; 32]> { + let num_rows = columns[0].len(); + let num_cols = columns.len(); + let byte_len = 8; + (0..num_rows) + .map(|row_idx| { + let br = reverse_index(row_idx as u64, num_rows as u64) as usize; + let mut buf = vec![0u8; num_cols * byte_len]; + for c in 0..num_cols { + columns[c][br].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]); + } + let mut h = Keccak256::new(); + h.update(&buf); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out + }) + .collect() +} + +fn cpu_leaves_ext3(columns: &[Vec]) -> Vec<[u8; 32]> { + let num_rows = columns[0].len(); + let num_cols = columns.len(); + let byte_len = 24; + (0..num_rows) + .map(|row_idx| { + let br = reverse_index(row_idx as u64, num_rows as u64) as usize; + let mut buf = vec![0u8; num_cols * byte_len]; + for c in 0..num_cols { + columns[c][br].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]); + } + let mut h = Keccak256::new(); + h.update(&buf); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out + }) + .collect() +} + +#[test] +fn keccak_leaves_base_matches_cpu() { + for log_n in [4u32, 6, 8, 10, 12] { + for num_cols in [1usize, 5, 17, 41] { + let n = 1 << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(100 + log_n as u64 + num_cols as u64); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| Fp::from_raw(rng.r#gen::())).collect()) + .collect(); + + let cpu = cpu_leaves_base(&columns); + + // Flatten columns into a contiguous base slab layout matching + // `coset_lde_batch_base_into`'s pinned staging format: + // `[col * stride + row]`. Use stride = num_rows for compactness. + let mut flat = vec![0u64; num_cols * n]; + for (c, col) in columns.iter().enumerate() { + for (r, e) in col.iter().enumerate() { + flat[c * n + r] = *e.value(); + } + } + let gpu = math_cuda::merkle::keccak_leaves_base(&flat, n, num_cols, n).unwrap(); + assert_eq!(gpu.len(), n * 32); + for i in 0..n { + assert_eq!( + &gpu[i * 32..(i + 1) * 32], + &cpu[i][..], + "base leaf mismatch at row {i} (log_n={log_n}, cols={num_cols})" + ); + } + } + } +} + +#[test] +fn keccak_leaves_ext3_matches_cpu() { + for log_n in [4u32, 6, 8, 10] { + for num_cols in [1usize, 3, 11, 20] { + let n = 1 << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(200 + log_n as u64 + num_cols as u64); + let columns: Vec> = (0..num_cols) + .map(|_| { + (0..n) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() + }) + .collect(); + + let cpu = cpu_leaves_ext3(&columns); + + // GPU expects 3 base slabs per ext3 column in the order + // [col*3+0 (comp a), col*3+1 (comp b), col*3+2 (comp c)], each a + // contiguous slab of n u64s (length = num_cols * 3 * n). + let mut flat = vec![0u64; num_cols * 3 * n]; + for (c, col) in columns.iter().enumerate() { + for (r, e) in col.iter().enumerate() { + flat[(c * 3) * n + r] = *e.value()[0].value(); + flat[(c * 3 + 1) * n + r] = *e.value()[1].value(); + flat[(c * 3 + 2) * n + r] = *e.value()[2].value(); + } + } + let gpu = math_cuda::merkle::keccak_leaves_ext3(&flat, n, num_cols, n).unwrap(); + assert_eq!(gpu.len(), n * 32); + for i in 0..n { + assert_eq!( + &gpu[i * 32..(i + 1) * 32], + &cpu[i][..], + "ext3 leaf mismatch at row {i} (log_n={log_n}, cols={num_cols})" + ); + } + } + } +} diff --git a/crypto/math-cuda/tests/merkle_tree.rs b/crypto/math-cuda/tests/merkle_tree.rs new file mode 100644 index 000000000..34d44c767 --- /dev/null +++ b/crypto/math-cuda/tests/merkle_tree.rs @@ -0,0 +1,92 @@ +//! Parity: GPU Merkle inner-tree construction must match the CPU +//! `crypto/crypto/src/merkle_tree/merkle.rs` `build_from_hashed_leaves` +//! (Keccak-256 pair hash at each level). + +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sha3::{Digest, Keccak256}; + +fn cpu_hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] { + let mut h = Keccak256::new(); + h.update(left); + h.update(right); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out +} + +/// CPU reference: same algorithm as `build_from_hashed_leaves`. +fn cpu_merkle_nodes(leaves: &[[u8; 32]]) -> Vec<[u8; 32]> { + let leaves_len = leaves.len(); + assert!(leaves_len.is_power_of_two() && leaves_len >= 2); + let total = 2 * leaves_len - 1; + + let mut nodes: Vec<[u8; 32]> = vec![[0u8; 32]; total]; + for (i, leaf) in leaves.iter().enumerate() { + nodes[leaves_len - 1 + i] = *leaf; + } + + let mut level_begin = leaves_len - 1; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + for j in 0..n_pairs { + let left = nodes[level_begin + 2 * j]; + let right = nodes[level_begin + 2 * j + 1]; + nodes[new_begin + j] = cpu_hash_pair(&left, &right); + } + level_begin = new_begin; + } + nodes +} + +fn run_parity(log_n: u32, seed: u64) { + let leaves_len = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let leaves: Vec<[u8; 32]> = (0..leaves_len) + .map(|_| { + let mut arr = [0u8; 32]; + rng.fill(&mut arr[..]); + arr + }) + .collect(); + + // Flat byte layout for the GPU entry point. + let mut flat = Vec::with_capacity(leaves_len * 32); + for l in &leaves { + flat.extend_from_slice(l); + } + + let gpu_nodes_bytes = math_cuda::merkle::build_merkle_tree_on_device(&flat).unwrap(); + assert_eq!(gpu_nodes_bytes.len(), (2 * leaves_len - 1) * 32); + + let cpu_nodes = cpu_merkle_nodes(&leaves); + + for i in 0..cpu_nodes.len() { + let g = &gpu_nodes_bytes[i * 32..(i + 1) * 32]; + let c = &cpu_nodes[i]; + assert_eq!( + g, c, + "node {i} mismatch at log_n={log_n} (cpu={c:?}, gpu={g:?})" + ); + } +} + +#[test] +fn merkle_tree_small() { + for log_n in 1u32..=6 { + run_parity(log_n, 100 + log_n as u64); + } +} + +#[test] +fn merkle_tree_medium() { + for log_n in [10u32, 12, 14] { + run_parity(log_n, 500 + log_n as u64); + } +} + +#[test] +fn merkle_tree_large() { + run_parity(18, 9999); +} From 0944a8326b2139a7e2cbe3a3fd643fd69d73706f Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 8 May 2026 12:28:34 -0300 Subject: [PATCH 05/22] fix comments --- crypto/math-cuda/kernels/arith.cu | 2 +- crypto/math-cuda/tests/lde.rs | 2 +- crypto/math-cuda/tests/ntt.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crypto/math-cuda/kernels/arith.cu b/crypto/math-cuda/kernels/arith.cu index 4bee9b8bb..ac73f7fd1 100644 --- a/crypto/math-cuda/kernels/arith.cu +++ b/crypto/math-cuda/kernels/arith.cu @@ -1,4 +1,4 @@ -// Element-wise Goldilocks kernels used by the Phase-2 parity tests. These mirror +// Element-wise Goldilocks kernels used by the parity tests. These mirror // the CPU reference in `crypto/math/src/field/goldilocks.rs` so raw u64 outputs // are bit-identical to the CPU path. diff --git a/crypto/math-cuda/tests/lde.rs b/crypto/math-cuda/tests/lde.rs index facd2d861..33f98f9ae 100644 --- a/crypto/math-cuda/tests/lde.rs +++ b/crypto/math-cuda/tests/lde.rs @@ -1,4 +1,4 @@ -//! Phase-5 parity: GPU `coset_lde_base` must match the CPU +//! Parity: GPU `coset_lde_base` must match the CPU //! `Polynomial::coset_lde_full_expand` for a sweep of realistic sizes and //! blowup factors. diff --git a/crypto/math-cuda/tests/ntt.rs b/crypto/math-cuda/tests/ntt.rs index f3689cf94..17a556c74 100644 --- a/crypto/math-cuda/tests/ntt.rs +++ b/crypto/math-cuda/tests/ntt.rs @@ -1,4 +1,4 @@ -//! Phase-3 parity: GPU forward NTT must agree with `Polynomial::evaluate_fft` +//! Parity: GPU forward NTT must agree with `Polynomial::evaluate_fft` //! as a field element, across a sweep of sizes from 2^4 to 2^20. //! //! Non-canonical u64s can differ between CPU and GPU while representing the From 12292ceeb486f838c9f05436febaefd26e5c2ab5 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 8 May 2026 12:35:28 -0300 Subject: [PATCH 06/22] fix comments --- crypto/math-cuda/kernels/ntt.cu | 10 ++++++---- crypto/math-cuda/src/device.rs | 5 +++-- crypto/math-cuda/src/lde.rs | 4 ---- crypto/math-cuda/tests/bench_quick.rs | 9 +++++---- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/crypto/math-cuda/kernels/ntt.cu b/crypto/math-cuda/kernels/ntt.cu index 2a5c8c786..a8e50d4e1 100644 --- a/crypto/math-cuda/kernels/ntt.cu +++ b/crypto/math-cuda/kernels/ntt.cu @@ -1,5 +1,6 @@ -// Radix-2 DIT NTT over Goldilocks. One kernel per butterfly level; the caller -// runs `bit_reverse_permute` once before the first level. +// Radix-2 DIT NTT over Goldilocks: per-level, fused 8-level (shmem), and +// batched (multi-column) variants. The caller runs `bit_reverse_permute` +// once before the first butterfly level. // // Input layout: bit-reversed-order coefficients (after `bit_reverse_permute`). // Output layout: natural-order evaluations — matches the CPU `evaluate_fft` contract. @@ -179,8 +180,9 @@ extern "C" __global__ void scalar_mul_batched(uint64_t *data, /// One DIT butterfly level. Thread `tid` (of n/2 total) owns exactly one /// butterfly pair (i0, i1 = i0 + half). Twiddle picked from the shared full -/// `tw` table at stride `n / block_size`. Kept for log_n < 8 where shmem -/// fusion is overkill. +/// `tw` table at stride `n / block_size`. Used for levels 0..7 when n < 256 +/// (shmem fusion needs at least 256 elements), and for levels >= 8 of any +/// size (above the shmem-fusion window). extern "C" __global__ void ntt_dit_level(uint64_t *x, const uint64_t *tw, uint64_t n, diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index 03f0b67ca..efdaf1518 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -15,8 +15,9 @@ use math::field::traits::IsFFTField; use crate::Result; use crate::ntt::{twiddles_forward, twiddles_inverse}; -/// Reusable pinned host staging buffer. One per stream; the stream's LDE call -/// holds its buffer's lock across the D2H + memcpy-to-user-Vecs window. +/// Reusable pinned host staging buffer. Shared across all streams via a +/// `Mutex` (see `Backend::pinned_staging`); the LDE call holds the lock +/// across the D2H + memcpy-to-user-Vecs window. /// /// Allocated with `cuMemHostAlloc(flags=0)` — portable, non-write-combined, /// so both DMA writes from device and CPU reads into user Vecs run at full diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 22d62ca3e..5f1b95f34 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -517,10 +517,6 @@ pub fn coset_lde_batch_base_into( /// Skips the iFFT stage of [`coset_lde_batch_ext3_into`] (input is /// coefficients, not evaluations). Weights encode the coset shift: /// `weights[k] = offset^k` (NO 1/N because iFFT normalisation doesn't apply). -/// -/// Used by the stark prover to GPU-accelerate -/// `evaluate_polynomial_on_lde_domain` calls inside the -/// `number_of_parts > 2` branch of the composition-polynomial LDE. pub fn evaluate_poly_coset_batch_ext3_into( coefs: &[&[u64]], n: usize, diff --git a/crypto/math-cuda/tests/bench_quick.rs b/crypto/math-cuda/tests/bench_quick.rs index 41e9444c3..9e3883085 100644 --- a/crypto/math-cuda/tests/bench_quick.rs +++ b/crypto/math-cuda/tests/bench_quick.rs @@ -99,8 +99,9 @@ fn bench_lde_2_to_16_blowup_4() { #[test] #[ignore = "informal perf probe; run with --ignored"] fn bench_lde_multi_column_parallel() { - // Simulates the prover's Phase A: many columns processed via rayon. - // log_n = 16 keeps memory footprint manageable while still stressing streams. + // Simulates a multi-column workload processed via rayon: many columns + // dispatched concurrently to stress the stream pool. log_n = 16 keeps + // memory footprint manageable. let log_n = 16u32; let blowup = 4usize; let n = 1usize << log_n; @@ -156,8 +157,8 @@ fn bench_lde_multi_column_parallel() { #[test] #[ignore = "informal perf probe; run with --ignored"] fn bench_lde_batched_prover_scale() { - // Realistic large-table shape from the 1M-fib prover: ~1M rows, blowup 4, - // a few dozen columns. This is what actually runs in expand_columns_to_lde. + // Realistic large-table shape: ~1M rows, blowup 4, a few dozen columns. + // Exercises batched LDE at prover-scale sizes. let log_n = 20u32; // 1M rows let blowup = 4usize; let n = 1usize << log_n; From ae1ea57f1490f14e7b73904c9dfd6566157ca300 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 8 May 2026 17:38:45 -0300 Subject: [PATCH 07/22] fix --- crypto/math-cuda/src/lde.rs | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index e04e14054..0a64ff65e 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -1487,9 +1487,8 @@ pub fn evaluate_poly_coset_batch_ext3_into( } /// Same as [`evaluate_poly_coset_batch_ext3_into`] but retains the de- -/// interleaved LDE device buffer as a `GpuLdeExt3` handle. Lets R2 commit -/// and R4 DEEP composition read the composition-parts LDE without -/// re-H2D'ing. +/// interleaved LDE device buffer as a `GpuLdeExt3` handle so callers can +/// reuse the LDE without a re-H2D. pub fn evaluate_poly_coset_batch_ext3_into_keep( coefs: &[&[u64]], n: usize, @@ -1666,7 +1665,8 @@ fn evaluate_poly_coset_batch_ext3_into_inner( /// the LDE output, builds the R2 composition-polynomial Merkle tree on device /// (row-pair Keccak leaves at bit-reversed indices + pair-hash inner tree). /// -/// `merkle_nodes_out` must have byte length `(2 * lde_size - 1) * 32`. +/// Row-pair commit: each leaf hashes 2 rows, so the tree has `lde_size / 2` +/// leaves and `merkle_nodes_out` must have byte length `(lde_size - 1) * 32`. /// Requires `lde_size >= 2`. pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( coefs: &[&[u64]], @@ -1692,7 +1692,7 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( assert_eq!(o.len(), 3 * lde_size); } assert!(lde_size >= 2); - let total_nodes = 2 * lde_size - 1; + let total_nodes = lde_size - 1; assert_eq!(merkle_nodes_out.len(), total_nodes * 32); if n == 0 { return Ok(()); @@ -1872,15 +1872,11 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( } }); - // Copy pinned tree → caller nodes_out. `merkle_nodes_out.len() == - // total_nodes * 32` is oversized relative to our tight tree; we write - // only the first `tight_total_nodes * 32` bytes and the caller trims. - // Expose the tight byte count via the slice length so the caller can - // construct the MerkleTree with the right node count. - assert!(merkle_nodes_out.len() >= tight_total_nodes * 32); + // Copy pinned tree → caller nodes_out. + debug_assert_eq!(merkle_nodes_out.len(), tight_total_nodes * 32); const CHUNK: usize = 64 * 1024; let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; - merkle_nodes_out[..tight_total_nodes * 32] + merkle_nodes_out .par_chunks_mut(CHUNK) .enumerate() .for_each(|(i, dst)| { From 6d4c65427da360a86e782b0795acaa257dcf707b Mon Sep 17 00:00:00 2001 From: MauroFab Date: Tue, 12 May 2026 15:40:53 +0200 Subject: [PATCH 08/22] test(math-cuda): adversarial ext3, known-poly NTT, _into parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add ext3_sub_kernel + ext3_sub_u64 wrapper (test infrastructure; the ext3::sub device function was previously unreachable from Rust). - Add tests/ext3_edge.rs: 7 adversarial tests for ext3::mul dot3 overflow tracking — (p-1)^3, u64::MAX^3, non-canonical p representations, identity, and 98 base-field edge pairs in a/b/c slots. - Add tests/ext3_sub.rs: parity test for the new sub wrapper. - Add tests/ntt_known.rs: known-polynomial tests for p(x) = 1+x at sizes 16 and 256, and p(x) = x^(N/2) for the alternating ±1 pattern. - Add tests/lde_batch_into.rs: direct parity test for coset_lde_batch_base_into vs coset_lde_batch_base. 15 new tests total, all green on RTX 5090. --- crypto/math-cuda/kernels/arith.cu | 14 ++ crypto/math-cuda/src/device.rs | 2 + crypto/math-cuda/src/lib.rs | 29 ++++ crypto/math-cuda/tests/ext3_edge.rs | 176 +++++++++++++++++++++++ crypto/math-cuda/tests/ext3_sub.rs | 109 ++++++++++++++ crypto/math-cuda/tests/lde_batch_into.rs | 87 +++++++++++ crypto/math-cuda/tests/ntt_known.rs | 125 ++++++++++++++++ 7 files changed, 542 insertions(+) create mode 100644 crypto/math-cuda/tests/ext3_edge.rs create mode 100644 crypto/math-cuda/tests/ext3_sub.rs create mode 100644 crypto/math-cuda/tests/lde_batch_into.rs create mode 100644 crypto/math-cuda/tests/ntt_known.rs diff --git a/crypto/math-cuda/kernels/arith.cu b/crypto/math-cuda/kernels/arith.cu index ac73f7fd1..b1a6bb8ab 100644 --- a/crypto/math-cuda/kernels/arith.cu +++ b/crypto/math-cuda/kernels/arith.cu @@ -81,3 +81,17 @@ extern "C" __global__ void ext3_add_kernel(const uint64_t *a_int, c_int[tid*3 + 1] = r.b; c_int[tid*3 + 2] = r.c; } + +extern "C" __global__ void ext3_sub_kernel(const uint64_t *a_int, + const uint64_t *b_int, + uint64_t *c_int, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + ext3::Fe3 a = ext3::make(a_int[tid*3 + 0], a_int[tid*3 + 1], a_int[tid*3 + 2]); + ext3::Fe3 b = ext3::make(b_int[tid*3 + 0], b_int[tid*3 + 1], b_int[tid*3 + 2]); + ext3::Fe3 r = ext3::sub(a, b); + c_int[tid*3 + 0] = r.a; + c_int[tid*3 + 1] = r.b; + c_int[tid*3 + 2] = r.c; +} diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index efdaf1518..af5105ee5 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -118,6 +118,7 @@ pub struct Backend { pub gl_neg: CudaFunction, pub ext3_mul: CudaFunction, pub ext3_add: CudaFunction, + pub ext3_sub: CudaFunction, // ntt.ptx pub bit_reverse_permute: CudaFunction, @@ -169,6 +170,7 @@ impl Backend { gl_neg: arith.load_function("gl_neg_kernel")?, ext3_mul: arith.load_function("ext3_mul_kernel")?, ext3_add: arith.load_function("ext3_add_kernel")?, + ext3_sub: arith.load_function("ext3_sub_kernel")?, bit_reverse_permute: ntt.load_function("bit_reverse_permute")?, ntt_dit_level: ntt.load_function("ntt_dit_level")?, ntt_dit_8_levels: ntt.load_function("ntt_dit_8_levels")?, diff --git a/crypto/math-cuda/src/lib.rs b/crypto/math-cuda/src/lib.rs index 821f5bd3a..04ec08b1d 100644 --- a/crypto/math-cuda/src/lib.rs +++ b/crypto/math-cuda/src/lib.rs @@ -89,6 +89,35 @@ pub fn ext3_mul_u64(a: &[u64], b: &[u64]) -> Result> { Ok(out) } +/// Element-wise ext3 subtract. Test helper for `ext3::sub` in `ext3.cuh`. +pub fn ext3_sub_u64(a: &[u64], b: &[u64]) -> Result> { + assert_eq!(a.len(), b.len()); + assert_eq!(a.len() % 3, 0); + let n = a.len() / 3; + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + let a_dev = stream.clone_htod(a)?; + let b_dev = stream.clone_htod(b)?; + let mut c_dev = stream.alloc_zeros::(3 * n)?; + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.ext3_sub) + .arg(&a_dev) + .arg(&b_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} + /// Element-wise ext3 add. pub fn ext3_add_u64(a: &[u64], b: &[u64]) -> Result> { assert_eq!(a.len(), b.len()); diff --git a/crypto/math-cuda/tests/ext3_edge.rs b/crypto/math-cuda/tests/ext3_edge.rs new file mode 100644 index 000000000..55b20bc92 --- /dev/null +++ b/crypto/math-cuda/tests/ext3_edge.rs @@ -0,0 +1,176 @@ +//! Adversarial edge-case parity tests for `ext3::mul` on the GPU. +//! +//! The CUDA `dot3` (kernels/ext3.cuh:62-98) manually tracks overflow when +//! summing three u128 products in split u64 hi/lo registers. The CPU +//! reference (`crypto/math/src/field/goldilocks.rs::dot_product_3` via +//! `Degree3GoldilocksExtensionField::mul`) uses native u128 and so reaches +//! the same answer via a totally different code path. These tests pick +//! inputs that maximally stress the overflow-count tracking, the +//! non-canonical input handling, and the identity/zero cases that random +//! tests are unlikely to cover. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +const P: u64 = 0xFFFF_FFFF_0000_0001; // Goldilocks prime +const EPSILON: u64 = 0xFFFF_FFFF; // 2^32 - 1 + +fn canon(x: u64) -> u64 { + GoldilocksField::canonical(&x) +} + +fn canon3(t: [u64; 3]) -> [u64; 3] { + [canon(t[0]), canon(t[1]), canon(t[2])] +} + +/// Run `pairs` of (a, b) through the GPU `ext3_mul_u64` and compare +/// canonical output to the CPU `Degree3GoldilocksExtensionField::mul`. +/// `label_fn(i)` produces a per-case label printed on failure. +fn assert_ext3_mul_pairs(pairs: &[([u64; 3], [u64; 3])], label_fn: impl Fn(usize) -> String) { + let mut a_raw = Vec::with_capacity(pairs.len() * 3); + let mut b_raw = Vec::with_capacity(pairs.len() * 3); + for (a, b) in pairs { + a_raw.extend_from_slice(a); + b_raw.extend_from_slice(b); + } + let gpu = math_cuda::ext3_mul_u64(&a_raw, &b_raw).expect("GPU ext3 mul launch"); + assert_eq!(gpu.len(), 3 * pairs.len()); + for (i, (a, b)) in pairs.iter().enumerate() { + // CPU reference. Build via from_raw so non-canonical inputs (like + // u64::MAX or p, p+1) are passed in untouched — matching what the + // GPU sees. + let ae = [Fp::from_raw(a[0]), Fp::from_raw(a[1]), Fp::from_raw(a[2])]; + let be = [Fp::from_raw(b[0]), Fp::from_raw(b[1]), Fp::from_raw(b[2])]; + let cpu = Degree3GoldilocksExtensionField::mul(&ae, &be); + let cpu_fp3 = Fp3::new(cpu); + let g = canon3([gpu[3 * i], gpu[3 * i + 1], gpu[3 * i + 2]]); + let c = canon3([ + *cpu_fp3.value()[0].value(), + *cpu_fp3.value()[1].value(), + *cpu_fp3.value()[2].value(), + ]); + assert_eq!( + g, + c, + "ext3 mul mismatch [{}]: a={:?} b={:?} gpu={:?} cpu={:?}", + label_fn(i), + a, + b, + g, + c, + ); + } +} + +#[test] +fn ext3_mul_max_canonical_inputs() { + // (p-1, p-1, p-1) * (p-1, p-1, p-1) — every base limb is (p-1), every + // dot3 product a_i*b_i = (p-1)^2 ~ 2^128, so summing three of them + // forces the overflow path twice on each component. + let m1 = [P - 1, P - 1, P - 1]; + let pairs = vec![(m1, m1)]; + assert_ext3_mul_pairs(&pairs, |_| "(p-1)^3 squared".into()); +} + +#[test] +fn ext3_mul_zero_cases() { + // (0,0,0) * (p-1,p-1,p-1) must be zero; covers the "all-zero a" path + // where every dot3 product is zero and no overflow occurs. + let z = [0u64, 0, 0]; + let m = [P - 1, P - 1, P - 1]; + let pairs = vec![(z, m), (m, z), (z, z)]; + assert_ext3_mul_pairs(&pairs, |i| format!("zero case {i}")); +} + +#[test] +fn ext3_mul_identity() { + // (1, 0, 0) * (a, b, c) == (a, b, c). One-component is multiplicative + // identity in Fp[w]/(w^3 - 2). Use varied b to also exercise small + // non-zero dot3 products. + let id = [1u64, 0, 0]; + let cases: Vec<([u64; 3], [u64; 3])> = vec![ + (id, [0, 0, 0]), + (id, [1, 0, 0]), + (id, [0, 1, 0]), + (id, [0, 0, 1]), + (id, [P - 1, 1, 2]), + (id, [123, 456, 789]), + (id, [P - 1, P - 1, P - 1]), + // Reverse order: (a, b, c) * (1, 0, 0). + ([0, 0, 0], id), + ([1, 0, 0], id), + ([0, 1, 0], id), + ([0, 0, 1], id), + ([P - 1, 1, 2], id), + ([123, 456, 789], id), + ]; + assert_ext3_mul_pairs(&cases, |i| format!("identity case {i}")); +} + +#[test] +fn ext3_mul_non_canonical_zero_p() { + // (p, p, p) is a non-canonical representation of (0, 0, 0). The CPU + // canonicalises before the dot3 in some code paths, the GPU does not. + // Either way the product must canonicalise to zero. + let p = [P, P, P]; + let some = [123u64, 456, 789]; + let pairs = vec![(p, p), (p, some), (some, p)]; + assert_ext3_mul_pairs(&pairs, |i| format!("non-canonical-p case {i}")); +} + +#[test] +fn ext3_mul_u64_max_all_overflow_paths() { + // (u64::MAX, u64::MAX, u64::MAX) for both operands. u64::MAX = p + (2^32 - 2), + // i.e., a non-canonical representation of (2^32 - 2) mod p. Every dot3 + // product is ~2^128 - small, so summing three of them is the hardest + // possible exercise of `over1`, `over2`, and the EPSILON^2 correction + // path in `dot3`. + let m = [u64::MAX, u64::MAX, u64::MAX]; + assert_ext3_mul_pairs(&[(m, m)], |_| "u64::MAX^3 squared".into()); +} + +#[test] +fn ext3_mul_base_edge_pairs_embedded() { + // Embed every base-field edge value as the `a` component of an ext3 + // element (so b = c = 0) and run all NxN pairs through GPU mul. This + // reduces to base-field multiplication on the a-component but + // exercises all the dot3 zero/non-zero combinations. + let edges: Vec = vec![0, 1, P - 1, P, P + 1, u64::MAX, EPSILON]; + let mut pairs = Vec::with_capacity(edges.len() * edges.len()); + for x in &edges { + for y in &edges { + pairs.push(([*x, 0, 0], [*y, 0, 0])); + } + } + assert_ext3_mul_pairs(&pairs, |i| { + let xi = i / edges.len(); + let yi = i % edges.len(); + format!("base-edge a={:#x} b={:#x}", edges[xi], edges[yi]) + }); +} + +#[test] +fn ext3_mul_base_edges_in_b_and_c_slots() { + // Same edge values, but placed in the b and c slots so the cross terms + // (which involve the `b1_2 = 2*y.b`, `b2_2 = 2*y.c` doubling) are also + // exercised at edge inputs. Non-canonical doubling of P-1 etc. is a + // path that random tests rarely hit. + let edges: Vec = vec![0, 1, P - 1, P, P + 1, u64::MAX, EPSILON]; + let mut pairs = Vec::with_capacity(edges.len() * edges.len()); + for x in &edges { + for y in &edges { + // Put edge values in b/c slots, a non-trivial. + pairs.push(([1, *x, *y], [1, *x, *y])); + } + } + assert_ext3_mul_pairs(&pairs, |i| { + let xi = i / edges.len(); + let yi = i % edges.len(); + format!("ext3-bc-edge b={:#x} c={:#x}", edges[xi], edges[yi]) + }); +} diff --git a/crypto/math-cuda/tests/ext3_sub.rs b/crypto/math-cuda/tests/ext3_sub.rs new file mode 100644 index 000000000..3593379d1 --- /dev/null +++ b/crypto/math-cuda/tests/ext3_sub.rs @@ -0,0 +1,109 @@ +//! Parity test for `ext3::sub` (kernels/ext3.cuh:41-45). This device +//! function is part of the public ext3 header but is not invoked by any +//! kernel in the PR — every other ext3 caller uses `mul`, `add`, or +//! `mul_base`. The PR's review test infrastructure adds an +//! `ext3_sub_kernel` so we can call it directly here for parity vs the +//! CPU `Degree3GoldilocksExtensionField::sub`. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +const N: usize = 10_000; +const P: u64 = 0xFFFF_FFFF_0000_0001; + +fn random_fp3s(seed: u64, count: usize) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..count) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() +} + +fn to_u64s(col: &[Fp3]) -> Vec { + let mut v = Vec::with_capacity(col.len() * 3); + for e in col { + v.push(*e.value()[0].value()); + v.push(*e.value()[1].value()); + v.push(*e.value()[2].value()); + } + v +} + +fn canon(x: u64) -> u64 { + GoldilocksField::canonical(&x) +} + +#[test] +fn ext3_sub_matches_cpu_random() { + let a = random_fp3s(101, N); + let b = random_fp3s(202, N); + let a_raw = to_u64s(&a); + let b_raw = to_u64s(&b); + let gpu = math_cuda::ext3_sub_u64(&a_raw, &b_raw).expect("GPU ext3 sub launch"); + assert_eq!(gpu.len(), 3 * N); + for i in 0..N { + let cpu = Degree3GoldilocksExtensionField::sub(a[i].value(), b[i].value()); + let cpu_fp3 = Fp3::new(cpu); + let g = [ + canon(gpu[3 * i]), + canon(gpu[3 * i + 1]), + canon(gpu[3 * i + 2]), + ]; + let c = [ + canon(*cpu_fp3.value()[0].value()), + canon(*cpu_fp3.value()[1].value()), + canon(*cpu_fp3.value()[2].value()), + ]; + assert_eq!(g, c, "ext3 sub mismatch at {i}"); + } +} + +#[test] +fn ext3_sub_edge_cases() { + // Underflow cases: a < b on each component, plus non-canonical p + // representations. + let cases: Vec<([u64; 3], [u64; 3])> = vec![ + ([0, 0, 0], [P - 1, P - 1, P - 1]), + ([1, 2, 3], [P - 1, P - 1, P - 1]), + ([P - 1, P - 1, P - 1], [0, 0, 0]), + ([P, P, P], [P, P, P]), // (0,0,0) - (0,0,0) + ([u64::MAX, u64::MAX, u64::MAX], [0, 0, 0]), + ([0, 0, 0], [u64::MAX, u64::MAX, u64::MAX]), + ]; + let mut a_raw = Vec::new(); + let mut b_raw = Vec::new(); + for (a, b) in &cases { + a_raw.extend_from_slice(a); + b_raw.extend_from_slice(b); + } + let gpu = math_cuda::ext3_sub_u64(&a_raw, &b_raw).expect("GPU ext3 sub launch"); + for (i, (a, b)) in cases.iter().enumerate() { + let ae = [Fp::from_raw(a[0]), Fp::from_raw(a[1]), Fp::from_raw(a[2])]; + let be = [Fp::from_raw(b[0]), Fp::from_raw(b[1]), Fp::from_raw(b[2])]; + let cpu = Degree3GoldilocksExtensionField::sub(&ae, &be); + let cpu_fp3 = Fp3::new(cpu); + let g = [ + canon(gpu[3 * i]), + canon(gpu[3 * i + 1]), + canon(gpu[3 * i + 2]), + ]; + let c = [ + canon(*cpu_fp3.value()[0].value()), + canon(*cpu_fp3.value()[1].value()), + canon(*cpu_fp3.value()[2].value()), + ]; + assert_eq!(g, c, "ext3 sub edge mismatch at {i}: a={a:?} b={b:?}"); + } +} diff --git a/crypto/math-cuda/tests/lde_batch_into.rs b/crypto/math-cuda/tests/lde_batch_into.rs new file mode 100644 index 000000000..19411d152 --- /dev/null +++ b/crypto/math-cuda/tests/lde_batch_into.rs @@ -0,0 +1,87 @@ +//! Direct parity test for `coset_lde_batch_base_into` (lde.rs:331), the +//! caller-allocated-buffer variant of `coset_lde_batch_base`. The two should +//! produce bit-identical canonical output for the same inputs; the only +//! difference between them is who owns the output Vec. +//! +//! This is otherwise covered indirectly through `coset_lde_batch_base_into_with_leaf_hash` +//! and similar, but the base `_into` variant ships as public API with no +//! direct test in the original PR. + +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = *Fp::from(n as u64).inv().unwrap().value(); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &g); + } + w +} + +fn canon(xs: &[u64]) -> Vec { + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() +} + +fn run_pair(log_n: u64, blowup: usize, m: usize, seed: u64) { + let n = 1usize << log_n; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let columns: Vec> = (0..m) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + + let weights = coset_weights(n, 7); + + let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + + // Reference: the existing batch-allocates-Vec API. Random-input tests + // already cross-validate this against the CPU single-column LDE in + // `lde_batch.rs`, so any divergence from it pinpoints `_into`. + let ref_out = math_cuda::lde::coset_lde_batch_base(&slices, blowup, &weights).unwrap(); + assert_eq!(ref_out.len(), m); + + // Caller-allocated buffers. + let mut owned: Vec> = (0..m).map(|_| vec![0u64; lde_size]).collect(); + { + let mut outs: Vec<&mut [u64]> = owned.iter_mut().map(|v| v.as_mut_slice()).collect(); + math_cuda::lde::coset_lde_batch_base_into(&slices, blowup, &weights, &mut outs) + .expect("into variant"); + } + + for c in 0..m { + assert_eq!( + canon(&owned[c]), + canon(&ref_out[c]), + "_into vs _batch_base diverge at column {c}, log_n={log_n}, blowup={blowup}, m={m}" + ); + } +} + +#[test] +fn into_matches_batch_base_small() { + run_pair(8, 4, 4, 1); + run_pair(10, 4, 1, 2); + run_pair(10, 4, 16, 3); +} + +#[test] +fn into_matches_batch_base_medium() { + run_pair(14, 4, 8, 4); + run_pair(15, 4, 32, 5); +} + +#[test] +fn into_matches_batch_base_uneven_blowup() { + // Non-default blowup factors (still power of two) — confirms the + // _into variant respects blowup_factor identically. + run_pair(8, 2, 4, 6); + run_pair(8, 8, 4, 7); +} diff --git a/crypto/math-cuda/tests/ntt_known.rs b/crypto/math-cuda/tests/ntt_known.rs new file mode 100644 index 000000000..b2cb2a162 --- /dev/null +++ b/crypto/math-cuda/tests/ntt_known.rs @@ -0,0 +1,125 @@ +//! Known-answer NTT test. Random-input tests catch most bugs but can mask +//! systematic errors (sign flips, off-by-one twiddle indices, wrong-direction +//! butterflies) that would cancel under noise. This test picks a polynomial +//! with a known closed-form evaluation at every root of unity and compares +//! the GPU forward NTT to that reference, computed independently from any +//! FFT code path. + +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsFFTField, IsPrimeField}; + +type Fp = FieldElement; + +fn canon(x: u64) -> u64 { + GoldilocksField::canonical(&x) +} + +/// p(x) = 1 + x. NTT at size N is the vector [p(omega^0), p(omega^1), ..., +/// p(omega^(N-1))] for `omega` a primitive N-th root of unity. We compute +/// the reference by direct exponentiation of `omega` — no FFT involved — +/// so a bug in either the CPU or GPU NTT can't hide here. +#[test] +fn ntt_known_polynomial_x_plus_one_size_256() { + let log_n: u64 = 8; + let n: usize = 1 << log_n; + + // GoldilocksField uses bit-reversed coefficient layout for forward NTT: + // input coeffs at index `i` become `p(omega^bitrev(i))` — confirm by + // matching against the existing forward() API which random-input tests + // already validate against `Polynomial::evaluate_fft`. The known-poly + // value lets us catch systematic errors in that pipeline that random + // inputs miss. + + // Input coefficients: [1, 1, 0, 0, ..., 0]. (Natural order, lowest + // degree first — `Polynomial::new` and `math_cuda::ntt::forward` both + // expect this convention.) + let mut input = vec![0u64; n]; + input[0] = 1; + input[1] = 1; + + let gpu = math_cuda::ntt::forward(&input).expect("gpu ntt"); + assert_eq!(gpu.len(), n); + + // Reference: omega = primitive N-th root of unity in Goldilocks. + // p(omega^i) = 1 + omega^i. + let omega = GoldilocksField::get_primitive_root_of_unity(log_n).expect("root of unity"); + let one = Fp::from_raw(1); + + let mut expected = Vec::with_capacity(n); + let mut omega_i = one.clone(); // omega^0 + for _ in 0..n { + let val = &one + &omega_i; + expected.push(*val.value()); + omega_i = &omega_i * ω + } + + for i in 0..n { + let g = canon(gpu[i]); + let e = canon(expected[i]); + if g != e { + panic!( + "p(omega^{i}) mismatch: gpu canon {:#018x}, expected canon {:#018x} (omega^{i} computed independently of any FFT)", + g, e, + ); + } + } +} + +/// Same idea, smaller: p(x) = 1 + x at size 2^4 = 16. A failure at this +/// size with passes at larger sizes (or vice-versa) would point at a +/// boundary bug between the recursive base case and the 8-level +/// shared-memory fused step in the GPU NTT. +#[test] +fn ntt_known_polynomial_x_plus_one_size_16() { + let log_n: u64 = 4; + let n: usize = 1 << log_n; + + let mut input = vec![0u64; n]; + input[0] = 1; + input[1] = 1; + + let gpu = math_cuda::ntt::forward(&input).expect("gpu ntt"); + + let omega = GoldilocksField::get_primitive_root_of_unity(log_n).expect("root"); + let one = Fp::from_raw(1); + let mut omega_i = one.clone(); + for i in 0..n { + let exp = &one + &omega_i; + assert_eq!( + canon(gpu[i]), + canon(*exp.value()), + "p(omega^{i}) mismatch at size 16" + ); + omega_i = &omega_i * ω + } +} + +/// p(x) = x^k for k = N/2. p(omega^i) = omega^(k*i). With k = N/2, +/// omega^(k*i) = (-1)^i since omega^(N/2) = -1 in any field with a +/// primitive N-th root of unity. So evaluations alternate +1, -1, +1, -1. +/// This is a strong test of twiddle-index direction and sign. +#[test] +fn ntt_known_polynomial_x_half_alternating() { + let log_n: u64 = 8; + let n: usize = 1 << log_n; + let k = n / 2; + + let mut input = vec![0u64; n]; + input[k] = 1; // p(x) = x^(N/2) + + let gpu = math_cuda::ntt::forward(&input).expect("gpu ntt"); + + // Expected: omega^(k*i) for i = 0..N, which is (omega^k)^i = (-1)^i. + // canonical(+1) = 1; canonical(-1) = p - 1. + let p_minus_one = 0xFFFF_FFFF_0000_0001u64 - 1; + for i in 0..n { + let exp = if i % 2 == 0 { 1u64 } else { p_minus_one }; + assert_eq!( + canon(gpu[i]), + exp, + "x^(N/2) NTT alternation mismatch at i={i}: got {:#018x}", + canon(gpu[i]) + ); + } +} From 8c529718ee6b8f0ae4287c3cab39b3dd552783dc Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 13 May 2026 11:59:28 -0300 Subject: [PATCH 09/22] address PR review feedback --- Makefile | 10 +- crypto/math-cuda/Cargo.toml | 3 +- crypto/math-cuda/build.rs | 57 +++++++-- crypto/math-cuda/kernels/ntt.cu | 2 +- crypto/math-cuda/src/device.rs | 29 ++++- crypto/math-cuda/src/lde.rs | 152 +++++++++-------------- crypto/math-cuda/src/lib.rs | 10 +- crypto/math-cuda/src/ntt.rs | 21 +++- crypto/math-cuda/tests/ext3_edge.rs | 2 +- crypto/math-cuda/tests/ext3_sub.rs | 4 +- crypto/math-cuda/tests/lde.rs | 4 +- crypto/math-cuda/tests/lde_batch_into.rs | 4 +- crypto/math-cuda/tests/ntt.rs | 2 - crypto/math-cuda/tests/ntt_known.rs | 26 ++-- 14 files changed, 184 insertions(+), 142 deletions(-) diff --git a/Makefile b/Makefile index c02bffc49..10be4f07a 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY: deps deps-linux deps-macos prepare-test-data compile-programs-asm compile-programs-rust compile-bench \ compile-programs clean-asm clean-rust clean-bench clean-shared clean test test-asm test-no-compile \ test-asm-no-compile test-rust test-rust-no-compile test-executor flamegraph-prover \ -test-fast test-prover test-prover-all build check clippy fmt lint +test-fast test-prover test-prover-all test-math-cuda bench-math-cuda build check clippy fmt lint UNAME := $(shell uname) @@ -166,6 +166,14 @@ test-prover-all: test-prover-debug: cargo test -p lambda-vm-prover --features debug-checks -- --nocapture +# math-cuda parity tests (requires NVIDIA GPU + nvcc) +test-math-cuda: + cargo test -p math-cuda --release + +# math-cuda quick microbench (median of 10 runs) +bench-math-cuda: + cargo test -p math-cuda --release --test bench_quick -- --ignored --nocapture + # Build all build: cargo build --workspace diff --git a/crypto/math-cuda/Cargo.toml b/crypto/math-cuda/Cargo.toml index 8c22d1110..0990dd6d6 100644 --- a/crypto/math-cuda/Cargo.toml +++ b/crypto/math-cuda/Cargo.toml @@ -9,7 +9,8 @@ cudarc = { version = "0.19", default-features = false, features = [ "driver", "nvrtc", "std", - "cuda-12080", + "cuda-version-from-build-system", + "fallback-latest", "dynamic-loading", ] } math = { path = "../math" } diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs index 43edaa31c..cf541b5fd 100644 --- a/crypto/math-cuda/build.rs +++ b/crypto/math-cuda/build.rs @@ -14,6 +14,38 @@ fn nvcc_path() -> PathBuf { cuda_home().join("bin").join("nvcc") } +/// Query `nvidia-smi` for the local GPU's compute capability (e.g. "12.0" +/// for Blackwell). Returns a `compute_XX` target on success, falling back +/// to `compute_89` (Ada) when no GPU is visible or the query fails. +fn detect_arch() -> String { + const FALLBACK: &str = "compute_89"; + let output = match Command::new("nvidia-smi") + .args(["--query-gpu=compute_cap", "--format=csv,noheader"]) + .output() + { + Ok(o) if o.status.success() => o, + _ => return FALLBACK.to_string(), + }; + let line = match std::str::from_utf8(&output.stdout) { + Ok(s) => s, + Err(_) => return FALLBACK.to_string(), + }; + // First line, first comma-separated value (covers multi-GPU hosts). + let cap = match line.lines().next() { + Some(l) => l.split(',').next().unwrap_or("").trim(), + None => return FALLBACK.to_string(), + }; + let (major, minor) = match cap.split_once('.') { + Some((m, n)) => (m.trim(), n.trim()), + None => return FALLBACK.to_string(), + }; + if major.chars().all(|c| c.is_ascii_digit()) && minor.chars().all(|c| c.is_ascii_digit()) { + format!("compute_{major}{minor}") + } else { + FALLBACK.to_string() + } +} + fn compile_ptx(src: &str, out_name: &str, have_nvcc: bool) { let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); @@ -25,20 +57,22 @@ fn compile_ptx(src: &str, out_name: &str, have_nvcc: bool) { println!("cargo:rerun-if-env-changed=CUDA_PATH"); println!("cargo:rerun-if-env-changed=CUDARC_NVCC_ARCH"); - // No nvcc on PATH → emit an empty PTX stub so the crate still compiles. - // include_str! in src/device.rs needs the file to exist at build time. - // Any runtime kernel call will then panic from cudarc when loading the - // empty module — which is the right failure mode (we can't run GPU code - // without nvcc on the build host anyway). + // When nvcc is missing from PATH, emit an empty PTX stub so the crate + // still compiles. include_str! in src/device.rs needs the file to exist + // at build time. Any runtime kernel call panics in cudarc when loading + // the empty module. We can't run GPU code without nvcc on the build + // host anyway. if !have_nvcc { fs::write(&out_path, "").expect("failed to write empty PTX stub"); return; } // Emit PTX for a virtual architecture; the CUDA driver JIT-compiles it for the - // actual GPU at load time, so one PTX works across Ada/Hopper/Blackwell. Override - // with CUDARC_NVCC_ARCH to pin a specific compute capability. - let arch = env::var("CUDARC_NVCC_ARCH").unwrap_or_else(|_| "compute_89".to_string()); + // actual GPU at load time. Override with CUDARC_NVCC_ARCH to pin a specific + // compute capability. If unset, try `nvidia-smi` to match the host GPU + // (avoids JIT failures like nvcc-13.0 PTX rejected on Blackwell drivers); + // fall back to compute_89 (Ada) when detection fails. + let arch = env::var("CUDARC_NVCC_ARCH").unwrap_or_else(|_| detect_arch()); let status = Command::new(nvcc_path()) .args(["--ptx", "-O3", "-std=c++17", "-arch", &arch, "-o"]) @@ -53,13 +87,14 @@ fn compile_ptx(src: &str, out_name: &str, have_nvcc: bool) { } fn main() { - // Headers are not compiled; emit rerun-if-changed so edits trigger rebuilds. + // Headers aren't compiled, so emit rerun-if-changed to rebuild on + // header edits. println!("cargo:rerun-if-changed=kernels/goldilocks.cuh"); println!("cargo:rerun-if-changed=kernels/ext3.cuh"); // Probe for nvcc once. Workspace consumers (clippy, fmt, CPU-only test - // runners) build math-cuda incidentally without using its kernels; allow - // those to succeed by stubbing out PTX when nvcc is unavailable. + // runners) build math-cuda incidentally without using its kernels. Stub + // out PTX when nvcc is unavailable so those builds succeed. let have_nvcc = Command::new(nvcc_path()) .arg("--version") .output() diff --git a/crypto/math-cuda/kernels/ntt.cu b/crypto/math-cuda/kernels/ntt.cu index a8e50d4e1..cf5e1df2c 100644 --- a/crypto/math-cuda/kernels/ntt.cu +++ b/crypto/math-cuda/kernels/ntt.cu @@ -218,7 +218,7 @@ extern "C" __global__ void ntt_dit_level(uint64_t *x, /// without writing to global memory between them — cuts DRAM traffic by up /// to 8× vs the per-level kernel. /// -/// `base_step` selects which 8-level window this launch handles (0, 8, 16…). +/// `base_step` selects which 8-level window this launch handles (0, 8, 16, ...). /// For levels 0–7 the implicit DIT element layout already places all pair /// mates inside the same 256-block; for higher base_step we remap the loaded /// row so pair mates land in consecutive shared-memory slots. diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index af5105ee5..b3c4a1c56 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -158,8 +158,9 @@ impl Backend { // not part of the pool that callers rotate through. let util_stream = ctx.new_stream()?; - // Goldilocks TWO_ADICITY is 32, so log_n ≤ 32 covers every LDE size - // the prover can produce. Overshoot by one for safety. + // Cache is indexed by log_n. Valid range is [0, TWO_ADICITY] since + // Goldilocks has roots of unity for orders 2^0..=2^TWO_ADICITY only. + // Length = TWO_ADICITY + 1 to allow indexing at log_n = TWO_ADICITY. let max_log = GoldilocksField::TWO_ADICITY as usize + 1; Ok(Self { @@ -219,6 +220,14 @@ impl Backend { } else { &self.inv_twiddles }; + // Cache is sized TWO_ADICITY + 1 in `Backend::init`. Callers derive + // log_n from `trailing_zeros` of valid Goldilocks domain sizes so it + // must stay in range; assert in debug to catch regressions. + debug_assert!( + log_n <= GoldilocksField::TWO_ADICITY, + "log_n {log_n} exceeds Goldilocks TWO_ADICITY ({})", + GoldilocksField::TWO_ADICITY, + ); { let guard = cache.lock().unwrap(); if let Some(t) = &guard[idx] { @@ -244,7 +253,19 @@ impl Backend { } } -pub fn backend() -> &'static Backend { +/// Returns the process-wide CUDA backend, initialising it on first call. +/// +/// Returns `Err` when CUDA initialisation fails (no driver, no GPU, PTX load +/// failure). Initialisation is retried on the next call until one succeeds — +/// only a successful `Backend` is cached. The race window where two threads +/// init concurrently is harmless: at most one extra `Backend::init()` runs +/// and the loser is dropped. +pub fn backend() -> Result<&'static Backend> { static BACKEND: OnceLock = OnceLock::new(); - BACKEND.get_or_init(|| Backend::init().expect("failed to initialise CUDA backend")) + if let Some(b) = BACKEND.get() { + return Ok(b); + } + let b = Backend::init()?; + let _ = BACKEND.set(b); + Ok(BACKEND.get().expect("backend just initialised")) } diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 5f1b95f34..719a50931 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -4,7 +4,7 @@ //! Input : N evaluations (natural order) of a poly on the standard subgroup, //! plus coset weights (size N). The weights include the `1/N` iFFT //! normalisation, matching the `LdeTwiddles::coset_weights` format at -//! `crypto/stark/src/prover.rs:248` — i.e. `weights[i] = g^i / N`. +//! `crypto/stark/src/prover.rs` — i.e. `weights[i] = g^i / N`. //! Output : N*blowup_factor evaluations (natural order) on the coset. //! //! On-device steps, picks a stream from the shared pool so rayon-parallel @@ -13,11 +13,24 @@ use std::sync::Arc; use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; +use rayon::prelude::*; use crate::Result; use crate::device::backend; use crate::ntt::run_ntt_body; +/// Goldilocks `TWO_ADICITY = 32` puts the theoretical domain ceiling at +/// `2^32`, where a downstream `as u32` cast would silently truncate to zero +/// and the corresponding kernel launch would do nothing. Assert at each +/// public entry point before any cast that depends on it. +#[inline] +fn assert_u32_domain(n: usize, what: &str) { + assert!( + n <= u32::MAX as usize, + "{what}: {n} exceeds u32 range — kernel grid would silently truncate", + ); +} + /// Handle to a base-field LDE kept live on device after R1 commit. /// Layout: `m` columns, each `lde_size` u64s, column `c` at byte offset /// `c * lde_size * 8` within `buf`. Freed when `buf` Arc drops. @@ -40,20 +53,23 @@ pub struct GpuLdeExt3 { pub fn coset_lde_base(evals: &[u64], blowup_factor: usize, weights: &[u64]) -> Result> { let n = evals.len(); + // Empty input must short-circuit before the power-of-two assert + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(Vec::new()); + } assert!(n.is_power_of_two(), "evals length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match evals"); assert!( blowup_factor.is_power_of_two(), "blowup must be power of two" ); - if n == 0 { - return Ok(Vec::new()); - } let lde_size = n * blowup_factor; + assert_u32_domain(lde_size, "coset_lde_base lde_size"); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); // Device buffer of lde_size, zero-padded tail, first N filled by copy. @@ -127,6 +143,11 @@ pub fn coset_lde_batch_base( } let m = columns.len(); let n = columns[0].len(); + // Empty columns must short-circuit before the power-of-two assert + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(vec![Vec::new(); m]); + } assert!(n.is_power_of_two(), "column length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match column length"); assert!( @@ -136,33 +157,15 @@ pub fn coset_lde_batch_base( for c in columns.iter() { assert_eq!(c.len(), n, "all columns must be the same size"); } - - if n == 0 { - return Ok(vec![Vec::new(); m]); - } let lde_size = n * blowup_factor; + assert_u32_domain(lde_size, "coset_lde_batch_base lde_size"); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); - let debug_phases = std::env::var("MATH_CUDA_PHASE_TIMING").is_ok(); - let t_start = if debug_phases { - Some(std::time::Instant::now()) - } else { - None - }; - let phase = |label: &str, prev: &mut Option| { - if let Some(p) = prev.as_ref() { - let now = std::time::Instant::now(); - eprintln!(" [{:>6.2} ms] {}", (now - *p).as_secs_f64() * 1e3, label); - *prev = Some(now); - } - }; - let mut last = t_start; - // Pinned staging. Lock and grow to max(m*n for upload, m*lde_size for // download). Holding the guard across the whole call serialises concurrent // batched calls that happened to hash to the same stream slot, but that's @@ -171,14 +174,10 @@ pub fn coset_lde_batch_base( staging.ensure_capacity(m * lde_size, &be.ctx)?; // SAFETY: staging is locked, the slice alias ends before we unlock. let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - if debug_phases { - phase("staging lock + grow", &mut last); - } // Pack columns into first m*n slots of the pinned buffer. Parallel: pinned - // writes are DRAM-bandwidth bound, saturates at ~8 cores on modern - // hardware, so rayon shaves 20+ ms at prover scale. - use rayon::prelude::*; + // writes are DRAM-bandwidth bound, so rayon spreads the cost across CPU + // cores. let pinned_base_ptr = pinned.as_mut_ptr() as usize; columns.par_iter().enumerate().for_each(|(c, col)| { // SAFETY: each task writes to a disjoint `[c*n..c*n+n]` region of @@ -188,35 +187,20 @@ pub fn coset_lde_batch_base( unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; dst.copy_from_slice(col); }); - if debug_phases { - phase("host pack (pinned, rayon)", &mut last); - } // Column layout: `buf[c * lde_size + r]`. Zeroed so the [n, lde_size) // tail of each column is already the zero-pad the CPU path does. let mut buf = stream.alloc_zeros::(m * lde_size)?; - if debug_phases { - stream.synchronize()?; - phase("alloc_zeros", &mut last); - } // One memcpy per column from the pinned buffer into the strided slots. // The pinned source hits PCIe line-rate. for c in 0..m { let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; } - if debug_phases { - stream.synchronize()?; - phase("H2D cols (pinned)", &mut last); - } let inv_tw = be.inv_twiddles_for(log_n)?; let fwd_tw = be.fwd_twiddles_for(log_lde)?; let weights_dev = stream.clone_htod(weights)?; - if debug_phases { - stream.synchronize()?; - phase("twiddles + weights", &mut last); - } let n_u64 = n as u64; let lde_u64 = lde_size as u64; @@ -242,10 +226,6 @@ pub fn coset_lde_batch_base( } } - if debug_phases { - stream.synchronize()?; - phase("bit_reverse N", &mut last); - } // === 2. iNTT body over all columns === run_batched_ntt_body( stream.as_ref(), @@ -256,10 +236,6 @@ pub fn coset_lde_batch_base( col_stride_u64, m_u32, )?; - if debug_phases { - stream.synchronize()?; - phase("iNTT body", &mut last); - } // === 3. Pointwise multiply by coset weights (includes 1/N) === { @@ -299,10 +275,6 @@ pub fn coset_lde_batch_base( } } - if debug_phases { - stream.synchronize()?; - phase("pointwise + bit_reverse LDE", &mut last); - } // === 5. Forward NTT on full LDE of every column === run_batched_ntt_body( stream.as_ref(), @@ -313,29 +285,22 @@ pub fn coset_lde_batch_base( col_stride_u64, m_u32, )?; - if debug_phases { - stream.synchronize()?; - phase("forward NTT body", &mut last); - } // Single big D2H into the reusable pinned staging buffer — pinned, one // call to the driver, saturates PCIe. stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; stream.synchronize()?; - if debug_phases { - phase("D2H (one shot into pinned)", &mut last); - } // Split pinned → per-column Vecs. The first write to each virgin - // Vec page-faults, which can dominate total time (~75 ms for 128 MB). - // Parallelise so the fault cost spreads across CPU cores. + // Vec page-faults, which can dominate total time. Parallelise so the + // fault cost spreads across CPU cores. let pinned_ptr = pinned.as_ptr() as usize; let out: Vec> = (0..m) .into_par_iter() .map(|c| { - // set_len skips the O(N) zero-init that vec![0; n] would do - // (saves ~75 ms per 128 MB at prover scale). copy_from_slice - // below writes every slot before any reader sees the Vec. + // set_len skips the O(N) zero-init that vec![0; n] would do. + // copy_from_slice below writes every slot before any reader + // sees the Vec. #[allow(clippy::uninit_vec)] let mut v = { let mut v = Vec::::with_capacity(lde_size); @@ -349,18 +314,15 @@ pub fn coset_lde_batch_base( v }) .collect(); - if debug_phases { - phase("copy out (rayon pinned → Vecs)", &mut last); - } drop(staging); Ok(out) } /// Like `coset_lde_batch_base` but writes directly into caller-provided /// output slices instead of allocating fresh `Vec`s. Each output slice -/// must already have length `n * blowup_factor`. Saves ~50–100 ms of pageable -/// allocator work + page faults at prover scale because the caller's Vecs -/// have been sized once and are reused across calls. +/// must already have length `n * blowup_factor`. Avoids pageable allocator +/// work and page faults at prover scale because the caller's Vecs have been +/// sized once and are reused across calls. pub fn coset_lde_batch_base_into( columns: &[&[u64]], blowup_factor: usize, @@ -373,6 +335,11 @@ pub fn coset_lde_batch_base_into( let m = columns.len(); assert_eq!(outputs.len(), m, "outputs must match columns count"); let n = columns[0].len(); + // Empty columns must short-circuit before the power-of-two assert + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(()); + } assert!(n.is_power_of_two(), "column length must be a power of two"); assert_eq!(weights.len(), n, "weights length must match column length"); assert!( @@ -386,13 +353,11 @@ pub fn coset_lde_batch_base_into( for o in outputs.iter() { assert_eq!(o.len(), lde_size, "each output must be lde_size"); } - if n == 0 { - return Ok(()); - } + assert_u32_domain(lde_size, "coset_lde_batch_base_into lde_size"); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); @@ -496,7 +461,6 @@ pub fn coset_lde_batch_base_into( // Parallel copy pinned → caller outputs. Caller's Vecs may still fault // on first write; we spread that cost across rayon cores. #[allow(unused_imports)] - use rayon::prelude::*; let pinned_ptr = pinned.as_ptr() as usize; outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { let src = unsafe { @@ -558,6 +522,11 @@ fn evaluate_poly_coset_batch_ext3_into_inner( } let m = coefs.len(); assert_eq!(outputs.len(), m); + // Empty domain must short-circuit before the power-of-two assert + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(None); + } assert!(n.is_power_of_two()); assert_eq!(weights.len(), n); assert!(blowup_factor.is_power_of_two()); @@ -568,13 +537,11 @@ fn evaluate_poly_coset_batch_ext3_into_inner( for o in outputs.iter() { assert_eq!(o.len(), 3 * lde_size); } - if n == 0 { - return Ok(None); - } + assert_u32_domain(lde_size, "evaluate_poly_coset_batch_ext3_into lde_size"); let log_lde = lde_size.trailing_zeros() as u64; let mb = 3 * m; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); @@ -582,7 +549,6 @@ fn evaluate_poly_coset_batch_ext3_into_inner( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - use rayon::prelude::*; let pinned_ptr_u = pinned.as_mut_ptr() as usize; coefs.par_iter().enumerate().for_each(|(c, col)| { let slab_a = unsafe { @@ -736,6 +702,11 @@ pub fn coset_lde_batch_ext3_into( } let m = columns.len(); assert_eq!(outputs.len(), m, "outputs must match columns count"); + // Empty domain must short-circuit before the power-of-two assert + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(()); + } assert!(n.is_power_of_two(), "n must be a power of two"); assert_eq!(weights.len(), n, "weights length must match n"); assert!( @@ -749,16 +720,14 @@ pub fn coset_lde_batch_ext3_into( for o in outputs.iter() { assert_eq!(o.len(), 3 * lde_size, "each output must be 3*lde_size u64s"); } - if n == 0 { - return Ok(()); - } + assert_u32_domain(lde_size, "coset_lde_batch_ext3_into lde_size"); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; // 3 base slabs per ext3 column; slab index `c*3 + k` holds component `k`. let mb = 3 * m; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); @@ -769,7 +738,6 @@ pub fn coset_lde_batch_ext3_into( // Pack: for each ext3 column, write 3 base slabs into pinned. The slab // for column c, component k lives at `pinned[(c*3 + k)*n .. (c*3+k)*n + n]`. // We de-interleave from the interleaved `[a, b, c, a, b, c, ...]` input. - use rayon::prelude::*; let pinned_ptr_u = pinned.as_mut_ptr() as usize; columns.par_iter().enumerate().for_each(|(c, col)| { // SAFETY: disjoint regions per c; staging lock held. @@ -925,7 +893,7 @@ fn run_batched_ntt_body( col_stride: u64, m: u32, ) -> Result<()> { - let be = backend(); + let be = backend()?; let fused = core::cmp::min(log_n, 8); if fused >= 8 { let grid_x = (n / 256) as u32; diff --git a/crypto/math-cuda/src/lib.rs b/crypto/math-cuda/src/lib.rs index 04ec08b1d..7731fe4c3 100644 --- a/crypto/math-cuda/src/lib.rs +++ b/crypto/math-cuda/src/lib.rs @@ -37,7 +37,7 @@ pub fn gl_neg_u64(a: &[u64]) -> Result> { if n == 0 { return Ok(Vec::new()); } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let a_dev = stream.clone_htod(a)?; @@ -68,7 +68,7 @@ pub fn ext3_mul_u64(a: &[u64], b: &[u64]) -> Result> { if n == 0 { return Ok(Vec::new()); } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let a_dev = stream.clone_htod(a)?; let b_dev = stream.clone_htod(b)?; @@ -97,7 +97,7 @@ pub fn ext3_sub_u64(a: &[u64], b: &[u64]) -> Result> { if n == 0 { return Ok(Vec::new()); } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let a_dev = stream.clone_htod(a)?; let b_dev = stream.clone_htod(b)?; @@ -126,7 +126,7 @@ pub fn ext3_add_u64(a: &[u64], b: &[u64]) -> Result> { if n == 0 { return Ok(Vec::new()); } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let a_dev = stream.clone_htod(a)?; let b_dev = stream.clone_htod(b)?; @@ -156,7 +156,7 @@ where if n == 0 { return Ok(Vec::new()); } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let a_dev = stream.clone_htod(a)?; diff --git a/crypto/math-cuda/src/ntt.rs b/crypto/math-cuda/src/ntt.rs index 31333c3ae..f5accc3b1 100644 --- a/crypto/math-cuda/src/ntt.rs +++ b/crypto/math-cuda/src/ntt.rs @@ -14,10 +14,13 @@ use math::field::traits::{IsFFTField, IsField}; use crate::Result; use crate::device::backend; -/// Host-side twiddle table: `[ω^0, ω^1, …, ω^{n/2-1}]` where ω is the +/// Host-side twiddle table: `[ω^0, ω^1, ..., ω^{n/2-1}]` where ω is the /// primitive n-th root of unity. Exposed for `device::Backend::cached_twiddles` /// and for direct use in tests / benches. pub fn twiddles_forward(log_n: u64) -> Vec { + // Smallest meaningful NTT is size 2 (log_n = 1); size-1 has nothing to + // twiddle. The shift `1 << (log_n - 1)` underflows for log_n = 0. + assert!(log_n >= 1, "twiddles_forward: log_n must be >= 1"); let omega = *GoldilocksField::get_primitive_root_of_unity(log_n) .expect("primitive root") .value(); @@ -26,6 +29,7 @@ pub fn twiddles_forward(log_n: u64) -> Vec { /// Inverse twiddle table: `[ω^{-i}]` for i in [0, n/2). pub fn twiddles_inverse(log_n: u64) -> Vec { + assert!(log_n >= 1, "twiddles_inverse: log_n must be >= 1"); let omega = GoldilocksField::get_primitive_root_of_unity(log_n).expect("primitive root"); let omega_inv = FieldElement::::inv(&omega).expect("inverse"); powers_of(*omega_inv.value(), 1usize << (log_n - 1)) @@ -56,13 +60,20 @@ pub fn inverse(evals: &[u64]) -> Result> { fn ntt_inplace(input: &[u64], forward: bool) -> Result> { let n = input.len(); - assert!(n.is_power_of_two(), "ntt length must be a power of two"); + // Empty / size-1 has no work to do. `is_power_of_two()` returns false for + // 0, so this branch must come before the assert to avoid panicking on + // empty input. if n <= 1 { return Ok(input.to_vec()); } + assert!(n.is_power_of_two(), "ntt length must be a power of two"); + assert!( + n <= u32::MAX as usize, + "ntt length {n} exceeds u32 range — kernel grid would silently truncate", + ); let log_n = n.trailing_zeros() as u64; - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let mut x_dev = stream.clone_htod(input)?; @@ -117,7 +128,7 @@ pub(crate) fn run_ntt_body( n: u64, log_n: u64, ) -> Result<()> { - let be = backend(); + let be = backend()?; // Levels 0..min(log_n, 8): one shmem-fused launch. Loads are fully // coalesced (base_step=0 → `row = tid`) and 8 butterfly rounds stay on // chip. This is the big DRAM-bandwidth win. @@ -183,7 +194,7 @@ pub fn pointwise_mul(x: &[u64], w: &[u64]) -> Result> { if n == 0 { return Ok(Vec::new()); } - let be = backend(); + let be = backend()?; let stream = be.next_stream(); let mut x_dev = stream.clone_htod(x)?; diff --git a/crypto/math-cuda/tests/ext3_edge.rs b/crypto/math-cuda/tests/ext3_edge.rs index 55b20bc92..f298fe884 100644 --- a/crypto/math-cuda/tests/ext3_edge.rs +++ b/crypto/math-cuda/tests/ext3_edge.rs @@ -1,6 +1,6 @@ //! Adversarial edge-case parity tests for `ext3::mul` on the GPU. //! -//! The CUDA `dot3` (kernels/ext3.cuh:62-98) manually tracks overflow when +//! The CUDA `dot3` in kernels/ext3.cuh manually tracks overflow when //! summing three u128 products in split u64 hi/lo registers. The CPU //! reference (`crypto/math/src/field/goldilocks.rs::dot_product_3` via //! `Degree3GoldilocksExtensionField::mul`) uses native u128 and so reaches diff --git a/crypto/math-cuda/tests/ext3_sub.rs b/crypto/math-cuda/tests/ext3_sub.rs index 3593379d1..1e94682a0 100644 --- a/crypto/math-cuda/tests/ext3_sub.rs +++ b/crypto/math-cuda/tests/ext3_sub.rs @@ -1,4 +1,4 @@ -//! Parity test for `ext3::sub` (kernels/ext3.cuh:41-45). This device +//! Parity test for `ext3::sub` in kernels/ext3.cuh. This device //! function is part of the public ext3 header but is not invoked by any //! kernel in the PR — every other ext3 caller uses `mul`, `add`, or //! `mul_base`. The PR's review test infrastructure adds an @@ -78,7 +78,7 @@ fn ext3_sub_edge_cases() { ([0, 0, 0], [P - 1, P - 1, P - 1]), ([1, 2, 3], [P - 1, P - 1, P - 1]), ([P - 1, P - 1, P - 1], [0, 0, 0]), - ([P, P, P], [P, P, P]), // (0,0,0) - (0,0,0) + ([P, P, P], [P, P, P]), // (0,0,0) - (0,0,0) ([u64::MAX, u64::MAX, u64::MAX], [0, 0, 0]), ([0, 0, 0], [u64::MAX, u64::MAX, u64::MAX]), ]; diff --git a/crypto/math-cuda/tests/lde.rs b/crypto/math-cuda/tests/lde.rs index 33f98f9ae..110997e6e 100644 --- a/crypto/math-cuda/tests/lde.rs +++ b/crypto/math-cuda/tests/lde.rs @@ -12,8 +12,8 @@ use rand_chacha::ChaCha8Rng; type Fp = FieldElement; -/// Build the coset weights `[1/N, g/N, g²/N, …, g^{n-1}/N]` — this is the -/// layout `crypto/stark/src/prover.rs:248` uses, with `1/N` pre-folded into the +/// Build the coset weights `[1/N, g/N, g²/N, ..., g^{n-1}/N]` — this is the +/// layout `crypto/stark/src/prover.rs` uses, with `1/N` pre-folded into the /// first coefficient so the iFFT step does not need a separate scaling pass. fn coset_weights(n: usize, coset_offset: u64) -> Vec { let inv_n_fe = FieldElement::::from(n as u64) diff --git a/crypto/math-cuda/tests/lde_batch_into.rs b/crypto/math-cuda/tests/lde_batch_into.rs index 19411d152..c3d25adbf 100644 --- a/crypto/math-cuda/tests/lde_batch_into.rs +++ b/crypto/math-cuda/tests/lde_batch_into.rs @@ -1,4 +1,4 @@ -//! Direct parity test for `coset_lde_batch_base_into` (lde.rs:331), the +//! Direct parity test for `coset_lde_batch_base_into` (lde.rs), the //! caller-allocated-buffer variant of `coset_lde_batch_base`. The two should //! produce bit-identical canonical output for the same inputs; the only //! difference between them is who owns the output Vec. @@ -27,7 +27,7 @@ fn coset_weights(n: usize, g: u64) -> Vec { } fn canon(xs: &[u64]) -> Vec { - xs.iter().map(|x| GoldilocksField::canonical(x)).collect() + xs.iter().map(GoldilocksField::canonical).collect() } fn run_pair(log_n: u64, blowup: usize, m: usize, seed: u64) { diff --git a/crypto/math-cuda/tests/ntt.rs b/crypto/math-cuda/tests/ntt.rs index 17a556c74..c02892204 100644 --- a/crypto/math-cuda/tests/ntt.rs +++ b/crypto/math-cuda/tests/ntt.rs @@ -61,13 +61,11 @@ fn ntt_sizes_medium() { #[test] fn ntt_size_2_to_20() { - // The hot LDE size. One seed is enough; any mismatch screams loudly. assert_ntt_match(20, 0xDEAD); } #[test] fn ntt_trivial_sizes() { - // Power-of-two below the interesting range — should still pass. assert_ntt_match(1, 1); assert_ntt_match(2, 2); assert_ntt_match(3, 3); diff --git a/crypto/math-cuda/tests/ntt_known.rs b/crypto/math-cuda/tests/ntt_known.rs index b2cb2a162..f0a7b9f5c 100644 --- a/crypto/math-cuda/tests/ntt_known.rs +++ b/crypto/math-cuda/tests/ntt_known.rs @@ -17,7 +17,7 @@ fn canon(x: u64) -> u64 { /// p(x) = 1 + x. NTT at size N is the vector [p(omega^0), p(omega^1), ..., /// p(omega^(N-1))] for `omega` a primitive N-th root of unity. We compute -/// the reference by direct exponentiation of `omega` — no FFT involved — +/// the reference by direct exponentiation of `omega` (no FFT involved) /// so a bug in either the CPU or GPU NTT can't hide here. #[test] fn ntt_known_polynomial_x_plus_one_size_256() { @@ -25,14 +25,14 @@ fn ntt_known_polynomial_x_plus_one_size_256() { let n: usize = 1 << log_n; // GoldilocksField uses bit-reversed coefficient layout for forward NTT: - // input coeffs at index `i` become `p(omega^bitrev(i))` — confirm by + // input coeffs at index `i` become `p(omega^bitrev(i))`. Confirm by // matching against the existing forward() API which random-input tests // already validate against `Polynomial::evaluate_fft`. The known-poly // value lets us catch systematic errors in that pipeline that random // inputs miss. // Input coefficients: [1, 1, 0, 0, ..., 0]. (Natural order, lowest - // degree first — `Polynomial::new` and `math_cuda::ntt::forward` both + // degree first. `Polynomial::new` and `math_cuda::ntt::forward` both // expect this convention.) let mut input = vec![0u64; n]; input[0] = 1; @@ -47,16 +47,16 @@ fn ntt_known_polynomial_x_plus_one_size_256() { let one = Fp::from_raw(1); let mut expected = Vec::with_capacity(n); - let mut omega_i = one.clone(); // omega^0 + let mut omega_i = one; // omega^0 for _ in 0..n { let val = &one + &omega_i; expected.push(*val.value()); omega_i = &omega_i * ω } - for i in 0..n { - let g = canon(gpu[i]); - let e = canon(expected[i]); + for (i, (&g_raw, &e_raw)) in gpu.iter().zip(expected.iter()).enumerate() { + let g = canon(g_raw); + let e = canon(e_raw); if g != e { panic!( "p(omega^{i}) mismatch: gpu canon {:#018x}, expected canon {:#018x} (omega^{i} computed independently of any FFT)", @@ -83,11 +83,11 @@ fn ntt_known_polynomial_x_plus_one_size_16() { let omega = GoldilocksField::get_primitive_root_of_unity(log_n).expect("root"); let one = Fp::from_raw(1); - let mut omega_i = one.clone(); - for i in 0..n { + let mut omega_i = one; + for (i, &g) in gpu.iter().enumerate() { let exp = &one + &omega_i; assert_eq!( - canon(gpu[i]), + canon(g), canon(*exp.value()), "p(omega^{i}) mismatch at size 16" ); @@ -113,13 +113,13 @@ fn ntt_known_polynomial_x_half_alternating() { // Expected: omega^(k*i) for i = 0..N, which is (omega^k)^i = (-1)^i. // canonical(+1) = 1; canonical(-1) = p - 1. let p_minus_one = 0xFFFF_FFFF_0000_0001u64 - 1; - for i in 0..n { + for (i, &g) in gpu.iter().enumerate() { let exp = if i % 2 == 0 { 1u64 } else { p_minus_one }; assert_eq!( - canon(gpu[i]), + canon(g), exp, "x^(N/2) NTT alternation mismatch at i={i}: got {:#018x}", - canon(gpu[i]) + canon(g) ); } } From 033fe89607f30051af403502eb14856361b83ac7 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 15 May 2026 17:34:40 -0300 Subject: [PATCH 10/22] address PR review: correctness fixes and missing parity tests --- crypto/math-cuda/src/lde.rs | 49 +++++++-- crypto/math-cuda/src/merkle.rs | 34 ++++-- crypto/math-cuda/tests/keccak_leaves.rs | 138 ++++++++++++++++++++++++ 3 files changed, 204 insertions(+), 17 deletions(-) diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index a3ebe184d..ab9105264 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -494,10 +494,18 @@ pub fn coset_lde_batch_base_into_with_leaf_hash( let m = columns.len(); assert_eq!(outputs.len(), m); let n = columns[0].len(); + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(()); + } assert!(n.is_power_of_two()); assert_eq!(weights.len(), n); assert!(blowup_factor.is_power_of_two()); let lde_size = n * blowup_factor; + assert_u32_domain( + lde_size, + "coset_lde_batch_base_into_with_leaf_hash lde_size", + ); for o in outputs.iter() { assert_eq!(o.len(), lde_size); } @@ -718,10 +726,18 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( let m = columns.len(); assert_eq!(outputs.len(), m); let n = columns[0].len(); + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(None); + } assert!(n.is_power_of_two()); assert_eq!(weights.len(), n); assert!(blowup_factor.is_power_of_two()); let lde_size = n * blowup_factor; + assert_u32_domain( + lde_size, + "coset_lde_batch_base_into_with_merkle_tree lde_size", + ); for o in outputs.iter() { assert_eq!(o.len(), lde_size); } @@ -943,6 +959,10 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( assert_eq!(outputs.len(), 0); return Ok(()); } + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(()); + } let m = columns.len(); assert_eq!(outputs.len(), m); assert!(n.is_power_of_two()); @@ -952,13 +972,14 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( assert_eq!(c.len(), 3 * n); } let lde_size = n * blowup_factor; + assert_u32_domain( + lde_size, + "coset_lde_batch_ext3_into_with_leaf_hash lde_size", + ); for o in outputs.iter() { assert_eq!(o.len(), 3 * lde_size); } assert_eq!(hashed_leaves_out.len(), lde_size * 32); - if n == 0 { - return Ok(()); - } let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; @@ -1191,6 +1212,10 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( assert_eq!(outputs.len(), 0); return Ok(None); } + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(None); + } let m = columns.len(); assert_eq!(outputs.len(), m); assert!(n.is_power_of_two()); @@ -1200,14 +1225,15 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( assert_eq!(c.len(), 3 * n); } let lde_size = n * blowup_factor; + assert_u32_domain( + lde_size, + "coset_lde_batch_ext3_into_with_merkle_tree lde_size", + ); for o in outputs.iter() { assert_eq!(o.len(), 3 * lde_size); } let total_nodes = 2 * lde_size - 1; assert_eq!(merkle_nodes_out.len(), total_nodes * 32); - if n == 0 { - return Ok(None); - } let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; @@ -1645,6 +1671,10 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( if coefs.is_empty() { return Ok(()); } + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(()); + } let m = coefs.len(); assert_eq!(outputs.len(), m); assert!(n.is_power_of_two()); @@ -1654,15 +1684,16 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( assert_eq!(c.len(), 3 * n); } let lde_size = n * blowup_factor; + assert_u32_domain( + lde_size, + "evaluate_poly_coset_batch_ext3_into_with_merkle_tree lde_size", + ); for o in outputs.iter() { assert_eq!(o.len(), 3 * lde_size); } assert!(lde_size >= 2); let total_nodes = lde_size - 1; assert_eq!(merkle_nodes_out.len(), total_nodes * 32); - if n == 0 { - return Ok(()); - } let log_lde = lde_size.trailing_zeros() as u64; let mb = 3 * m; diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs index 0a547568e..5332fe8e6 100644 --- a/crypto/math-cuda/src/merkle.rs +++ b/crypto/math-cuda/src/merkle.rs @@ -32,10 +32,17 @@ pub fn keccak_leaves_base( num_rows: usize, ) -> Result> { assert!(num_rows.is_power_of_two()); - assert!(columns.len() >= num_cols * col_stride); + assert!( + col_stride >= num_rows, + "col_stride must be >= num_rows to keep per-column reads in-bounds" + ); + let total = num_cols + .checked_mul(col_stride) + .expect("num_cols * col_stride overflows usize"); + assert!(columns.len() >= total); let be = backend()?; let stream = be.next_stream(); - let cols_dev = stream.clone_htod(&columns[..num_cols * col_stride])?; + let cols_dev = stream.clone_htod(&columns[..total])?; let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; launch_keccak_base( stream.as_ref(), @@ -59,10 +66,18 @@ pub fn keccak_leaves_ext3( num_rows: usize, ) -> Result> { assert!(num_rows.is_power_of_two()); - assert!(columns.len() >= num_cols * 3 * col_stride); + assert!( + col_stride >= num_rows, + "col_stride must be >= num_rows to keep per-column reads in-bounds" + ); + let total = num_cols + .checked_mul(3) + .and_then(|v| v.checked_mul(col_stride)) + .expect("num_cols * 3 * col_stride overflows usize"); + assert!(columns.len() >= total); let be = backend()?; let stream = be.next_stream(); - let cols_dev = stream.clone_htod(&columns[..num_cols * 3 * col_stride])?; + let cols_dev = stream.clone_htod(&columns[..total])?; let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; launch_keccak_ext3( stream.as_ref(), @@ -100,6 +115,9 @@ pub(crate) fn launch_keccak_base( num_rows: u64, out_dev: &mut CudaSlice, ) -> Result<()> { + // The kernel computes `__brevll(tid) >> (64 - log_num_rows)`, which is UB + // for `log_num_rows == 0` (single-row trees are degenerate anyway). + debug_assert!(num_rows >= 2, "keccak leaf kernel: num_rows must be >= 2"); let be = backend()?; let log_num_rows = num_rows.trailing_zeros() as u64; let cfg = keccak_launch_cfg(num_rows); @@ -326,11 +344,8 @@ pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { ); let num_evals = evals.len() / 3; let num_leaves = num_evals / 2; - assert!(num_leaves.is_power_of_two() && num_leaves >= 1); + assert!(num_leaves.is_power_of_two() && num_leaves >= 2); let tight_total_nodes = 2 * num_leaves - 1; - if tight_total_nodes == 0 { - return Ok(Vec::new()); - } let be = backend()?; let stream = be.next_stream(); @@ -397,6 +412,9 @@ pub(crate) fn launch_keccak_ext3( num_rows: u64, out_dev: &mut CudaSlice, ) -> Result<()> { + // The kernel computes `__brevll(tid) >> (64 - log_num_rows)`, which is UB + // for `log_num_rows == 0` (single-row trees are degenerate anyway). + debug_assert!(num_rows >= 2, "keccak leaf kernel: num_rows must be >= 2"); let be = backend()?; let log_num_rows = num_rows.trailing_zeros() as u64; let cfg = keccak_launch_cfg(num_rows); diff --git a/crypto/math-cuda/tests/keccak_leaves.rs b/crypto/math-cuda/tests/keccak_leaves.rs index 1e451b386..33610aeca 100644 --- a/crypto/math-cuda/tests/keccak_leaves.rs +++ b/crypto/math-cuda/tests/keccak_leaves.rs @@ -93,6 +93,144 @@ fn keccak_leaves_base_matches_cpu() { } } +// Row-pair leaves for the R2 composition-polynomial commit. For each leaf i: +// br_0 = bit_reverse(2*i, log_lde), br_1 = bit_reverse(2*i+1, log_lde) +// hash is Keccak256 of BE bytes of every part's ext3 value at br_0 then br_1 +// (matching `commit_composition_polynomial` on the CPU side). +fn cpu_leaves_comp_poly(parts: &[Vec]) -> Vec<[u8; 32]> { + let lde_size = parts[0].len(); + let num_parts = parts.len(); + let num_leaves = lde_size / 2; + let byte_len = 24; + (0..num_leaves) + .map(|i| { + let br_0 = reverse_index((2 * i) as u64, lde_size as u64) as usize; + let br_1 = reverse_index((2 * i + 1) as u64, lde_size as u64) as usize; + let mut buf = vec![0u8; 2 * num_parts * byte_len]; + for (p, part) in parts.iter().enumerate() { + part[br_0].write_bytes_be(&mut buf[p * byte_len..(p + 1) * byte_len]); + } + let off = num_parts * byte_len; + for (p, part) in parts.iter().enumerate() { + part[br_1].write_bytes_be(&mut buf[off + p * byte_len..off + (p + 1) * byte_len]); + } + let mut h = Keccak256::new(); + h.update(&buf); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out + }) + .collect() +} + +// FRI leaves: each leaf hashes 2 consecutive ext3 evals, no bit reversal. +fn cpu_leaves_fri(evals: &[Fp3]) -> Vec<[u8; 32]> { + let num_leaves = evals.len() / 2; + let byte_len = 24; + (0..num_leaves) + .map(|i| { + let mut buf = vec![0u8; 2 * byte_len]; + evals[2 * i].write_bytes_be(&mut buf[..byte_len]); + evals[2 * i + 1].write_bytes_be(&mut buf[byte_len..]); + let mut h = Keccak256::new(); + h.update(&buf); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out + }) + .collect() +} + +#[test] +fn keccak_comp_poly_leaves_matches_cpu() { + // Built tree's leaves live at byte offset `(num_leaves - 1) * 32` and + // span `num_leaves * 32` bytes. Compare those to the CPU reference. + for log_lde in [2u32, 4, 6, 8, 10, 12] { + for num_parts in [1usize, 2, 5, 17] { + let lde_size = 1usize << log_lde; + let mut rng = ChaCha8Rng::seed_from_u64(300 + log_lde as u64 + num_parts as u64); + let parts: Vec> = (0..num_parts) + .map(|_| { + (0..lde_size) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() + }) + .collect(); + let cpu = cpu_leaves_comp_poly(&parts); + + // Each part is passed as `[a0,a1,a2, b0,b1,b2, ...]` of length `3 * lde_size`. + let parts_interleaved: Vec> = parts + .iter() + .map(|p| { + let mut v = vec![0u64; 3 * lde_size]; + for (i, e) in p.iter().enumerate() { + v[i * 3] = *e.value()[0].value(); + v[i * 3 + 1] = *e.value()[1].value(); + v[i * 3 + 2] = *e.value()[2].value(); + } + v + }) + .collect(); + let parts_slices: Vec<&[u64]> = + parts_interleaved.iter().map(|v| v.as_slice()).collect(); + + let nodes = + math_cuda::merkle::build_comp_poly_tree_from_evals_ext3(&parts_slices).unwrap(); + let num_leaves = lde_size / 2; + let leaves_offset = (num_leaves - 1) * 32; + for i in 0..num_leaves { + assert_eq!( + &nodes[leaves_offset + i * 32..leaves_offset + (i + 1) * 32], + &cpu[i][..], + "comp-poly leaf mismatch at i={i} (log_lde={log_lde}, parts={num_parts})" + ); + } + } + } +} + +#[test] +fn keccak_fri_leaves_matches_cpu() { + for log_lde in [2u32, 4, 6, 8, 10, 12] { + let lde_size = 1usize << log_lde; + let mut rng = ChaCha8Rng::seed_from_u64(400 + log_lde as u64); + let evals: Vec = (0..lde_size) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect(); + let cpu = cpu_leaves_fri(&evals); + + let mut evals_interleaved = vec![0u64; 3 * lde_size]; + for (i, e) in evals.iter().enumerate() { + evals_interleaved[i * 3] = *e.value()[0].value(); + evals_interleaved[i * 3 + 1] = *e.value()[1].value(); + evals_interleaved[i * 3 + 2] = *e.value()[2].value(); + } + let nodes = + math_cuda::merkle::build_fri_layer_tree_from_evals_ext3(&evals_interleaved).unwrap(); + let num_leaves = lde_size / 2; + let leaves_offset = (num_leaves - 1) * 32; + for i in 0..num_leaves { + assert_eq!( + &nodes[leaves_offset + i * 32..leaves_offset + (i + 1) * 32], + &cpu[i][..], + "fri leaf mismatch at i={i} (log_lde={log_lde})" + ); + } + } +} + #[test] fn keccak_leaves_ext3_matches_cpu() { for log_n in [4u32, 6, 8, 10] { From 7956c75c50037018b9c2744a2aa8dacc92938563 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 15 May 2026 17:36:25 -0300 Subject: [PATCH 11/22] address PR review: cleanups --- crypto/math-cuda/kernels/keccak.cu | 20 ++++---- crypto/math-cuda/src/device.rs | 7 ++- crypto/math-cuda/src/lde.rs | 81 +++++++++++------------------- crypto/math-cuda/src/merkle.rs | 12 ++--- 4 files changed, 47 insertions(+), 73 deletions(-) diff --git a/crypto/math-cuda/kernels/keccak.cu b/crypto/math-cuda/kernels/keccak.cu index 68ddce3b4..8ea9c28b1 100644 --- a/crypto/math-cuda/kernels/keccak.cu +++ b/crypto/math-cuda/kernels/keccak.cu @@ -1,7 +1,7 @@ -// CUDA Keccak-256 (original Keccak, NOT SHA3-256 — uses 0x01 padding delimiter). +// Original Keccak-256 (0x01 padding) // // Used by the lambda-vm prover's Merkle commit: -// leaf = Keccak-256(concat(col_0[br_idx].to_be_bytes(), col_1[br_idx].to_be_bytes(), …)) +// leaf = Keccak-256(concat(col_0[br_idx].to_be_bytes(), col_1[br_idx].to_be_bytes(), ...)) // where `br_idx = bit_reverse(row_idx, log_num_rows)` and each element is // written in BIG-ENDIAN canonical form (per `FieldElement::write_bytes_be`). // @@ -96,9 +96,9 @@ __device__ __forceinline__ void keccak_f1600(uint64_t st[25]) { } // --------------------------------------------------------------------------- -// Helper: absorb one 8-byte lane (already in lane form — i.e. LE interpretation -// of the BE-serialised u64) into the sponge at `rate_pos` (in bytes). Permutes -// when a full 136-byte block has been absorbed. +// Helper: absorb one 8-byte lane (already byte-swapped from BE serialisation +// into Keccak's LE lane form) into the sponge at `rate_pos` (in bytes). +// Permutes when a full 136-byte block has been absorbed. // --------------------------------------------------------------------------- __device__ __forceinline__ void absorb_lane(uint64_t st[25], uint32_t &rate_pos, @@ -172,7 +172,7 @@ extern "C" __global__ void keccak256_leaves_base_batched( uint64_t v = columns_base_ptr[c * col_stride + br]; // Canonicalise to match `canonical_u64().to_be_bytes()` on host. uint64_t canon = goldilocks::canonical(v); - // The on-disk leaf bytes are canon.to_be_bytes(); Keccak reads those + // The on-disk leaf bytes are canon.to_be_bytes(). Keccak reads those // as a LE lane, which equals bswap64(canon). uint64_t lane = bswap64(canon); absorb_lane(st, rate_pos, lane); @@ -275,8 +275,8 @@ extern "C" __global__ void keccak_comp_poly_leaves_ext3( // // Each leaf hashes 2 consecutive ext3 values: Keccak256 over // evals[2j].to_bytes_be() ++ evals[2j+1].to_bytes_be() -// = 48 BE bytes = 6 u64 BE lanes. No bit reversal; no column slab layout — -// the input is a single interleaved ext3 eval vector `[a0,a1,a2,b0,b1,b2,...]`. +// = 48 BE bytes = 6 u64 BE lanes. No bit reversal, no column slab layout. +// The input is a single interleaved ext3 eval vector `[a0,a1,a2,b0,b1,b2,...]`. // --------------------------------------------------------------------------- extern "C" __global__ void keccak_fri_leaves_ext3( const uint64_t *evals_interleaved, // 3 * num_evals u64s (ext3 interleaved) @@ -311,14 +311,14 @@ extern "C" __global__ void keccak_fri_leaves_ext3( // // `nodes` is the full Merkle node buffer (length `2*leaves_len - 1`, each // element 32 bytes). `parent_begin` is the node-index offset of the first -// parent slot in this level; children live at `parent_begin + n_pairs`. +// parent slot in this level. Children live at `parent_begin + n_pairs`. // The layout mirrors `crypto/crypto/src/merkle_tree/merkle.rs`: // // children: nodes[parent_begin + n_pairs .. parent_begin + 3 * n_pairs] // parents: nodes[parent_begin .. parent_begin + n_pairs] // // Each thread hashes one child pair → one parent. Keccak-256 of the -// concatenation of two 32-byte siblings; identical to +// concatenation of two 32-byte siblings, identical to // `FieldElementVectorBackend::hash_new_parent` on host. // --------------------------------------------------------------------------- extern "C" __global__ void keccak_merkle_level( diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index f52c0b1e3..d6d5fc403 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -109,9 +109,8 @@ pub struct Backend { /// cost for every first use of a new table size). pinned_staging: Mutex, /// Separate pinned staging for Merkle leaf hashes. Sized `num_rows * 32` - /// bytes; lives alongside the LDE staging so the GPU→host D2H for - /// hashed leaves runs at full PCIe line-rate instead of the pageable - /// ~1.3 GB/s path that would otherwise eat ~100 ms per main-trace commit. + /// bytes. It lives alongside the LDE staging so the GPU→host D2H for + /// hashed leaves runs at full PCIe line-rate. pinned_hashes: Mutex, util_stream: Arc, next: AtomicUsize, @@ -227,7 +226,7 @@ impl Backend { } /// Separate pinned staging for Merkle leaf hash output. Sized in u64 - /// units; caller should reserve `(num_rows * 32 + 7) / 8` u64s. + /// units. Caller should reserve `(num_rows * 32 + 7) / 8` u64s. pub fn pinned_hashes(&self) -> &Mutex { &self.pinned_hashes } diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index ab9105264..323447903 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -474,9 +474,9 @@ pub fn coset_lde_batch_base_into( } /// Variant of `coset_lde_batch_base_into` that also emits the Keccak-256 -/// Merkle leaf hashes from the LDE output — all on GPU, no second H2D of -/// the LDE data. Leaves are computed reading columns at bit-reversed rows -/// (matching `commit_columns_bit_reversed` on the CPU side). +/// Merkle leaf hashes from the LDE output. All on GPU, the device LDE buffer +/// is hashed in place. Leaves are computed reading columns at bit-reversed +/// rows (matching `commit_columns_bit_reversed` on the CPU side). /// /// `hashed_leaves_out` must be `lde_size * 32` bytes (one 32-byte digest /// per output row, in natural row order). @@ -521,7 +521,6 @@ pub fn coset_lde_batch_base_into_with_leaf_hash( staging.ensure_capacity(m * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - use rayon::prelude::*; let pinned_base_ptr = pinned.as_mut_ptr() as usize; columns.par_iter().enumerate().for_each(|(c, col)| { // SAFETY: disjoint regions per c, outer staging lock held. @@ -614,13 +613,13 @@ pub fn coset_lde_batch_base_into_with_leaf_hash( col_stride_u64, m as u64, lde_u64, - &mut hashes_dev, + &mut hashes_dev.as_view_mut(), )?; - // D2H the LDE into the pinned LDE staging and the hashes into a - // dedicated pinned hash staging, in parallel on the same stream. Both - // at pinned PCIe line-rate — pageable D2H of the 128 MB hash buffer - // would otherwise cost ~100 ms per main-trace commit. + // D2H the LDE into the pinned LDE staging, then the hashes into a + // dedicated pinned hash staging. The two copies run back-to-back on the + // same stream; both go at pinned PCIe line-rate — pageable D2H of the + // 128 MB hash buffer would otherwise cost ~100 ms per main-trace commit. stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; let hashes_u64_len = (lde_size * 32).div_ceil(8); let hashes_staging_slot = be.pinned_hashes(); @@ -635,7 +634,9 @@ pub fn coset_lde_batch_base_into_with_leaf_hash( stream.memcpy_dtoh(&hashes_dev, hashes_pinned_bytes)?; stream.synchronize()?; - // Copy pinned → caller outputs in parallel with the hash memcpy. + // Copy pinned → caller outputs. Both D2H copies have already drained + // (the synchronize above); this is a rayon-parallel host memcpy from + // pinned to pageable memory, not concurrent with any GPU transfer. let pinned_ptr = pinned.as_ptr() as usize; outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { let src = unsafe { @@ -754,7 +755,6 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( staging.ensure_capacity(m * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - use rayon::prelude::*; let pinned_base_ptr = pinned.as_mut_ptr() as usize; columns.par_iter().enumerate().for_each(|(c, col)| { let dst = @@ -851,25 +851,14 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( { let mut leaves_view = nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); - let log_num_rows_leaves = lde_size.trailing_zeros() as u64; - let num_cols_u64 = m as u64; - let grid = (lde_size as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak256_leaves_base_batched) - .arg(&buf) - .arg(&col_stride_u64) - .arg(&num_cols_u64) - .arg(&lde_u64) - .arg(&log_num_rows_leaves) - .arg(&mut leaves_view) - .launch(cfg)?; - } + launch_keccak_base( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut leaves_view, + )?; } // Inner tree levels — mirror the CPU `build(nodes, leaves_len)` scan. @@ -992,7 +981,6 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - use rayon::prelude::*; let pinned_ptr_u = pinned.as_mut_ptr() as usize; columns.par_iter().enumerate().for_each(|(c, col)| { let slab_a = unsafe { @@ -1092,7 +1080,7 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( col_stride_u64, m as u64, lde_u64, - &mut hashes_dev, + &mut hashes_dev.as_view_mut(), )?; // D2H LDE (mb * lde_size u64) and hashes (lde_size * 32 bytes). @@ -1246,7 +1234,6 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - use rayon::prelude::*; let pinned_ptr_u = pinned.as_mut_ptr() as usize; columns.par_iter().enumerate().for_each(|(c, col)| { let slab_a = unsafe { @@ -1344,25 +1331,14 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( { let mut leaves_view = nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); - let log_num_rows_leaves = lde_size.trailing_zeros() as u64; - let num_cols_u64 = m as u64; - let grid = (lde_size as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak256_leaves_ext3_batched) - .arg(&buf) - .arg(&col_stride_u64) - .arg(&num_cols_u64) - .arg(&lde_u64) - .arg(&log_num_rows_leaves) - .arg(&mut leaves_view) - .launch(cfg)?; - } + launch_keccak_ext3( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut leaves_view, + )?; } // Inner tree levels. @@ -1705,7 +1681,6 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - use rayon::prelude::*; let pinned_ptr_u = pinned.as_mut_ptr() as usize; coefs.par_iter().enumerate().for_each(|(c, col)| { let slab_a = unsafe { diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs index 5332fe8e6..35dc38664 100644 --- a/crypto/math-cuda/src/merkle.rs +++ b/crypto/math-cuda/src/merkle.rs @@ -15,7 +15,8 @@ //! — three base slabs per ext3 column — and the kernel reads three u64s per //! column in component order 0,1,2 to match `FieldElement::::write_bytes_be`. -use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; +use cudarc::driver::{CudaSlice, CudaStream, CudaViewMut, LaunchConfig, PushKernelArg}; +use rayon::prelude::*; use crate::Result; use crate::device::backend; @@ -50,7 +51,7 @@ pub fn keccak_leaves_base( col_stride as u64, num_cols as u64, num_rows as u64, - &mut out_dev, + &mut out_dev.as_view_mut(), )?; let out = stream.clone_dtoh(&out_dev)?; stream.synchronize()?; @@ -85,7 +86,7 @@ pub fn keccak_leaves_ext3( col_stride as u64, num_cols as u64, num_rows as u64, - &mut out_dev, + &mut out_dev.as_view_mut(), )?; let out = stream.clone_dtoh(&out_dev)?; stream.synchronize()?; @@ -113,7 +114,7 @@ pub(crate) fn launch_keccak_base( col_stride: u64, num_cols: u64, num_rows: u64, - out_dev: &mut CudaSlice, + out_dev: &mut CudaViewMut<'_, u8>, ) -> Result<()> { // The kernel computes `__brevll(tid) >> (64 - log_num_rows)`, which is UB // for `log_num_rows == 0` (single-row trees are degenerate anyway). @@ -238,7 +239,6 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - use rayon::prelude::*; let pinned_ptr_u = pinned.as_mut_ptr() as usize; parts_interleaved .par_iter() @@ -410,7 +410,7 @@ pub(crate) fn launch_keccak_ext3( col_stride: u64, num_cols: u64, num_rows: u64, - out_dev: &mut CudaSlice, + out_dev: &mut CudaViewMut<'_, u8>, ) -> Result<()> { // The kernel computes `__brevll(tid) >> (64 - log_num_rows)`, which is UB // for `log_num_rows == 0` (single-row trees are degenerate anyway). From 4800f577c36b685dda1673ff855ca101cecd92d5 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 18 May 2026 11:30:51 -0300 Subject: [PATCH 12/22] address PR comments --- crypto/math-cuda/src/merkle.rs | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs index 35dc38664..e304c6352 100644 --- a/crypto/math-cuda/src/merkle.rs +++ b/crypto/math-cuda/src/merkle.rs @@ -3,7 +3,7 @@ //! Matches `FieldElementVectorBackend::hash_data` in //! `crypto/crypto/src/merkle_tree/backends/field_element_vector.rs`, combined //! with the `reverse_index` row read pattern used in -//! `commit_columns_bit_reversed` at `crypto/stark/src/prover.rs:368`. +//! `commit_columns_bit_reversed` at `crypto/stark/src/prover.rs`. //! //! Caller supplies base-field column slabs already laid out as //! `[col * col_stride + row]` (the same layout `coset_lde_batch_base_into` @@ -11,9 +11,10 @@ //! reads each column's canonical u64 at that row, byte-swaps it into a //! Keccak lane, absorbs lane-by-lane, and squeezes 32 bytes per leaf. //! -//! For ext3 columns the layout is `[col*3*col_stride + k*col_stride + row]` -//! — three base slabs per ext3 column — and the kernel reads three u64s per -//! column in component order 0,1,2 to match `FieldElement::::write_bytes_be`. +//! For ext3 columns the layout is `[col*3*col_stride + k*col_stride + row]`, +//! three base-field components per ext3 column, indexed by `k ∈ {0,1,2}`, +//! and the kernel reads three u64s per column in component order 0,1,2 +//! to match `FieldElement::::write_bytes_be`. use cudarc::driver::{CudaSlice, CudaStream, CudaViewMut, LaunchConfig, PushKernelArg}; use rayon::prelude::*; @@ -58,7 +59,7 @@ pub fn keccak_leaves_base( Ok(out) } -/// Ext3 variant — columns interleaved as three base slabs per ext3 column. +/// Ext3 variant. Columns interleaved as three base slabs per ext3 column. /// `columns.len() >= num_cols * 3 * col_stride`. pub fn keccak_leaves_ext3( columns: &[u64], @@ -94,7 +95,7 @@ pub fn keccak_leaves_ext3( } /// Block size for Keccak kernels. Per-thread register footprint is ~60 regs -/// (25-lane state + auxiliaries); the default 256 threads/block pushes the +/// (25-lane state + auxiliaries). The default 256 threads/block pushes the /// block register file past the hardware limit on sm_120 (Blackwell). 128 /// keeps us inside the budget with some head-room. const KECCAK_BLOCK_DIM: u32 = 128; @@ -147,7 +148,7 @@ pub(crate) fn launch_keccak_base( /// the resulting `nodes` Vec plugs straight into `MerkleTree { root, nodes }` /// for downstream proof generation. /// -/// `leaves_len` must be a power of two and ≥ 2. +/// `leaves_len` must be a power of two and >= 2. pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { assert!(hashed_leaves.len().is_multiple_of(32)); let leaves_len = hashed_leaves.len() / 32; @@ -161,7 +162,7 @@ pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { let be = backend()?; let stream = be.next_stream(); - // Allocate the full node buffer without zero-fill — we overwrite the + // Allocate the full node buffer without zero-fill. We overwrite the // leaf half via H2D immediately, and every inner node is written by the // pair-hash kernel below. // SAFETY: every byte is written before it is read: leaves are filled by @@ -182,7 +183,7 @@ pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { // and each iteration computes: // new_level_begin_index = level_begin_index / 2 // new_level_length = level_begin_index - new_level_begin_index - // The parents occupy [new_level_begin_index, level_begin_index); the + // The parents occupy [new_level_begin_index, level_begin_index), the // children occupy [level_begin_index, level_end_index + 1). let mut level_begin: u64 = (leaves_len - 1) as u64; while level_begin != 0 { @@ -333,7 +334,7 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res } /// Build a FRI-layer Merkle tree on device from an interleaved ext3 eval -/// vector. Each leaf hashes two consecutive ext3 values; `num_leaves = +/// vector. Each leaf hashes two consecutive ext3 values. `num_leaves = /// evals.len() / 6` (since each ext3 is 3 u64s). /// /// Returns the `(2*num_leaves - 1) * 32`-byte node buffer in standard layout. From a6037a98a4c5e898c93282519b6f99d2b674edb5 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti <56092489+ColoCarletti@users.noreply.github.com> Date: Mon, 18 May 2026 11:35:54 -0300 Subject: [PATCH 13/22] Update crypto/math-cuda/src/merkle.rs Co-authored-by: Gabriel Bosio <38794644+gabrielbosio@users.noreply.github.com> --- crypto/math-cuda/src/merkle.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs index e304c6352..e4ce09468 100644 --- a/crypto/math-cuda/src/merkle.rs +++ b/crypto/math-cuda/src/merkle.rs @@ -376,7 +376,7 @@ pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { } } - // Inner tree levels — identical to the R2 version. + // Inner tree levels, identical to the R2 version. { let mut level_begin: u64 = (num_leaves - 1) as u64; while level_begin != 0 { From c8e1b3dd1bba20bb966a16c6e02e3fdae1f705ce Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 18 May 2026 12:16:00 -0300 Subject: [PATCH 14/22] drop outer keccack round unroll --- crypto/math-cuda/kernels/keccak.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/math-cuda/kernels/keccak.cu b/crypto/math-cuda/kernels/keccak.cu index 8ea9c28b1..115b5836b 100644 --- a/crypto/math-cuda/kernels/keccak.cu +++ b/crypto/math-cuda/kernels/keccak.cu @@ -49,7 +49,7 @@ __device__ __forceinline__ uint64_t bswap64(uint64_t x) { __device__ __forceinline__ void keccak_f1600(uint64_t st[25]) { uint64_t C[5], D[5], B[25]; - #pragma unroll + // No outer unroll: fully unrolling the 24 rounds slowed the kernel ~7.5% on RTX 5090. for (int r = 0; r < 24; ++r) { // Theta #pragma unroll From d17cf916f80448ae65a895270498f0b60dc18f51 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 18 May 2026 12:59:24 -0300 Subject: [PATCH 15/22] refactor --- crypto/math-cuda/kernels/keccak.cu | 2 + crypto/math-cuda/src/lde.rs | 821 ++++++----------------------- crypto/math-cuda/src/merkle.rs | 60 ++- 3 files changed, 200 insertions(+), 683 deletions(-) diff --git a/crypto/math-cuda/kernels/keccak.cu b/crypto/math-cuda/kernels/keccak.cu index 115b5836b..c22bc4d05 100644 --- a/crypto/math-cuda/kernels/keccak.cu +++ b/crypto/math-cuda/kernels/keccak.cu @@ -333,6 +333,8 @@ extern "C" __global__ void keccak_merkle_level( for (int i = 0; i < 25; ++i) st[i] = 0; uint32_t rate_pos = 0; + // `nodes` comes from cuMemAlloc (256-byte aligned); each 32-byte node + // sits at a 32-byte-aligned offset, so the u64 cast is safe. const uint64_t *left = reinterpret_cast( nodes + (parent_begin + n_pairs + 2 * tid) * 32); #pragma unroll diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 323447903..636c64952 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -12,11 +12,11 @@ use std::sync::Arc; -use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; +use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; use rayon::prelude::*; use crate::Result; -use crate::device::backend; +use crate::device::{Backend, backend}; use crate::merkle::{launch_keccak_base, launch_keccak_ext3}; use crate::ntt::run_ntt_body; @@ -32,6 +32,129 @@ fn assert_u32_domain(n: usize, what: &str) { ); } +/// De-interleave `columns` (each `3*n` u64s, ext3-per-element layout +/// `[a, b, c, a, b, c, ...]`) into `pinned` as `3*m` base-field slabs. +/// Component `k` of column `c` lands at `pinned[(c*3 + k)*n .. (c*3 + k)*n + n]`. +/// +/// Caller invariants: `pinned.len() >= 3 * columns.len() * n` and each +/// `columns[c].len() >= 3 * n`. The caller must hold the pinned-staging lock. +fn pack_ext3_to_pinned_slabs(columns: &[&[u64]], pinned: &mut [u64], n: usize) { + let m = columns.len(); + debug_assert!(pinned.len() >= 3 * m * n); + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: each task writes to disjoint `[(c*3 + k)*n .. ..+n]` regions + // of `pinned`. The outer `&mut [u64]` borrow guarantees no aliasing. + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); +} + +/// Re-interleave the `3*m` base-field slabs in `pinned` (layout matches +/// `pack_ext3_to_pinned_slabs`) into `outputs`, writing each as +/// `3*lde_size` interleaved u64s. +fn unpack_pinned_slabs_to_ext3(pinned: &[u64], outputs: &mut [&mut [u64]], lde_size: usize) { + let m = outputs.len(); + debug_assert!(pinned.len() >= 3 * m * lde_size); + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + // SAFETY: each task reads from disjoint `[(c*3 + k)*lde_size .. ..+lde_size]` + // regions of `pinned`. Caller borrows `pinned` for the duration of the call. + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); +} + +/// Run `bit_reverse_permute_batched` over `m` columns of length `n` each +/// (column stride `col_stride`). 256 threads per block, grid sized to cover +/// `n` per column. +fn launch_bit_reverse_batched( + stream: &CudaStream, + be: &Backend, + buf: &mut CudaSlice, + n: u64, + log_n: u64, + col_stride: u64, + m: u32, +) -> Result<()> { + let cfg = LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(buf) + .arg(&n) + .arg(&log_n) + .arg(&col_stride) + .launch(cfg)?; + } + Ok(()) +} + +/// Run `pointwise_mul_batched`: `buf[c*col_stride + i] *= weights[i]` for +/// `m` columns, `n` elements each. +fn launch_pointwise_mul_batched( + stream: &CudaStream, + be: &Backend, + buf: &mut CudaSlice, + weights: &CudaSlice, + n: u64, + col_stride: u64, + m: u32, +) -> Result<()> { + let cfg = LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(buf) + .arg(weights) + .arg(&n) + .arg(&col_stride) + .launch(cfg)?; + } + Ok(()) +} + /// Handle to a base-field LDE kept live on device after R1 commit. /// Layout: `m` columns, each `lde_size` u64s, column `c` at byte offset /// `c * lde_size * 8` within `buf`. Freed when `buf` Arc drops. @@ -209,23 +332,7 @@ pub fn coset_lde_batch_base( let m_u32 = m as u32; // === 1. Bit-reverse first N of every column === - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, m_u32)?; // === 2. iNTT body over all columns === run_batched_ntt_body( @@ -239,42 +346,10 @@ pub fn coset_lde_batch_base( )?; // === 3. Pointwise multiply by coset weights (includes 1/N) === - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, m_u32)?; // === 4. Bit-reverse full LDE of every column === - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, m_u32)?; // === 5. Forward NTT on full LDE of every column === run_batched_ntt_body( @@ -386,23 +461,7 @@ pub fn coset_lde_batch_base_into( let m_u32 = m as u32; // iNTT bit-reverse + body, pointwise mul, forward bit-reverse + body. - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, m_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -412,40 +471,8 @@ pub fn coset_lde_batch_base_into( col_stride_u64, m_u32, )?; - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, m_u32)?; + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, m_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -545,19 +572,7 @@ pub fn coset_lde_batch_base_into_with_leaf_hash( let m_u32 = m as u32; // iNTT - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, m_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -568,33 +583,9 @@ pub fn coset_lde_batch_base_into_with_leaf_hash( m_u32, )?; // pointwise coset scale - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, m_u32)?; // forward NTT on full LDE slab - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((lde_size as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, m_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -778,19 +769,7 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( let m_u32 = m as u32; // iNTT - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, m_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -800,33 +779,9 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( col_stride_u64, m_u32, )?; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, m_u32)?; // forward NTT at LDE size - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((lde_size as u32).div_ceil(256), m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, m_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -861,29 +816,7 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( )?; } - // Inner tree levels — mirror the CPU `build(nodes, leaves_len)` scan. - { - let mut level_begin: u64 = (lde_size - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - let grid = (n_pairs as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) - .launch(cfg)?; - } - level_begin = new_begin; - } - } + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; // D2H the LDE and the tree nodes via pinned staging. stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; @@ -981,23 +914,7 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) - }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); + pack_ext3_to_pinned_slabs(columns, pinned, n); let mut buf = stream.alloc_zeros::(mb * lde_size)?; for s in 0..mb { @@ -1014,19 +931,7 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( let col_stride_u64 = lde_size as u64; let mb_u32 = mb as u32; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, mb_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1036,32 +941,8 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( col_stride_u64, mb_u32, )?; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, mb_u32)?; + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, mb_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1097,32 +978,7 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( stream.synchronize()?; // Re-interleave pinned → caller ext3 outputs, parallel. - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; - } - }); + unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); // Parallel memcpy of pinned hashes → caller. const CHUNK: usize = 64 * 1024; @@ -1234,23 +1090,7 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) - }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); + pack_ext3_to_pinned_slabs(columns, pinned, n); let mut buf = stream.alloc_zeros::(mb * lde_size)?; for s in 0..mb { @@ -1267,19 +1107,7 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( let col_stride_u64 = lde_size as u64; let mb_u32 = mb as u32; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, mb_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1289,32 +1117,8 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( col_stride_u64, mb_u32, )?; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, mb_u32)?; + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, mb_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1341,29 +1145,7 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( )?; } - // Inner tree levels. - { - let mut level_begin: u64 = (lde_size - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - let grid = (n_pairs as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) - .launch(cfg)?; - } - level_begin = new_begin; - } - } + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; // D2H LDE (mb * lde_size u64) and tree nodes. stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; @@ -1379,32 +1161,7 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( stream.synchronize()?; // Re-interleave pinned → caller ext3 outputs. - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; - } - }); + unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); const CHUNK: usize = 64 * 1024; let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; @@ -1508,23 +1265,7 @@ fn evaluate_poly_coset_batch_ext3_into_inner( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - coefs.par_iter().enumerate().for_each(|(c, col)| { - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) - }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); + pack_ext3_to_pinned_slabs(coefs, pinned, n); let mut buf = stream.alloc_zeros::(mb * lde_size)?; for s in 0..mb { @@ -1541,42 +1282,10 @@ fn evaluate_poly_coset_batch_ext3_into_inner( let mb_u32 = mb as u32; // Apply coset scaling: x[k] *= weights[k] for k in 0..n (no iFFT first). - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, mb_u32)?; // Bit-reverse full lde_size slab, then forward DIT NTT. - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, mb_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1590,32 +1299,7 @@ fn evaluate_poly_coset_batch_ext3_into_inner( stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; stream.synchronize()?; - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; - } - }); + unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); drop(staging); if keep_device_buf { Ok(Some(GpuLdeExt3 { @@ -1681,23 +1365,7 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - coefs.par_iter().enumerate().for_each(|(c, col)| { - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) - }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); + pack_ext3_to_pinned_slabs(coefs, pinned, n); let mut buf = stream.alloc_zeros::(mb * lde_size)?; for s in 0..mb { @@ -1713,32 +1381,8 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( let col_stride_u64 = lde_size as u64; let mb_u32 = mb as u32; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(LaunchConfig { - grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - })?; - } + launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, mb_u32)?; + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, mb_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1780,28 +1424,7 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( .launch(cfg)?; } } - { - let mut level_begin: u64 = (num_leaves - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - let grid = (n_pairs as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) - .launch(cfg)?; - } - level_begin = new_begin; - } - } + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; // D2H LDE and tree. stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; @@ -1817,32 +1440,7 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( stream.synchronize()?; // Re-interleave pinned → caller ext3 outputs. - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; - } - }); + unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); // Copy pinned tree → caller nodes_out. debug_assert_eq!(merkle_nodes_out.len(), tight_total_nodes * 32); @@ -1926,27 +1524,7 @@ pub fn coset_lde_batch_ext3_into( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - // Pack: for each ext3 column, write 3 base slabs into pinned. The slab - // for column c, component k lives at `pinned[(c*3 + k)*n .. (c*3+k)*n + n]`. - // We de-interleave from the interleaved `[a, b, c, a, b, c, ...]` input. - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - // SAFETY: disjoint regions per c; staging lock held. - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) - }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); + pack_ext3_to_pinned_slabs(columns, pinned, n); // Allocate + zero-pad device buffer holding 3M slabs of `lde_size`. let mut buf = stream.alloc_zeros::(mb * lde_size)?; @@ -1967,23 +1545,7 @@ pub fn coset_lde_batch_ext3_into( // === Butterflies: identical to the base-field batched path, but with // grid.y = 3M instead of M. === - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, mb_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1993,40 +1555,8 @@ pub fn coset_lde_batch_ext3_into( col_stride_u64, mb_u32, )?; - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, mb_u32)?; + launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, mb_u32)?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -2042,32 +1572,7 @@ pub fn coset_lde_batch_ext3_into( // Unpack: for each output column, re-interleave 3 slabs back into the // ext3-per-element layout. - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; - } - }); + unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); drop(staging); Ok(()) } diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs index e4ce09468..8bb0f03e9 100644 --- a/crypto/math-cuda/src/merkle.rs +++ b/crypto/math-cuda/src/merkle.rs @@ -20,7 +20,7 @@ use cudarc::driver::{CudaSlice, CudaStream, CudaViewMut, LaunchConfig, PushKerne use rayon::prelude::*; use crate::Result; -use crate::device::backend; +use crate::device::{Backend, backend}; /// Run GPU Keccak-256 leaf hashing on a base-field column buffer. /// @@ -101,6 +101,10 @@ pub fn keccak_leaves_ext3( const KECCAK_BLOCK_DIM: u32 = 128; fn keccak_launch_cfg(num_rows: u64) -> LaunchConfig { + debug_assert!( + num_rows <= u32::MAX as u64, + "keccak_launch_cfg: num_rows ({num_rows}) exceeds u32 grid range", + ); let grid = (num_rows as u32).div_ceil(KECCAK_BLOCK_DIM); LaunchConfig { grid_dim: (grid, 1, 1), @@ -109,6 +113,35 @@ fn keccak_launch_cfg(num_rows: u64) -> LaunchConfig { } } +/// Walk the inner Merkle tree on device. `nodes_dev` already has the +/// `leaves_len` hashed leaves written into the tail; this loops +/// `log2(leaves_len)` times invoking `keccak_merkle_level` to fill in the +/// inner nodes from the bottom up. Mirrors the CPU `build(nodes, leaves_len)` +/// scan in `crypto/crypto/src/merkle_tree/merkle.rs`. +pub(crate) fn build_inner_tree_levels( + stream: &CudaStream, + be: &Backend, + nodes_dev: &mut CudaSlice, + leaves_len: usize, +) -> Result<()> { + let mut level_begin: u64 = (leaves_len - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let cfg = keccak_launch_cfg(n_pairs); + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut *nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + Ok(()) +} + pub(crate) fn launch_keccak_base( stream: &CudaStream, cols_dev: &CudaSlice, @@ -177,30 +210,7 @@ pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { stream.memcpy_htod(hashed_leaves, &mut slice)?; } - // Build level by level. The CPU `build(nodes, leaves_len)` starts with - // level_begin_index = leaves_len - 1 - // level_end_index = 2 * level_begin_index - // and each iteration computes: - // new_level_begin_index = level_begin_index / 2 - // new_level_length = level_begin_index - new_level_begin_index - // The parents occupy [new_level_begin_index, level_begin_index), the - // children occupy [level_begin_index, level_end_index + 1). - let mut level_begin: u64 = (leaves_len - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - - let cfg = keccak_launch_cfg(n_pairs); - unsafe { - stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) - .launch(cfg)?; - } - level_begin = new_begin; - } + build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, leaves_len)?; let out = stream.clone_dtoh(&nodes_dev)?; stream.synchronize()?; From 6c621a6d657481227c1a341b0890c2e4fe5d7981 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 18 May 2026 14:59:17 -0300 Subject: [PATCH 16/22] refactor --- crypto/math-cuda/src/lde.rs | 683 +++++++++++++++------------------ crypto/math-cuda/src/merkle.rs | 48 +-- 2 files changed, 308 insertions(+), 423 deletions(-) diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 636c64952..af5e0a384 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -32,6 +32,33 @@ fn assert_u32_domain(n: usize, what: &str) { ); } +/// Output shape requested from the fused LDE + Keccak entry points. +#[derive(Copy, Clone, PartialEq, Eq)] +enum KeccakCommit { + /// Only the `lde_size` keccak-256 leaves; no inner-tree build. Caller + /// receives `lde_size * 32` bytes. + LeavesOnly, + /// Full Merkle tree: leaves at the tail + inner nodes built on-device. + /// Caller receives `(2*lde_size - 1) * 32` bytes. + FullTree, +} + +impl KeccakCommit { + fn total_nodes_bytes(self, lde_size: usize) -> usize { + match self { + KeccakCommit::LeavesOnly => lde_size * 32, + KeccakCommit::FullTree => (2 * lde_size - 1) * 32, + } + } + + fn leaves_offset_bytes(self, lde_size: usize) -> usize { + match self { + KeccakCommit::LeavesOnly => 0, + KeccakCommit::FullTree => (lde_size - 1) * 32, + } + } +} + /// De-interleave `columns` (each `3*n` u64s, ext3-per-element layout /// `[a, b, c, a, b, c, ...]`) into `pinned` as `3*m` base-field slabs. /// Component `k` of column `c` lands at `pinned[(c*3 + k)*n .. (c*3 + k)*n + n]`. @@ -127,6 +154,47 @@ fn launch_bit_reverse_batched( Ok(()) } +/// D2H `dst.len()` bytes from `dev_bytes` into the caller's pageable `dst` +/// via the pinned-hashes staging buffer. Synchronises the stream first (so +/// any other D2H queued on the same stream also drains), then does a rayon +/// chunked memcpy pinned → caller to spread page-fault cost across cores. +fn d2h_bytes_via_pinned_hashes( + stream: &Arc, + be: &Backend, + dev_bytes: &CudaSlice, + dst: &mut [u8], +) -> Result<()> { + let n_bytes = dst.len(); + let u64_len = n_bytes.div_ceil(8); + let staging_slot = be.pinned_hashes(); + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(u64_len, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(u64_len) }; + // Reinterpret the u64 pinned buffer as bytes — same allocation, just + // typed differently. SAFETY: u64 has stricter alignment than u8 and the + // byte length fits in the `u64_len` capacity (rounded up to u64). + let pinned_bytes: &mut [u8] = + unsafe { std::slice::from_raw_parts_mut(pinned.as_mut_ptr() as *mut u8, n_bytes) }; + stream.memcpy_dtoh(dev_bytes, pinned_bytes)?; + stream.synchronize()?; + + // Single-threaded `copy_from_slice` faults virgin pageable pages one at + // a time; the mm_struct rwsem serialises them at prover scale. Chunk so + // ~N cores pre-fault+write in parallel. + const CHUNK: usize = 64 * 1024; + let src_ptr = pinned_bytes.as_ptr() as usize; + dst.par_chunks_mut(CHUNK).enumerate().for_each(|(i, d)| { + // SAFETY: each task reads `[i*CHUNK .. i*CHUNK + d.len()]` of + // `pinned_bytes`, which is disjoint per `i` and lives until `staging` + // is dropped below. + let src = + unsafe { std::slice::from_raw_parts((src_ptr as *const u8).add(i * CHUNK), d.len()) }; + d.copy_from_slice(src); + }); + drop(staging); + Ok(()) +} + /// Run `pointwise_mul_batched`: `buf[c*col_stride + i] *= weights[i]` for /// `m` columns, `n` elements each. fn launch_pointwise_mul_batched( @@ -332,7 +400,15 @@ pub fn coset_lde_batch_base( let m_u32 = m as u32; // === 1. Bit-reverse first N of every column === - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, m_u32)?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; // === 2. iNTT body over all columns === run_batched_ntt_body( @@ -346,10 +422,26 @@ pub fn coset_lde_batch_base( )?; // === 3. Pointwise multiply by coset weights (includes 1/N) === - launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, m_u32)?; + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + m_u32, + )?; // === 4. Bit-reverse full LDE of every column === - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, m_u32)?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; // === 5. Forward NTT on full LDE of every column === run_batched_ntt_body( @@ -461,7 +553,15 @@ pub fn coset_lde_batch_base_into( let m_u32 = m as u32; // iNTT bit-reverse + body, pointwise mul, forward bit-reverse + body. - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, m_u32)?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -471,8 +571,24 @@ pub fn coset_lde_batch_base_into( col_stride_u64, m_u32, )?; - launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, m_u32)?; - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, m_u32)?; + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + m_u32, + )?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -507,6 +623,10 @@ pub fn coset_lde_batch_base_into( /// /// `hashed_leaves_out` must be `lde_size * 32` bytes (one 32-byte digest /// per output row, in natural row order). +/// Fused LDE + Keccak-256 leaf hashing. Caller receives the `lde_size * 32` +/// bytes of leaf hashes in `hashed_leaves_out`. Thin wrapper over +/// `coset_lde_batch_base_into_with_merkle_tree_inner` with `LeavesOnly` — +/// no inner-tree build, no device handle. pub fn coset_lde_batch_base_into_with_leaf_hash( columns: &[&[u64]], blowup_factor: usize, @@ -514,145 +634,16 @@ pub fn coset_lde_batch_base_into_with_leaf_hash( outputs: &mut [&mut [u64]], hashed_leaves_out: &mut [u8], ) -> Result<()> { - if columns.is_empty() { - assert_eq!(outputs.len(), 0); - return Ok(()); - } - let m = columns.len(); - assert_eq!(outputs.len(), m); - let n = columns[0].len(); - // (is_power_of_two returns false for 0). - if n == 0 { - return Ok(()); - } - assert!(n.is_power_of_two()); - assert_eq!(weights.len(), n); - assert!(blowup_factor.is_power_of_two()); - let lde_size = n * blowup_factor; - assert_u32_domain( - lde_size, - "coset_lde_batch_base_into_with_leaf_hash lde_size", - ); - for o in outputs.iter() { - assert_eq!(o.len(), lde_size); - } - assert_eq!(hashed_leaves_out.len(), lde_size * 32); - let log_n = n.trailing_zeros() as u64; - let log_lde = lde_size.trailing_zeros() as u64; - - let be = backend()?; - let stream = be.next_stream(); - let staging_slot = be.pinned_staging(); - - let mut staging = staging_slot.lock().unwrap(); - staging.ensure_capacity(m * lde_size, &be.ctx)?; - let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - - let pinned_base_ptr = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - // SAFETY: disjoint regions per c, outer staging lock held. - let dst = - unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; - dst.copy_from_slice(col); - }); - - let mut buf = stream.alloc_zeros::(m * lde_size)?; - for c in 0..m { - let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); - stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; - } - - let inv_tw = be.inv_twiddles_for(log_n)?; - let fwd_tw = be.fwd_twiddles_for(log_lde)?; - let weights_dev = stream.clone_htod(weights)?; - - let n_u64 = n as u64; - let lde_u64 = lde_size as u64; - let col_stride_u64 = lde_size as u64; - let m_u32 = m as u32; - - // iNTT - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, m_u32)?; - run_batched_ntt_body( - stream.as_ref(), - &mut buf, - inv_tw.as_ref(), - n_u64, - log_n, - col_stride_u64, - m_u32, - )?; - // pointwise coset scale - launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, m_u32)?; - // forward NTT on full LDE slab - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, m_u32)?; - run_batched_ntt_body( - stream.as_ref(), - &mut buf, - fwd_tw.as_ref(), - lde_u64, - log_lde, - col_stride_u64, - m_u32, - )?; - - // Keccak-256 leaf hashing directly on the device LDE buffer. - let mut hashes_dev = stream.alloc_zeros::(lde_size * 32)?; - launch_keccak_base( - stream.as_ref(), - &buf, - col_stride_u64, - m as u64, - lde_u64, - &mut hashes_dev.as_view_mut(), - )?; - - // D2H the LDE into the pinned LDE staging, then the hashes into a - // dedicated pinned hash staging. The two copies run back-to-back on the - // same stream; both go at pinned PCIe line-rate — pageable D2H of the - // 128 MB hash buffer would otherwise cost ~100 ms per main-trace commit. - stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; - let hashes_u64_len = (lde_size * 32).div_ceil(8); - let hashes_staging_slot = be.pinned_hashes(); - let mut hashes_staging = hashes_staging_slot.lock().unwrap(); - hashes_staging.ensure_capacity(hashes_u64_len, &be.ctx)?; - let hashes_pinned = unsafe { hashes_staging.as_mut_slice(hashes_u64_len) }; - // `memcpy_dtoh` needs a byte slice. Reinterpret the u64 pinned buffer - // as bytes — same allocation, just typed differently. - let hashes_pinned_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(hashes_pinned.as_mut_ptr() as *mut u8, lde_size * 32) - }; - stream.memcpy_dtoh(&hashes_dev, hashes_pinned_bytes)?; - stream.synchronize()?; - - // Copy pinned → caller outputs. Both D2H copies have already drained - // (the synchronize above); this is a rayon-parallel host memcpy from - // pinned to pageable memory, not concurrent with any GPU transfer. - let pinned_ptr = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) - }; - dst.copy_from_slice(src); - }); - // Rayon-parallel memcpy of 128 MB from pinned → caller. Single-threaded - // `copy_from_slice` faults virgin pageable pages one at a time; the - // mm_struct rwsem serialises them into ~100 ms at 1M-fib scale. Chunk - // the slice so ~N cores pre-fault+write in parallel. - const CHUNK: usize = 64 * 1024; // 64 KiB ≈ 16 pages per chunk - let pinned_hash_ptr = hashes_pinned_bytes.as_ptr() as usize; - hashed_leaves_out - .par_chunks_mut(CHUNK) - .enumerate() - .for_each(|(i, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_hash_ptr as *const u8).add(i * CHUNK), dst.len()) - }; - dst.copy_from_slice(src); - }); - drop(hashes_staging); - drop(staging); - Ok(()) + coset_lde_batch_base_into_with_merkle_tree_inner( + columns, + blowup_factor, + weights, + outputs, + hashed_leaves_out, + KeccakCommit::LeavesOnly, + false, + ) + .map(|_| ()) } /// Like `coset_lde_batch_base_into_with_leaf_hash`, but also builds the full @@ -676,6 +667,7 @@ pub fn coset_lde_batch_base_into_with_merkle_tree( weights, outputs, merkle_nodes_out, + KeccakCommit::FullTree, false, ) .map(|_| ()) @@ -697,6 +689,7 @@ pub fn coset_lde_batch_base_into_with_merkle_tree_keep( weights, outputs, merkle_nodes_out, + KeccakCommit::FullTree, true, )?; let handle = opt.expect("keep_device_buf=true must return Some"); @@ -708,7 +701,8 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( blowup_factor: usize, weights: &[u64], outputs: &mut [&mut [u64]], - merkle_nodes_out: &mut [u8], + nodes_out: &mut [u8], + commit: KeccakCommit, keep_device_buf: bool, ) -> Result> { if columns.is_empty() { @@ -733,8 +727,8 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( for o in outputs.iter() { assert_eq!(o.len(), lde_size); } - let total_nodes = 2 * lde_size - 1; - assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + let nodes_dev_bytes = commit.total_nodes_bytes(lde_size); + assert_eq!(nodes_out.len(), nodes_dev_bytes); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; @@ -769,7 +763,15 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( let m_u32 = m as u32; // iNTT - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, m_u32)?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -779,9 +781,25 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( col_stride_u64, m_u32, )?; - launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, m_u32)?; + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + m_u32, + )?; // forward NTT at LDE size - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, m_u32)?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -792,17 +810,14 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( m_u32, )?; - // Allocate the full node buffer; leaves occupy the tail slab, inner - // nodes are written by the pair-hash level kernel below. `alloc` (not - // `alloc_zeros`) is safe because every byte is written before it is - // read: leaf kernel fills the tail, tree kernel fills the head. - // - // The leaf kernel writes to `nodes_dev` starting at byte offset - // `(lde_size - 1) * 32`; we pass the base pointer as-is because the - // kernel indexes linearly from `hashed_leaves_out[row_idx * 32]`, so we - // build an offset device slice and feed that to the launch. - let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; - let leaves_offset_bytes = (lde_size - 1) * 32; + // Allocate the device output buffer. In `LeavesOnly` mode this is just + // `lde_size * 32` bytes (the leaves themselves); in `FullTree` mode it's + // `(2*lde_size - 1) * 32` bytes (leaves in the tail + inner nodes filled + // below). `alloc` (not `alloc_zeros`) is safe because every byte is + // written before any reader sees it: the keccak kernel fills the + // leaves slab, the inner-tree pass (when present) fills the head. + let mut nodes_dev = unsafe { stream.alloc::(nodes_dev_bytes) }?; + let leaves_offset_bytes = commit.leaves_offset_bytes(lde_size); { let mut leaves_view = nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); @@ -816,23 +831,15 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( )?; } - crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; + if commit == KeccakCommit::FullTree { + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; + } - // D2H the LDE and the tree nodes via pinned staging. + // D2H the LDE and the tree/leaves nodes via pinned staging. stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?; - let tree_u64_len = (total_nodes * 32).div_ceil(8); - let tree_staging_slot = be.pinned_hashes(); - let mut tree_staging = tree_staging_slot.lock().unwrap(); - tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; - let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; - let tree_pinned_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, total_nodes * 32) - }; - stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; - stream.synchronize()?; - - // Parallel memcpy pinned → caller. + // Pinned LDE → caller outputs (post-sync host memcpy). let pinned_ptr = pinned.as_ptr() as usize; outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { let src = unsafe { @@ -840,18 +847,6 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( }; dst.copy_from_slice(src); }); - const CHUNK: usize = 64 * 1024; - let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; - merkle_nodes_out - .par_chunks_mut(CHUNK) - .enumerate() - .for_each(|(i, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) - }; - dst.copy_from_slice(src); - }); - drop(tree_staging); drop(staging); if keep_device_buf { @@ -866,9 +861,9 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( } } -/// Ext3 variant of `coset_lde_batch_base_into_with_leaf_hash`: run an LDE -/// over ext3 columns AND emit Keccak-256 Merkle leaves, all in one on-device -/// pipeline. +/// Ext3 variant of `coset_lde_batch_base_into_with_leaf_hash`: fused +/// LDE + Keccak-256 leaf hashing over ext3 columns. Thin wrapper over +/// `coset_lde_batch_ext3_into_with_merkle_tree_inner` with `LeavesOnly`. pub fn coset_lde_batch_ext3_into_with_leaf_hash( columns: &[&[u64]], n: usize, @@ -877,124 +872,17 @@ pub fn coset_lde_batch_ext3_into_with_leaf_hash( outputs: &mut [&mut [u64]], hashed_leaves_out: &mut [u8], ) -> Result<()> { - if columns.is_empty() { - assert_eq!(outputs.len(), 0); - return Ok(()); - } - // (is_power_of_two returns false for 0). - if n == 0 { - return Ok(()); - } - let m = columns.len(); - assert_eq!(outputs.len(), m); - assert!(n.is_power_of_two()); - assert_eq!(weights.len(), n); - assert!(blowup_factor.is_power_of_two()); - for c in columns.iter() { - assert_eq!(c.len(), 3 * n); - } - let lde_size = n * blowup_factor; - assert_u32_domain( - lde_size, - "coset_lde_batch_ext3_into_with_leaf_hash lde_size", - ); - for o in outputs.iter() { - assert_eq!(o.len(), 3 * lde_size); - } - assert_eq!(hashed_leaves_out.len(), lde_size * 32); - let log_n = n.trailing_zeros() as u64; - let log_lde = lde_size.trailing_zeros() as u64; - - let mb = 3 * m; - let be = backend()?; - let stream = be.next_stream(); - let staging_slot = be.pinned_staging(); - - let mut staging = staging_slot.lock().unwrap(); - staging.ensure_capacity(mb * lde_size, &be.ctx)?; - let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - - pack_ext3_to_pinned_slabs(columns, pinned, n); - - let mut buf = stream.alloc_zeros::(mb * lde_size)?; - for s in 0..mb { - let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); - stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; - } - - let inv_tw = be.inv_twiddles_for(log_n)?; - let fwd_tw = be.fwd_twiddles_for(log_lde)?; - let weights_dev = stream.clone_htod(weights)?; - - let n_u64 = n as u64; - let lde_u64 = lde_size as u64; - let col_stride_u64 = lde_size as u64; - let mb_u32 = mb as u32; - - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, mb_u32)?; - run_batched_ntt_body( - stream.as_ref(), - &mut buf, - inv_tw.as_ref(), - n_u64, - log_n, - col_stride_u64, - mb_u32, - )?; - launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, mb_u32)?; - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, mb_u32)?; - run_batched_ntt_body( - stream.as_ref(), - &mut buf, - fwd_tw.as_ref(), - lde_u64, - log_lde, - col_stride_u64, - mb_u32, - )?; - - // Keccak-256 on the de-interleaved device buffer (3M base slabs). - let mut hashes_dev = stream.alloc_zeros::(lde_size * 32)?; - launch_keccak_ext3( - stream.as_ref(), - &buf, - col_stride_u64, - m as u64, - lde_u64, - &mut hashes_dev.as_view_mut(), - )?; - - // D2H LDE (mb * lde_size u64) and hashes (lde_size * 32 bytes). - stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; - let hashes_u64_len = (lde_size * 32).div_ceil(8); - let hashes_staging_slot = be.pinned_hashes(); - let mut hashes_staging = hashes_staging_slot.lock().unwrap(); - hashes_staging.ensure_capacity(hashes_u64_len, &be.ctx)?; - let hashes_pinned = unsafe { hashes_staging.as_mut_slice(hashes_u64_len) }; - let hashes_pinned_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(hashes_pinned.as_mut_ptr() as *mut u8, lde_size * 32) - }; - stream.memcpy_dtoh(&hashes_dev, hashes_pinned_bytes)?; - stream.synchronize()?; - - // Re-interleave pinned → caller ext3 outputs, parallel. - unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); - - // Parallel memcpy of pinned hashes → caller. - const CHUNK: usize = 64 * 1024; - let hash_src_ptr = hashes_pinned_bytes.as_ptr() as usize; - hashed_leaves_out - .par_chunks_mut(CHUNK) - .enumerate() - .for_each(|(i, dst)| { - let src = unsafe { - std::slice::from_raw_parts((hash_src_ptr as *const u8).add(i * CHUNK), dst.len()) - }; - dst.copy_from_slice(src); - }); - drop(hashes_staging); - drop(staging); - Ok(()) + coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + hashed_leaves_out, + KeccakCommit::LeavesOnly, + false, + ) + .map(|_| ()) } /// Ext3 variant of the fused `coset_lde_batch_base_into_with_merkle_tree`. @@ -1015,6 +903,7 @@ pub fn coset_lde_batch_ext3_into_with_merkle_tree( weights, outputs, merkle_nodes_out, + KeccakCommit::FullTree, false, ) .map(|_| ()) @@ -1038,18 +927,21 @@ pub fn coset_lde_batch_ext3_into_with_merkle_tree_keep( weights, outputs, merkle_nodes_out, + KeccakCommit::FullTree, true, )?; Ok(opt.expect("keep_device_buf=true must return Some")) } +#[allow(clippy::too_many_arguments)] fn coset_lde_batch_ext3_into_with_merkle_tree_inner( columns: &[&[u64]], n: usize, blowup_factor: usize, weights: &[u64], outputs: &mut [&mut [u64]], - merkle_nodes_out: &mut [u8], + nodes_out: &mut [u8], + commit: KeccakCommit, keep_device_buf: bool, ) -> Result> { if columns.is_empty() { @@ -1076,8 +968,8 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( for o in outputs.iter() { assert_eq!(o.len(), 3 * lde_size); } - let total_nodes = 2 * lde_size - 1; - assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + let nodes_dev_bytes = commit.total_nodes_bytes(lde_size); + assert_eq!(nodes_out.len(), nodes_dev_bytes); let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; @@ -1107,7 +999,15 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( let col_stride_u64 = lde_size as u64; let mb_u32 = mb as u32; - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, mb_u32)?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1117,8 +1017,24 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( col_stride_u64, mb_u32, )?; - launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, mb_u32)?; - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, mb_u32)?; + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + mb_u32, + )?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1129,9 +1045,11 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( mb_u32, )?; - // Allocate full tree buffer; leaf kernel writes to the tail slab. - let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; - let leaves_offset_bytes = (lde_size - 1) * 32; + // Allocate device output buffer (LeavesOnly → lde_size*32; FullTree → + // (2*lde_size - 1)*32). Leaf kernel writes to the leaves slab; the + // inner-tree pass (when present) fills the head. + let mut nodes_dev = unsafe { stream.alloc::(nodes_dev_bytes) }?; + let leaves_offset_bytes = commit.leaves_offset_bytes(lde_size); { let mut leaves_view = nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); @@ -1145,36 +1063,15 @@ fn coset_lde_batch_ext3_into_with_merkle_tree_inner( )?; } - crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; + if commit == KeccakCommit::FullTree { + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; + } - // D2H LDE (mb * lde_size u64) and tree nodes. + // D2H LDE (mb * lde_size u64) and tree/leaves nodes. stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; - let tree_u64_len = (total_nodes * 32).div_ceil(8); - let tree_staging_slot = be.pinned_hashes(); - let mut tree_staging = tree_staging_slot.lock().unwrap(); - tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; - let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; - let tree_pinned_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, total_nodes * 32) - }; - stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; - stream.synchronize()?; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?; - // Re-interleave pinned → caller ext3 outputs. unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); - - const CHUNK: usize = 64 * 1024; - let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; - merkle_nodes_out - .par_chunks_mut(CHUNK) - .enumerate() - .for_each(|(i, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) - }; - dst.copy_from_slice(src); - }); - drop(tree_staging); drop(staging); if keep_device_buf { @@ -1282,10 +1179,26 @@ fn evaluate_poly_coset_batch_ext3_into_inner( let mb_u32 = mb as u32; // Apply coset scaling: x[k] *= weights[k] for k in 0..n (no iFFT first). - launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, mb_u32)?; + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + mb_u32, + )?; // Bit-reverse full lde_size slab, then forward DIT NTT. - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, mb_u32)?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1381,8 +1294,24 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( let col_stride_u64 = lde_size as u64; let mb_u32 = mb as u32; - launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, mb_u32)?; - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, mb_u32)?; + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + mb_u32, + )?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1427,35 +1356,11 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; // D2H LDE and tree. + debug_assert_eq!(merkle_nodes_out.len(), tight_total_nodes * 32); stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; - let tree_u64_len = (tight_total_nodes * 32).div_ceil(8); - let tree_staging_slot = be.pinned_hashes(); - let mut tree_staging = tree_staging_slot.lock().unwrap(); - tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; - let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; - let tree_pinned_bytes: &mut [u8] = unsafe { - std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, tight_total_nodes * 32) - }; - stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; - stream.synchronize()?; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, merkle_nodes_out)?; - // Re-interleave pinned → caller ext3 outputs. unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); - - // Copy pinned tree → caller nodes_out. - debug_assert_eq!(merkle_nodes_out.len(), tight_total_nodes * 32); - const CHUNK: usize = 64 * 1024; - let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; - merkle_nodes_out - .par_chunks_mut(CHUNK) - .enumerate() - .for_each(|(i, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_tree_ptr as *const u8).add(i * CHUNK), dst.len()) - }; - dst.copy_from_slice(src); - }); - drop(tree_staging); drop(staging); Ok(()) } @@ -1545,7 +1450,15 @@ pub fn coset_lde_batch_ext3_into( // === Butterflies: identical to the base-field batched path, but with // grid.y = 3M instead of M. === - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, n_u64, log_n, col_stride_u64, mb_u32)?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -1555,8 +1468,24 @@ pub fn coset_lde_batch_ext3_into( col_stride_u64, mb_u32, )?; - launch_pointwise_mul_batched(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, col_stride_u64, mb_u32)?; - launch_bit_reverse_batched(stream.as_ref(), be, &mut buf, lde_u64, log_lde, col_stride_u64, mb_u32)?; + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + mb_u32, + )?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs index 8bb0f03e9..4b73d89fb 100644 --- a/crypto/math-cuda/src/merkle.rs +++ b/crypto/math-cuda/src/merkle.rs @@ -313,29 +313,7 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res } } - // Inner tree. - { - let mut level_begin: u64 = (num_leaves - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - let grid = (n_pairs as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) - .launch(cfg)?; - } - level_begin = new_begin; - } - } + build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; let out = stream.clone_dtoh(&nodes_dev)?; stream.synchronize()?; @@ -386,29 +364,7 @@ pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { } } - // Inner tree levels, identical to the R2 version. - { - let mut level_begin: u64 = (num_leaves - 1) as u64; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - let grid = (n_pairs as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak_merkle_level) - .arg(&mut nodes_dev) - .arg(&new_begin) - .arg(&n_pairs) - .launch(cfg)?; - } - level_begin = new_begin; - } - } + build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; let out = stream.clone_dtoh(&nodes_dev)?; stream.synchronize()?; From 55a75f126911b5861dbe5afb674d98278b10de09 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 18 May 2026 15:12:29 -0300 Subject: [PATCH 17/22] refactor tests --- Cargo.lock | 2 + crypto/crypto/src/merkle_tree/merkle.rs | 8 + crypto/math-cuda/Cargo.toml | 2 + crypto/math-cuda/tests/keccak_leaves.rs | 208 +++++++----------------- crypto/math-cuda/tests/merkle_tree.rs | 50 ++---- crypto/stark/src/prover.rs | 172 ++++++++++++-------- 6 files changed, 190 insertions(+), 252 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 001b4b841..0f01bf090 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2129,12 +2129,14 @@ dependencies = [ name = "math-cuda" version = "0.1.0" dependencies = [ + "crypto", "cudarc", "math", "rand 0.8.5", "rand_chacha 0.3.1", "rayon", "sha3", + "stark", ] [[package]] diff --git a/crypto/crypto/src/merkle_tree/merkle.rs b/crypto/crypto/src/merkle_tree/merkle.rs index 55fa49a83..4f5fe28d3 100644 --- a/crypto/crypto/src/merkle_tree/merkle.rs +++ b/crypto/crypto/src/merkle_tree/merkle.rs @@ -81,6 +81,14 @@ where }) } + /// Read-only access to the full node buffer in standard layout: + /// `nodes[0..leaves_len - 1]` are inner nodes (root at index 0) and + /// `nodes[leaves_len - 1..]` are the leaves. Useful for parity tests + /// against alternative tree builders. + pub fn nodes(&self) -> &[B::Node] { + &self.nodes + } + /// Returns a Merkle proof for the element/s at position pos /// For example, give me an inclusion proof for the 3rd element in the /// Merkle tree diff --git a/crypto/math-cuda/Cargo.toml b/crypto/math-cuda/Cargo.toml index 0990dd6d6..e700ec73e 100644 --- a/crypto/math-cuda/Cargo.toml +++ b/crypto/math-cuda/Cargo.toml @@ -17,7 +17,9 @@ math = { path = "../math" } rayon = "1.7" [dev-dependencies] +crypto = { path = "../crypto" } rand = { version = "0.8.5", features = ["std"] } rand_chacha = "0.3.1" rayon = "1.7" sha3 = "0.10.8" +stark = { path = "../stark" } diff --git a/crypto/math-cuda/tests/keccak_leaves.rs b/crypto/math-cuda/tests/keccak_leaves.rs index 33610aeca..6b54003c2 100644 --- a/crypto/math-cuda/tests/keccak_leaves.rs +++ b/crypto/math-cuda/tests/keccak_leaves.rs @@ -1,63 +1,23 @@ -//! Parity: GPU Keccak-256 leaf hashes must match CPU -//! `FieldElementVectorBackend::::hash_data` applied to -//! bit-reversed rows (same pattern as `commit_columns_bit_reversed` in the -//! stark prover). - +//! Parity: GPU Keccak-256 leaf hashes must match the CPU prover's leaf +//! hashing helpers — `stark::prover::keccak_leaves_bit_reversed` for +//! per-row commits, `keccak_leaves_row_pair_bit_reversed` for the R2 +//! composition commit, and `FieldElementPairBackend::hash_data` for the FRI +//! commit. These are the same helpers the prover itself calls so any +//! change to the CPU leaf-hash contract surfaces here. + +use crypto::merkle_tree::backends::field_element_vector::FieldElementPairBackend; +use crypto::merkle_tree::traits::IsMerkleTreeBackend; use math::field::element::FieldElement; use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; use math::field::goldilocks::GoldilocksField; -use math::traits::ByteConversion; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; -use sha3::{Digest, Keccak256}; +use sha3::Keccak256; +use stark::prover::{keccak_leaves_bit_reversed, keccak_leaves_row_pair_bit_reversed}; type Fp = FieldElement; type Fp3 = FieldElement; - -fn reverse_index(i: u64, n: u64) -> u64 { - let log_n = n.trailing_zeros(); - i.reverse_bits() >> (64 - log_n) -} - -fn cpu_leaves_base(columns: &[Vec]) -> Vec<[u8; 32]> { - let num_rows = columns[0].len(); - let num_cols = columns.len(); - let byte_len = 8; - (0..num_rows) - .map(|row_idx| { - let br = reverse_index(row_idx as u64, num_rows as u64) as usize; - let mut buf = vec![0u8; num_cols * byte_len]; - for c in 0..num_cols { - columns[c][br].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]); - } - let mut h = Keccak256::new(); - h.update(&buf); - let mut out = [0u8; 32]; - out.copy_from_slice(&h.finalize()); - out - }) - .collect() -} - -fn cpu_leaves_ext3(columns: &[Vec]) -> Vec<[u8; 32]> { - let num_rows = columns[0].len(); - let num_cols = columns.len(); - let byte_len = 24; - (0..num_rows) - .map(|row_idx| { - let br = reverse_index(row_idx as u64, num_rows as u64) as usize; - let mut buf = vec![0u8; num_cols * byte_len]; - for c in 0..num_cols { - columns[c][br].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]); - } - let mut h = Keccak256::new(); - h.update(&buf); - let mut out = [0u8; 32]; - out.copy_from_slice(&h.finalize()); - out - }) - .collect() -} +type FriPairBackend = FieldElementPairBackend; #[test] fn keccak_leaves_base_matches_cpu() { @@ -69,7 +29,7 @@ fn keccak_leaves_base_matches_cpu() { .map(|_| (0..n).map(|_| Fp::from_raw(rng.r#gen::())).collect()) .collect(); - let cpu = cpu_leaves_base(&columns); + let cpu = keccak_leaves_bit_reversed(&columns); // Flatten columns into a contiguous base slab layout matching // `coset_lde_batch_base_into`'s pinned staging format: @@ -93,52 +53,50 @@ fn keccak_leaves_base_matches_cpu() { } } -// Row-pair leaves for the R2 composition-polynomial commit. For each leaf i: -// br_0 = bit_reverse(2*i, log_lde), br_1 = bit_reverse(2*i+1, log_lde) -// hash is Keccak256 of BE bytes of every part's ext3 value at br_0 then br_1 -// (matching `commit_composition_polynomial` on the CPU side). -fn cpu_leaves_comp_poly(parts: &[Vec]) -> Vec<[u8; 32]> { - let lde_size = parts[0].len(); - let num_parts = parts.len(); - let num_leaves = lde_size / 2; - let byte_len = 24; - (0..num_leaves) - .map(|i| { - let br_0 = reverse_index((2 * i) as u64, lde_size as u64) as usize; - let br_1 = reverse_index((2 * i + 1) as u64, lde_size as u64) as usize; - let mut buf = vec![0u8; 2 * num_parts * byte_len]; - for (p, part) in parts.iter().enumerate() { - part[br_0].write_bytes_be(&mut buf[p * byte_len..(p + 1) * byte_len]); +#[test] +fn keccak_leaves_ext3_matches_cpu() { + for log_n in [4u32, 6, 8, 10] { + for num_cols in [1usize, 3, 11, 20] { + let n = 1 << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(200 + log_n as u64 + num_cols as u64); + let columns: Vec> = (0..num_cols) + .map(|_| { + (0..n) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() + }) + .collect(); + + let cpu = keccak_leaves_bit_reversed(&columns); + + // GPU expects 3 base slabs per ext3 column in the order + // [col*3+0 (comp a), col*3+1 (comp b), col*3+2 (comp c)], each a + // contiguous slab of n u64s (length = num_cols * 3 * n). + let mut flat = vec![0u64; num_cols * 3 * n]; + for (c, col) in columns.iter().enumerate() { + for (r, e) in col.iter().enumerate() { + flat[(c * 3) * n + r] = *e.value()[0].value(); + flat[(c * 3 + 1) * n + r] = *e.value()[1].value(); + flat[(c * 3 + 2) * n + r] = *e.value()[2].value(); + } } - let off = num_parts * byte_len; - for (p, part) in parts.iter().enumerate() { - part[br_1].write_bytes_be(&mut buf[off + p * byte_len..off + (p + 1) * byte_len]); + let gpu = math_cuda::merkle::keccak_leaves_ext3(&flat, n, num_cols, n).unwrap(); + assert_eq!(gpu.len(), n * 32); + for i in 0..n { + assert_eq!( + &gpu[i * 32..(i + 1) * 32], + &cpu[i][..], + "ext3 leaf mismatch at row {i} (log_n={log_n}, cols={num_cols})" + ); } - let mut h = Keccak256::new(); - h.update(&buf); - let mut out = [0u8; 32]; - out.copy_from_slice(&h.finalize()); - out - }) - .collect() -} - -// FRI leaves: each leaf hashes 2 consecutive ext3 evals, no bit reversal. -fn cpu_leaves_fri(evals: &[Fp3]) -> Vec<[u8; 32]> { - let num_leaves = evals.len() / 2; - let byte_len = 24; - (0..num_leaves) - .map(|i| { - let mut buf = vec![0u8; 2 * byte_len]; - evals[2 * i].write_bytes_be(&mut buf[..byte_len]); - evals[2 * i + 1].write_bytes_be(&mut buf[byte_len..]); - let mut h = Keccak256::new(); - h.update(&buf); - let mut out = [0u8; 32]; - out.copy_from_slice(&h.finalize()); - out - }) - .collect() + } + } } #[test] @@ -162,7 +120,7 @@ fn keccak_comp_poly_leaves_matches_cpu() { .collect() }) .collect(); - let cpu = cpu_leaves_comp_poly(&parts); + let cpu = keccak_leaves_row_pair_bit_reversed(&parts); // Each part is passed as `[a0,a1,a2, b0,b1,b2, ...]` of length `3 * lde_size`. let parts_interleaved: Vec> = parts @@ -209,7 +167,13 @@ fn keccak_fri_leaves_matches_cpu() { ]) }) .collect(); - let cpu = cpu_leaves_fri(&evals); + + // CPU reference: consecutive ext3 pairs hashed via the prover's + // `FieldElementPairBackend::hash_data`. + let cpu: Vec<[u8; 32]> = evals + .chunks_exact(2) + .map(|c| FriPairBackend::hash_data(&[c[0], c[1]])) + .collect(); let mut evals_interleaved = vec![0u64; 3 * lde_size]; for (i, e) in evals.iter().enumerate() { @@ -230,49 +194,3 @@ fn keccak_fri_leaves_matches_cpu() { } } } - -#[test] -fn keccak_leaves_ext3_matches_cpu() { - for log_n in [4u32, 6, 8, 10] { - for num_cols in [1usize, 3, 11, 20] { - let n = 1 << log_n; - let mut rng = ChaCha8Rng::seed_from_u64(200 + log_n as u64 + num_cols as u64); - let columns: Vec> = (0..num_cols) - .map(|_| { - (0..n) - .map(|_| { - Fp3::new([ - Fp::from_raw(rng.r#gen::()), - Fp::from_raw(rng.r#gen::()), - Fp::from_raw(rng.r#gen::()), - ]) - }) - .collect() - }) - .collect(); - - let cpu = cpu_leaves_ext3(&columns); - - // GPU expects 3 base slabs per ext3 column in the order - // [col*3+0 (comp a), col*3+1 (comp b), col*3+2 (comp c)], each a - // contiguous slab of n u64s (length = num_cols * 3 * n). - let mut flat = vec![0u64; num_cols * 3 * n]; - for (c, col) in columns.iter().enumerate() { - for (r, e) in col.iter().enumerate() { - flat[(c * 3) * n + r] = *e.value()[0].value(); - flat[(c * 3 + 1) * n + r] = *e.value()[1].value(); - flat[(c * 3 + 2) * n + r] = *e.value()[2].value(); - } - } - let gpu = math_cuda::merkle::keccak_leaves_ext3(&flat, n, num_cols, n).unwrap(); - assert_eq!(gpu.len(), n * 32); - for i in 0..n { - assert_eq!( - &gpu[i * 32..(i + 1) * 32], - &cpu[i][..], - "ext3 leaf mismatch at row {i} (log_n={log_n}, cols={num_cols})" - ); - } - } - } -} diff --git a/crypto/math-cuda/tests/merkle_tree.rs b/crypto/math-cuda/tests/merkle_tree.rs index 34d44c767..e5569379d 100644 --- a/crypto/math-cuda/tests/merkle_tree.rs +++ b/crypto/math-cuda/tests/merkle_tree.rs @@ -1,44 +1,17 @@ //! Parity: GPU Merkle inner-tree construction must match the CPU //! `crypto/crypto/src/merkle_tree/merkle.rs` `build_from_hashed_leaves` -//! (Keccak-256 pair hash at each level). +//! (Keccak-256 pair hash at each level). Uses the prover's +//! `FieldElementVectorBackend<_, Keccak256, 32>` directly so any change to +//! the CPU tree builder is automatically exercised here. +use crypto::merkle_tree::backends::field_element_vector::FieldElementVectorBackend; +use crypto::merkle_tree::merkle::MerkleTree; +use math::field::goldilocks::GoldilocksField; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; -use sha3::{Digest, Keccak256}; +use sha3::Keccak256; -fn cpu_hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] { - let mut h = Keccak256::new(); - h.update(left); - h.update(right); - let mut out = [0u8; 32]; - out.copy_from_slice(&h.finalize()); - out -} - -/// CPU reference: same algorithm as `build_from_hashed_leaves`. -fn cpu_merkle_nodes(leaves: &[[u8; 32]]) -> Vec<[u8; 32]> { - let leaves_len = leaves.len(); - assert!(leaves_len.is_power_of_two() && leaves_len >= 2); - let total = 2 * leaves_len - 1; - - let mut nodes: Vec<[u8; 32]> = vec![[0u8; 32]; total]; - for (i, leaf) in leaves.iter().enumerate() { - nodes[leaves_len - 1 + i] = *leaf; - } - - let mut level_begin = leaves_len - 1; - while level_begin != 0 { - let new_begin = level_begin / 2; - let n_pairs = level_begin - new_begin; - for j in 0..n_pairs { - let left = nodes[level_begin + 2 * j]; - let right = nodes[level_begin + 2 * j + 1]; - nodes[new_begin + j] = cpu_hash_pair(&left, &right); - } - level_begin = new_begin; - } - nodes -} +type CpuTree = MerkleTree>; fn run_parity(log_n: u32, seed: u64) { let leaves_len = 1usize << log_n; @@ -60,11 +33,12 @@ fn run_parity(log_n: u32, seed: u64) { let gpu_nodes_bytes = math_cuda::merkle::build_merkle_tree_on_device(&flat).unwrap(); assert_eq!(gpu_nodes_bytes.len(), (2 * leaves_len - 1) * 32); - let cpu_nodes = cpu_merkle_nodes(&leaves); + // CPU reference: the prover's MerkleTree builder over the same backend. + let cpu_tree = CpuTree::build_from_hashed_leaves(leaves).unwrap(); + let cpu_nodes = cpu_tree.nodes(); - for i in 0..cpu_nodes.len() { + for (i, c) in cpu_nodes.iter().enumerate() { let g = &gpu_nodes_bytes[i * 32..(i + 1) * 32]; - let c = &cpu_nodes[i]; assert_eq!( g, c, "node {i} mismatch at log_n={log_n} (cpu={c:?}, gpu={g:?})" diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index a5386017a..ef7935325 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -374,6 +374,106 @@ where } } +/// Compute Keccak-256 leaf hashes for `commit_columns_bit_reversed`: one +/// leaf per row, where each row is read at `reverse_index(row_idx)` and the +/// columns are concatenated as big-endian bytes before hashing. +/// +/// Returns `Vec` with the same length as `columns[0]`. Exposed +/// (instead of being a closure inside `commit_columns_bit_reversed`) so +/// parity tests in dependent crates can compare against the same code path +/// the prover uses. +pub fn keccak_leaves_bit_reversed(columns: &[Vec>]) -> Vec +where + E: IsField, + FieldElement: AsBytes + Sync + Send + math::traits::ByteConversion, +{ + use math::traits::ByteConversion; + + if columns.is_empty() || columns[0].is_empty() { + return Vec::new(); + } + + let num_rows = columns[0].len(); + let num_cols = columns.len(); + let byte_len = as ByteConversion>::BYTE_LEN; + + debug_assert!( + num_rows.is_power_of_two(), + "num_rows must be a power of two for reverse_index" + ); + + #[cfg(feature = "parallel")] + let iter = (0..num_rows).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let iter = 0..num_rows; + + iter.map(|row_idx| { + let br_idx = reverse_index(row_idx, num_rows as u64); + let total_bytes = num_cols * byte_len; + let mut buf = vec![0u8; total_bytes]; + for col_idx in 0..num_cols { + columns[col_idx][br_idx] + .write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]); + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) + }) + .collect() +} + +/// Compute Keccak-256 leaf hashes for `commit_composition_polynomial`: one +/// leaf per row-pair, where leaf `i` hashes the BE concatenation of +/// `parts[..][br_0] ++ parts[..][br_1]` with +/// `br_k = reverse_index(2*i + k, num_rows)`. +/// +/// Returns `Vec` of length `parts[0].len() / 2`. +pub fn keccak_leaves_row_pair_bit_reversed(parts: &[Vec>]) -> Vec +where + E: IsField, + FieldElement: AsBytes + Sync + Send + math::traits::ByteConversion, +{ + use math::traits::ByteConversion; + + let num_parts = parts.len(); + if num_parts == 0 { + return Vec::new(); + } + let num_rows = parts[0].len(); + if num_rows == 0 { + return Vec::new(); + } + + let num_leaves = num_rows / 2; + debug_assert!( + num_rows.is_power_of_two(), + "num_rows must be a power of two for reverse_index" + ); + + let byte_len = as ByteConversion>::BYTE_LEN; + + #[cfg(feature = "parallel")] + let iter = (0..num_leaves).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let iter = 0..num_leaves; + + iter.map(|leaf_idx| { + let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); + let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); + let total_bytes = 2 * num_parts * byte_len; + let mut buf = vec![0u8; total_bytes]; + let mut offset = 0; + for part in parts.iter() { + part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + for part in parts.iter() { + part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) + }) + .collect() +} + /// The functionality of a STARK prover providing methods to run the STARK Prove protocol /// https://lambdaclass.github.io/lambdaworks/starks/protocol.html /// The default implementation is complete and is compatible with Stone prover @@ -400,41 +500,10 @@ pub trait IsStarkProver< FieldElement: AsBytes + Sync + Send + math::traits::ByteConversion, E: IsField, { - use math::traits::ByteConversion; - if columns.is_empty() || columns[0].is_empty() { return None; } - - let num_rows = columns[0].len(); - let num_cols = columns.len(); - let byte_len = as ByteConversion>::BYTE_LEN; - - debug_assert!( - num_rows.is_power_of_two(), - "num_rows must be a power of two for reverse_index" - ); - - #[cfg(feature = "parallel")] - let iter = (0..num_rows).into_par_iter(); - #[cfg(not(feature = "parallel"))] - let iter = 0..num_rows; - - // One allocation per row (was one per field element): write all columns - // into a single buffer, then hash once. - let hashed_leaves: Vec = iter - .map(|row_idx| { - let br_idx = reverse_index(row_idx, num_rows as u64); - let total_bytes = num_cols * byte_len; - let mut buf = vec![0u8; total_bytes]; - for col_idx in 0..num_cols { - columns[col_idx][br_idx] - .write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]); - } - BatchedMerkleTreeBackend::::hash_bytes(&buf) - }) - .collect(); - + let hashed_leaves = keccak_leaves_bit_reversed(columns); let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; Some((tree, root)) @@ -723,8 +792,6 @@ pub trait IsStarkProver< FieldElement: AsBytes + Sync + Send, FieldElement: AsBytes + Sync + Send + math::traits::ByteConversion, { - use math::traits::ByteConversion; - let num_parts = lde_composition_poly_parts_evaluations.len(); if num_parts == 0 { return None; @@ -733,41 +800,8 @@ pub trait IsStarkProver< if num_rows == 0 { return None; } - - let num_leaves = num_rows / 2; - debug_assert!( - num_rows.is_power_of_two(), - "num_rows must be a power of two for reverse_index" - ); - - let byte_len = as ByteConversion>::BYTE_LEN; - - // One allocation per leaf (was one per field element): write all parts - // into a single buffer. Each leaf = row_pair[2*i] ++ row_pair[2*i+1] after bit-reverse. - #[cfg(feature = "parallel")] - let iter = (0..num_leaves).into_par_iter(); - #[cfg(not(feature = "parallel"))] - let iter = 0..num_leaves; - - let hashed_leaves: Vec = iter - .map(|leaf_idx| { - let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); - let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); - let total_bytes = 2 * num_parts * byte_len; - let mut buf = vec![0u8; total_bytes]; - let mut offset = 0; - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - BatchedMerkleTreeBackend::::hash_bytes(&buf) - }) - .collect(); - + let hashed_leaves = + keccak_leaves_row_pair_bit_reversed(lde_composition_poly_parts_evaluations); let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; Some((tree, root)) From b32202bd113e0c7c2e56056ff7c09938ba4f17ab Mon Sep 17 00:00:00 2001 From: Joaquin Carletti <56092489+ColoCarletti@users.noreply.github.com> Date: Mon, 18 May 2026 15:58:07 -0300 Subject: [PATCH 18/22] Update crypto/crypto/src/merkle_tree/merkle.rs Co-authored-by: Gabriel Bosio <38794644+gabrielbosio@users.noreply.github.com> --- crypto/crypto/src/merkle_tree/merkle.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crypto/crypto/src/merkle_tree/merkle.rs b/crypto/crypto/src/merkle_tree/merkle.rs index 4f5fe28d3..b702a846e 100644 --- a/crypto/crypto/src/merkle_tree/merkle.rs +++ b/crypto/crypto/src/merkle_tree/merkle.rs @@ -83,8 +83,7 @@ where /// Read-only access to the full node buffer in standard layout: /// `nodes[0..leaves_len - 1]` are inner nodes (root at index 0) and - /// `nodes[leaves_len - 1..]` are the leaves. Useful for parity tests - /// against alternative tree builders. + /// `nodes[leaves_len - 1..]` are the leaves. pub fn nodes(&self) -> &[B::Node] { &self.nodes } From 302fd297d8fb30b3532b82297244c9db7224a867 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 18 May 2026 17:00:43 -0300 Subject: [PATCH 19/22] drop redundant synchronize --- crypto/math-cuda/src/merkle.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs index 4b73d89fb..c1f17ed2c 100644 --- a/crypto/math-cuda/src/merkle.rs +++ b/crypto/math-cuda/src/merkle.rs @@ -55,7 +55,6 @@ pub fn keccak_leaves_base( &mut out_dev.as_view_mut(), )?; let out = stream.clone_dtoh(&out_dev)?; - stream.synchronize()?; Ok(out) } @@ -90,7 +89,6 @@ pub fn keccak_leaves_ext3( &mut out_dev.as_view_mut(), )?; let out = stream.clone_dtoh(&out_dev)?; - stream.synchronize()?; Ok(out) } @@ -213,7 +211,6 @@ pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, leaves_len)?; let out = stream.clone_dtoh(&nodes_dev)?; - stream.synchronize()?; Ok(out) } @@ -316,7 +313,6 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; let out = stream.clone_dtoh(&nodes_dev)?; - stream.synchronize()?; drop(staging); Ok(out) } @@ -367,7 +363,6 @@ pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; let out = stream.clone_dtoh(&nodes_dev)?; - stream.synchronize()?; Ok(out) } From ca36b83cff68e2995b1907e11e8c2e3ffa88ef3a Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 18 May 2026 17:21:23 -0300 Subject: [PATCH 20/22] refactor --- crypto/math-cuda/src/lde.rs | 220 ++++++++++++--------------------- crypto/math-cuda/src/merkle.rs | 48 +------ 2 files changed, 84 insertions(+), 184 deletions(-) diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index af5e0a384..83244152a 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -17,7 +17,7 @@ use rayon::prelude::*; use crate::Result; use crate::device::{Backend, backend}; -use crate::merkle::{launch_keccak_base, launch_keccak_ext3}; +use crate::merkle::{keccak_launch_cfg, launch_keccak_base, launch_keccak_ext3}; use crate::ntt::run_ntt_body; /// Goldilocks `TWO_ADICITY = 32` puts the theoretical domain ceiling at @@ -65,7 +65,7 @@ impl KeccakCommit { /// /// Caller invariants: `pinned.len() >= 3 * columns.len() * n` and each /// `columns[c].len() >= 3 * n`. The caller must hold the pinned-staging lock. -fn pack_ext3_to_pinned_slabs(columns: &[&[u64]], pinned: &mut [u64], n: usize) { +pub(crate) fn pack_ext3_to_pinned_slabs(columns: &[&[u64]], pinned: &mut [u64], n: usize) { let m = columns.len(); debug_assert!(pinned.len() >= 3 * m * n); let pinned_ptr_u = pinned.as_mut_ptr() as usize; @@ -604,7 +604,6 @@ pub fn coset_lde_batch_base_into( // Parallel copy pinned → caller outputs. Caller's Vecs may still fault // on first write; we spread that cost across rayon cores. - #[allow(unused_imports)] let pinned_ptr = pinned.as_ptr() as usize; outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { let src = unsafe { @@ -616,17 +615,12 @@ pub fn coset_lde_batch_base_into( Ok(()) } -/// Variant of `coset_lde_batch_base_into` that also emits the Keccak-256 -/// Merkle leaf hashes from the LDE output. All on GPU, the device LDE buffer -/// is hashed in place. Leaves are computed reading columns at bit-reversed -/// rows (matching `commit_columns_bit_reversed` on the CPU side). -/// -/// `hashed_leaves_out` must be `lde_size * 32` bytes (one 32-byte digest -/// per output row, in natural row order). /// Fused LDE + Keccak-256 leaf hashing. Caller receives the `lde_size * 32` -/// bytes of leaf hashes in `hashed_leaves_out`. Thin wrapper over -/// `coset_lde_batch_base_into_with_merkle_tree_inner` with `LeavesOnly` — -/// no inner-tree build, no device handle. +/// bytes of leaf hashes in `hashed_leaves_out` (one 32-byte digest per output +/// row, in natural row order; leaves are computed reading columns at +/// bit-reversed rows, matching `commit_columns_bit_reversed` on the CPU +/// side). Thin wrapper over `coset_lde_batch_base_into_with_merkle_tree_inner` +/// with `LeavesOnly` — no inner-tree build, no device handle. pub fn coset_lde_batch_base_into_with_leaf_hash( columns: &[&[u64]], blowup_factor: usize, @@ -1102,8 +1096,16 @@ pub fn evaluate_poly_coset_batch_ext3_into( weights: &[u64], outputs: &mut [&mut [u64]], ) -> Result<()> { - evaluate_poly_coset_batch_ext3_into_inner(coefs, n, blowup_factor, weights, outputs, false) - .map(|_| ()) + evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + None, + false, + ) + .map(|_| ()) } /// Same as [`evaluate_poly_coset_batch_ext3_into`] but retains the de- @@ -1116,8 +1118,15 @@ pub fn evaluate_poly_coset_batch_ext3_into_keep( weights: &[u64], outputs: &mut [&mut [u64]], ) -> Result { - let opt = - evaluate_poly_coset_batch_ext3_into_inner(coefs, n, blowup_factor, weights, outputs, true)?; + let opt = evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + None, + true, + )?; Ok(opt.expect("keep_device_buf=true must return Some")) } @@ -1127,6 +1136,7 @@ fn evaluate_poly_coset_batch_ext3_into_inner( blowup_factor: usize, weights: &[u64], outputs: &mut [&mut [u64]], + merkle_nodes_out: Option<&mut [u8]>, keep_device_buf: bool, ) -> Result> { if coefs.is_empty() { @@ -1151,6 +1161,9 @@ fn evaluate_poly_coset_batch_ext3_into_inner( assert_eq!(o.len(), 3 * lde_size); } assert_u32_domain(lde_size, "evaluate_poly_coset_batch_ext3_into lde_size"); + if merkle_nodes_out.is_some() { + assert!(lde_size >= 2); + } let log_lde = lde_size.trailing_zeros() as u64; let mb = 3 * m; @@ -1209,8 +1222,39 @@ fn evaluate_poly_coset_batch_ext3_into_inner( mb_u32, )?; - stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; - stream.synchronize()?; + // Optional R2-style row-pair Merkle tree build on the LDE buffer. + if let Some(nodes_out) = merkle_nodes_out { + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + assert_eq!(nodes_out.len(), tight_total_nodes * 32); + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let log_num_rows = log_lde; + let num_parts_u64 = m as u64; + let cfg = keccak_launch_cfg(num_leaves as u64); + unsafe { + stream + .launch_builder(&be.keccak_comp_poly_leaves_ext3) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_parts_u64) + .arg(&lde_u64) + .arg(&log_num_rows) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; + + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?; + } else { + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + stream.synchronize()?; + } unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); drop(staging); @@ -1233,6 +1277,12 @@ fn evaluate_poly_coset_batch_ext3_into_inner( /// Row-pair commit: each leaf hashes 2 rows, so the tree has `lde_size / 2` /// leaves and `merkle_nodes_out` must have byte length `(lde_size - 1) * 32`. /// Requires `lde_size >= 2`. +/// Variant of [`evaluate_poly_coset_batch_ext3_into`] that also builds the +/// R2 composition-polynomial Merkle tree on device. +/// +/// Row-pair commit: each leaf hashes 2 bit-reversed rows, so the tree has +/// `lde_size / 2` leaves and `merkle_nodes_out` must have byte length +/// `(lde_size - 1) * 32`. Requires `lde_size >= 2`. pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( coefs: &[&[u64]], n: usize, @@ -1241,128 +1291,16 @@ pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( outputs: &mut [&mut [u64]], merkle_nodes_out: &mut [u8], ) -> Result<()> { - if coefs.is_empty() { - return Ok(()); - } - // (is_power_of_two returns false for 0). - if n == 0 { - return Ok(()); - } - let m = coefs.len(); - assert_eq!(outputs.len(), m); - assert!(n.is_power_of_two()); - assert_eq!(weights.len(), n); - assert!(blowup_factor.is_power_of_two()); - for c in coefs.iter() { - assert_eq!(c.len(), 3 * n); - } - let lde_size = n * blowup_factor; - assert_u32_domain( - lde_size, - "evaluate_poly_coset_batch_ext3_into_with_merkle_tree lde_size", - ); - for o in outputs.iter() { - assert_eq!(o.len(), 3 * lde_size); - } - assert!(lde_size >= 2); - let total_nodes = lde_size - 1; - assert_eq!(merkle_nodes_out.len(), total_nodes * 32); - let log_lde = lde_size.trailing_zeros() as u64; - - let mb = 3 * m; - let be = backend()?; - let stream = be.next_stream(); - let staging_slot = be.pinned_staging(); - - let mut staging = staging_slot.lock().unwrap(); - staging.ensure_capacity(mb * lde_size, &be.ctx)?; - let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - - pack_ext3_to_pinned_slabs(coefs, pinned, n); - - let mut buf = stream.alloc_zeros::(mb * lde_size)?; - for s in 0..mb { - let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); - stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; - } - - let fwd_tw = be.fwd_twiddles_for(log_lde)?; - let weights_dev = stream.clone_htod(weights)?; - - let n_u64 = n as u64; - let lde_u64 = lde_size as u64; - let col_stride_u64 = lde_size as u64; - let mb_u32 = mb as u32; - - launch_pointwise_mul_batched( - stream.as_ref(), - be, - &mut buf, - &weights_dev, - n_u64, - col_stride_u64, - mb_u32, - )?; - launch_bit_reverse_batched( - stream.as_ref(), - be, - &mut buf, - lde_u64, - log_lde, - col_stride_u64, - mb_u32, - )?; - run_batched_ntt_body( - stream.as_ref(), - &mut buf, - fwd_tw.as_ref(), - lde_u64, - log_lde, - col_stride_u64, - mb_u32, - )?; - - // Build the row-pair Merkle tree on device. - // - // Row-pair commit: each leaf hashes 2 rows (bit-reversed indices) → - // num_leaves = lde_size / 2. Tree size: 2*num_leaves - 1 = lde_size - 1. - let num_leaves = lde_size / 2; - let tight_total_nodes = 2 * num_leaves - 1; - let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; - let leaves_offset_bytes = (num_leaves - 1) * 32; - { - let mut leaves_view = - nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); - let log_num_rows = log_lde; - let num_parts_u64 = m as u64; - let grid = (num_leaves as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.keccak_comp_poly_leaves_ext3) - .arg(&buf) - .arg(&col_stride_u64) - .arg(&num_parts_u64) - .arg(&lde_u64) - .arg(&log_num_rows) - .arg(&mut leaves_view) - .launch(cfg)?; - } - } - crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; - - // D2H LDE and tree. - debug_assert_eq!(merkle_nodes_out.len(), tight_total_nodes * 32); - stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; - d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, merkle_nodes_out)?; - - unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); - drop(staging); - Ok(()) + evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + Some(merkle_nodes_out), + false, + ) + .map(|_| ()) } /// Batched coset LDE for Goldilocks **cubic extension** columns. /// diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs index c1f17ed2c..6faf12b51 100644 --- a/crypto/math-cuda/src/merkle.rs +++ b/crypto/math-cuda/src/merkle.rs @@ -17,10 +17,10 @@ //! to match `FieldElement::::write_bytes_be`. use cudarc::driver::{CudaSlice, CudaStream, CudaViewMut, LaunchConfig, PushKernelArg}; -use rayon::prelude::*; use crate::Result; use crate::device::{Backend, backend}; +use crate::lde::pack_ext3_to_pinned_slabs; /// Run GPU Keccak-256 leaf hashing on a base-field column buffer. /// @@ -98,7 +98,7 @@ pub fn keccak_leaves_ext3( /// keeps us inside the budget with some head-room. const KECCAK_BLOCK_DIM: u32 = 128; -fn keccak_launch_cfg(num_rows: u64) -> LaunchConfig { +pub(crate) fn keccak_launch_cfg(num_rows: u64) -> LaunchConfig { debug_assert!( num_rows <= u32::MAX as u64, "keccak_launch_cfg: num_rows ({num_rows}) exceeds u32 grid range", @@ -247,35 +247,7 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - parts_interleaved - .par_iter() - .enumerate() - .for_each(|(c, col)| { - let slab_a = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut( - (pinned_ptr_u as *mut u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); + pack_ext3_to_pinned_slabs(parts_interleaved, pinned, lde_size); // H2D the de-interleaved parts. let mut buf = stream.alloc_zeros::(mb * lde_size)?; @@ -291,12 +263,7 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res let num_parts_u64 = m as u64; let num_rows_u64 = lde_size as u64; let log_num_rows = lde_size.trailing_zeros() as u64; - let grid = (num_leaves as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; + let cfg = keccak_launch_cfg(num_leaves as u64); unsafe { stream .launch_builder(&be.keccak_comp_poly_leaves_ext3) @@ -344,12 +311,7 @@ pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { let mut leaves_view = nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); let num_leaves_u64 = num_leaves as u64; - let grid = (num_leaves as u32).div_ceil(128); - let cfg = LaunchConfig { - grid_dim: (grid, 1, 1), - block_dim: (128, 1, 1), - shared_mem_bytes: 0, - }; + let cfg = keccak_launch_cfg(num_leaves as u64); unsafe { stream .launch_builder(&be.keccak_fri_leaves_ext3) From 6676293981ed7ae43b2c53f25574899c3d429b0b Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 18 May 2026 17:49:54 -0300 Subject: [PATCH 21/22] fix --- crypto/math-cuda/tests/keccak_leaves.rs | 18 ++++++++++-------- crypto/math-cuda/tests/merkle_tree.rs | 7 ------- crypto/stark/src/prover.rs | 10 +++------- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/crypto/math-cuda/tests/keccak_leaves.rs b/crypto/math-cuda/tests/keccak_leaves.rs index 6b54003c2..d614e233d 100644 --- a/crypto/math-cuda/tests/keccak_leaves.rs +++ b/crypto/math-cuda/tests/keccak_leaves.rs @@ -1,23 +1,21 @@ //! Parity: GPU Keccak-256 leaf hashes must match the CPU prover's leaf -//! hashing helpers — `stark::prover::keccak_leaves_bit_reversed` for +//! hashing helpers. `stark::prover::keccak_leaves_bit_reversed` for //! per-row commits, `keccak_leaves_row_pair_bit_reversed` for the R2 -//! composition commit, and `FieldElementPairBackend::hash_data` for the FRI -//! commit. These are the same helpers the prover itself calls so any +//! composition commit, and `FriLayerMerkleTreeBackend::hash_data` for the +//! FRI commit. These are the same helpers the prover itself calls so any //! change to the CPU leaf-hash contract surfaces here. -use crypto::merkle_tree::backends::field_element_vector::FieldElementPairBackend; use crypto::merkle_tree::traits::IsMerkleTreeBackend; use math::field::element::FieldElement; use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; use math::field::goldilocks::GoldilocksField; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; -use sha3::Keccak256; +use stark::config::FriLayerMerkleTreeBackend; use stark::prover::{keccak_leaves_bit_reversed, keccak_leaves_row_pair_bit_reversed}; type Fp = FieldElement; type Fp3 = FieldElement; -type FriPairBackend = FieldElementPairBackend; #[test] fn keccak_leaves_base_matches_cpu() { @@ -169,10 +167,14 @@ fn keccak_fri_leaves_matches_cpu() { .collect(); // CPU reference: consecutive ext3 pairs hashed via the prover's - // `FieldElementPairBackend::hash_data`. + // FRI-layer Merkle backend. let cpu: Vec<[u8; 32]> = evals .chunks_exact(2) - .map(|c| FriPairBackend::hash_data(&[c[0], c[1]])) + .map(|c| { + FriLayerMerkleTreeBackend::::hash_data(&[ + c[0], c[1], + ]) + }) .collect(); let mut evals_interleaved = vec![0u64; 3 * lde_size]; diff --git a/crypto/math-cuda/tests/merkle_tree.rs b/crypto/math-cuda/tests/merkle_tree.rs index e5569379d..76fdeb919 100644 --- a/crypto/math-cuda/tests/merkle_tree.rs +++ b/crypto/math-cuda/tests/merkle_tree.rs @@ -53,13 +53,6 @@ fn merkle_tree_small() { } } -#[test] -fn merkle_tree_medium() { - for log_n in [10u32, 12, 14] { - run_parity(log_n, 500 + log_n as u64); - } -} - #[test] fn merkle_tree_large() { run_parity(18, 9999); diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index ef7935325..8a4577360 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -10,7 +10,7 @@ use math::fft::errors::FFTError; use log::info; use math::field::traits::{IsField, IsSubFieldOf}; -use math::traits::AsBytes; +use math::traits::{AsBytes, ByteConversion}; use math::{ field::{element::FieldElement, traits::IsFFTField}, polynomial::Polynomial, @@ -385,10 +385,8 @@ where pub fn keccak_leaves_bit_reversed(columns: &[Vec>]) -> Vec where E: IsField, - FieldElement: AsBytes + Sync + Send + math::traits::ByteConversion, + FieldElement: AsBytes + Sync + Send + ByteConversion, { - use math::traits::ByteConversion; - if columns.is_empty() || columns[0].is_empty() { return Vec::new(); } @@ -429,10 +427,8 @@ where pub fn keccak_leaves_row_pair_bit_reversed(parts: &[Vec>]) -> Vec where E: IsField, - FieldElement: AsBytes + Sync + Send + math::traits::ByteConversion, + FieldElement: AsBytes + Sync + Send + ByteConversion, { - use math::traits::ByteConversion; - let num_parts = parts.len(); if num_parts == 0 { return Vec::new(); From fde929a2ce334a403f2ea7c1ab62d16f75ae542f Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 18 May 2026 19:07:18 -0300 Subject: [PATCH 22/22] fix typo --- crypto/math-cuda/src/lde.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 83244152a..02f109938 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -1274,12 +1274,6 @@ fn evaluate_poly_coset_batch_ext3_into_inner( /// the LDE output, builds the R2 composition-polynomial Merkle tree on device /// (row-pair Keccak leaves at bit-reversed indices + pair-hash inner tree). /// -/// Row-pair commit: each leaf hashes 2 rows, so the tree has `lde_size / 2` -/// leaves and `merkle_nodes_out` must have byte length `(lde_size - 1) * 32`. -/// Requires `lde_size >= 2`. -/// Variant of [`evaluate_poly_coset_batch_ext3_into`] that also builds the -/// R2 composition-polynomial Merkle tree on device. -/// /// Row-pair commit: each leaf hashes 2 bit-reversed rows, so the tree has /// `lde_size / 2` leaves and `merkle_nodes_out` must have byte length /// `(lde_size - 1) * 32`. Requires `lde_size >= 2`.