diff --git a/crypto/math/src/fft/polynomial.rs b/crypto/math/src/fft/polynomial.rs index 9157903fd..8b90cbd1b 100644 --- a/crypto/math/src/fft/polynomial.rs +++ b/crypto/math/src/fft/polynomial.rs @@ -80,6 +80,40 @@ impl Polynomial> { evaluate_fft_cpu::(&coeffs) } + /// Same as `evaluate_fft` but returns the evaluations in bit-reversed order, + /// skipping the final natural-order permutation. Use when the consumer expects + /// bit-reversed input (e.g. FRI commit phase, which pairs consecutive values as + /// {f(x), f(-x)}). + /// + /// Stricter validation than `evaluate_fft`: non-power-of-two `len` is rejected + /// even for empty polynomials. `evaluate_fft`'s guard lives inside + /// `evaluate_fft_cpu`, which its empty-polynomial fast path skips. + pub fn evaluate_fft_bit_reversed>( + poly: &Polynomial>, + blowup_factor: usize, + domain_size: Option, + ) -> Result>, FFTError> + where + E: Send + Sync, + { + let domain_size = domain_size.unwrap_or(0); + let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor; + if len.trailing_zeros() as u64 > F::TWO_ADICITY { + return Err(FFTError::DomainSizeError(len.trailing_zeros() as usize)); + } + if !len.is_power_of_two() { + return Err(FFTError::InputError(len)); + } + + if poly.coefficients().is_empty() { + return Ok(vec![FieldElement::zero(); len]); + } + + let mut coeffs = poly.coefficients().to_vec(); + coeffs.resize(len, FieldElement::zero()); + evaluate_fft_cpu_inner::(coeffs, false) + } + /// Returns `N` evaluations with an offset of this polynomial using FFT over a domain in a subfield F of E /// (so the results are P(w^i), with w being a primitive root of unity). /// `N = max(self.coeff_len(), domain_size).next_power_of_two() * blowup_factor`. @@ -278,7 +312,10 @@ where Polynomial::interpolate_fft::(values.as_slice()).unwrap() } -pub fn evaluate_fft_cpu(coeffs: &[FieldElement]) -> Result>, FFTError> +fn evaluate_fft_cpu_inner( + mut coeffs: Vec>, + permute_after: bool, +) -> Result>, FFTError> where F: IsFFTField + IsSubFieldOf, E: IsField + Send + Sync, @@ -291,10 +328,19 @@ where let layer_twiddles = LayerTwiddles::::new(order).ok_or(FFTError::DomainSizeError(order as usize))?; - let mut result = coeffs.to_vec(); - dispatch_fft(&mut result, &layer_twiddles)?; - in_place_bit_reverse_permute(&mut result); - Ok(result) + dispatch_fft(&mut coeffs, &layer_twiddles)?; + if permute_after { + in_place_bit_reverse_permute(&mut coeffs); + } + Ok(coeffs) +} + +pub fn evaluate_fft_cpu(coeffs: &[FieldElement]) -> Result>, FFTError> +where + F: IsFFTField + IsSubFieldOf, + E: IsField + Send + Sync, +{ + evaluate_fft_cpu_inner::(coeffs.to_vec(), true) } pub fn interpolate_fft_cpu( @@ -514,4 +560,44 @@ mod tests { assert_eq!(reference, buffer, "Mismatch for seed {}", seed); } } + + #[test] + fn evaluate_fft_bit_reversed_matches_evaluate_fft_then_permute() { + for order in 1..=8 { + let n = 1usize << order; + let coeffs: Vec = (0..n).map(|i| FE::from((i * 17 + 3) as u64)).collect(); + let poly = Polynomial::new(&coeffs); + + for blowup_factor in [1, 2, 4] { + for domain_mult in [1, 2] { + let size = Some(n * domain_mult); + let mut expected = + Polynomial::evaluate_fft::(&poly, blowup_factor, size).unwrap(); + in_place_bit_reverse_permute(&mut expected); + + let got = + Polynomial::evaluate_fft_bit_reversed::(&poly, blowup_factor, size) + .unwrap(); + + assert_eq!( + got, expected, + "order={order}, blowup={blowup_factor}, domain_mult={domain_mult}" + ); + } + } + } + } + + #[test] + fn evaluate_fft_bit_reversed_rejects_non_power_of_two_blowup() { + let coeffs: Vec = (0..8).map(|i| FE::from(i as u64)).collect(); + let poly = Polynomial::new(&coeffs); + + let err = Polynomial::evaluate_fft_bit_reversed::(&poly, 3, Some(8)); + assert!(matches!(err, Err(FFTError::InputError(_)))); + + let empty = Polynomial::::new(&[]); + let err = Polynomial::evaluate_fft_bit_reversed::(&empty, 3, None); + assert!(matches!(err, Err(FFTError::InputError(_)))); + } } diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 41ccb8366..d418f7773 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -361,10 +361,11 @@ pub trait IsStarkProver< /// Builds a Merkle tree commitment from column-major LDE evaluations with /// bit-reverse permutation, without cloning the full evaluation matrix. /// - /// For each row index `i`, we hash `col_0[br(i)] || col_1[br(i)] || ...` - /// where `br(i)` is the bit-reversal of `i`. This produces the same Merkle - /// tree as the old clone + bit-reverse + columns2rows + batch_commit flow, - /// but avoids allocating the cloned and transposed matrices entirely. + /// Hashes `col_0[k] || col_1[k] || ...` for k = 0..num_rows (sequential column + /// reads, cache-friendly), then permutes the hash vector in bit-reversed order + /// so leaves[i] = hash(col_0[br(i)] || col_1[br(i)] || ...). Same Merkle tree + /// as reading at br(row_idx) inside the hashing loop, but the scattered column + /// access is replaced by a single small bit-reverse pass over 32-byte digests. fn commit_columns_bit_reversed( columns: &[Vec>], ) -> Option<(BatchedMerkleTree, Commitment)> @@ -392,21 +393,20 @@ pub trait IsStarkProver< #[cfg(not(feature = "parallel"))] let iter = 0..num_rows; - // One allocation per row (was one per field element): write all columns - // into a single buffer, then hash once. - let hashed_leaves: Vec = iter - .map(|row_idx| { - let br_idx = reverse_index(row_idx, num_rows as u64); + let mut hashed_leaves: Vec = iter + .map(|k| { let total_bytes = num_cols * byte_len; let mut buf = vec![0u8; total_bytes]; for col_idx in 0..num_cols { - columns[col_idx][br_idx] + columns[col_idx][k] .write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]); } BatchedMerkleTreeBackend::::hash_bytes(&buf) }) .collect(); + in_place_bit_reverse_permute(&mut hashed_leaves); + let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; Some((tree, root)) @@ -1081,9 +1081,12 @@ pub trait IsStarkProver< let t_sub = Instant::now(); let deep_poly = Polynomial::interpolate_fft::(&deep_evals).expect("iFFT should succeed"); - let mut lde_evals = Polynomial::evaluate_fft::(&deep_poly, 1, Some(domain_size)) - .expect("FFT should succeed"); - in_place_bit_reverse_permute(&mut lde_evals); + // FRI commit_phase consumes bit-reversed evaluations natively. Request them + // directly from evaluate_fft_bit_reversed to avoid a pair of redundant permutes + // (evaluate_fft's internal natural-order permute + an external re-bit-reverse). + let lde_evals = + Polynomial::evaluate_fft_bit_reversed::(&deep_poly, 1, Some(domain_size)) + .expect("FFT should succeed"); #[cfg(feature = "instruments")] let r4_fft_dur = t_sub.elapsed(); diff --git a/crypto/stark/src/tests/prover_tests.rs b/crypto/stark/src/tests/prover_tests.rs index 55d58da7e..d0e7fe8d1 100644 --- a/crypto/stark/src/tests/prover_tests.rs +++ b/crypto/stark/src/tests/prover_tests.rs @@ -233,3 +233,40 @@ fn test_decompose_and_extend_d2_matches_original() { assert_eq!(new_result[1][i], original[1][i], "H₁ mismatch at index {i}"); } } + +#[test] +fn commit_columns_bit_reversed_matches_naive_reference() { + use crate::config::{BatchedMerkleTree, BatchedMerkleTreeBackend, Commitment}; + use math::fft::cpu::bit_reversing::reverse_index; + use math::traits::ByteConversion; + + let num_rows = 8usize; + let num_cols = 3usize; + let columns: Vec> = (0..num_cols) + .map(|c| { + (0..num_rows) + .map(|r| Felt::from((c * num_rows + r + 1) as u64)) + .collect() + }) + .collect(); + + let byte_len = ::BYTE_LEN; + let reference_leaves: Vec = (0..num_rows) + .map(|i| { + let br_i = reverse_index(i, num_rows as u64); + let mut buf = vec![0u8; num_cols * byte_len]; + for c in 0..num_cols { + columns[c][br_i].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]); + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) + }) + .collect(); + let reference_tree = + BatchedMerkleTree::::build_from_hashed_leaves(reference_leaves).unwrap(); + + let (_tree, root) = + Prover::::commit_columns_bit_reversed(&columns) + .unwrap(); + + assert_eq!(reference_tree.root, root); +}