Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 52 additions & 38 deletions prover/src/tables/memw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
//! - `timestamp`: DWordWL (64-bit timestamp, 2 cols)
//! - `write2/4/8`: Bit (access width flags)
//! - `old[8]`: BaseField[8] (previous values at address)
//! - `add_limb_overflow[7]`: Bit[7] (carry flags for base_address + i)
//! - `carry[7]`: Bit[7] (carry flags for base_address + i)
//! - `old_timestamp[8]`: DWordWL[8] (previous timestamps, 16 cols)
//! - `mu_read`, `mu_write`: multiplicity columns
//!
//! ## Virtual (computed inline)
//! - `address_add[i]` = (base_address_0 + i+1 - 2^32 * overflow[i], base_address_1 + overflow[i])
//! - `address_add[i]` = (base_address_0 + i+1 - 2^32 * carry[i], base_address_1 + carry[i])
//! - `w2`: write2 + write4 + write8 (writing at least 2 bytes)
//! - `w4`: write4 + write8 (writing at least 4 bytes)
//! - `μ_sum`: μ_read + μ_write
Expand All @@ -25,6 +25,8 @@
//! - 8 LT timestamp checks (old_timestamp[i] < timestamp)
//! - 16 Memory bus tokens (read old + write new, per byte)
//! - 2 MEMW output interactions (read + write, from CPU)
//!
//! ## Constraints (11 total: 2 custom + 2 IS_BIT for multiplicities + 7 IS_BIT for carry)

