diff --git a/crypto/crypto/src/merkle_tree/merkle.rs b/crypto/crypto/src/merkle_tree/merkle.rs index 4ea0e5411..5cc203550 100644 --- a/crypto/crypto/src/merkle_tree/merkle.rs +++ b/crypto/crypto/src/merkle_tree/merkle.rs @@ -123,6 +123,44 @@ where Self::build_from_hashed_leaves(hashed_leaves) } + /// Useful for handing a GPU-built tree to the stark prover. + /// Performs no hashing, the caller is responsible for the layout's + /// cryptographic correctness. + /// + /// Expected layout (matches [`build_from_hashed_leaves`]): + /// - `nodes.len() == 2 * leaves_len - 1` where `leaves_len` is a power of two + /// - `nodes[0]` is the root + /// - `nodes[leaves_len - 1 .. 2*leaves_len - 1]` are the leaves + pub fn from_precomputed_nodes(nodes: Vec) -> Option { + if nodes.is_empty() { + return None; + } + // Validate (cheap) that (nodes.len() + 1) is a power of two: there + // must be `leaves_len - 1 + leaves_len = 2*leaves_len - 1` entries. + let total = nodes.len(); + if !(total + 1).is_power_of_two() { + return None; + } + // Debug-only integrity spot-check: the root must equal hash(left, right). + // Catches GPU correctness regressions in CI without paying for a full + // tree walk on every call. + #[cfg(debug_assertions)] + if total >= 3 { + let expected_root = B::hash_new_parent(&nodes[1], &nodes[2]); + debug_assert!( + nodes[ROOT] == expected_root, + "from_precomputed_nodes: root does not hash from children", + ); + } + let root = nodes[ROOT].clone(); + Some(MerkleTree { + root, + nodes, + #[cfg(feature = "disk-spill")] + mmap_backing: None, + }) + } + /// Create a Merkle tree from pre-hashed leaf nodes. /// /// This skips the `hash_leaves` step, useful when leaves have already been diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index d6d5fc403..f2cc988c0 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -102,16 +102,26 @@ const STREAM_POOL_SIZE: usize = 32; pub struct Backend { pub ctx: Arc, streams: Vec>, - /// Single shared pinned staging buffer, grown to the biggest LDE size - /// seen. Concurrent batched LDE calls serialise on it; in exchange the - /// process keeps only ONE gigabyte-sized pinned allocation (per-stream - /// buffers 32×-inflated memory use and multiplied the one-time pinning - /// cost for every first use of a new table size). - pinned_staging: Mutex, - /// Separate pinned staging for Merkle leaf hashes. Sized `num_rows * 32` - /// bytes. It lives alongside the LDE staging so the GPU→host D2H for - /// hashed leaves runs at full PCIe line-rate. - pinned_hashes: Mutex, + /// Per-rayon-worker pinned staging buffers, grown lazily to the biggest + /// LDE size each worker sees. Indexed by `rayon::current_thread_index()` + /// (or 0 for non-rayon callers). + /// + /// Per-worker (not single-shared) because the LDE call holds the lock + /// across an internal rayon `par_chunks_mut`/`par_iter` window: with a + /// single shared mutex, rayon work-stealing can yield a lock-holder onto + /// another task waiting for the same lock — classic recursive-rayon + /// deadlock. Per-worker buffers eliminate cross-worker contention so + /// each `par_iter` worker hits a distinct mutex. + /// + /// Each entry starts empty (`PinnedStaging::empty()` is a zero-cost null + /// handle); only the slots actually used by the running workers ever + /// allocate pinned memory. Worst-case footprint is `N_workers × + /// max_LDE_size` of pinned host RAM. + pinned_staging: Vec>, + /// Per-worker pinned staging for Merkle leaf hashes. Same layout as + /// `pinned_staging`; sized `num_rows * 32` bytes per slot. Lives + /// alongside the LDE staging so the GPU→host D2H runs at PCIe line-rate. + pinned_hashes: Vec>, util_stream: Arc, next: AtomicUsize, @@ -166,8 +176,20 @@ impl Backend { for _ in 0..STREAM_POOL_SIZE { streams.push(ctx.new_stream()?); } - let pinned_staging = Mutex::new(PinnedStaging::empty()); - let pinned_hashes = Mutex::new(PinnedStaging::empty()); + // Size to the rayon worker count, plus one for non-rayon callers + // who land on slot 0 (`rayon::current_thread_index()` returns None + // outside a rayon context — we map that to 0). + // + // `current_num_threads()` returns the default-pool size if no custom + // pool is in use, which is the cpu count. Stable across the + // backend's lifetime since rayon's pool is fixed at first use. + let n_slots = rayon::current_num_threads().max(1); + let pinned_staging: Vec> = (0..n_slots) + .map(|_| Mutex::new(PinnedStaging::empty())) + .collect(); + let pinned_hashes: Vec> = (0..n_slots) + .map(|_| Mutex::new(PinnedStaging::empty())) + .collect(); // Separate "utility" stream for twiddle uploads and other bookkeeping; // not part of the pool that callers rotate through. let util_stream = ctx.new_stream()?; @@ -219,16 +241,29 @@ impl Backend { self.streams[idx].clone() } - /// Shared pinned staging buffer. Grows to the largest LDE the process - /// has seen so far. Concurrent callers serialise on the mutex. + /// Per-rayon-worker pinned staging buffer. Returns the slot for the + /// current worker (or slot 0 outside a rayon context). Grows lazily to + /// the largest LDE the worker has seen. See the field docs for the + /// rationale behind the per-worker split. pub fn pinned_staging(&self) -> &Mutex { - &self.pinned_staging + &self.pinned_staging[self.worker_slot(self.pinned_staging.len())] } - /// Separate pinned staging for Merkle leaf hash output. Sized in u64 + /// Per-worker pinned staging for Merkle leaf hash output. Sized in u64 /// units. Caller should reserve `(num_rows * 32 + 7) / 8` u64s. pub fn pinned_hashes(&self) -> &Mutex { - &self.pinned_hashes + &self.pinned_hashes[self.worker_slot(self.pinned_hashes.len())] + } + + /// Map `rayon::current_thread_index()` to a slot index, with a defensive + /// clamp in case the rayon pool grew past the Vec we sized at init. + fn worker_slot(&self, len: usize) -> usize { + let idx = rayon::current_thread_index().unwrap_or(0); + // Should be unreachable with rayon's fixed default pool, but if a + // larger custom pool sneaks in we still want safety — fall back to + // slot 0 (correctness preserved, just contention). + debug_assert!(idx < len, "rayon worker {idx} >= staging slots {len}"); + idx.min(len.saturating_sub(1)) } pub fn fwd_twiddles_for(&self, log_n: u64) -> Result>> { diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 02f109938..48b580994 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -13,7 +13,6 @@ use std::sync::Arc; use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; -use rayon::prelude::*; use crate::Result; use crate::device::{Backend, backend}; @@ -69,7 +68,10 @@ pub(crate) fn pack_ext3_to_pinned_slabs(columns: &[&[u64]], pinned: &mut [u64], let m = columns.len(); debug_assert!(pinned.len() >= 3 * m * n); let pinned_ptr_u = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { + // Sequential, not `par_iter`: this runs while the per-worker pinned + // staging mutex is held. Rayon inside a held mutex risks recursive + // stealing-during-wait deadlocks — see `Backend::pinned_staging` docs. + columns.iter().enumerate().for_each(|(c, col)| { // SAFETY: each task writes to disjoint `[(c*3 + k)*n .. ..+n]` regions // of `pinned`. The outer `&mut [u64]` borrow guarantees no aliasing. let slab_a = unsafe { @@ -96,7 +98,9 @@ fn unpack_pinned_slabs_to_ext3(pinned: &[u64], outputs: &mut [&mut [u64]], lde_s let m = outputs.len(); debug_assert!(pinned.len() >= 3 * m * lde_size); let pinned_const = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + // Sequential, not `par_iter_mut`: runs inside a held pinned-staging + // mutex; rayon-inside-mutex risks deadlock (see `Backend::pinned_staging`). + outputs.iter_mut().enumerate().for_each(|(c, dst)| { // SAFETY: each task reads from disjoint `[(c*3 + k)*lde_size .. ..+lde_size]` // regions of `pinned`. Caller borrows `pinned` for the duration of the call. let slab_a = unsafe { @@ -178,19 +182,14 @@ fn d2h_bytes_via_pinned_hashes( stream.memcpy_dtoh(dev_bytes, pinned_bytes)?; stream.synchronize()?; - // Single-threaded `copy_from_slice` faults virgin pageable pages one at - // a time; the mm_struct rwsem serialises them at prover scale. Chunk so - // ~N cores pre-fault+write in parallel. - const CHUNK: usize = 64 * 1024; - let src_ptr = pinned_bytes.as_ptr() as usize; - dst.par_chunks_mut(CHUNK).enumerate().for_each(|(i, d)| { - // SAFETY: each task reads `[i*CHUNK .. i*CHUNK + d.len()]` of - // `pinned_bytes`, which is disjoint per `i` and lives until `staging` - // is dropped below. - let src = - unsafe { std::slice::from_raw_parts((src_ptr as *const u8).add(i * CHUNK), d.len()) }; - d.copy_from_slice(src); - }); + // Sequential, not `par_chunks_mut`: this runs while the per-worker + // pinned_hashes mutex is held. Rayon inside a held mutex risks + // recursive stealing-during-wait deadlocks — see + // `Backend::pinned_staging` docs. Page-fault parallelism on virgin + // destination pages is recovered at the outer level: per-worker + // staging buffers let rayon's outer `par_iter` dispatch multiple LDE + // calls in parallel, each faulting its own destination pages. + dst.copy_from_slice(pinned_bytes); drop(staging); Ok(()) } @@ -367,18 +366,14 @@ pub fn coset_lde_batch_base( // SAFETY: staging is locked, the slice alias ends before we unlock. let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - // Pack columns into first m*n slots of the pinned buffer. Parallel: pinned - // writes are DRAM-bandwidth bound, so rayon spreads the cost across CPU - // cores. - let pinned_base_ptr = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - // SAFETY: each task writes to a disjoint `[c*n..c*n+n]` region of - // `pinned`, and the outer `staging` lock guarantees no other call is - // using the buffer concurrently. - let dst = - unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; - dst.copy_from_slice(col); - }); + // Pack columns into first m*n slots of the pinned buffer. Sequential + // (not `par_iter`) because this runs inside the held pinned-staging + // mutex — see `Backend::pinned_staging` docs. Pre-fault parallelism on + // the destination is recovered at the outer level via per-worker + // staging slots. + for (c, col) in columns.iter().enumerate() { + pinned[c * n..c * n + n].copy_from_slice(col); + } // Column layout: `buf[c * lde_size + r]`. Zeroed so the [n, lde_size) // tail of each column is already the zero-pad the CPU path does. @@ -459,12 +454,11 @@ pub fn coset_lde_batch_base( stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; stream.synchronize()?; - // Split pinned → per-column Vecs. The first write to each virgin - // Vec page-faults, which can dominate total time. Parallelise so the - // fault cost spreads across CPU cores. - let pinned_ptr = pinned.as_ptr() as usize; + // Split pinned → per-column Vecs. Sequential (not `into_par_iter`) + // because this runs inside the held pinned-staging mutex — see + // `Backend::pinned_staging` docs. Fault-cost parallelism is recovered + // at the outer level (per-worker staging slots). let out: Vec> = (0..m) - .into_par_iter() .map(|c| { // set_len skips the O(N) zero-init that vec![0; n] would do. // copy_from_slice below writes every slot before any reader @@ -475,10 +469,7 @@ pub fn coset_lde_batch_base( unsafe { v.set_len(lde_size) }; v }; - let src = unsafe { - std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) - }; - v.copy_from_slice(src); + v.copy_from_slice(&pinned[c * lde_size..c * lde_size + lde_size]); v }) .collect(); @@ -602,15 +593,14 @@ pub fn coset_lde_batch_base_into( stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; stream.synchronize()?; - // Parallel copy pinned → caller outputs. Caller's Vecs may still fault - // on first write; we spread that cost across rayon cores. - let pinned_ptr = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) - }; - dst.copy_from_slice(src); - }); + // Sequential copy pinned → caller outputs (not `par_iter_mut`): runs + // inside the held pinned-staging mutex; rayon-inside-mutex risks + // recursive stealing-during-wait deadlocks (see + // `Backend::pinned_staging`). Fault-cost parallelism is recovered at + // the outer level via per-worker staging slots. + for (c, dst) in outputs.iter_mut().enumerate() { + dst.copy_from_slice(&pinned[c * lde_size..c * lde_size + lde_size]); + } drop(staging); Ok(()) } @@ -734,12 +724,12 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( staging.ensure_capacity(m * lde_size, &be.ctx)?; let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; - let pinned_base_ptr = pinned.as_mut_ptr() as usize; - columns.par_iter().enumerate().for_each(|(c, col)| { - let dst = - unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) }; - dst.copy_from_slice(col); - }); + // Sequential pack (not `par_iter`): runs inside the held pinned-staging + // mutex (see `Backend::pinned_staging` docs). Per-worker staging slots + // give the outer parallelism back. + for (c, col) in columns.iter().enumerate() { + pinned[c * n..c * n + n].copy_from_slice(col); + } let mut buf = stream.alloc_zeros::(m * lde_size)?; for c in 0..m { @@ -833,14 +823,11 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?; - // Pinned LDE → caller outputs (post-sync host memcpy). - let pinned_ptr = pinned.as_ptr() as usize; - outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { - let src = unsafe { - std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size) - }; - dst.copy_from_slice(src); - }); + // Sequential pinned → caller outputs (not `par_iter_mut`): runs inside + // the held pinned-staging mutex (see `Backend::pinned_staging` docs). + for (c, dst) in outputs.iter_mut().enumerate() { + dst.copy_from_slice(&pinned[c * lde_size..c * lde_size + lde_size]); + } drop(staging); if keep_device_buf { diff --git a/crypto/stark/Cargo.toml b/crypto/stark/Cargo.toml index 8c109ff93..ccf792745 100644 --- a/crypto/stark/Cargo.toml +++ b/crypto/stark/Cargo.toml @@ -26,6 +26,7 @@ rayon = { version = "1.8.0", optional = true } memmap2 = { version = "0.9", optional = true } tempfile = { version = "3", optional = true } libc = { version = "0.2", optional = true } + # GPU backend for trace LDE — only linked when `cuda` is enabled. math-cuda = { path = "../math-cuda", optional = true } diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs new file mode 100644 index 000000000..ea4f04488 --- /dev/null +++ b/crypto/stark/src/gpu_lde.rs @@ -0,0 +1,596 @@ +//! GPU dispatch layer for the per-column coset LDE. +//! +//! Handles only Goldilocks base-field columns above a size threshold. Falls +//! back to CPU for extension-field columns and small columns where kernel +//! launch overhead dominates. Produces the same natural-order, non-canonical +//! LDE evaluations as the CPU path. + +use std::any::TypeId; +use std::slice::{from_raw_parts, from_raw_parts_mut}; +use std::sync::OnceLock; +use std::sync::atomic::{AtomicU64, Ordering}; + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsSubFieldOf}; + +use crate::domain::Domain; + +/// Break-even LDE size. Below this, the CPU `coset_lde_full_expand` completes +/// in a few hundred microseconds and the GPU's tens of kernel launches plus +/// H2D/D2H round-trip is a net loss. The check is on **lde size**, not trace +/// length, because that's what determines the FFT workload. +/// +/// 2^19 is a conservative default calibrated against a 46-core machine where +/// rayon-parallel CPU LDE is already fast. Override via env var for tuning +/// on smaller machines, see `crypto/math-cuda/tests/bench_quick.rs`. +const DEFAULT_GPU_LDE_THRESHOLD: usize = 1 << 19; + +fn gpu_lde_threshold() -> usize { + static CACHED: OnceLock = OnceLock::new(); + *CACHED.get_or_init(|| { + std::env::var("LAMBDA_VM_GPU_LDE_THRESHOLD") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_GPU_LDE_THRESHOLD) + }) +} + +/// Atomically counted by `try_expand_column` every time it actually routes a +/// column to the GPU. Used by benchmarks to confirm the GPU path fired. +static GPU_LDE_CALLS: AtomicU64 = AtomicU64::new(0); + +pub fn gpu_lde_calls() -> u64 { + GPU_LDE_CALLS.load(Ordering::Relaxed) +} + +/// Reset all GPU call counters at once. Useful between bench warm-up and +/// profiled passes so the numbers reported aren't doubled by the warm-up. +pub fn reset_all_gpu_call_counters() { + GPU_LDE_CALLS.store(0, Ordering::Relaxed); + GPU_EXTEND_HALVES_CALLS.store(0, Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.store(0, Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.store(0, Ordering::Relaxed); +} + +pub(crate) static GPU_EXTEND_HALVES_CALLS: AtomicU64 = AtomicU64::new(0); +pub fn gpu_extend_halves_calls() -> u64 { + GPU_EXTEND_HALVES_CALLS.load(Ordering::Relaxed) +} + +// ============================================================================ +// Shared dispatch helpers +// ============================================================================ +// +// Common prologue for the try_expand_* variants: empty-check, threshold, +// TypeId checks, equal-length check, column-to-u64 cast. + +/// Outcome of validating an input slice against the GPU dispatch preconditions. +enum LayoutDispatch { + /// Input slice is empty, no work to do. + Empty, + /// Preconditions not met: below threshold, wrong element types, or + /// columns of unequal length. + Skip, + /// Preconditions met. `n` is the per-column input length: + /// `lde_size = n * blowup_factor` (saturating). + Run { n: usize, lde_size: usize }, +} + +/// Validate preconditions for the base-field batched GPU path: every column +/// must be Goldilocks base-field of equal length, the LDE size must clear the +/// threshold. +fn check_base_layout(columns: &[Vec>], blowup_factor: usize) -> LayoutDispatch +where + F: IsField + 'static, + E: IsField + 'static, +{ + if columns.is_empty() { + return LayoutDispatch::Empty; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return LayoutDispatch::Skip; + } + if TypeId::of::() != TypeId::of::() { + return LayoutDispatch::Skip; + } + if TypeId::of::() != TypeId::of::() { + return LayoutDispatch::Skip; + } + if columns.iter().any(|c| c.len() != n) { + return LayoutDispatch::Skip; + } + LayoutDispatch::Run { n, lde_size } +} + +/// Validate preconditions for the ext3 batched GPU path: every column must be +/// `Degree3GoldilocksExtensionField` of equal length, weights must be over +/// `GoldilocksField`, LDE size must clear the threshold. +fn check_ext3_layout(columns: &[Vec>], blowup_factor: usize) -> LayoutDispatch +where + F: IsField + 'static, + E: IsField + 'static, +{ + if columns.is_empty() { + return LayoutDispatch::Empty; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return LayoutDispatch::Skip; + } + if TypeId::of::() != TypeId::of::() { + return LayoutDispatch::Skip; + } + if TypeId::of::() != TypeId::of::() { + return LayoutDispatch::Skip; + } + if columns.iter().any(|c| c.len() != n) { + return LayoutDispatch::Skip; + } + LayoutDispatch::Run { n, lde_size } +} + +/// Convert base-field columns to `Vec>` for the GPU input slice list. +/// +/// SAFETY: caller must have established `E == GoldilocksField` (e.g. via +/// [`check_base_layout`]). Each `FieldElement` is then a `#[repr(transparent)]` +/// wrapper over `u64`. +unsafe fn columns_to_u64_base(columns: &[Vec>]) -> Vec> { + columns + .iter() + .map(|col| { + col.iter() + .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) + .collect() + }) + .collect() +} + +/// Convert ext3 columns to `Vec>` (de-interleaved into raw `[u64; 3]` +/// lanes per element) for the GPU input slice list. +/// +/// SAFETY: caller must have established `E == Degree3GoldilocksExtensionField` +/// (e.g. via [`check_ext3_layout`]). Each `FieldElement` is then a +/// `#[repr(transparent)]` wrapper over `[u64; 3]`. +unsafe fn columns_to_u64_ext3(columns: &[Vec>]) -> Vec> { + columns + .iter() + .map(|col| { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + unsafe { from_raw_parts(ptr, len) }.to_vec() + }) + .collect() +} + +/// Convert weights to raw `Vec`. +/// +/// SAFETY: caller must have established `F == GoldilocksField`. +unsafe fn weights_to_u64(weights: &[FieldElement]) -> Vec { + weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect() +} + +/// Pre-size each column to `lde_size` and view it as a `&mut [u64]` of length +/// `lde_size` (base-field, single-u64 layout). +/// +/// SAFETY: caller must have established `E == GoldilocksField`. +unsafe fn presize_and_view_base( + columns: &mut [Vec>], + lde_size: usize, +) -> Vec<&mut [u64]> { + for col in columns.iter_mut() { + assert!( + col.capacity() >= lde_size, + "col capacity {} < lde_size {}", + col.capacity(), + lde_size + ); + // SAFETY: assert above guarantees capacity, the GPU path overwrites + // every slot before any reader sees the new length. + unsafe { col.set_len(lde_size) }; + } + columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len(); + // SAFETY: single-u64 layout, caller still owns the backing alloc. + unsafe { from_raw_parts_mut(ptr, len) } + }) + .collect() +} + +/// Same as [`presize_and_view_base`] but for ext3 columns: each view is +/// `3 * lde_size` u64s (de-interleaved lanes). +/// +/// SAFETY: caller must have established `E == Degree3GoldilocksExtensionField`. +unsafe fn presize_and_view_ext3( + columns: &mut [Vec>], + lde_size: usize, +) -> Vec<&mut [u64]> { + for col in columns.iter_mut() { + assert!( + col.capacity() >= lde_size, + "col capacity {} < lde_size {}", + col.capacity(), + lde_size + ); + // SAFETY: assert above + GPU path overwrites every slot. + unsafe { col.set_len(lde_size) }; + } + columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len() * 3; + // SAFETY: ext3 `[u64; 3]` layout, caller still owns the backing. + unsafe { from_raw_parts_mut(ptr, len) } + }) + .collect() +} + +/// Truncate each column back to `n` (trace size) after a GPU error so the +/// CPU fallback (which reads `buffer.len()` as the trace size) runs cleanly. +/// Safe because `math_cuda` writes outputs only at the final host copy, post- +/// synchronize; any `Err` returns before that copy, leaving `columns[0..n]` untouched. +fn restore_columns_on_err(columns: &mut [Vec>], n: usize) { + for col in columns.iter_mut() { + col.truncate(n); + } +} + +/// Allocate the `[u8; 32]` Merkle node buffer for a tree of `lde_size` leaves +/// and return both the node `Vec` (length-initialised, contents undefined) and +/// a `&mut [u8]` byte view of total length `total_nodes * 32`. Returns `None` +/// if the layout would be invalid (`lde_size < 2` or the byte length +/// overflows). The caller must overwrite every byte via the GPU D2H below. +fn alloc_merkle_nodes(lde_size: usize) -> Option<(Vec<[u8; 32]>, usize)> { + if lde_size < 2 { + return None; + } + let total_nodes = 2usize.saturating_mul(lde_size).checked_sub(1)?; + let _byte_len = total_nodes.checked_mul(32)?; + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + // SAFETY: every byte will be overwritten via the GPU D2H before the + // contents are read. The caller computes the byte-length view from the + // returned `nodes` Vec using `total_nodes.checked_mul(32)`. + unsafe { nodes.set_len(total_nodes) }; + Some((nodes, total_nodes)) +} + +/// Try to GPU-batch all columns in one pass. +/// +/// Only engaged for Goldilocks-base tables whose LDE size is above the +/// threshold. The prover's `expand_columns_to_lde` hands us every column of +/// one table at once. Those columns all share twiddles and coset weights so +/// they can be processed in a single batched pipeline on one stream. +/// +/// Returns `Some(())` if the batch was handled on GPU (and `columns` now +/// contains the LDE evaluations). Returns `None` to let the caller run the +/// per-column CPU fallback. +pub(crate) fn try_expand_columns_batched( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option<()> +where + F: IsField + 'static, + E: IsField + 'static, +{ + // Ext3 path: decompose each ext3 column into its 3 base components and + // dispatch to the base-field batched NTT with 3×M logical columns. + // Butterflies with a base-field twiddle act componentwise on ext3, so + // this is exactly equivalent to running the NTT in the extension field. + if TypeId::of::() == TypeId::of::() { + return try_expand_columns_batched_ext3::(columns, blowup_factor, weights); + } + + let (n, lde_size) = match check_base_layout::(columns, blowup_factor) { + LayoutDispatch::Empty => return Some(()), // nothing to do — same as CPU path + LayoutDispatch::Skip => return None, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + + // SAFETY: the `Run` arm of `check_base_layout::` (matched above) + // guarantees `E == GoldilocksField` and `F == GoldilocksField`. + let raw_columns = unsafe { columns_to_u64_base::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + GPU_LDE_CALLS.fetch_add(num_columns as u64, Ordering::Relaxed); + let gpu_result = { + let mut raw_outputs = unsafe { presize_and_view_base::(columns, lde_size) }; + math_cuda::lde::coset_lde_batch_base_into( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + ) + }; + if gpu_result.is_err() { + // Restore columns to trace length for the CPU fallback. `math_cuda` + // only writes outputs at the very end (post-synchronize host copy); + // on any Err the caller's `columns[0..n]` is untouched trace data. + restore_columns_on_err(columns, n); + return None; + } + Some(()) +} + +/// GPU path for `Prover::extend_half_to_lde`. +/// +/// Inside `decompose_and_extend_d2` (R2 quotient decomposition) the prover +/// does `rayon::join` of two calls: `iFFT(N on g²-coset) → FFT(2N on g-coset)` +/// over ext3 halves H0 and H1. They share the same domain/offset and sizes, +/// so we batch them into a single GPU call with M=2 ext3 columns. +/// +/// Weights = `[1/N, g^(-1)/N, g^(-2)/N, …, g^(-(N-1))/N]`. This bakes the +/// `(g²)^(-k)` input-coset-undo from `interpolate_offset_fft` together with +/// the `g^k` forward-coset-shift from `evaluate_polynomial_on_lde_domain` — +/// net is `g^(-k)` — plus the `1/N` iFFT normalisation. +/// +/// Returns `None` when the GPU path doesn't apply (too small, or CPU path +/// should be used); in that case the caller runs its existing rayon::join. +pub(crate) fn try_extend_two_halves_gpu( + h0: &[FieldElement], + h1: &[FieldElement], + domain: &Domain, +) -> Option<(Vec>, Vec>)> +where + F: math::field::traits::IsFFTField + IsField + 'static, + E: IsField + 'static, + F: IsSubFieldOf, +{ + if h0.len() != h1.len() { + return None; + } + let n = h0.len(); + let blowup = 2; // extend_half_to_lde extends N → 2N always + let lde_size = n * blowup; + if lde_size < gpu_lde_threshold() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + GPU_EXTEND_HALVES_CALLS.fetch_add(1, Ordering::Relaxed); + // Weights are built from `g = domain.coset_offset` directly: the + // CPU caller previously passed `g²` redundantly. See the + // `g^(-k) / N` weight loop below. + + // Flatten ext3 slices to raw 3*n u64 buffers. + let to_u64 = |col: &[FieldElement]| -> Vec { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + unsafe { from_raw_parts(ptr, len) }.to_vec() + }; + let h0_raw = to_u64(h0); + let h1_raw = to_u64(h1); + + // weights[k] = g^(-k) / N as a u64. + let inv_n = FieldElement::::from(n as u64).inv().expect("N nonzero"); + let g = &domain.coset_offset; + let g_inv = g.inv().expect("g nonzero"); + let mut weights_u64 = Vec::with_capacity(n); + let mut w = inv_n.clone(); + for _ in 0..n { + // F == GoldilocksField by TypeId check above, so value is u64. + let v: u64 = unsafe { *(w.value() as *const _ as *const u64) }; + weights_u64.push(v); + w = w * &g_inv; + } + + // Pre-allocate outputs. + let mut lde_h0 = vec![FieldElement::::zero(); lde_size]; + let mut lde_h1 = vec![FieldElement::::zero(); lde_size]; + + // Two ext3 columns (h0 + h1), each composed of 3 base-field components. + const NUM_COLS: usize = 2; + GPU_LDE_CALLS.fetch_add((NUM_COLS * 3) as u64, Ordering::Relaxed); + { + let inputs: [&[u64]; 2] = [&h0_raw, &h1_raw]; + // View each output Vec> as &mut [u64] of length 3*lde_size. + let out0_ptr = lde_h0.as_mut_ptr() as *mut u64; + let out1_ptr = lde_h1.as_mut_ptr() as *mut u64; + // SAFETY: ext3 FieldElement is [u64; 3] in memory, and the Vec has len + // = lde_size so the backing is 3*lde_size u64s. + let ext3_len = lde_size + .checked_mul(3) + .expect("ext3 output length overflow"); + let out0_slice = unsafe { from_raw_parts_mut(out0_ptr, ext3_len) }; + let out1_slice = unsafe { from_raw_parts_mut(out1_ptr, ext3_len) }; + let mut outputs: [&mut [u64]; 2] = [out0_slice, out1_slice]; + if math_cuda::lde::coset_lde_batch_ext3_into(&inputs, n, blowup, &weights_u64, &mut outputs) + .is_err() + { + return None; + } + } + + Some((lde_h0, lde_h1)) +} + +pub(crate) static GPU_LEAF_HASH_CALLS: AtomicU64 = AtomicU64::new(0); +pub fn gpu_leaf_hash_calls() -> u64 { + GPU_LEAF_HASH_CALLS.load(Ordering::Relaxed) +} + +/// Fused base-field path: LDE + Keccak-256 leaf hash + Merkle tree build, +/// all on device, with the LDE buffer retained for R2–R4 GPU reuse. On +/// success: `columns[c]` is resized to `lde_size` with the LDE output, and +/// the returned `(tree, GpuLdeBase)` pair is the host-side tree plus a +/// device-resident handle to the LDE buffer. +pub(crate) fn try_expand_leaf_and_tree_batched_keep( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option<( + crypto::merkle_tree::merkle::MerkleTree, + math_cuda::lde::GpuLdeBase, +)> +where + F: IsField + 'static, + E: IsField + 'static, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + let (n, lde_size) = match check_base_layout::(columns, blowup_factor) { + LayoutDispatch::Empty | LayoutDispatch::Skip => return None, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + let (mut nodes, total_nodes) = alloc_merkle_nodes(lde_size)?; + let node_byte_len = total_nodes + .checked_mul(32) + .expect("node byte length overflow"); + + // SAFETY: layout-checked above. + let raw_columns = unsafe { columns_to_u64_base::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + GPU_LDE_CALLS.fetch_add(num_columns as u64, Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, Ordering::Relaxed); + + let handle_result = { + let mut raw_outputs = unsafe { presize_and_view_base::(columns, lde_size) }; + let nodes_bytes: &mut [u8] = + unsafe { from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) }; + math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree_keep( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + }; + let handle = match handle_result { + Ok(h) => h, + Err(_) => { + restore_columns_on_err(columns, n); + return None; + } + }; + + let tree = crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes)?; + Some((tree, handle)) +} + +/// Fused ext3 path: LDE + Keccak-256 leaf hash + Merkle tree build over +/// ext3 columns via the three-slab decomposition, with the ext3 LDE device +/// buffer (de-interleaved 3-slab layout) retained for downstream GPU rounds. +/// `B::Node = [u8; 32]` by construction for `BatchKeccak256Backend`. +pub(crate) fn try_expand_leaf_and_tree_batched_ext3_keep( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option<( + crypto::merkle_tree::merkle::MerkleTree, + math_cuda::lde::GpuLdeExt3, +)> +where + F: IsField + 'static, + E: IsField + 'static, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + let (n, lde_size) = match check_ext3_layout::(columns, blowup_factor) { + LayoutDispatch::Empty | LayoutDispatch::Skip => return None, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + let (mut nodes, total_nodes) = alloc_merkle_nodes(lde_size)?; + let node_byte_len = total_nodes + .checked_mul(32) + .expect("node byte length overflow"); + + // SAFETY: layout-checked above. + let raw_columns = unsafe { columns_to_u64_ext3::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + GPU_LDE_CALLS.fetch_add((num_columns * 3) as u64, Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, Ordering::Relaxed); + + let handle_result = { + let mut raw_outputs = unsafe { presize_and_view_ext3::(columns, lde_size) }; + let nodes_bytes: &mut [u8] = + unsafe { from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) }; + math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree_keep( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + }; + let handle = match handle_result { + Ok(h) => h, + Err(_) => { + restore_columns_on_err(columns, n); + return None; + } + }; + + let tree = crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes)?; + Some((tree, handle)) +} + +/// Ext3 specialisation of [`try_expand_columns_batched`]. `E` is known to be +/// `Degree3GoldilocksExtensionField` by TypeId match at the caller. +fn try_expand_columns_batched_ext3( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option<()> +where + F: IsField + 'static, + E: IsField + 'static, +{ + let (n, lde_size) = match check_ext3_layout::(columns, blowup_factor) { + LayoutDispatch::Empty => return Some(()), + LayoutDispatch::Skip => return None, + LayoutDispatch::Run { n, lde_size } => (n, lde_size), + }; + let num_columns = columns.len(); + + // SAFETY: layout-checked above. + let raw_columns = unsafe { columns_to_u64_ext3::(columns) }; + let weights_u64 = unsafe { weights_to_u64::(weights) }; + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + // Account each ext3 column as 3 logical GPU LDE "calls" (base-field + // components) so the counter matches the base-field batched path. + GPU_LDE_CALLS.fetch_add((num_columns * 3) as u64, Ordering::Relaxed); + let gpu_result = { + let mut raw_outputs = unsafe { presize_and_view_ext3::(columns, lde_size) }; + math_cuda::lde::coset_lde_batch_ext3_into( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + ) + }; + if gpu_result.is_err() { + restore_columns_on_err(columns, n); + return None; + } + Some(()) +} + +static GPU_MERKLE_TREE_CALLS: AtomicU64 = AtomicU64::new(0); +pub fn gpu_merkle_tree_calls() -> u64 { + GPU_MERKLE_TREE_CALLS.load(Ordering::Relaxed) +} diff --git a/crypto/stark/src/lib.rs b/crypto/stark/src/lib.rs index 7379594b4..3ae8415c1 100644 --- a/crypto/stark/src/lib.rs +++ b/crypto/stark/src/lib.rs @@ -13,6 +13,8 @@ pub mod domain; pub mod examples; pub mod frame; pub mod fri; +#[cfg(feature = "cuda")] +pub mod gpu_lde; pub mod grinding; #[cfg(feature = "instruments")] pub mod instruments; diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 460052a51..85c2168ba 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -60,8 +60,8 @@ pub struct Prover< } impl< - Field: IsSubFieldOf + IsFFTField + Send + Sync, - FieldExtension: Send + Sync + IsField, + Field: IsSubFieldOf + IsFFTField + Send + Sync + 'static, + FieldExtension: Send + Sync + IsField + 'static, PI, > IsStarkProver for Prover where @@ -169,6 +169,18 @@ where pub(crate) bus_public_inputs: Option>, } +/// Tuple returned by `commit_main_trace`: the commit, the cached LDE columns, +/// and — under cuda — the optional device LDE buffer kept alive for downstream +/// rounds when the R1 fused GPU pipeline ran. +#[cfg(feature = "cuda")] +type MainCommitTuple = ( + TableCommit, + Vec>>, + Option, +); +#[cfg(not(feature = "cuda"))] +type MainCommitTuple = (TableCommit, Vec>>); + /// Round 1 commitment artifacts — Merkle trees, roots, challenges, and bus inputs. /// Borrowed (not consumed) when building `Round1` in Phase D. pub(crate) struct Round1Commitments @@ -191,6 +203,13 @@ where struct Lde { main: Vec>>, aux: Vec>>, + /// Device-side main LDE buffer, populated only when the R1 GPU fused + /// pipeline ran for this table. Kept so R2/R3/R4 GPU paths can read + /// the LDE without re-H2D. + #[cfg(feature = "cuda")] + gpu_main: Option, + #[cfg(feature = "cuda")] + gpu_aux: Option, } impl Round1Commitments @@ -208,8 +227,20 @@ where step_size: usize, blowup_factor: usize, ) -> Round1 { + #[allow(unused_mut)] + let mut lde_trace = + LDETraceTable::from_columns(lde.main, lde.aux, step_size, blowup_factor); + #[cfg(feature = "cuda")] + { + if let Some(h) = lde.gpu_main { + lde_trace.set_gpu_main(h); + } + if let Some(h) = lde.gpu_aux { + lde_trace.set_gpu_aux(h); + } + } Round1 { - lde_trace: LDETraceTable::from_columns(lde.main, lde.aux, step_size, blowup_factor), + lde_trace, main: self.main.share(), aux: self.aux.as_ref().map(TableCommit::share), rap_challenges: self.rap_challenges.clone(), @@ -477,8 +508,8 @@ where /// `private_interfaces` allow is removed once these helpers move off the trait. #[allow(private_interfaces)] pub trait IsStarkProver< - Field: IsSubFieldOf + IsFFTField + Send + Sync, - FieldExtension: Send + Sync + IsField, + Field: IsSubFieldOf + IsFFTField + Send + Sync + 'static, + FieldExtension: Send + Sync + IsField + 'static, PI, > where FieldElement: math::traits::ByteConversion, @@ -580,13 +611,28 @@ pub trait IsStarkProver< twiddles: &LdeTwiddles, ) where Field: IsSubFieldOf, - E: IsSubFieldOf + IsField + Send + Sync, + E: IsSubFieldOf + IsField + Send + Sync + 'static, FieldElement: Send + Sync, { if columns.is_empty() { return; } + // GPU batched fast path: all columns at once in one pipeline on one + // stream. Falls through to per-column rayon when the table is too + // small, the element type isn't Goldilocks, or the `cuda` feature is + // off. + #[cfg(feature = "cuda")] + if crate::gpu_lde::try_expand_columns_batched::( + columns, + domain.blowup_factor, + &twiddles.coset_weights, + ) + .is_some() + { + return; + } + #[cfg(feature = "parallel")] let iter = columns.par_iter_mut(); #[cfg(not(feature = "parallel"))] @@ -604,7 +650,9 @@ pub trait IsStarkProver< } /// Compute the main-trace LDE and commit. Returns a `TableCommit` along - /// with the owned LDE columns (consumed later in Phase D). + /// with the owned LDE columns (consumed later in Phase D) and — under + /// cuda — the optional device LDE buffer kept alive for downstream rounds + /// when the R1 fused GPU pipeline ran. /// /// `precomputed`: if present, the leading `num_cols` columns are committed /// as a separate Merkle tree (the precomputed split for preprocessed @@ -616,13 +664,40 @@ pub trait IsStarkProver< twiddles: &LdeTwiddles, precomputed: Option<(Commitment, usize)>, #[cfg(feature = "disk-spill")] storage_mode: StorageMode, - ) -> Result<(TableCommit, Vec>>), ProvingError> + ) -> Result, ProvingError> where FieldElement: AsBytes, FieldElement: AsBytes, { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; let mut columns = trace.extract_columns_main(lde_size); + + // Fused GPU path is only wired for non-preprocessed mains today; the + // preprocessed split runs the CPU pipeline below. + #[cfg(feature = "cuda")] + if precomputed.is_none() { + #[cfg(feature = "instruments")] + let t_sub = Instant::now(); + if let Some((tree, handle)) = + crate::gpu_lde::try_expand_leaf_and_tree_batched_keep::< + Field, + Field, + BatchedMerkleTreeBackend, + >(&mut columns, domain.blowup_factor, &twiddles.coset_weights) + { + #[cfg(feature = "instruments")] + let main_lde_dur = t_sub.elapsed(); + let root = tree.root; + // Fused GPU path produces LDE + leaves + tree as one pipeline, + // so the wall-clock total lands in `main_lde_dur`. Bill the + // merkle bucket equal to LDE so the sum (lde + merkle) stays + // comparable to the non-GPU path's combined LDE+commit total. + #[cfg(feature = "instruments")] + crate::instruments::accum_r1_main(main_lde_dur, main_lde_dur); + return Ok((TableCommit::plain(tree, root), columns, Some(handle))); + } + } + #[cfg(feature = "disk-spill")] if storage_mode == StorageMode::Disk { trace.main_table.advise_drop_cache(); @@ -683,6 +758,9 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] crate::instruments::accum_r1_main(main_lde_dur, t_sub.elapsed()); + #[cfg(feature = "cuda")] + return Ok((commit, columns, None)); + #[cfg(not(feature = "cuda"))] Ok((commit, columns)) } @@ -714,7 +792,18 @@ pub trait IsStarkProver< Vec::new() }; - Ok(commitment.build_round1(Lde { main, aux }, air.step_size(), domain.blowup_factor)) + Ok(commitment.build_round1( + Lde { + main, + aux, + #[cfg(feature = "cuda")] + gpu_main: None, + #[cfg(feature = "cuda")] + gpu_aux: None, + }, + air.step_size(), + domain.blowup_factor, + )) } /// Reconstruct Round1 for every table, print the bus balance report, and @@ -834,6 +923,16 @@ pub trait IsStarkProver< // Step 3: Extend each part from N evals on g²-coset to 2N evals on g-coset. // The squared coset offset is g² (= coset_offset²). let coset_offset_squared = &domain.coset_offset * &domain.coset_offset; + + // GPU fast path: batch both halves into one ext3 LDE call. Requires + // `cuda` feature and a qualifying size; falls through to CPU when not. + #[cfg(feature = "cuda")] + if let Some((lde_h0, lde_h1)) = + crate::gpu_lde::try_extend_two_halves_gpu(&h0_evals, &h1_evals, domain) + { + return vec![lde_h0, lde_h1]; + } + let (lde_h0, lde_h1) = crate::par::join( || Self::extend_half_to_lde(&h0_evals, &coset_offset_squared, domain), || Self::extend_half_to_lde(&h1_evals, &coset_offset_squared, domain), @@ -1552,6 +1651,12 @@ pub trait IsStarkProver< let mut main_commits: Vec> = Vec::with_capacity(num_airs); let mut main_ldes: Vec>>> = Vec::with_capacity(num_airs); + // Optional device-side LDE handle per table, populated only when the + // R1 fused GPU pipeline produced one. Threaded through Phase D's zip + // chain so alignment is compiler-enforced (no `.next().expect()`). + #[cfg(feature = "cuda")] + let mut main_gpu_handles: Vec> = + Vec::with_capacity(num_airs); for chunk_start in (0..num_airs).step_by(k) { let chunk_end = (chunk_start + k).min(num_airs); @@ -1584,6 +1689,9 @@ pub trait IsStarkProver< // Sequential: append roots to shared transcript (Fiat-Shamir ordering) for result in chunk_results { + #[cfg(feature = "cuda")] + let (commit, cached_main, gpu_main) = result?; + #[cfg(not(feature = "cuda"))] let (commit, cached_main) = result?; if let Some(ref pre_root) = commit.precomputed_root { transcript.append_bytes(pre_root); @@ -1591,6 +1699,8 @@ pub trait IsStarkProver< transcript.append_bytes(&commit.root); main_commits.push(commit); main_ldes.push(cached_main); + #[cfg(feature = "cuda")] + main_gpu_handles.push(gpu_main); } } @@ -1684,14 +1794,20 @@ pub trait IsStarkProver< }) .collect(); - // Parallel aux commit in chunks of K. Each entry holds the optional aux - // `TableCommit` (`None` when the AIR has no aux trace) and the cached - // aux LDE columns consumed in Phase D. + // Parallel aux commit in chunks of K. The closure returns a cfg-gated + // AuxResult — under cuda it carries the optional ext3 GPU LDE handle + // as a third element so the .zip() chain in Phase D stays + // compiler-aligned with no side vectors. + #[cfg(feature = "cuda")] + type AuxResult = ( + Option>, + Vec>>, + Option, + ); + #[cfg(not(feature = "cuda"))] + type AuxResult = (Option>, Vec>>); #[allow(clippy::type_complexity)] - let mut aux_results: Vec<( - Option>, - Vec>>, - )> = Vec::with_capacity(num_airs); + let mut aux_results: Vec> = Vec::with_capacity(num_airs); for chunk_start in (0..num_airs).step_by(k) { let chunk_end = (chunk_start + k).min(num_airs); @@ -1702,7 +1818,8 @@ pub trait IsStarkProver< #[cfg(not(feature = "parallel"))] let iter = chunk_range; - let chunk_aux: Vec> = iter + #[allow(clippy::type_complexity)] + let chunk_aux: Vec, ProvingError>> = iter .map(|idx| { let (air, trace, _) = &air_trace_pairs[idx]; let domain = &domains[idx]; @@ -1711,6 +1828,40 @@ pub trait IsStarkProver< if air.has_aux_trace() { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; let mut columns = trace.extract_columns_aux(lde_size); + + // GPU combined path: ext3 LDE + Keccak-256 leaf + // hashing + Merkle tree build in one on-device + // pipeline. The fused `_keep` variant also returns + // the device LDE handle for downstream GPU rounds. + #[cfg(feature = "cuda")] + { + #[cfg(feature = "instruments")] + let t_sub = Instant::now(); + if let Some((tree, handle)) = + crate::gpu_lde::try_expand_leaf_and_tree_batched_ext3_keep::< + Field, + FieldExtension, + BatchedMerkleTreeBackend, + >( + &mut columns, domain.blowup_factor, &twiddles.coset_weights + ) + { + #[cfg(feature = "instruments")] + let aux_lde_dur = t_sub.elapsed(); + let root = tree.root; + // Fused GPU path: bill merkle equal to LDE so + // the (lde + merkle) sum stays comparable to + // the non-GPU path's combined R1 total. + #[cfg(feature = "instruments")] + crate::instruments::accum_r1_aux(aux_lde_dur, aux_lde_dur); + return Ok(( + Some(TableCommit::plain(tree, root)), + columns, + Some(handle), + )); + } + } + #[cfg(feature = "disk-spill")] if storage_mode == StorageMode::Disk { trace.aux_table.advise_drop_cache(); @@ -1738,20 +1889,28 @@ pub trait IsStarkProver< ProvingError::DiskSpill(format!("aux Merkle tree: {e}")) })?; } + #[cfg(feature = "cuda")] + return Ok((Some(TableCommit::plain(tree, root)), columns, None)); + #[cfg(not(feature = "cuda"))] Ok((Some(TableCommit::plain(tree, root)), columns)) } else { + #[cfg(feature = "cuda")] + return Ok((None, Vec::new(), None)); + #[cfg(not(feature = "cuda"))] Ok((None, Vec::new())) } }) .collect(); - // Sequential: append aux roots to forked transcripts + // Sequential: append aux roots to forked transcripts. for (j, result) in chunk_aux.into_iter().enumerate() { - let (aux_commit, cached_aux) = result?; - if let Some(ref c) = aux_commit { + let aux_full = result?; + // Tuple shape is cfg-gated; `.0` is the optional TableCommit + // in both variants. + if let Some(ref c) = aux_full.0 { table_transcripts[chunk_start + j].append_bytes(&c.root); } - aux_results.push((aux_commit, cached_aux)); + aux_results.push(aux_full); } } @@ -1760,18 +1919,41 @@ pub trait IsStarkProver< let mut commitments: Vec> = Vec::with_capacity(num_airs); let mut cached_ldes: Vec> = Vec::with_capacity(num_airs); - for (((main_commit, main_lde), (aux_commit, cached_aux)), bus_public_inputs) in main_commits + // Under cuda, fold main_gpu_handles into the zip chain so alignment is + // compiler-enforced (M4: no `.next().expect()` plumbing). + #[cfg(feature = "cuda")] + let main_iter = main_commits .into_iter() .zip(main_ldes) - .zip(aux_results) - .zip(bus_inputs_vec) + .zip(main_gpu_handles); + #[cfg(not(feature = "cuda"))] + let main_iter = main_commits.into_iter().zip(main_ldes); + + for ((main_pack, aux_full), bus_public_inputs) in + main_iter.zip(aux_results).zip(bus_inputs_vec) { + #[cfg(feature = "cuda")] + let ((main_commit, main_lde), gpu_main) = main_pack; + #[cfg(not(feature = "cuda"))] + let (main_commit, main_lde) = main_pack; + #[cfg(feature = "cuda")] + let (aux_commit, cached_aux, gpu_aux) = aux_full; + #[cfg(not(feature = "cuda"))] + let (aux_commit, cached_aux) = aux_full; commitments.push(Round1Commitments { main: main_commit, aux: aux_commit, rap_challenges: lookup_challenges.clone(), bus_public_inputs, }); + #[cfg(feature = "cuda")] + cached_ldes.push(Lde { + main: main_lde, + aux: cached_aux, + gpu_main, + gpu_aux, + }); + #[cfg(not(feature = "cuda"))] cached_ldes.push(Lde { main: main_lde, aux: cached_aux, diff --git a/crypto/stark/src/trace.rs b/crypto/stark/src/trace.rs index 834ffdcda..b3bef444a 100644 --- a/crypto/stark/src/trace.rs +++ b/crypto/stark/src/trace.rs @@ -220,6 +220,16 @@ where pub(crate) aux_columns: Vec>>, pub(crate) lde_step_size: usize, pub(crate) blowup_factor: usize, + /// If the main trace was LDE'd on the GPU via the fused pipeline, + /// the device buffer is retained here so downstream GPU rounds can + /// read the LDE without a re-H2D. `None` when the GPU LDE didn't + /// run (small tables, cuda feature off, fallback path). + #[cfg(feature = "cuda")] + pub(crate) gpu_main: Option, + /// Same as `gpu_main` but for the aux trace (ext3 de-interleaved + /// layout on device). + #[cfg(feature = "cuda")] + pub(crate) gpu_aux: Option, } impl LDETraceTable @@ -242,9 +252,37 @@ where aux_columns, lde_step_size, blowup_factor, + #[cfg(feature = "cuda")] + gpu_main: None, + #[cfg(feature = "cuda")] + gpu_aux: None, } } + /// Attach an already-populated device LDE handle for the main columns. + /// Only set when the GPU fused pipeline produced the LDE — callers that + /// ran the CPU path should leave this alone. + #[cfg(feature = "cuda")] + pub fn set_gpu_main(&mut self, h: math_cuda::lde::GpuLdeBase) { + self.gpu_main = Some(h); + } + + /// Attach an already-populated device LDE handle for the aux columns. + #[cfg(feature = "cuda")] + pub fn set_gpu_aux(&mut self, h: math_cuda::lde::GpuLdeExt3) { + self.gpu_aux = Some(h); + } + + #[cfg(feature = "cuda")] + pub fn gpu_main(&self) -> Option<&math_cuda::lde::GpuLdeBase> { + self.gpu_main.as_ref() + } + + #[cfg(feature = "cuda")] + pub fn gpu_aux(&self) -> Option<&math_cuda::lde::GpuLdeExt3> { + self.gpu_aux.as_ref() + } + /// Consume self and return the owned column vectors. #[allow(clippy::type_complexity)] pub fn into_columns(self) -> (Vec>>, Vec>>) { diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 9e03da9b3..76410dbf8 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [features] default = ["parallel"] parallel = ["stark/parallel", "math/parallel", "crypto/parallel", "dep:rayon"] +cuda = ["stark/cuda"] debug-checks = ["stark/debug-checks"] instruments = ["stark/instruments"] disk-spill = ["stark/disk-spill"] diff --git a/prover/src/tests/prove_elfs_tests.rs b/prover/src/tests/prove_elfs_tests.rs index 6219e766f..ff7eda5df 100644 --- a/prover/src/tests/prove_elfs_tests.rs +++ b/prover/src/tests/prove_elfs_tests.rs @@ -21,10 +21,13 @@ use stark::proof::options::ProofOptions; use stark::traits::AIR; use stark::verifier::{IsStarkVerifier, Verifier}; +use crate::VmProof; +use crate::tables::MaxRowsConfig; use crate::tables::trace_builder::Traces; use crate::tables::types::{GoldilocksExtension, GoldilocksField}; use executor::elf::Elf; +use executor::vm::execution::Executor; // Import shared utilities use crate::VmAirs; @@ -83,6 +86,79 @@ fn prove_and_verify_vm_minimal(elf: &Elf, traces: &mut Traces) -> bool { ) } +/// Like [`crate::prove_with_options_and_inputs`] but with trimmed bitwise (TEST ONLY). +/// +/// ~100x faster than the production path. Same unsoundness caveats as +/// [`Traces::from_elf_and_logs_minimal`]. The full preprocessed bitwise +/// path is covered by `test_prove_elfs_all_instructions_64_full`. +fn prove_vm_minimal(elf_bytes: &[u8], private_inputs: &[u8], max_rows: &MaxRowsConfig) -> VmProof { + let proof_options = ProofOptions::default_test_options(); + let elf = Elf::load(elf_bytes).expect("ELF load"); + let executor = Executor::new(&elf, private_inputs.to_vec()).expect("executor"); + let result = executor.run().expect("execution"); + let mut traces = + Traces::from_elf_and_logs_minimal(&elf, &result.logs, max_rows, private_inputs).unwrap(); + let table_counts = traces.table_counts(); + let airs = VmAirs::new( + &elf, + &proof_options, + true, + &traces.page_configs, + &table_counts, + ); + let runtime_page_ranges = traces.runtime_page_ranges(); + let proof = multi_prove_ram( + airs.air_trace_pairs(&mut traces), + &mut DefaultTranscript::::new(&[]), + ) + .expect("prove"); + let num_private_input_pages = traces + .page_configs + .iter() + .filter(|c| c.is_private_input) + .count(); + VmProof { + proof, + runtime_page_ranges, + table_counts, + public_output: traces.public_output_bytes.clone(), + num_private_input_pages, + } +} + +/// Like [`crate::verify_with_options`] but matches the minimal bitwise AIR. +/// +/// Must be used to verify proofs from [`prove_vm_minimal`]. +fn verify_vm_minimal(vm_proof: &VmProof, elf_bytes: &[u8]) -> bool { + let proof_options = ProofOptions::default_test_options(); + let elf = Elf::load(elf_bytes).expect("ELF load"); + let page_configs = Traces::page_configs_from_elf_and_runtime( + &elf, + &vm_proof.runtime_page_ranges, + vm_proof.num_private_input_pages, + ); + let airs = VmAirs::new( + &elf, + &proof_options, + true, + &page_configs, + &vm_proof.table_counts, + ); + let air_refs = airs.air_refs(); + let expected_bus_balance = crate::compute_expected_commit_bus_balance( + &air_refs, + &vm_proof.proof, + &vm_proof.public_output, + ) + .expect("fingerprint collision in test"); + Verifier::multi_verify( + &air_refs, + &vm_proof.proof, + &mut DefaultTranscript::::new(&[]), + &expected_bus_balance, + ) +} + // ============================================================================= // Integration tests // ============================================================================= @@ -157,7 +233,7 @@ fn test_cpu_only_no_bus() { fn test_prove_elfs_sub_fast() { let _ = env_logger::builder().is_test(true).try_init(); let (elf, logs, _instructions) = run_asm_elf("sub"); - // Use from_elf_and_logs to get PAGE and REGISTER tables for Memory bus + // Use from_elf_and_logs_minimal to get PAGE and REGISTER tables for Memory bus let mut traces = Traces::from_elf_and_logs_minimal(&elf, &logs, &Default::default(), &[]).unwrap(); @@ -1943,11 +2019,9 @@ fn test_crafted_zero_count_proof_must_not_verify() { #[test] fn test_small_max_rows_splits_tables() { let elf_bytes = crate::test_utils::asm_elf_bytes("all_instructions_64"); - let proof_options = ProofOptions::default_test_options(); let max_rows = crate::tables::MaxRowsConfig::small(); - let vm_proof = crate::prove_with_options(&elf_bytes, &proof_options, &max_rows) - .expect("Prover should succeed with small max_rows"); + let vm_proof = prove_vm_minimal(&elf_bytes, &[], &max_rows); // With 2^5 max rows and 64+ instructions, tables should have multiple chunks. assert!( @@ -1956,9 +2030,10 @@ fn test_small_max_rows_splits_tables() { vm_proof.table_counts.cpu ); - let verified = crate::verify_with_options(&vm_proof, &elf_bytes, &proof_options) - .expect("Verifier should not error"); - assert!(verified, "Proof with small max_rows should verify"); + assert!( + verify_vm_minimal(&vm_proof, &elf_bytes), + "Proof with small max_rows should verify" + ); } // ============================================================================= @@ -2023,8 +2098,11 @@ fn test_verify_rejects_inflated_table_counts() { #[test] fn test_prove_wsuffix_64bit() { let elf_bytes = crate::test_utils::asm_elf_bytes("test_wsuffix_64bit"); - let result = crate::prove_and_verify(&elf_bytes).expect("prove_and_verify failed"); - assert!(result, "W-suffix 64-bit register test should verify"); + let vm_proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); + assert!( + verify_vm_minimal(&vm_proof, &elf_bytes), + "W-suffix 64-bit register test should verify" + ); } /// Proves a minimal Rust std program that uses `init_allocator()` and @@ -2041,9 +2119,9 @@ fn test_prove_allocator_minimal_reproducer() { let elf_bytes = std::fs::read(workspace_root.join("executor/program_artifacts/rust/allocator.elf")) .expect("allocator.elf not found — run `make compile-programs-rust`"); - let proof = crate::prove(&elf_bytes).expect("prove should succeed"); + let proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); assert!( - crate::verify(&proof, &elf_bytes).expect("verify should not error"), + verify_vm_minimal(&proof, &elf_bytes), "allocator.elf should verify" ); assert_eq!(proof.public_output, b"Hello World"); @@ -2060,9 +2138,9 @@ fn test_pure_commit_rust() { let elf_bytes = std::fs::read(workspace_root.join("executor/program_artifacts/rust/pure_commit.elf")) .expect("pure_commit.elf not found — run `make compile-programs-rust`"); - let proof = crate::prove(&elf_bytes).expect("prove should succeed"); + let proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); assert!( - crate::verify(&proof, &elf_bytes).expect("verify should not error"), + verify_vm_minimal(&proof, &elf_bytes), "pure_commit.elf should verify" ); assert_eq!(proof.public_output, vec![0xAA, 0xBB, 0xCC, 0xDD]); @@ -2085,12 +2163,8 @@ fn test_prove_with_input_empty() { fn test_prove_private_input_xpage() { let elf_bytes = crate::test_utils::asm_elf_bytes("test_private_input_xpage"); let input: Vec = (0u8..16).collect(); - let proof = - crate::prove_with_inputs(&elf_bytes, &input).expect("prove_with_inputs should succeed"); - assert!( - crate::verify(&proof, &elf_bytes).expect("verify should not error"), - "proof should verify" - ); + let proof = prove_vm_minimal(&elf_bytes, &input, &Default::default()); + assert!(verify_vm_minimal(&proof, &elf_bytes), "proof should verify"); assert_eq!(proof.public_output, input[4..12].to_vec()); } @@ -2102,11 +2176,8 @@ fn test_prove_private_input_different_values() { 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, ]; - let proof = crate::prove_with_inputs(&elf_bytes, &input).expect("prove"); - assert!( - crate::verify(&proof, &elf_bytes).expect("verify"), - "proof should verify" - ); + let proof = prove_vm_minimal(&elf_bytes, &input, &Default::default()); + assert!(verify_vm_minimal(&proof, &elf_bytes), "proof should verify"); assert_eq!(proof.public_output, input[4..12].to_vec()); } @@ -2146,9 +2217,9 @@ fn test_prove_commit_sum() { std::fs::read(workspace_root.join("executor/program_artifacts/rust/commit_sum.elf")) .expect("commit_sum.elf not found — run `make compile-programs-rust`"); let input = &[3u8, 5u8]; - let proof = crate::prove_with_inputs(&elf_bytes, input).expect("prove should succeed"); + let proof = prove_vm_minimal(&elf_bytes, input, &Default::default()); assert!( - crate::verify(&proof, &elf_bytes).expect("verify should not error"), + verify_vm_minimal(&proof, &elf_bytes), "commit_sum should verify" ); assert_eq!(proof.public_output, vec![8u8]); @@ -2264,7 +2335,7 @@ fn test_verify_rejects_private_input_with_tampered_public_output() { let vm_proof = crate::prove_with_inputs(&elf_bytes, &input).expect("prove should succeed"); assert!( - crate::verify(&vm_proof, &elf_bytes).expect("verify"), + crate::verify(&vm_proof, &elf_bytes).expect("verify should not error"), "Baseline must verify" ); @@ -2313,8 +2384,11 @@ fn test_proof_does_not_contain_private_input_field() { #[test] fn test_addiw_neg_immediate() { let elf_bytes = crate::test_utils::asm_elf_bytes("test_addiw_neg"); - let result = crate::prove_and_verify(&elf_bytes).expect("prove_and_verify failed"); - assert!(result, "addiw with negative immediate should verify"); + let proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); + assert!( + verify_vm_minimal(&proof, &elf_bytes), + "addiw with negative immediate should verify" + ); } /// Regression test: both main and aux field element counts must be nonzero for any real ELF. diff --git a/prover/tests/bench_single.rs b/prover/tests/bench_single.rs new file mode 100644 index 000000000..fac6ad901 --- /dev/null +++ b/prover/tests/bench_single.rs @@ -0,0 +1,23 @@ +//! Single-prove bench for profiling with nsys / ncu. +use lambda_vm_prover::test_utils::asm_elf_bytes; + +#[test] +#[ignore = "bench; run with --ignored --nocapture"] +fn prove_fib_1m_once() { + let elf = asm_elf_bytes("fib_iterative_1M"); + // Warm-up pays one-time costs (PTX load, pool warm-up). + let _ = lambda_vm_prover::prove(&elf).expect("warm-up"); + // Reset GPU counters so the profiled-pass assert below reflects only the + // second run, not warm-up + profiled combined. + #[cfg(feature = "cuda")] + stark::gpu_lde::reset_all_gpu_call_counters(); + // The profiled run: + let _ = lambda_vm_prover::prove(&elf).expect("prove"); + // Catch silent regressions where the table sizes drop below the GPU LDE + // threshold and we'd be measuring CPU numbers without noticing. + #[cfg(feature = "cuda")] + assert!( + stark::gpu_lde::gpu_lde_calls() > 0, + "GPU LDE path did not fire — fib_iterative_1M may have dropped below the GPU threshold" + ); +}