Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 91 additions & 5 deletions crypto/math/src/fft/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,40 @@ impl<E: IsField> Polynomial<FieldElement<E>> {
evaluate_fft_cpu::<F, E>(&coeffs)
}

/// Same as `evaluate_fft` but returns the evaluations in bit-reversed order,
/// skipping the final natural-order permutation. Use when the consumer expects
/// bit-reversed input (e.g. FRI commit phase, which pairs consecutive values as
/// {f(x), f(-x)}).
///
/// Stricter validation than `evaluate_fft`: non-power-of-two `len` is rejected
/// even for empty polynomials. `evaluate_fft`'s guard lives inside
/// `evaluate_fft_cpu`, which its empty-polynomial fast path skips.
pub fn evaluate_fft_bit_reversed<F: IsFFTField + IsSubFieldOf<E>>(
poly: &Polynomial<FieldElement<E>>,
blowup_factor: usize,
domain_size: Option<usize>,
) -> Result<Vec<FieldElement<E>>, FFTError>
where
E: Send + Sync,
{
let domain_size = domain_size.unwrap_or(0);
let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low – blowup_factor = 0 produces a confusing error

let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor;

When blowup_factor = 0, len = 0. On a 64-bit target 0usize.trailing_zeros() returns 64, so the first guard fires and returns DomainSizeError(64) — a misleading message for a caller that passed a zero blowup.

evaluate_fft has the same behaviour, so this isn't a regression, but since evaluate_fft_bit_reversed is a new public API it's a good place to add an early guard:

Suggested change
let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor;
if blowup_factor == 0 {
return Err(FFTError::InputError(0));
}
let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor;

if len.trailing_zeros() as u64 > F::TWO_ADICITY {
return Err(FFTError::DomainSizeError(len.trailing_zeros() as usize));
}
if !len.is_power_of_two() {
Comment thread
gabrielbosio marked this conversation as resolved.
return Err(FFTError::InputError(len));
}

if poly.coefficients().is_empty() {
return Ok(vec![FieldElement::zero(); len]);
}

let mut coeffs = poly.coefficients().to_vec();
coeffs.resize(len, FieldElement::zero());
evaluate_fft_cpu_inner::<F, E>(coeffs, false)
}

/// Returns `N` evaluations with an offset of this polynomial using FFT over a domain in a subfield F of E
/// (so the results are P(w^i), with w being a primitive root of unity).
/// `N = max(self.coeff_len(), domain_size).next_power_of_two() * blowup_factor`.
Expand Down Expand Up @@ -278,7 +312,10 @@ where
Polynomial::interpolate_fft::<F>(values.as_slice()).unwrap()
}

pub fn evaluate_fft_cpu<F, E>(coeffs: &[FieldElement<E>]) -> Result<Vec<FieldElement<E>>, FFTError>
fn evaluate_fft_cpu_inner<F, E>(
mut coeffs: Vec<FieldElement<E>>,
permute_after: bool,
) -> Result<Vec<FieldElement<E>>, FFTError>
where
F: IsFFTField + IsSubFieldOf<E>,
E: IsField + Send + Sync,
Expand All @@ -291,10 +328,19 @@ where
let layer_twiddles =
LayerTwiddles::<F>::new(order).ok_or(FFTError::DomainSizeError(order as usize))?;

let mut result = coeffs.to_vec();
dispatch_fft(&mut result, &layer_twiddles)?;
in_place_bit_reverse_permute(&mut result);
Ok(result)
dispatch_fft(&mut coeffs, &layer_twiddles)?;
if permute_after {
in_place_bit_reverse_permute(&mut coeffs);
}
Ok(coeffs)
}

pub fn evaluate_fft_cpu<F, E>(coeffs: &[FieldElement<E>]) -> Result<Vec<FieldElement<E>>, FFTError>
where
F: IsFFTField + IsSubFieldOf<E>,
E: IsField + Send + Sync,
{
evaluate_fft_cpu_inner::<F, E>(coeffs.to_vec(), true)
}

pub fn interpolate_fft_cpu<F, E>(
Expand Down Expand Up @@ -514,4 +560,44 @@ mod tests {
assert_eq!(reference, buffer, "Mismatch for seed {}", seed);
}
}

#[test]
fn evaluate_fft_bit_reversed_matches_evaluate_fft_then_permute() {
for order in 1..=8 {
let n = 1usize << order;
let coeffs: Vec<FE> = (0..n).map(|i| FE::from((i * 17 + 3) as u64)).collect();
let poly = Polynomial::new(&coeffs);

for blowup_factor in [1, 2, 4] {
for domain_mult in [1, 2] {
let size = Some(n * domain_mult);
let mut expected =
Polynomial::evaluate_fft::<F>(&poly, blowup_factor, size).unwrap();
in_place_bit_reverse_permute(&mut expected);

let got =
Polynomial::evaluate_fft_bit_reversed::<F>(&poly, blowup_factor, size)
.unwrap();

assert_eq!(
got, expected,
"order={order}, blowup={blowup_factor}, domain_mult={domain_mult}"
);
}
}
}
}

