diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 2a9d2c912..7cb947f6f 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -41,8 +41,8 @@ use crate::tables::types::BusId; use crate::test_utils::{ E, F, VmAir, create_bitwise_air, create_branch_air, create_commit_air, create_cpu_air, create_decode_air, create_dvrm_air, create_halt_air, create_load_air, create_lt_air, - create_memw_air, create_memw_aligned_air, create_mul_air, create_page_air, create_register_air, - create_shift_air, + create_memw_air, create_memw_aligned_air, create_memw_register_air, create_mul_air, + create_page_air, create_register_air, create_shift_air, }; use stark::proof::options::{GoldilocksCubicProofOptions, ProofOptions}; @@ -73,6 +73,7 @@ pub struct TableCounts { pub dvrm: usize, pub shift: usize, pub branch: usize, + pub memw_register: usize, } impl TableCounts { @@ -91,6 +92,7 @@ impl TableCounts { + self.dvrm + self.shift + self.branch + + self.memw_register } /// Validate that all required tables have at least one chunk. @@ -108,6 +110,7 @@ impl TableCounts { ("dvrm", self.dvrm), ("shift", self.shift), ("branch", self.branch), + ("memw_register", self.memw_register), ]; for (name, count) in checks { if count == 0 { @@ -195,6 +198,7 @@ pub(crate) struct VmAirs { pub commit: VmAir, pub register: VmAir, pub pages: Vec, + pub memw_registers: Vec, } impl VmAirs { @@ -242,6 +246,13 @@ impl VmAirs { for (air, trace) in self.pages.iter().zip(traces.pages.iter_mut()) { pairs.push((air, trace, &())); } + for (air, trace) in self + .memw_registers + .iter() + .zip(traces.memw_registers.iter_mut()) + { + pairs.push((air, trace, &())); + } pairs } @@ -286,6 +297,9 @@ impl VmAirs { for air in &self.pages { refs.push(air); } + for air in &self.memw_registers { + refs.push(air); + } refs } @@ -358,6 +372,9 @@ impl VmAirs { ) }) .collect(); + let memw_registers: Vec<_> = (0..table_counts.memw_register) + .map(|i| create_memw_register_air(proof_options).with_name(&format!("MEMW_R[{}]", i))) + .collect(); #[cfg(feature = "debug-checks")] debug_report::print_bus_legend(); @@ -378,6 +395,7 @@ impl VmAirs { commit, register, pages, + memw_registers, } } } diff --git a/prover/src/tables/memw_register.rs b/prover/src/tables/memw_register.rs new file mode 100644 index 000000000..206a9c746 --- /dev/null +++ b/prover/src/tables/memw_register.rs @@ -0,0 +1,520 @@ +//! MEMW_R (Memory Write/Read -- Register) table. +//! +//! Ultra-slim fast path for register accesses. Registers are always 2 words +//! (DWordWL), always aligned, and `is_register=1`, so this table strips out +//! all memory-specific columns (address decomposition, alignment mask, width +//! flags, per-byte old_timestamps). +//! +//! ## Timestamp ordering: IS_HALF instead of LT +//! +//! The general MEMW table proves `old_timestamp < timestamp` by routing through +//! the LT table, which requires extra LT trace rows and bus interactions. +//! MEMW_R instead checks `IS_HALF[timestamp[0] - old_timestamp[0] - 1]`, +//! which proves the delta is in `[1, 2^16]` in a single lookup. This is safe +//! because registers are accessed very frequently — their timestamp deltas are +//! almost always small — and the routing predicate (`is_register_op`) enforces +//! the delta fits before admitting an op into this table. +//! +//! ## Column layout (10 columns) +//! +//! - `ADDRESS`: Byte (register index 0-31) +//! - `TIMESTAMP_0`: Word (low 32 bits) +//! - `TIMESTAMP_1`: Word (high 32 bits) +//! - `VAL_0`: Word (low 32 bits of register value) +//! - `VAL_1`: Word (high 32 bits of register value) +//! - `OLD_0`: Word (low 32 bits of previous value) +//! - `OLD_1`: Word (high 32 bits of previous value) +//! - `OLD_TIMESTAMP_LO`: Word (low 32 bits of old timestamp; upper limb = TIMESTAMP_1) +//! - `MU_READ`: Bit +//! - `MU_WRITE`: Bit +//! +//! ## Virtual +//! +//! - `old_timestamp = [OLD_TIMESTAMP_LO, TIMESTAMP_1]` (shares upper limb!) +//! - `mu_sum = MU_READ + MU_WRITE` +//! +//! ## Bus Interactions (7) +//! - 1 IS_HALFWORD[timestamp_0 - old_timestamp_lo - 1] +//! - 4 Memory bus tokens (read-old + write-new, per word) +//! - 2 MEMW output interactions (read + write, from CPU) + +use math::field::element::FieldElement; +use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraints::transition::TransitionConstraint; +use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; +use stark::table::TableView; +use stark::trace::TraceTable; +use stark::traits::TransitionEvaluationContext; + +use super::memw::MemwOperation; +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; + +// ========================================================================= +// Column indices (10 columns) +// ========================================================================= + +pub mod cols { + /// Register index (0-31). CPU sends base_address = 2*reg_index. + pub const ADDRESS: usize = 0; + + /// Timestamp low 32 bits + pub const TIMESTAMP_0: usize = 1; + /// Timestamp high 32 bits + pub const TIMESTAMP_1: usize = 2; + + /// Register value low 32 bits + pub const VAL_0: usize = 3; + /// Register value high 32 bits + pub const VAL_1: usize = 4; + + /// Previous value low 32 bits + pub const OLD_0: usize = 5; + /// Previous value high 32 bits + pub const OLD_1: usize = 6; + + /// Old timestamp low 32 bits (upper limb shared with TIMESTAMP_1) + pub const OLD_TIMESTAMP_LO: usize = 7; + + /// Read multiplicity + pub const MU_READ: usize = 8; + /// Write multiplicity + pub const MU_WRITE: usize = 9; + + pub const NUM_COLUMNS: usize = 10; +} + +// ========================================================================= +// Trace generation +// ========================================================================= + +/// Generates the MEMW_R trace table from register operations. +/// +/// Reuses `MemwOperation` -- the trace generator divides `base_address` by 2 +/// to recover the register index (CPU sends `2 * register_index`). +pub fn generate_memw_register_trace( + operations: &[MemwOperation], +) -> TraceTable { + let num_rows = operations.len().next_power_of_two().max(4); + let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; + + for (row_idx, op) in operations.iter().enumerate() { + let base = row_idx * cols::NUM_COLUMNS; + + debug_assert_eq!( + op.base_address % 2, + 0, + "register base_address must be even (got {})", + op.base_address + ); + // Both register words must have been last accessed at the same timestamp. + // MEMW_R stores a single old_timestamp_lo and shares TIMESTAMP_1 as the + // upper limb, so if the two words differ, the wrong token would be sent + // to the memory bus. The routing predicate enforces this before dispatch. + debug_assert_eq!( + op.old_timestamp[0], op.old_timestamp[1], + "register words must share old_timestamp ({} != {})", + op.old_timestamp[0], op.old_timestamp[1] + ); + + // ADDRESS = base_address / 2 (CPU sends 2 * register_index) + data[base + cols::ADDRESS] = FE::from(op.base_address / 2); + + // Timestamp split into lo/hi 32-bit words + data[base + cols::TIMESTAMP_0] = FE::from(op.timestamp & 0xFFFF_FFFF); + data[base + cols::TIMESTAMP_1] = FE::from(op.timestamp >> 32); + + // Value: registers are DWordWL = 2 words + data[base + cols::VAL_0] = FE::from(op.value[0]); + data[base + cols::VAL_1] = FE::from(op.value[1]); + + // Old value + data[base + cols::OLD_0] = FE::from(op.old[0]); + data[base + cols::OLD_1] = FE::from(op.old[1]); + + // Old timestamp low (upper limb shared with TIMESTAMP_1) + data[base + cols::OLD_TIMESTAMP_LO] = FE::from(op.old_timestamp[0] & 0xFFFF_FFFF); + + // Multiplicity + data[base + cols::MU_READ] = FE::from(op.is_read as u64); + data[base + cols::MU_WRITE] = FE::from(!op.is_read as u64); + } + + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +// ========================================================================= +// Bus interactions (7 total) +// ========================================================================= + +pub fn bus_interactions() -> Vec { + let mut interactions = Vec::with_capacity(7); + + let mu_sum = Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE); + + // ------------------------------------------------------------------------- + // IS_HALFWORD[timestamp_0 - old_timestamp_lo - 1] with mu_sum + // ------------------------------------------------------------------------- + interactions.push(BusInteraction::sender( + BusId::IsHalfword, + mu_sum.clone(), + vec![BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::TIMESTAMP_0, + }, + LinearTerm::Column { + coefficient: -1, + column: cols::OLD_TIMESTAMP_LO, + }, + LinearTerm::Constant(-1), + ])], + )); + + // ------------------------------------------------------------------------- + // Memory bus read-old (sender, for i=0,1) + // memory[is_register=1, addr_lo=2*ADDRESS+i, addr_hi=0, + // OLD_TIMESTAMP_LO, TIMESTAMP_1, OLD[i]] + // ------------------------------------------------------------------------- + for i in 0..2 { + let addr_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 2, + column: cols::ADDRESS, + }, + LinearTerm::Constant(i as i64), + ]); + + interactions.push(BusInteraction::sender( + BusId::Memory, + mu_sum.clone(), + vec![ + BusValue::constant(1), + addr_lo, + BusValue::constant(0), + BusValue::Packed { + start_column: cols::OLD_TIMESTAMP_LO, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: if i == 0 { cols::OLD_0 } else { cols::OLD_1 }, + packing: Packing::Direct, + }, + ], + )); + } + + // ------------------------------------------------------------------------- + // Memory bus write-new (receiver, for i=0,1) + // memory[is_register=1, addr_lo=2*ADDRESS+i, addr_hi=0, + // TIMESTAMP_0, TIMESTAMP_1, VAL[i]] + // ------------------------------------------------------------------------- + for i in 0..2 { + let addr_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 2, + column: cols::ADDRESS, + }, + LinearTerm::Constant(i as i64), + ]); + + interactions.push(BusInteraction::receiver( + BusId::Memory, + mu_sum.clone(), + vec![ + BusValue::constant(1), + addr_lo, + BusValue::constant(0), + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: if i == 0 { cols::VAL_0 } else { cols::VAL_1 }, + packing: Packing::Direct, + }, + ], + )); + } + + // ------------------------------------------------------------------------- + // CO24: MEMW read receiver (from CPU M1/M3 sender) + // ------------------------------------------------------------------------- + let addr_lo_linear = BusValue::linear(vec![LinearTerm::Column { + coefficient: 2, + column: cols::ADDRESS, + }]); + + interactions.push(BusInteraction::receiver( + BusId::Memw, + Multiplicity::Column(cols::MU_READ), + vec![ + // old[0..8] + BusValue::Packed { + start_column: cols::OLD_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OLD_1, + packing: Packing::Direct, + }, + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + // is_register = 1 + BusValue::constant(1), + // base_address = [2*ADDRESS, 0] + addr_lo_linear.clone(), + BusValue::constant(0), + // value[0..8] + BusValue::Packed { + start_column: cols::VAL_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VAL_1, + packing: Packing::Direct, + }, + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + // timestamp + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + // write flags: write2=1, write4=0, write8=0 (registers are always 2 words) + BusValue::constant(1), + BusValue::constant(0), + BusValue::constant(0), + ], + )); + + // ------------------------------------------------------------------------- + // CO25: MEMW write receiver (from CPU M5 sender — register write to rd) + // ------------------------------------------------------------------------- + interactions.push(BusInteraction::receiver( + BusId::Memw, + Multiplicity::Column(cols::MU_WRITE), + vec![ + // is_register = 1 + BusValue::constant(1), + // base_address = [2*ADDRESS, 0] + addr_lo_linear, + BusValue::constant(0), + // value[0..8] + BusValue::Packed { + start_column: cols::VAL_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VAL_1, + packing: Packing::Direct, + }, + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + // timestamp + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + // write flags: write2=1, write4=0, write8=0 + BusValue::constant(1), + BusValue::constant(0), + BusValue::constant(0), + ], + )); + + interactions +} + +// ========================================================================= +// Constraints (3 algebraic) +// ========================================================================= + +/// MEMW_R constraint: IS_BIT(mu_sum) = (mu_read + mu_write) * (1 - mu_read - mu_write) = 0 +pub struct MemwRegisterMuSumIsBit { + constraint_idx: usize, +} + +impl MemwRegisterMuSumIsBit { + pub fn new(constraint_idx: usize) -> Self { + Self { constraint_idx } + } + + fn compute(&self, step: &TableView) -> FieldElement + where + F: IsSubFieldOf, + E: IsField, + { + let one = FieldElement::::one(); + let mu_read = step.get_main_evaluation_element(0, cols::MU_READ).clone(); + let mu_write = step.get_main_evaluation_element(0, cols::MU_WRITE).clone(); + let mu_sum = &mu_read + &mu_write; + &mu_sum * (&one - &mu_sum) + } +} + +impl TransitionConstraint for MemwRegisterMuSumIsBit { + fn degree(&self) -> usize { + 2 + } + + fn constraint_idx(&self) -> usize { + self.constraint_idx + } + + fn end_exemptions(&self) -> usize { + 0 + } + + fn evaluate( + &self, + evaluation_context: &TransitionEvaluationContext, + transition_evaluations: &mut [FieldElement], + ) { + match evaluation_context { + TransitionEvaluationContext::Prover { + frame, + periodic_values: _, + rap_challenges: _, + .. + } => { + let v = self.compute(frame.get_evaluation_step(0)); + transition_evaluations[self.constraint_idx] = v.to_extension(); + } + TransitionEvaluationContext::Verifier { + frame, + periodic_values: _, + rap_challenges: _, + .. + } => { + let v = self.compute(frame.get_evaluation_step(0)); + transition_evaluations[self.constraint_idx] = v; + } + } + } +} + +/// Creates all constraints for the MEMW_R table (3 total). +/// +/// - IS_BIT(MU_READ) -- unconditional +/// - IS_BIT(MU_WRITE) -- unconditional +/// - IS_BIT(mu_sum) = (mu_read + mu_write) * (1 - mu_read - mu_write) = 0 +pub fn constraints() -> Vec>> { + use crate::constraints::templates::IsBitConstraint; + + vec![ + Box::new(IsBitConstraint::unconditional(cols::MU_READ, 0)), + Box::new(IsBitConstraint::unconditional(cols::MU_WRITE, 1)), + Box::new(MemwRegisterMuSumIsBit::new(2)), + ] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_memw_register_trace_generation() { + // Create a simple register op (reg x1 = address 1, so base_address = 2) + let ops = vec![ + MemwOperation::new( + true, // is_register + 2, // base_address = 2 * register_index (reg x1) + [42, 7, 0, 0, 0, 0, 0, 0], + 100, + 2, // width = 2 words (registers are DWordWL) + true, + ) + .with_old([10, 3, 0, 0, 0, 0, 0, 0], [50, 50, 0, 0, 0, 0, 0, 0]), + ]; + + let trace = generate_memw_register_trace(&ops); + assert_eq!(trace.num_cols(), cols::NUM_COLUMNS); + assert!(trace.num_rows() >= 4); // minimum 4 rows + + // ADDRESS = base_address / 2 = 2 / 2 = 1 + assert_eq!(*trace.get_main(0, cols::ADDRESS), FE::from(1u64)); + + // TIMESTAMP split + assert_eq!(*trace.get_main(0, cols::TIMESTAMP_0), FE::from(100u64)); + assert_eq!(*trace.get_main(0, cols::TIMESTAMP_1), FE::from(0u64)); + + // Values + assert_eq!(*trace.get_main(0, cols::VAL_0), FE::from(42u64)); + assert_eq!(*trace.get_main(0, cols::VAL_1), FE::from(7u64)); + + // Old values + assert_eq!(*trace.get_main(0, cols::OLD_0), FE::from(10u64)); + assert_eq!(*trace.get_main(0, cols::OLD_1), FE::from(3u64)); + + // Old timestamp lo + assert_eq!(*trace.get_main(0, cols::OLD_TIMESTAMP_LO), FE::from(50u64)); + + // Multiplicity: is_read = true => MU_READ=1, MU_WRITE=0 + assert_eq!(*trace.get_main(0, cols::MU_READ), FE::from(1u64)); + assert_eq!(*trace.get_main(0, cols::MU_WRITE), FE::from(0u64)); + } + + #[test] + fn test_memw_register_trace_generation_write_op() { + // Write op: is_read = false => MU_WRITE=1, MU_READ=0 + let ops = vec![ + MemwOperation::new( + true, // is_register + 4, // base_address = 2 * register_index (reg x2) + [99, 55, 0, 0, 0, 0, 0, 0], + 200, + 2, // width = 2 words + false, // is_read = false (write) + ) + .with_old([11, 22, 0, 0, 0, 0, 0, 0], [180, 180, 0, 0, 0, 0, 0, 0]), + ]; + + let trace = generate_memw_register_trace(&ops); + + // ADDRESS = base_address / 2 = 4 / 2 = 2 + assert_eq!(*trace.get_main(0, cols::ADDRESS), FE::from(2u64)); + + // Values + assert_eq!(*trace.get_main(0, cols::VAL_0), FE::from(99u64)); + assert_eq!(*trace.get_main(0, cols::VAL_1), FE::from(55u64)); + + // Old values + assert_eq!(*trace.get_main(0, cols::OLD_0), FE::from(11u64)); + assert_eq!(*trace.get_main(0, cols::OLD_1), FE::from(22u64)); + + // Old timestamp lo + assert_eq!(*trace.get_main(0, cols::OLD_TIMESTAMP_LO), FE::from(180u64)); + + // Multiplicity: is_read = false => MU_WRITE=1, MU_READ=0 + assert_eq!(*trace.get_main(0, cols::MU_READ), FE::from(0u64)); + assert_eq!(*trace.get_main(0, cols::MU_WRITE), FE::from(1u64)); + } +} diff --git a/prover/src/tables/mod.rs b/prover/src/tables/mod.rs index 1bb351583..551dc4aa3 100644 --- a/prover/src/tables/mod.rs +++ b/prover/src/tables/mod.rs @@ -32,6 +32,7 @@ pub mod load; pub mod lt; pub mod memw; pub mod memw_aligned; +pub mod memw_register; pub mod mul; pub mod page; pub mod register; @@ -59,6 +60,7 @@ pub use types::BusId; /// | SHIFT | 27 | 15 | 72 | 2^20 | /// | LOAD | 18 | 5 | 33 | 2^21 | /// | BRANCH | 14 | 6 | 32 | 2^21 | +/// | MEMW_R | 10 | 7 | 31 | 2^21 | pub mod max_rows { pub const CPU: usize = 1 << 19; // 524,288 — eff. width 194 pub const MEMW: usize = 1 << 19; // 524,288 — eff. width 127 (baseline) @@ -69,6 +71,7 @@ pub mod max_rows { pub const SHIFT: usize = 1 << 20; // 1,048,576 — eff. width 72 pub const LOAD: usize = 1 << 21; // 2,097,152 — eff. width 33 pub const BRANCH: usize = 1 << 21; // 2,097,152 — eff. width 32 + pub const MEMW_R: usize = 1 << 21; // 2,097,152 — eff. width 31 } /// Per-table maximum row limits, configurable for different environments. @@ -86,6 +89,7 @@ pub struct MaxRowsConfig { pub shift: usize, pub load: usize, pub branch: usize, + pub memw_register: usize, } impl Default for MaxRowsConfig { @@ -100,6 +104,7 @@ impl Default for MaxRowsConfig { shift: max_rows::SHIFT, load: max_rows::LOAD, branch: max_rows::BRANCH, + memw_register: max_rows::MEMW_R, } } } @@ -118,6 +123,7 @@ impl MaxRowsConfig { shift: 1 << 5, load: 1 << 5, branch: 1 << 5, + memw_register: 1 << 5, } } } diff --git a/prover/src/tables/trace_builder.rs b/prover/src/tables/trace_builder.rs index 14e88942d..6cbcd0735 100644 --- a/prover/src/tables/trace_builder.rs +++ b/prover/src/tables/trace_builder.rs @@ -10,9 +10,9 @@ //! ```text //! PHASE 0: ELF → DECODE, MEMORY_INIT (preprocessed tables) //! PHASE 1: Logs → CPU ops -//! PHASE 2: CPU ops → MEMW, MEMW_A, LOAD, LT, Bitwise (with state tracking for MEMW/LOAD) -//! PHASE 3: MEMW/MEMW_A → LT ops (timestamp ordering) -//! PHASE 4: LT, MEMW_A → Bitwise lookups +//! PHASE 2: CPU ops → MEMW, MEMW_A, MEMW_R, LOAD, LT, Bitwise (with state tracking for MEMW/LOAD) +//! PHASE 3: MEMW/MEMW_A → LT ops (timestamp ordering); MEMW_R uses IS_HALFWORD instead +//! PHASE 4: LT, MEMW_A, MEMW_R → Bitwise lookups //! PHASE 5: Generate all traces //! ``` //! @@ -44,6 +44,7 @@ use super::load::{self, LoadOperation}; use super::lt::{self, LtOperation}; use super::memw::{self, MemwOperation}; use super::memw_aligned; +use super::memw_register; use super::mul::{self, MulOperation}; use super::page::{self, FinalByteState, FinalStateMap, PageConfig}; use super::register::{self, FinalRegisterStateMap, FinalRegisterWordState}; @@ -902,6 +903,59 @@ fn collect_bitwise_from_memw_aligned(ops: &[MemwOperation]) -> Vec old_timestamp[0] (lower limb ordering) +/// 5. timestamp[0] - old_timestamp[0] <= 0x10000 (delta fits in IS_HALF range [1, 2^16]) +/// +/// Width-1 register ops (e.g. COMMIT x254) stay in MEMW, which has +/// dynamic write flags. MEMW_R hardcodes write2=1. +fn is_register_op(op: &MemwOperation) -> bool { + if !op.is_register || op.width != 2 { + return false; + } + // Both words must share old_timestamp (atomic register write assumption) + if op.old_timestamp[0] != op.old_timestamp[1] { + return false; + } + let ts = op.timestamp; + let old_ts = op.old_timestamp[0]; + let ts_lo = ts & 0xFFFF_FFFF; + let old_ts_lo = old_ts & 0xFFFF_FFFF; + let ts_hi = ts >> 32; + let old_ts_hi = old_ts >> 32; + ts_hi == old_ts_hi && ts_lo > old_ts_lo && (ts_lo - old_ts_lo) <= 0x10000 +} + +/// Collects IS_HALFWORD bitwise lookups for MEMW_R operations. +/// +/// For each register op: checks that `timestamp[0] - old_timestamp_lo - 1` fits +/// in a halfword (proving the timestamp delta is in range [1, 2^16]). +fn collect_bitwise_from_memw_register(ops: &[MemwOperation]) -> Vec { + ops.iter() + .map(|op| { + let ts_lo = op.timestamp & 0xFFFF_FFFF; + let old_ts_lo = op.old_timestamp[0] & 0xFFFF_FFFF; + debug_assert!( + ts_lo > old_ts_lo, + "ts_lo must exceed old_ts_lo (enforced by is_register_op)" + ); + let diff_minus_1 = (ts_lo - old_ts_lo - 1) as u16; + BitwiseOperation::halfword( + BitwiseOperationType::IsHalf, + (diff_minus_1 & 0xFF) as u8, + (diff_minus_1 >> 8) as u8, + ) + }) + .collect() +} + // ============================================================================= // Phase 4: All → Bitwise lookups // ============================================================================= @@ -1569,6 +1623,9 @@ pub struct Traces { /// COMMIT table for write syscall (byte-by-byte commit with recursive bus) pub commit: TraceTable, + + /// MEMW_R register-only fast-path traces (split into chunks of max_rows::MEMW_R) + pub memw_registers: Vec>, } /// Chunk raw ops and generate one trace table per chunk. @@ -1597,6 +1654,7 @@ impl Traces { dvrm: self.dvrms.len(), shift: self.shifts.len(), branch: self.branches.len(), + memw_register: self.memw_registers.len(), } } @@ -1753,7 +1811,10 @@ impl Traces { let halt_memw_ops = collect_halt_ops(&mut register_state); memw_ops.extend(halt_memw_ops); - // Route MEMW operations: aligned ops → MEMW_A, rest → MEMW + // Route MEMW_R (register fast-path) first, then MEMW_A (aligned), rest → MEMW. + // Order matters: register ops would also pass is_aligned_op, so check first. + let (memw_register_ops, memw_ops): (Vec<_>, Vec<_>) = + memw_ops.into_iter().partition(is_register_op); let (memw_aligned_ops, memw_ops): (Vec<_>, Vec<_>) = memw_ops.into_iter().partition(is_aligned_op); @@ -1834,6 +1895,8 @@ impl Traces { bitwise_ops.extend(collect_bitwise_from_branch(&branch_ops)); bitwise_ops.extend(shift::collect_bitwise_from_shift(&shift_ops)); bitwise_ops.extend(collect_bitwise_from_memw_aligned(&memw_aligned_ops)); + // MEMW_R sends IS_HALFWORD[timestamp_0 - old_timestamp_lo - 1] + bitwise_ops.extend(collect_bitwise_from_memw_register(&memw_register_ops)); // PAGE tables do IS_BYTE lookups for init and fini values (C1, C2) bitwise_ops.extend(collect_bitwise_from_page(elf, &memory_state)); @@ -1872,6 +1935,11 @@ impl Traces { max_rows.memw_aligned, memw_aligned::generate_memw_aligned_trace, ); + let memw_registers = chunk_and_generate( + &memw_register_ops, + max_rows.memw_register, + memw_register::generate_memw_register_trace, + ); let loads = chunk_and_generate(&load_ops, max_rows.load, load::generate_load_trace); let lts = chunk_and_generate(<_ops, max_rows.lt, lt::generate_lt_trace); let shifts = chunk_and_generate(&shift_ops, max_rows.shift, shift::generate_shift_trace); @@ -1954,6 +2022,7 @@ impl Traces { branches, halt: halt_trace, commit: commit_trace, + memw_registers, }) } @@ -1989,7 +2058,9 @@ impl Traces { let halt_memw_ops = collect_halt_ops(&mut register_state); memw_ops.extend(halt_memw_ops); - // Route MEMW operations: aligned ops → MEMW_A, rest → MEMW + // Route MEMW_R (register fast-path) first, then MEMW_A (aligned), rest → MEMW. + let (memw_register_ops, memw_ops): (Vec<_>, Vec<_>) = + memw_ops.into_iter().partition(is_register_op); let (memw_aligned_ops, memw_ops): (Vec<_>, Vec<_>) = memw_ops.into_iter().partition(is_aligned_op); @@ -2070,6 +2141,8 @@ impl Traces { bitwise_ops.extend(collect_bitwise_from_branch(&branch_ops)); bitwise_ops.extend(shift::collect_bitwise_from_shift(&shift_ops)); bitwise_ops.extend(collect_bitwise_from_memw_aligned(&memw_aligned_ops)); + // MEMW_R sends IS_HALFWORD[timestamp_0 - old_timestamp_lo - 1] + bitwise_ops.extend(collect_bitwise_from_memw_register(&memw_register_ops)); let public_output_bytes: Vec = commit_ops .iter() @@ -2105,6 +2178,11 @@ impl Traces { max_rows.memw_aligned, memw_aligned::generate_memw_aligned_trace, ); + let memw_registers = chunk_and_generate( + &memw_register_ops, + max_rows.memw_register, + memw_register::generate_memw_register_trace, + ); let loads = chunk_and_generate(&load_ops, max_rows.load, load::generate_load_trace); let lts = chunk_and_generate(<_ops, max_rows.lt, lt::generate_lt_trace); let shifts = chunk_and_generate(&shift_ops, max_rows.shift, shift::generate_shift_trace); @@ -2173,6 +2251,7 @@ impl Traces { branches, halt: halt_trace, commit: commit_trace, + memw_registers, }) } @@ -2224,3 +2303,54 @@ impl Traces { Self::from_logs_trimmed(logs, instructions, max_rows) } } + +#[cfg(test)] +mod routing_tests { + use super::*; + + fn make_register_op(timestamp: u64, old_timestamp: u64) -> MemwOperation { + MemwOperation::new(true, 2, [1, 0, 0, 0, 0, 0, 0, 0], timestamp, 2, false) + .with_old([0; 8], [old_timestamp, old_timestamp, 0, 0, 0, 0, 0, 0]) + } + + #[test] + fn test_is_register_op_delta_at_boundary_routes_in() { + // delta = 0x10000 = 2^16: spec allows this (IS_HALF[0xFFFF] is valid) + let op = make_register_op(0x10000, 0); + assert!(is_register_op(&op), "delta = 2^16 should route to MEMW_R"); + } + + #[test] + fn test_is_register_op_delta_above_boundary_falls_back() { + // delta = 0x10001: one above the IS_HALF range, must fall back to MEMW_A + let op = make_register_op(0x10001, 0); + assert!( + !is_register_op(&op), + "delta = 2^16 + 1 should fall back to MEMW_A" + ); + } + + #[test] + fn test_is_register_op_delta_one_routes_in() { + // delta = 1: minimum allowed value + let op = make_register_op(1, 0); + assert!(is_register_op(&op), "delta = 1 should route to MEMW_R"); + } + + #[test] + fn test_is_register_op_delta_zero_falls_back() { + // delta = 0: ts[0] not strictly greater than old_ts[0] + let op = make_register_op(5, 5); + assert!(!is_register_op(&op), "delta = 0 should not route to MEMW_R"); + } + + #[test] + fn test_is_register_op_upper_limb_mismatch_falls_back() { + // ts_hi != old_ts_hi: shared upper limb assumption violated + let op = make_register_op(0x1_0000_0001, 0x0_0000_0000); + assert!( + !is_register_op(&op), + "different upper limbs should fall back to MEMW_A" + ); + } +} diff --git a/prover/src/test_utils.rs b/prover/src/test_utils.rs index 93d6d2971..155af86cb 100644 --- a/prover/src/test_utils.rs +++ b/prover/src/test_utils.rs @@ -54,6 +54,10 @@ use crate::tables::memw_aligned::{ bus_interactions as memw_aligned_bus_interactions, cols as memw_aligned_cols, constraints as memw_aligned_constraints, }; +use crate::tables::memw_register::{ + bus_interactions as memw_register_bus_interactions, cols as memw_register_cols, + constraints as memw_register_constraints, +}; use crate::tables::mul::{bus_interactions as mul_bus_interactions, cols as mul_cols}; use crate::tables::page::{bus_interactions as page_bus_interactions, cols as page_cols}; use crate::tables::register::{ @@ -576,6 +580,24 @@ pub fn create_memw_aligned_air(proof_options: &ProofOptions) -> VmAir { .with_name("MEMW_A") } +/// Create MEMW_R (register) AIR with constraints and bus interactions. +pub fn create_memw_register_air(proof_options: &ProofOptions) -> VmAir { + let transition_constraints = memw_register_constraints(); + + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: memw_register_bus_interactions(), + }; + + AirWithBuses::new( + memw_register_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("MEMW_R") +} + /// Create LOAD AIR with constraints and bus interactions. pub fn create_load_air(proof_options: &ProofOptions) -> VmAir { let transition_constraints = load_constraints(); diff --git a/prover/src/tests/prove_elfs_tests.rs b/prover/src/tests/prove_elfs_tests.rs index 5186a61b2..9a1b8a6d0 100644 --- a/prover/src/tests/prove_elfs_tests.rs +++ b/prover/src/tests/prove_elfs_tests.rs @@ -1663,6 +1663,33 @@ fn test_heap_alloc_runtime_pages_roundtrip() { ); } +/// Verify that register ops route to MEMW_R and a full prove/verify roundtrip +/// succeeds. Uses `test_add_8` which exercises register reads and writes. +#[test] +fn test_prove_verify_with_memw_register() { + let (elf, logs, instructions) = run_asm_elf("test_add_8"); + let mut traces = + Traces::from_logs_minimal(&logs, instructions.clone(), &Default::default()).unwrap(); + + // Register ops must go to MEMW_R, not to MEMW_A. + assert!( + !traces.memw_registers.is_empty(), + "register ops should route to MEMW_R: memw_registers must be non-empty" + ); + + // MEMW_A should still have non-register aligned ops (e.g. stack stores). + assert!( + !traces.memw_aligneds.is_empty(), + "MEMW_A should still have aligned non-register ops" + ); + + // Full prove + verify roundtrip. + assert!( + prove_and_verify_vm_minimal(&elf, &mut traces), + "prove/verify should succeed when MEMW_R handles register ops" + ); +} + /// Verify rejects table_counts with all zeros. #[test] fn test_verify_rejects_zero_table_counts() { @@ -1689,6 +1716,7 @@ fn test_verify_rejects_zero_table_counts() { dvrm: 0, shift: 0, branch: 0, + memw_register: 0, }, ..vm_proof }; @@ -1756,6 +1784,7 @@ fn test_crafted_zero_count_proof_must_not_verify() { dvrm: 0, shift: 0, branch: 0, + memw_register: 0, }; let airs = VmAirs::new(&elf, &proof_options, true, &[], &zero_counts); diff --git a/prover/src/tests/trace_builder_tests.rs b/prover/src/tests/trace_builder_tests.rs index 542780e51..a21f2c273 100644 --- a/prover/src/tests/trace_builder_tests.rs +++ b/prover/src/tests/trace_builder_tests.rs @@ -3,7 +3,7 @@ use crate::tables::bitwise; use crate::tables::cpu::cols; use crate::tables::lt; -use crate::tables::memw_aligned; +use crate::tables::memw_register; use crate::tables::trace_builder::Traces; use crate::tables::types::FE; use executor::vm::instruction::decoding::{ArithOp, Comparison, Instruction}; @@ -393,43 +393,39 @@ fn test_memw_generated_from_register_ops() { let traces = Traces::from_logs(&logs, instructions, &Default::default()).unwrap(); - // MEMW_A table should have register operations (register ops are always aligned) - // First instruction generates: M1 (read x2), M3 (read x3), M5 (write x1) + // Register ops should route to MEMW_R (memw_registers), not MEMW_A. + // First instruction generates: M1 (read x2), M3 (read x3), M5 (write x1). assert!( - !traces.memw_aligneds.is_empty(), - "MEMW_A should have at least one chunk for register ops" + !traces.memw_registers.is_empty(), + "MEMW_R should have at least one chunk for register ops" ); assert!( - traces.memw_aligneds[0].main_table.height >= 3, - "MEMW_A should have at least 3 rows for register ops" + traces.memw_registers[0].main_table.height >= 3, + "MEMW_R should have at least 3 rows for register ops (reads x2, x3 + write x1)" ); - // Find the register write to x1 in MEMW_A - // Register address for x1 = 2*1 = 2, decomposed: base_address[0]=2, base_address[1]=0, base_address[2]=0 + // Find the register write to x1 in MEMW_R. + // MEMW_R columns: ADDRESS = register_index (x1 → index 1), + // MU_WRITE = 1 for writes, VAL_0 = value low 32 bits. let mut found_write = false; - for chunk in &traces.memw_aligneds { - for row_idx in 0..chunk.main_table.height { - let row = chunk.main_table.get_row(row_idx); - // Check for register write: is_register=1, base_address[0]=2, mu_write=1 - if row[memw_aligned::cols::IS_REGISTER] == FE::one() - && row[memw_aligned::cols::BASE_ADDRESS[0]] == FE::from(2u64) - && row[memw_aligned::cols::BASE_ADDRESS[1]] == FE::zero() - && row[memw_aligned::cols::BASE_ADDRESS[2]] == FE::zero() - && row[memw_aligned::cols::MU_WRITE] == FE::one() - { - // Check value is 300 (lo32 word for register DWordWL packing) - assert_eq!(row[memw_aligned::cols::VALUE[0]], FE::from(300u64)); - found_write = true; - break; - } - } - if found_write { + for row_idx in 0..traces.memw_registers[0].main_table.height { + let row = traces.memw_registers[0].main_table.get_row(row_idx); + // ADDRESS = 1 (x1), MU_WRITE = 1, VAL_0 = 300 + if row[memw_register::cols::ADDRESS] == FE::from(1u64) + && row[memw_register::cols::MU_WRITE] == FE::one() + { + assert_eq!( + row[memw_register::cols::VAL_0], + FE::from(300u64), + "Write value for x1 should be 300" + ); + found_write = true; break; } } assert!( found_write, - "Register write to x1 not found in MEMW_A table" + "Register write to x1 (ADDRESS=1, MU_WRITE=1, VAL_0=300) not found in MEMW_R" ); } @@ -478,31 +474,28 @@ fn test_memw_generates_lt_for_timestamp_ordering() { let traces = Traces::from_logs(&logs, instructions, &Default::default()).unwrap(); - // LT table should have ops from MEMW timestamp ordering - // First instruction: 3 register ops (M1, M3, M5) → at least 3 LT ops for C7 - // Each LT op checks old_timestamp < timestamp - // For first access, old_timestamp=0, timestamp=4, so LT(0, 4) should exist + // Register ops route to MEMW_R (IS_HALFWORD, not LT). + assert!( + !traces.memw_registers.is_empty(), + "Register ops should route to MEMW_R" + ); - // Find LT op with lhs=0, rhs=4 (first register read's timestamp check) - let mut found_timestamp_lt = false; - for row_idx in 0..traces.lts[0].main_table.height { - let row = traces.lts[0].main_table.get_row(row_idx); - // Check for LT(0, 4): lhs=0, rhs=4, signed=0 - if row[lt::cols::LHS_0] == FE::zero() - && row[lt::cols::LHS_1] == FE::zero() - && row[lt::cols::LHS_2] == FE::zero() - && row[lt::cols::RHS_0] == FE::from(4u64) - && row[lt::cols::RHS_1] == FE::zero() - && row[lt::cols::RHS_2] == FE::zero() - && row[lt::cols::SIGNED] == FE::zero() - { - found_timestamp_lt = true; - break; - } - } + // Register ops use IS_HALF for timestamp ordering instead of LT. + // Verify the bitwise table has at least one IS_HALF entry with non-zero + // multiplicity, proving that MEMW_R's IS_HALF lookups were emitted. + let has_is_half_entry = (0..traces.bitwise.main_table.height) + .any(|i| traces.bitwise.main_table.get_row(i)[bitwise::cols::MU_IS_HALF] != FE::zero()); + assert!( + has_is_half_entry, + "MEMW_R register ops should produce IS_HALF bitwise entries" + ); + + // The LT table should still have ops from non-register MEMW accesses + // (e.g. PC next-pc write is a non-register memory op that needs LT). + let total_lt_rows: usize = traces.lts.iter().map(|t| t.main_table.height).sum(); assert!( - found_timestamp_lt, - "LT op for timestamp ordering (0 < 4) not found" + total_lt_rows > 0, + "LT table should have ops from non-register MEMW timestamp ordering" ); }