diff --git a/crypto/math/src/fft/cpu/bit_reversing.rs b/crypto/math/src/fft/cpu/bit_reversing.rs index f225dd5e0..fd6936ff7 100644 --- a/crypto/math/src/fft/cpu/bit_reversing.rs +++ b/crypto/math/src/fft/cpu/bit_reversing.rs @@ -1,7 +1,42 @@ /// In-place bit-reverse permutation algorithm. Requires input length to be a power of two. -pub fn in_place_bit_reverse_permute(input: &mut [E]) { - for i in 0..input.len() { - let bit_reversed_index = reverse_index(i, input.len() as u64); +pub fn in_place_bit_reverse_permute(input: &mut [E]) { + let n = input.len(); + #[cfg(feature = "parallel")] + { + // Pair-parallel swap: each pair (i, br(i)) with i < br(i) is independent of all + // other pairs (disjoint indices), so threads can swap concurrently provided they + // never touch the same memory location. `if br > i` selects exactly one owner + // per pair, so no two threads ever write the same slot. + const PARALLEL_BITREV_THRESHOLD: usize = 1 << 14; + if n >= PARALLEL_BITREV_THRESHOLD { + use rayon::prelude::*; + struct SendPtr(*mut E); + impl Copy for SendPtr {} + impl Clone for SendPtr { + fn clone(&self) -> Self { + *self + } + } + unsafe impl Send for SendPtr {} + unsafe impl Sync for SendPtr {} + let ptr = SendPtr(input.as_mut_ptr()); + (0..n).into_par_iter().for_each(|i| { + let br = reverse_index(i, n as u64); + if br > i { + // SAFETY: (i, br) uniquely identifies this pair (smaller index is owner), + // so no two threads race on the same `ptr.0.add(k)` slot. Both indices + // are in-bounds since i < n and br < n. + let p = ptr; + unsafe { + core::ptr::swap(p.0.add(i), p.0.add(br)); + } + } + }); + return; + } + } + for i in 0..n { + let bit_reversed_index = reverse_index(i, n as u64); if bit_reversed_index > i { input.swap(i, bit_reversed_index); } diff --git a/crypto/math/src/fft/polynomial.rs b/crypto/math/src/fft/polynomial.rs index 9157903fd..129473207 100644 --- a/crypto/math/src/fft/polynomial.rs +++ b/crypto/math/src/fft/polynomial.rs @@ -80,6 +80,37 @@ impl Polynomial> { evaluate_fft_cpu::(&coeffs) } + /// Same as `evaluate_fft` but returns the evaluations in bit-reversed order, + /// skipping the final natural-order permutation. Use when the consumer expects + /// bit-reversed input (e.g. FRI commit phase, which pairs consecutive values as + /// {f(x), f(-x)}). + pub fn evaluate_fft_bit_reversed>( + poly: &Polynomial>, + blowup_factor: usize, + domain_size: Option, + ) -> Result>, FFTError> + where + E: Send + Sync, + { + let domain_size = domain_size.unwrap_or(0); + let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor; + if len.trailing_zeros() as u64 > F::TWO_ADICITY { + return Err(FFTError::DomainSizeError(len.trailing_zeros() as usize)); + } + if poly.coefficients().is_empty() { + return Ok(vec![FieldElement::zero(); len]); + } + + let mut coeffs = poly.coefficients().to_vec(); + coeffs.resize(len, FieldElement::zero()); + + let order = len.trailing_zeros() as u64; + let layer_twiddles = + LayerTwiddles::::new(order).ok_or(FFTError::DomainSizeError(order as usize))?; + dispatch_fft(&mut coeffs, &layer_twiddles)?; + Ok(coeffs) + } + /// Returns `N` evaluations with an offset of this polynomial using FFT over a domain in a subfield F of E /// (so the results are P(w^i), with w being a primitive root of unity). /// `N = max(self.coeff_len(), domain_size).next_power_of_two() * blowup_factor`. diff --git a/crypto/math/src/field/element.rs b/crypto/math/src/field/element.rs index 9c2ac3258..e34ec0fb7 100644 --- a/crypto/math/src/field/element.rs +++ b/crypto/math/src/field/element.rs @@ -51,7 +51,29 @@ impl FieldElement { /// Computes the multiplicative inverses of a slice of field elements /// The algorithm just performs one inversion and several multiplications and should be used /// when wanting to invert several elements together - pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> { + pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> + where + Self: Send + Sync, + { + #[cfg(feature = "parallel")] + { + // Montgomery batch inverse has a serial prefix-product dependency, but + // chunks are independent — each chunk inverts its own elements without + // needing values from other chunks. Trade K-1 extra field inversions + // (negligible vs ~2N mults per chunk) for K-way parallelism. + const PARALLEL_BATCH_INV_THRESHOLD: usize = 1 << 16; + if numbers.len() >= PARALLEL_BATCH_INV_THRESHOLD { + use rayon::prelude::*; + let chunk_size = numbers.len().div_ceil(rayon::current_num_threads().max(1)); + return numbers + .par_chunks_mut(chunk_size) + .try_for_each(Self::inplace_batch_inverse_sequential); + } + } + Self::inplace_batch_inverse_sequential(numbers) + } + + fn inplace_batch_inverse_sequential(numbers: &mut [Self]) -> Result<(), FieldError> { if numbers.is_empty() { return Ok(()); } diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 43086d4fa..d0fd356cc 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -389,10 +389,11 @@ pub trait IsStarkProver< /// Builds a Merkle tree commitment from column-major LDE evaluations with /// bit-reverse permutation, without cloning the full evaluation matrix. /// - /// For each row index `i`, we hash `col_0[br(i)] || col_1[br(i)] || ...` - /// where `br(i)` is the bit-reversal of `i`. This produces the same Merkle - /// tree as the old clone + bit-reverse + columns2rows + batch_commit flow, - /// but avoids allocating the cloned and transposed matrices entirely. + /// Hashes `col_0[k] || col_1[k] || ...` for k = 0..num_rows (sequential column + /// reads, cache-friendly), then permutes the hash vector in bit-reversed order + /// so leaves[i] = hash(col_0[br(i)] || col_1[br(i)] || ...). Same Merkle tree + /// as reading at br(row_idx) inside the hashing loop, but the scattered column + /// access is replaced by a single small bit-reverse pass over 32-byte digests. fn commit_columns_bit_reversed( columns: &[Vec>], ) -> Option<(BatchedMerkleTree, Commitment)> @@ -420,21 +421,20 @@ pub trait IsStarkProver< #[cfg(not(feature = "parallel"))] let iter = 0..num_rows; - // One allocation per row (was one per field element): write all columns - // into a single buffer, then hash once. - let hashed_leaves: Vec = iter - .map(|row_idx| { - let br_idx = reverse_index(row_idx, num_rows as u64); + let mut hashed_leaves: Vec = iter + .map(|k| { let total_bytes = num_cols * byte_len; let mut buf = vec![0u8; total_bytes]; for col_idx in 0..num_cols { - columns[col_idx][br_idx] + columns[col_idx][k] .write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]); } BatchedMerkleTreeBackend::::hash_bytes(&buf) }) .collect(); + in_place_bit_reverse_permute(&mut hashed_leaves); + let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; Some((tree, root)) @@ -1109,9 +1109,12 @@ pub trait IsStarkProver< let t_sub = Instant::now(); let deep_poly = Polynomial::interpolate_fft::(&deep_evals).expect("iFFT should succeed"); - let mut lde_evals = Polynomial::evaluate_fft::(&deep_poly, 1, Some(domain_size)) - .expect("FFT should succeed"); - in_place_bit_reverse_permute(&mut lde_evals); + // FRI commit_phase consumes bit-reversed evaluations natively. Request them + // directly from evaluate_fft_bit_reversed to avoid a pair of redundant permutes + // (evaluate_fft's internal natural-order permute + an external re-bit-reverse). + let lde_evals = + Polynomial::evaluate_fft_bit_reversed::(&deep_poly, 1, Some(domain_size)) + .expect("FFT should succeed"); #[cfg(feature = "instruments")] let r4_fft_dur = t_sub.elapsed(); diff --git a/prover/src/tables/trace_builder.rs b/prover/src/tables/trace_builder.rs index d2743a1e5..7ca975970 100644 --- a/prover/src/tables/trace_builder.rs +++ b/prover/src/tables/trace_builder.rs @@ -1686,15 +1686,23 @@ struct CollectedOps { } /// Chunk raw ops and generate one trace table per chunk. -fn chunk_and_generate( +fn chunk_and_generate( ops: &[T], max_rows: usize, - generate: impl Fn(&[T]) -> TraceTable, + generate: impl Fn(&[T]) -> TraceTable + Sync + Send, ) -> Vec> { if ops.is_empty() { vec![generate(&[])] } else { - ops.chunks(max_rows).map(generate).collect() + #[cfg(feature = "parallel")] + { + use rayon::prelude::*; + ops.par_chunks(max_rows).map(&generate).collect() + } + #[cfg(not(feature = "parallel"))] + { + ops.chunks(max_rows).map(generate).collect() + } } }