From 257c67274d18a2788d571af640c42598a1c951d4 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Fri, 13 Mar 2026 16:26:42 -0300 Subject: [PATCH 1/3] feat(prover): slim MEMW table and add MEMW_A aligned fast path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement the new MEMW spec (PR #398) with two major changes: 1. MEMW slimdown (70→49 cols, 57→26 interactions): - Replace address_add[7] DWordHL (28 cols) with add_limb_overflow[7] Bit (7 cols), making address_add virtual via BusValue::linear - Remove 28 IS_HALFWORD interactions and 3 LT overflow checks - Replace AddConstraint pairs with IsBitConstraint::unconditional 2. MEMW_A aligned fast path (30 cols, 22 interactions): - New table for aligned accesses where all bytes share the same old_timestamp (covers all register ops + most memory ops) - Address decomposed into high/mid/low parts - AND_BYTE interaction for alignment check - Single old_timestamp instead of per-byte Operations are routed at trace build time: aligned ops with uniform timestamps go to MEMW_A, the rest go to the slimmed MEMW. --- prover/src/lib.rs | 30 +- prover/src/tables/memw.rs | 588 ++++++---------- prover/src/tables/memw_aligned.rs | 846 ++++++++++++++++++++++++ prover/src/tables/mod.rs | 33 +- prover/src/tables/trace_builder.rs | 146 ++-- prover/src/test_utils.rs | 36 +- prover/src/tests/prove_elfs_tests.rs | 24 +- prover/src/tests/trace_builder_tests.rs | 47 +- 8 files changed, 1244 insertions(+), 506 deletions(-) create mode 100644 prover/src/tables/memw_aligned.rs diff --git a/prover/src/lib.rs b/prover/src/lib.rs index c44dcb654..44f11f851 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -36,7 +36,8 @@ use crate::tables::trace_builder::Traces; use crate::test_utils::{ E, F, VmAir, create_bitwise_air, create_branch_air, create_commit_air, create_cpu_air, create_decode_air, create_dvrm_air, create_halt_air, create_load_air, create_lt_air, - create_memw_air, create_mul_air, create_page_air, create_register_air, create_shift_air, + create_memw_air, create_memw_aligned_air, create_mul_air, create_page_air, create_register_air, + create_shift_air, }; use stark::proof::options::{GoldilocksCubicProofOptions, ProofOptions}; @@ -61,6 +62,7 @@ pub struct TableCounts { pub cpu: usize, pub lt: usize, pub memw: usize, + pub memw_aligned: usize, pub load: usize, pub mul: usize, pub dvrm: usize, @@ -75,7 +77,15 @@ impl TableCounts { /// allowing a malicious prover to bypass soundness checks. /// Sum of all chunk counts across split tables. pub fn total(&self) -> usize { - self.cpu + self.lt + self.memw + self.load + self.mul + self.dvrm + self.shift + self.branch + self.cpu + + self.lt + + self.memw + + self.memw_aligned + + self.load + + self.mul + + self.dvrm + + self.shift + + self.branch } /// Validate that all required tables have at least one chunk. @@ -87,6 +97,7 @@ impl TableCounts { ("cpu", self.cpu), ("lt", self.lt), ("memw", self.memw), + ("memw_aligned", self.memw_aligned), ("load", self.load), ("mul", self.mul), ("dvrm", self.dvrm), @@ -167,6 +178,7 @@ pub(crate) struct VmAirs { pub lts: Vec, pub shifts: Vec, pub memws: Vec, + pub memw_aligneds: Vec, pub loads: Vec, pub decode: VmAir, pub muls: Vec, @@ -201,6 +213,13 @@ impl VmAirs { for (air, trace) in self.memws.iter().zip(traces.memws.iter_mut()) { pairs.push((air, trace, &())); } + for (air, trace) in self + .memw_aligneds + .iter() + .zip(traces.memw_aligneds.iter_mut()) + { + pairs.push((air, trace, &())); + } for (air, trace) in self.loads.iter().zip(traces.loads.iter_mut()) { pairs.push((air, trace, &())); } @@ -242,6 +261,9 @@ impl VmAirs { for air in &self.memws { refs.push(air); } + for air in &self.memw_aligneds { + refs.push(air); + } for air in &self.loads { refs.push(air); } @@ -294,6 +316,9 @@ impl VmAirs { let memws: Vec<_> = (0..table_counts.memw) .map(|i| create_memw_air(proof_options).with_name(&format!("MEMW[{}]", i))) .collect(); + let memw_aligneds: Vec<_> = (0..table_counts.memw_aligned) + .map(|i| create_memw_aligned_air(proof_options).with_name(&format!("MEMW_A[{}]", i))) + .collect(); let loads: Vec<_> = (0..table_counts.load) .map(|i| create_load_air(proof_options).with_name(&format!("LOAD[{}]", i))) .collect(); @@ -336,6 +361,7 @@ impl VmAirs { lts, shifts, memws, + memw_aligneds, loads, decode, muls, diff --git a/prover/src/tables/memw.rs b/prover/src/tables/memw.rs index 3bbd96e2f..3b1959275 100644 --- a/prover/src/tables/memw.rs +++ b/prover/src/tables/memw.rs @@ -1,32 +1,30 @@ -//! MEMW (Memory Write/Read) table. +//! MEMW (Memory Write/Read) table — unaligned / split-timestamp path. //! -//! This table handles memory and register read/write operations with timestamp-based -//! consistency checking. +//! This table handles memory and register read/write operations where bytes may +//! have different old_timestamps or the access is unaligned. +//! +//! ## Column layout (49 columns) //! -//! ## Inputs //! - `is_register`: Bit (1 = register access, 0 = memory access) -//! - `base_address`: DWordWL (64-bit address) +//! - `base_address`: DWordWL (64-bit address, 2 cols) //! - `value[8]`: BaseField[8] (8 bytes to write) -//! - `timestamp`: DWordWL (64-bit timestamp) +//! - `timestamp`: DWordWL (64-bit timestamp, 2 cols) //! - `write2/4/8`: Bit (access width flags) -//! -//! ## Output //! - `old[8]`: BaseField[8] (previous values at address) -//! -//! ## Auxiliary -//! - `address_add[7]`: DWordHL[7] (base_address + 1..7) -//! - `old_timestamp[8]`: DWordWL[8] (previous timestamps) +//! - `add_limb_overflow[7]`: Bit[7] (carry flags for base_address + i) +//! - `old_timestamp[8]`: DWordWL[8] (previous timestamps, 16 cols) +//! - `mu_read`, `mu_write`: multiplicity columns //! //! ## Virtual (computed inline) +//! - `address_add[i]` = (base_address_0 + i+1 - 2^32 * overflow[i], base_address_1 + overflow[i]) //! - `w2`: write2 + write4 + write8 (writing at least 2 bytes) //! - `w4`: write4 + write8 (writing at least 4 bytes) //! - `μ_sum`: μ_read + μ_write //! -//! ## Bus Interactions -//! - Receiver: MEMW (from CPU for LOAD/STORE operations) -//! - Sender: IS_HALFWORD (range checks for address_add) -//! - Sender: LT (timestamp ordering checks) -//! - Sender/Receiver: Memory bus (internal read/write consistency) +//! ## Bus Interactions (26) +//! - 8 LT timestamp checks (old_timestamp[i] < timestamp) +//! - 16 Memory bus tokens (read old + write new, per byte) +//! - 2 MEMW output interactions (read + write, from CPU) use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; @@ -37,14 +35,14 @@ use stark::trace::TraceTable; use stark::traits::TransitionEvaluationContext; use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; -use crate::constraints::templates::{AddConstraint, AddOperand}; +use crate::constraints::templates::IsBitConstraint; /// Maximum number of rows per MEMW table chunk. /// If operations exceed this, the trace is split into multiple tables. pub const MAX_ROWS: usize = super::max_rows::MEMW; // ========================================================================= -// Column indices for MEMW table +// Column indices for MEMW table (49 columns) // ========================================================================= /// Column definitions for the MEMW table. @@ -74,31 +72,21 @@ pub mod cols { pub const OLD: [usize; 8] = [16, 17, 18, 19, 20, 21, 22, 23]; // Auxiliary columns - /// address_add[7]: each is DWordHL (4 halfwords = 4 columns) - /// Total: 7 * 4 = 28 columns - pub const ADDRESS_ADD_START: usize = 24; - // address_add[i] uses columns ADDRESS_ADD_START + i*4 .. ADDRESS_ADD_START + i*4 + 4 + /// add_limb_overflow[7]: Bit columns indicating carry when adding i+1 to base_address_0 + pub const ADD_LIMB_OVERFLOW: [usize; 7] = [24, 25, 26, 27, 28, 29, 30]; /// old_timestamp[8]: each is DWordWL (2 words = 2 columns) /// Total: 8 * 2 = 16 columns - pub const OLD_TIMESTAMP_START: usize = 52; // 24 + 28 - // old_timestamp[i] uses columns OLD_TIMESTAMP_START + i*2 .. OLD_TIMESTAMP_START + i*2 + 2 + pub const OLD_TIMESTAMP_START: usize = 31; // Multiplicity columns /// μ_read: Whether we are performing a read - pub const MU_READ: usize = 68; // 52 + 16 + pub const MU_READ: usize = 47; /// μ_write: Whether we are performing a write - pub const MU_WRITE: usize = 69; + pub const MU_WRITE: usize = 48; /// Total number of columns - /// Note: w2, w4, μ_sum are now computed inline via Multiplicity::Linear/Sum - pub const NUM_COLUMNS: usize = 70; - - /// Helper to get address_add[i] column indices (4 halfwords each) - pub fn address_add(i: usize) -> [usize; 4] { - let base = ADDRESS_ADD_START + i * 4; - [base, base + 1, base + 2, base + 3] - } + pub const NUM_COLUMNS: usize = 49; /// Helper to get old_timestamp[i] column indices (2 words each) pub fn old_timestamp(i: usize) -> [usize; 2] { @@ -163,25 +151,12 @@ impl MemwOperation { /// Convert access width to the spec's flag representation (write2, write4, write8). /// - /// The spec uses three flags to encode access width: - /// - `write2`: set if accessing 2+ bytes (width >= 2) - /// - `write4`: set if accessing 4+ bytes (width >= 4) - /// - `write8`: set if accessing 8 bytes (width == 8) - /// - /// This encoding allows computing "at least N bytes" predicates: - /// - w2 (at least 2) = write2 + write4 + write8 - /// - w4 (at least 4) = write4 + write8 - /// /// | Width | write2 | write4 | write8 | /// |-------|--------|--------|--------| /// | 1 | 0 | 0 | 0 | /// | 2 | 1 | 0 | 0 | /// | 4 | 0 | 1 | 0 | /// | 8 | 0 | 0 | 1 | - /// - /// Note: These are "exactly N" semantics per spec, not cumulative. - /// Virtual columns w2 = write2 + write4 + write8 and w4 = write4 + write8 - /// compute "at least N" from these. pub fn write_flags(&self) -> (bool, bool, bool) { match self.width { 1 => (false, false, false), @@ -191,86 +166,6 @@ impl MemwOperation { _ => (false, false, false), } } - - /// Collect LT operations for timestamp ordering and overflow checking. - /// - /// Per spec constraints #7-10 and R1-R3: - /// - #7: old_timestamp[0] < timestamp (for all accesses) - /// - #8: old_timestamp[1] < timestamp (for width >= 2) - /// - #9: old_timestamp[2,3] < timestamp (for width >= 4) - /// - #10: old_timestamp[4..7] < timestamp (for width == 8) - /// - R1: base_address < base_address + 1 (overflow check for width >= 2) - /// - R2: base_address < base_address + 3 (overflow check for width >= 4) - /// - R3: base_address < base_address + 7 (overflow check for width == 8) - pub fn collect_lt_lookups(&self) -> Vec { - use super::lt::LtOperation; - - let mut ops = Vec::new(); - - // Constraint 7: old_timestamp[0] < timestamp (always, for any access) - ops.push(LtOperation::new( - self.old_timestamp[0], - self.timestamp, - false, - )); - - // Constraint 8: old_timestamp[1] < timestamp (for width >= 2) - if self.width >= 2 { - ops.push(LtOperation::new( - self.old_timestamp[1], - self.timestamp, - false, - )); - } - - // Constraint 9: old_timestamp[2,3] < timestamp (for width >= 4) - if self.width >= 4 { - ops.push(LtOperation::new( - self.old_timestamp[2], - self.timestamp, - false, - )); - ops.push(LtOperation::new( - self.old_timestamp[3], - self.timestamp, - false, - )); - } - - // Constraint 10: old_timestamp[4..7] < timestamp (for width == 8) - if self.width == 8 { - for i in 4..8 { - ops.push(LtOperation::new( - self.old_timestamp[i], - self.timestamp, - false, - )); - } - } - - // Overflow checks R1-R3: base_address < base_address + offset - // Always generate the LT operation - if overflow occurred, LT will return 0 - // and the constraint (expecting result=1) will fail, rejecting the proof. - // R1: for width == 2, check base_address < base_address + 1 - if self.width == 2 { - let addr_plus_1 = self.base_address.wrapping_add(1); - ops.push(LtOperation::new(self.base_address, addr_plus_1, false)); - } - - // R2: for width == 4, check base_address < base_address + 3 - if self.width == 4 { - let addr_plus_3 = self.base_address.wrapping_add(3); - ops.push(LtOperation::new(self.base_address, addr_plus_3, false)); - } - - // R3: for width == 8, check base_address < base_address + 7 - if self.width == 8 { - let addr_plus_7 = self.base_address.wrapping_add(7); - ops.push(LtOperation::new(self.base_address, addr_plus_7, false)); - } - - ops - } } /// Generates the MEMW trace table from a list of operations. @@ -287,7 +182,8 @@ pub fn generate_memw_trace( data[base + cols::IS_REGISTER] = FE::from(op.is_register as u64); // base_address as DWordWL (2 words) - data[base + cols::BASE_ADDRESS_0] = FE::from(op.base_address & 0xFFFF_FFFF); + let base_addr_lo = op.base_address & 0xFFFF_FFFF; + data[base + cols::BASE_ADDRESS_0] = FE::from(base_addr_lo); data[base + cols::BASE_ADDRESS_1] = FE::from(op.base_address >> 32); // value[8] @@ -310,14 +206,11 @@ pub fn generate_memw_trace( data[base + cols::OLD[i]] = FE::from(op.old[i]); } - // Auxiliary: address_add[7] - each as DWordHL (4 halfwords) + // Auxiliary: add_limb_overflow[7] + // overflow[i] = 1 if (base_address_lo + i+1) >= 2^32 for i in 0..7 { - let addr = op.base_address.wrapping_add(i as u64 + 1); - let cols_i = cols::address_add(i); - data[base + cols_i[0]] = FE::from(addr & 0xFFFF); - data[base + cols_i[1]] = FE::from((addr >> 16) & 0xFFFF); - data[base + cols_i[2]] = FE::from((addr >> 32) & 0xFFFF); - data[base + cols_i[3]] = FE::from((addr >> 48) & 0xFFFF); + let overflows = base_addr_lo + (i as u64 + 1) >= (1u64 << 32); + data[base + cols::ADD_LIMB_OVERFLOW[i]] = FE::from(overflows as u64); } // Auxiliary: old_timestamp[8] - each as DWordWL (2 words) @@ -330,79 +223,30 @@ pub fn generate_memw_trace( // Multiplicity data[base + cols::MU_READ] = FE::from(op.is_read as u64); data[base + cols::MU_WRITE] = FE::from(!op.is_read as u64); - // Note: w2, w4, μ_sum are computed inline via Multiplicity::Linear/Sum } TraceTable::new_main(data, cols::NUM_COLUMNS, 1) } // ========================================================================= -// Bus interactions +// Bus interactions (26 total) // ========================================================================= /// Creates all bus interactions for the MEMW table. /// -/// The MEMW table: -/// - **Receives** MEMW lookups from CPU (for LOAD/STORE operations) -/// - **Sends** IS_HALFWORD lookups for address_add range checks -/// - **Sends** LT lookups for timestamp ordering (old_timestamp < timestamp) +/// 26 interactions: +/// - 8 LT timestamp ordering checks +/// - 16 Memory bus tokens (read old + write new per byte) +/// - 2 MEMW output interactions (read + write from CPU) pub fn bus_interactions() -> Vec { let mut interactions = Vec::new(); // ------------------------------------------------------------------------- - // IS_HALFWORD range checks for address_add[i][j] - // ------------------------------------------------------------------------- - // ------------------------------------------------------------------------- - // IsHalfword range checks for address_add columns - // ------------------------------------------------------------------------- - // Each address_add[i] is 4 halfwords (DWordHL packing), need to range check all. - // Only check when row is active (μ_read + μ_write > 0). - for i in 0..7 { - let cols_i = cols::address_add(i); - for &col in &cols_i { - interactions.push(BusInteraction::sender( - BusId::IsHalfword, - // Only range check when row is active - Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE), - vec![BusValue::Packed { - start_column: col, - packing: Packing::Direct, - }], - )); - } - } - - // ------------------------------------------------------------------------- - // Memory bus interactions (M1-M8 from spec) - // ------------------------------------------------------------------------- - // DISABLED: Memory bus requires initialization and finalization: - // - Initialization: For each address accessed, an initial row at timestamp=0 - // with the starting value must exist so the first read has a matching write. - // - Finalization: Final values must be consumed to balance the bus. - // Without these, the bus won't balance. - // ------------------------------------------------------------------------- - // These ensure read/write consistency: - // - Read old value at old_timestamp (+multiplicity) - // - Write new value at current timestamp (-multiplicity) - // - // Memory bus format: memory[is_register, address, timestamp_lo, timestamp_hi, value] - // - // Register tokens (is_register=1) are balanced by the REGISTER table. - // Memory tokens (is_register=0) are balanced by PAGE tables. - - // ------------------------------------------------------------------------- - // Memory bus interactions per spec CM16-CM23 + // Memory bus interactions (16 total) // ------------------------------------------------------------------------- - // Token format: memory[is_register, address_lo, address_hi, ts_lo, ts_hi, value] - // - // For registers (is_register=1): value is a Word (32-bit), address is Word-indexed - // For memory (is_register=0): value is a Byte (8-bit), address is byte-indexed - // - // Multiplicities per spec: - // - CM16/17 (index 0): μ_sum - // - CM18/19 (index 1): w2 = write2 + write4 + write8 - // - CM20/21 (indices 2-3): w4 = write4 + write8 - // - CM22/23 (indices 4-7): write8 + // address_add[i] is VIRTUAL: + // lo = base_address_0 + (i+1) - 2^32 * add_limb_overflow[i] + // hi = base_address_1 + add_limb_overflow[i] // CM16: memory[is_register, base_address, old_timestamp[0], old[0]] with +μ_sum interactions.push(BusInteraction::sender( @@ -468,35 +312,50 @@ pub fn bus_interactions() -> Vec { ], )); - // Helper: address_add[0] = base_address + 1, stored as DWordHL (4 halfwords) - // Use Word2L to combine each pair of halfwords into a word - let addr_add_0_lo = BusValue::Packed { - start_column: cols::address_add(0)[0], - packing: Packing::Word2L, - }; - let addr_add_0_hi = BusValue::Packed { - start_column: cols::address_add(0)[2], - packing: Packing::Word2L, - }; - - // CM18: memory[is_register, address_add[0], old_timestamp[1], old[1]] with +w2 - // w2 = write2 + write4 + write8 + // CM18/19: byte 1, multiplicity w2 = write2 + write4 + write8 + // address_add[0] is virtual: lo = base_address_0 + 1 - 2^32 * overflow[0] + // hi = base_address_1 + overflow[0] + let w2_mult = Multiplicity::Linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::WRITE2, + }, + LinearTerm::Column { + coefficient: 1, + column: cols::WRITE4, + }, + LinearTerm::Column { + coefficient: 1, + column: cols::WRITE8, + }, + ]); + + let addr_add_0_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_0, + }, + LinearTerm::Constant(1), + LinearTerm::Column { + coefficient: -(1i64 << 32), + column: cols::ADD_LIMB_OVERFLOW[0], + }, + ]); + let addr_add_0_hi = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_1, + }, + LinearTerm::Column { + coefficient: 1, + column: cols::ADD_LIMB_OVERFLOW[0], + }, + ]); + + // CM18: send old token for byte 1 interactions.push(BusInteraction::sender( BusId::Memory, - Multiplicity::Linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE2, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE4, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE8, - }, - ]), + w2_mult.clone(), vec![ BusValue::Packed { start_column: cols::IS_REGISTER, @@ -519,23 +378,10 @@ pub fn bus_interactions() -> Vec { ], )); - // CM19: memory[is_register, address_add[0], timestamp, value[1]] with -w2 + // CM19: receive new token for byte 1 interactions.push(BusInteraction::receiver( BusId::Memory, - Multiplicity::Linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE2, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE4, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE8, - }, - ]), + w2_mult, vec![ BusValue::Packed { start_column: cols::IS_REGISTER, @@ -558,30 +404,35 @@ pub fn bus_interactions() -> Vec { ], )); - // CM20/21: indices 2-3 with multiplicity w4 = write4 + write8 + // CM20/21: bytes 2-3 with multiplicity w4 = write4 + write8 for i in 2..=3 { - let addr_add_lo = BusValue::Packed { - start_column: cols::address_add(i - 1)[0], - packing: Packing::Word2L, - }; - let addr_add_hi = BusValue::Packed { - start_column: cols::address_add(i - 1)[2], - packing: Packing::Word2L, - }; - - // CM22.i: send old token + let overflow_col = cols::ADD_LIMB_OVERFLOW[i - 1]; + let addr_add_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_0, + }, + LinearTerm::Constant(i as i64), + LinearTerm::Column { + coefficient: -(1i64 << 32), + column: overflow_col, + }, + ]); + let addr_add_hi = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_1, + }, + LinearTerm::Column { + coefficient: 1, + column: overflow_col, + }, + ]); + + // send old token interactions.push(BusInteraction::sender( BusId::Memory, - Multiplicity::Linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE4, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE8, - }, - ]), + Multiplicity::Sum(cols::WRITE4, cols::WRITE8), vec![ BusValue::Packed { start_column: cols::IS_REGISTER, @@ -604,19 +455,10 @@ pub fn bus_interactions() -> Vec { ], )); - // CM23.i: receive new token + // receive new token interactions.push(BusInteraction::receiver( BusId::Memory, - Multiplicity::Linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE4, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE8, - }, - ]), + Multiplicity::Sum(cols::WRITE4, cols::WRITE8), vec![ BusValue::Packed { start_column: cols::IS_REGISTER, @@ -640,18 +482,32 @@ pub fn bus_interactions() -> Vec { )); } - // CM22/23: indices 4-7 with multiplicity write8 + // CM22/23: bytes 4-7 with multiplicity write8 for i in 4..=7 { - let addr_add_lo = BusValue::Packed { - start_column: cols::address_add(i - 1)[0], - packing: Packing::Word2L, - }; - let addr_add_hi = BusValue::Packed { - start_column: cols::address_add(i - 1)[2], - packing: Packing::Word2L, - }; - - // CM22.i: send old token + let overflow_col = cols::ADD_LIMB_OVERFLOW[i - 1]; + let addr_add_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_0, + }, + LinearTerm::Constant(i as i64), + LinearTerm::Column { + coefficient: -(1i64 << 32), + column: overflow_col, + }, + ]); + let addr_add_hi = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_1, + }, + LinearTerm::Column { + coefficient: 1, + column: overflow_col, + }, + ]); + + // send old token interactions.push(BusInteraction::sender( BusId::Memory, Multiplicity::Column(cols::WRITE8), @@ -677,7 +533,7 @@ pub fn bus_interactions() -> Vec { ], )); - // CM23.i: receive new token + // receive new token interactions.push(BusInteraction::receiver( BusId::Memory, Multiplicity::Column(cols::WRITE8), @@ -705,17 +561,13 @@ pub fn bus_interactions() -> Vec { } // ------------------------------------------------------------------------- - // CO24: Read receiver (unified for register and memory operations) + // CO24: Read receiver (from CPU) // ------------------------------------------------------------------------- - // OLD and VALUE are 8 individual BaseField elements (Direct packing). - // For registers: [lo32_word, hi32_word, 0, 0, 0, 0, 0, 0] - // For memory: [byte0, byte1, ..., byte7] - // Both match sender format since bus compares field elements directly. interactions.push(BusInteraction::receiver( BusId::Memw, Multiplicity::Column(cols::MU_READ), vec![ - // old[8] - output for reads (words) + // old[8] BusValue::Packed { start_column: cols::OLD[0], packing: Packing::Direct, @@ -753,7 +605,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::IS_REGISTER, packing: Packing::Direct, }, - // base_address (DWordWL = 2 words) + // base_address BusValue::Packed { start_column: cols::BASE_ADDRESS_0, packing: Packing::Direct, @@ -762,7 +614,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::BASE_ADDRESS_1, packing: Packing::Direct, }, - // value[8] - direct reads (words) + // value[8] BusValue::Packed { start_column: cols::VALUE[0], packing: Packing::Direct, @@ -795,7 +647,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::VALUE[7], packing: Packing::Direct, }, - // timestamp (DWordWL = 2 words) + // timestamp BusValue::Packed { start_column: cols::TIMESTAMP_0, packing: Packing::Direct, @@ -821,12 +673,8 @@ pub fn bus_interactions() -> Vec { )); // ------------------------------------------------------------------------- - // CO25: Write receiver (unified for register and memory operations) + // CO25: Write receiver (from CPU) // ------------------------------------------------------------------------- - // VALUE is 8 individual BaseField elements (Direct packing). - // For registers: [lo32_word, hi32_word, 0, 0, 0, 0, 0, 0] - // For memory: [byte0, byte1, ..., byte7] - // Both match sender format since bus compares field elements directly. interactions.push(BusInteraction::receiver( BusId::Memw, Multiplicity::Column(cols::MU_WRITE), @@ -836,7 +684,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::IS_REGISTER, packing: Packing::Direct, }, - // base_address (DWordWL = 2 words) + // base_address BusValue::Packed { start_column: cols::BASE_ADDRESS_0, packing: Packing::Direct, @@ -845,7 +693,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::BASE_ADDRESS_1, packing: Packing::Direct, }, - // value[8] - direct reads (words) + // value[8] BusValue::Packed { start_column: cols::VALUE[0], packing: Packing::Direct, @@ -878,7 +726,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::VALUE[7], packing: Packing::Direct, }, - // timestamp (DWordWL = 2 words) + // timestamp BusValue::Packed { start_column: cols::TIMESTAMP_0, packing: Packing::Direct, @@ -906,33 +754,26 @@ pub fn bus_interactions() -> Vec { // ------------------------------------------------------------------------- // LT interactions for timestamp ordering (constraints 7-10) // ------------------------------------------------------------------------- - // Verify old_timestamp[i] < timestamp for each accessed byte. - // LT bus uses 2 elements per 64-bit operand: [lo32, hi32] - // Both old_timestamp and timestamp are DWordWL, so use Packing::DWordWL. - // Constraint 7: LT[1; old_timestamp[0], timestamp] with μ_sum + // C7: LT[1; old_timestamp[0], timestamp] with μ_sum interactions.push(BusInteraction::sender( BusId::Lt, Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE), vec![ - // lhs = old_timestamp[0] as DWordWL (2 elements: [lo32, hi32]) BusValue::Packed { start_column: cols::old_timestamp(0)[0], packing: Packing::DWordWL, }, - // rhs = timestamp as DWordWL (2 elements: [lo32, hi32]) BusValue::Packed { start_column: cols::TIMESTAMP_0, packing: Packing::DWordWL, }, - // signed = 0 (unsigned comparison) BusValue::constant(0), - // lt = 1 (expected result: old_timestamp < timestamp) BusValue::constant(1), ], )); - // Constraint 8: LT[1; old_timestamp[1], timestamp] with w2 + // C8: LT[1; old_timestamp[1], timestamp] with w2 interactions.push(BusInteraction::sender( BusId::Lt, Multiplicity::Linear(vec![ @@ -963,7 +804,7 @@ pub fn bus_interactions() -> Vec { ], )); - // Constraint 9: LT[1; old_timestamp[i], timestamp] for i ∈ [2,3] with w4 + // C9: LT[1; old_timestamp[i], timestamp] for i ∈ [2,3] with w4 for i in 2..4 { interactions.push(BusInteraction::sender( BusId::Lt, @@ -983,7 +824,7 @@ pub fn bus_interactions() -> Vec { )); } - // Constraint 10: LT[1; old_timestamp[i], timestamp] for i ∈ [4,7] with write8 + // C10: LT[1; old_timestamp[i], timestamp] for i ∈ [4,7] with write8 for i in 4..8 { interactions.push(BusInteraction::sender( BusId::Lt, @@ -1003,69 +844,6 @@ pub fn bus_interactions() -> Vec { )); } - // ------------------------------------------------------------------------- - // LT interactions for overflow checking (constraints R1-R3) - // ------------------------------------------------------------------------- - // Verify base_address < address_add[i] (no overflow when adding offset). - // base_address is DWordWL, address_add[i] is DWordHL (cast to DWordWL). - // Both packings produce 2 elements [lo32, hi32]. - - // R1: LT[1; base_address, address_add[0]] with write2 - // This checks for no overflow when accessing byte 1 (width == 2) - interactions.push(BusInteraction::sender( - BusId::Lt, - Multiplicity::Column(cols::WRITE2), - vec![ - BusValue::Packed { - start_column: cols::BASE_ADDRESS_0, - packing: Packing::DWordWL, - }, - BusValue::Packed { - start_column: cols::address_add(0)[0], - packing: Packing::DWordHL, - }, - BusValue::constant(0), // unsigned - BusValue::constant(1), // lt = 1 - ], - )); - - // R2: LT[1; base_address, address_add[2]] with write4 - // This checks for no overflow when accessing byte 3 (width == 4) - interactions.push(BusInteraction::sender( - BusId::Lt, - Multiplicity::Column(cols::WRITE4), - vec![ - BusValue::Packed { - start_column: cols::BASE_ADDRESS_0, - packing: Packing::DWordWL, - }, - BusValue::Packed { - start_column: cols::address_add(2)[0], - packing: Packing::DWordHL, - }, - BusValue::constant(0), - BusValue::constant(1), - ], - )); - - // R3: LT[1; base_address, address_add[6]] with write8 - interactions.push(BusInteraction::sender( - BusId::Lt, - Multiplicity::Column(cols::WRITE8), - vec![ - BusValue::Packed { - start_column: cols::BASE_ADDRESS_0, - packing: Packing::DWordWL, - }, - BusValue::Packed { - start_column: cols::address_add(6)[0], - packing: Packing::DWordHL, - }, - BusValue::constant(0), - BusValue::constant(1), - ], - )); - interactions } @@ -1097,7 +875,7 @@ where } // ========================================================================= -// Constraints +// Constraints (9 total: 2 custom + 7 IS_BIT) // ========================================================================= /// MEMW table constraint kinds. @@ -1132,12 +910,10 @@ impl MemwConstraint { match self.kind { MemwConstraintKind::MuSumIsBit => { - // IS_BIT<μ_sum>: μ_sum * (1 - μ_sum) = 0 let mu_sum = compute_mu_sum(step); &mu_sum * (&one - &mu_sum) } MemwConstraintKind::W2ImpliesMuSum => { - // w2 * (1 - μ_sum) = 0 let w2 = compute_w2(step); let mu_sum = compute_mu_sum(step); &w2 * (&one - &mu_sum) @@ -1191,6 +967,11 @@ impl TransitionConstraint for MemwConstrai } /// Creates all constraints for the MEMW table. +/// +/// 9 constraints total: +/// - IS_BIT<μ_sum> (1) +/// - w2 => μ_sum (1) +/// - IS_BIT for add_limb_overflow[0..6] (7) pub fn constraints() -> Vec>> { let mut constraints: Vec>> = Vec::new(); @@ -1211,29 +992,10 @@ pub fn constraints() -> Vec vec![cols::WRITE2, cols::WRITE4, cols::WRITE8], // w2 - 1 | 2 => vec![cols::WRITE4, cols::WRITE8], // w4 - _ => vec![cols::WRITE8], // write8 (i = 3..6) - }; - - // ADD constraint produces 2 constraints (carry_0, carry_1) - let (c0, c1) = AddConstraint::new_pair(condition, lhs, rhs, sum, idx); - constraints.push(Box::new(c0)); - constraints.push(Box::new(c1)); - idx += 2; + // IS_BIT for add_limb_overflow[0..6] + for &col in &cols::ADD_LIMB_OVERFLOW { + constraints.push(Box::new(IsBitConstraint::unconditional(col, idx))); + idx += 1; } constraints @@ -1259,17 +1021,39 @@ mod tests { #[test] fn test_write_flags() { - // "Exactly N" semantics per spec let op1 = MemwOperation::new(false, 0, [0; 8], 0, 1, false); - assert_eq!(op1.write_flags(), (false, false, false)); // no flags for 1 byte + assert_eq!(op1.write_flags(), (false, false, false)); let op2 = MemwOperation::new(false, 0, [0; 8], 0, 2, false); - assert_eq!(op2.write_flags(), (true, false, false)); // write2 only + assert_eq!(op2.write_flags(), (true, false, false)); let op4 = MemwOperation::new(false, 0, [0; 8], 0, 4, false); - assert_eq!(op4.write_flags(), (false, true, false)); // write4 only + assert_eq!(op4.write_flags(), (false, true, false)); let op8 = MemwOperation::new(false, 0, [0; 8], 0, 8, false); - assert_eq!(op8.write_flags(), (false, false, true)); // write8 only + assert_eq!(op8.write_flags(), (false, false, true)); + } + + #[test] + fn test_add_limb_overflow() { + // Address 0xFFFF_FFFF should overflow when adding 1 + let op = + MemwOperation::new(false, 0xFFFF_FFFF, [0; 8], 100, 8, false).with_old([0; 8], [50; 8]); + let trace = generate_memw_trace(&[op]); + + // All 7 overflow flags should be 1 since 0xFFFF_FFFF + i >= 2^32 for i >= 1 + for i in 0..7 { + let val = trace.get_main(0, cols::ADD_LIMB_OVERFLOW[i]); + assert_eq!(*val, FE::one(), "overflow[{i}] should be 1"); + } + + // Address 0x0000_0000 should not overflow + let op2 = + MemwOperation::new(false, 0x0000_0000, [0; 8], 100, 8, false).with_old([0; 8], [50; 8]); + let trace2 = generate_memw_trace(&[op2]); + for i in 0..7 { + let val = trace2.get_main(0, cols::ADD_LIMB_OVERFLOW[i]); + assert_eq!(*val, FE::zero(), "overflow[{i}] should be 0"); + } } } diff --git a/prover/src/tables/memw_aligned.rs b/prover/src/tables/memw_aligned.rs new file mode 100644 index 000000000..8e992060f --- /dev/null +++ b/prover/src/tables/memw_aligned.rs @@ -0,0 +1,846 @@ +//! MEMW_A (Memory Write/Read — Aligned) table. +//! +//! Fast path for aligned memory/register accesses where all bytes share the +//! same old_timestamp. Most operations (aligned memory + all register accesses) +//! route here instead of the heavier MEMW table. +//! +//! ## Column layout (30 columns) +//! +//! - `is_register`: Bit +//! - `base_address_high`: Word (32-bit high word) +//! - `base_address_mid`: Half (16-bit mid) +//! - `base_address_low[2]`: Bytes (low 2 bytes) +//! - `value[8]`: BaseField[8] +//! - `timestamp`: DWordWL (2 cols) +//! - `write2/4/8`: Bit (access width flags) +//! - `old[8]`: BaseField[8] +//! - `old_timestamp`: DWordWL (2 cols — single, not 8!) +//! - `mu_read`, `mu_write`: multiplicity columns +//! +//! ## Bus Interactions (22) +//! - 1 IS_HALFWORD[base_address_mid] +//! - 1 IS_BYTE[base_address_low[1]] +//! - 1 AND_BYTE[base_address_low[0], mask] → 0 (alignment check) +//! - 1 LT[old_timestamp, timestamp, 0] → 1 +//! - 16 Memory bus tokens +//! - 2 MEMW output interactions (read + write) + +use math::field::element::FieldElement; +use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraints::transition::TransitionConstraint; +use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; +use stark::table::TableView; +use stark::trace::TraceTable; +use stark::traits::TransitionEvaluationContext; + +use super::memw::MemwOperation; +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; + +/// Maximum number of rows per MEMW_A table chunk. +pub const MAX_ROWS: usize = super::max_rows::MEMW_A; + +// ========================================================================= +// Column indices (30 columns) +// ========================================================================= + +pub mod cols { + pub const IS_REGISTER: usize = 0; + + /// base_address decomposed: high = addr >> 32 (Word, 32-bit) + pub const BASE_ADDRESS_HIGH: usize = 1; + /// base_address decomposed: mid = (addr >> 16) & 0xFFFF (Half, 16-bit) + pub const BASE_ADDRESS_MID: usize = 2; + /// base_address decomposed: low bytes + /// low[0] = addr & 0xFF, low[1] = (addr >> 8) & 0xFF + pub const BASE_ADDRESS_LOW: [usize; 2] = [3, 4]; + + pub const VALUE: [usize; 8] = [5, 6, 7, 8, 9, 10, 11, 12]; + + pub const TIMESTAMP_0: usize = 13; + pub const TIMESTAMP_1: usize = 14; + + pub const WRITE2: usize = 15; + pub const WRITE4: usize = 16; + pub const WRITE8: usize = 17; + + pub const OLD: [usize; 8] = [18, 19, 20, 21, 22, 23, 24, 25]; + + /// Single old_timestamp (shared across all bytes, since they're aligned) + pub const OLD_TIMESTAMP_0: usize = 26; + pub const OLD_TIMESTAMP_1: usize = 27; + + pub const MU_READ: usize = 28; + pub const MU_WRITE: usize = 29; + + pub const NUM_COLUMNS: usize = 30; +} + +// ========================================================================= +// Trace generation +// ========================================================================= + +/// Generates the MEMW_A trace table from aligned operations. +/// +/// Reuses `MemwOperation` — the trace generator uses `old_timestamp[0]` +/// (verified equal for all accessed bytes by the routing logic). +pub fn generate_memw_aligned_trace( + operations: &[MemwOperation], +) -> TraceTable { + let num_rows = operations.len().next_power_of_two().max(4); + let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; + + for (row_idx, op) in operations.iter().enumerate() { + let base = row_idx * cols::NUM_COLUMNS; + + data[base + cols::IS_REGISTER] = FE::from(op.is_register as u64); + + // Decompose base_address + let addr = op.base_address; + let high = addr >> 32; + let mid = (addr >> 16) & 0xFFFF; + let low_1 = (addr >> 8) & 0xFF; + let low_0 = addr & 0xFF; + + data[base + cols::BASE_ADDRESS_HIGH] = FE::from(high); + data[base + cols::BASE_ADDRESS_MID] = FE::from(mid); + data[base + cols::BASE_ADDRESS_LOW[0]] = FE::from(low_0); + data[base + cols::BASE_ADDRESS_LOW[1]] = FE::from(low_1); + + for i in 0..8 { + data[base + cols::VALUE[i]] = FE::from(op.value[i]); + } + + data[base + cols::TIMESTAMP_0] = FE::from(op.timestamp & 0xFFFF_FFFF); + data[base + cols::TIMESTAMP_1] = FE::from(op.timestamp >> 32); + + let (w2, w4, w8) = op.write_flags(); + data[base + cols::WRITE2] = FE::from(w2 as u64); + data[base + cols::WRITE4] = FE::from(w4 as u64); + data[base + cols::WRITE8] = FE::from(w8 as u64); + + for i in 0..8 { + data[base + cols::OLD[i]] = FE::from(op.old[i]); + } + + // Single old_timestamp (from old_timestamp[0], verified equal for all bytes) + data[base + cols::OLD_TIMESTAMP_0] = FE::from(op.old_timestamp[0] & 0xFFFF_FFFF); + data[base + cols::OLD_TIMESTAMP_1] = FE::from(op.old_timestamp[0] >> 32); + + data[base + cols::MU_READ] = FE::from(op.is_read as u64); + data[base + cols::MU_WRITE] = FE::from(!op.is_read as u64); + } + + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +// ========================================================================= +// Bus interactions (22 total) +// ========================================================================= + +pub fn bus_interactions() -> Vec { + let mut interactions = Vec::new(); + + let mu_sum = Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE); + + // ------------------------------------------------------------------------- + // IS_HALFWORD[base_address_mid] with μ_sum + // ------------------------------------------------------------------------- + interactions.push(BusInteraction::sender( + BusId::IsHalfword, + mu_sum.clone(), + vec![BusValue::Packed { + start_column: cols::BASE_ADDRESS_MID, + packing: Packing::Direct, + }], + )); + + // ------------------------------------------------------------------------- + // IS_BYTE[base_address_low[1]] with μ_sum + // ------------------------------------------------------------------------- + interactions.push(BusInteraction::sender( + BusId::IsByte, + mu_sum.clone(), + vec![BusValue::Packed { + start_column: cols::BASE_ADDRESS_LOW[1], + packing: Packing::Direct, + }], + )); + + // ------------------------------------------------------------------------- + // AND_BYTE[base_address_low[0], mask] → 0 with μ_sum + // mask = write2*1 + write4*3 + write8*7 + // This implicitly range-checks low[0] to [0, 256) AND checks alignment. + // ------------------------------------------------------------------------- + interactions.push(BusInteraction::sender( + BusId::AndByte, + mu_sum.clone(), + vec![ + // x = base_address_low[0] + BusValue::Packed { + start_column: cols::BASE_ADDRESS_LOW[0], + packing: Packing::Direct, + }, + // y = mask = write2*1 + write4*3 + write8*7 + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::WRITE2, + }, + LinearTerm::Column { + coefficient: 3, + column: cols::WRITE4, + }, + LinearTerm::Column { + coefficient: 7, + column: cols::WRITE8, + }, + ]), + // result = 0 (alignment constraint: low bits must be 0) + BusValue::constant(0), + ], + )); + + // ------------------------------------------------------------------------- + // LT[old_timestamp, timestamp, 0] → 1 with μ_sum + // ------------------------------------------------------------------------- + interactions.push(BusInteraction::sender( + BusId::Lt, + mu_sum.clone(), + vec![ + BusValue::Packed { + start_column: cols::OLD_TIMESTAMP_0, + packing: Packing::DWordWL, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::DWordWL, + }, + BusValue::constant(0), + BusValue::constant(1), + ], + )); + + // ------------------------------------------------------------------------- + // Memory bus interactions (16 total) + // ------------------------------------------------------------------------- + // For aligned accesses, address for byte i: + // lo = 2^16 * MID + 2^8 * LOW[1] + LOW[0] + i + // hi = HIGH + // All old_timestamp references use the single old_timestamp columns. + + // Virtual base_address_lo = 2^16 * MID + 2^8 * LOW[1] + LOW[0] + // For byte 0, the address is exactly this. + let base_addr_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1 << 16, + column: cols::BASE_ADDRESS_MID, + }, + LinearTerm::Column { + coefficient: 1 << 8, + column: cols::BASE_ADDRESS_LOW[1], + }, + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_LOW[0], + }, + ]); + + let base_addr_hi = BusValue::Packed { + start_column: cols::BASE_ADDRESS_HIGH, + packing: Packing::Direct, + }; + + // CM16: memory[is_register, base_address, old_timestamp, old[0]] with +μ_sum + interactions.push(BusInteraction::sender( + BusId::Memory, + mu_sum.clone(), + vec![ + BusValue::Packed { + start_column: cols::IS_REGISTER, + packing: Packing::Direct, + }, + base_addr_lo.clone(), + base_addr_hi.clone(), + BusValue::Packed { + start_column: cols::OLD_TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD_TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD[0], + packing: Packing::Direct, + }, + ], + )); + + // CM17: memory[is_register, base_address, timestamp, value[0]] with -μ_sum + interactions.push(BusInteraction::receiver( + BusId::Memory, + mu_sum.clone(), + vec![ + BusValue::Packed { + start_column: cols::IS_REGISTER, + packing: Packing::Direct, + }, + base_addr_lo.clone(), + base_addr_hi.clone(), + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[0], + packing: Packing::Direct, + }, + ], + )); + + // w2 multiplicity + let w2_mult = Multiplicity::Linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::WRITE2, + }, + LinearTerm::Column { + coefficient: 1, + column: cols::WRITE4, + }, + LinearTerm::Column { + coefficient: 1, + column: cols::WRITE8, + }, + ]); + + // CM18/19: byte 1 with w2 + // For aligned accesses, adding 1 to the low byte never overflows to hi word + // (since alignment guarantees base_address_lo + width-1 < 2^32). + let addr_1_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1 << 16, + column: cols::BASE_ADDRESS_MID, + }, + LinearTerm::Column { + coefficient: 1 << 8, + column: cols::BASE_ADDRESS_LOW[1], + }, + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_LOW[0], + }, + LinearTerm::Constant(1), + ]); + + interactions.push(BusInteraction::sender( + BusId::Memory, + w2_mult.clone(), + vec![ + BusValue::Packed { + start_column: cols::IS_REGISTER, + packing: Packing::Direct, + }, + addr_1_lo.clone(), + base_addr_hi.clone(), + BusValue::Packed { + start_column: cols::OLD_TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD_TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD[1], + packing: Packing::Direct, + }, + ], + )); + + interactions.push(BusInteraction::receiver( + BusId::Memory, + w2_mult, + vec![ + BusValue::Packed { + start_column: cols::IS_REGISTER, + packing: Packing::Direct, + }, + addr_1_lo, + base_addr_hi.clone(), + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[1], + packing: Packing::Direct, + }, + ], + )); + + // CM20/21: bytes 2-3 with w4 + for i in 2..=3 { + let addr_i_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1 << 16, + column: cols::BASE_ADDRESS_MID, + }, + LinearTerm::Column { + coefficient: 1 << 8, + column: cols::BASE_ADDRESS_LOW[1], + }, + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_LOW[0], + }, + LinearTerm::Constant(i as i64), + ]); + + interactions.push(BusInteraction::sender( + BusId::Memory, + Multiplicity::Sum(cols::WRITE4, cols::WRITE8), + vec![ + BusValue::Packed { + start_column: cols::IS_REGISTER, + packing: Packing::Direct, + }, + addr_i_lo.clone(), + base_addr_hi.clone(), + BusValue::Packed { + start_column: cols::OLD_TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD_TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD[i], + packing: Packing::Direct, + }, + ], + )); + + interactions.push(BusInteraction::receiver( + BusId::Memory, + Multiplicity::Sum(cols::WRITE4, cols::WRITE8), + vec![ + BusValue::Packed { + start_column: cols::IS_REGISTER, + packing: Packing::Direct, + }, + addr_i_lo, + base_addr_hi.clone(), + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[i], + packing: Packing::Direct, + }, + ], + )); + } + + // CM22/23: bytes 4-7 with write8 + for i in 4..=7 { + let addr_i_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1 << 16, + column: cols::BASE_ADDRESS_MID, + }, + LinearTerm::Column { + coefficient: 1 << 8, + column: cols::BASE_ADDRESS_LOW[1], + }, + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_LOW[0], + }, + LinearTerm::Constant(i as i64), + ]); + + interactions.push(BusInteraction::sender( + BusId::Memory, + Multiplicity::Column(cols::WRITE8), + vec![ + BusValue::Packed { + start_column: cols::IS_REGISTER, + packing: Packing::Direct, + }, + addr_i_lo.clone(), + base_addr_hi.clone(), + BusValue::Packed { + start_column: cols::OLD_TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD_TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD[i], + packing: Packing::Direct, + }, + ], + )); + + interactions.push(BusInteraction::receiver( + BusId::Memory, + Multiplicity::Column(cols::WRITE8), + vec![ + BusValue::Packed { + start_column: cols::IS_REGISTER, + packing: Packing::Direct, + }, + addr_i_lo, + base_addr_hi.clone(), + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[i], + packing: Packing::Direct, + }, + ], + )); + } + + // ------------------------------------------------------------------------- + // CO24: Read receiver (from CPU) + // ------------------------------------------------------------------------- + // The MEMW output bus fingerprint uses base_address as [lo32, hi32]. + // Reconstruct: lo32 = 2^16*MID + 2^8*LOW[1] + LOW[0], hi32 = HIGH + interactions.push(BusInteraction::receiver( + BusId::Memw, + Multiplicity::Column(cols::MU_READ), + vec![ + // old[8] + BusValue::Packed { + start_column: cols::OLD[0], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD[1], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD[2], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD[3], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD[4], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD[5], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD[6], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD[7], + packing: Packing::Direct, + }, + // is_register + BusValue::Packed { + start_column: cols::IS_REGISTER, + packing: Packing::Direct, + }, + // base_address reconstructed as [lo32, hi32] + base_addr_lo.clone(), + base_addr_hi.clone(), + // value[8] + BusValue::Packed { + start_column: cols::VALUE[0], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[1], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[2], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[3], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[4], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[5], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[6], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[7], + packing: Packing::Direct, + }, + // timestamp + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + // write flags + BusValue::Packed { + start_column: cols::WRITE2, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::WRITE4, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::WRITE8, + packing: Packing::Direct, + }, + ], + )); + + // ------------------------------------------------------------------------- + // CO25: Write receiver (from CPU) + // ------------------------------------------------------------------------- + interactions.push(BusInteraction::receiver( + BusId::Memw, + Multiplicity::Column(cols::MU_WRITE), + vec![ + // is_register + BusValue::Packed { + start_column: cols::IS_REGISTER, + packing: Packing::Direct, + }, + // base_address reconstructed + base_addr_lo, + base_addr_hi, + // value[8] + BusValue::Packed { + start_column: cols::VALUE[0], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[1], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[2], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[3], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[4], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[5], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[6], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[7], + packing: Packing::Direct, + }, + // timestamp + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + // write flags + BusValue::Packed { + start_column: cols::WRITE2, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::WRITE4, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::WRITE8, + packing: Packing::Direct, + }, + ], + )); + + interactions +} + +// ========================================================================= +// Constraints (2 algebraic) +// ========================================================================= + +/// MEMW_A constraint kinds. +#[derive(Debug, Clone, Copy)] +pub enum MemwAlignedConstraintKind { + /// IS_BIT<μ_sum>: multiplicity sum is 0 or 1 + MuSumIsBit, + /// w2 => μ_sum: if accessing 2+ bytes, must be active row + W2ImpliesMuSum, +} + +pub struct MemwAlignedConstraint { + constraint_idx: usize, + kind: MemwAlignedConstraintKind, +} + +impl MemwAlignedConstraint { + pub fn new(kind: MemwAlignedConstraintKind, constraint_idx: usize) -> Self { + Self { + constraint_idx, + kind, + } + } + + fn compute(&self, step: &TableView) -> FieldElement + where + F: IsSubFieldOf, + E: IsField, + { + let one = FieldElement::::one(); + let mu_read = step.get_main_evaluation_element(0, cols::MU_READ).clone(); + let mu_write = step.get_main_evaluation_element(0, cols::MU_WRITE).clone(); + let mu_sum = &mu_read + &mu_write; + + match self.kind { + MemwAlignedConstraintKind::MuSumIsBit => &mu_sum * (&one - &mu_sum), + MemwAlignedConstraintKind::W2ImpliesMuSum => { + let write2 = step.get_main_evaluation_element(0, cols::WRITE2).clone(); + let write4 = step.get_main_evaluation_element(0, cols::WRITE4).clone(); + let write8 = step.get_main_evaluation_element(0, cols::WRITE8).clone(); + let w2 = write2 + write4 + write8; + &w2 * (&one - &mu_sum) + } + } + } +} + +impl TransitionConstraint for MemwAlignedConstraint { + fn degree(&self) -> usize { + 2 + } + + fn constraint_idx(&self) -> usize { + self.constraint_idx + } + + fn end_exemptions(&self) -> usize { + 0 + } + + fn evaluate( + &self, + evaluation_context: &TransitionEvaluationContext, + transition_evaluations: &mut [FieldElement], + ) { + match evaluation_context { + TransitionEvaluationContext::Prover { + frame, + periodic_values: _, + rap_challenges: _, + .. + } => { + let v = self.compute(frame.get_evaluation_step(0)); + transition_evaluations[self.constraint_idx] = v.to_extension(); + } + TransitionEvaluationContext::Verifier { + frame, + periodic_values: _, + rap_challenges: _, + .. + } => { + let v = self.compute(frame.get_evaluation_step(0)); + transition_evaluations[self.constraint_idx] = v; + } + } + } +} + +/// Creates all constraints for the MEMW_A table (2 total). +pub fn constraints() -> Vec>> { + vec![ + Box::new(MemwAlignedConstraint::new( + MemwAlignedConstraintKind::MuSumIsBit, + 0, + )), + Box::new(MemwAlignedConstraint::new( + MemwAlignedConstraintKind::W2ImpliesMuSum, + 1, + )), + ] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_memw_aligned_trace_generation() { + let ops = vec![ + MemwOperation::new(true, 4, [42, 0, 0, 0, 0, 0, 0, 0], 100, 2, true) + .with_old([42, 0, 0, 0, 0, 0, 0, 0], [50, 50, 0, 0, 0, 0, 0, 0]), + MemwOperation::new(false, 0x1000, [1, 2, 3, 4, 0, 0, 0, 0], 200, 4, false) + .with_old([0; 8], [100; 8]), + ]; + + let trace = generate_memw_aligned_trace(&ops); + assert_eq!(trace.num_cols(), cols::NUM_COLUMNS); + assert!(trace.num_rows() >= 2); + + // Check address decomposition for op[1]: addr = 0x1000 + // high = 0, mid = 0, low[1] = 0x10, low[0] = 0x00 + assert_eq!(*trace.get_main(1, cols::BASE_ADDRESS_HIGH), FE::from(0u64)); + assert_eq!(*trace.get_main(1, cols::BASE_ADDRESS_MID), FE::from(0u64)); + assert_eq!( + *trace.get_main(1, cols::BASE_ADDRESS_LOW[1]), + FE::from(0x10u64) + ); + assert_eq!( + *trace.get_main(1, cols::BASE_ADDRESS_LOW[0]), + FE::from(0x00u64) + ); + } +} diff --git a/prover/src/tables/mod.rs b/prover/src/tables/mod.rs index 2000ae414..4c1e1ddb0 100644 --- a/prover/src/tables/mod.rs +++ b/prover/src/tables/mod.rs @@ -13,7 +13,8 @@ //! //! ## Memory Tables //! -//! - **MEMW**: Memory word read/write table +//! - **MEMW**: Memory word read/write table (unaligned/split-timestamp path, 49 cols, 26 interactions) +//! - **MEMW_A**: Memory word read/write table (aligned fast path, 30 cols, 22 interactions) //! - **LOAD**: Memory load with extension table //! - **PAGE**: Paged memory init/final table (one per used page) //! - **REGISTER**: Register init/final table (32 registers × 8 bytes = 256 rows) @@ -30,6 +31,7 @@ pub mod halt; pub mod load; pub mod lt; pub mod memw; +pub mod memw_aligned; pub mod mul; pub mod page; pub mod register; @@ -41,22 +43,22 @@ pub use types::BusId; /// Per-table maximum rows, sized so each chunk uses roughly the same memory. /// /// Effective width = main_cols + 3 × bus_interactions (extension field = 3× cost). -/// MEMW (effective width 241) at 2^19 is the baseline; other tables are scaled -/// proportionally: max_rows = (241 × 2^19) / effective_width, rounded to 2^N. /// -/// | Table | Main | Bus | Eff.width | Max rows | -/// |--------|------|-----|-----------|----------| -/// | MEMW | 70 | 57 | 241 | 2^19 | -/// | CPU | 74 | 40 | 194 | 2^19 | -/// | DVRM | 34 | 34 | 136 | 2^19 | -/// | MUL | 26 | 16 | 74 | 2^20 | -/// | LT | 15 | 9 | 42 | 2^21 | -/// | SHIFT | 27 | 15 | 72 | 2^20 | -/// | LOAD | 18 | 5 | 33 | 2^21 | -/// | BRANCH | 14 | 6 | 32 | 2^21 | +/// | Table | Main | Bus | Eff.width | Max rows | +/// |---------|------|-----|-----------|----------| +/// | MEMW | 49 | 26 | 127 | 2^19 | +/// | MEMW_A | 30 | 22 | 96 | 2^20 | +/// | CPU | 74 | 40 | 194 | 2^19 | +/// | DVRM | 34 | 34 | 136 | 2^19 | +/// | MUL | 26 | 16 | 74 | 2^20 | +/// | LT | 15 | 9 | 42 | 2^21 | +/// | SHIFT | 27 | 15 | 72 | 2^20 | +/// | LOAD | 18 | 5 | 33 | 2^21 | +/// | BRANCH | 14 | 6 | 32 | 2^21 | pub mod max_rows { pub const CPU: usize = 1 << 19; // 524,288 — eff. width 194 - pub const MEMW: usize = 1 << 19; // 524,288 — eff. width 241 (baseline) + pub const MEMW: usize = 1 << 19; // 524,288 — eff. width 127 + pub const MEMW_A: usize = 1 << 20; // 1,048,576 — eff. width 96 pub const DVRM: usize = 1 << 19; // 524,288 — eff. width 136 pub const MUL: usize = 1 << 20; // 1,048,576 — eff. width 74 pub const LT: usize = 1 << 21; // 2,097,152 — eff. width 42 @@ -73,6 +75,7 @@ pub mod max_rows { pub struct MaxRowsConfig { pub cpu: usize, pub memw: usize, + pub memw_aligned: usize, pub dvrm: usize, pub mul: usize, pub lt: usize, @@ -86,6 +89,7 @@ impl Default for MaxRowsConfig { Self { cpu: max_rows::CPU, memw: max_rows::MEMW, + memw_aligned: max_rows::MEMW_A, dvrm: max_rows::DVRM, mul: max_rows::MUL, lt: max_rows::LT, @@ -103,6 +107,7 @@ impl MaxRowsConfig { Self { cpu: 1 << 5, memw: 1 << 5, + memw_aligned: 1 << 5, dvrm: 1 << 5, mul: 1 << 5, lt: 1 << 5, diff --git a/prover/src/tables/trace_builder.rs b/prover/src/tables/trace_builder.rs index 31e4a1d5f..e7773db57 100644 --- a/prover/src/tables/trace_builder.rs +++ b/prover/src/tables/trace_builder.rs @@ -10,9 +10,9 @@ //! ```text //! PHASE 0: ELF → DECODE, MEMORY_INIT (preprocessed tables) //! PHASE 1: Logs → CPU ops -//! PHASE 2: CPU ops → MEMW, LOAD, LT, Bitwise (with state tracking for MEMW/LOAD) -//! PHASE 3: MEMW → LT ops (timestamp ordering, overflow checks) -//! PHASE 4: LT, MEMW → Bitwise lookups +//! PHASE 2: CPU ops → MEMW, MEMW_A, LOAD, LT, Bitwise (with state tracking for MEMW/LOAD) +//! PHASE 3: MEMW/MEMW_A → LT ops (timestamp ordering) +//! PHASE 4: LT, MEMW_A → Bitwise lookups //! PHASE 5: Generate all traces //! ``` //! @@ -43,6 +43,7 @@ use super::halt; use super::load::{self, LoadOperation}; use super::lt::{self, LtOperation}; use super::memw::{self, MemwOperation}; +use super::memw_aligned; use super::mul::{self, MulOperation}; use super::page::{self, FinalByteState, FinalStateMap, PageConfig}; use super::register::{self, FinalRegisterStateMap, FinalRegisterWordState}; @@ -692,11 +693,10 @@ fn collect_halt_ops(register_state: &mut RegisterState) -> Vec { // Phase 3: MEMW → LT // ============================================================================= -/// Collects LT operations from MEMW for timestamp ordering and overflow checks. +/// Collects LT operations from MEMW for timestamp ordering. /// /// From spec memw.md: /// - C7-C10: old_timestamp[i] < timestamp (based on width) -/// - R1-R3: base_address < base_address + offset (overflow checks) /// /// Returns: Vec of LT operations fn collect_lt_from_memw(memw_ops: &[MemwOperation]) -> Vec { @@ -743,25 +743,43 @@ fn collect_lt_from_memw(memw_ops: &[MemwOperation]) -> Vec { )); } } + } - // R1-R3: Address overflow checks (unconditional per MEMW-CR13/14/15) - // If overflow occurs, LT returns lt=0 and the constraint (expecting lt=1) - // rejects the proof via value mismatch. - if memw_op.width == 2 { - let addr_plus_1 = memw_op.base_address.wrapping_add(1); - lt_ops.push(LtOperation::new(memw_op.base_address, addr_plus_1, false)); - } - if memw_op.width == 4 { - let addr_plus_3 = memw_op.base_address.wrapping_add(3); - lt_ops.push(LtOperation::new(memw_op.base_address, addr_plus_3, false)); - } - if memw_op.width == 8 { - let addr_plus_7 = memw_op.base_address.wrapping_add(7); - lt_ops.push(LtOperation::new(memw_op.base_address, addr_plus_7, false)); + lt_ops +} + +/// Collects LT operations from MEMW_A for timestamp ordering. +/// +/// Each aligned operation has a single old_timestamp < timestamp check. +fn collect_lt_from_memw_aligned(memw_aligned_ops: &[MemwOperation]) -> Vec { + memw_aligned_ops + .iter() + .map(|op| LtOperation::new(op.old_timestamp[0], op.timestamp, false)) + .collect() +} + +/// Checks whether a MEMW operation qualifies for the aligned fast path (MEMW_A). +/// +/// An operation is aligned if: +/// 1. For width > 1: base_address is aligned to width (low bits are zero) +/// 2. All accessed bytes share the same old_timestamp +fn is_aligned_op(op: &MemwOperation) -> bool { + let low = (op.base_address & 0xFFFF_FFFF) as u32; + let width = op.width as u32; + + // Check alignment (trivially true for width=1) + if width > 1 && (low & (width - 1)) != 0 { + return false; + } + + // Check uniform old_timestamp + for i in 1..op.width as usize { + if op.old_timestamp[i] != op.old_timestamp[0] { + return false; } } - lt_ops + true } // ============================================================================= @@ -1062,25 +1080,45 @@ fn collect_bitwise_from_dvrm(dvrm_ops: &[(DvrmOperation, bool)]) -> Vec Vec { - let mut bitwise_ops = Vec::with_capacity(memw_ops.len() * 28); // 7 addresses * 4 halfwords +/// Per operation: +/// - 1 IS_HALFWORD for base_address_mid +/// - 1 IS_BYTE for base_address_low[1] +/// - 1 AND_BYTE for alignment check (low[0] & mask == 0) +fn collect_bitwise_from_memw_aligned(ops: &[MemwOperation]) -> Vec { + let mut bitwise_ops = Vec::with_capacity(ops.len() * 3); + + for op in ops { + let low_0 = (op.base_address & 0xFF) as u8; + let low_1 = ((op.base_address >> 8) & 0xFF) as u8; + let mid = ((op.base_address >> 16) & 0xFFFF) as u16; + let mask: u8 = match op.width { + 2 => 1, + 4 => 3, + 8 => 7, + _ => 0, + }; - for memw_op in memw_ops { - for i in 0..7u64 { - let addr_add = memw_op.base_address.wrapping_add(i + 1); - // Extract 4 halfwords (DWordHL packing) - for shift in [0, 16, 32, 48] { - let half = ((addr_add >> shift) & 0xFFFF) as u16; - bitwise_ops.push(BitwiseOperation::halfword( - BitwiseOperationType::IsHalf, - (half & 0xFF) as u8, - (half >> 8) as u8, - )); - } - } + // IS_HALFWORD[mid] + bitwise_ops.push(BitwiseOperation::halfword( + BitwiseOperationType::IsHalf, + (mid & 0xFF) as u8, + (mid >> 8) as u8, + )); + + // IS_BYTE[low_1] + bitwise_ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + low_1, + )); + + // AND_BYTE[low_0, mask] → expects result 0 + bitwise_ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::AndByte, + low_0, + mask, + )); } bitwise_ops @@ -1440,6 +1478,9 @@ pub struct Traces { /// MEMW memory/register read/write traces (split into chunks of max_rows::MEMW) pub memws: Vec>, + /// MEMW_A aligned memory/register read/write traces (split into chunks of max_rows::MEMW_A) + pub memw_aligneds: Vec>, + /// LOAD memory load with extension traces (split into chunks of max_rows::LOAD) pub loads: Vec>, @@ -1491,6 +1532,7 @@ impl Traces { cpu: self.cpus.len(), lt: self.lts.len(), memw: self.memws.len(), + memw_aligned: self.memw_aligneds.len(), load: self.loads.len(), mul: self.muls.len(), dvrm: self.dvrms.len(), @@ -1652,6 +1694,10 @@ impl Traces { let halt_memw_ops = collect_halt_ops(&mut register_state); memw_ops.extend(halt_memw_ops); + // Route MEMW operations: aligned ops → MEMW_A, rest → MEMW + let (memw_aligned_ops, memw_ops): (Vec<_>, Vec<_>) = + memw_ops.into_iter().partition(is_aligned_op); + // Collect BRANCH operations from CPU ops where branch_cond = true let branch_ops: Vec = cpu_ops .iter() @@ -1715,15 +1761,16 @@ impl Traces { } // ===================================================================== - // PHASE 3: MEMW → LT (timestamp ordering and overflow checks) + // PHASE 3: MEMW/MEMW_A → LT (timestamp ordering) // ===================================================================== lt_ops.extend(collect_lt_from_memw(&memw_ops)); + lt_ops.extend(collect_lt_from_memw_aligned(&memw_aligned_ops)); // ===================================================================== // PHASE 4: All → Bitwise lookups // ===================================================================== bitwise_ops.extend(collect_bitwise_from_lt(<_ops)); - bitwise_ops.extend(collect_bitwise_from_memw(&memw_ops)); + bitwise_ops.extend(collect_bitwise_from_memw_aligned(&memw_aligned_ops)); bitwise_ops.extend(collect_bitwise_from_mul(&mul_ops)); bitwise_ops.extend(collect_bitwise_from_dvrm(&dvrm_ops)); bitwise_ops.extend(collect_bitwise_from_branch(&branch_ops)); @@ -1758,6 +1805,11 @@ impl Traces { let cpus = chunk_and_generate(&cpu_ops, max_rows.cpu, cpu::generate_cpu_trace); let memws = chunk_and_generate(&memw_ops, max_rows.memw, memw::generate_memw_trace); + let memw_aligneds = chunk_and_generate( + &memw_aligned_ops, + max_rows.memw_aligned, + memw_aligned::generate_memw_aligned_trace, + ); let loads = chunk_and_generate(&load_ops, max_rows.load, load::generate_load_trace); let lts = chunk_and_generate(<_ops, max_rows.lt, lt::generate_lt_trace); let shifts = chunk_and_generate(&shift_ops, max_rows.shift, shift::generate_shift_trace); @@ -1828,6 +1880,7 @@ impl Traces { lts, shifts, memws, + memw_aligneds, loads, decode, muls, @@ -1873,6 +1926,10 @@ impl Traces { let halt_memw_ops = collect_halt_ops(&mut register_state); memw_ops.extend(halt_memw_ops); + // Route MEMW operations: aligned ops → MEMW_A, rest → MEMW + let (memw_aligned_ops, memw_ops): (Vec<_>, Vec<_>) = + memw_ops.into_iter().partition(is_aligned_op); + // Collect MUL operations from CPU ops where op_mul = true let mut mul_ops: Vec<(MulOperation, bool)> = cpu_ops .iter() @@ -1936,15 +1993,16 @@ impl Traces { } // ===================================================================== - // PHASE 3: MEMW → LT (timestamp ordering and overflow checks) + // PHASE 3: MEMW/MEMW_A → LT (timestamp ordering) // ===================================================================== lt_ops.extend(collect_lt_from_memw(&memw_ops)); + lt_ops.extend(collect_lt_from_memw_aligned(&memw_aligned_ops)); // ===================================================================== // PHASE 4: All → Bitwise lookups // ===================================================================== bitwise_ops.extend(collect_bitwise_from_lt(<_ops)); - bitwise_ops.extend(collect_bitwise_from_memw(&memw_ops)); + bitwise_ops.extend(collect_bitwise_from_memw_aligned(&memw_aligned_ops)); bitwise_ops.extend(collect_bitwise_from_mul(&mul_ops)); bitwise_ops.extend(collect_bitwise_from_dvrm(&dvrm_ops)); bitwise_ops.extend(collect_bitwise_from_branch(&branch_ops)); @@ -1976,6 +2034,11 @@ impl Traces { let cpus = chunk_and_generate(&cpu_ops, max_rows.cpu, cpu::generate_cpu_trace); let memws = chunk_and_generate(&memw_ops, max_rows.memw, memw::generate_memw_trace); + let memw_aligneds = chunk_and_generate( + &memw_aligned_ops, + max_rows.memw_aligned, + memw_aligned::generate_memw_aligned_trace, + ); let loads = chunk_and_generate(&load_ops, max_rows.load, load::generate_load_trace); let lts = chunk_and_generate(<_ops, max_rows.lt, lt::generate_lt_trace); let shifts = chunk_and_generate(&shift_ops, max_rows.shift, shift::generate_shift_trace); @@ -2032,6 +2095,7 @@ impl Traces { lts, shifts, memws, + memw_aligneds, loads, decode, muls, diff --git a/prover/src/test_utils.rs b/prover/src/test_utils.rs index 113b0ce2b..4ecae60f9 100644 --- a/prover/src/test_utils.rs +++ b/prover/src/test_utils.rs @@ -50,6 +50,10 @@ use crate::tables::lt::{LtOperation, bus_interactions as lt_bus_interactions, co use crate::tables::memw::{ bus_interactions as memw_bus_interactions, cols as memw_cols, constraints as memw_constraints, }; +use crate::tables::memw_aligned::{ + bus_interactions as memw_aligned_bus_interactions, cols as memw_aligned_cols, + constraints as memw_aligned_constraints, +}; use crate::tables::mul::{bus_interactions as mul_bus_interactions, cols as mul_cols}; use crate::tables::page::{bus_interactions as page_bus_interactions, cols as page_cols}; use crate::tables::register::{ @@ -352,20 +356,6 @@ pub fn collect_bitwise_ops_from_load( .collect() } -/// Collect LT operations from MEMW operations. -/// -/// The MEMW table sends LT lookups for: -/// - Timestamp ordering: old_timestamp[i] < timestamp -/// - Overflow checking: base_address < base_address + offset -pub fn collect_lt_lookups_from_memw( - memw_ops: &[crate::tables::memw::MemwOperation], -) -> Vec { - memw_ops - .iter() - .flat_map(|op| op.collect_lt_lookups()) - .collect() -} - // ============================================================================= // Minimal Trace Generation (for testing/benchmarking only) // ============================================================================= @@ -570,6 +560,24 @@ pub fn create_memw_air(proof_options: &ProofOptions) -> VmAir { .with_name("MEMW") } +/// Create MEMW_A (aligned) AIR with constraints and bus interactions. +pub fn create_memw_aligned_air(proof_options: &ProofOptions) -> VmAir { + let transition_constraints = memw_aligned_constraints(); + + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: memw_aligned_bus_interactions(), + }; + + AirWithBuses::new( + memw_aligned_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("MEMW_A") +} + /// Create LOAD AIR with constraints and bus interactions. pub fn create_load_air(proof_options: &ProofOptions) -> VmAir { let transition_constraints = load_constraints(); diff --git a/prover/src/tests/prove_elfs_tests.rs b/prover/src/tests/prove_elfs_tests.rs index 21383185f..f3a43e795 100644 --- a/prover/src/tests/prove_elfs_tests.rs +++ b/prover/src/tests/prove_elfs_tests.rs @@ -1109,25 +1109,13 @@ fn test_debug_memory_tokens_sb_sh() { let val1 = memw.main_table.get(row, memw_cols::VALUE[1]).to_raw(); let old1 = memw.main_table.get(row, memw_cols::OLD[1]).to_raw(); - // address_add(0) = base + 1, stored as DWordHL - let addr1_lo0 = memw + // address_add(0) = base + 1, now virtual (computed from base + overflow) + let overflow0 = memw .main_table - .get(row, memw_cols::address_add(0)[0]) + .get(row, memw_cols::ADD_LIMB_OVERFLOW[0]) .to_raw(); - let addr1_lo1 = memw - .main_table - .get(row, memw_cols::address_add(0)[1]) - .to_raw(); - let addr1_hi0 = memw - .main_table - .get(row, memw_cols::address_add(0)[2]) - .to_raw(); - let addr1_hi1 = memw - .main_table - .get(row, memw_cols::address_add(0)[3]) - .to_raw(); - let addr1_lo = addr1_lo0 + (addr1_lo1 << 16); - let addr1_hi = addr1_hi0 + (addr1_hi1 << 16); + let addr1_lo = base_lo + 1 - overflow0 * (1u64 << 32); + let addr1_hi = base_hi + overflow0; // CM16: SEND old token for byte 1 let send_token1: Token = (is_reg, addr1_lo, addr1_hi, old_ts1_lo, old_ts1_hi, old1); @@ -1514,6 +1502,7 @@ fn test_verify_rejects_zero_table_counts() { cpu: 0, lt: 0, memw: 0, + memw_aligned: 0, load: 0, mul: 0, dvrm: 0, @@ -1580,6 +1569,7 @@ fn test_crafted_zero_count_proof_must_not_verify() { cpu: 0, lt: 0, memw: 0, + memw_aligned: 0, load: 0, mul: 0, dvrm: 0, diff --git a/prover/src/tests/trace_builder_tests.rs b/prover/src/tests/trace_builder_tests.rs index e4c6c5992..ba2de9c3b 100644 --- a/prover/src/tests/trace_builder_tests.rs +++ b/prover/src/tests/trace_builder_tests.rs @@ -3,7 +3,7 @@ use crate::tables::bitwise; use crate::tables::cpu::cols; use crate::tables::lt; -use crate::tables::memw; +use crate::tables::memw_aligned; use crate::tables::trace_builder::Traces; use crate::tables::types::FE; use executor::vm::instruction::decoding::{ArithOp, Comparison, Instruction}; @@ -393,29 +393,44 @@ fn test_memw_generated_from_register_ops() { let traces = Traces::from_logs(&logs, instructions, &Default::default()).unwrap(); - // MEMW table should have register operations + // MEMW_A table should have register operations (register ops are always aligned) // First instruction generates: M1 (read x2), M3 (read x3), M5 (write x1) assert!( - traces.memws[0].main_table.height >= 3, - "MEMW should have at least 3 rows for register ops" + !traces.memw_aligneds.is_empty(), + "MEMW_A should have at least one chunk for register ops" + ); + assert!( + traces.memw_aligneds[0].main_table.height >= 3, + "MEMW_A should have at least 3 rows for register ops" ); - // Find the register write to x1 (address = 2 * 1 = 2, is_register = 1) + // Find the register write to x1 in MEMW_A + // Register address for x1 = 2*1 = 2, decomposed: high=0, mid=0, low=[2,0] let mut found_write = false; - for row_idx in 0..traces.memws[0].main_table.height { - let row = traces.memws[0].main_table.get_row(row_idx); - // Check for register write: is_register=1, address=2 (x1), mu_write=1 - if row[memw::cols::IS_REGISTER] == FE::one() - && row[memw::cols::BASE_ADDRESS_0] == FE::from(2u64) - && row[memw::cols::MU_WRITE] == FE::one() - { - // Check value is 300 (lo32=300, hi32=0) - assert_eq!(row[memw::cols::VALUE[0]], FE::from(300u64)); - found_write = true; + for chunk in &traces.memw_aligneds { + for row_idx in 0..chunk.main_table.height { + let row = chunk.main_table.get_row(row_idx); + // Check for register write: is_register=1, base_address_low[0]=2, mu_write=1 + if row[memw_aligned::cols::IS_REGISTER] == FE::one() + && row[memw_aligned::cols::BASE_ADDRESS_LOW[0]] == FE::from(2u64) + && row[memw_aligned::cols::BASE_ADDRESS_MID] == FE::zero() + && row[memw_aligned::cols::BASE_ADDRESS_HIGH] == FE::zero() + && row[memw_aligned::cols::MU_WRITE] == FE::one() + { + // Check value is 300 (lo32 word for register DWordWL packing) + assert_eq!(row[memw_aligned::cols::VALUE[0]], FE::from(300u64)); + found_write = true; + break; + } + } + if found_write { break; } } - assert!(found_write, "Register write to x1 not found in MEMW table"); + assert!( + found_write, + "Register write to x1 not found in MEMW_A table" + ); } // ============================================================================= From 726542945b60427aecf65e88a8ee2ac5f3dcd285 Mon Sep 17 00:00:00 2001 From: diegokingston Date: Fri, 13 Mar 2026 17:13:04 -0300 Subject: [PATCH 2/3] perf: pre-allocate bus interaction vectors in MEMW and MEMW_A Use Vec::with_capacity for known interaction counts (26 and 22) to avoid reallocations. Remove unnecessary clone on last mu_sum use. --- prover/src/tables/memw.rs | 2 +- prover/src/tables/memw_aligned.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/prover/src/tables/memw.rs b/prover/src/tables/memw.rs index 3b1959275..d91fa1bbb 100644 --- a/prover/src/tables/memw.rs +++ b/prover/src/tables/memw.rs @@ -239,7 +239,7 @@ pub fn generate_memw_trace( /// - 16 Memory bus tokens (read old + write new per byte) /// - 2 MEMW output interactions (read + write from CPU) pub fn bus_interactions() -> Vec { - let mut interactions = Vec::new(); + let mut interactions = Vec::with_capacity(26); // ------------------------------------------------------------------------- // Memory bus interactions (16 total) diff --git a/prover/src/tables/memw_aligned.rs b/prover/src/tables/memw_aligned.rs index 8e992060f..40ed2a4f4 100644 --- a/prover/src/tables/memw_aligned.rs +++ b/prover/src/tables/memw_aligned.rs @@ -138,7 +138,7 @@ pub fn generate_memw_aligned_trace( // ========================================================================= pub fn bus_interactions() -> Vec { - let mut interactions = Vec::new(); + let mut interactions = Vec::with_capacity(22); let mu_sum = Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE); @@ -279,7 +279,7 @@ pub fn bus_interactions() -> Vec { // CM17: memory[is_register, base_address, timestamp, value[0]] with -μ_sum interactions.push(BusInteraction::receiver( BusId::Memory, - mu_sum.clone(), + mu_sum, vec![ BusValue::Packed { start_column: cols::IS_REGISTER, From 66ee8c66017c5560fb676f72c8ea8e1a6cdc128e Mon Sep 17 00:00:00 2001 From: diegokingston Date: Mon, 16 Mar 2026 20:33:06 -0300 Subject: [PATCH 3/3] fix: remove redundant MEMW_A range checks and update stale comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove IS_HALFWORD[mid] and IS_BYTE[low[1]] bus interactions from MEMW_A: these are spec assumptions (MEMW_A-A2, MEMW_A-A3) satisfied by the CPU's IS_BYTE range checks on its byte decomposition plus Memw bus propagation. MID and LOW[1] only appear inside the linear combination 2^16*MID + 2^8*LOW[1] + LOW[0], which the bus constrains to equal the CPU's base_address_0, so individual range checks are redundant. Saves 2 bus interactions per MEMW_A row. - Remove corresponding IS_HALFWORD and IS_BYTE bitwise lookups from collect_bitwise_from_memw_aligned in trace_builder. - Update stale LT comment labels in memw.rs: C7-C10 → MEMW-C4 through MEMW-C7 to match current spec numbering. --- prover/src/tables/memw.rs | 10 ++++----- prover/src/tables/memw_aligned.rs | 36 +++++++----------------------- prover/src/tables/trace_builder.rs | 20 ++++------------- 3 files changed, 17 insertions(+), 49 deletions(-) diff --git a/prover/src/tables/memw.rs b/prover/src/tables/memw.rs index d91fa1bbb..b89cd7a9c 100644 --- a/prover/src/tables/memw.rs +++ b/prover/src/tables/memw.rs @@ -752,10 +752,10 @@ pub fn bus_interactions() -> Vec { )); // ------------------------------------------------------------------------- - // LT interactions for timestamp ordering (constraints 7-10) + // LT interactions for timestamp ordering (MEMW-C4 through MEMW-C7) // ------------------------------------------------------------------------- - // C7: LT[1; old_timestamp[0], timestamp] with μ_sum + // MEMW-C4: LT[1; old_timestamp[0], timestamp] with μ_sum interactions.push(BusInteraction::sender( BusId::Lt, Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE), @@ -773,7 +773,7 @@ pub fn bus_interactions() -> Vec { ], )); - // C8: LT[1; old_timestamp[1], timestamp] with w2 + // MEMW-C5: LT[1; old_timestamp[1], timestamp] with w2 interactions.push(BusInteraction::sender( BusId::Lt, Multiplicity::Linear(vec![ @@ -804,7 +804,7 @@ pub fn bus_interactions() -> Vec { ], )); - // C9: LT[1; old_timestamp[i], timestamp] for i ∈ [2,3] with w4 + // MEMW-C6: LT[1; old_timestamp[i], timestamp] for i ∈ [2,3] with w4 for i in 2..4 { interactions.push(BusInteraction::sender( BusId::Lt, @@ -824,7 +824,7 @@ pub fn bus_interactions() -> Vec { )); } - // C10: LT[1; old_timestamp[i], timestamp] for i ∈ [4,7] with write8 + // MEMW-C7: LT[1; old_timestamp[i], timestamp] for i ∈ [4,7] with write8 for i in 4..8 { interactions.push(BusInteraction::sender( BusId::Lt, diff --git a/prover/src/tables/memw_aligned.rs b/prover/src/tables/memw_aligned.rs index 40ed2a4f4..dac367a71 100644 --- a/prover/src/tables/memw_aligned.rs +++ b/prover/src/tables/memw_aligned.rs @@ -17,10 +17,8 @@ //! - `old_timestamp`: DWordWL (2 cols — single, not 8!) //! - `mu_read`, `mu_write`: multiplicity columns //! -//! ## Bus Interactions (22) -//! - 1 IS_HALFWORD[base_address_mid] -//! - 1 IS_BYTE[base_address_low[1]] -//! - 1 AND_BYTE[base_address_low[0], mask] → 0 (alignment check) +//! ## Bus Interactions (20) +//! - 1 AND_BYTE[base_address_low[0], mask] → 0 (alignment check + implicit IS_BYTE) //! - 1 LT[old_timestamp, timestamp, 0] → 1 //! - 16 Memory bus tokens //! - 2 MEMW output interactions (read + write) @@ -138,33 +136,15 @@ pub fn generate_memw_aligned_trace( // ========================================================================= pub fn bus_interactions() -> Vec { - let mut interactions = Vec::with_capacity(22); + let mut interactions = Vec::with_capacity(20); let mu_sum = Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE); - // ------------------------------------------------------------------------- - // IS_HALFWORD[base_address_mid] with μ_sum - // ------------------------------------------------------------------------- - interactions.push(BusInteraction::sender( - BusId::IsHalfword, - mu_sum.clone(), - vec![BusValue::Packed { - start_column: cols::BASE_ADDRESS_MID, - packing: Packing::Direct, - }], - )); - - // ------------------------------------------------------------------------- - // IS_BYTE[base_address_low[1]] with μ_sum - // ------------------------------------------------------------------------- - interactions.push(BusInteraction::sender( - BusId::IsByte, - mu_sum.clone(), - vec![BusValue::Packed { - start_column: cols::BASE_ADDRESS_LOW[1], - packing: Packing::Direct, - }], - )); + // MEMW_A-A2 (IS_HALF[mid]) and MEMW_A-A3 (IS_BYTE[low]) are assumptions: + // the CPU constrains base_address_0 ∈ [0, 2^32-1] via IS_BYTE on its bytes, + // and the Memw bus forces 2^16*MID + 2^8*LOW[1] + LOW[0] = base_address_0. + // Since MID and LOW[1] only appear inside this linear combination (never + // independently), their individual range checks are redundant. // ------------------------------------------------------------------------- // AND_BYTE[base_address_low[0], mask] → 0 with μ_sum diff --git a/prover/src/tables/trace_builder.rs b/prover/src/tables/trace_builder.rs index e7773db57..f0a10c8bc 100644 --- a/prover/src/tables/trace_builder.rs +++ b/prover/src/tables/trace_builder.rs @@ -1087,12 +1087,13 @@ fn collect_bitwise_from_dvrm(dvrm_ops: &[(DvrmOperation, bool)]) -> Vec Vec { - let mut bitwise_ops = Vec::with_capacity(ops.len() * 3); + // Only AND_BYTE for alignment check (also implicitly IS_BYTE-checks low[0]). + // IS_HALF[mid] and IS_BYTE[low[1]] are assumptions satisfied by the CPU's + // IS_BYTE range checks on its byte decomposition + Memw bus propagation. + let mut bitwise_ops = Vec::with_capacity(ops.len()); for op in ops { let low_0 = (op.base_address & 0xFF) as u8; - let low_1 = ((op.base_address >> 8) & 0xFF) as u8; - let mid = ((op.base_address >> 16) & 0xFFFF) as u16; let mask: u8 = match op.width { 2 => 1, 4 => 3, @@ -1100,19 +1101,6 @@ fn collect_bitwise_from_memw_aligned(ops: &[MemwOperation]) -> Vec 0, }; - // IS_HALFWORD[mid] - bitwise_ops.push(BitwiseOperation::halfword( - BitwiseOperationType::IsHalf, - (mid & 0xFF) as u8, - (mid >> 8) as u8, - )); - - // IS_BYTE[low_1] - bitwise_ops.push(BitwiseOperation::single_byte( - BitwiseOperationType::IsByte, - low_1, - )); - // AND_BYTE[low_0, mask] → expects result 0 bitwise_ops.push(BitwiseOperation::byte_op( BitwiseOperationType::AndByte,