Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions crypto/crypto/src/merkle_tree/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,44 @@ where
Self::build_from_hashed_leaves(hashed_leaves)
}

/// Useful for handing a GPU-built tree to the stark prover.
/// Performs no hashing, the caller is responsible for the layout's
/// cryptographic correctness.
///
/// Expected layout (matches [`build_from_hashed_leaves`]):
/// - `nodes.len() == 2 * leaves_len - 1` where `leaves_len` is a power of two
/// - `nodes[0]` is the root
/// - `nodes[leaves_len - 1 .. 2*leaves_len - 1]` are the leaves
pub fn from_precomputed_nodes(nodes: Vec<B::Node>) -> Option<Self> {
Comment thread
ColoCarletti marked this conversation as resolved.
if nodes.is_empty() {
return None;
}
// Validate (cheap) that (nodes.len() + 1) is a power of two: there
// must be `leaves_len - 1 + leaves_len = 2*leaves_len - 1` entries.
let total = nodes.len();
if !(total + 1).is_power_of_two() {
return None;
}
// Debug-only integrity spot-check: the root must equal hash(left, right).
// Catches GPU correctness regressions in CI without paying for a full
// tree walk on every call.
#[cfg(debug_assertions)]
if total >= 3 {
let expected_root = B::hash_new_parent(&nodes[1], &nodes[2]);
debug_assert!(
nodes[ROOT] == expected_root,
"from_precomputed_nodes: root does not hash from children",
);
}
let root = nodes[ROOT].clone();
Some(MerkleTree {
root,
nodes,
#[cfg(feature = "disk-spill")]
mmap_backing: None,
})
}

