diff --git a/prover/src/tables/memw.rs b/prover/src/tables/memw.rs index 3bbd96e2f..3d05d7d0a 100644 --- a/prover/src/tables/memw.rs +++ b/prover/src/tables/memw.rs @@ -1,32 +1,30 @@ -//! MEMW (Memory Write/Read) table. +//! MEMW (Memory Write/Read) table — unaligned / split-timestamp path. //! -//! This table handles memory and register read/write operations with timestamp-based -//! consistency checking. +//! This table handles memory and register read/write operations where bytes may +//! have different old_timestamps or the access is unaligned. +//! +//! ## Column layout (49 columns) //! -//! ## Inputs //! - `is_register`: Bit (1 = register access, 0 = memory access) -//! - `base_address`: DWordWL (64-bit address) +//! - `base_address`: DWordWL (64-bit address, 2 cols) //! - `value[8]`: BaseField[8] (8 bytes to write) -//! - `timestamp`: DWordWL (64-bit timestamp) +//! - `timestamp`: DWordWL (64-bit timestamp, 2 cols) //! - `write2/4/8`: Bit (access width flags) -//! -//! ## Output //! - `old[8]`: BaseField[8] (previous values at address) -//! -//! ## Auxiliary -//! - `address_add[7]`: DWordHL[7] (base_address + 1..7) -//! - `old_timestamp[8]`: DWordWL[8] (previous timestamps) +//! - `add_limb_overflow[7]`: Bit[7] (carry flags for base_address + i) +//! - `old_timestamp[8]`: DWordWL[8] (previous timestamps, 16 cols) +//! - `mu_read`, `mu_write`: multiplicity columns //! //! ## Virtual (computed inline) +//! - `address_add[i]` = (base_address_0 + i+1 - 2^32 * overflow[i], base_address_1 + overflow[i]) //! - `w2`: write2 + write4 + write8 (writing at least 2 bytes) //! - `w4`: write4 + write8 (writing at least 4 bytes) //! - `μ_sum`: μ_read + μ_write //! -//! ## Bus Interactions -//! - Receiver: MEMW (from CPU for LOAD/STORE operations) -//! - Sender: IS_HALFWORD (range checks for address_add) -//! - Sender: LT (timestamp ordering checks) -//! - Sender/Receiver: Memory bus (internal read/write consistency) +//! ## Bus Interactions (26) +//! - 8 LT timestamp checks (old_timestamp[i] < timestamp) +//! - 16 Memory bus tokens (read old + write new, per byte) +//! - 2 MEMW output interactions (read + write, from CPU) use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; @@ -37,14 +35,14 @@ use stark::trace::TraceTable; use stark::traits::TransitionEvaluationContext; use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; -use crate::constraints::templates::{AddConstraint, AddOperand}; +use crate::constraints::templates::IsBitConstraint; /// Maximum number of rows per MEMW table chunk. /// If operations exceed this, the trace is split into multiple tables. pub const MAX_ROWS: usize = super::max_rows::MEMW; // ========================================================================= -// Column indices for MEMW table +// Column indices for MEMW table (49 columns) // ========================================================================= /// Column definitions for the MEMW table. @@ -74,31 +72,21 @@ pub mod cols { pub const OLD: [usize; 8] = [16, 17, 18, 19, 20, 21, 22, 23]; // Auxiliary columns - /// address_add[7]: each is DWordHL (4 halfwords = 4 columns) - /// Total: 7 * 4 = 28 columns - pub const ADDRESS_ADD_START: usize = 24; - // address_add[i] uses columns ADDRESS_ADD_START + i*4 .. ADDRESS_ADD_START + i*4 + 4 + /// add_limb_overflow[7]: Bit columns indicating carry when adding i+1 to base_address_0 + pub const ADD_LIMB_OVERFLOW: [usize; 7] = [24, 25, 26, 27, 28, 29, 30]; /// old_timestamp[8]: each is DWordWL (2 words = 2 columns) /// Total: 8 * 2 = 16 columns - pub const OLD_TIMESTAMP_START: usize = 52; // 24 + 28 - // old_timestamp[i] uses columns OLD_TIMESTAMP_START + i*2 .. OLD_TIMESTAMP_START + i*2 + 2 + pub const OLD_TIMESTAMP_START: usize = 31; // Multiplicity columns /// μ_read: Whether we are performing a read - pub const MU_READ: usize = 68; // 52 + 16 + pub const MU_READ: usize = 47; /// μ_write: Whether we are performing a write - pub const MU_WRITE: usize = 69; + pub const MU_WRITE: usize = 48; /// Total number of columns - /// Note: w2, w4, μ_sum are now computed inline via Multiplicity::Linear/Sum - pub const NUM_COLUMNS: usize = 70; - - /// Helper to get address_add[i] column indices (4 halfwords each) - pub fn address_add(i: usize) -> [usize; 4] { - let base = ADDRESS_ADD_START + i * 4; - [base, base + 1, base + 2, base + 3] - } + pub const NUM_COLUMNS: usize = 49; /// Helper to get old_timestamp[i] column indices (2 words each) pub fn old_timestamp(i: usize) -> [usize; 2] { @@ -163,25 +151,12 @@ impl MemwOperation { /// Convert access width to the spec's flag representation (write2, write4, write8). /// - /// The spec uses three flags to encode access width: - /// - `write2`: set if accessing 2+ bytes (width >= 2) - /// - `write4`: set if accessing 4+ bytes (width >= 4) - /// - `write8`: set if accessing 8 bytes (width == 8) - /// - /// This encoding allows computing "at least N bytes" predicates: - /// - w2 (at least 2) = write2 + write4 + write8 - /// - w4 (at least 4) = write4 + write8 - /// /// | Width | write2 | write4 | write8 | /// |-------|--------|--------|--------| /// | 1 | 0 | 0 | 0 | /// | 2 | 1 | 0 | 0 | /// | 4 | 0 | 1 | 0 | /// | 8 | 0 | 0 | 1 | - /// - /// Note: These are "exactly N" semantics per spec, not cumulative. - /// Virtual columns w2 = write2 + write4 + write8 and w4 = write4 + write8 - /// compute "at least N" from these. pub fn write_flags(&self) -> (bool, bool, bool) { match self.width { 1 => (false, false, false), @@ -191,86 +166,6 @@ impl MemwOperation { _ => (false, false, false), } } - - /// Collect LT operations for timestamp ordering and overflow checking. - /// - /// Per spec constraints #7-10 and R1-R3: - /// - #7: old_timestamp[0] < timestamp (for all accesses) - /// - #8: old_timestamp[1] < timestamp (for width >= 2) - /// - #9: old_timestamp[2,3] < timestamp (for width >= 4) - /// - #10: old_timestamp[4..7] < timestamp (for width == 8) - /// - R1: base_address < base_address + 1 (overflow check for width >= 2) - /// - R2: base_address < base_address + 3 (overflow check for width >= 4) - /// - R3: base_address < base_address + 7 (overflow check for width == 8) - pub fn collect_lt_lookups(&self) -> Vec { - use super::lt::LtOperation; - - let mut ops = Vec::new(); - - // Constraint 7: old_timestamp[0] < timestamp (always, for any access) - ops.push(LtOperation::new( - self.old_timestamp[0], - self.timestamp, - false, - )); - - // Constraint 8: old_timestamp[1] < timestamp (for width >= 2) - if self.width >= 2 { - ops.push(LtOperation::new( - self.old_timestamp[1], - self.timestamp, - false, - )); - } - - // Constraint 9: old_timestamp[2,3] < timestamp (for width >= 4) - if self.width >= 4 { - ops.push(LtOperation::new( - self.old_timestamp[2], - self.timestamp, - false, - )); - ops.push(LtOperation::new( - self.old_timestamp[3], - self.timestamp, - false, - )); - } - - // Constraint 10: old_timestamp[4..7] < timestamp (for width == 8) - if self.width == 8 { - for i in 4..8 { - ops.push(LtOperation::new( - self.old_timestamp[i], - self.timestamp, - false, - )); - } - } - - // Overflow checks R1-R3: base_address < base_address + offset - // Always generate the LT operation - if overflow occurred, LT will return 0 - // and the constraint (expecting result=1) will fail, rejecting the proof. - // R1: for width == 2, check base_address < base_address + 1 - if self.width == 2 { - let addr_plus_1 = self.base_address.wrapping_add(1); - ops.push(LtOperation::new(self.base_address, addr_plus_1, false)); - } - - // R2: for width == 4, check base_address < base_address + 3 - if self.width == 4 { - let addr_plus_3 = self.base_address.wrapping_add(3); - ops.push(LtOperation::new(self.base_address, addr_plus_3, false)); - } - - // R3: for width == 8, check base_address < base_address + 7 - if self.width == 8 { - let addr_plus_7 = self.base_address.wrapping_add(7); - ops.push(LtOperation::new(self.base_address, addr_plus_7, false)); - } - - ops - } } /// Generates the MEMW trace table from a list of operations. @@ -287,7 +182,8 @@ pub fn generate_memw_trace( data[base + cols::IS_REGISTER] = FE::from(op.is_register as u64); // base_address as DWordWL (2 words) - data[base + cols::BASE_ADDRESS_0] = FE::from(op.base_address & 0xFFFF_FFFF); + let base_addr_lo = op.base_address & 0xFFFF_FFFF; + data[base + cols::BASE_ADDRESS_0] = FE::from(base_addr_lo); data[base + cols::BASE_ADDRESS_1] = FE::from(op.base_address >> 32); // value[8] @@ -310,14 +206,11 @@ pub fn generate_memw_trace( data[base + cols::OLD[i]] = FE::from(op.old[i]); } - // Auxiliary: address_add[7] - each as DWordHL (4 halfwords) + // Auxiliary: add_limb_overflow[7] + // overflow[i] = 1 if (base_address_lo + i+1) >= 2^32 for i in 0..7 { - let addr = op.base_address.wrapping_add(i as u64 + 1); - let cols_i = cols::address_add(i); - data[base + cols_i[0]] = FE::from(addr & 0xFFFF); - data[base + cols_i[1]] = FE::from((addr >> 16) & 0xFFFF); - data[base + cols_i[2]] = FE::from((addr >> 32) & 0xFFFF); - data[base + cols_i[3]] = FE::from((addr >> 48) & 0xFFFF); + let overflows = base_addr_lo + (i as u64 + 1) >= (1u64 << 32); + data[base + cols::ADD_LIMB_OVERFLOW[i]] = FE::from(overflows as u64); } // Auxiliary: old_timestamp[8] - each as DWordWL (2 words) @@ -330,81 +223,42 @@ pub fn generate_memw_trace( // Multiplicity data[base + cols::MU_READ] = FE::from(op.is_read as u64); data[base + cols::MU_WRITE] = FE::from(!op.is_read as u64); - // Note: w2, w4, μ_sum are computed inline via Multiplicity::Linear/Sum } TraceTable::new_main(data, cols::NUM_COLUMNS, 1) } // ========================================================================= -// Bus interactions +// Bus interactions (26 total) // ========================================================================= /// Creates all bus interactions for the MEMW table. /// -/// The MEMW table: -/// - **Receives** MEMW lookups from CPU (for LOAD/STORE operations) -/// - **Sends** IS_HALFWORD lookups for address_add range checks -/// - **Sends** LT lookups for timestamp ordering (old_timestamp < timestamp) +/// 26 interactions: +/// - 8 LT timestamp ordering checks +/// - 16 Memory bus tokens (read old + write new per byte) +/// - 2 MEMW output interactions (read + write from CPU) pub fn bus_interactions() -> Vec { - let mut interactions = Vec::new(); - - // ------------------------------------------------------------------------- - // IS_HALFWORD range checks for address_add[i][j] - // ------------------------------------------------------------------------- - // ------------------------------------------------------------------------- - // IsHalfword range checks for address_add columns - // ------------------------------------------------------------------------- - // Each address_add[i] is 4 halfwords (DWordHL packing), need to range check all. - // Only check when row is active (μ_read + μ_write > 0). - for i in 0..7 { - let cols_i = cols::address_add(i); - for &col in &cols_i { - interactions.push(BusInteraction::sender( - BusId::IsHalfword, - // Only range check when row is active - Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE), - vec![BusValue::Packed { - start_column: col, - packing: Packing::Direct, - }], - )); - } - } - - // ------------------------------------------------------------------------- - // Memory bus interactions (M1-M8 from spec) - // ------------------------------------------------------------------------- - // DISABLED: Memory bus requires initialization and finalization: - // - Initialization: For each address accessed, an initial row at timestamp=0 - // with the starting value must exist so the first read has a matching write. - // - Finalization: Final values must be consumed to balance the bus. - // Without these, the bus won't balance. - // ------------------------------------------------------------------------- - // These ensure read/write consistency: - // - Read old value at old_timestamp (+multiplicity) - // - Write new value at current timestamp (-multiplicity) - // - // Memory bus format: memory[is_register, address, timestamp_lo, timestamp_hi, value] - // - // Register tokens (is_register=1) are balanced by the REGISTER table. - // Memory tokens (is_register=0) are balanced by PAGE tables. + let mut interactions = Vec::with_capacity(26); // ------------------------------------------------------------------------- - // Memory bus interactions per spec CM16-CM23 + // Memory bus interactions (16 total) // ------------------------------------------------------------------------- - // Token format: memory[is_register, address_lo, address_hi, ts_lo, ts_hi, value] - // - // For registers (is_register=1): value is a Word (32-bit), address is Word-indexed - // For memory (is_register=0): value is a Byte (8-bit), address is byte-indexed + // address_add[i] is VIRTUAL: + // lo = base_address_0 + (i+1) - 2^32 * add_limb_overflow[i] + // hi = base_address_1 + add_limb_overflow[i] // - // Multiplicities per spec: - // - CM16/17 (index 0): μ_sum - // - CM18/19 (index 1): w2 = write2 + write4 + write8 - // - CM20/21 (indices 2-3): w4 = write4 + write8 - // - CM22/23 (indices 4-7): write8 - - // CM16: memory[is_register, base_address, old_timestamp[0], old[0]] with +μ_sum + // Safety: `hi` is at most `base_address_1 + 1`. This never reaches 2^32 + // because the CPU table splits addresses into (lo, hi) with both halves + // in [0, 2^32), and the Memw bus ties MEMW's base_address to the CPU's + // value. MEMW only receives accesses where base_address_1 <= 0xFFFF_FFFE + // (addresses near u64::MAX are rejected by the executor before proving). + // Consequently, `add_limb_overflow[i]` is implicitly correct: a wrong + // carry bit produces a memory token at a wrong address that has no + // matching PAGE/REGISTER token, causing multiset imbalance and an + // invalid proof. + + // CM8: memory[is_register, base_address, old_timestamp[0], old[0]] with +μ_sum interactions.push(BusInteraction::sender( BusId::Memory, Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE), @@ -436,7 +290,7 @@ pub fn bus_interactions() -> Vec { ], )); - // CM17: memory[is_register, base_address, timestamp, value[0]] with -μ_sum + // CM9: memory[is_register, base_address, timestamp, value[0]] with -μ_sum interactions.push(BusInteraction::receiver( BusId::Memory, Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE), @@ -468,35 +322,50 @@ pub fn bus_interactions() -> Vec { ], )); - // Helper: address_add[0] = base_address + 1, stored as DWordHL (4 halfwords) - // Use Word2L to combine each pair of halfwords into a word - let addr_add_0_lo = BusValue::Packed { - start_column: cols::address_add(0)[0], - packing: Packing::Word2L, - }; - let addr_add_0_hi = BusValue::Packed { - start_column: cols::address_add(0)[2], - packing: Packing::Word2L, - }; - - // CM18: memory[is_register, address_add[0], old_timestamp[1], old[1]] with +w2 - // w2 = write2 + write4 + write8 + // CM10/11: byte 1, multiplicity w2 = write2 + write4 + write8 + // address_add[0] is virtual: lo = base_address_0 + 1 - 2^32 * overflow[0] + // hi = base_address_1 + overflow[0] + let w2_mult = Multiplicity::Linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::WRITE2, + }, + LinearTerm::Column { + coefficient: 1, + column: cols::WRITE4, + }, + LinearTerm::Column { + coefficient: 1, + column: cols::WRITE8, + }, + ]); + + let addr_add_0_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_0, + }, + LinearTerm::Constant(1), + LinearTerm::Column { + coefficient: -(1i64 << 32), + column: cols::ADD_LIMB_OVERFLOW[0], + }, + ]); + let addr_add_0_hi = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_1, + }, + LinearTerm::Column { + coefficient: 1, + column: cols::ADD_LIMB_OVERFLOW[0], + }, + ]); + + // CM10: send old token for byte 1 interactions.push(BusInteraction::sender( BusId::Memory, - Multiplicity::Linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE2, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE4, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE8, - }, - ]), + w2_mult.clone(), vec![ BusValue::Packed { start_column: cols::IS_REGISTER, @@ -519,23 +388,10 @@ pub fn bus_interactions() -> Vec { ], )); - // CM19: memory[is_register, address_add[0], timestamp, value[1]] with -w2 + // CM11: receive new token for byte 1 interactions.push(BusInteraction::receiver( BusId::Memory, - Multiplicity::Linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE2, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE4, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE8, - }, - ]), + w2_mult, vec![ BusValue::Packed { start_column: cols::IS_REGISTER, @@ -558,30 +414,35 @@ pub fn bus_interactions() -> Vec { ], )); - // CM20/21: indices 2-3 with multiplicity w4 = write4 + write8 + // CM12/13: bytes 2-3 with multiplicity w4 = write4 + write8 for i in 2..=3 { - let addr_add_lo = BusValue::Packed { - start_column: cols::address_add(i - 1)[0], - packing: Packing::Word2L, - }; - let addr_add_hi = BusValue::Packed { - start_column: cols::address_add(i - 1)[2], - packing: Packing::Word2L, - }; - - // CM22.i: send old token + let overflow_col = cols::ADD_LIMB_OVERFLOW[i - 1]; + let addr_add_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_0, + }, + LinearTerm::Constant(i as i64), + LinearTerm::Column { + coefficient: -(1i64 << 32), + column: overflow_col, + }, + ]); + let addr_add_hi = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_1, + }, + LinearTerm::Column { + coefficient: 1, + column: overflow_col, + }, + ]); + + // send old token interactions.push(BusInteraction::sender( BusId::Memory, - Multiplicity::Linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE4, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE8, - }, - ]), + Multiplicity::Sum(cols::WRITE4, cols::WRITE8), vec![ BusValue::Packed { start_column: cols::IS_REGISTER, @@ -604,19 +465,10 @@ pub fn bus_interactions() -> Vec { ], )); - // CM23.i: receive new token + // receive new token interactions.push(BusInteraction::receiver( BusId::Memory, - Multiplicity::Linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE4, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::WRITE8, - }, - ]), + Multiplicity::Sum(cols::WRITE4, cols::WRITE8), vec![ BusValue::Packed { start_column: cols::IS_REGISTER, @@ -640,18 +492,32 @@ pub fn bus_interactions() -> Vec { )); } - // CM22/23: indices 4-7 with multiplicity write8 + // CM14/15: bytes 4-7 with multiplicity write8 for i in 4..=7 { - let addr_add_lo = BusValue::Packed { - start_column: cols::address_add(i - 1)[0], - packing: Packing::Word2L, - }; - let addr_add_hi = BusValue::Packed { - start_column: cols::address_add(i - 1)[2], - packing: Packing::Word2L, - }; - - // CM22.i: send old token + let overflow_col = cols::ADD_LIMB_OVERFLOW[i - 1]; + let addr_add_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_0, + }, + LinearTerm::Constant(i as i64), + LinearTerm::Column { + coefficient: -(1i64 << 32), + column: overflow_col, + }, + ]); + let addr_add_hi = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::BASE_ADDRESS_1, + }, + LinearTerm::Column { + coefficient: 1, + column: overflow_col, + }, + ]); + + // send old token interactions.push(BusInteraction::sender( BusId::Memory, Multiplicity::Column(cols::WRITE8), @@ -677,7 +543,7 @@ pub fn bus_interactions() -> Vec { ], )); - // CM23.i: receive new token + // receive new token interactions.push(BusInteraction::receiver( BusId::Memory, Multiplicity::Column(cols::WRITE8), @@ -705,17 +571,13 @@ pub fn bus_interactions() -> Vec { } // ------------------------------------------------------------------------- - // CO24: Read receiver (unified for register and memory operations) + // CO16: Read receiver (from CPU) // ------------------------------------------------------------------------- - // OLD and VALUE are 8 individual BaseField elements (Direct packing). - // For registers: [lo32_word, hi32_word, 0, 0, 0, 0, 0, 0] - // For memory: [byte0, byte1, ..., byte7] - // Both match sender format since bus compares field elements directly. interactions.push(BusInteraction::receiver( BusId::Memw, Multiplicity::Column(cols::MU_READ), vec![ - // old[8] - output for reads (words) + // old[8] BusValue::Packed { start_column: cols::OLD[0], packing: Packing::Direct, @@ -753,7 +615,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::IS_REGISTER, packing: Packing::Direct, }, - // base_address (DWordWL = 2 words) + // base_address BusValue::Packed { start_column: cols::BASE_ADDRESS_0, packing: Packing::Direct, @@ -762,7 +624,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::BASE_ADDRESS_1, packing: Packing::Direct, }, - // value[8] - direct reads (words) + // value[8] BusValue::Packed { start_column: cols::VALUE[0], packing: Packing::Direct, @@ -795,7 +657,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::VALUE[7], packing: Packing::Direct, }, - // timestamp (DWordWL = 2 words) + // timestamp BusValue::Packed { start_column: cols::TIMESTAMP_0, packing: Packing::Direct, @@ -821,12 +683,8 @@ pub fn bus_interactions() -> Vec { )); // ------------------------------------------------------------------------- - // CO25: Write receiver (unified for register and memory operations) + // CO17: Write receiver (from CPU) // ------------------------------------------------------------------------- - // VALUE is 8 individual BaseField elements (Direct packing). - // For registers: [lo32_word, hi32_word, 0, 0, 0, 0, 0, 0] - // For memory: [byte0, byte1, ..., byte7] - // Both match sender format since bus compares field elements directly. interactions.push(BusInteraction::receiver( BusId::Memw, Multiplicity::Column(cols::MU_WRITE), @@ -836,7 +694,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::IS_REGISTER, packing: Packing::Direct, }, - // base_address (DWordWL = 2 words) + // base_address BusValue::Packed { start_column: cols::BASE_ADDRESS_0, packing: Packing::Direct, @@ -845,7 +703,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::BASE_ADDRESS_1, packing: Packing::Direct, }, - // value[8] - direct reads (words) + // value[8] BusValue::Packed { start_column: cols::VALUE[0], packing: Packing::Direct, @@ -878,7 +736,7 @@ pub fn bus_interactions() -> Vec { start_column: cols::VALUE[7], packing: Packing::Direct, }, - // timestamp (DWordWL = 2 words) + // timestamp BusValue::Packed { start_column: cols::TIMESTAMP_0, packing: Packing::Direct, @@ -904,35 +762,28 @@ pub fn bus_interactions() -> Vec { )); // ------------------------------------------------------------------------- - // LT interactions for timestamp ordering (constraints 7-10) + // LT interactions for timestamp ordering (MEMW-C4 through C7) // ------------------------------------------------------------------------- - // Verify old_timestamp[i] < timestamp for each accessed byte. - // LT bus uses 2 elements per 64-bit operand: [lo32, hi32] - // Both old_timestamp and timestamp are DWordWL, so use Packing::DWordWL. - // Constraint 7: LT[1; old_timestamp[0], timestamp] with μ_sum + // MEMW-C4: LT[1; old_timestamp[0], timestamp] with μ_sum interactions.push(BusInteraction::sender( BusId::Lt, Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE), vec![ - // lhs = old_timestamp[0] as DWordWL (2 elements: [lo32, hi32]) BusValue::Packed { start_column: cols::old_timestamp(0)[0], packing: Packing::DWordWL, }, - // rhs = timestamp as DWordWL (2 elements: [lo32, hi32]) BusValue::Packed { start_column: cols::TIMESTAMP_0, packing: Packing::DWordWL, }, - // signed = 0 (unsigned comparison) BusValue::constant(0), - // lt = 1 (expected result: old_timestamp < timestamp) BusValue::constant(1), ], )); - // Constraint 8: LT[1; old_timestamp[1], timestamp] with w2 + // MEMW-C5: LT[1; old_timestamp[1], timestamp] with w2 interactions.push(BusInteraction::sender( BusId::Lt, Multiplicity::Linear(vec![ @@ -963,7 +814,7 @@ pub fn bus_interactions() -> Vec { ], )); - // Constraint 9: LT[1; old_timestamp[i], timestamp] for i ∈ [2,3] with w4 + // MEMW-C6: LT[1; old_timestamp[i], timestamp] for i ∈ [2,3] with w4 for i in 2..4 { interactions.push(BusInteraction::sender( BusId::Lt, @@ -983,7 +834,7 @@ pub fn bus_interactions() -> Vec { )); } - // Constraint 10: LT[1; old_timestamp[i], timestamp] for i ∈ [4,7] with write8 + // MEMW-C7: LT[1; old_timestamp[i], timestamp] for i ∈ [4,7] with write8 for i in 4..8 { interactions.push(BusInteraction::sender( BusId::Lt, @@ -1003,69 +854,6 @@ pub fn bus_interactions() -> Vec { )); } - // ------------------------------------------------------------------------- - // LT interactions for overflow checking (constraints R1-R3) - // ------------------------------------------------------------------------- - // Verify base_address < address_add[i] (no overflow when adding offset). - // base_address is DWordWL, address_add[i] is DWordHL (cast to DWordWL). - // Both packings produce 2 elements [lo32, hi32]. - - // R1: LT[1; base_address, address_add[0]] with write2 - // This checks for no overflow when accessing byte 1 (width == 2) - interactions.push(BusInteraction::sender( - BusId::Lt, - Multiplicity::Column(cols::WRITE2), - vec![ - BusValue::Packed { - start_column: cols::BASE_ADDRESS_0, - packing: Packing::DWordWL, - }, - BusValue::Packed { - start_column: cols::address_add(0)[0], - packing: Packing::DWordHL, - }, - BusValue::constant(0), // unsigned - BusValue::constant(1), // lt = 1 - ], - )); - - // R2: LT[1; base_address, address_add[2]] with write4 - // This checks for no overflow when accessing byte 3 (width == 4) - interactions.push(BusInteraction::sender( - BusId::Lt, - Multiplicity::Column(cols::WRITE4), - vec![ - BusValue::Packed { - start_column: cols::BASE_ADDRESS_0, - packing: Packing::DWordWL, - }, - BusValue::Packed { - start_column: cols::address_add(2)[0], - packing: Packing::DWordHL, - }, - BusValue::constant(0), - BusValue::constant(1), - ], - )); - - // R3: LT[1; base_address, address_add[6]] with write8 - interactions.push(BusInteraction::sender( - BusId::Lt, - Multiplicity::Column(cols::WRITE8), - vec![ - BusValue::Packed { - start_column: cols::BASE_ADDRESS_0, - packing: Packing::DWordWL, - }, - BusValue::Packed { - start_column: cols::address_add(6)[0], - packing: Packing::DWordHL, - }, - BusValue::constant(0), - BusValue::constant(1), - ], - )); - interactions } @@ -1097,7 +885,7 @@ where } // ========================================================================= -// Constraints +// Constraints (9 total: 2 custom + 7 IS_BIT) // ========================================================================= /// MEMW table constraint kinds. @@ -1132,12 +920,10 @@ impl MemwConstraint { match self.kind { MemwConstraintKind::MuSumIsBit => { - // IS_BIT<μ_sum>: μ_sum * (1 - μ_sum) = 0 let mu_sum = compute_mu_sum(step); &mu_sum * (&one - &mu_sum) } MemwConstraintKind::W2ImpliesMuSum => { - // w2 * (1 - μ_sum) = 0 let w2 = compute_w2(step); let mu_sum = compute_mu_sum(step); &w2 * (&one - &mu_sum) @@ -1191,6 +977,11 @@ impl TransitionConstraint for MemwConstrai } /// Creates all constraints for the MEMW table. +/// +/// 9 constraints total: +/// - IS_BIT<μ_sum> (1) +/// - w2 => μ_sum (1) +/// - IS_BIT for add_limb_overflow[0..6] (7) pub fn constraints() -> Vec>> { let mut constraints: Vec>> = Vec::new(); @@ -1211,29 +1002,10 @@ pub fn constraints() -> Vec vec![cols::WRITE2, cols::WRITE4, cols::WRITE8], // w2 - 1 | 2 => vec![cols::WRITE4, cols::WRITE8], // w4 - _ => vec![cols::WRITE8], // write8 (i = 3..6) - }; - - // ADD constraint produces 2 constraints (carry_0, carry_1) - let (c0, c1) = AddConstraint::new_pair(condition, lhs, rhs, sum, idx); - constraints.push(Box::new(c0)); - constraints.push(Box::new(c1)); - idx += 2; + // IS_BIT for add_limb_overflow[0..6] + for &col in &cols::ADD_LIMB_OVERFLOW { + constraints.push(Box::new(IsBitConstraint::unconditional(col, idx))); + idx += 1; } constraints @@ -1259,17 +1031,60 @@ mod tests { #[test] fn test_write_flags() { - // "Exactly N" semantics per spec let op1 = MemwOperation::new(false, 0, [0; 8], 0, 1, false); - assert_eq!(op1.write_flags(), (false, false, false)); // no flags for 1 byte + assert_eq!(op1.write_flags(), (false, false, false)); let op2 = MemwOperation::new(false, 0, [0; 8], 0, 2, false); - assert_eq!(op2.write_flags(), (true, false, false)); // write2 only + assert_eq!(op2.write_flags(), (true, false, false)); let op4 = MemwOperation::new(false, 0, [0; 8], 0, 4, false); - assert_eq!(op4.write_flags(), (false, true, false)); // write4 only + assert_eq!(op4.write_flags(), (false, true, false)); let op8 = MemwOperation::new(false, 0, [0; 8], 0, 8, false); - assert_eq!(op8.write_flags(), (false, false, true)); // write8 only + assert_eq!(op8.write_flags(), (false, false, true)); + } + + #[test] + fn test_add_limb_overflow() { + // Address 0xFFFF_FFFF should overflow when adding 1 + let op = + MemwOperation::new(false, 0xFFFF_FFFF, [0; 8], 100, 8, false).with_old([0; 8], [50; 8]); + let trace = generate_memw_trace(&[op]); + + // All 7 overflow flags should be 1 since 0xFFFF_FFFF + i >= 2^32 for i >= 1 + for i in 0..7 { + let val = trace.get_main(0, cols::ADD_LIMB_OVERFLOW[i]); + assert_eq!(*val, FE::one(), "overflow[{i}] should be 1"); + } + + // Address 0x0000_0000 should not overflow + let op2 = + MemwOperation::new(false, 0x0000_0000, [0; 8], 100, 8, false).with_old([0; 8], [50; 8]); + let trace2 = generate_memw_trace(&[op2]); + for i in 0..7 { + let val = trace2.get_main(0, cols::ADD_LIMB_OVERFLOW[i]); + assert_eq!(*val, FE::zero(), "overflow[{i}] should be 0"); + } + + // Address 0xFFFF_FFFE with width=8 exercises mixed per-byte carry bits: + // overflow[0]=0 (0xFFFF_FFFE+1 = 0xFFFF_FFFF < 2^32) + // overflow[1..6]=1 (0xFFFF_FFFE+2..8 >= 2^32) + let op3 = + MemwOperation::new(false, 0xFFFF_FFFE, [0; 8], 100, 8, false).with_old([0; 8], [50; 8]); + let trace3 = generate_memw_trace(&[op3]); + let val0 = trace3.get_main(0, cols::ADD_LIMB_OVERFLOW[0]); + assert_eq!( + *val0, + FE::zero(), + "overflow[0] should be 0 for base 0xFFFF_FFFE" + ); + for i in 1..7 { + let val = trace3.get_main(0, cols::ADD_LIMB_OVERFLOW[i]); + assert_eq!( + *val, + FE::one(), + "overflow[{i}] should be 1 for base 0xFFFF_FFFE" + ); + } } } diff --git a/prover/src/tables/mod.rs b/prover/src/tables/mod.rs index 2000ae414..90c0910f3 100644 --- a/prover/src/tables/mod.rs +++ b/prover/src/tables/mod.rs @@ -41,12 +41,12 @@ pub use types::BusId; /// Per-table maximum rows, sized so each chunk uses roughly the same memory. /// /// Effective width = main_cols + 3 × bus_interactions (extension field = 3× cost). -/// MEMW (effective width 241) at 2^19 is the baseline; other tables are scaled -/// proportionally: max_rows = (241 × 2^19) / effective_width, rounded to 2^N. +/// MEMW (effective width 127) at 2^19 is the baseline; other tables are scaled +/// proportionally: max_rows = (127 × 2^19) / effective_width, rounded to 2^N. /// /// | Table | Main | Bus | Eff.width | Max rows | /// |--------|------|-----|-----------|----------| -/// | MEMW | 70 | 57 | 241 | 2^19 | +/// | MEMW | 49 | 26 | 127 | 2^19 | /// | CPU | 74 | 40 | 194 | 2^19 | /// | DVRM | 34 | 34 | 136 | 2^19 | /// | MUL | 26 | 16 | 74 | 2^20 | @@ -56,7 +56,7 @@ pub use types::BusId; /// | BRANCH | 14 | 6 | 32 | 2^21 | pub mod max_rows { pub const CPU: usize = 1 << 19; // 524,288 — eff. width 194 - pub const MEMW: usize = 1 << 19; // 524,288 — eff. width 241 (baseline) + pub const MEMW: usize = 1 << 19; // 524,288 — eff. width 127 (baseline) pub const DVRM: usize = 1 << 19; // 524,288 — eff. width 136 pub const MUL: usize = 1 << 20; // 1,048,576 — eff. width 74 pub const LT: usize = 1 << 21; // 2,097,152 — eff. width 42 diff --git a/prover/src/tables/trace_builder.rs b/prover/src/tables/trace_builder.rs index 3261c31b5..8ae25b9bd 100644 --- a/prover/src/tables/trace_builder.rs +++ b/prover/src/tables/trace_builder.rs @@ -11,8 +11,8 @@ //! PHASE 0: ELF → DECODE, MEMORY_INIT (preprocessed tables) //! PHASE 1: Logs → CPU ops //! PHASE 2: CPU ops → MEMW, LOAD, LT, Bitwise (with state tracking for MEMW/LOAD) -//! PHASE 3: MEMW → LT ops (timestamp ordering, overflow checks) -//! PHASE 4: LT, MEMW → Bitwise lookups +//! PHASE 3: MEMW → LT ops (timestamp ordering) +//! PHASE 4: LT → Bitwise lookups //! PHASE 5: Generate all traces //! ``` //! @@ -701,25 +701,24 @@ fn collect_halt_ops(register_state: &mut RegisterState) -> Vec { // Phase 3: MEMW → LT // ============================================================================= -/// Collects LT operations from MEMW for timestamp ordering and overflow checks. +/// Collects LT operations from MEMW for timestamp ordering. /// /// From spec memw.md: -/// - C7-C10: old_timestamp[i] < timestamp (based on width) -/// - R1-R3: base_address < base_address + offset (overflow checks) +/// - MEMW-C4 through MEMW-C7: old_timestamp[i] < timestamp (based on width) /// /// Returns: Vec of LT operations fn collect_lt_from_memw(memw_ops: &[MemwOperation]) -> Vec { let mut lt_ops = Vec::with_capacity(memw_ops.len() * 8); for memw_op in memw_ops { - // C7: old_timestamp[0] < timestamp (all accesses) + // MEMW-C4: old_timestamp[0] < timestamp (all accesses) lt_ops.push(LtOperation::new( memw_op.old_timestamp[0], memw_op.timestamp, false, )); - // C8: old_timestamp[1] < timestamp (width >= 2) + // MEMW-C5: old_timestamp[1] < timestamp (width >= 2) if memw_op.width >= 2 { lt_ops.push(LtOperation::new( memw_op.old_timestamp[1], @@ -728,7 +727,7 @@ fn collect_lt_from_memw(memw_ops: &[MemwOperation]) -> Vec { )); } - // C9: old_timestamp[2,3] < timestamp (width >= 4) + // MEMW-C6: old_timestamp[2,3] < timestamp (width >= 4) if memw_op.width >= 4 { lt_ops.push(LtOperation::new( memw_op.old_timestamp[2], @@ -742,7 +741,7 @@ fn collect_lt_from_memw(memw_ops: &[MemwOperation]) -> Vec { )); } - // C10: old_timestamp[4..7] < timestamp (width == 8) + // MEMW-C7: old_timestamp[4..7] < timestamp (width == 8) if memw_op.width == 8 { for i in 4..8 { lt_ops.push(LtOperation::new( @@ -752,22 +751,6 @@ fn collect_lt_from_memw(memw_ops: &[MemwOperation]) -> Vec { )); } } - - // R1-R3: Address overflow checks (unconditional per MEMW-CR13/14/15) - // If overflow occurs, LT returns lt=0 and the constraint (expecting lt=1) - // rejects the proof via value mismatch. - if memw_op.width == 2 { - let addr_plus_1 = memw_op.base_address.wrapping_add(1); - lt_ops.push(LtOperation::new(memw_op.base_address, addr_plus_1, false)); - } - if memw_op.width == 4 { - let addr_plus_3 = memw_op.base_address.wrapping_add(3); - lt_ops.push(LtOperation::new(memw_op.base_address, addr_plus_3, false)); - } - if memw_op.width == 8 { - let addr_plus_7 = memw_op.base_address.wrapping_add(7); - lt_ops.push(LtOperation::new(memw_op.base_address, addr_plus_7, false)); - } } lt_ops @@ -1054,30 +1037,6 @@ fn collect_bitwise_from_dvrm(dvrm_ops: &[(DvrmOperation, bool)]) -> Vec Vec { - let mut bitwise_ops = Vec::with_capacity(memw_ops.len() * 28); // 7 addresses * 4 halfwords - - for memw_op in memw_ops { - for i in 0..7u64 { - let addr_add = memw_op.base_address.wrapping_add(i + 1); - // Extract 4 halfwords (DWordHL packing) - for shift in [0, 16, 32, 48] { - let half = ((addr_add >> shift) & 0xFFFF) as u16; - bitwise_ops.push(BitwiseOperation::halfword( - BitwiseOperationType::IsHalf, - (half & 0xFF) as u8, - (half >> 8) as u8, - )); - } - } - } - - bitwise_ops -} - /// Collects bitwise lookups from BRANCH operations. /// /// BRANCH sends: @@ -1715,7 +1674,6 @@ impl Traces { // PHASE 4: All → Bitwise lookups // ===================================================================== bitwise_ops.extend(collect_bitwise_from_lt(<_ops)); - bitwise_ops.extend(collect_bitwise_from_memw(&memw_ops)); bitwise_ops.extend(collect_bitwise_from_mul(&mul_ops)); bitwise_ops.extend(collect_bitwise_from_dvrm(&dvrm_ops)); bitwise_ops.extend(collect_bitwise_from_branch(&branch_ops)); @@ -1936,7 +1894,6 @@ impl Traces { // PHASE 4: All → Bitwise lookups // ===================================================================== bitwise_ops.extend(collect_bitwise_from_lt(<_ops)); - bitwise_ops.extend(collect_bitwise_from_memw(&memw_ops)); bitwise_ops.extend(collect_bitwise_from_mul(&mul_ops)); bitwise_ops.extend(collect_bitwise_from_dvrm(&dvrm_ops)); bitwise_ops.extend(collect_bitwise_from_branch(&branch_ops)); diff --git a/prover/src/test_utils.rs b/prover/src/test_utils.rs index c5f730559..09ba51345 100644 --- a/prover/src/test_utils.rs +++ b/prover/src/test_utils.rs @@ -352,20 +352,6 @@ pub fn collect_bitwise_ops_from_load( .collect() } -/// Collect LT operations from MEMW operations. -/// -/// The MEMW table sends LT lookups for: -/// - Timestamp ordering: old_timestamp[i] < timestamp -/// - Overflow checking: base_address < base_address + offset -pub fn collect_lt_lookups_from_memw( - memw_ops: &[crate::tables::memw::MemwOperation], -) -> Vec { - memw_ops - .iter() - .flat_map(|op| op.collect_lt_lookups()) - .collect() -} - // ============================================================================= // Minimal Trace Generation (for testing/benchmarking only) // ============================================================================= diff --git a/prover/src/tests/prove_elfs_tests.rs b/prover/src/tests/prove_elfs_tests.rs index 9c77d24c8..d1713dc4f 100644 --- a/prover/src/tests/prove_elfs_tests.rs +++ b/prover/src/tests/prove_elfs_tests.rs @@ -593,6 +593,10 @@ fn test_prove_elfs_test_lb_lh_8() { fn test_prove_elfs_test_sb_sh_8() { let (elf, logs, _instructions) = run_asm_elf("test_sb_sh_8"); let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default()).unwrap(); + assert!( + !traces.memws.is_empty(), + "test_sb_sh_8 should produce MEMW rows for byte/halfword memory accesses" + ); assert!( prove_and_verify_vm_minimal(&elf, &mut traces), "test_sb_sh_8 failed" @@ -1173,25 +1177,13 @@ fn test_debug_memory_tokens_sb_sh() { let val1 = memw.main_table.get(row, memw_cols::VALUE[1]).to_raw(); let old1 = memw.main_table.get(row, memw_cols::OLD[1]).to_raw(); - // address_add(0) = base + 1, stored as DWordHL - let addr1_lo0 = memw - .main_table - .get(row, memw_cols::address_add(0)[0]) - .to_raw(); - let addr1_lo1 = memw - .main_table - .get(row, memw_cols::address_add(0)[1]) - .to_raw(); - let addr1_hi0 = memw - .main_table - .get(row, memw_cols::address_add(0)[2]) - .to_raw(); - let addr1_hi1 = memw + // address_add(0) = base + 1, now virtual (computed from base + overflow) + let overflow0 = memw .main_table - .get(row, memw_cols::address_add(0)[3]) + .get(row, memw_cols::ADD_LIMB_OVERFLOW[0]) .to_raw(); - let addr1_lo = addr1_lo0 + (addr1_lo1 << 16); - let addr1_hi = addr1_hi0 + (addr1_hi1 << 16); + let addr1_lo = base_lo + 1 - overflow0 * (1u64 << 32); + let addr1_hi = base_hi + overflow0; // CM16: SEND old token for byte 1 let send_token1: Token = (is_reg, addr1_lo, addr1_hi, old_ts1_lo, old_ts1_hi, old1);