diff --git a/Cargo.lock b/Cargo.lock index 001b4b841..0f01bf090 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2129,12 +2129,14 @@ dependencies = [ name = "math-cuda" version = "0.1.0" dependencies = [ + "crypto", "cudarc", "math", "rand 0.8.5", "rand_chacha 0.3.1", "rayon", "sha3", + "stark", ] [[package]] diff --git a/crypto/crypto/src/merkle_tree/merkle.rs b/crypto/crypto/src/merkle_tree/merkle.rs index 55fa49a83..b702a846e 100644 --- a/crypto/crypto/src/merkle_tree/merkle.rs +++ b/crypto/crypto/src/merkle_tree/merkle.rs @@ -81,6 +81,13 @@ where }) } + /// Read-only access to the full node buffer in standard layout: + /// `nodes[0..leaves_len - 1]` are inner nodes (root at index 0) and + /// `nodes[leaves_len - 1..]` are the leaves. + pub fn nodes(&self) -> &[B::Node] { + &self.nodes + } + /// Returns a Merkle proof for the element/s at position pos /// For example, give me an inclusion proof for the 3rd element in the /// Merkle tree diff --git a/crypto/math-cuda/Cargo.toml b/crypto/math-cuda/Cargo.toml index 0990dd6d6..e700ec73e 100644 --- a/crypto/math-cuda/Cargo.toml +++ b/crypto/math-cuda/Cargo.toml @@ -17,7 +17,9 @@ math = { path = "../math" } rayon = "1.7" [dev-dependencies] +crypto = { path = "../crypto" } rand = { version = "0.8.5", features = ["std"] } rand_chacha = "0.3.1" rayon = "1.7" sha3 = "0.10.8" +stark = { path = "../stark" } diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs index cf541b5fd..bc84f2653 100644 --- a/crypto/math-cuda/build.rs +++ b/crypto/math-cuda/build.rs @@ -110,4 +110,5 @@ fn main() { compile_ptx("arith.cu", "arith.ptx", have_nvcc); compile_ptx("ntt.cu", "ntt.ptx", have_nvcc); + compile_ptx("keccak.cu", "keccak.ptx", have_nvcc); } diff --git a/crypto/math-cuda/kernels/keccak.cu b/crypto/math-cuda/kernels/keccak.cu new file mode 100644 index 000000000..c22bc4d05 --- /dev/null +++ b/crypto/math-cuda/kernels/keccak.cu @@ -0,0 +1,349 @@ +// Original Keccak-256 (0x01 padding) +// +// Used by the lambda-vm prover's Merkle commit: +// leaf = Keccak-256(concat(col_0[br_idx].to_be_bytes(), col_1[br_idx].to_be_bytes(), ...)) +// where `br_idx = bit_reverse(row_idx, log_num_rows)` and each element is +// written in BIG-ENDIAN canonical form (per `FieldElement::write_bytes_be`). +// +// Keccak state is 5x5 lanes of u64, interpreted little-endian. Rate = 136 B +// (17 lanes) for 256-bit output, capacity = 64 B (8 lanes). +// +// Since every input byte is u64-aligned (each field element is 8 or 24 bytes), +// we can absorb lane-by-lane instead of byte-by-byte. Canonicalise + byte-swap +// each u64 on read to turn a BE-serialised element into its LE-interpreted +// lane value. + +#include +#include "goldilocks.cuh" + +__device__ __constant__ uint64_t KECCAK_RC[24] = { + 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, + 0x8000000080008000ULL, 0x000000000000808bULL, 0x0000000080000001ULL, + 0x8000000080008081ULL, 0x8000000000008009ULL, 0x000000000000008aULL, + 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, + 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, + 0x8000000000008003ULL, 0x8000000000008002ULL, 0x8000000000000080ULL, + 0x000000000000800aULL, 0x800000008000000aULL, 0x8000000080008081ULL, + 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL, +}; + +// Rotation offsets indexed by lane position x + 5*y. Standard Keccak rho. +__device__ __constant__ uint32_t KECCAK_RHO_OFFSETS[25] = { + 0, 1, 62, 28, 27, // y=0: x=0..4 + 36, 44, 6, 55, 20, // y=1 + 3, 10, 43, 25, 39, // y=2 + 41, 45, 15, 21, 8, // y=3 + 18, 2, 61, 56, 14, // y=4 +}; + +__device__ __forceinline__ uint64_t rotl64(uint64_t x, uint32_t n) { + return (n == 0) ? x : ((x << n) | (x >> (64 - n))); +} + +__device__ __forceinline__ uint64_t bswap64(uint64_t x) { + // Reverse byte order: turns a BE-serialised u64 into its LE-read lane. + x = ((x & 0x00ff00ff00ff00ffULL) << 8) | ((x & 0xff00ff00ff00ff00ULL) >> 8); + x = ((x & 0x0000ffff0000ffffULL) << 16) | ((x & 0xffff0000ffff0000ULL) >> 16); + return (x << 32) | (x >> 32); +} + +__device__ __forceinline__ void keccak_f1600(uint64_t st[25]) { + uint64_t C[5], D[5], B[25]; + // No outer unroll: fully unrolling the 24 rounds slowed the kernel ~7.5% on RTX 5090. + for (int r = 0; r < 24; ++r) { + // Theta + #pragma unroll + for (int x = 0; x < 5; ++x) { + C[x] = st[x] ^ st[x + 5] ^ st[x + 10] ^ st[x + 15] ^ st[x + 20]; + } + #pragma unroll + for (int x = 0; x < 5; ++x) { + D[x] = C[(x + 4) % 5] ^ rotl64(C[(x + 1) % 5], 1); + } + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + st[x + 5 * y] ^= D[x]; + } + } + + // Rho + Pi: B[pi(x,y)] = rotl(st[x,y], rho(x,y)) + // pi: (x', y') = (y, (2x + 3y) mod 5) + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + int nx = y; + int ny = (2 * x + 3 * y) % 5; + B[nx + 5 * ny] = rotl64(st[x + 5 * y], KECCAK_RHO_OFFSETS[x + 5 * y]); + } + } + + // Chi + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + st[x + 5 * y] = + B[x + 5 * y] ^ ((~B[((x + 1) % 5) + 5 * y]) & B[((x + 2) % 5) + 5 * y]); + } + } + + // Iota + st[0] ^= KECCAK_RC[r]; + } +} + +// --------------------------------------------------------------------------- +// Helper: absorb one 8-byte lane (already byte-swapped from BE serialisation +// into Keccak's LE lane form) into the sponge at `rate_pos` (in bytes). +// Permutes when a full 136-byte block has been absorbed. +// --------------------------------------------------------------------------- +__device__ __forceinline__ void absorb_lane(uint64_t st[25], + uint32_t &rate_pos, + uint64_t lane) { + st[rate_pos / 8] ^= lane; + rate_pos += 8; + if (rate_pos == 136) { + keccak_f1600(st); + rate_pos = 0; + } +} + +// --------------------------------------------------------------------------- +// After all data lanes absorbed, apply Keccak (pre-SHA-3) padding: a single +// 0x01 byte at the current position, then bit 0x80 on the last rate byte +// (byte 135 = last byte of lane 16). Then permute and squeeze 32 bytes from +// the first four lanes in LE order. +// --------------------------------------------------------------------------- +__device__ __forceinline__ void finalize_keccak256(uint64_t st[25], + uint32_t rate_pos, + uint8_t *out32) { + // 0x01 at rate_pos + st[rate_pos / 8] ^= ((uint64_t)0x01) << ((rate_pos & 7) * 8); + // 0x80 at byte 135 (last byte of lane 16) + st[16] ^= ((uint64_t)0x80) << 56; + keccak_f1600(st); + + // Squeeze 32 bytes: 4 lanes, each LE-serialised. + #pragma unroll + for (int i = 0; i < 4; ++i) { + uint64_t lane = st[i]; + #pragma unroll + for (int b = 0; b < 8; ++b) { + out32[i * 8 + b] = (uint8_t)((lane >> (b * 8)) & 0xff); + } + } +} + +// --------------------------------------------------------------------------- +// Goldilocks BASE-FIELD leaf hashing. +// +// For output row `row_idx` (natural order), the leaf hashes the canonical BE +// byte representation of `columns[c][bit_reverse(row_idx, log_num_rows)]` for +// `c` in `[0, num_cols)`, concatenated in column order. Writes 32 bytes to +// `hashed_leaves_out[row_idx * 32 ..]`. +// +// `columns_base_ptr` points to a `num_cols * col_stride * u64` buffer; column +// `c` is the contiguous slab `[c*col_stride .. c*col_stride + num_rows]`. The +// remaining `col_stride - num_rows` entries (if any) are ignored. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak256_leaves_base_batched( + const uint64_t *columns_base_ptr, + uint64_t col_stride, + uint64_t num_cols, + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *hashed_leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_rows) return; + + // Bit-reverse the row index so we read columns at `br` but write the + // hashed leaf at `tid` — matching the CPU `commit_columns_bit_reversed`. + uint64_t br = __brevll(tid) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + for (uint64_t c = 0; c < num_cols; ++c) { + uint64_t v = columns_base_ptr[c * col_stride + br]; + // Canonicalise to match `canonical_u64().to_be_bytes()` on host. + uint64_t canon = goldilocks::canonical(v); + // The on-disk leaf bytes are canon.to_be_bytes(). Keccak reads those + // as a LE lane, which equals bswap64(canon). + uint64_t lane = bswap64(canon); + absorb_lane(st, rate_pos, lane); + } + + finalize_keccak256(st, rate_pos, hashed_leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// Goldilocks EXT3 leaf hashing (3 base-field components per ext3 element). +// +// Components live in three separate base-field slabs (our de-interleaved +// layout). Column `c` component `k` is at `columns_base_ptr[(c*3 + k)*col_stride +// + br]`. Per-element BE bytes are `[comp0, comp1, comp2]` each 8 BE bytes +// (matches `FieldElement::::write_bytes_be`). +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak256_leaves_ext3_batched( + const uint64_t *columns_base_ptr, + uint64_t col_stride, + uint64_t num_cols, // number of ext3 columns (NOT slabs) + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *hashed_leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_rows) return; + uint64_t br = __brevll(tid) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + for (uint64_t c = 0; c < num_cols; ++c) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = columns_base_ptr[(c * 3 + (uint64_t)k) * col_stride + br]; + uint64_t canon = goldilocks::canonical(v); + uint64_t lane = bswap64(canon); + absorb_lane(st, rate_pos, lane); + } + } + + finalize_keccak256(st, rate_pos, hashed_leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// R2 composition-polynomial leaf hashing. +// +// Each leaf hashes `2 * num_parts` ext3 values taken from bit-reversed rows +// `br_0 = reverse_index(2*leaf_idx)` and `br_1 = reverse_index(2*leaf_idx+1)` +// across all `num_parts` parts, in (br_0 row: part 0..K-1) then (br_1 row: +// part 0..K-1) order. Each ext3 value is 3 base components × 8 BE bytes. +// +// Columns arrive in the de-interleaved 3-slab layout: part `p` component +// `k` is at `parts_base_ptr[(p*3 + k) * col_stride + row]`. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_comp_poly_leaves_ext3( + const uint64_t *parts_base_ptr, + uint64_t col_stride, + uint64_t num_parts, + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t num_leaves = num_rows >> 1; + if (tid >= num_leaves) return; + + uint64_t br_0 = __brevll(2 * tid) >> (64 - log_num_rows); + uint64_t br_1 = __brevll(2 * tid + 1) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + // First row (br_0): part 0..K-1 × 3 components each. + for (uint64_t p = 0; p < num_parts; ++p) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = parts_base_ptr[(p * 3 + (uint64_t)k) * col_stride + br_0]; + uint64_t canon = goldilocks::canonical(v); + absorb_lane(st, rate_pos, bswap64(canon)); + } + } + // Second row (br_1). + for (uint64_t p = 0; p < num_parts; ++p) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = parts_base_ptr[(p * 3 + (uint64_t)k) * col_stride + br_1]; + uint64_t canon = goldilocks::canonical(v); + absorb_lane(st, rate_pos, bswap64(canon)); + } + } + + finalize_keccak256(st, rate_pos, leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// FRI layer leaf hashing. +// +// Each leaf hashes 2 consecutive ext3 values: Keccak256 over +// evals[2j].to_bytes_be() ++ evals[2j+1].to_bytes_be() +// = 48 BE bytes = 6 u64 BE lanes. No bit reversal, no column slab layout. +// The input is a single interleaved ext3 eval vector `[a0,a1,a2,b0,b1,b2,...]`. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_fri_leaves_ext3( + const uint64_t *evals_interleaved, // 3 * num_evals u64s (ext3 interleaved) + uint64_t num_leaves, // = num_evals / 2 + uint8_t *leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_leaves) return; + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + uint32_t rate_pos = 0; + + const uint64_t *left = evals_interleaved + 2 * tid * 3; // 3 u64s + const uint64_t *right = left + 3; + #pragma unroll + for (int i = 0; i < 3; ++i) { + uint64_t canon = goldilocks::canonical(left[i]); + absorb_lane(st, rate_pos, bswap64(canon)); + } + #pragma unroll + for (int i = 0; i < 3; ++i) { + uint64_t canon = goldilocks::canonical(right[i]); + absorb_lane(st, rate_pos, bswap64(canon)); + } + + finalize_keccak256(st, rate_pos, leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// Merkle inner-tree pair hash: one level of the inner Merkle tree. +// +// `nodes` is the full Merkle node buffer (length `2*leaves_len - 1`, each +// element 32 bytes). `parent_begin` is the node-index offset of the first +// parent slot in this level. Children live at `parent_begin + n_pairs`. +// The layout mirrors `crypto/crypto/src/merkle_tree/merkle.rs`: +// +// children: nodes[parent_begin + n_pairs .. parent_begin + 3 * n_pairs] +// parents: nodes[parent_begin .. parent_begin + n_pairs] +// +// Each thread hashes one child pair → one parent. Keccak-256 of the +// concatenation of two 32-byte siblings, identical to +// `FieldElementVectorBackend::hash_new_parent` on host. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_merkle_level( + uint8_t *nodes, + uint64_t parent_begin, // node index (counted in 32-byte nodes) + uint64_t n_pairs) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_pairs) return; + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + // `nodes` comes from cuMemAlloc (256-byte aligned); each 32-byte node + // sits at a 32-byte-aligned offset, so the u64 cast is safe. + const uint64_t *left = reinterpret_cast( + nodes + (parent_begin + n_pairs + 2 * tid) * 32); + #pragma unroll + for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, left[i]); + + const uint64_t *right = reinterpret_cast( + nodes + (parent_begin + n_pairs + 2 * tid + 1) * 32); + #pragma unroll + for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, right[i]); + + finalize_keccak256(st, rate_pos, nodes + (parent_begin + tid) * 32); +} diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index b3c4a1c56..d6d5fc403 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -92,6 +92,7 @@ impl Drop for PinnedStaging { const ARITH_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/arith.ptx")); const NTT_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/ntt.ptx")); +const KECCAK_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/keccak.ptx")); /// Number of CUDA streams in the pool. Larger pools let many rayon-parallel /// callers overlap on the GPU without serializing on stream ownership. The /// default stream is deliberately excluded because it synchronises with all @@ -107,6 +108,10 @@ pub struct Backend { /// buffers 32×-inflated memory use and multiplied the one-time pinning /// cost for every first use of a new table size). pinned_staging: Mutex, + /// Separate pinned staging for Merkle leaf hashes. Sized `num_rows * 32` + /// bytes. It lives alongside the LDE staging so the GPU→host D2H for + /// hashed leaves runs at full PCIe line-rate. + pinned_hashes: Mutex, util_stream: Arc, next: AtomicUsize, @@ -132,6 +137,13 @@ pub struct Backend { pub pointwise_mul_batched: CudaFunction, pub scalar_mul_batched: CudaFunction, + // keccak.ptx + pub keccak256_leaves_base_batched: CudaFunction, + pub keccak256_leaves_ext3_batched: CudaFunction, + pub keccak_comp_poly_leaves_ext3: CudaFunction, + pub keccak_fri_leaves_ext3: CudaFunction, + pub keccak_merkle_level: CudaFunction, + // Twiddle caches keyed by log_n. fwd_twiddles: Mutex>>>>, inv_twiddles: Mutex>>>>, @@ -148,12 +160,14 @@ impl Backend { let arith = ctx.load_module(Ptx::from_src(ARITH_PTX))?; let ntt = ctx.load_module(Ptx::from_src(NTT_PTX))?; + let keccak = ctx.load_module(Ptx::from_src(KECCAK_PTX))?; let mut streams = Vec::with_capacity(STREAM_POOL_SIZE); for _ in 0..STREAM_POOL_SIZE { streams.push(ctx.new_stream()?); } let pinned_staging = Mutex::new(PinnedStaging::empty()); + let pinned_hashes = Mutex::new(PinnedStaging::empty()); // Separate "utility" stream for twiddle uploads and other bookkeeping; // not part of the pool that callers rotate through. let util_stream = ctx.new_stream()?; @@ -182,11 +196,17 @@ impl Backend { ntt_dit_8_levels_batched: ntt.load_function("ntt_dit_8_levels_batched")?, pointwise_mul_batched: ntt.load_function("pointwise_mul_batched")?, scalar_mul_batched: ntt.load_function("scalar_mul_batched")?, + keccak256_leaves_base_batched: keccak.load_function("keccak256_leaves_base_batched")?, + keccak256_leaves_ext3_batched: keccak.load_function("keccak256_leaves_ext3_batched")?, + keccak_comp_poly_leaves_ext3: keccak.load_function("keccak_comp_poly_leaves_ext3")?, + keccak_fri_leaves_ext3: keccak.load_function("keccak_fri_leaves_ext3")?, + keccak_merkle_level: keccak.load_function("keccak_merkle_level")?, fwd_twiddles: Mutex::new(vec![None; max_log]), inv_twiddles: Mutex::new(vec![None; max_log]), ctx, streams, pinned_staging, + pinned_hashes, util_stream, next: AtomicUsize::new(0), }) @@ -205,6 +225,12 @@ impl Backend { &self.pinned_staging } + /// Separate pinned staging for Merkle leaf hash output. Sized in u64 + /// units. Caller should reserve `(num_rows * 32 + 7) / 8` u64s. + pub fn pinned_hashes(&self) -> &Mutex { + &self.pinned_hashes + } + pub fn fwd_twiddles_for(&self, log_n: u64) -> Result>> { self.cached_twiddles(log_n, true) } diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 719a50931..02f109938 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -12,11 +12,12 @@ use std::sync::Arc; -use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; +use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; use rayon::prelude::*; use crate::Result; -use crate::device::backend; +use crate::device::{Backend, backend}; +use crate::merkle::{keccak_launch_cfg, launch_keccak_base, launch_keccak_ext3}; use crate::ntt::run_ntt_body; /// Goldilocks `TWO_ADICITY = 32` puts the theoretical domain ceiling at @@ -31,6 +32,197 @@ fn assert_u32_domain(n: usize, what: &str) { ); } +/// Output shape requested from the fused LDE + Keccak entry points. +#[derive(Copy, Clone, PartialEq, Eq)] +enum KeccakCommit { + /// Only the `lde_size` keccak-256 leaves; no inner-tree build. Caller + /// receives `lde_size * 32` bytes. + LeavesOnly, + /// Full Merkle tree: leaves at the tail + inner nodes built on-device. + /// Caller receives `(2*lde_size - 1) * 32` bytes. + FullTree, +} + +impl KeccakCommit { + fn total_nodes_bytes(self, lde_size: usize) -> usize { + match self { + KeccakCommit::LeavesOnly => lde_size * 32, + KeccakCommit::FullTree => (2 * lde_size - 1) * 32, + } + } + + fn leaves_offset_bytes(self, lde_size: usize) -> usize { + match self { + KeccakCommit::LeavesOnly => 0, + KeccakCommit::FullTree => (lde_size - 1) * 32, + } + } +} + +/// De-interleave `columns` (each `3*n` u64s, ext3-per-element layout +/// `[a, b, c, a, b, c, ...]`) into `pinned` as `3*m` base-field slabs. +/// Component `k` of column `c` lands at `pinned[(c*3 + k)*n .. (c*3 + k)*n + n]`. +/// +/// Caller invariants: `pinned.len() >= 3 * columns.len() * n` and each +/// `columns[c].len() >= 3 * n`. The caller must hold the pinned-staging lock. +pub(crate) fn pack_ext3_to_pinned_slabs(columns: &[&[u64]], pinned: &mut [u64], n: usize) { + let m = columns.len(); + debug_assert!(pinned.len() >= 3 * m * n); + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: each task writes to disjoint `[(c*3 + k)*n .. ..+n]` regions + // of `pinned`. The outer `&mut [u64]` borrow guarantees no aliasing. + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); +} + +/// Re-interleave the `3*m` base-field slabs in `pinned` (layout matches +/// `pack_ext3_to_pinned_slabs`) into `outputs`, writing each as +/// `3*lde_size` interleaved u64s. +fn unpack_pinned_slabs_to_ext3(pinned: &[u64], outputs: &mut [&mut [u64]], lde_size: usize) { + let m = outputs.len(); + debug_assert!(pinned.len() >= 3 * m * lde_size); + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + // SAFETY: each task reads from disjoint `[(c*3 + k)*lde_size .. ..+lde_size]` + // regions of `pinned`. Caller borrows `pinned` for the duration of the call. + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); +} + +/// Run `bit_reverse_permute_batched` over `m` columns of length `n` each +/// (column stride `col_stride`). 256 threads per block, grid sized to cover +/// `n` per column. +fn launch_bit_reverse_batched( + stream: &CudaStream, + be: &Backend, + buf: &mut CudaSlice, + n: u64, + log_n: u64, + col_stride: u64, + m: u32, +) -> Result<()> { + let cfg = LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(buf) + .arg(&n) + .arg(&log_n) + .arg(&col_stride) + .launch(cfg)?; + } + Ok(()) +} + +/// D2H `dst.len()` bytes from `dev_bytes` into the caller's pageable `dst` +/// via the pinned-hashes staging buffer. Synchronises the stream first (so +/// any other D2H queued on the same stream also drains), then does a rayon +/// chunked memcpy pinned → caller to spread page-fault cost across cores. +fn d2h_bytes_via_pinned_hashes( + stream: &Arc, + be: &Backend, + dev_bytes: &CudaSlice, + dst: &mut [u8], +) -> Result<()> { + let n_bytes = dst.len(); + let u64_len = n_bytes.div_ceil(8); + let staging_slot = be.pinned_hashes(); + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(u64_len, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(u64_len) }; + // Reinterpret the u64 pinned buffer as bytes — same allocation, just + // typed differently. SAFETY: u64 has stricter alignment than u8 and the + // byte length fits in the `u64_len` capacity (rounded up to u64). + let pinned_bytes: &mut [u8] = + unsafe { std::slice::from_raw_parts_mut(pinned.as_mut_ptr() as *mut u8, n_bytes) }; + stream.memcpy_dtoh(dev_bytes, pinned_bytes)?; + stream.synchronize()?; + + // Single-threaded `copy_from_slice` faults virgin pageable pages one at + // a time; the mm_struct rwsem serialises them at prover scale. Chunk so + // ~N cores pre-fault+write in parallel. + const CHUNK: usize = 64 * 1024; + let src_ptr = pinned_bytes.as_ptr() as usize; + dst.par_chunks_mut(CHUNK).enumerate().for_each(|(i, d)| { + // SAFETY: each task reads `[i*CHUNK .. i*CHUNK + d.len()]` of + // `pinned_bytes`, which is disjoint per `i` and lives until `staging` + // is dropped below. + let src = + unsafe { std::slice::from_raw_parts((src_ptr as *const u8).add(i * CHUNK), d.len()) }; + d.copy_from_slice(src); + }); + drop(staging); + Ok(()) +} + +/// Run `pointwise_mul_batched`: `buf[c*col_stride + i] *= weights[i]` for +/// `m` columns, `n` elements each. +fn launch_pointwise_mul_batched( + stream: &CudaStream, + be: &Backend, + buf: &mut CudaSlice, + weights: &CudaSlice, + n: u64, + col_stride: u64, + m: u32, +) -> Result<()> { + let cfg = LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(buf) + .arg(weights) + .arg(&n) + .arg(&col_stride) + .launch(cfg)?; + } + Ok(()) +} + /// Handle to a base-field LDE kept live on device after R1 commit. /// Layout: `m` columns, each `lde_size` u64s, column `c` at byte offset /// `c * lde_size * 8` within `buf`. Freed when `buf` Arc drops. @@ -208,23 +400,15 @@ pub fn coset_lde_batch_base( let m_u32 = m as u32; // === 1. Bit-reverse first N of every column === - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; // === 2. iNTT body over all columns === run_batched_ntt_body( @@ -238,42 +422,26 @@ pub fn coset_lde_batch_base( )?; // === 3. Pointwise multiply by coset weights (includes 1/N) === - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + m_u32, + )?; // === 4. Bit-reverse full LDE of every column === - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; // === 5. Forward NTT on full LDE of every column === run_batched_ntt_body( @@ -385,23 +553,15 @@ pub fn coset_lde_batch_base_into( let m_u32 = m as u32; // iNTT bit-reverse + body, pointwise mul, forward bit-reverse + body. - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -411,40 +571,24 @@ pub fn coset_lde_batch_base_into( col_stride_u64, m_u32, )?; - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, m_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + m_u32, + )?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -460,7 +604,6 @@ pub fn coset_lde_batch_base_into( // Parallel copy pinned → caller outputs. Caller's Vecs may still fault // on first write; we spread that cost across rayon cores. - #[allow(unused_imports)] let pinned_ptr = pinned.as_ptr() as usize; outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { let src = unsafe { @@ -472,57 +615,97 @@ pub fn coset_lde_batch_base_into( Ok(()) } -/// Batched ext3 polynomial → coset evaluation. -/// -/// Input: M ext3 columns of `n` coefficients each (interleaved, 3n u64). -/// Output: M ext3 columns of `n * blowup_factor` evaluations each at the -/// offset-coset. +/// Fused LDE + Keccak-256 leaf hashing. Caller receives the `lde_size * 32` +/// bytes of leaf hashes in `hashed_leaves_out` (one 32-byte digest per output +/// row, in natural row order; leaves are computed reading columns at +/// bit-reversed rows, matching `commit_columns_bit_reversed` on the CPU +/// side). Thin wrapper over `coset_lde_batch_base_into_with_merkle_tree_inner` +/// with `LeavesOnly` — no inner-tree build, no device handle. +pub fn coset_lde_batch_base_into_with_leaf_hash( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + hashed_leaves_out: &mut [u8], +) -> Result<()> { + coset_lde_batch_base_into_with_merkle_tree_inner( + columns, + blowup_factor, + weights, + outputs, + hashed_leaves_out, + KeccakCommit::LeavesOnly, + false, + ) + .map(|_| ()) +} + +/// Like `coset_lde_batch_base_into_with_leaf_hash`, but also builds the full +/// Merkle tree on device and returns the `2*lde_size - 1` node buffer back +/// to the caller in `merkle_nodes_out` (byte length `(2*lde_size - 1) * 32`). /// -/// Skips the iFFT stage of [`coset_lde_batch_ext3_into`] (input is -/// coefficients, not evaluations). Weights encode the coset shift: -/// `weights[k] = offset^k` (NO 1/N because iFFT normalisation doesn't apply). -pub fn evaluate_poly_coset_batch_ext3_into( - coefs: &[&[u64]], - n: usize, +/// The leaf hashes are never exposed to the caller — they stay on device and +/// feed straight into the pair-hash tree kernel, avoiding the +/// pinned→pageable→pinned round-trip that the separate-step GPU tree build +/// would pay. +pub fn coset_lde_batch_base_into_with_merkle_tree( + columns: &[&[u64]], blowup_factor: usize, weights: &[u64], outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], ) -> Result<()> { - evaluate_poly_coset_batch_ext3_into_inner(coefs, n, blowup_factor, weights, outputs, false) - .map(|_| ()) + coset_lde_batch_base_into_with_merkle_tree_inner( + columns, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + KeccakCommit::FullTree, + false, + ) + .map(|_| ()) } -/// Same as [`evaluate_poly_coset_batch_ext3_into`] but retains the de- -/// interleaved LDE device buffer as a `GpuLdeExt3` handle. Lets R2 commit -/// and R4 DEEP composition read the composition-parts LDE without -/// re-H2D'ing. -pub fn evaluate_poly_coset_batch_ext3_into_keep( - coefs: &[&[u64]], - n: usize, +/// Fused LDE + leaf-hash + Merkle tree build. If `keep_device_buf` is true, +/// returns an `Arc>` wrapping the LDE device buffer so callers +/// (R2–R4 GPU paths) can reuse the LDE without a re-H2D. +pub fn coset_lde_batch_base_into_with_merkle_tree_keep( + columns: &[&[u64]], blowup_factor: usize, weights: &[u64], outputs: &mut [&mut [u64]], -) -> Result { - let opt = - evaluate_poly_coset_batch_ext3_into_inner(coefs, n, blowup_factor, weights, outputs, true)?; - Ok(opt.expect("keep_device_buf=true must return Some")) + merkle_nodes_out: &mut [u8], +) -> Result { + let opt = coset_lde_batch_base_into_with_merkle_tree_inner( + columns, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + KeccakCommit::FullTree, + true, + )?; + let handle = opt.expect("keep_device_buf=true must return Some"); + Ok(handle) } -fn evaluate_poly_coset_batch_ext3_into_inner( - coefs: &[&[u64]], - n: usize, +fn coset_lde_batch_base_into_with_merkle_tree_inner( + columns: &[&[u64]], blowup_factor: usize, weights: &[u64], outputs: &mut [&mut [u64]], + nodes_out: &mut [u8], + commit: KeccakCommit, keep_device_buf: bool, -) -> Result> { - if coefs.is_empty() { +) -> Result> { + if columns.is_empty() { assert_eq!(outputs.len(), 0); return Ok(None); } - let m = coefs.len(); + let m = columns.len(); assert_eq!(outputs.len(), m); - // Empty domain must short-circuit before the power-of-two assert + let n = columns[0].len(); // (is_power_of_two returns false for 0). if n == 0 { return Ok(None); @@ -530,42 +713,270 @@ fn evaluate_poly_coset_batch_ext3_into_inner( assert!(n.is_power_of_two()); assert_eq!(weights.len(), n); assert!(blowup_factor.is_power_of_two()); - for c in coefs.iter() { - assert_eq!(c.len(), 3 * n); - } let lde_size = n * blowup_factor; + assert_u32_domain( + lde_size, + "coset_lde_batch_base_into_with_merkle_tree lde_size", + ); for o in outputs.iter() { - assert_eq!(o.len(), 3 * lde_size); + assert_eq!(o.len(), lde_size); } - assert_u32_domain(lde_size, "evaluate_poly_coset_batch_ext3_into lde_size"); + let nodes_dev_bytes = commit.total_nodes_bytes(lde_size); + assert_eq!(nodes_out.len(), nodes_dev_bytes); + let log_n = n.trailing_zeros() as u64; let log_lde = lde_size.trailing_zeros() as u64; - let mb = 3 * m; let be = backend()?; let stream = be.next_stream(); let staging_slot = be.pinned_staging(); let mut staging = staging_slot.lock().unwrap(); - staging.ensure_capacity(mb * lde_size, &be.ctx)?; - let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + staging.ensure_capacity(m * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - coefs.par_iter().enumerate().for_each(|(c, col)| { - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + let pinned_base_ptr = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + let dst = + unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; + dst.copy_from_slice(col); + }); + + let mut buf = stream.alloc_zeros::(m * lde_size)?; + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // iNTT + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + m_u32, + )?; + // forward NTT at LDE size + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + + // Allocate the device output buffer. In `LeavesOnly` mode this is just + // `lde_size * 32` bytes (the leaves themselves); in `FullTree` mode it's + // `(2*lde_size - 1) * 32` bytes (leaves in the tail + inner nodes filled + // below). `alloc` (not `alloc_zeros`) is safe because every byte is + // written before any reader sees it: the keccak kernel fills the + // leaves slab, the inner-tree pass (when present) fills the head. + let mut nodes_dev = unsafe { stream.alloc::(nodes_dev_bytes) }?; + let leaves_offset_bytes = commit.leaves_offset_bytes(lde_size); + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); + launch_keccak_base( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut leaves_view, + )?; + } + + if commit == KeccakCommit::FullTree { + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; + } + + // D2H the LDE and the tree/leaves nodes via pinned staging. + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?; + + // Pinned LDE → caller outputs (post-sync host memcpy). + let pinned_ptr = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } + dst.copy_from_slice(src); }); + drop(staging); + + if keep_device_buf { + Ok(Some(GpuLdeBase { + buf: Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + +/// Ext3 variant of `coset_lde_batch_base_into_with_leaf_hash`: fused +/// LDE + Keccak-256 leaf hashing over ext3 columns. Thin wrapper over +/// `coset_lde_batch_ext3_into_with_merkle_tree_inner` with `LeavesOnly`. +pub fn coset_lde_batch_ext3_into_with_leaf_hash( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + hashed_leaves_out: &mut [u8], +) -> Result<()> { + coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + hashed_leaves_out, + KeccakCommit::LeavesOnly, + false, + ) + .map(|_| ()) +} + +/// Ext3 variant of the fused `coset_lde_batch_base_into_with_merkle_tree`. +/// LDE + leaf hashing + inner-tree build, all on device; D2Hs only the LDE +/// evaluations and the full `2*lde_size - 1` node buffer. +pub fn coset_lde_batch_ext3_into_with_merkle_tree( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + KeccakCommit::FullTree, + false, + ) + .map(|_| ()) +} + +/// Ext3 variant of [`coset_lde_batch_base_into_with_merkle_tree_keep`] — +/// returns an `Arc>` handle to the de-interleaved LDE device +/// buffer. +pub fn coset_lde_batch_ext3_into_with_merkle_tree_keep( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result { + let opt = coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + KeccakCommit::FullTree, + true, + )?; + Ok(opt.expect("keep_device_buf=true must return Some")) +} + +#[allow(clippy::too_many_arguments)] +fn coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + nodes_out: &mut [u8], + commit: KeccakCommit, + keep_device_buf: bool, +) -> Result> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(None); + } + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(None); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in columns.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + assert_u32_domain( + lde_size, + "coset_lde_batch_ext3_into_with_merkle_tree lde_size", + ); + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + let nodes_dev_bytes = commit.total_nodes_bytes(lde_size); + assert_eq!(nodes_out.len(), nodes_dev_bytes); + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend()?; + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + pack_ext3_to_pinned_slabs(columns, pinned, n); let mut buf = stream.alloc_zeros::(mb * lde_size)?; for s in 0..mb { @@ -573,6 +984,7 @@ fn evaluate_poly_coset_batch_ext3_into_inner( stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; } + let inv_tw = be.inv_twiddles_for(log_n)?; let fwd_tw = be.fwd_twiddles_for(log_lde)?; let weights_dev = stream.clone_htod(weights)?; @@ -581,43 +993,225 @@ fn evaluate_poly_coset_batch_ext3_into_inner( let col_stride_u64 = lde_size as u64; let mb_u32 = mb as u32; - // Apply coset scaling: x[k] *= weights[k] for k in 0..n (no iFFT first). + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + mb_u32, + )?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + // Allocate device output buffer (LeavesOnly → lde_size*32; FullTree → + // (2*lde_size - 1)*32). Leaf kernel writes to the leaves slab; the + // inner-tree pass (when present) fills the head. + let mut nodes_dev = unsafe { stream.alloc::(nodes_dev_bytes) }?; + let leaves_offset_bytes = commit.leaves_offset_bytes(lde_size); { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); + launch_keccak_ext3( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut leaves_view, + )?; } - // Bit-reverse full lde_size slab, then forward DIT NTT. - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } + if commit == KeccakCommit::FullTree { + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; + } + + // D2H LDE (mb * lde_size u64) and tree/leaves nodes. + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?; + + unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); + drop(staging); + + if keep_device_buf { + Ok(Some(GpuLdeExt3 { + buf: Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + +/// Batched ext3 polynomial → coset evaluation. +/// +/// Input: M ext3 columns of `n` coefficients each (interleaved, 3n u64). +/// Output: M ext3 columns of `n * blowup_factor` evaluations each at the +/// offset-coset. +/// +/// Skips the iFFT stage of [`coset_lde_batch_ext3_into`] (input is +/// coefficients, not evaluations). Weights encode the coset shift: +/// `weights[k] = offset^k` (NO 1/N because iFFT normalisation doesn't apply). +pub fn evaluate_poly_coset_batch_ext3_into( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result<()> { + evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + None, + false, + ) + .map(|_| ()) +} + +/// Same as [`evaluate_poly_coset_batch_ext3_into`] but retains the de- +/// interleaved LDE device buffer as a `GpuLdeExt3` handle so callers can +/// reuse the LDE without a re-H2D. +pub fn evaluate_poly_coset_batch_ext3_into_keep( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result { + let opt = evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + None, + true, + )?; + Ok(opt.expect("keep_device_buf=true must return Some")) +} + +fn evaluate_poly_coset_batch_ext3_into_inner( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: Option<&mut [u8]>, + keep_device_buf: bool, +) -> Result> { + if coefs.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(None); + } + let m = coefs.len(); + assert_eq!(outputs.len(), m); + // Empty domain must short-circuit before the power-of-two assert + // (is_power_of_two returns false for 0). + if n == 0 { + return Ok(None); } + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in coefs.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + assert_u32_domain(lde_size, "evaluate_poly_coset_batch_ext3_into lde_size"); + if merkle_nodes_out.is_some() { + assert!(lde_size >= 2); + } + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend()?; + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + pack_ext3_to_pinned_slabs(coefs, pinned, n); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + // Apply coset scaling: x[k] *= weights[k] for k in 0..n (no iFFT first). + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + mb_u32, + )?; + + // Bit-reverse full lde_size slab, then forward DIT NTT. + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -628,35 +1222,41 @@ fn evaluate_poly_coset_batch_ext3_into_inner( mb_u32, )?; - stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; - stream.synchronize()?; - - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; + // Optional R2-style row-pair Merkle tree build on the LDE buffer. + if let Some(nodes_out) = merkle_nodes_out { + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + assert_eq!(nodes_out.len(), tight_total_nodes * 32); + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let log_num_rows = log_lde; + let num_parts_u64 = m as u64; + let cfg = keccak_launch_cfg(num_leaves as u64); + unsafe { + stream + .launch_builder(&be.keccak_comp_poly_leaves_ext3) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_parts_u64) + .arg(&lde_u64) + .arg(&log_num_rows) + .arg(&mut leaves_view) + .launch(cfg)?; + } } - }); + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; + + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?; + } else { + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + stream.synchronize()?; + } + + unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); drop(staging); if keep_device_buf { Ok(Some(GpuLdeExt3 { @@ -670,6 +1270,32 @@ fn evaluate_poly_coset_batch_ext3_into_inner( } } +/// Fused variant of [`evaluate_poly_coset_batch_ext3_into`]: in addition to +/// the LDE output, builds the R2 composition-polynomial Merkle tree on device +/// (row-pair Keccak leaves at bit-reversed indices + pair-hash inner tree). +/// +/// Row-pair commit: each leaf hashes 2 bit-reversed rows, so the tree has +/// `lde_size / 2` leaves and `merkle_nodes_out` must have byte length +/// `(lde_size - 1) * 32`. Requires `lde_size >= 2`. +pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + Some(merkle_nodes_out), + false, + ) + .map(|_| ()) +} /// Batched coset LDE for Goldilocks **cubic extension** columns. /// /// A degree-3 extension element is `(a, b, c)` in memory (three contiguous @@ -735,27 +1361,7 @@ pub fn coset_lde_batch_ext3_into( staging.ensure_capacity(mb * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; - // Pack: for each ext3 column, write 3 base slabs into pinned. The slab - // for column c, component k lives at `pinned[(c*3 + k)*n .. (c*3+k)*n + n]`. - // We de-interleave from the interleaved `[a, b, c, a, b, c, ...]` input. - let pinned_ptr_u = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - // SAFETY: disjoint regions per c; staging lock held. - let slab_a = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) - }; - let slab_b = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) - }; - let slab_c = unsafe { - std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) - }; - for i in 0..n { - slab_a[i] = col[i * 3]; - slab_b[i] = col[i * 3 + 1]; - slab_c[i] = col[i * 3 + 2]; - } - }); + pack_ext3_to_pinned_slabs(columns, pinned, n); // Allocate + zero-pad device buffer holding 3M slabs of `lde_size`. let mut buf = stream.alloc_zeros::(mb * lde_size)?; @@ -776,23 +1382,15 @@ pub fn coset_lde_batch_ext3_into( // === Butterflies: identical to the base-field batched path, but with // grid.y = 3M instead of M. === - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&n_u64) - .arg(&log_n) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -802,40 +1400,24 @@ pub fn coset_lde_batch_ext3_into( col_stride_u64, mb_u32, )?; - { - let grid_x = (n as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.pointwise_mul_batched) - .arg(&mut buf) - .arg(&weights_dev) - .arg(&n_u64) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } - { - let grid_x = (lde_size as u32).div_ceil(256); - let cfg = LaunchConfig { - grid_dim: (grid_x, mb_u32, 1), - block_dim: (256, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - stream - .launch_builder(&be.bit_reverse_permute_batched) - .arg(&mut buf) - .arg(&lde_u64) - .arg(&log_lde) - .arg(&col_stride_u64) - .launch(cfg)?; - } - } + launch_pointwise_mul_batched( + stream.as_ref(), + be, + &mut buf, + &weights_dev, + n_u64, + col_stride_u64, + mb_u32, + )?; + launch_bit_reverse_batched( + stream.as_ref(), + be, + &mut buf, + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; run_batched_ntt_body( stream.as_ref(), &mut buf, @@ -851,32 +1433,7 @@ pub fn coset_lde_batch_ext3_into( // Unpack: for each output column, re-interleave 3 slabs back into the // ext3-per-element layout. - let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let slab_a = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3) * lde_size), - lde_size, - ) - }; - let slab_b = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 1) * lde_size), - lde_size, - ) - }; - let slab_c = unsafe { - std::slice::from_raw_parts( - (pinned_const as *const u64).add((c * 3 + 2) * lde_size), - lde_size, - ) - }; - for i in 0..lde_size { - dst[i * 3] = slab_a[i]; - dst[i * 3 + 1] = slab_b[i]; - dst[i * 3 + 2] = slab_c[i]; - } - }); + unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size); drop(staging); Ok(()) } diff --git a/crypto/math-cuda/src/lib.rs b/crypto/math-cuda/src/lib.rs index 7731fe4c3..d1b4e1210 100644 --- a/crypto/math-cuda/src/lib.rs +++ b/crypto/math-cuda/src/lib.rs @@ -6,6 +6,7 @@ pub mod device; pub mod lde; +pub mod merkle; pub mod ntt; use cudarc::driver::{LaunchConfig, PushKernelArg}; diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs new file mode 100644 index 000000000..6faf12b51 --- /dev/null +++ b/crypto/math-cuda/src/merkle.rs @@ -0,0 +1,357 @@ +//! GPU Keccak-256 leaf hashing for Merkle commits. +//! +//! Matches `FieldElementVectorBackend::hash_data` in +//! `crypto/crypto/src/merkle_tree/backends/field_element_vector.rs`, combined +//! with the `reverse_index` row read pattern used in +//! `commit_columns_bit_reversed` at `crypto/stark/src/prover.rs`. +//! +//! Caller supplies base-field column slabs already laid out as +//! `[col * col_stride + row]` (the same layout `coset_lde_batch_base_into` +//! writes to the pinned staging buffer). The kernel bit-reverses `row_idx`, +//! reads each column's canonical u64 at that row, byte-swaps it into a +//! Keccak lane, absorbs lane-by-lane, and squeezes 32 bytes per leaf. +//! +//! For ext3 columns the layout is `[col*3*col_stride + k*col_stride + row]`, +//! three base-field components per ext3 column, indexed by `k ∈ {0,1,2}`, +//! and the kernel reads three u64s per column in component order 0,1,2 +//! to match `FieldElement::::write_bytes_be`. + +use cudarc::driver::{CudaSlice, CudaStream, CudaViewMut, LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::{Backend, backend}; +use crate::lde::pack_ext3_to_pinned_slabs; + +/// Run GPU Keccak-256 leaf hashing on a base-field column buffer. +/// +/// `columns` must hold `num_cols * col_stride` u64s with column `c`'s data +/// at `[c*col_stride .. c*col_stride + num_rows]`. Returns `num_rows * 32` +/// hash bytes in natural (non-bit-reversed) row order. +pub fn keccak_leaves_base( + columns: &[u64], + col_stride: usize, + num_cols: usize, + num_rows: usize, +) -> Result> { + assert!(num_rows.is_power_of_two()); + assert!( + col_stride >= num_rows, + "col_stride must be >= num_rows to keep per-column reads in-bounds" + ); + let total = num_cols + .checked_mul(col_stride) + .expect("num_cols * col_stride overflows usize"); + assert!(columns.len() >= total); + let be = backend()?; + let stream = be.next_stream(); + let cols_dev = stream.clone_htod(&columns[..total])?; + let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; + launch_keccak_base( + stream.as_ref(), + &cols_dev, + col_stride as u64, + num_cols as u64, + num_rows as u64, + &mut out_dev.as_view_mut(), + )?; + let out = stream.clone_dtoh(&out_dev)?; + Ok(out) +} + +/// Ext3 variant. Columns interleaved as three base slabs per ext3 column. +/// `columns.len() >= num_cols * 3 * col_stride`. +pub fn keccak_leaves_ext3( + columns: &[u64], + col_stride: usize, + num_cols: usize, + num_rows: usize, +) -> Result> { + assert!(num_rows.is_power_of_two()); + assert!( + col_stride >= num_rows, + "col_stride must be >= num_rows to keep per-column reads in-bounds" + ); + let total = num_cols + .checked_mul(3) + .and_then(|v| v.checked_mul(col_stride)) + .expect("num_cols * 3 * col_stride overflows usize"); + assert!(columns.len() >= total); + let be = backend()?; + let stream = be.next_stream(); + let cols_dev = stream.clone_htod(&columns[..total])?; + let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; + launch_keccak_ext3( + stream.as_ref(), + &cols_dev, + col_stride as u64, + num_cols as u64, + num_rows as u64, + &mut out_dev.as_view_mut(), + )?; + let out = stream.clone_dtoh(&out_dev)?; + Ok(out) +} + +/// Block size for Keccak kernels. Per-thread register footprint is ~60 regs +/// (25-lane state + auxiliaries). The default 256 threads/block pushes the +/// block register file past the hardware limit on sm_120 (Blackwell). 128 +/// keeps us inside the budget with some head-room. +const KECCAK_BLOCK_DIM: u32 = 128; + +pub(crate) fn keccak_launch_cfg(num_rows: u64) -> LaunchConfig { + debug_assert!( + num_rows <= u32::MAX as u64, + "keccak_launch_cfg: num_rows ({num_rows}) exceeds u32 grid range", + ); + let grid = (num_rows as u32).div_ceil(KECCAK_BLOCK_DIM); + LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (KECCAK_BLOCK_DIM, 1, 1), + shared_mem_bytes: 0, + } +} + +/// Walk the inner Merkle tree on device. `nodes_dev` already has the +/// `leaves_len` hashed leaves written into the tail; this loops +/// `log2(leaves_len)` times invoking `keccak_merkle_level` to fill in the +/// inner nodes from the bottom up. Mirrors the CPU `build(nodes, leaves_len)` +/// scan in `crypto/crypto/src/merkle_tree/merkle.rs`. +pub(crate) fn build_inner_tree_levels( + stream: &CudaStream, + be: &Backend, + nodes_dev: &mut CudaSlice, + leaves_len: usize, +) -> Result<()> { + let mut level_begin: u64 = (leaves_len - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let cfg = keccak_launch_cfg(n_pairs); + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut *nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + Ok(()) +} + +pub(crate) fn launch_keccak_base( + stream: &CudaStream, + cols_dev: &CudaSlice, + col_stride: u64, + num_cols: u64, + num_rows: u64, + out_dev: &mut CudaViewMut<'_, u8>, +) -> Result<()> { + // The kernel computes `__brevll(tid) >> (64 - log_num_rows)`, which is UB + // for `log_num_rows == 0` (single-row trees are degenerate anyway). + debug_assert!(num_rows >= 2, "keccak leaf kernel: num_rows must be >= 2"); + let be = backend()?; + let log_num_rows = num_rows.trailing_zeros() as u64; + let cfg = keccak_launch_cfg(num_rows); + unsafe { + stream + .launch_builder(&be.keccak256_leaves_base_batched) + .arg(cols_dev) + .arg(&col_stride) + .arg(&num_cols) + .arg(&num_rows) + .arg(&log_num_rows) + .arg(out_dev) + .launch(cfg)?; + } + Ok(()) +} + +/// Given `hashed_leaves` of length `leaves_len * 32`, build the full Merkle +/// tree on device and return the complete node buffer `(2*leaves_len - 1) * +/// 32` bytes in the standard layout: +/// +/// `nodes[0..leaves_len - 1]` are inner nodes (root at index 0), and +/// `nodes[leaves_len - 1..]` are the leaves themselves. +/// +/// Matches the CPU `crypto/crypto/src/merkle_tree/merkle.rs` construction so +/// the resulting `nodes` Vec plugs straight into `MerkleTree { root, nodes }` +/// for downstream proof generation. +/// +/// `leaves_len` must be a power of two and >= 2. +pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { + assert!(hashed_leaves.len().is_multiple_of(32)); + let leaves_len = hashed_leaves.len() / 32; + assert!(leaves_len >= 2, "tree needs at least two leaves"); + assert!( + leaves_len.is_power_of_two(), + "leaves_len must be a power of two" + ); + + let total_nodes = 2 * leaves_len - 1; + let be = backend()?; + let stream = be.next_stream(); + + // Allocate the full node buffer without zero-fill. We overwrite the + // leaf half via H2D immediately, and every inner node is written by the + // pair-hash kernel below. + // SAFETY: every byte is written before it is read: leaves are filled by + // the H2D below; inner nodes are filled by the level loop that follows. + let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; + let leaves_offset_bytes = (leaves_len - 1) * 32; + // SAFETY: target slice `nodes_dev[leaves_offset_bytes..]` has exactly + // `leaves_len * 32 == hashed_leaves.len()` bytes capacity. + { + let mut slice = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + hashed_leaves.len()); + stream.memcpy_htod(hashed_leaves, &mut slice)?; + } + + build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, leaves_len)?; + + let out = stream.clone_dtoh(&nodes_dev)?; + Ok(out) +} + +/// Row-pair Keccak leaf + Merkle tree build for R2 composition-polynomial +/// commit. `parts_interleaved` is `num_parts` slices, each holding an ext3 +/// LDE column interleaved as `[a0,a1,a2, b0,b1,b2, ...]` of length `3*lde_size`. +/// +/// Returns `(2*(lde_size/2) - 1) * 32` bytes of tree nodes in the standard +/// layout (root at byte offset 0, leaves in the tail). +pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Result> { + assert!(!parts_interleaved.is_empty()); + let m = parts_interleaved.len(); + let ext3_elems = parts_interleaved[0].len() / 3; + assert_eq!( + parts_interleaved[0].len(), + 3 * ext3_elems, + "ext3 buffer length must be 3 * lde_size" + ); + for p in parts_interleaved.iter() { + assert_eq!(p.len(), 3 * ext3_elems); + } + let lde_size = ext3_elems; + assert!(lde_size.is_power_of_two() && lde_size >= 2); + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + + let be = backend()?; + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + // Stage: de-interleave each part into 3 base slabs in pinned memory. + let mb = 3 * m; + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + pack_ext3_to_pinned_slabs(parts_interleaved, pinned, lde_size); + + // H2D the de-interleaved parts. + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + stream.memcpy_htod(&pinned[..mb * lde_size], &mut buf)?; + + // Leaves into tail of a tight node buffer. + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let col_stride_u64 = lde_size as u64; + let num_parts_u64 = m as u64; + let num_rows_u64 = lde_size as u64; + let log_num_rows = lde_size.trailing_zeros() as u64; + let cfg = keccak_launch_cfg(num_leaves as u64); + unsafe { + stream + .launch_builder(&be.keccak_comp_poly_leaves_ext3) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_parts_u64) + .arg(&num_rows_u64) + .arg(&log_num_rows) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; + + let out = stream.clone_dtoh(&nodes_dev)?; + drop(staging); + Ok(out) +} + +/// Build a FRI-layer Merkle tree on device from an interleaved ext3 eval +/// vector. Each leaf hashes two consecutive ext3 values. `num_leaves = +/// evals.len() / 6` (since each ext3 is 3 u64s). +/// +/// Returns the `(2*num_leaves - 1) * 32`-byte node buffer in standard layout. +pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { + assert!( + evals.len().is_multiple_of(6), + "evals must hold whole pair-leaves" + ); + let num_evals = evals.len() / 3; + let num_leaves = num_evals / 2; + assert!(num_leaves.is_power_of_two() && num_leaves >= 2); + let tight_total_nodes = 2 * num_leaves - 1; + + let be = backend()?; + let stream = be.next_stream(); + + let evals_dev = stream.clone_htod(evals)?; + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + + // Leaf kernel: num_leaves threads, one leaf each. + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let num_leaves_u64 = num_leaves as u64; + let cfg = keccak_launch_cfg(num_leaves as u64); + unsafe { + stream + .launch_builder(&be.keccak_fri_leaves_ext3) + .arg(&evals_dev) + .arg(&num_leaves_u64) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; + + let out = stream.clone_dtoh(&nodes_dev)?; + Ok(out) +} + +pub(crate) fn launch_keccak_ext3( + stream: &CudaStream, + cols_dev: &CudaSlice, + col_stride: u64, + num_cols: u64, + num_rows: u64, + out_dev: &mut CudaViewMut<'_, u8>, +) -> Result<()> { + // The kernel computes `__brevll(tid) >> (64 - log_num_rows)`, which is UB + // for `log_num_rows == 0` (single-row trees are degenerate anyway). + debug_assert!(num_rows >= 2, "keccak leaf kernel: num_rows must be >= 2"); + let be = backend()?; + let log_num_rows = num_rows.trailing_zeros() as u64; + let cfg = keccak_launch_cfg(num_rows); + unsafe { + stream + .launch_builder(&be.keccak256_leaves_ext3_batched) + .arg(cols_dev) + .arg(&col_stride) + .arg(&num_cols) + .arg(&num_rows) + .arg(&log_num_rows) + .arg(out_dev) + .launch(cfg)?; + } + Ok(()) +} diff --git a/crypto/math-cuda/tests/keccak_leaves.rs b/crypto/math-cuda/tests/keccak_leaves.rs new file mode 100644 index 000000000..d614e233d --- /dev/null +++ b/crypto/math-cuda/tests/keccak_leaves.rs @@ -0,0 +1,198 @@ +//! Parity: GPU Keccak-256 leaf hashes must match the CPU prover's leaf +//! hashing helpers. `stark::prover::keccak_leaves_bit_reversed` for +//! per-row commits, `keccak_leaves_row_pair_bit_reversed` for the R2 +//! composition commit, and `FriLayerMerkleTreeBackend::hash_data` for the +//! FRI commit. These are the same helpers the prover itself calls so any +//! change to the CPU leaf-hash contract surfaces here. + +use crypto::merkle_tree::traits::IsMerkleTreeBackend; +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use stark::config::FriLayerMerkleTreeBackend; +use stark::prover::{keccak_leaves_bit_reversed, keccak_leaves_row_pair_bit_reversed}; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +#[test] +fn keccak_leaves_base_matches_cpu() { + for log_n in [4u32, 6, 8, 10, 12] { + for num_cols in [1usize, 5, 17, 41] { + let n = 1 << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(100 + log_n as u64 + num_cols as u64); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| Fp::from_raw(rng.r#gen::())).collect()) + .collect(); + + let cpu = keccak_leaves_bit_reversed(&columns); + + // Flatten columns into a contiguous base slab layout matching + // `coset_lde_batch_base_into`'s pinned staging format: + // `[col * stride + row]`. Use stride = num_rows for compactness. + let mut flat = vec![0u64; num_cols * n]; + for (c, col) in columns.iter().enumerate() { + for (r, e) in col.iter().enumerate() { + flat[c * n + r] = *e.value(); + } + } + let gpu = math_cuda::merkle::keccak_leaves_base(&flat, n, num_cols, n).unwrap(); + assert_eq!(gpu.len(), n * 32); + for i in 0..n { + assert_eq!( + &gpu[i * 32..(i + 1) * 32], + &cpu[i][..], + "base leaf mismatch at row {i} (log_n={log_n}, cols={num_cols})" + ); + } + } + } +} + +#[test] +fn keccak_leaves_ext3_matches_cpu() { + for log_n in [4u32, 6, 8, 10] { + for num_cols in [1usize, 3, 11, 20] { + let n = 1 << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(200 + log_n as u64 + num_cols as u64); + let columns: Vec> = (0..num_cols) + .map(|_| { + (0..n) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() + }) + .collect(); + + let cpu = keccak_leaves_bit_reversed(&columns); + + // GPU expects 3 base slabs per ext3 column in the order + // [col*3+0 (comp a), col*3+1 (comp b), col*3+2 (comp c)], each a + // contiguous slab of n u64s (length = num_cols * 3 * n). + let mut flat = vec![0u64; num_cols * 3 * n]; + for (c, col) in columns.iter().enumerate() { + for (r, e) in col.iter().enumerate() { + flat[(c * 3) * n + r] = *e.value()[0].value(); + flat[(c * 3 + 1) * n + r] = *e.value()[1].value(); + flat[(c * 3 + 2) * n + r] = *e.value()[2].value(); + } + } + let gpu = math_cuda::merkle::keccak_leaves_ext3(&flat, n, num_cols, n).unwrap(); + assert_eq!(gpu.len(), n * 32); + for i in 0..n { + assert_eq!( + &gpu[i * 32..(i + 1) * 32], + &cpu[i][..], + "ext3 leaf mismatch at row {i} (log_n={log_n}, cols={num_cols})" + ); + } + } + } +} + +#[test] +fn keccak_comp_poly_leaves_matches_cpu() { + // Built tree's leaves live at byte offset `(num_leaves - 1) * 32` and + // span `num_leaves * 32` bytes. Compare those to the CPU reference. + for log_lde in [2u32, 4, 6, 8, 10, 12] { + for num_parts in [1usize, 2, 5, 17] { + let lde_size = 1usize << log_lde; + let mut rng = ChaCha8Rng::seed_from_u64(300 + log_lde as u64 + num_parts as u64); + let parts: Vec> = (0..num_parts) + .map(|_| { + (0..lde_size) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() + }) + .collect(); + let cpu = keccak_leaves_row_pair_bit_reversed(&parts); + + // Each part is passed as `[a0,a1,a2, b0,b1,b2, ...]` of length `3 * lde_size`. + let parts_interleaved: Vec> = parts + .iter() + .map(|p| { + let mut v = vec![0u64; 3 * lde_size]; + for (i, e) in p.iter().enumerate() { + v[i * 3] = *e.value()[0].value(); + v[i * 3 + 1] = *e.value()[1].value(); + v[i * 3 + 2] = *e.value()[2].value(); + } + v + }) + .collect(); + let parts_slices: Vec<&[u64]> = + parts_interleaved.iter().map(|v| v.as_slice()).collect(); + + let nodes = + math_cuda::merkle::build_comp_poly_tree_from_evals_ext3(&parts_slices).unwrap(); + let num_leaves = lde_size / 2; + let leaves_offset = (num_leaves - 1) * 32; + for i in 0..num_leaves { + assert_eq!( + &nodes[leaves_offset + i * 32..leaves_offset + (i + 1) * 32], + &cpu[i][..], + "comp-poly leaf mismatch at i={i} (log_lde={log_lde}, parts={num_parts})" + ); + } + } + } +} + +#[test] +fn keccak_fri_leaves_matches_cpu() { + for log_lde in [2u32, 4, 6, 8, 10, 12] { + let lde_size = 1usize << log_lde; + let mut rng = ChaCha8Rng::seed_from_u64(400 + log_lde as u64); + let evals: Vec = (0..lde_size) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect(); + + // CPU reference: consecutive ext3 pairs hashed via the prover's + // FRI-layer Merkle backend. + let cpu: Vec<[u8; 32]> = evals + .chunks_exact(2) + .map(|c| { + FriLayerMerkleTreeBackend::::hash_data(&[ + c[0], c[1], + ]) + }) + .collect(); + + let mut evals_interleaved = vec![0u64; 3 * lde_size]; + for (i, e) in evals.iter().enumerate() { + evals_interleaved[i * 3] = *e.value()[0].value(); + evals_interleaved[i * 3 + 1] = *e.value()[1].value(); + evals_interleaved[i * 3 + 2] = *e.value()[2].value(); + } + let nodes = + math_cuda::merkle::build_fri_layer_tree_from_evals_ext3(&evals_interleaved).unwrap(); + let num_leaves = lde_size / 2; + let leaves_offset = (num_leaves - 1) * 32; + for i in 0..num_leaves { + assert_eq!( + &nodes[leaves_offset + i * 32..leaves_offset + (i + 1) * 32], + &cpu[i][..], + "fri leaf mismatch at i={i} (log_lde={log_lde})" + ); + } + } +} diff --git a/crypto/math-cuda/tests/merkle_tree.rs b/crypto/math-cuda/tests/merkle_tree.rs new file mode 100644 index 000000000..76fdeb919 --- /dev/null +++ b/crypto/math-cuda/tests/merkle_tree.rs @@ -0,0 +1,59 @@ +//! Parity: GPU Merkle inner-tree construction must match the CPU +//! `crypto/crypto/src/merkle_tree/merkle.rs` `build_from_hashed_leaves` +//! (Keccak-256 pair hash at each level). Uses the prover's +//! `FieldElementVectorBackend<_, Keccak256, 32>` directly so any change to +//! the CPU tree builder is automatically exercised here. + +use crypto::merkle_tree::backends::field_element_vector::FieldElementVectorBackend; +use crypto::merkle_tree::merkle::MerkleTree; +use math::field::goldilocks::GoldilocksField; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sha3::Keccak256; + +type CpuTree = MerkleTree>; + +fn run_parity(log_n: u32, seed: u64) { + let leaves_len = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let leaves: Vec<[u8; 32]> = (0..leaves_len) + .map(|_| { + let mut arr = [0u8; 32]; + rng.fill(&mut arr[..]); + arr + }) + .collect(); + + // Flat byte layout for the GPU entry point. + let mut flat = Vec::with_capacity(leaves_len * 32); + for l in &leaves { + flat.extend_from_slice(l); + } + + let gpu_nodes_bytes = math_cuda::merkle::build_merkle_tree_on_device(&flat).unwrap(); + assert_eq!(gpu_nodes_bytes.len(), (2 * leaves_len - 1) * 32); + + // CPU reference: the prover's MerkleTree builder over the same backend. + let cpu_tree = CpuTree::build_from_hashed_leaves(leaves).unwrap(); + let cpu_nodes = cpu_tree.nodes(); + + for (i, c) in cpu_nodes.iter().enumerate() { + let g = &gpu_nodes_bytes[i * 32..(i + 1) * 32]; + assert_eq!( + g, c, + "node {i} mismatch at log_n={log_n} (cpu={c:?}, gpu={g:?})" + ); + } +} + +#[test] +fn merkle_tree_small() { + for log_n in 1u32..=6 { + run_parity(log_n, 100 + log_n as u64); + } +} + +#[test] +fn merkle_tree_large() { + run_parity(18, 9999); +} diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index a5386017a..8a4577360 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -10,7 +10,7 @@ use math::fft::errors::FFTError; use log::info; use math::field::traits::{IsField, IsSubFieldOf}; -use math::traits::AsBytes; +use math::traits::{AsBytes, ByteConversion}; use math::{ field::{element::FieldElement, traits::IsFFTField}, polynomial::Polynomial, @@ -374,6 +374,102 @@ where } } +/// Compute Keccak-256 leaf hashes for `commit_columns_bit_reversed`: one +/// leaf per row, where each row is read at `reverse_index(row_idx)` and the +/// columns are concatenated as big-endian bytes before hashing. +/// +/// Returns `Vec` with the same length as `columns[0]`. Exposed +/// (instead of being a closure inside `commit_columns_bit_reversed`) so +/// parity tests in dependent crates can compare against the same code path +/// the prover uses. +pub fn keccak_leaves_bit_reversed(columns: &[Vec>]) -> Vec +where + E: IsField, + FieldElement: AsBytes + Sync + Send + ByteConversion, +{ + if columns.is_empty() || columns[0].is_empty() { + return Vec::new(); + } + + let num_rows = columns[0].len(); + let num_cols = columns.len(); + let byte_len = as ByteConversion>::BYTE_LEN; + + debug_assert!( + num_rows.is_power_of_two(), + "num_rows must be a power of two for reverse_index" + ); + + #[cfg(feature = "parallel")] + let iter = (0..num_rows).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let iter = 0..num_rows; + + iter.map(|row_idx| { + let br_idx = reverse_index(row_idx, num_rows as u64); + let total_bytes = num_cols * byte_len; + let mut buf = vec![0u8; total_bytes]; + for col_idx in 0..num_cols { + columns[col_idx][br_idx] + .write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]); + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) + }) + .collect() +} + +/// Compute Keccak-256 leaf hashes for `commit_composition_polynomial`: one +/// leaf per row-pair, where leaf `i` hashes the BE concatenation of +/// `parts[..][br_0] ++ parts[..][br_1]` with +/// `br_k = reverse_index(2*i + k, num_rows)`. +/// +/// Returns `Vec` of length `parts[0].len() / 2`. +pub fn keccak_leaves_row_pair_bit_reversed(parts: &[Vec>]) -> Vec +where + E: IsField, + FieldElement: AsBytes + Sync + Send + ByteConversion, +{ + let num_parts = parts.len(); + if num_parts == 0 { + return Vec::new(); + } + let num_rows = parts[0].len(); + if num_rows == 0 { + return Vec::new(); + } + + let num_leaves = num_rows / 2; + debug_assert!( + num_rows.is_power_of_two(), + "num_rows must be a power of two for reverse_index" + ); + + let byte_len = as ByteConversion>::BYTE_LEN; + + #[cfg(feature = "parallel")] + let iter = (0..num_leaves).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let iter = 0..num_leaves; + + iter.map(|leaf_idx| { + let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); + let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); + let total_bytes = 2 * num_parts * byte_len; + let mut buf = vec![0u8; total_bytes]; + let mut offset = 0; + for part in parts.iter() { + part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + for part in parts.iter() { + part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) + }) + .collect() +} + /// The functionality of a STARK prover providing methods to run the STARK Prove protocol /// https://lambdaclass.github.io/lambdaworks/starks/protocol.html /// The default implementation is complete and is compatible with Stone prover @@ -400,41 +496,10 @@ pub trait IsStarkProver< FieldElement: AsBytes + Sync + Send + math::traits::ByteConversion, E: IsField, { - use math::traits::ByteConversion; - if columns.is_empty() || columns[0].is_empty() { return None; } - - let num_rows = columns[0].len(); - let num_cols = columns.len(); - let byte_len = as ByteConversion>::BYTE_LEN; - - debug_assert!( - num_rows.is_power_of_two(), - "num_rows must be a power of two for reverse_index" - ); - - #[cfg(feature = "parallel")] - let iter = (0..num_rows).into_par_iter(); - #[cfg(not(feature = "parallel"))] - let iter = 0..num_rows; - - // One allocation per row (was one per field element): write all columns - // into a single buffer, then hash once. - let hashed_leaves: Vec = iter - .map(|row_idx| { - let br_idx = reverse_index(row_idx, num_rows as u64); - let total_bytes = num_cols * byte_len; - let mut buf = vec![0u8; total_bytes]; - for col_idx in 0..num_cols { - columns[col_idx][br_idx] - .write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]); - } - BatchedMerkleTreeBackend::::hash_bytes(&buf) - }) - .collect(); - + let hashed_leaves = keccak_leaves_bit_reversed(columns); let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; Some((tree, root)) @@ -723,8 +788,6 @@ pub trait IsStarkProver< FieldElement: AsBytes + Sync + Send, FieldElement: AsBytes + Sync + Send + math::traits::ByteConversion, { - use math::traits::ByteConversion; - let num_parts = lde_composition_poly_parts_evaluations.len(); if num_parts == 0 { return None; @@ -733,41 +796,8 @@ pub trait IsStarkProver< if num_rows == 0 { return None; } - - let num_leaves = num_rows / 2; - debug_assert!( - num_rows.is_power_of_two(), - "num_rows must be a power of two for reverse_index" - ); - - let byte_len = as ByteConversion>::BYTE_LEN; - - // One allocation per leaf (was one per field element): write all parts - // into a single buffer. Each leaf = row_pair[2*i] ++ row_pair[2*i+1] after bit-reverse. - #[cfg(feature = "parallel")] - let iter = (0..num_leaves).into_par_iter(); - #[cfg(not(feature = "parallel"))] - let iter = 0..num_leaves; - - let hashed_leaves: Vec = iter - .map(|leaf_idx| { - let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); - let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); - let total_bytes = 2 * num_parts * byte_len; - let mut buf = vec![0u8; total_bytes]; - let mut offset = 0; - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - BatchedMerkleTreeBackend::::hash_bytes(&buf) - }) - .collect(); - + let hashed_leaves = + keccak_leaves_row_pair_bit_reversed(lde_composition_poly_parts_evaluations); let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; Some((tree, root))