From 41316ff0e392ad4841886a42de4290a7cb6eb850 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 15:42:20 -0300 Subject: [PATCH 01/17] docs: add prover parallelism improvement plan (6 tasks) --- .../plans/2026-04-21-prover-parallelism.md | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 docs/superpowers/plans/2026-04-21-prover-parallelism.md diff --git a/docs/superpowers/plans/2026-04-21-prover-parallelism.md b/docs/superpowers/plans/2026-04-21-prover-parallelism.md new file mode 100644 index 000000000..84371d635 --- /dev/null +++ b/docs/superpowers/plans/2026-04-21-prover-parallelism.md @@ -0,0 +1,55 @@ +# Prover Parallelism Improvements + +**Goal:** Close the ~2x performance gap with Plonky3 by parallelizing the hottest sequential code paths in the STARK prover. + +**Architecture:** Five independent parallelism improvements. No protocol changes, no new data structures. Each produces identical outputs but runs faster. All use Rayon gated behind `#[cfg(feature = "parallel")]`. + +--- + +## Task 1: Parallel FRI fold + +**File:** `crypto/stark/src/fri/fri_functions.rs` + +`fold_evaluations_in_place` is a plain `for j in 0..half` loop. At layer 0 with LDE 2^21, ~1M sequential iterations. Each output depends only on its own pair -- no cross-iteration dependency. Use `into_par_iter` with a temp buffer (aliasing prevents in-place parallel write). + +## Task 2: Parallel FRI leaf construction + +**File:** `crypto/stark/src/fri/mod.rs` + +Leaf array `evals.chunks_exact(2).map(...).collect()` is sequential. Use `par_chunks_exact`. + +## Task 3: Parallel LogUp fingerprint computation + +**File:** `crypto/stark/src/lookup.rs` + +Two fingerprint loops in `compute_logup_batched_term_column` (lines 1619-1650): sequential over trace_len rows. Each row reads shared immutable data. Use `into_par_iter`. Also parallelize the final term computation loop and `compute_logup_term_column`. + +## Task 4: Parallel table_contribution sum + +**File:** `crypto/stark/src/lookup.rs` + +`build_accumulated_column_from_terms` sums term columns across all rows sequentially. Use `into_par_iter` with `reduce`. Note: the accumulated column running-sum loop CANNOT be parallelized. + +## Task 5: Chunked parallel batch inverse + +**Files:** `crypto/math/src/field/element.rs`, `crypto/stark/src/lookup.rs` + +Montgomery batch inverse is sequential. Split into K=num_threads chunks, run one independent batch inverse per chunk via `par_chunks_mut`. Cost: K-1 extra inversions, but O(N/K) per thread. Threshold at 1024 elements. + +## Task 6: Benchmark and validate + +Run `cargo bench --bench profile_vm_prover --features "parallel,instruments"`. Compare against baseline. Push. + +--- + +## Expected Impact + +| Optimization | Sequential cost | Parallel | Speedup | +|---|---|---|---| +| FRI fold | O(N) ext-field per layer | O(N/P) | ~P | +| FRI leaves | O(N) clones per layer | O(N/P) | ~P | +| Fingerprint loops | O(N) F*E per pair | O(N/P) | ~P | +| Batch inverse | O(N) prefix-suffix | O(N/P + K inv) | ~P large N | +| table_contribution | O(N*cols) | O(N*cols/P) | ~P | + +These target ~20-30% of prover time. Combined with MMCS/shared-FRI (~30-40%), closes most of the 2x gap. From 1451d7ce1e4c5216959339caa595d1bedd10d5e9 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 15:49:39 -0300 Subject: [PATCH 02/17] perf: parallelize LogUp fingerprint computation with rayon --- crypto/stark/src/lookup.rs | 256 +++++++++++++++++++++++++------------ 1 file changed, 175 insertions(+), 81 deletions(-) diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 80d40a78c..5d7332265 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -1463,45 +1463,67 @@ where // into the fingerprint without collecting bus elements into intermediate Vecs. let bus_id_f = FieldElement::::from(table_interaction.bus_id); let shifts = PackingShifts::::new(); - let mut fingerprints: Vec> = Vec::with_capacity(trace_len); - for row in 0..trace_len { - // Accumulate fingerprint directly: bus_id * α^0 + Σ element_i * α^(1+i) - let mut linear_combination = &bus_id_f * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &table_interaction.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut linear_combination, - &shifts, - ); - alpha_offset += consumed; - } - fingerprints.push(z - &linear_combination); + // Fingerprint computation is embarrassingly parallel: each row reads shared + // immutable data (main_segment_cols, alpha_powers) with no cross-row dependencies. + #[cfg(feature = "parallel")] + let mut fingerprints: Vec> = (0..trace_len) + .into_par_iter() + .map(|row| { + let mut linear_combination = &bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &table_interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut linear_combination, + &shifts, + ); + alpha_offset += consumed; + } + z - &linear_combination + }) + .collect(); + #[cfg(not(feature = "parallel"))] + let mut fingerprints: Vec> = (0..trace_len) + .map(|row| { + let mut linear_combination = &bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &table_interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut linear_combination, + &shifts, + ); + alpha_offset += consumed; + } + z - &linear_combination + }) + .collect(); - #[cfg(feature = "debug-checks")] - { - // Reconstruct base_elements for debug logging - let mut base_elements: Vec> = vec![bus_id_f.clone()]; - base_elements.extend( - table_interaction - .values - .iter() - .flat_map(|bv| bv.combine_from(|col| main_segment_cols[col][row].clone())), - ); - crate::bus_debug::log_interaction( - _table_name, - row, - table_interaction.bus_id, - table_interaction.is_sender, - &multiplicities[row].canonical(), - &base_elements, - fingerprints.last().unwrap(), - ); - } + #[cfg(feature = "debug-checks")] + for row in 0..trace_len { + let mut base_elements: Vec> = vec![bus_id_f.clone()]; + base_elements.extend( + table_interaction + .values + .iter() + .flat_map(|bv| bv.combine_from(|col| main_segment_cols[col][row].clone())), + ); + crate::bus_debug::log_interaction( + _table_name, + row, + table_interaction.bus_id, + table_interaction.is_sender, + &multiplicities[row].canonical(), + &base_elements, + &fingerprints[row], + ); } FieldElement::inplace_batch_inverse(&mut fingerprints) @@ -1614,40 +1636,94 @@ where let shifts = PackingShifts::::new(); // Concatenate both fingerprint vectors for a single batch inversion - let mut all_fingerprints: Vec> = Vec::with_capacity(2 * trace_len); + // + // Fingerprint computation is embarrassingly parallel: each row reads shared + // immutable data (main_segment_cols, alpha_powers) with no cross-row dependencies. + #[cfg(feature = "parallel")] + let fingerprints_a: Vec> = (0..trace_len) + .into_par_iter() + .map(|row| { + let mut lc_a = &bus_id_a * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction_a.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc_a, + &shifts, + ); + alpha_offset += consumed; + } + z - &lc_a + }) + .collect(); + #[cfg(not(feature = "parallel"))] + let fingerprints_a: Vec> = (0..trace_len) + .map(|row| { + let mut lc_a = &bus_id_a * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction_a.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc_a, + &shifts, + ); + alpha_offset += consumed; + } + z - &lc_a + }) + .collect(); - for row in 0..trace_len { - let mut lc_a = &bus_id_a * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction_a.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc_a, - &shifts, - ); - alpha_offset += consumed; - } - all_fingerprints.push(z - &lc_a); - } - for row in 0..trace_len { - let mut lc_b = &bus_id_b * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction_b.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc_b, - &shifts, - ); - alpha_offset += consumed; - } - all_fingerprints.push(z - &lc_b); - } + #[cfg(feature = "parallel")] + let fingerprints_b: Vec> = (0..trace_len) + .into_par_iter() + .map(|row| { + let mut lc_b = &bus_id_b * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction_b.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc_b, + &shifts, + ); + alpha_offset += consumed; + } + z - &lc_b + }) + .collect(); + #[cfg(not(feature = "parallel"))] + let fingerprints_b: Vec> = (0..trace_len) + .map(|row| { + let mut lc_b = &bus_id_b * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction_b.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc_b, + &shifts, + ); + alpha_offset += consumed; + } + z - &lc_b + }) + .collect(); + + // Concatenate into single vec for batch inversion + let mut all_fingerprints: Vec> = + Vec::with_capacity(2 * trace_len); + all_fingerprints.extend(fingerprints_a); + all_fingerprints.extend(fingerprints_b); // Single batch inversion for all 2*N fingerprints FieldElement::inplace_batch_inverse(&mut all_fingerprints) @@ -1655,17 +1731,35 @@ where // Compute batched terms: term[i] = m_a[i] / fp_a[i] ± m_b[i] / fp_b[i] // Use conditional negation instead of E×E sign multiplication - (0..trace_len) - .map(|row| { - let fp_a_inv = &all_fingerprints[row]; - let fp_b_inv = &all_fingerprints[trace_len + row]; - let term_a = &multiplicities_a[row] * fp_a_inv; - let term_b = &multiplicities_b[row] * fp_b_inv; - let term_a = if negate_a { -term_a } else { term_a }; - let term_b = if negate_b { -term_b } else { term_b }; - term_a + term_b - }) - .collect() + #[cfg(feature = "parallel")] + { + (0..trace_len) + .into_par_iter() + .map(|row| { + let fp_a_inv = &all_fingerprints[row]; + let fp_b_inv = &all_fingerprints[trace_len + row]; + let term_a = &multiplicities_a[row] * fp_a_inv; + let term_b = &multiplicities_b[row] * fp_b_inv; + let term_a = if negate_a { -term_a } else { term_a }; + let term_b = if negate_b { -term_b } else { term_b }; + term_a + term_b + }) + .collect() + } + #[cfg(not(feature = "parallel"))] + { + (0..trace_len) + .map(|row| { + let fp_a_inv = &all_fingerprints[row]; + let fp_b_inv = &all_fingerprints[trace_len + row]; + let term_a = &multiplicities_a[row] * fp_a_inv; + let term_b = &multiplicities_b[row] * fp_b_inv; + let term_a = if negate_a { -term_a } else { term_a }; + let term_b = if negate_b { -term_b } else { term_b }; + term_a + term_b + }) + .collect() + } } /// Builds the circular accumulated column from pre-computed term columns. From 3ab5504ba5fac636be508cf977a05b687a613c77 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 15:51:40 -0300 Subject: [PATCH 03/17] perf: parallelize table_contribution sum with rayon reduce --- crypto/stark/src/lookup.rs | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 5d7332265..dd5032e2b 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -1784,12 +1784,31 @@ where let trace_len = term_columns[0].len(); // Compute L = sum of all terms across all rows - let mut table_contribution = FieldElement::::zero(); - for row in 0..trace_len { - for col in term_columns { - table_contribution = &table_contribution + &col[row]; + #[cfg(feature = "parallel")] + let table_contribution: FieldElement = { + use rayon::prelude::*; + (0..trace_len) + .into_par_iter() + .map(|row| { + let mut row_sum = FieldElement::::zero(); + for col in term_columns { + row_sum = row_sum + &col[row]; + } + row_sum + }) + .reduce(FieldElement::zero, |a, b| a + b) + }; + + #[cfg(not(feature = "parallel"))] + let table_contribution = { + let mut total = FieldElement::::zero(); + for row in 0..trace_len { + for col in term_columns { + total = &total + &col[row]; + } } - } + total + }; // offset_per_row = L / N let n = FieldElement::::from(trace_len as u64); From 60c925556cbf09a43ab2030e08f2ff5623ab13c6 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 15:54:20 -0300 Subject: [PATCH 04/17] perf: chunked parallel batch inverse for LogUp fingerprints Add `par_batch_inverse` to `FieldElement` (math crate, parallel feature) that splits the input into per-thread chunks and runs one Montgomery batch inversion per chunk, trading K extra inversions for O(N/K) sequential work per thread. Falls back to sequential for inputs < 1024. Use it in `compute_logup_term_column` and `compute_logup_batched_term_column` in lookup.rs (guarded by #[cfg(feature = "parallel")]). Also add `math/parallel` to stark's `parallel` feature so the new method is visible when stark is compiled with parallelism enabled. --- crypto/math/src/field/element.rs | 19 +++++++++++++++++++ crypto/stark/Cargo.toml | 2 +- crypto/stark/src/lookup.rs | 8 ++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/crypto/math/src/field/element.rs b/crypto/math/src/field/element.rs index 9c2ac3258..e592c0b6e 100644 --- a/crypto/math/src/field/element.rs +++ b/crypto/math/src/field/element.rs @@ -71,6 +71,25 @@ impl FieldElement { Ok(()) } + /// Parallel batch inverse: splits into chunks, one Montgomery inversion per chunk. + /// Cost: K extra field inversions (one per chunk) but O(N/K) sequential work per thread. + /// Falls back to sequential for small inputs. + #[cfg(feature = "parallel")] + pub fn par_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> + where + Self: Send + Sync, + { + if numbers.len() < 1024 { + return Self::inplace_batch_inverse(numbers); + } + use rayon::prelude::*; + let num_chunks = rayon::current_num_threads().min(numbers.len() / 256); + let chunk_size = (numbers.len() + num_chunks - 1) / num_chunks; + numbers + .par_chunks_mut(chunk_size) + .try_for_each(|chunk| Self::inplace_batch_inverse(chunk)) + } + #[inline(always)] pub fn to_subfield_vec(self) -> alloc::vec::Vec> where diff --git a/crypto/stark/Cargo.toml b/crypto/stark/Cargo.toml index 53b205996..3641e7821 100644 --- a/crypto/stark/Cargo.toml +++ b/crypto/stark/Cargo.toml @@ -38,7 +38,7 @@ test-utils = [] test_fiat_shamir = [] instruments = [] # This enables timing prints in prover and verifier debug-checks = [] # Enables validate_trace + bus balance report in prover -parallel = ["dep:rayon", "crypto/parallel"] +parallel = ["dep:rayon", "crypto/parallel", "math/parallel"] wasm = ["dep:wasm-bindgen", "dep:serde-wasm-bindgen", "dep:web-sys"] diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index dd5032e2b..282bcae9d 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -1526,6 +1526,10 @@ where ); } + #[cfg(feature = "parallel")] + FieldElement::par_batch_inverse(&mut fingerprints) + .expect("fingerprint is zero - probability of sampling zero is negligible"); + #[cfg(not(feature = "parallel"))] FieldElement::inplace_batch_inverse(&mut fingerprints) .expect("fingerprint is zero - probability of sampling zero is negligible"); @@ -1726,6 +1730,10 @@ where all_fingerprints.extend(fingerprints_b); // Single batch inversion for all 2*N fingerprints + #[cfg(feature = "parallel")] + FieldElement::par_batch_inverse(&mut all_fingerprints) + .expect("fingerprint is zero - probability of sampling zero is negligible"); + #[cfg(not(feature = "parallel"))] FieldElement::inplace_batch_inverse(&mut all_fingerprints) .expect("fingerprint is zero - probability of sampling zero is negligible"); From 35f8614e01de08ba19f76473dce1e518f47dd8fb Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 16:03:56 -0300 Subject: [PATCH 05/17] revert: undo LogUp parallelization (caused +3.9% regression due to nested Rayon over-subscription) --- crypto/math/src/field/element.rs | 19 -- crypto/stark/Cargo.toml | 2 +- crypto/stark/src/lookup.rs | 293 +++++++++---------------------- 3 files changed, 87 insertions(+), 227 deletions(-) diff --git a/crypto/math/src/field/element.rs b/crypto/math/src/field/element.rs index e592c0b6e..9c2ac3258 100644 --- a/crypto/math/src/field/element.rs +++ b/crypto/math/src/field/element.rs @@ -71,25 +71,6 @@ impl FieldElement { Ok(()) } - /// Parallel batch inverse: splits into chunks, one Montgomery inversion per chunk. - /// Cost: K extra field inversions (one per chunk) but O(N/K) sequential work per thread. - /// Falls back to sequential for small inputs. - #[cfg(feature = "parallel")] - pub fn par_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> - where - Self: Send + Sync, - { - if numbers.len() < 1024 { - return Self::inplace_batch_inverse(numbers); - } - use rayon::prelude::*; - let num_chunks = rayon::current_num_threads().min(numbers.len() / 256); - let chunk_size = (numbers.len() + num_chunks - 1) / num_chunks; - numbers - .par_chunks_mut(chunk_size) - .try_for_each(|chunk| Self::inplace_batch_inverse(chunk)) - } - #[inline(always)] pub fn to_subfield_vec(self) -> alloc::vec::Vec> where diff --git a/crypto/stark/Cargo.toml b/crypto/stark/Cargo.toml index 3641e7821..53b205996 100644 --- a/crypto/stark/Cargo.toml +++ b/crypto/stark/Cargo.toml @@ -38,7 +38,7 @@ test-utils = [] test_fiat_shamir = [] instruments = [] # This enables timing prints in prover and verifier debug-checks = [] # Enables validate_trace + bus balance report in prover -parallel = ["dep:rayon", "crypto/parallel", "math/parallel"] +parallel = ["dep:rayon", "crypto/parallel"] wasm = ["dep:wasm-bindgen", "dep:serde-wasm-bindgen", "dep:web-sys"] diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 282bcae9d..80d40a78c 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -1463,73 +1463,47 @@ where // into the fingerprint without collecting bus elements into intermediate Vecs. let bus_id_f = FieldElement::::from(table_interaction.bus_id); let shifts = PackingShifts::::new(); + let mut fingerprints: Vec> = Vec::with_capacity(trace_len); + for row in 0..trace_len { + // Accumulate fingerprint directly: bus_id * α^0 + Σ element_i * α^(1+i) + let mut linear_combination = &bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &table_interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut linear_combination, + &shifts, + ); + alpha_offset += consumed; + } - // Fingerprint computation is embarrassingly parallel: each row reads shared - // immutable data (main_segment_cols, alpha_powers) with no cross-row dependencies. - #[cfg(feature = "parallel")] - let mut fingerprints: Vec> = (0..trace_len) - .into_par_iter() - .map(|row| { - let mut linear_combination = &bus_id_f * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &table_interaction.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut linear_combination, - &shifts, - ); - alpha_offset += consumed; - } - z - &linear_combination - }) - .collect(); - #[cfg(not(feature = "parallel"))] - let mut fingerprints: Vec> = (0..trace_len) - .map(|row| { - let mut linear_combination = &bus_id_f * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &table_interaction.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut linear_combination, - &shifts, - ); - alpha_offset += consumed; - } - z - &linear_combination - }) - .collect(); + fingerprints.push(z - &linear_combination); - #[cfg(feature = "debug-checks")] - for row in 0..trace_len { - let mut base_elements: Vec> = vec![bus_id_f.clone()]; - base_elements.extend( - table_interaction - .values - .iter() - .flat_map(|bv| bv.combine_from(|col| main_segment_cols[col][row].clone())), - ); - crate::bus_debug::log_interaction( - _table_name, - row, - table_interaction.bus_id, - table_interaction.is_sender, - &multiplicities[row].canonical(), - &base_elements, - &fingerprints[row], - ); + #[cfg(feature = "debug-checks")] + { + // Reconstruct base_elements for debug logging + let mut base_elements: Vec> = vec![bus_id_f.clone()]; + base_elements.extend( + table_interaction + .values + .iter() + .flat_map(|bv| bv.combine_from(|col| main_segment_cols[col][row].clone())), + ); + crate::bus_debug::log_interaction( + _table_name, + row, + table_interaction.bus_id, + table_interaction.is_sender, + &multiplicities[row].canonical(), + &base_elements, + fingerprints.last().unwrap(), + ); + } } - #[cfg(feature = "parallel")] - FieldElement::par_batch_inverse(&mut fingerprints) - .expect("fingerprint is zero - probability of sampling zero is negligible"); - #[cfg(not(feature = "parallel"))] FieldElement::inplace_batch_inverse(&mut fingerprints) .expect("fingerprint is zero - probability of sampling zero is negligible"); @@ -1640,134 +1614,58 @@ where let shifts = PackingShifts::::new(); // Concatenate both fingerprint vectors for a single batch inversion - // - // Fingerprint computation is embarrassingly parallel: each row reads shared - // immutable data (main_segment_cols, alpha_powers) with no cross-row dependencies. - #[cfg(feature = "parallel")] - let fingerprints_a: Vec> = (0..trace_len) - .into_par_iter() - .map(|row| { - let mut lc_a = &bus_id_a * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction_a.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc_a, - &shifts, - ); - alpha_offset += consumed; - } - z - &lc_a - }) - .collect(); - #[cfg(not(feature = "parallel"))] - let fingerprints_a: Vec> = (0..trace_len) - .map(|row| { - let mut lc_a = &bus_id_a * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction_a.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc_a, - &shifts, - ); - alpha_offset += consumed; - } - z - &lc_a - }) - .collect(); + let mut all_fingerprints: Vec> = Vec::with_capacity(2 * trace_len); - #[cfg(feature = "parallel")] - let fingerprints_b: Vec> = (0..trace_len) - .into_par_iter() - .map(|row| { - let mut lc_b = &bus_id_b * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction_b.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc_b, - &shifts, - ); - alpha_offset += consumed; - } - z - &lc_b - }) - .collect(); - #[cfg(not(feature = "parallel"))] - let fingerprints_b: Vec> = (0..trace_len) - .map(|row| { - let mut lc_b = &bus_id_b * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction_b.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc_b, - &shifts, - ); - alpha_offset += consumed; - } - z - &lc_b - }) - .collect(); - - // Concatenate into single vec for batch inversion - let mut all_fingerprints: Vec> = - Vec::with_capacity(2 * trace_len); - all_fingerprints.extend(fingerprints_a); - all_fingerprints.extend(fingerprints_b); + for row in 0..trace_len { + let mut lc_a = &bus_id_a * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction_a.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc_a, + &shifts, + ); + alpha_offset += consumed; + } + all_fingerprints.push(z - &lc_a); + } + for row in 0..trace_len { + let mut lc_b = &bus_id_b * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction_b.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc_b, + &shifts, + ); + alpha_offset += consumed; + } + all_fingerprints.push(z - &lc_b); + } // Single batch inversion for all 2*N fingerprints - #[cfg(feature = "parallel")] - FieldElement::par_batch_inverse(&mut all_fingerprints) - .expect("fingerprint is zero - probability of sampling zero is negligible"); - #[cfg(not(feature = "parallel"))] FieldElement::inplace_batch_inverse(&mut all_fingerprints) .expect("fingerprint is zero - probability of sampling zero is negligible"); // Compute batched terms: term[i] = m_a[i] / fp_a[i] ± m_b[i] / fp_b[i] // Use conditional negation instead of E×E sign multiplication - #[cfg(feature = "parallel")] - { - (0..trace_len) - .into_par_iter() - .map(|row| { - let fp_a_inv = &all_fingerprints[row]; - let fp_b_inv = &all_fingerprints[trace_len + row]; - let term_a = &multiplicities_a[row] * fp_a_inv; - let term_b = &multiplicities_b[row] * fp_b_inv; - let term_a = if negate_a { -term_a } else { term_a }; - let term_b = if negate_b { -term_b } else { term_b }; - term_a + term_b - }) - .collect() - } - #[cfg(not(feature = "parallel"))] - { - (0..trace_len) - .map(|row| { - let fp_a_inv = &all_fingerprints[row]; - let fp_b_inv = &all_fingerprints[trace_len + row]; - let term_a = &multiplicities_a[row] * fp_a_inv; - let term_b = &multiplicities_b[row] * fp_b_inv; - let term_a = if negate_a { -term_a } else { term_a }; - let term_b = if negate_b { -term_b } else { term_b }; - term_a + term_b - }) - .collect() - } + (0..trace_len) + .map(|row| { + let fp_a_inv = &all_fingerprints[row]; + let fp_b_inv = &all_fingerprints[trace_len + row]; + let term_a = &multiplicities_a[row] * fp_a_inv; + let term_b = &multiplicities_b[row] * fp_b_inv; + let term_a = if negate_a { -term_a } else { term_a }; + let term_b = if negate_b { -term_b } else { term_b }; + term_a + term_b + }) + .collect() } /// Builds the circular accumulated column from pre-computed term columns. @@ -1792,31 +1690,12 @@ where let trace_len = term_columns[0].len(); // Compute L = sum of all terms across all rows - #[cfg(feature = "parallel")] - let table_contribution: FieldElement = { - use rayon::prelude::*; - (0..trace_len) - .into_par_iter() - .map(|row| { - let mut row_sum = FieldElement::::zero(); - for col in term_columns { - row_sum = row_sum + &col[row]; - } - row_sum - }) - .reduce(FieldElement::zero, |a, b| a + b) - }; - - #[cfg(not(feature = "parallel"))] - let table_contribution = { - let mut total = FieldElement::::zero(); - for row in 0..trace_len { - for col in term_columns { - total = &total + &col[row]; - } + let mut table_contribution = FieldElement::::zero(); + for row in 0..trace_len { + for col in term_columns { + table_contribution = &table_contribution + &col[row]; } - total - }; + } // offset_per_row = L / N let n = FieldElement::::from(trace_len as u64); From 00b33d426896c7626ec8580b9880008393a4e4e8 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 16:25:19 -0300 Subject: [PATCH 06/17] perf: fused chunk-local LogUp processing with parallel prefix sum Replace the column-parallel LogUp auxiliary trace build (which caused Rayon over-subscription when called from an already-parallel context) with a chunk-local approach inspired by Plonky3. Key changes: - New `compute_logup_batched_term_column_chunked` and `compute_logup_term_column_chunked` functions process rows in chunks of 1024, fusing fingerprint computation + batch inverse + term evaluation per chunk for L2 cache locality - Parallelism is across row-chunks (par_chunks_mut), not across interaction pairs, avoiding nested Rayon over-subscription - New `compute_multiplicity_for_row` helper avoids materializing full Vec per interaction in the chunked path - `build_accumulated_column_from_terms` now uses parallel reduction for table_contribution and 3-phase parallel prefix sum for the accumulated column - Sequential (non-parallel) path unchanged, using original functions - All 121 stark tests pass with and without parallel feature --- crypto/stark/src/lookup.rs | 384 ++++++++++++++++++++++++++++++++++--- 1 file changed, 359 insertions(+), 25 deletions(-) diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 80d40a78c..3bb859ab9 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -19,7 +19,10 @@ use math::field::{ traits::{IsFFTField, IsField, IsPrimeField, IsSubFieldOf}, }; #[cfg(feature = "parallel")] -use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use rayon::prelude::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, + ParallelSliceMut, +}; // ============================================================================= // Shift Constants for Type Combining @@ -100,6 +103,11 @@ pub const LOGUP_CHALLENGE_ALPHA: usize = 1; /// Number of challenges required by the LogUp protocol. pub const LOGUP_NUM_CHALLENGES: usize = 2; +/// Chunk size for fused chunk-local LogUp processing. +/// Each chunk processes all interactions for CHUNK_SIZE rows, fitting in L2 cache. +#[cfg(feature = "parallel")] +const LOGUP_CHUNK_SIZE: usize = 1024; + /// Split N interactions into committed batched pairs and absorbed remainder. /// /// Returns `(num_committed_pairs, absorbed_count)` where: @@ -1016,23 +1024,24 @@ where // Clone main columns once (shared across all interactions) let main_segment_cols = trace.columns_main(); let trace_len = trace.num_rows(); - let table_name = self.name.as_deref().unwrap_or("UNKNOWN"); + let _table_name = self.name.as_deref().unwrap_or("UNKNOWN"); // Split interactions: committed pairs get term columns, last 1-2 are absorbed (virtual) let (num_committed_pairs, absorbed_count) = split_interactions(num_interactions); - // Compute committed term columns in parallel (batched pairs only) + // Compute committed term columns (batched pairs only). + // With `parallel`: sequential over pairs, each using chunk-local parallelism + // (parallel across row-chunks, not across pairs) for better cache locality. + // Without `parallel`: sequential over pairs, sequential over rows. #[cfg(feature = "parallel")] let committed_columns: Vec>> = (0..num_committed_pairs) - .into_par_iter() .map(|i| { - compute_logup_batched_term_column( + compute_logup_batched_term_column_chunked( &self.auxiliary_trace_build_data.interactions[i * 2], &self.auxiliary_trace_build_data.interactions[i * 2 + 1], &main_segment_cols, trace_len, challenges, - table_name, ) }) .collect(); @@ -1045,12 +1054,30 @@ where &main_segment_cols, trace_len, challenges, - table_name, + _table_name, ) }) .collect(); // Compute virtual column for absorbed interactions (NOT written to trace) + #[cfg(feature = "parallel")] + let virtual_column = if absorbed_count == 2 { + compute_logup_batched_term_column_chunked( + &self.auxiliary_trace_build_data.interactions[num_interactions - 2], + &self.auxiliary_trace_build_data.interactions[num_interactions - 1], + &main_segment_cols, + trace_len, + challenges, + ) + } else { + compute_logup_term_column_chunked( + &self.auxiliary_trace_build_data.interactions[num_interactions - 1], + &main_segment_cols, + trace_len, + challenges, + ) + }; + #[cfg(not(feature = "parallel"))] let virtual_column = if absorbed_count == 2 { compute_logup_batched_term_column( &self.auxiliary_trace_build_data.interactions[num_interactions - 2], @@ -1058,7 +1085,7 @@ where &main_segment_cols, trace_len, challenges, - table_name, + _table_name, ) } else { compute_logup_term_column( @@ -1066,7 +1093,7 @@ where &main_segment_cols, trace_len, challenges, - table_name, + _table_name, ) }; @@ -1084,7 +1111,7 @@ where &main_segment_cols, trace_len, challenges, - table_name, + _table_name, ); // Build accumulated from all columns (committed + virtual) @@ -1361,6 +1388,7 @@ where /// This is a pure function that takes shared main columns and returns the computed column, /// enabling parallel computation across interactions within a table. #[allow(clippy::needless_range_loop)] +#[cfg_attr(feature = "parallel", allow(dead_code))] fn compute_logup_term_column( table_interaction: &BusInteraction, main_segment_cols: &[Vec>], @@ -1525,6 +1553,7 @@ where /// /// Uses a single batch inversion for both fingerprint vectors (2*N elements). #[allow(clippy::needless_range_loop)] +#[cfg_attr(feature = "parallel", allow(dead_code))] fn compute_logup_batched_term_column( interaction_a: &BusInteraction, interaction_b: &BusInteraction, @@ -1668,6 +1697,238 @@ where .collect() } +/// Computes the multiplicity for a single row of an interaction. +/// +/// This avoids materializing a full Vec> of length trace_len +/// when processing rows in chunks. +#[cfg(feature = "parallel")] +#[inline] +fn compute_multiplicity_for_row( + multiplicity: &Multiplicity, + main_segment_cols: &[Vec>], + row: usize, +) -> FieldElement { + match multiplicity { + Multiplicity::One => FieldElement::one(), + Multiplicity::Column(col) => main_segment_cols[*col][row].clone(), + Multiplicity::Sum(col_a, col_b) => { + &main_segment_cols[*col_a][row] + &main_segment_cols[*col_b][row] + } + Multiplicity::Negated(col) => FieldElement::::one() - &main_segment_cols[*col][row], + Multiplicity::Diff(col_a, col_b) => { + &main_segment_cols[*col_a][row] - &main_segment_cols[*col_b][row] + } + Multiplicity::Sum3(col_a, col_b, col_c) => { + &main_segment_cols[*col_a][row] + + &main_segment_cols[*col_b][row] + + &main_segment_cols[*col_c][row] + } + Multiplicity::Linear(terms) => { + let mut result = FieldElement::::zero(); + for term in terms { + match *term { + LinearTerm::Column { + coefficient, + column, + } => { + let coeff = FieldElement::::from(coefficient); + result += &main_segment_cols[column][row] * coeff; + } + LinearTerm::ColumnUnsigned { + coefficient, + column, + } => { + let coeff = FieldElement::::from(coefficient); + result += &main_segment_cols[column][row] * coeff; + } + LinearTerm::Constant(value) => { + result += FieldElement::::from(value); + } + } + } + result + } + } +} + +/// Chunk-local batched term column computation for two interactions. +/// +/// Processes rows in chunks of `LOGUP_CHUNK_SIZE`. Per chunk: +/// 1. Compute 2*CHUNK fingerprints (interaction_a and interaction_b) +/// 2. Batch-invert locally (one Montgomery inverse per chunk) +/// 3. Compute terms: m_a/fp_a +/- m_b/fp_b +/// +/// Parallelism is across row-chunks (not across interaction pairs), giving +/// much better cache locality: each thread touches only CHUNK_SIZE rows of +/// main trace data before moving to the next phase. +#[cfg(feature = "parallel")] +fn compute_logup_batched_term_column_chunked( + interaction_a: &BusInteraction, + interaction_b: &BusInteraction, + main_segment_cols: &[Vec>], + trace_len: usize, + challenges: &[FieldElement], +) -> Vec> +where + F: IsFFTField + IsSubFieldOf + IsPrimeField + Send + Sync, + E: IsField + Send + Sync, +{ + let z = &challenges[0]; + let alpha = &challenges[LOGUP_CHALLENGE_ALPHA]; + + let max_bus_elements = interaction_a + .num_bus_elements() + .max(interaction_b.num_bus_elements()); + let alpha_powers = compute_alpha_powers(alpha, max_bus_elements); + + let negate_a = !interaction_a.is_sender; + let negate_b = !interaction_b.is_sender; + + let bus_id_a = FieldElement::::from(interaction_a.bus_id); + let bus_id_b = FieldElement::::from(interaction_b.bus_id); + let shifts = PackingShifts::::new(); + + // Output: one FieldElement per row + let mut result = vec![FieldElement::::zero(); trace_len]; + + result.par_chunks_mut(LOGUP_CHUNK_SIZE).enumerate().for_each( + |(chunk_idx, result_chunk)| { + let chunk_start = chunk_idx * LOGUP_CHUNK_SIZE; + let chunk_len = result_chunk.len(); + + // Phase 1: Compute fingerprints for both interactions in this chunk. + // Layout: [fp_a[0..chunk_len], fp_b[0..chunk_len]] + let mut fingerprints: Vec> = Vec::with_capacity(2 * chunk_len); + + for row in chunk_start..chunk_start + chunk_len { + let mut lc_a = &bus_id_a * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction_a.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc_a, + &shifts, + ); + alpha_offset += consumed; + } + fingerprints.push(z - &lc_a); + } + for row in chunk_start..chunk_start + chunk_len { + let mut lc_b = &bus_id_b * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction_b.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc_b, + &shifts, + ); + alpha_offset += consumed; + } + fingerprints.push(z - &lc_b); + } + + // Phase 2: Batch-invert all fingerprints in this chunk + FieldElement::inplace_batch_inverse(&mut fingerprints) + .expect("fingerprint is zero - probability of sampling zero is negligible"); + + // Phase 3: Compute terms: m_a/fp_a +/- m_b/fp_b + for (i, result_elem) in result_chunk.iter_mut().enumerate() { + let row = chunk_start + i; + let fp_a_inv = &fingerprints[i]; + let fp_b_inv = &fingerprints[chunk_len + i]; + + let m_a = compute_multiplicity_for_row(&interaction_a.multiplicity, main_segment_cols, row); + let m_b = compute_multiplicity_for_row(&interaction_b.multiplicity, main_segment_cols, row); + + let term_a = &m_a * fp_a_inv; + let term_b = &m_b * fp_b_inv; + let term_a = if negate_a { -term_a } else { term_a }; + let term_b = if negate_b { -term_b } else { term_b }; + *result_elem = term_a + term_b; + } + }, + ); + + result +} + +/// Chunk-local single-interaction term column computation. +/// +/// Same cache-locality benefits as `compute_logup_batched_term_column_chunked` +/// but for a single interaction (used for the virtual absorbed column when +/// `absorbed_count == 1`). +#[cfg(feature = "parallel")] +fn compute_logup_term_column_chunked( + interaction: &BusInteraction, + main_segment_cols: &[Vec>], + trace_len: usize, + challenges: &[FieldElement], +) -> Vec> +where + F: IsFFTField + IsSubFieldOf + IsPrimeField + Send + Sync, + E: IsField + Send + Sync, +{ + let z = &challenges[0]; + let alpha = &challenges[LOGUP_CHALLENGE_ALPHA]; + + let num_bus_elements = interaction.num_bus_elements(); + let alpha_powers = compute_alpha_powers(alpha, num_bus_elements); + + let negate = !interaction.is_sender; + + let bus_id_f = FieldElement::::from(interaction.bus_id); + let shifts = PackingShifts::::new(); + + let mut result = vec![FieldElement::::zero(); trace_len]; + + result.par_chunks_mut(LOGUP_CHUNK_SIZE).enumerate().for_each( + |(chunk_idx, result_chunk)| { + let chunk_start = chunk_idx * LOGUP_CHUNK_SIZE; + let chunk_len = result_chunk.len(); + + // Phase 1: Compute fingerprints for this chunk + let mut fingerprints: Vec> = Vec::with_capacity(chunk_len); + + for row in chunk_start..chunk_start + chunk_len { + let mut lc = &bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc, + &shifts, + ); + alpha_offset += consumed; + } + fingerprints.push(z - &lc); + } + + // Phase 2: Batch-invert fingerprints in this chunk + FieldElement::inplace_batch_inverse(&mut fingerprints) + .expect("fingerprint is zero - probability of sampling zero is negligible"); + + // Phase 3: Compute terms: ±(m * fp_inv) + for (i, result_elem) in result_chunk.iter_mut().enumerate() { + let row = chunk_start + i; + let m = compute_multiplicity_for_row(&interaction.multiplicity, main_segment_cols, row); + let term = &m * &fingerprints[i]; + *result_elem = if negate { -term } else { term }; + } + }, + ); + + result +} + /// Builds the circular accumulated column from pre-computed term columns. /// /// For the circular constraint: acc[(i+1) mod N] - acc[i] - terms[(i+1) mod N] + L/N = 0 @@ -1689,27 +1950,100 @@ where } let trace_len = term_columns[0].len(); - // Compute L = sum of all terms across all rows - let mut table_contribution = FieldElement::::zero(); - for row in 0..trace_len { - for col in term_columns { - table_contribution = &table_contribution + &col[row]; - } - } + // Precompute row_sums[row] = sum of all term_columns at that row. + // This avoids recomputing during the prefix sum and enables parallel reduction. + let row_sums: Vec> = (0..trace_len) + .map(|row| { + let mut s = FieldElement::::zero(); + for col in term_columns { + s = s + &col[row]; + } + s + }) + .collect(); + + // Compute L = sum of all row_sums (parallel when feature enabled) + #[cfg(feature = "parallel")] + let table_contribution: FieldElement = row_sums + .par_iter() + .cloned() + .reduce(FieldElement::zero, |a, b| a + b); + #[cfg(not(feature = "parallel"))] + let table_contribution: FieldElement = row_sums.iter().fold(FieldElement::zero(), |a, b| a + b); // offset_per_row = L / N let n = FieldElement::::from(trace_len as u64); let offset_per_row = &table_contribution * n.inv().unwrap(); - // Build circular accumulated column - let mut accumulated = FieldElement::::zero(); - for row in 0..trace_len { - let mut row_sum = FieldElement::::zero(); - for col in term_columns { - row_sum = row_sum + &col[row]; + // Build circular accumulated column using 3-phase parallel prefix sum. + // + // Phase 1: Compute chunk-local prefix sums in parallel. + // Each chunk computes partial_sums[i] = Σ(row_sums[j] - offset) for j in chunk. + // Also stores the chunk's total as `chunk_totals[chunk_idx]`. + // + // Phase 2: Sequential scan of chunk_totals to compute offsets for each chunk. + // + // Phase 3: Add chunk offset to each element in the accumulated vector. + // + // Finally write the accumulated column to trace (sequential, since set_aux takes &mut). + #[cfg(feature = "parallel")] + let accumulated_col = { + let num_chunks = (trace_len + LOGUP_CHUNK_SIZE - 1) / LOGUP_CHUNK_SIZE; + + // Phase 1: Compute chunk-local prefix sums + let chunk_data: Vec<(Vec>, FieldElement)> = (0..num_chunks) + .into_par_iter() + .map(|chunk_idx| { + let start = chunk_idx * LOGUP_CHUNK_SIZE; + let end = (start + LOGUP_CHUNK_SIZE).min(trace_len); + + let mut local_prefix = Vec::with_capacity(end - start); + let mut acc = FieldElement::::zero(); + for row in start..end { + acc = &acc + &row_sums[row] - &offset_per_row; + local_prefix.push(acc.clone()); + } + let chunk_total = acc; + (local_prefix, chunk_total) + }) + .collect(); + + // Phase 2: Sequential scan of chunk totals to get per-chunk offsets + let mut chunk_offsets = Vec::with_capacity(num_chunks); + let mut running = FieldElement::::zero(); + for (_, chunk_total) in &chunk_data { + chunk_offsets.push(running.clone()); + running = &running + chunk_total; } - accumulated = &accumulated + &row_sum - &offset_per_row; - trace.set_aux(row, acc_column_idx, accumulated.clone()); + + // Phase 3: Build final accumulated vector (parallel across chunks) + let mut acc_col = vec![FieldElement::::zero(); trace_len]; + acc_col + .par_chunks_mut(LOGUP_CHUNK_SIZE) + .enumerate() + .for_each(|(chunk_idx, out_chunk)| { + let offset = &chunk_offsets[chunk_idx]; + for (i, out) in out_chunk.iter_mut().enumerate() { + *out = offset + &chunk_data[chunk_idx].0[i]; + } + }); + acc_col + }; + + #[cfg(not(feature = "parallel"))] + let accumulated_col = { + let mut col = Vec::with_capacity(trace_len); + let mut accumulated = FieldElement::::zero(); + for row_sum in &row_sums { + accumulated = &accumulated + row_sum - &offset_per_row; + col.push(accumulated.clone()); + } + col + }; + + // Write accumulated column to trace + for (row, value) in accumulated_col.into_iter().enumerate() { + trace.set_aux(row, acc_column_idx, value); } table_contribution From 09483893792bdf7474cd37136556138229019d7c Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 16:54:42 -0300 Subject: [PATCH 07/17] perf: optimize FRI fold arithmetic and leaf construction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two optimizations to the FRI commit phase: 1. Precompute zeta * inv_twiddles[j] once per layer (F×E = 3 base muls each). The per-row fold then uses one E×E multiply (9 base muls) instead of E×E + F×E (12 base muls). Saves ~25% of fold arithmetic. 2. Hash FRI leaves directly from evals pairs via build_from_hashed_leaves, eliminating the intermediate Vec<[FieldElement; 2]> allocation (~24MB at FRI layer 0). --- crypto/stark/src/fri/fri_functions.rs | 18 +++++++++++++----- crypto/stark/src/fri/mod.rs | 13 +++++++++---- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/crypto/stark/src/fri/fri_functions.rs b/crypto/stark/src/fri/fri_functions.rs index 8bd355ec4..e2e796b09 100644 --- a/crypto/stark/src/fri/fri_functions.rs +++ b/crypto/stark/src/fri/fri_functions.rs @@ -7,23 +7,31 @@ use math::field::{ }; /// Evaluation-form FRI fold: given evaluations in bit-reversed order where /// consecutive pairs (2j, 2j+1) are conjugates (p(x_j), p(-x_j)), compute -/// the folded evaluations: (lo + hi) + inv_twiddle[j] * zeta * (lo - hi) -/// = 2 * (p_even(x_j²) + zeta * p_odd(x_j²)) +/// the folded evaluations: (lo + hi) + (zeta * inv_twiddle[j]) * (lo - hi) /// -/// After folding, the N/2 results are evaluations on the squared coset -/// in bit-reversed order, preserving conjugate pairing for the next fold. +/// Optimization: precomputes `zeta * inv_twiddle[j]` (F×E = 3 base muls each) +/// so the per-row fold is ONE E×E multiply (9 base muls) instead of +/// E×E + F×E (12 base muls). Saves ~25% of fold arithmetic. pub fn fold_evaluations_in_place, E: IsField>( evals: &mut Vec>, zeta: &FieldElement, inv_twiddles: &[FieldElement], ) { let half = evals.len() / 2; + + // Precompute zeta * inv_twiddle[j] once per layer. + // Each is F×E = 3 base muls (vs 12 per row without precomputation). + let zeta_tw: Vec> = inv_twiddles[..half] + .iter() + .map(|tw| tw * zeta) + .collect(); + for j in 0..half { let lo = &evals[2 * j]; let hi = &evals[2 * j + 1]; let sum = lo + hi; let diff = lo - hi; - evals[j] = &sum + &(&inv_twiddles[j] * &(zeta * &diff)); + evals[j] = sum + &zeta_tw[j] * diff; } evals.truncate(half); } diff --git a/crypto/stark/src/fri/mod.rs b/crypto/stark/src/fri/mod.rs index 87ab66a5b..7e553ede2 100644 --- a/crypto/stark/src/fri/mod.rs +++ b/crypto/stark/src/fri/mod.rs @@ -8,6 +8,8 @@ use math::field::traits::IsSubFieldOf; use math::field::traits::{IsFFTField, IsField}; use math::traits::AsBytes; +use crypto::merkle_tree::traits::IsMerkleTreeBackend; + use crate::config::{FriLayerMerkleTree, FriLayerMerkleTreeBackend}; use self::fri_commitment::FriLayer; @@ -49,12 +51,15 @@ where // Fold evaluations in-place (no FFT needed) fold_evaluations_in_place(&mut evals, &zeta, &inv_twiddles); - // Build Merkle tree from consecutive pairs - let leaves: Vec<[FieldElement; 2]> = evals + // Hash leaves directly from evals pairs (no intermediate Vec allocation). + // Each leaf = hash(evals[2i] || evals[2i+1]). + let hashed_leaves: Vec<_> = evals .chunks_exact(2) - .map(|chunk| [chunk[0].clone(), chunk[1].clone()]) + .map(|chunk| FriLayerMerkleTreeBackend::::hash_data( + &[chunk[0].clone(), chunk[1].clone()], + )) .collect(); - let merkle_tree = FriLayerMerkleTree::build(&leaves) + let merkle_tree = FriLayerMerkleTree::build_from_hashed_leaves(hashed_leaves) .expect("FRI commit: Merkle tree construction must succeed"); let root = merkle_tree.root; fri_layer_list.push(FriLayer::new( From c2bdf66663c8c969dfd4ab4a10af5181ddd67455 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 16:55:12 -0300 Subject: [PATCH 08/17] perf: eliminate per-row heap allocation in Merkle leaf hashing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use Rayon map_init to allocate one byte buffer per thread (reused across all rows) instead of vec![0u8; N] per row. For CPU table (74 cols × 2^21 rows), this eliminates ~2M heap allocations. Applied to both commit_columns_bit_reversed (main/aux trace) and commit_composition_polynomial (composition poly). --- crypto/stark/src/prover.rs | 111 +++++++++++++++++++++++++------------ 1 file changed, 77 insertions(+), 34 deletions(-) diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 8e59807c1..0839fa2b9 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -387,26 +387,43 @@ pub trait IsStarkProver< "num_rows must be a power of two for reverse_index" ); + let total_bytes = num_cols * byte_len; + + // Use map_init to allocate one byte buffer per thread (reused across rows). + // Eliminates 2M heap allocations for typical CPU table (74 cols × 2^21 rows). #[cfg(feature = "parallel")] - let iter = (0..num_rows).into_par_iter(); - #[cfg(not(feature = "parallel"))] - let iter = 0..num_rows; - - // One allocation per row (was one per field element): write all columns - // into a single buffer, then hash once. - let hashed_leaves: Vec = iter - .map(|row_idx| { - let br_idx = reverse_index(row_idx, num_rows as u64); - let total_bytes = num_cols * byte_len; - let mut buf = vec![0u8; total_bytes]; - for col_idx in 0..num_cols { - columns[col_idx][br_idx] - .write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]); - } - BatchedMerkleTreeBackend::::hash_bytes(&buf) - }) + let hashed_leaves: Vec = (0..num_rows) + .into_par_iter() + .map_init( + || vec![0u8; total_bytes], + |buf, row_idx| { + let br_idx = reverse_index(row_idx, num_rows as u64); + for col_idx in 0..num_cols { + columns[col_idx][br_idx].write_bytes_be( + &mut buf[col_idx * byte_len..(col_idx + 1) * byte_len], + ); + } + BatchedMerkleTreeBackend::::hash_bytes(buf) + }, + ) .collect(); + #[cfg(not(feature = "parallel"))] + let hashed_leaves: Vec = { + let mut buf = vec![0u8; total_bytes]; + (0..num_rows) + .map(|row_idx| { + let br_idx = reverse_index(row_idx, num_rows as u64); + for col_idx in 0..num_cols { + columns[col_idx][br_idx].write_bytes_be( + &mut buf[col_idx * byte_len..(col_idx + 1) * byte_len], + ); + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) + }) + .collect() + }; + let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; Some((tree, root)) @@ -721,25 +738,51 @@ pub trait IsStarkProver< #[cfg(not(feature = "parallel"))] let iter = 0..num_leaves; - let hashed_leaves: Vec = iter - .map(|leaf_idx| { - let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); - let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); - let total_bytes = 2 * num_parts * byte_len; - let mut buf = vec![0u8; total_bytes]; - let mut offset = 0; - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - BatchedMerkleTreeBackend::::hash_bytes(&buf) - }) + let total_bytes = 2 * num_parts * byte_len; + + #[cfg(feature = "parallel")] + let hashed_leaves: Vec = (0..num_leaves) + .into_par_iter() + .map_init( + || vec![0u8; total_bytes], + |buf, leaf_idx| { + let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); + let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); + let mut offset = 0; + for part in lde_composition_poly_parts_evaluations.iter() { + part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + for part in lde_composition_poly_parts_evaluations.iter() { + part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + BatchedMerkleTreeBackend::::hash_bytes(buf) + }, + ) .collect(); + #[cfg(not(feature = "parallel"))] + let hashed_leaves: Vec = { + let mut buf = vec![0u8; total_bytes]; + (0..num_leaves) + .map(|leaf_idx| { + let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); + let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); + let mut offset = 0; + for part in lde_composition_poly_parts_evaluations.iter() { + part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + for part in lde_composition_poly_parts_evaluations.iter() { + part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) + }) + .collect() + }; + let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; Some((tree, root)) From bd7bc46b76ba9d0b303cc54383084488fcf8eeb1 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 17:02:47 -0300 Subject: [PATCH 09/17] revert: undo FRI fold twiddle precomputation (caused regression, needs isolated benchmarking) --- crypto/stark/src/fri/fri_functions.rs | 18 +++++------------- crypto/stark/src/fri/mod.rs | 13 ++++--------- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/crypto/stark/src/fri/fri_functions.rs b/crypto/stark/src/fri/fri_functions.rs index e2e796b09..8bd355ec4 100644 --- a/crypto/stark/src/fri/fri_functions.rs +++ b/crypto/stark/src/fri/fri_functions.rs @@ -7,31 +7,23 @@ use math::field::{ }; /// Evaluation-form FRI fold: given evaluations in bit-reversed order where /// consecutive pairs (2j, 2j+1) are conjugates (p(x_j), p(-x_j)), compute -/// the folded evaluations: (lo + hi) + (zeta * inv_twiddle[j]) * (lo - hi) +/// the folded evaluations: (lo + hi) + inv_twiddle[j] * zeta * (lo - hi) +/// = 2 * (p_even(x_j²) + zeta * p_odd(x_j²)) /// -/// Optimization: precomputes `zeta * inv_twiddle[j]` (F×E = 3 base muls each) -/// so the per-row fold is ONE E×E multiply (9 base muls) instead of -/// E×E + F×E (12 base muls). Saves ~25% of fold arithmetic. +/// After folding, the N/2 results are evaluations on the squared coset +/// in bit-reversed order, preserving conjugate pairing for the next fold. pub fn fold_evaluations_in_place, E: IsField>( evals: &mut Vec>, zeta: &FieldElement, inv_twiddles: &[FieldElement], ) { let half = evals.len() / 2; - - // Precompute zeta * inv_twiddle[j] once per layer. - // Each is F×E = 3 base muls (vs 12 per row without precomputation). - let zeta_tw: Vec> = inv_twiddles[..half] - .iter() - .map(|tw| tw * zeta) - .collect(); - for j in 0..half { let lo = &evals[2 * j]; let hi = &evals[2 * j + 1]; let sum = lo + hi; let diff = lo - hi; - evals[j] = sum + &zeta_tw[j] * diff; + evals[j] = &sum + &(&inv_twiddles[j] * &(zeta * &diff)); } evals.truncate(half); } diff --git a/crypto/stark/src/fri/mod.rs b/crypto/stark/src/fri/mod.rs index 7e553ede2..87ab66a5b 100644 --- a/crypto/stark/src/fri/mod.rs +++ b/crypto/stark/src/fri/mod.rs @@ -8,8 +8,6 @@ use math::field::traits::IsSubFieldOf; use math::field::traits::{IsFFTField, IsField}; use math::traits::AsBytes; -use crypto::merkle_tree::traits::IsMerkleTreeBackend; - use crate::config::{FriLayerMerkleTree, FriLayerMerkleTreeBackend}; use self::fri_commitment::FriLayer; @@ -51,15 +49,12 @@ where // Fold evaluations in-place (no FFT needed) fold_evaluations_in_place(&mut evals, &zeta, &inv_twiddles); - // Hash leaves directly from evals pairs (no intermediate Vec allocation). - // Each leaf = hash(evals[2i] || evals[2i+1]). - let hashed_leaves: Vec<_> = evals + // Build Merkle tree from consecutive pairs + let leaves: Vec<[FieldElement; 2]> = evals .chunks_exact(2) - .map(|chunk| FriLayerMerkleTreeBackend::::hash_data( - &[chunk[0].clone(), chunk[1].clone()], - )) + .map(|chunk| [chunk[0].clone(), chunk[1].clone()]) .collect(); - let merkle_tree = FriLayerMerkleTree::build_from_hashed_leaves(hashed_leaves) + let merkle_tree = FriLayerMerkleTree::build(&leaves) .expect("FRI commit: Merkle tree construction must succeed"); let root = merkle_tree.root; fri_layer_list.push(FriLayer::new( From 64f3839e9acdc219968cd4f00f25ebe948f344f6 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 17:10:25 -0300 Subject: [PATCH 10/17] revert: undo per-row buffer reuse in Merkle hashing (allocation was not the bottleneck, Keccak dominates) --- crypto/stark/src/prover.rs | 111 ++++++++++++------------------------- 1 file changed, 34 insertions(+), 77 deletions(-) diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 0839fa2b9..8e59807c1 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -387,42 +387,25 @@ pub trait IsStarkProver< "num_rows must be a power of two for reverse_index" ); - let total_bytes = num_cols * byte_len; - - // Use map_init to allocate one byte buffer per thread (reused across rows). - // Eliminates 2M heap allocations for typical CPU table (74 cols × 2^21 rows). #[cfg(feature = "parallel")] - let hashed_leaves: Vec = (0..num_rows) - .into_par_iter() - .map_init( - || vec![0u8; total_bytes], - |buf, row_idx| { - let br_idx = reverse_index(row_idx, num_rows as u64); - for col_idx in 0..num_cols { - columns[col_idx][br_idx].write_bytes_be( - &mut buf[col_idx * byte_len..(col_idx + 1) * byte_len], - ); - } - BatchedMerkleTreeBackend::::hash_bytes(buf) - }, - ) - .collect(); - + let iter = (0..num_rows).into_par_iter(); #[cfg(not(feature = "parallel"))] - let hashed_leaves: Vec = { - let mut buf = vec![0u8; total_bytes]; - (0..num_rows) - .map(|row_idx| { - let br_idx = reverse_index(row_idx, num_rows as u64); - for col_idx in 0..num_cols { - columns[col_idx][br_idx].write_bytes_be( - &mut buf[col_idx * byte_len..(col_idx + 1) * byte_len], - ); - } - BatchedMerkleTreeBackend::::hash_bytes(&buf) - }) - .collect() - }; + let iter = 0..num_rows; + + // One allocation per row (was one per field element): write all columns + // into a single buffer, then hash once. + let hashed_leaves: Vec = iter + .map(|row_idx| { + let br_idx = reverse_index(row_idx, num_rows as u64); + let total_bytes = num_cols * byte_len; + let mut buf = vec![0u8; total_bytes]; + for col_idx in 0..num_cols { + columns[col_idx][br_idx] + .write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]); + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) + }) + .collect(); let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; @@ -738,51 +721,25 @@ pub trait IsStarkProver< #[cfg(not(feature = "parallel"))] let iter = 0..num_leaves; - let total_bytes = 2 * num_parts * byte_len; - - #[cfg(feature = "parallel")] - let hashed_leaves: Vec = (0..num_leaves) - .into_par_iter() - .map_init( - || vec![0u8; total_bytes], - |buf, leaf_idx| { - let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); - let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); - let mut offset = 0; - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - BatchedMerkleTreeBackend::::hash_bytes(buf) - }, - ) + let hashed_leaves: Vec = iter + .map(|leaf_idx| { + let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); + let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); + let total_bytes = 2 * num_parts * byte_len; + let mut buf = vec![0u8; total_bytes]; + let mut offset = 0; + for part in lde_composition_poly_parts_evaluations.iter() { + part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + for part in lde_composition_poly_parts_evaluations.iter() { + part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) + }) .collect(); - #[cfg(not(feature = "parallel"))] - let hashed_leaves: Vec = { - let mut buf = vec![0u8; total_bytes]; - (0..num_leaves) - .map(|leaf_idx| { - let br_0 = reverse_index(2 * leaf_idx, num_rows as u64); - let br_1 = reverse_index(2 * leaf_idx + 1, num_rows as u64); - let mut offset = 0; - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - for part in lde_composition_poly_parts_evaluations.iter() { - part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); - offset += byte_len; - } - BatchedMerkleTreeBackend::::hash_bytes(&buf) - }) - .collect() - }; - let tree = BatchedMerkleTree::::build_from_hashed_leaves(hashed_leaves)?; let root = tree.root; Some((tree, root)) From 10fe99f0cf3071de9caba9b792d839129ecdf4f8 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 17:16:02 -0300 Subject: [PATCH 11/17] fix: extract fingerprint computation closure to reduce cfg duplication in chunked LogUp --- crypto/stark/src/lookup.rs | 56 ++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 3bb859ab9..a24e2d824 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -1798,40 +1798,30 @@ where // Phase 1: Compute fingerprints for both interactions in this chunk. // Layout: [fp_a[0..chunk_len], fp_b[0..chunk_len]] - let mut fingerprints: Vec> = Vec::with_capacity(2 * chunk_len); - - for row in chunk_start..chunk_start + chunk_len { - let mut lc_a = &bus_id_a * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction_a.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc_a, - &shifts, - ); - alpha_offset += consumed; - } - fingerprints.push(z - &lc_a); - } - for row in chunk_start..chunk_start + chunk_len { - let mut lc_b = &bus_id_b * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction_b.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc_b, - &shifts, - ); - alpha_offset += consumed; + let compute_chunk_fingerprints = |interaction: &BusInteraction, + bus_id_f: &FieldElement, + fps: &mut Vec>| { + for row in chunk_start..chunk_start + chunk_len { + let mut lc = bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc, + &shifts, + ); + alpha_offset += consumed; + } + fps.push(z - &lc); } - fingerprints.push(z - &lc_b); - } + }; + + let mut fingerprints: Vec> = Vec::with_capacity(2 * chunk_len); + compute_chunk_fingerprints(interaction_a, &bus_id_a, &mut fingerprints); + compute_chunk_fingerprints(interaction_b, &bus_id_b, &mut fingerprints); // Phase 2: Batch-invert all fingerprints in this chunk FieldElement::inplace_batch_inverse(&mut fingerprints) From c81453fc412035c6d96c89fa83da4462545ca924 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 17:28:44 -0300 Subject: [PATCH 12/17] perf: skip Rayon for small Merkle trees (< 1024 nodes) FRI produces ~190 small tree builds per proof (layers 10-18 have 2-512 leaves). Rayon scheduling overhead exceeds computation for these tiny trees. Add a 1024-node threshold: below it, use sequential iteration for both leaf hashing and internal node construction. --- crypto/crypto/src/merkle_tree/traits.rs | 17 ++++++++---- crypto/crypto/src/merkle_tree/utils.rs | 37 +++++++++++++++++++------ 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/crypto/crypto/src/merkle_tree/traits.rs b/crypto/crypto/src/merkle_tree/traits.rs index c09cff9d0..81665278b 100644 --- a/crypto/crypto/src/merkle_tree/traits.rs +++ b/crypto/crypto/src/merkle_tree/traits.rs @@ -16,11 +16,18 @@ pub trait IsMerkleTreeBackend { /// tree will be built from and converts it to a list of leaf nodes. fn hash_leaves(unhashed_leaves: &[Self::Data]) -> Vec { #[cfg(feature = "parallel")] - let iter = unhashed_leaves.par_iter(); - #[cfg(not(feature = "parallel"))] - let iter = unhashed_leaves.iter(); - - iter.map(|leaf| Self::hash_data(leaf)).collect() + { + if unhashed_leaves.len() >= 1024 { + return unhashed_leaves + .par_iter() + .map(|leaf| Self::hash_data(leaf)) + .collect(); + } + } + unhashed_leaves + .iter() + .map(|leaf| Self::hash_data(leaf)) + .collect() } /// This function takes to children nodes and builds a new parent node. diff --git a/crypto/crypto/src/merkle_tree/utils.rs b/crypto/crypto/src/merkle_tree/utils.rs index 7cc64166b..bd971bb9c 100644 --- a/crypto/crypto/src/merkle_tree/utils.rs +++ b/crypto/crypto/src/merkle_tree/utils.rs @@ -78,17 +78,36 @@ where let (new_level_iter, children_iter) = nodes[new_level_begin_index..level_end_index + 1].split_at_mut(new_level_length); + // Skip Rayon for small levels: the scheduling overhead exceeds + // computation for levels with fewer than 1024 nodes. This avoids + // hundreds of unnecessary task spawns from small FRI layer trees. #[cfg(feature = "parallel")] - let parent_and_children_zipped_iter = new_level_iter - .into_par_iter() - .zip(children_iter.par_chunks_exact(2)); + { + if new_level_length >= 1024 { + new_level_iter + .into_par_iter() + .zip(children_iter.par_chunks_exact(2)) + .for_each(|(new_parent, children)| { + *new_parent = B::hash_new_parent(&children[0], &children[1]); + }); + } else { + new_level_iter + .iter_mut() + .zip(children_iter.chunks_exact(2)) + .for_each(|(new_parent, children)| { + *new_parent = B::hash_new_parent(&children[0], &children[1]); + }); + } + } #[cfg(not(feature = "parallel"))] - let parent_and_children_zipped_iter = - new_level_iter.iter_mut().zip(children_iter.chunks_exact(2)); - - parent_and_children_zipped_iter.for_each(|(new_parent, children)| { - *new_parent = B::hash_new_parent(&children[0], &children[1]); - }); + { + new_level_iter + .iter_mut() + .zip(children_iter.chunks_exact(2)) + .for_each(|(new_parent, children)| { + *new_parent = B::hash_new_parent(&children[0], &children[1]); + }); + } level_end_index = level_begin_index - 1; level_begin_index = new_level_begin_index; From 8e70557918733a48a4a34d512dbe9c53aaebc5fe Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 17:41:19 -0300 Subject: [PATCH 13/17] perf: deduplicate Domain + LdeTwiddles across tables by (trace_length, blowup) Tables with the same domain size (e.g., 7+ tables at 2^20) were each creating their own Domain (~24 MB) and LdeTwiddles (~32 MB). With ~20 tables and only 4-5 distinct sizes, this wasted ~300 MB of memory and redundant root-of-unity + twiddle generation. Now uses a HashMap cache keyed by (trace_length, blowup_factor). Domain and LdeTwiddles are shared via Arc across all tables with the same parameters. --- crypto/stark/src/prover.rs | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 8e59807c1..6f77a1306 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -1521,13 +1521,30 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] let phase_start = Instant::now(); + // Deduplicate Domain + LdeTwiddles by (trace_length, blowup_factor). + // Many tables share the same domain size (e.g., 7+ tables at 2^20). + // Without dedup, each creates its own Domain (~24 MB) and LdeTwiddles (~32 MB). + let mut domain_cache: std::collections::HashMap< + (usize, usize), + (Arc>, Arc>), + > = std::collections::HashMap::new(); + let mut domains = Vec::with_capacity(num_airs); - let mut twiddle_caches: Vec> = Vec::with_capacity(num_airs); + let mut twiddle_caches: Vec>> = Vec::with_capacity(num_airs); for (air, trace, _pub_inputs) in &*air_trace_pairs { let trace_length = trace.num_rows(); - let domain = new_domain(*air, trace_length); - let twiddles = LdeTwiddles::new(&domain); + let blowup = air.options().blowup_factor as usize; + let key = (trace_length, blowup); + + let (domain, twiddles) = domain_cache + .entry(key) + .or_insert_with(|| { + let d = new_domain(*air, trace_length); + let t = LdeTwiddles::new(&d); + (Arc::new(d), Arc::new(t)) + }) + .clone(); domains.push(domain); twiddle_caches.push(twiddles); From 2ed646031c0cacd1949ba92d241a36129a77c05f Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 18:23:47 -0300 Subject: [PATCH 14/17] chore: remove outdated parallelism plan (approach changed) --- .../plans/2026-04-21-prover-parallelism.md | 55 ------------------- 1 file changed, 55 deletions(-) delete mode 100644 docs/superpowers/plans/2026-04-21-prover-parallelism.md diff --git a/docs/superpowers/plans/2026-04-21-prover-parallelism.md b/docs/superpowers/plans/2026-04-21-prover-parallelism.md deleted file mode 100644 index 84371d635..000000000 --- a/docs/superpowers/plans/2026-04-21-prover-parallelism.md +++ /dev/null @@ -1,55 +0,0 @@ -# Prover Parallelism Improvements - -**Goal:** Close the ~2x performance gap with Plonky3 by parallelizing the hottest sequential code paths in the STARK prover. - -**Architecture:** Five independent parallelism improvements. No protocol changes, no new data structures. Each produces identical outputs but runs faster. All use Rayon gated behind `#[cfg(feature = "parallel")]`. - ---- - -## Task 1: Parallel FRI fold - -**File:** `crypto/stark/src/fri/fri_functions.rs` - -`fold_evaluations_in_place` is a plain `for j in 0..half` loop. At layer 0 with LDE 2^21, ~1M sequential iterations. Each output depends only on its own pair -- no cross-iteration dependency. Use `into_par_iter` with a temp buffer (aliasing prevents in-place parallel write). - -## Task 2: Parallel FRI leaf construction - -**File:** `crypto/stark/src/fri/mod.rs` - -Leaf array `evals.chunks_exact(2).map(...).collect()` is sequential. Use `par_chunks_exact`. - -## Task 3: Parallel LogUp fingerprint computation - -**File:** `crypto/stark/src/lookup.rs` - -Two fingerprint loops in `compute_logup_batched_term_column` (lines 1619-1650): sequential over trace_len rows. Each row reads shared immutable data. Use `into_par_iter`. Also parallelize the final term computation loop and `compute_logup_term_column`. - -## Task 4: Parallel table_contribution sum - -**File:** `crypto/stark/src/lookup.rs` - -`build_accumulated_column_from_terms` sums term columns across all rows sequentially. Use `into_par_iter` with `reduce`. Note: the accumulated column running-sum loop CANNOT be parallelized. - -## Task 5: Chunked parallel batch inverse - -**Files:** `crypto/math/src/field/element.rs`, `crypto/stark/src/lookup.rs` - -Montgomery batch inverse is sequential. Split into K=num_threads chunks, run one independent batch inverse per chunk via `par_chunks_mut`. Cost: K-1 extra inversions, but O(N/K) per thread. Threshold at 1024 elements. - -## Task 6: Benchmark and validate - -Run `cargo bench --bench profile_vm_prover --features "parallel,instruments"`. Compare against baseline. Push. - ---- - -## Expected Impact - -| Optimization | Sequential cost | Parallel | Speedup | -|---|---|---|---| -| FRI fold | O(N) ext-field per layer | O(N/P) | ~P | -| FRI leaves | O(N) clones per layer | O(N/P) | ~P | -| Fingerprint loops | O(N) F*E per pair | O(N/P) | ~P | -| Batch inverse | O(N) prefix-suffix | O(N/P + K inv) | ~P large N | -| table_contribution | O(N*cols) | O(N*cols/P) | ~P | - -These target ~20-30% of prover time. Combined with MMCS/shared-FRI (~30-40%), closes most of the 2x gap. From 895f8aaf1d37638d1f1f771dbb375fa312d4adf5 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 18:38:39 -0300 Subject: [PATCH 15/17] fix: resolve clippy warnings (div_ceil, loop variable, complex type) --- crypto/stark/src/lookup.rs | 6 +++--- crypto/stark/src/prover.rs | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index a24e2d824..946799143 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -1978,7 +1978,7 @@ where // Finally write the accumulated column to trace (sequential, since set_aux takes &mut). #[cfg(feature = "parallel")] let accumulated_col = { - let num_chunks = (trace_len + LOGUP_CHUNK_SIZE - 1) / LOGUP_CHUNK_SIZE; + let num_chunks = trace_len.div_ceil(LOGUP_CHUNK_SIZE); // Phase 1: Compute chunk-local prefix sums let chunk_data: Vec<(Vec>, FieldElement)> = (0..num_chunks) @@ -1989,8 +1989,8 @@ where let mut local_prefix = Vec::with_capacity(end - start); let mut acc = FieldElement::::zero(); - for row in start..end { - acc = &acc + &row_sums[row] - &offset_per_row; + for rs in &row_sums[start..end] { + acc = &acc + rs - &offset_per_row; local_prefix.push(acc.clone()); } let chunk_total = acc; diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 6f77a1306..e890a5534 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -1524,10 +1524,9 @@ pub trait IsStarkProver< // Deduplicate Domain + LdeTwiddles by (trace_length, blowup_factor). // Many tables share the same domain size (e.g., 7+ tables at 2^20). // Without dedup, each creates its own Domain (~24 MB) and LdeTwiddles (~32 MB). - let mut domain_cache: std::collections::HashMap< - (usize, usize), - (Arc>, Arc>), - > = std::collections::HashMap::new(); + type DomainEntry = (Arc>, Arc>); + let mut domain_cache: std::collections::HashMap<(usize, usize), DomainEntry> = + std::collections::HashMap::new(); let mut domains = Vec::with_capacity(num_airs); let mut twiddle_caches: Vec>> = Vec::with_capacity(num_airs); From 85e5c61bb146b55636281754159a18768941cbd2 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Tue, 21 Apr 2026 18:43:27 -0300 Subject: [PATCH 16/17] style: cargo fmt --all --- crypto/stark/src/lookup.rs | 75 ++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 946799143..2a4416565 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -1791,33 +1791,36 @@ where // Output: one FieldElement per row let mut result = vec![FieldElement::::zero(); trace_len]; - result.par_chunks_mut(LOGUP_CHUNK_SIZE).enumerate().for_each( - |(chunk_idx, result_chunk)| { + result + .par_chunks_mut(LOGUP_CHUNK_SIZE) + .enumerate() + .for_each(|(chunk_idx, result_chunk)| { let chunk_start = chunk_idx * LOGUP_CHUNK_SIZE; let chunk_len = result_chunk.len(); // Phase 1: Compute fingerprints for both interactions in this chunk. // Layout: [fp_a[0..chunk_len], fp_b[0..chunk_len]] - let compute_chunk_fingerprints = |interaction: &BusInteraction, - bus_id_f: &FieldElement, - fps: &mut Vec>| { - for row in chunk_start..chunk_start + chunk_len { - let mut lc = bus_id_f * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc, - &shifts, - ); - alpha_offset += consumed; + let compute_chunk_fingerprints = + |interaction: &BusInteraction, + bus_id_f: &FieldElement, + fps: &mut Vec>| { + for row in chunk_start..chunk_start + chunk_len { + let mut lc = bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc, + &shifts, + ); + alpha_offset += consumed; + } + fps.push(z - &lc); } - fps.push(z - &lc); - } - }; + }; let mut fingerprints: Vec> = Vec::with_capacity(2 * chunk_len); compute_chunk_fingerprints(interaction_a, &bus_id_a, &mut fingerprints); @@ -1833,8 +1836,16 @@ where let fp_a_inv = &fingerprints[i]; let fp_b_inv = &fingerprints[chunk_len + i]; - let m_a = compute_multiplicity_for_row(&interaction_a.multiplicity, main_segment_cols, row); - let m_b = compute_multiplicity_for_row(&interaction_b.multiplicity, main_segment_cols, row); + let m_a = compute_multiplicity_for_row( + &interaction_a.multiplicity, + main_segment_cols, + row, + ); + let m_b = compute_multiplicity_for_row( + &interaction_b.multiplicity, + main_segment_cols, + row, + ); let term_a = &m_a * fp_a_inv; let term_b = &m_b * fp_b_inv; @@ -1842,8 +1853,7 @@ where let term_b = if negate_b { -term_b } else { term_b }; *result_elem = term_a + term_b; } - }, - ); + }); result } @@ -1877,8 +1887,10 @@ where let mut result = vec![FieldElement::::zero(); trace_len]; - result.par_chunks_mut(LOGUP_CHUNK_SIZE).enumerate().for_each( - |(chunk_idx, result_chunk)| { + result + .par_chunks_mut(LOGUP_CHUNK_SIZE) + .enumerate() + .for_each(|(chunk_idx, result_chunk)| { let chunk_start = chunk_idx * LOGUP_CHUNK_SIZE; let chunk_len = result_chunk.len(); @@ -1909,12 +1921,12 @@ where // Phase 3: Compute terms: ±(m * fp_inv) for (i, result_elem) in result_chunk.iter_mut().enumerate() { let row = chunk_start + i; - let m = compute_multiplicity_for_row(&interaction.multiplicity, main_segment_cols, row); + let m = + compute_multiplicity_for_row(&interaction.multiplicity, main_segment_cols, row); let term = &m * &fingerprints[i]; *result_elem = if negate { -term } else { term }; } - }, - ); + }); result } @@ -1959,7 +1971,8 @@ where .cloned() .reduce(FieldElement::zero, |a, b| a + b); #[cfg(not(feature = "parallel"))] - let table_contribution: FieldElement = row_sums.iter().fold(FieldElement::zero(), |a, b| a + b); + let table_contribution: FieldElement = + row_sums.iter().fold(FieldElement::zero(), |a, b| a + b); // offset_per_row = L / N let n = FieldElement::::from(trace_len as u64); From 4ec3d920c4f8c3f1388ea801cb57e1f6e2c29617 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Wed, 22 Apr 2026 10:18:48 -0300 Subject: [PATCH 17/17] fix: update run_debug_checks to accept Arc and Arc --- crypto/stark/src/prover.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index e890a5534..5dcd185b4 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -645,8 +645,8 @@ pub trait IsStarkProver< fn run_debug_checks( air_trace_pairs: &[AirTracePair<'_, Field, FieldExtension, PI>], commitments: &[Round1Commitments], - domains: &[Domain], - twiddle_caches: &[LdeTwiddles], + domains: &[Arc>], + twiddle_caches: &[Arc>], ) where FieldElement: AsBytes, FieldElement: AsBytes,