diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 770193020..17ba7c5ec 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -19,7 +19,9 @@ use math::field::{ traits::{IsFFTField, IsField, IsPrimeField, IsSubFieldOf}, }; #[cfg(feature = "parallel")] -use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use rayon::prelude::{ + IndexedParallelIterator, IntoParallelIterator, ParallelIterator, ParallelSliceMut, +}; // ============================================================================= // Shift Constants for Type Combining @@ -100,6 +102,11 @@ pub const LOGUP_CHALLENGE_ALPHA: usize = 1; /// Number of challenges required by the LogUp protocol. pub const LOGUP_NUM_CHALLENGES: usize = 2; +/// Chunk size for fused chunk-local LogUp processing. +/// Each chunk processes all interactions for CHUNK_SIZE rows, fitting in L2 cache. +#[cfg(feature = "parallel")] +const LOGUP_CHUNK_SIZE: usize = 1024; + /// Split N interactions into committed batched pairs and absorbed remainder. /// /// Returns `(num_committed_pairs, absorbed_count)` where: @@ -1028,26 +1035,46 @@ where // Clone main columns once (shared across all interactions) let main_segment_cols = trace.columns_main(); let trace_len = trace.num_rows(); - let table_name = self.name.as_deref().unwrap_or("UNKNOWN"); + let _table_name = self.name.as_deref().unwrap_or("UNKNOWN"); // Split interactions: committed pairs get term columns, last 1-2 are absorbed (virtual) let (num_committed_pairs, absorbed_count) = split_interactions(num_interactions); - // Compute committed term columns in parallel (batched pairs only) + // Compute committed term columns (batched pairs only). + // With `parallel`: when `trace_len > LOGUP_CHUNK_SIZE` the chunk-internal + // parallelism inside each pair already saturates Rayon, so iterate pairs + // sequentially to keep cache locality. When `trace_len <= LOGUP_CHUNK_SIZE` + // each pair yields a single chunk, so parallelize across pairs to recover + // the throughput the per-pair dispatch used to provide for small-trace + // tables with many interactions. + // Without `parallel`: sequential over pairs, sequential over rows. #[cfg(feature = "parallel")] - let committed_columns: Vec>> = (0..num_committed_pairs) - .into_par_iter() - .map(|i| { - compute_logup_batched_term_column( - &self.auxiliary_trace_build_data.interactions[i * 2], - &self.auxiliary_trace_build_data.interactions[i * 2 + 1], - &main_segment_cols, - trace_len, - challenges, - table_name, - ) - }) - .collect(); + let committed_columns: Vec>> = if trace_len <= LOGUP_CHUNK_SIZE { + (0..num_committed_pairs) + .into_par_iter() + .map(|i| { + compute_logup_batched_term_column( + &self.auxiliary_trace_build_data.interactions[i * 2], + &self.auxiliary_trace_build_data.interactions[i * 2 + 1], + &main_segment_cols, + trace_len, + challenges, + ) + }) + .collect() + } else { + (0..num_committed_pairs) + .map(|i| { + compute_logup_batched_term_column( + &self.auxiliary_trace_build_data.interactions[i * 2], + &self.auxiliary_trace_build_data.interactions[i * 2 + 1], + &main_segment_cols, + trace_len, + challenges, + ) + }) + .collect() + }; #[cfg(not(feature = "parallel"))] let committed_columns: Vec>> = (0..num_committed_pairs) .map(|i| { @@ -1057,7 +1084,6 @@ where &main_segment_cols, trace_len, challenges, - table_name, ) }) .collect(); @@ -1070,7 +1096,6 @@ where &main_segment_cols, trace_len, challenges, - table_name, ) } else { compute_logup_term_column( @@ -1078,7 +1103,7 @@ where &main_segment_cols, trace_len, challenges, - table_name, + _table_name, ) }; @@ -1096,7 +1121,7 @@ where &main_segment_cols, trace_len, challenges, - table_name, + _table_name, ); // Build accumulated from all columns (committed + virtual) @@ -1216,6 +1241,58 @@ pub enum Multiplicity { Linear(Vec), } +impl Multiplicity { + /// Evaluate the multiplicity for a single row. + #[inline] + fn evaluate_at_row( + &self, + main_segment_cols: &[Vec>], + row: usize, + ) -> FieldElement { + match self { + Multiplicity::One => FieldElement::one(), + Multiplicity::Column(col) => main_segment_cols[*col][row].clone(), + Multiplicity::Sum(col_a, col_b) => { + &main_segment_cols[*col_a][row] + &main_segment_cols[*col_b][row] + } + Multiplicity::Negated(col) => FieldElement::::one() - &main_segment_cols[*col][row], + Multiplicity::Diff(col_a, col_b) => { + &main_segment_cols[*col_a][row] - &main_segment_cols[*col_b][row] + } + Multiplicity::Sum3(col_a, col_b, col_c) => { + &main_segment_cols[*col_a][row] + + &main_segment_cols[*col_b][row] + + &main_segment_cols[*col_c][row] + } + Multiplicity::Linear(terms) => { + let mut result = FieldElement::::zero(); + for term in terms { + match *term { + LinearTerm::Column { + coefficient, + column, + } => { + let coeff = FieldElement::::from(coefficient); + result += &main_segment_cols[column][row] * coeff; + } + LinearTerm::ColumnUnsigned { + coefficient, + column, + } => { + let coeff = FieldElement::::from(coefficient); + result += &main_segment_cols[column][row] * coeff; + } + LinearTerm::Constant(value) => { + result += FieldElement::::from(value); + } + } + } + result + } + } + } +} + /// Struct representing a lookup interaction for a given table. /// Contains the multiplicity and bus values involved in said interaction. /// @@ -1372,7 +1449,10 @@ where /// /// This is a pure function that takes shared main columns and returns the computed column, /// enabling parallel computation across interactions within a table. -#[allow(clippy::needless_range_loop)] +/// +/// With `parallel`: processes rows in chunks of `LOGUP_CHUNK_SIZE` via `par_chunks_mut`, +/// giving good cache locality (each thread touches only CHUNK_SIZE rows before moving on). +/// Without `parallel`: processes all rows as a single chunk (equivalent to the old sequential path). fn compute_logup_term_column( table_interaction: &BusInteraction, main_segment_cols: &[Vec>], @@ -1384,166 +1464,103 @@ where F: IsFFTField + IsSubFieldOf + IsPrimeField + Send + Sync, E: IsField + Send + Sync, { - // Handle multiplicity column(s) - let multiplicities_owned: Vec>; - let multiplicities: &[FieldElement] = match table_interaction.multiplicity { - Multiplicity::One => { - multiplicities_owned = vec![FieldElement::one(); trace_len]; - &multiplicities_owned - } - Multiplicity::Column(col) => &main_segment_cols[col], - Multiplicity::Sum(col_a, col_b) => { - multiplicities_owned = main_segment_cols[col_a] - .iter() - .zip(main_segment_cols[col_b].iter()) - .map(|(a, b)| a + b) - .collect(); - &multiplicities_owned - } - Multiplicity::Negated(col) => { - multiplicities_owned = main_segment_cols[col] - .iter() - .map(|elem| FieldElement::::one() - elem) - .collect(); - &multiplicities_owned - } - Multiplicity::Diff(col_a, col_b) => { - multiplicities_owned = main_segment_cols[col_a] - .iter() - .zip(main_segment_cols[col_b].iter()) - .map(|(a, b)| a - b) - .collect(); - &multiplicities_owned - } - Multiplicity::Sum3(col_a, col_b, col_c) => { - multiplicities_owned = (0..trace_len) - .map(|row| { - &main_segment_cols[col_a][row] - + &main_segment_cols[col_b][row] - + &main_segment_cols[col_c][row] - }) - .collect(); - &multiplicities_owned - } - Multiplicity::Linear(ref terms) => { - multiplicities_owned = (0..trace_len) - .map(|row| { - let mut result = FieldElement::::zero(); - for term in terms { - match *term { - LinearTerm::Column { - coefficient, - column, - } => { - let coeff = FieldElement::::from(coefficient); - result += &main_segment_cols[column][row] * coeff; - } - LinearTerm::ColumnUnsigned { - coefficient, - column, - } => { - let coeff = FieldElement::::from(coefficient); - result += &main_segment_cols[column][row] * coeff; - } - LinearTerm::Constant(value) => { - result += FieldElement::::from(value); - } - } - } - result - }) - .collect(); - &multiplicities_owned - } - }; - - // LogUp challenges (must be shared across all tables for bus to balance) let z = &challenges[0]; let alpha = &challenges[LOGUP_CHALLENGE_ALPHA]; - - // Precompute powers of alpha for all bus elements (using incremental multiplication) let num_bus_elements = table_interaction.num_bus_elements(); let alpha_powers = compute_alpha_powers(alpha, num_bus_elements); - let negate = !table_interaction.is_sender; - - // Batch inversion: collect all fingerprints, invert once, then multiply back. - // Compute fingerprint = z - (bus_id*α^0 + v0*α^1 + v1*α^2 + ...) using - // base-field × extension-field multiplication (F×E→E) to avoid to_extension(). - // - // Zero-allocation inner loop: accumulate the linear combination directly - // into the fingerprint without collecting bus elements into intermediate Vecs. let bus_id_f = FieldElement::::from(table_interaction.bus_id); let shifts = PackingShifts::::new(); - let mut fingerprints: Vec> = Vec::with_capacity(trace_len); - for row in 0..trace_len { - // Accumulate fingerprint directly: bus_id * α^0 + Σ element_i * α^(1+i) - let mut linear_combination = &bus_id_f * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &table_interaction.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut linear_combination, - &shifts, - ); - alpha_offset += consumed; - } - fingerprints.push(z - &linear_combination); + let mut result = vec![FieldElement::::zero(); trace_len]; + + let process_chunk = |chunk_start: usize, result_chunk: &mut [FieldElement]| { + let chunk_len = result_chunk.len(); + + // Phase 1: Compute fingerprints + let mut fingerprints: Vec> = Vec::with_capacity(chunk_len); + for row in chunk_start..chunk_start + chunk_len { + let mut lc = &bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &table_interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc, + &shifts, + ); + alpha_offset += consumed; + } + fingerprints.push(z - &lc); - #[cfg(feature = "debug-checks")] - { - // Reconstruct base_elements for debug logging - let mut base_elements: Vec> = vec![bus_id_f.clone()]; - base_elements.extend( - table_interaction - .values - .iter() - .flat_map(|bv| bv.combine_from(|col| main_segment_cols[col][row].clone())), - ); - crate::bus_debug::log_interaction( - _table_name, - row, - table_interaction.bus_id, - table_interaction.is_sender, - &multiplicities[row].canonical(), - &base_elements, - fingerprints.last().unwrap(), - ); + #[cfg(feature = "debug-checks")] + { + let mut base_elements: Vec> = vec![bus_id_f.clone()]; + base_elements.extend( + table_interaction + .values + .iter() + .flat_map(|bv| bv.combine_from(|col| main_segment_cols[col][row].clone())), + ); + let multiplicity = table_interaction + .multiplicity + .evaluate_at_row(main_segment_cols, row); + crate::bus_debug::log_interaction( + _table_name, + row, + table_interaction.bus_id, + table_interaction.is_sender, + &multiplicity.canonical(), + &base_elements, + fingerprints.last().unwrap(), + ); + } } - } - FieldElement::inplace_batch_inverse(&mut fingerprints) - .expect("fingerprint is zero - probability of sampling zero is negligible"); - - // Compute terms: term[i] = ±(multiplicity[i] * fingerprint_inv[i]) - // Use conditional negation instead of E×E sign multiplication - multiplicities - .iter() - .zip(fingerprints.iter()) - .map(|(multiplicity, fingerprint_inv)| { - let term = multiplicity * fingerprint_inv; - if negate { -term } else { term } - }) - .collect() + // Phase 2: Batch-invert + FieldElement::inplace_batch_inverse(&mut fingerprints) + .expect("fingerprint is zero - probability of sampling zero is negligible"); + + // Phase 3: Compute terms + for (i, result_elem) in result_chunk.iter_mut().enumerate() { + let row = chunk_start + i; + let m = table_interaction + .multiplicity + .evaluate_at_row(main_segment_cols, row); + let term = &m * &fingerprints[i]; + *result_elem = if negate { -term } else { term }; + } + }; + + #[cfg(feature = "parallel")] + result + .par_chunks_mut(LOGUP_CHUNK_SIZE) + .enumerate() + .for_each(|(i, chunk)| process_chunk(i * LOGUP_CHUNK_SIZE, chunk)); + + #[cfg(not(feature = "parallel"))] + process_chunk(0, &mut result); + + result } /// Computes a batched term column for two interactions sharing one aux column. /// /// Each row contains: `term[i] = sign_a * m_a[i] / fp_a[i] + sign_b * m_b[i] / fp_b[i]` /// -/// Uses a single batch inversion for both fingerprint vectors (2*N elements). -#[allow(clippy::needless_range_loop)] +/// Uses chunk-local batch inversion for good cache locality: each chunk processes +/// both interactions for CHUNK_SIZE rows before moving on. +/// +/// With `parallel`: processes rows in chunks of `LOGUP_CHUNK_SIZE` via `par_chunks_mut`. +/// Without `parallel`: processes all rows as a single chunk (equivalent to the old sequential path). fn compute_logup_batched_term_column( interaction_a: &BusInteraction, interaction_b: &BusInteraction, main_segment_cols: &[Vec>], trace_len: usize, challenges: &[FieldElement], - #[cfg_attr(not(feature = "debug-checks"), allow(unused))] _table_name: &str, ) -> Vec> where F: IsFFTField + IsSubFieldOf + IsPrimeField + Send + Sync, @@ -1551,133 +1568,80 @@ where { let z = &challenges[0]; let alpha = &challenges[LOGUP_CHALLENGE_ALPHA]; - let max_bus_elements = interaction_a .num_bus_elements() .max(interaction_b.num_bus_elements()); let alpha_powers = compute_alpha_powers(alpha, max_bus_elements); - let negate_a = !interaction_a.is_sender; let negate_b = !interaction_b.is_sender; - - // Helper to compute multiplicities for an interaction - let compute_multiplicities = |interaction: &BusInteraction| -> Vec> { - match &interaction.multiplicity { - Multiplicity::One => vec![FieldElement::one(); trace_len], - Multiplicity::Column(col) => main_segment_cols[*col].clone(), - Multiplicity::Sum(col_a, col_b) => main_segment_cols[*col_a] - .iter() - .zip(main_segment_cols[*col_b].iter()) - .map(|(a, b)| a + b) - .collect(), - Multiplicity::Negated(col) => main_segment_cols[*col] - .iter() - .map(|elem| FieldElement::::one() - elem) - .collect(), - Multiplicity::Diff(col_a, col_b) => main_segment_cols[*col_a] - .iter() - .zip(main_segment_cols[*col_b].iter()) - .map(|(a, b)| a - b) - .collect(), - Multiplicity::Sum3(col_a, col_b, col_c) => (0..trace_len) - .map(|row| { - &main_segment_cols[*col_a][row] - + &main_segment_cols[*col_b][row] - + &main_segment_cols[*col_c][row] - }) - .collect(), - Multiplicity::Linear(terms) => (0..trace_len) - .map(|row| { - let mut result = FieldElement::::zero(); - for term in terms { - match *term { - LinearTerm::Column { - coefficient, - column, - } => { - let coeff = FieldElement::::from(coefficient); - result += &main_segment_cols[column][row] * coeff; - } - LinearTerm::ColumnUnsigned { - coefficient, - column, - } => { - let coeff = FieldElement::::from(coefficient); - result += &main_segment_cols[column][row] * coeff; - } - LinearTerm::Constant(value) => { - result += FieldElement::::from(value); - } - } - } - result - }) - .collect(), - } - }; - - let multiplicities_a = compute_multiplicities(interaction_a); - let multiplicities_b = compute_multiplicities(interaction_b); - - // Compute fingerprints for both interactions using accumulate_fingerprint - // (zero-allocation inner loop: F×E multiplication instead of to_extension()) let bus_id_a = FieldElement::::from(interaction_a.bus_id); let bus_id_b = FieldElement::::from(interaction_b.bus_id); let shifts = PackingShifts::::new(); - // Concatenate both fingerprint vectors for a single batch inversion - let mut all_fingerprints: Vec> = Vec::with_capacity(2 * trace_len); - - for row in 0..trace_len { - let mut lc_a = &bus_id_a * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction_a.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc_a, - &shifts, - ); - alpha_offset += consumed; - } - all_fingerprints.push(z - &lc_a); - } - for row in 0..trace_len { - let mut lc_b = &bus_id_b * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction_b.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc_b, - &shifts, - ); - alpha_offset += consumed; - } - all_fingerprints.push(z - &lc_b); - } + let mut result = vec![FieldElement::::zero(); trace_len]; + + let process_chunk = |chunk_start: usize, result_chunk: &mut [FieldElement]| { + let chunk_len = result_chunk.len(); + + // Phase 1: Compute fingerprints for both interactions + let compute_fps = |interaction: &BusInteraction, + bus_id_f: &FieldElement, + fps: &mut Vec>| { + for row in chunk_start..chunk_start + chunk_len { + let mut lc = bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc, + &shifts, + ); + alpha_offset += consumed; + } + fps.push(z - &lc); + } + }; - // Single batch inversion for all 2*N fingerprints - FieldElement::inplace_batch_inverse(&mut all_fingerprints) - .expect("fingerprint is zero - probability of sampling zero is negligible"); - - // Compute batched terms: term[i] = m_a[i] / fp_a[i] ± m_b[i] / fp_b[i] - // Use conditional negation instead of E×E sign multiplication - (0..trace_len) - .map(|row| { - let fp_a_inv = &all_fingerprints[row]; - let fp_b_inv = &all_fingerprints[trace_len + row]; - let term_a = &multiplicities_a[row] * fp_a_inv; - let term_b = &multiplicities_b[row] * fp_b_inv; + let mut fingerprints: Vec> = Vec::with_capacity(2 * chunk_len); + compute_fps(interaction_a, &bus_id_a, &mut fingerprints); + compute_fps(interaction_b, &bus_id_b, &mut fingerprints); + + // Phase 2: Batch-invert + FieldElement::inplace_batch_inverse(&mut fingerprints) + .expect("fingerprint is zero - probability of sampling zero is negligible"); + + // Phase 3: Compute terms + for (i, result_elem) in result_chunk.iter_mut().enumerate() { + let row = chunk_start + i; + let fp_a_inv = &fingerprints[i]; + let fp_b_inv = &fingerprints[chunk_len + i]; + let m_a = interaction_a + .multiplicity + .evaluate_at_row(main_segment_cols, row); + let m_b = interaction_b + .multiplicity + .evaluate_at_row(main_segment_cols, row); + let term_a = &m_a * fp_a_inv; + let term_b = &m_b * fp_b_inv; let term_a = if negate_a { -term_a } else { term_a }; let term_b = if negate_b { -term_b } else { term_b }; - term_a + term_b - }) - .collect() + *result_elem = term_a + term_b; + } + }; + + #[cfg(feature = "parallel")] + result + .par_chunks_mut(LOGUP_CHUNK_SIZE) + .enumerate() + .for_each(|(i, chunk)| process_chunk(i * LOGUP_CHUNK_SIZE, chunk)); + + #[cfg(not(feature = "parallel"))] + process_chunk(0, &mut result); + + result } /// Builds the circular accumulated column from pre-computed term columns.