/// Create a Merkle tree from pre-hashed leaf nodes.
///
/// This skips the `hash_leaves` step, useful when leaves have already been
Expand Down
69 changes: 52 additions & 17 deletions crypto/math-cuda/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,26 @@ const STREAM_POOL_SIZE: usize = 32;
pub struct Backend {
pub ctx: Arc<CudaContext>,
streams: Vec<Arc<CudaStream>>,
/// Single shared pinned staging buffer, grown to the biggest LDE size
/// seen. Concurrent batched LDE calls serialise on it; in exchange the
/// process keeps only ONE gigabyte-sized pinned allocation (per-stream
/// buffers 32×-inflated memory use and multiplied the one-time pinning
/// cost for every first use of a new table size).
pinned_staging: Mutex<PinnedStaging>,
/// Separate pinned staging for Merkle leaf hashes. Sized `num_rows * 32`
/// bytes. It lives alongside the LDE staging so the GPU→host D2H for
/// hashed leaves runs at full PCIe line-rate.
pinned_hashes: Mutex<PinnedStaging>,
/// Per-rayon-worker pinned staging buffers, grown lazily to the biggest
/// LDE size each worker sees. Indexed by `rayon::current_thread_index()`
/// (or 0 for non-rayon callers).
///
/// Per-worker (not single-shared) because the LDE call holds the lock
/// across an internal rayon `par_chunks_mut`/`par_iter` window: with a
/// single shared mutex, rayon work-stealing can yield a lock-holder onto
/// another task waiting for the same lock — classic recursive-rayon
/// deadlock. Per-worker buffers eliminate cross-worker contention so
/// each `par_iter` worker hits a distinct mutex.
///
/// Each entry starts empty (`PinnedStaging::empty()` is a zero-cost null
/// handle); only the slots actually used by the running workers ever
/// allocate pinned memory. Worst-case footprint is `N_workers ×
/// max_LDE_size` of pinned host RAM.
pinned_staging: Vec<Mutex<PinnedStaging>>,
/// Per-worker pinned staging for Merkle leaf hashes. Same layout as
/// `pinned_staging`; sized `num_rows * 32` bytes per slot. Lives
/// alongside the LDE staging so the GPU→host D2H runs at PCIe line-rate.
pinned_hashes: Vec<Mutex<PinnedStaging>>,
util_stream: Arc<CudaStream>,
next: AtomicUsize,

Expand Down Expand Up @@ -166,8 +176,20 @@ impl Backend {
for _ in 0..STREAM_POOL_SIZE {
streams.push(ctx.new_stream()?);
}
let pinned_staging = Mutex::new(PinnedStaging::empty());
let pinned_hashes = Mutex::new(PinnedStaging::empty());
// Size to the rayon worker count, plus one for non-rayon callers
// who land on slot 0 (`rayon::current_thread_index()` returns None
// outside a rayon context — we map that to 0).
//
// `current_num_threads()` returns the default-pool size if no custom
// pool is in use, which is the cpu count. Stable across the
// backend's lifetime since rayon's pool is fixed at first use.
let n_slots = rayon::current_num_threads().max(1);
let pinned_staging: Vec<Mutex<PinnedStaging>> = (0..n_slots)
.map(|_| Mutex::new(PinnedStaging::empty()))
.collect();
let pinned_hashes: Vec<Mutex<PinnedStaging>> = (0..n_slots)
.map(|_| Mutex::new(PinnedStaging::empty()))
.collect();
// Separate "utility" stream for twiddle uploads and other bookkeeping;
// not part of the pool that callers rotate through.
let util_stream = ctx.new_stream()?;
Expand Down Expand Up @@ -219,16 +241,29 @@ impl Backend {
self.streams[idx].clone()
}

/// Shared pinned staging buffer. Grows to the largest LDE the process
/// has seen so far. Concurrent callers serialise on the mutex.
/// Per-rayon-worker pinned staging buffer. Returns the slot for the
/// current worker (or slot 0 outside a rayon context). Grows lazily to
/// the largest LDE the worker has seen. See the field docs for the
/// rationale behind the per-worker split.
pub fn pinned_staging(&self) -> &Mutex<PinnedStaging> {
&self.pinned_staging
&self.pinned_staging[self.worker_slot(self.pinned_staging.len())]
}

/// Separate pinned staging for Merkle leaf hash output. Sized in u64
/// Per-worker pinned staging for Merkle leaf hash output. Sized in u64
/// units. Caller should reserve `(num_rows * 32 + 7) / 8` u64s.
pub fn pinned_hashes(&self) -> &Mutex<PinnedStaging> {
&self.pinned_hashes
&self.pinned_hashes[self.worker_slot(self.pinned_hashes.len())]
}

/// Map `rayon::current_thread_index()` to a slot index, with a defensive
/// clamp in case the rayon pool grew past the Vec we sized at init.
fn worker_slot(&self, len: usize) -> usize {
let idx = rayon::current_thread_index().unwrap_or(0);
// Should be unreachable with rayon's fixed default pool, but if a
// larger custom pool sneaks in we still want safety — fall back to
// slot 0 (correctness preserved, just contention).
debug_assert!(idx < len, "rayon worker {idx} >= staging slots {len}");
idx.min(len.saturating_sub(1))
}

pub fn fwd_twiddles_for(&self, log_n: u64) -> Result<Arc<CudaSlice<u64>>> {
Expand Down
107 changes: 47 additions & 60 deletions crypto/math-cuda/src/lde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
use std::sync::Arc;

use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg};
use rayon::prelude::*;

use crate::Result;
use crate::device::{Backend, backend};
Expand Down Expand Up @@ -69,7 +68,10 @@ pub(crate) fn pack_ext3_to_pinned_slabs(columns: &[&[u64]], pinned: &mut [u64],
let m = columns.len();
debug_assert!(pinned.len() >= 3 * m * n);
let pinned_ptr_u = pinned.as_mut_ptr() as usize;
columns.par_iter().enumerate().for_each(|(c, col)| {
// Sequential, not `par_iter`: this runs while the per-worker pinned
// staging mutex is held. Rayon inside a held mutex risks recursive
// stealing-during-wait deadlocks — see `Backend::pinned_staging` docs.
columns.iter().enumerate().for_each(|(c, col)| {
// SAFETY: each task writes to disjoint `[(c*3 + k)*n .. ..+n]` regions
// of `pinned`. The outer `&mut [u64]` borrow guarantees no aliasing.
let slab_a = unsafe {
Expand All @@ -96,7 +98,9 @@ fn unpack_pinned_slabs_to_ext3(pinned: &[u64], outputs: &mut [&mut [u64]], lde_s
let m = outputs.len();
debug_assert!(pinned.len() >= 3 * m * lde_size);
let pinned_const = pinned.as_ptr() as usize;
outputs.par_iter_mut().enumerate().for_each(|(c, dst)| {
// Sequential, not `par_iter_mut`: runs inside a held pinned-staging
// mutex; rayon-inside-mutex risks deadlock (see `Backend::pinned_staging`).
outputs.iter_mut().enumerate().for_each(|(c, dst)| {
// SAFETY: each task reads from disjoint `[(c*3 + k)*lde_size .. ..+lde_size]`
// regions of `pinned`. Caller borrows `pinned` for the duration of the call.
let slab_a = unsafe {
Expand Down Expand Up @@ -178,19 +182,14 @@ fn d2h_bytes_via_pinned_hashes(
stream.memcpy_dtoh(dev_bytes, pinned_bytes)?;
stream.synchronize()?;

// Single-threaded `copy_from_slice` faults virgin pageable pages one at
// a time; the mm_struct rwsem serialises them at prover scale. Chunk so
// ~N cores pre-fault+write in parallel.
const CHUNK: usize = 64 * 1024;
let src_ptr = pinned_bytes.as_ptr() as usize;
dst.par_chunks_mut(CHUNK).enumerate().for_each(|(i, d)| {
// SAFETY: each task reads `[i*CHUNK .. i*CHUNK + d.len()]` of
// `pinned_bytes`, which is disjoint per `i` and lives until `staging`
// is dropped below.
let src =
unsafe { std::slice::from_raw_parts((src_ptr as *const u8).add(i * CHUNK), d.len()) };
d.copy_from_slice(src);
});
// Sequential, not `par_chunks_mut`: this runs while the per-worker
// pinned_hashes mutex is held. Rayon inside a held mutex risks
// recursive stealing-during-wait deadlocks — see
// `Backend::pinned_staging` docs. Page-fault parallelism on virgin
// destination pages is recovered at the outer level: per-worker
// staging buffers let rayon's outer `par_iter` dispatch multiple LDE
// calls in parallel, each faulting its own destination pages.
dst.copy_from_slice(pinned_bytes);
drop(staging);
Ok(())
}
Expand Down Expand Up @@ -367,18 +366,14 @@ pub fn coset_lde_batch_base(
// SAFETY: staging is locked, the slice alias ends before we unlock.
let pinned = unsafe { staging.as_mut_slice(m * lde_size) };

// Pack columns into first m*n slots of the pinned buffer. Parallel: pinned
// writes are DRAM-bandwidth bound, so rayon spreads the cost across CPU
// cores.
let pinned_base_ptr = pinned.as_mut_ptr() as usize;
columns.par_iter().enumerate().for_each(|(c, col)| {
// SAFETY: each task writes to a disjoint `[c*n..c*n+n]` region of
// `pinned`, and the outer `staging` lock guarantees no other call is
// using the buffer concurrently.
let dst =
unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) };
dst.copy_from_slice(col);
});
// Pack columns into first m*n slots of the pinned buffer. Sequential
// (not `par_iter`) because this runs inside the held pinned-staging
// mutex — see `Backend::pinned_staging` docs. Pre-fault parallelism on
// the destination is recovered at the outer level via per-worker
// staging slots.
for (c, col) in columns.iter().enumerate() {
pinned[c * n..c * n + n].copy_from_slice(col);
}

// Column layout: `buf[c * lde_size + r]`. Zeroed so the [n, lde_size)
// tail of each column is already the zero-pad the CPU path does.
Expand Down Expand Up @@ -459,12 +454,11 @@ pub fn coset_lde_batch_base(
stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?;
stream.synchronize()?;

// Split pinned → per-column Vec<u64>s. The first write to each virgin
// Vec page-faults, which can dominate total time. Parallelise so the
// fault cost spreads across CPU cores.
let pinned_ptr = pinned.as_ptr() as usize;
// Split pinned → per-column Vec<u64>s. Sequential (not `into_par_iter`)
// because this runs inside the held pinned-staging mutex — see
// `Backend::pinned_staging` docs. Fault-cost parallelism is recovered
// at the outer level (per-worker staging slots).
let out: Vec<Vec<u64>> = (0..m)
.into_par_iter()
.map(|c| {
// set_len skips the O(N) zero-init that vec![0; n] would do.
// copy_from_slice below writes every slot before any reader
Expand All @@ -475,10 +469,7 @@ pub fn coset_lde_batch_base(
unsafe { v.set_len(lde_size) };
v
};
let src = unsafe {
std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size)
};
v.copy_from_slice(src);
v.copy_from_slice(&pinned[c * lde_size..c * lde_size + lde_size]);
v
})
.collect();
Expand Down Expand Up @@ -602,15 +593,14 @@ pub fn coset_lde_batch_base_into(
stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?;
stream.synchronize()?;

// Parallel copy pinned → caller outputs. Caller's Vecs may still fault
// on first write; we spread that cost across rayon cores.
let pinned_ptr = pinned.as_ptr() as usize;
outputs.par_iter_mut().enumerate().for_each(|(c, dst)| {
let src = unsafe {
std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size)
};
dst.copy_from_slice(src);
});
// Sequential copy pinned → caller outputs (not `par_iter_mut`): runs
// inside the held pinned-staging mutex; rayon-inside-mutex risks
// recursive stealing-during-wait deadlocks (see
// `Backend::pinned_staging`). Fault-cost parallelism is recovered at
// the outer level via per-worker staging slots.
for (c, dst) in outputs.iter_mut().enumerate() {
dst.copy_from_slice(&pinned[c * lde_size..c * lde_size + lde_size]);
}
drop(staging);
Ok(())
}
Expand Down Expand Up @@ -734,12 +724,12 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner(
staging.ensure_capacity(m * lde_size, &be.ctx)?;
let pinned = unsafe { staging.as_mut_slice(m * lde_size) };

let pinned_base_ptr = pinned.as_mut_ptr() as usize;
columns.par_iter().enumerate().for_each(|(c, col)| {
let dst =
unsafe { std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) };
dst.copy_from_slice(col);
});
// Sequential pack (not `par_iter`): runs inside the held pinned-staging
// mutex (see `Backend::pinned_staging` docs). Per-worker staging slots
// give the outer parallelism back.
for (c, col) in columns.iter().enumerate() {
pinned[c * n..c * n + n].copy_from_slice(col);
}

let mut buf = stream.alloc_zeros::<u64>(m * lde_size)?;
for c in 0..m {
Expand Down Expand Up @@ -833,14 +823,11 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner(
stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?;
d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?;

// Pinned LDE → caller outputs (post-sync host memcpy).
let pinned_ptr = pinned.as_ptr() as usize;
outputs.par_iter_mut().enumerate().for_each(|(c, dst)| {
let src = unsafe {
std::slice::from_raw_parts((pinned_ptr as *const u64).add(c * lde_size), lde_size)
};
dst.copy_from_slice(src);
});
// Sequential pinned → caller outputs (not `par_iter_mut`): runs inside
// the held pinned-staging mutex (see `Backend::pinned_staging` docs).
for (c, dst) in outputs.iter_mut().enumerate() {
dst.copy_from_slice(&pinned[c * lde_size..c * lde_size + lde_size]);
}
drop(staging);

if keep_device_buf {
Expand Down
1 change: 1 addition & 0 deletions crypto/stark/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ rayon = { version = "1.8.0", optional = true }
memmap2 = { version = "0.9", optional = true }
tempfile = { version = "3", optional = true }
libc = { version = "0.2", optional = true }

# GPU backend for trace LDE — only linked when `cuda` is enabled.
math-cuda = { path = "../math-cuda", optional = true }

Expand Down
Loading
Loading