#[test]
fn evaluate_fft_bit_reversed_rejects_non_power_of_two_blowup() {
let coeffs: Vec<FE> = (0..8).map(|i| FE::from(i as u64)).collect();
let poly = Polynomial::new(&coeffs);

let err = Polynomial::evaluate_fft_bit_reversed::<F>(&poly, 3, Some(8));
assert!(matches!(err, Err(FFTError::InputError(_))));

let empty = Polynomial::<FE>::new(&[]);
let err = Polynomial::evaluate_fft_bit_reversed::<F>(&empty, 3, None);
assert!(matches!(err, Err(FFTError::InputError(_))));
}
Comment thread
gabrielbosio marked this conversation as resolved.
}
29 changes: 16 additions & 13 deletions crypto/stark/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,11 @@ pub trait IsStarkProver<
/// Builds a Merkle tree commitment from column-major LDE evaluations with
/// bit-reverse permutation, without cloning the full evaluation matrix.
///
/// For each row index `i`, we hash `col_0[br(i)] || col_1[br(i)] || ...`
/// where `br(i)` is the bit-reversal of `i`. This produces the same Merkle
/// tree as the old clone + bit-reverse + columns2rows + batch_commit flow,
/// but avoids allocating the cloned and transposed matrices entirely.
/// Hashes `col_0[k] || col_1[k] || ...` for k = 0..num_rows (sequential column
/// reads, cache-friendly), then permutes the hash vector in bit-reversed order
/// so leaves[i] = hash(col_0[br(i)] || col_1[br(i)] || ...). Same Merkle tree
/// as reading at br(row_idx) inside the hashing loop, but the scattered column
/// access is replaced by a single small bit-reverse pass over 32-byte digests.
fn commit_columns_bit_reversed<E>(
columns: &[Vec<FieldElement<E>>],
) -> Option<(BatchedMerkleTree<E>, Commitment)>
Expand Down Expand Up @@ -392,21 +393,20 @@ pub trait IsStarkProver<
#[cfg(not(feature = "parallel"))]
let iter = 0..num_rows;

// One allocation per row (was one per field element): write all columns
// into a single buffer, then hash once.
let hashed_leaves: Vec<Commitment> = iter
.map(|row_idx| {
let br_idx = reverse_index(row_idx, num_rows as u64);
let mut hashed_leaves: Vec<Commitment> = iter
.map(|k| {
let total_bytes = num_cols * byte_len;
let mut buf = vec![0u8; total_bytes];
for col_idx in 0..num_cols {
columns[col_idx][br_idx]
columns[col_idx][k]
.write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]);
}
BatchedMerkleTreeBackend::<E>::hash_bytes(&buf)
})
.collect();

in_place_bit_reverse_permute(&mut hashed_leaves);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low – Sequential permute after parallel hashing

When the parallel feature is enabled, the iter.map(...).collect() runs on a Rayon thread pool, but in_place_bit_reverse_permute is always sequential. For large num_rows this single-threaded pass could become a meaningful bottleneck relative to the parallelized hash work above it.

The bit-reverse permutation is a small, branch-heavy scatter/gather that parallelises poorly at typical domain sizes, so this is only worth addressing if profiling shows it as a hot spot. Just noting it so the trade-off is explicit.


let tree = BatchedMerkleTree::<E>::build_from_hashed_leaves(hashed_leaves)?;
let root = tree.root;
Comment thread
gabrielbosio marked this conversation as resolved.
Some((tree, root))
Expand Down Expand Up @@ -1081,9 +1081,12 @@ pub trait IsStarkProver<
let t_sub = Instant::now();
let deep_poly =
Polynomial::interpolate_fft::<Field>(&deep_evals).expect("iFFT should succeed");
let mut lde_evals = Polynomial::evaluate_fft::<Field>(&deep_poly, 1, Some(domain_size))
.expect("FFT should succeed");
in_place_bit_reverse_permute(&mut lde_evals);
// FRI commit_phase consumes bit-reversed evaluations natively. Request them
// directly from evaluate_fft_bit_reversed to avoid a pair of redundant permutes
// (evaluate_fft's internal natural-order permute + an external re-bit-reverse).
let lde_evals =
Polynomial::evaluate_fft_bit_reversed::<Field>(&deep_poly, 1, Some(domain_size))
.expect("FFT should succeed");
#[cfg(feature = "instruments")]
let r4_fft_dur = t_sub.elapsed();

Expand Down
37 changes: 37 additions & 0 deletions crypto/stark/src/tests/prover_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,40 @@ fn test_decompose_and_extend_d2_matches_original() {
assert_eq!(new_result[1][i], original[1][i], "H₁ mismatch at index {i}");
}
}

#[test]
fn commit_columns_bit_reversed_matches_naive_reference() {
use crate::config::{BatchedMerkleTree, BatchedMerkleTreeBackend, Commitment};
use math::fft::cpu::bit_reversing::reverse_index;
use math::traits::ByteConversion;

let num_rows = 8usize;
let num_cols = 3usize;
let columns: Vec<Vec<Felt>> = (0..num_cols)
.map(|c| {
(0..num_rows)
.map(|r| Felt::from((c * num_rows + r + 1) as u64))
.collect()
})
.collect();

let byte_len = <Felt as ByteConversion>::BYTE_LEN;
let reference_leaves: Vec<Commitment> = (0..num_rows)
.map(|i| {
let br_i = reverse_index(i, num_rows as u64);
let mut buf = vec![0u8; num_cols * byte_len];
for c in 0..num_cols {
columns[c][br_i].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]);
}
BatchedMerkleTreeBackend::<GoldilocksField>::hash_bytes(&buf)
})
.collect();
let reference_tree =
BatchedMerkleTree::<GoldilocksField>::build_from_hashed_leaves(reference_leaves).unwrap();

let (_tree, root) =
Prover::<GoldilocksField, GoldilocksField, ()>::commit_columns_bit_reversed(&columns)
.unwrap();

assert_eq!(reference_tree.root, root);
}
Loading