Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions crypto/math/src/fft/cpu/bit_reversing.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,42 @@
/// In-place bit-reverse permutation algorithm. Requires input length to be a power of two.
pub fn in_place_bit_reverse_permute<E>(input: &mut [E]) {
for i in 0..input.len() {
let bit_reversed_index = reverse_index(i, input.len() as u64);
pub fn in_place_bit_reverse_permute<E: Send>(input: &mut [E]) {
let n = input.len();
#[cfg(feature = "parallel")]
{
// Pair-parallel swap: each pair (i, br(i)) with i < br(i) is independent of all
// other pairs (disjoint indices), so threads can swap concurrently provided they
// never touch the same memory location. `if br > i` selects exactly one owner
// per pair, so no two threads ever write the same slot.
const PARALLEL_BITREV_THRESHOLD: usize = 1 << 14;
if n >= PARALLEL_BITREV_THRESHOLD {
Comment on lines +10 to +11
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Medium — Missing power-of-two guard before unsafe parallel swap

The SAFETY argument for core::ptr::swap relies on reverse_index being a bijection on [0, n), which only holds when n is a power of two. The sequential path just produces wrong output if the contract is broken; the parallel path invokes undefined behaviour (a data race) because two threads could swap the same element concurrently.

A debug_assert! catches violations in debug builds at zero release cost:

Suggested change
const PARALLEL_BITREV_THRESHOLD: usize = 1 << 14;
if n >= PARALLEL_BITREV_THRESHOLD {
const PARALLEL_BITREV_THRESHOLD: usize = 1 << 14;
debug_assert!(n.is_power_of_two(), "in_place_bit_reverse_permute requires a power-of-two length");
if n >= PARALLEL_BITREV_THRESHOLD {

use rayon::prelude::*;
struct SendPtr<E>(*mut E);
impl<E> Copy for SendPtr<E> {}
impl<E> Clone for SendPtr<E> {
fn clone(&self) -> Self {
*self
}
}
unsafe impl<E> Send for SendPtr<E> {}
unsafe impl<E> Sync for SendPtr<E> {}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low — Sync impl is broader than needed

for_each on a rayon ParallelIterator only requires the closure to be Send. Since ptr is Copy, each closure invocation gets its own copy of the raw pointer, so the closure is Send as long as SendPtr: Send. The Sync impl (&SendPtr<E> safe to share across threads) is never exercised by this code and makes a stronger safety claim than what is actually verified. It should be removed to keep the unsafe surface minimal.

Suggested change
unsafe impl<E> Sync for SendPtr<E> {}
unsafe impl<E> Send for SendPtr<E> {}

let ptr = SendPtr(input.as_mut_ptr());
(0..n).into_par_iter().for_each(|i| {
let br = reverse_index(i, n as u64);
if br > i {
// SAFETY: (i, br) uniquely identifies this pair (smaller index is owner),
// so no two threads race on the same `ptr.0.add(k)` slot. Both indices
// are in-bounds since i < n and br < n.
let p = ptr;
unsafe {
core::ptr::swap(p.0.add(i), p.0.add(br));
}
}
});
return;
}
}
for i in 0..n {
let bit_reversed_index = reverse_index(i, n as u64);
if bit_reversed_index > i {
input.swap(i, bit_reversed_index);
}
Expand Down
31 changes: 31 additions & 0 deletions crypto/math/src/fft/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,37 @@ impl<E: IsField> Polynomial<FieldElement<E>> {
evaluate_fft_cpu::<F, E>(&coeffs)
}

/// Same as `evaluate_fft` but returns the evaluations in bit-reversed order,
/// skipping the final natural-order permutation. Use when the consumer expects
/// bit-reversed input (e.g. FRI commit phase, which pairs consecutive values as
/// {f(x), f(-x)}).
pub fn evaluate_fft_bit_reversed<F: IsFFTField + IsSubFieldOf<E>>(
poly: &Polynomial<FieldElement<E>>,
blowup_factor: usize,
domain_size: Option<usize>,
) -> Result<Vec<FieldElement<E>>, FFTError>
where
E: Send + Sync,
{
let domain_size = domain_size.unwrap_or(0);
let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor;
if len.trailing_zeros() as u64 > F::TWO_ADICITY {
return Err(FFTError::DomainSizeError(len.trailing_zeros() as usize));
}
if poly.coefficients().is_empty() {
return Ok(vec![FieldElement::zero(); len]);
}

let mut coeffs = poly.coefficients().to_vec();
coeffs.resize(len, FieldElement::zero());

let order = len.trailing_zeros() as u64;
let layer_twiddles =
LayerTwiddles::<F>::new(order).ok_or(FFTError::DomainSizeError(order as usize))?;
dispatch_fft(&mut coeffs, &layer_twiddles)?;
Ok(coeffs)
}
Comment on lines +87 to +112
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low — No test for the new public API

evaluate_fft_bit_reversed is used in a correctness-critical path (FRI commit phase) but has no unit test. The invariant to verify is simple: the output should equal evaluate_fft with in_place_bit_reverse_permute applied.

#[test]
fn evaluate_fft_bit_reversed_matches_evaluate_fft_permuted() {
    use crate::fft::cpu::bit_reversing::in_place_bit_reverse_permute;
    let coeffs: Vec<FE> = (0u64..8).map(FE::from).collect();
    let poly = Polynomial::new(&coeffs);
    let mut expected = Polynomial::evaluate_fft::<F>(&poly, 2, None).unwrap();
    in_place_bit_reverse_permute(&mut expected);
    let got = Polynomial::evaluate_fft_bit_reversed::<F>(&poly, 2, None).unwrap();
    assert_eq!(got, expected);
}

Comment on lines +87 to +112
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low — Duplicate setup code with evaluate_fft

evaluate_fft_bit_reversed is identical to evaluate_fft except it skips the final in_place_bit_reverse_permute. Both functions share ~20 lines of setup (domain-size computation, empty-poly check, zero-padding, twiddle construction). If either is changed in the future the other will silently diverge.

Consider a private helper:

fn evaluate_fft_raw<F: IsFFTField + IsSubFieldOf<E>, E: IsField + Send + Sync>(
    poly: &Polynomial<FieldElement<E>>,
    blowup_factor: usize,
    domain_size: Option<usize>,
) -> Result<Vec<FieldElement<E>>, FFTError> {
    // shared setup + dispatch_fft, no permutation
}

Then evaluate_fft calls the helper and permutes, while evaluate_fft_bit_reversed just calls the helper.


/// Returns `N` evaluations with an offset of this polynomial using FFT over a domain in a subfield F of E
/// (so the results are P(w^i), with w being a primitive root of unity).
/// `N = max(self.coeff_len(), domain_size).next_power_of_two() * blowup_factor`.
Expand Down
24 changes: 23 additions & 1 deletion crypto/math/src/field/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,29 @@ impl<F: IsField> FieldElement<F> {
/// Computes the multiplicative inverses of a slice of field elements
/// The algorithm just performs one inversion and several multiplications and should be used
/// when wanting to invert several elements together
pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> {
pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError>
where
Self: Send + Sync,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Medium — Unconditional API-breaking bound

The where Self: Send + Sync constraint is on the public function signature unconditionally, not gated by #[cfg(feature = "parallel")]. This means every caller of inplace_batch_inverse must now satisfy the bound even on no_std targets or crates that never enable the parallel feature, which is a breaking change for any downstream field implementation whose BaseType is not Send + Sync.

Consider splitting:

Suggested change
Self: Send + Sync,
pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> {

…and adding the Send + Sync bound only in the #[cfg(feature = "parallel")] branch, e.g. via an internal helper with the tighter bound, keeping the public API signature unchanged.

{
#[cfg(feature = "parallel")]
{
// Montgomery batch inverse has a serial prefix-product dependency, but
// chunks are independent — each chunk inverts its own elements without
// needing values from other chunks. Trade K-1 extra field inversions
// (negligible vs ~2N mults per chunk) for K-way parallelism.
const PARALLEL_BATCH_INV_THRESHOLD: usize = 1 << 16;
if numbers.len() >= PARALLEL_BATCH_INV_THRESHOLD {
use rayon::prelude::*;
let chunk_size = numbers.len().div_ceil(rayon::current_num_threads().max(1));
return numbers
.par_chunks_mut(chunk_size)
.try_for_each(Self::inplace_batch_inverse_sequential);
}
}
Self::inplace_batch_inverse_sequential(numbers)
}

fn inplace_batch_inverse_sequential(numbers: &mut [Self]) -> Result<(), FieldError> {
if numbers.is_empty() {
return Ok(());
}
Expand Down
29 changes: 16 additions & 13 deletions crypto/stark/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,11 @@ pub trait IsStarkProver<
/// Builds a Merkle tree commitment from column-major LDE evaluations with
/// bit-reverse permutation, without cloning the full evaluation matrix.
///
/// For each row index `i`, we hash `col_0[br(i)] || col_1[br(i)] || ...`
/// where `br(i)` is the bit-reversal of `i`. This produces the same Merkle
/// tree as the old clone + bit-reverse + columns2rows + batch_commit flow,
/// but avoids allocating the cloned and transposed matrices entirely.
/// Hashes `col_0[k] || col_1[k] || ...` for k = 0..num_rows (sequential column
/// reads, cache-friendly), then permutes the hash vector in bit-reversed order
/// so leaves[i] = hash(col_0[br(i)] || col_1[br(i)] || ...). Same Merkle tree
/// as reading at br(row_idx) inside the hashing loop, but the scattered column
/// access is replaced by a single small bit-reverse pass over 32-byte digests.
fn commit_columns_bit_reversed<E>(
columns: &[Vec<FieldElement<E>>],
) -> Option<(BatchedMerkleTree<E>, Commitment)>
Expand Down Expand Up @@ -420,21 +421,20 @@ pub trait IsStarkProver<
#[cfg(not(feature = "parallel"))]
let iter = 0..num_rows;

// One allocation per row (was one per field element): write all columns
// into a single buffer, then hash once.
let hashed_leaves: Vec<Commitment> = iter
.map(|row_idx| {
let br_idx = reverse_index(row_idx, num_rows as u64);
let mut hashed_leaves: Vec<Commitment> = iter
.map(|k| {
let total_bytes = num_cols * byte_len;
let mut buf = vec![0u8; total_bytes];
for col_idx in 0..num_cols {
columns[col_idx][br_idx]
columns[col_idx][k]
.write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]);
}
BatchedMerkleTreeBackend::<E>::hash_bytes(&buf)
})
.collect();

in_place_bit_reverse_permute(&mut hashed_leaves);

let tree = BatchedMerkleTree::<E>::build_from_hashed_leaves(hashed_leaves)?;
let root = tree.root;
Some((tree, root))
Expand Down Expand Up @@ -1109,9 +1109,12 @@ pub trait IsStarkProver<
let t_sub = Instant::now();
let deep_poly =
Polynomial::interpolate_fft::<Field>(&deep_evals).expect("iFFT should succeed");
let mut lde_evals = Polynomial::evaluate_fft::<Field>(&deep_poly, 1, Some(domain_size))
.expect("FFT should succeed");
in_place_bit_reverse_permute(&mut lde_evals);
// FRI commit_phase consumes bit-reversed evaluations natively. Request them
// directly from evaluate_fft_bit_reversed to avoid a pair of redundant permutes
// (evaluate_fft's internal natural-order permute + an external re-bit-reverse).
let lde_evals =
Polynomial::evaluate_fft_bit_reversed::<Field>(&deep_poly, 1, Some(domain_size))
.expect("FFT should succeed");
#[cfg(feature = "instruments")]
let r4_fft_dur = t_sub.elapsed();

Expand Down
14 changes: 11 additions & 3 deletions prover/src/tables/trace_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1686,15 +1686,23 @@ struct CollectedOps {
}

/// Chunk raw ops and generate one trace table per chunk.
fn chunk_and_generate<T>(
fn chunk_and_generate<T: Sync>(
ops: &[T],
max_rows: usize,
generate: impl Fn(&[T]) -> TraceTable<GoldilocksField, GoldilocksExtension>,
generate: impl Fn(&[T]) -> TraceTable<GoldilocksField, GoldilocksExtension> + Sync + Send,
) -> Vec<TraceTable<GoldilocksField, GoldilocksExtension>> {
if ops.is_empty() {
vec![generate(&[])]
} else {
ops.chunks(max_rows).map(generate).collect()
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
ops.par_chunks(max_rows).map(&generate).collect()
}
#[cfg(not(feature = "parallel"))]
{
ops.chunks(max_rows).map(generate).collect()
}
}
}

Expand Down
Loading