diff --git a/crypto/crypto/src/merkle_tree/traits.rs b/crypto/crypto/src/merkle_tree/traits.rs index c09cff9d0..81665278b 100644 --- a/crypto/crypto/src/merkle_tree/traits.rs +++ b/crypto/crypto/src/merkle_tree/traits.rs @@ -16,11 +16,18 @@ pub trait IsMerkleTreeBackend { /// tree will be built from and converts it to a list of leaf nodes. fn hash_leaves(unhashed_leaves: &[Self::Data]) -> Vec { #[cfg(feature = "parallel")] - let iter = unhashed_leaves.par_iter(); - #[cfg(not(feature = "parallel"))] - let iter = unhashed_leaves.iter(); - - iter.map(|leaf| Self::hash_data(leaf)).collect() + { + if unhashed_leaves.len() >= 1024 { + return unhashed_leaves + .par_iter() + .map(|leaf| Self::hash_data(leaf)) + .collect(); + } + } + unhashed_leaves + .iter() + .map(|leaf| Self::hash_data(leaf)) + .collect() } /// This function takes to children nodes and builds a new parent node. diff --git a/crypto/crypto/src/merkle_tree/utils.rs b/crypto/crypto/src/merkle_tree/utils.rs index 7cc64166b..bd971bb9c 100644 --- a/crypto/crypto/src/merkle_tree/utils.rs +++ b/crypto/crypto/src/merkle_tree/utils.rs @@ -78,17 +78,36 @@ where let (new_level_iter, children_iter) = nodes[new_level_begin_index..level_end_index + 1].split_at_mut(new_level_length); + // Skip Rayon for small levels: the scheduling overhead exceeds + // computation for levels with fewer than 1024 nodes. This avoids + // hundreds of unnecessary task spawns from small FRI layer trees. #[cfg(feature = "parallel")] - let parent_and_children_zipped_iter = new_level_iter - .into_par_iter() - .zip(children_iter.par_chunks_exact(2)); + { + if new_level_length >= 1024 { + new_level_iter + .into_par_iter() + .zip(children_iter.par_chunks_exact(2)) + .for_each(|(new_parent, children)| { + *new_parent = B::hash_new_parent(&children[0], &children[1]); + }); + } else { + new_level_iter + .iter_mut() + .zip(children_iter.chunks_exact(2)) + .for_each(|(new_parent, children)| { + *new_parent = B::hash_new_parent(&children[0], &children[1]); + }); + } + } #[cfg(not(feature = "parallel"))] - let parent_and_children_zipped_iter = - new_level_iter.iter_mut().zip(children_iter.chunks_exact(2)); - - parent_and_children_zipped_iter.for_each(|(new_parent, children)| { - *new_parent = B::hash_new_parent(&children[0], &children[1]); - }); + { + new_level_iter + .iter_mut() + .zip(children_iter.chunks_exact(2)) + .for_each(|(new_parent, children)| { + *new_parent = B::hash_new_parent(&children[0], &children[1]); + }); + } level_end_index = level_begin_index - 1; level_begin_index = new_level_begin_index; diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 770193020..9992ded53 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -19,7 +19,10 @@ use math::field::{ traits::{IsFFTField, IsField, IsPrimeField, IsSubFieldOf}, }; #[cfg(feature = "parallel")] -use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use rayon::prelude::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, + ParallelSliceMut, +}; // ============================================================================= // Shift Constants for Type Combining @@ -100,6 +103,11 @@ pub const LOGUP_CHALLENGE_ALPHA: usize = 1; /// Number of challenges required by the LogUp protocol. pub const LOGUP_NUM_CHALLENGES: usize = 2; +/// Chunk size for fused chunk-local LogUp processing. +/// Each chunk processes all interactions for CHUNK_SIZE rows, fitting in L2 cache. +#[cfg(feature = "parallel")] +const LOGUP_CHUNK_SIZE: usize = 1024; + /// Split N interactions into committed batched pairs and absorbed remainder. /// /// Returns `(num_committed_pairs, absorbed_count)` where: @@ -1028,23 +1036,24 @@ where // Clone main columns once (shared across all interactions) let main_segment_cols = trace.columns_main(); let trace_len = trace.num_rows(); - let table_name = self.name.as_deref().unwrap_or("UNKNOWN"); + let _table_name = self.name.as_deref().unwrap_or("UNKNOWN"); // Split interactions: committed pairs get term columns, last 1-2 are absorbed (virtual) let (num_committed_pairs, absorbed_count) = split_interactions(num_interactions); - // Compute committed term columns in parallel (batched pairs only) + // Compute committed term columns (batched pairs only). + // With `parallel`: sequential over pairs, each using chunk-local parallelism + // (parallel across row-chunks, not across pairs) for better cache locality. + // Without `parallel`: sequential over pairs, sequential over rows. #[cfg(feature = "parallel")] let committed_columns: Vec>> = (0..num_committed_pairs) - .into_par_iter() .map(|i| { - compute_logup_batched_term_column( + compute_logup_batched_term_column_chunked( &self.auxiliary_trace_build_data.interactions[i * 2], &self.auxiliary_trace_build_data.interactions[i * 2 + 1], &main_segment_cols, trace_len, challenges, - table_name, ) }) .collect(); @@ -1057,12 +1066,30 @@ where &main_segment_cols, trace_len, challenges, - table_name, + _table_name, ) }) .collect(); // Compute virtual column for absorbed interactions (NOT written to trace) + #[cfg(feature = "parallel")] + let virtual_column = if absorbed_count == 2 { + compute_logup_batched_term_column_chunked( + &self.auxiliary_trace_build_data.interactions[num_interactions - 2], + &self.auxiliary_trace_build_data.interactions[num_interactions - 1], + &main_segment_cols, + trace_len, + challenges, + ) + } else { + compute_logup_term_column_chunked( + &self.auxiliary_trace_build_data.interactions[num_interactions - 1], + &main_segment_cols, + trace_len, + challenges, + ) + }; + #[cfg(not(feature = "parallel"))] let virtual_column = if absorbed_count == 2 { compute_logup_batched_term_column( &self.auxiliary_trace_build_data.interactions[num_interactions - 2], @@ -1070,7 +1097,7 @@ where &main_segment_cols, trace_len, challenges, - table_name, + _table_name, ) } else { compute_logup_term_column( @@ -1078,7 +1105,7 @@ where &main_segment_cols, trace_len, challenges, - table_name, + _table_name, ) }; @@ -1096,7 +1123,7 @@ where &main_segment_cols, trace_len, challenges, - table_name, + _table_name, ); // Build accumulated from all columns (committed + virtual) @@ -1373,6 +1400,7 @@ where /// This is a pure function that takes shared main columns and returns the computed column, /// enabling parallel computation across interactions within a table. #[allow(clippy::needless_range_loop)] +#[cfg_attr(feature = "parallel", allow(dead_code))] fn compute_logup_term_column( table_interaction: &BusInteraction, main_segment_cols: &[Vec>], @@ -1537,6 +1565,7 @@ where /// /// Uses a single batch inversion for both fingerprint vectors (2*N elements). #[allow(clippy::needless_range_loop)] +#[cfg_attr(feature = "parallel", allow(dead_code))] fn compute_logup_batched_term_column( interaction_a: &BusInteraction, interaction_b: &BusInteraction, @@ -1680,6 +1709,240 @@ where .collect() } +/// Computes the multiplicity for a single row of an interaction. +/// +/// This avoids materializing a full Vec> of length trace_len +/// when processing rows in chunks. +#[cfg(feature = "parallel")] +#[inline] +fn compute_multiplicity_for_row( + multiplicity: &Multiplicity, + main_segment_cols: &[Vec>], + row: usize, +) -> FieldElement { + match multiplicity { + Multiplicity::One => FieldElement::one(), + Multiplicity::Column(col) => main_segment_cols[*col][row].clone(), + Multiplicity::Sum(col_a, col_b) => { + &main_segment_cols[*col_a][row] + &main_segment_cols[*col_b][row] + } + Multiplicity::Negated(col) => FieldElement::::one() - &main_segment_cols[*col][row], + Multiplicity::Diff(col_a, col_b) => { + &main_segment_cols[*col_a][row] - &main_segment_cols[*col_b][row] + } + Multiplicity::Sum3(col_a, col_b, col_c) => { + &main_segment_cols[*col_a][row] + + &main_segment_cols[*col_b][row] + + &main_segment_cols[*col_c][row] + } + Multiplicity::Linear(terms) => { + let mut result = FieldElement::::zero(); + for term in terms { + match *term { + LinearTerm::Column { + coefficient, + column, + } => { + let coeff = FieldElement::::from(coefficient); + result += &main_segment_cols[column][row] * coeff; + } + LinearTerm::ColumnUnsigned { + coefficient, + column, + } => { + let coeff = FieldElement::::from(coefficient); + result += &main_segment_cols[column][row] * coeff; + } + LinearTerm::Constant(value) => { + result += FieldElement::::from(value); + } + } + } + result + } + } +} + +/// Chunk-local batched term column computation for two interactions. +/// +/// Processes rows in chunks of `LOGUP_CHUNK_SIZE`. Per chunk: +/// 1. Compute 2*CHUNK fingerprints (interaction_a and interaction_b) +/// 2. Batch-invert locally (one Montgomery inverse per chunk) +/// 3. Compute terms: m_a/fp_a +/- m_b/fp_b +/// +/// Parallelism is across row-chunks (not across interaction pairs), giving +/// much better cache locality: each thread touches only CHUNK_SIZE rows of +/// main trace data before moving to the next phase. +#[cfg(feature = "parallel")] +fn compute_logup_batched_term_column_chunked( + interaction_a: &BusInteraction, + interaction_b: &BusInteraction, + main_segment_cols: &[Vec>], + trace_len: usize, + challenges: &[FieldElement], +) -> Vec> +where + F: IsFFTField + IsSubFieldOf + IsPrimeField + Send + Sync, + E: IsField + Send + Sync, +{ + let z = &challenges[0]; + let alpha = &challenges[LOGUP_CHALLENGE_ALPHA]; + + let max_bus_elements = interaction_a + .num_bus_elements() + .max(interaction_b.num_bus_elements()); + let alpha_powers = compute_alpha_powers(alpha, max_bus_elements); + + let negate_a = !interaction_a.is_sender; + let negate_b = !interaction_b.is_sender; + + let bus_id_a = FieldElement::::from(interaction_a.bus_id); + let bus_id_b = FieldElement::::from(interaction_b.bus_id); + let shifts = PackingShifts::::new(); + + // Output: one FieldElement per row + let mut result = vec![FieldElement::::zero(); trace_len]; + + result + .par_chunks_mut(LOGUP_CHUNK_SIZE) + .enumerate() + .for_each(|(chunk_idx, result_chunk)| { + let chunk_start = chunk_idx * LOGUP_CHUNK_SIZE; + let chunk_len = result_chunk.len(); + + // Phase 1: Compute fingerprints for both interactions in this chunk. + // Layout: [fp_a[0..chunk_len], fp_b[0..chunk_len]] + let compute_chunk_fingerprints = + |interaction: &BusInteraction, + bus_id_f: &FieldElement, + fps: &mut Vec>| { + for row in chunk_start..chunk_start + chunk_len { + let mut lc = bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc, + &shifts, + ); + alpha_offset += consumed; + } + fps.push(z - &lc); + } + }; + + let mut fingerprints: Vec> = Vec::with_capacity(2 * chunk_len); + compute_chunk_fingerprints(interaction_a, &bus_id_a, &mut fingerprints); + compute_chunk_fingerprints(interaction_b, &bus_id_b, &mut fingerprints); + + // Phase 2: Batch-invert all fingerprints in this chunk + FieldElement::inplace_batch_inverse(&mut fingerprints) + .expect("fingerprint is zero - probability of sampling zero is negligible"); + + // Phase 3: Compute terms: m_a/fp_a +/- m_b/fp_b + for (i, result_elem) in result_chunk.iter_mut().enumerate() { + let row = chunk_start + i; + let fp_a_inv = &fingerprints[i]; + let fp_b_inv = &fingerprints[chunk_len + i]; + + let m_a = compute_multiplicity_for_row( + &interaction_a.multiplicity, + main_segment_cols, + row, + ); + let m_b = compute_multiplicity_for_row( + &interaction_b.multiplicity, + main_segment_cols, + row, + ); + + let term_a = &m_a * fp_a_inv; + let term_b = &m_b * fp_b_inv; + let term_a = if negate_a { -term_a } else { term_a }; + let term_b = if negate_b { -term_b } else { term_b }; + *result_elem = term_a + term_b; + } + }); + + result +} + +/// Chunk-local single-interaction term column computation. +/// +/// Same cache-locality benefits as `compute_logup_batched_term_column_chunked` +/// but for a single interaction (used for the virtual absorbed column when +/// `absorbed_count == 1`). +#[cfg(feature = "parallel")] +fn compute_logup_term_column_chunked( + interaction: &BusInteraction, + main_segment_cols: &[Vec>], + trace_len: usize, + challenges: &[FieldElement], +) -> Vec> +where + F: IsFFTField + IsSubFieldOf + IsPrimeField + Send + Sync, + E: IsField + Send + Sync, +{ + let z = &challenges[0]; + let alpha = &challenges[LOGUP_CHALLENGE_ALPHA]; + + let num_bus_elements = interaction.num_bus_elements(); + let alpha_powers = compute_alpha_powers(alpha, num_bus_elements); + + let negate = !interaction.is_sender; + + let bus_id_f = FieldElement::::from(interaction.bus_id); + let shifts = PackingShifts::::new(); + + let mut result = vec![FieldElement::::zero(); trace_len]; + + result + .par_chunks_mut(LOGUP_CHUNK_SIZE) + .enumerate() + .for_each(|(chunk_idx, result_chunk)| { + let chunk_start = chunk_idx * LOGUP_CHUNK_SIZE; + let chunk_len = result_chunk.len(); + + // Phase 1: Compute fingerprints for this chunk + let mut fingerprints: Vec> = Vec::with_capacity(chunk_len); + + for row in chunk_start..chunk_start + chunk_len { + let mut lc = &bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc, + &shifts, + ); + alpha_offset += consumed; + } + fingerprints.push(z - &lc); + } + + // Phase 2: Batch-invert fingerprints in this chunk + FieldElement::inplace_batch_inverse(&mut fingerprints) + .expect("fingerprint is zero - probability of sampling zero is negligible"); + + // Phase 3: Compute terms: ±(m * fp_inv) + for (i, result_elem) in result_chunk.iter_mut().enumerate() { + let row = chunk_start + i; + let m = + compute_multiplicity_for_row(&interaction.multiplicity, main_segment_cols, row); + let term = &m * &fingerprints[i]; + *result_elem = if negate { -term } else { term }; + } + }); + + result +} + /// Builds the circular accumulated column from pre-computed term columns. /// /// For the circular constraint: acc[(i+1) mod N] - acc[i] - terms[(i+1) mod N] + L/N = 0 @@ -1701,27 +1964,101 @@ where } let trace_len = term_columns[0].len(); - // Compute L = sum of all terms across all rows - let mut table_contribution = FieldElement::::zero(); - for row in 0..trace_len { - for col in term_columns { - table_contribution = &table_contribution + &col[row]; - } - } + // Precompute row_sums[row] = sum of all term_columns at that row. + // This avoids recomputing during the prefix sum and enables parallel reduction. + let row_sums: Vec> = (0..trace_len) + .map(|row| { + let mut s = FieldElement::::zero(); + for col in term_columns { + s = s + &col[row]; + } + s + }) + .collect(); + + // Compute L = sum of all row_sums (parallel when feature enabled) + #[cfg(feature = "parallel")] + let table_contribution: FieldElement = row_sums + .par_iter() + .cloned() + .reduce(FieldElement::zero, |a, b| a + b); + #[cfg(not(feature = "parallel"))] + let table_contribution: FieldElement = + row_sums.iter().fold(FieldElement::zero(), |a, b| a + b); // offset_per_row = L / N let n = FieldElement::::from(trace_len as u64); let offset_per_row = &table_contribution * n.inv().unwrap(); - // Build circular accumulated column - let mut accumulated = FieldElement::::zero(); - for row in 0..trace_len { - let mut row_sum = FieldElement::::zero(); - for col in term_columns { - row_sum = row_sum + &col[row]; + // Build circular accumulated column using 3-phase parallel prefix sum. + // + // Phase 1: Compute chunk-local prefix sums in parallel. + // Each chunk computes partial_sums[i] = Σ(row_sums[j] - offset) for j in chunk. + // Also stores the chunk's total as `chunk_totals[chunk_idx]`. + // + // Phase 2: Sequential scan of chunk_totals to compute offsets for each chunk. + // + // Phase 3: Add chunk offset to each element in the accumulated vector. + // + // Finally write the accumulated column to trace (sequential, since set_aux takes &mut). + #[cfg(feature = "parallel")] + let accumulated_col = { + let num_chunks = trace_len.div_ceil(LOGUP_CHUNK_SIZE); + + // Phase 1: Compute chunk-local prefix sums + let chunk_data: Vec<(Vec>, FieldElement)> = (0..num_chunks) + .into_par_iter() + .map(|chunk_idx| { + let start = chunk_idx * LOGUP_CHUNK_SIZE; + let end = (start + LOGUP_CHUNK_SIZE).min(trace_len); + + let mut local_prefix = Vec::with_capacity(end - start); + let mut acc = FieldElement::::zero(); + for rs in &row_sums[start..end] { + acc = &acc + rs - &offset_per_row; + local_prefix.push(acc.clone()); + } + let chunk_total = acc; + (local_prefix, chunk_total) + }) + .collect(); + + // Phase 2: Sequential scan of chunk totals to get per-chunk offsets + let mut chunk_offsets = Vec::with_capacity(num_chunks); + let mut running = FieldElement::::zero(); + for (_, chunk_total) in &chunk_data { + chunk_offsets.push(running.clone()); + running = &running + chunk_total; } - accumulated = &accumulated + &row_sum - &offset_per_row; - trace.set_aux(row, acc_column_idx, accumulated.clone()); + + // Phase 3: Build final accumulated vector (parallel across chunks) + let mut acc_col = vec![FieldElement::::zero(); trace_len]; + acc_col + .par_chunks_mut(LOGUP_CHUNK_SIZE) + .enumerate() + .for_each(|(chunk_idx, out_chunk)| { + let offset = &chunk_offsets[chunk_idx]; + for (i, out) in out_chunk.iter_mut().enumerate() { + *out = offset + &chunk_data[chunk_idx].0[i]; + } + }); + acc_col + }; + + #[cfg(not(feature = "parallel"))] + let accumulated_col = { + let mut col = Vec::with_capacity(trace_len); + let mut accumulated = FieldElement::::zero(); + for row_sum in &row_sums { + accumulated = &accumulated + row_sum - &offset_per_row; + col.push(accumulated.clone()); + } + col + }; + + // Write accumulated column to trace + for (row, value) in accumulated_col.into_iter().enumerate() { + trace.set_aux(row, acc_column_idx, value); } table_contribution diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 41ccb8366..135b5ba16 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -645,8 +645,8 @@ pub trait IsStarkProver< fn run_debug_checks( air_trace_pairs: &[AirTracePair<'_, Field, FieldExtension, PI>], commitments: &[Round1Commitments], - domains: &[Domain], - twiddle_caches: &[LdeTwiddles], + domains: &[Arc>], + twiddle_caches: &[Arc>], ) where FieldElement: AsBytes, FieldElement: AsBytes, @@ -1523,13 +1523,29 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] let phase_start = Instant::now(); + // Deduplicate Domain + LdeTwiddles by (trace_length, blowup_factor). + // Many tables share the same domain size (e.g., 7+ tables at 2^20). + // Without dedup, each creates its own Domain (~24 MB) and LdeTwiddles (~32 MB). + type DomainEntry = (Arc>, Arc>); + let mut domain_cache: std::collections::HashMap<(usize, usize), DomainEntry> = + std::collections::HashMap::new(); + let mut domains = Vec::with_capacity(num_airs); - let mut twiddle_caches: Vec> = Vec::with_capacity(num_airs); + let mut twiddle_caches: Vec>> = Vec::with_capacity(num_airs); for (air, trace, _pub_inputs) in &*air_trace_pairs { let trace_length = trace.num_rows(); - let domain = new_domain(*air, trace_length); - let twiddles = LdeTwiddles::new(&domain); + let blowup = air.options().blowup_factor as usize; + let key = (trace_length, blowup); + + let (domain, twiddles) = domain_cache + .entry(key) + .or_insert_with(|| { + let d = new_domain(*air, trace_length); + let t = LdeTwiddles::new(&d); + (Arc::new(d), Arc::new(t)) + }) + .clone(); domains.push(domain); twiddle_caches.push(twiddles);