diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 770193020..8d58e9ef2 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -19,7 +19,10 @@ use math::field::{ traits::{IsFFTField, IsField, IsPrimeField, IsSubFieldOf}, }; #[cfg(feature = "parallel")] -use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use rayon::prelude::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, + ParallelSliceMut, +}; // ============================================================================= // Shift Constants for Type Combining @@ -100,6 +103,11 @@ pub const LOGUP_CHALLENGE_ALPHA: usize = 1; /// Number of challenges required by the LogUp protocol. pub const LOGUP_NUM_CHALLENGES: usize = 2; +/// Chunk size for fused chunk-local LogUp processing. +/// Each chunk processes all interactions for CHUNK_SIZE rows, fitting in L2 cache. +#[cfg(feature = "parallel")] +const LOGUP_CHUNK_SIZE: usize = 1024; + /// Split N interactions into committed batched pairs and absorbed remainder. /// /// Returns `(num_committed_pairs, absorbed_count)` where: @@ -1701,27 +1709,101 @@ where } let trace_len = term_columns[0].len(); - // Compute L = sum of all terms across all rows - let mut table_contribution = FieldElement::::zero(); - for row in 0..trace_len { - for col in term_columns { - table_contribution = &table_contribution + &col[row]; - } - } + // Precompute row_sums[row] = sum of all term_columns at that row. + // This avoids recomputing during the prefix sum and enables parallel reduction. + let row_sums: Vec> = (0..trace_len) + .map(|row| { + let mut s = FieldElement::::zero(); + for col in term_columns { + s = s + &col[row]; + } + s + }) + .collect(); + + // Compute L = sum of all row_sums (parallel when feature enabled) + #[cfg(feature = "parallel")] + let table_contribution: FieldElement = row_sums + .par_iter() + .cloned() + .reduce(FieldElement::zero, |a, b| a + b); + #[cfg(not(feature = "parallel"))] + let table_contribution: FieldElement = + row_sums.iter().fold(FieldElement::zero(), |a, b| a + b); // offset_per_row = L / N let n = FieldElement::::from(trace_len as u64); let offset_per_row = &table_contribution * n.inv().unwrap(); - // Build circular accumulated column - let mut accumulated = FieldElement::::zero(); - for row in 0..trace_len { - let mut row_sum = FieldElement::::zero(); - for col in term_columns { - row_sum = row_sum + &col[row]; + // Build circular accumulated column using 3-phase parallel prefix sum. + // + // Phase 1: Compute chunk-local prefix sums in parallel. + // Each chunk computes partial_sums[i] = Σ(row_sums[j] - offset) for j in chunk. + // Also stores the chunk's total as `chunk_totals[chunk_idx]`. + // + // Phase 2: Sequential scan of chunk_totals to compute offsets for each chunk. + // + // Phase 3: Add chunk offset to each element in the accumulated vector. + // + // Finally write the accumulated column to trace (sequential, since set_aux takes &mut). + #[cfg(feature = "parallel")] + let accumulated_col = { + let num_chunks = trace_len.div_ceil(LOGUP_CHUNK_SIZE); + + // Phase 1: Compute chunk-local prefix sums + let chunk_data: Vec<(Vec>, FieldElement)> = (0..num_chunks) + .into_par_iter() + .map(|chunk_idx| { + let start = chunk_idx * LOGUP_CHUNK_SIZE; + let end = (start + LOGUP_CHUNK_SIZE).min(trace_len); + + let mut local_prefix = Vec::with_capacity(end - start); + let mut acc = FieldElement::::zero(); + for rs in &row_sums[start..end] { + acc = &acc + rs - &offset_per_row; + local_prefix.push(acc.clone()); + } + let chunk_total = acc; + (local_prefix, chunk_total) + }) + .collect(); + + // Phase 2: Sequential scan of chunk totals to get per-chunk offsets + let mut chunk_offsets = Vec::with_capacity(num_chunks); + let mut running = FieldElement::::zero(); + for (_, chunk_total) in &chunk_data { + chunk_offsets.push(running.clone()); + running = &running + chunk_total; } - accumulated = &accumulated + &row_sum - &offset_per_row; - trace.set_aux(row, acc_column_idx, accumulated.clone()); + + // Phase 3: Build final accumulated vector (parallel across chunks) + let mut acc_col = vec![FieldElement::::zero(); trace_len]; + acc_col + .par_chunks_mut(LOGUP_CHUNK_SIZE) + .enumerate() + .for_each(|(chunk_idx, out_chunk)| { + let offset = &chunk_offsets[chunk_idx]; + for (i, out) in out_chunk.iter_mut().enumerate() { + *out = offset + &chunk_data[chunk_idx].0[i]; + } + }); + acc_col + }; + + #[cfg(not(feature = "parallel"))] + let accumulated_col = { + let mut col = Vec::with_capacity(trace_len); + let mut accumulated = FieldElement::::zero(); + for row_sum in &row_sums { + accumulated = &accumulated + row_sum - &offset_per_row; + col.push(accumulated.clone()); + } + col + }; + + // Write accumulated column to trace + for (row, value) in accumulated_col.into_iter().enumerate() { + trace.set_aux(row, acc_column_idx, value); } table_contribution