use math::field::element::FieldElement;
use math::field::traits::{IsField, IsSubFieldOf};
Expand Down Expand Up @@ -72,8 +74,8 @@ pub mod cols {
pub const OLD: [usize; 8] = [16, 17, 18, 19, 20, 21, 22, 23];

// Auxiliary columns
/// add_limb_overflow[7]: Bit columns indicating carry when adding i+1 to base_address_0
pub const ADD_LIMB_OVERFLOW: [usize; 7] = [24, 25, 26, 27, 28, 29, 30];
/// carry[7]: Bit columns indicating carry when adding i+1 to base_address_0
pub const CARRY: [usize; 7] = [24, 25, 26, 27, 28, 29, 30];

/// old_timestamp[8]: each is DWordWL (2 words = 2 columns)
/// Total: 8 * 2 = 16 columns
Expand Down Expand Up @@ -206,11 +208,11 @@ pub fn generate_memw_trace(
data[base + cols::OLD[i]] = FE::from(op.old[i]);
}

// Auxiliary: add_limb_overflow[7]
// overflow[i] = 1 if (base_address_lo + i+1) >= 2^32
// Auxiliary: carry[7]
// carry[i] = 1 if (base_address_lo + i+1) >= 2^32
for i in 0..7 {
let overflows = base_addr_lo + (i as u64 + 1) >= (1u64 << 32);
data[base + cols::ADD_LIMB_OVERFLOW[i]] = FE::from(overflows as u64);
data[base + cols::CARRY[i]] = FE::from(overflows as u64);
}

// Auxiliary: old_timestamp[8] - each as DWordWL (2 words)
Expand Down Expand Up @@ -245,18 +247,17 @@ pub fn bus_interactions() -> Vec<BusInteraction> {
// Memory bus interactions (16 total)
// -------------------------------------------------------------------------
// address_add[i] is VIRTUAL:
// lo = base_address_0 + (i+1) - 2^32 * add_limb_overflow[i]
// hi = base_address_1 + add_limb_overflow[i]
// lo = base_address_0 + (i+1) - 2^32 * carry[i]
// hi = base_address_1 + carry[i]
//
// Safety: `hi` is at most `base_address_1 + 1`. This never reaches 2^32
// because the CPU table splits addresses into (lo, hi) with both halves
// in [0, 2^32), and the Memw bus ties MEMW's base_address to the CPU's
// value. MEMW only receives accesses where base_address_1 <= 0xFFFF_FFFE
// (addresses near u64::MAX are rejected by the executor before proving).
// Consequently, `add_limb_overflow[i]` is implicitly correct: a wrong
// carry bit produces a memory token at a wrong address that has no
// matching PAGE/REGISTER token, causing multiset imbalance and an
// invalid proof.
// Consequently, `carry[i]` is implicitly correct: a wrong carry bit
// produces a memory token at a wrong address that has no matching
// PAGE/REGISTER token, causing multiset imbalance and an invalid proof.

// CM8: memory[is_register, base_address, old_timestamp[0], old[0]] with +μ_sum
interactions.push(BusInteraction::sender(
Expand Down Expand Up @@ -323,8 +324,8 @@ pub fn bus_interactions() -> Vec<BusInteraction> {
));

// CM10/11: byte 1, multiplicity w2 = write2 + write4 + write8
// address_add[0] is virtual: lo = base_address_0 + 1 - 2^32 * overflow[0]
// hi = base_address_1 + overflow[0]
// address_add[0] is virtual: lo = base_address_0 + 1 - 2^32 * carry[0]
// hi = base_address_1 + carry[0]
let addr_add_0_lo = BusValue::linear(vec![
LinearTerm::Column {
coefficient: 1,
Expand All @@ -333,7 +334,7 @@ pub fn bus_interactions() -> Vec<BusInteraction> {
LinearTerm::Constant(1),
LinearTerm::Column {
coefficient: -(1i64 << 32),
column: cols::ADD_LIMB_OVERFLOW[0],
column: cols::CARRY[0],
},
]);
let addr_add_0_hi = BusValue::linear(vec![
Expand All @@ -343,7 +344,7 @@ pub fn bus_interactions() -> Vec<BusInteraction> {
},
LinearTerm::Column {
coefficient: 1,
column: cols::ADD_LIMB_OVERFLOW[0],
column: cols::CARRY[0],
},
]);

Expand Down Expand Up @@ -401,7 +402,7 @@ pub fn bus_interactions() -> Vec<BusInteraction> {

// CM12/13: bytes 2-3 with multiplicity w4 = write4 + write8
for i in 2..=3 {
let overflow_col = cols::ADD_LIMB_OVERFLOW[i - 1];
let overflow_col = cols::CARRY[i - 1];
let addr_add_lo = BusValue::linear(vec![
LinearTerm::Column {
coefficient: 1,
Expand Down Expand Up @@ -479,7 +480,7 @@ pub fn bus_interactions() -> Vec<BusInteraction> {

// CM14/15: bytes 4-7 with multiplicity write8
for i in 4..=7 {
let overflow_col = cols::ADD_LIMB_OVERFLOW[i - 1];
let overflow_col = cols::CARRY[i - 1];
let addr_add_lo = BusValue::linear(vec![
LinearTerm::Column {
coefficient: 1,
Expand Down Expand Up @@ -857,7 +858,7 @@ where
}

// =========================================================================
// Constraints (9 total: 2 custom + 7 IS_BIT)
// Constraints (11 total: 2 custom + 2 IS_BIT for multiplicities + 7 IS_BIT for carry)
// =========================================================================

/// MEMW table constraint kinds.
Expand Down Expand Up @@ -946,10 +947,12 @@ impl TransitionConstraint<GoldilocksField, GoldilocksExtension> for MemwConstrai

/// Creates all constraints for the MEMW table.
///
/// 9 constraints total:
/// 11 constraints total:
/// - IS_BIT<μ_sum> (1)
/// - w2 => μ_sum (1)
/// - IS_BIT for add_limb_overflow[0..6] (7)
/// - IS_BIT<μ_read> (1)
/// - IS_BIT<μ_write> (1)
/// - IS_BIT for carry[0..6] (7)
pub fn constraints() -> Vec<Box<dyn TransitionConstraint<GoldilocksField, GoldilocksExtension>>> {
let mut constraints: Vec<Box<dyn TransitionConstraint<GoldilocksField, GoldilocksExtension>>> =
Vec::new();
Expand All @@ -970,8 +973,19 @@ pub fn constraints() -> Vec<Box<dyn TransitionConstraint<GoldilocksField, Goldil
)));
idx += 1;

// IS_BIT for add_limb_overflow[0..6]
for &col in &cols::ADD_LIMB_OVERFLOW {
// IS_BIT<μ_read>
constraints.push(Box::new(IsBitConstraint::unconditional(cols::MU_READ, idx)));
idx += 1;

// IS_BIT<μ_write>
constraints.push(Box::new(IsBitConstraint::unconditional(
cols::MU_WRITE,
idx,
)));
idx += 1;

// IS_BIT for carry[0..6]
for &col in &cols::CARRY {
constraints.push(Box::new(IsBitConstraint::unconditional(col, idx)));
idx += 1;
}
Expand Down Expand Up @@ -1013,45 +1027,45 @@ mod tests {
}

#[test]
fn test_add_limb_overflow() {
// Address 0xFFFF_FFFF should overflow when adding 1
fn test_carry_flags() {
// Address 0xFFFF_FFFF should carry when adding 1
let op =
MemwOperation::new(false, 0xFFFF_FFFF, [0; 8], 100, 8, false).with_old([0; 8], [50; 8]);
let trace = generate_memw_trace(&[op]);

// All 7 overflow flags should be 1 since 0xFFFF_FFFF + i >= 2^32 for i >= 1
// All 7 carry flags should be 1 since 0xFFFF_FFFF + i >= 2^32 for i >= 1
for i in 0..7 {
let val = trace.get_main(0, cols::ADD_LIMB_OVERFLOW[i]);
assert_eq!(*val, FE::one(), "overflow[{i}] should be 1");
let val = trace.get_main(0, cols::CARRY[i]);
assert_eq!(*val, FE::one(), "carry[{i}] should be 1");
}

// Address 0x0000_0000 should not overflow
// Address 0x0000_0000 should not carry
let op2 =
MemwOperation::new(false, 0x0000_0000, [0; 8], 100, 8, false).with_old([0; 8], [50; 8]);
let trace2 = generate_memw_trace(&[op2]);
for i in 0..7 {
let val = trace2.get_main(0, cols::ADD_LIMB_OVERFLOW[i]);
assert_eq!(*val, FE::zero(), "overflow[{i}] should be 0");
let val = trace2.get_main(0, cols::CARRY[i]);
assert_eq!(*val, FE::zero(), "carry[{i}] should be 0");
}

// Address 0xFFFF_FFFE with width=8 exercises mixed per-byte carry bits:
// overflow[0]=0 (0xFFFF_FFFE+1 = 0xFFFF_FFFF < 2^32)
// overflow[1..6]=1 (0xFFFF_FFFE+2..8 >= 2^32)
// carry[0]=0 (0xFFFF_FFFE+1 = 0xFFFF_FFFF < 2^32)
// carry[1..6]=1 (0xFFFF_FFFE+2..8 >= 2^32)
let op3 =
MemwOperation::new(false, 0xFFFF_FFFE, [0; 8], 100, 8, false).with_old([0; 8], [50; 8]);
let trace3 = generate_memw_trace(&[op3]);
let val0 = trace3.get_main(0, cols::ADD_LIMB_OVERFLOW[0]);
let val0 = trace3.get_main(0, cols::CARRY[0]);
assert_eq!(
*val0,
FE::zero(),
"overflow[0] should be 0 for base 0xFFFF_FFFE"
"carry[0] should be 0 for base 0xFFFF_FFFE"
);
for i in 1..7 {
let val = trace3.get_main(0, cols::ADD_LIMB_OVERFLOW[i]);
let val = trace3.get_main(0, cols::CARRY[i]);
assert_eq!(
*val,
FE::one(),
"overflow[{i}] should be 1 for base 0xFFFF_FFFE"
"carry[{i}] should be 1 for base 0xFFFF_FFFE"
);
}
}
Expand Down
Loading
Loading