Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 54 additions & 41 deletions prover/src/constraints/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ pub const BIT_FLAG_COLUMNS: &[usize] = &[
cols::ECALL,
cols::EBREAK,
// Sign bits
cols::RV1_SIGN_BIT,
cols::ARG2_SIGN_BIT,
cols::RES_SIGN_BIT,
cols::RV1_EXT_BIT,
cols::RV2_EXT_BIT,
cols::RES_EXT_BIT,
// Computed flags
cols::IS_EQUAL,
cols::BRANCH_COND,
Expand Down Expand Up @@ -391,7 +391,7 @@ impl TransitionConstraint<GoldilocksField, GoldilocksExtension> for Arg1LowerCon
}
}

/// Constraint: arg1[4:8] = rv1[2] * (1 - word_instr) + (2^32 - 1) * rv1_sign_bit * signed
/// Constraint: arg1[4:8] = rv1[2] * (1 - word_instr) + (2^32 - 1) * rv1_ext_bit * signed
///
/// Upper 32 bits of arg1 depends on word_instr and sign extension.
pub struct Arg1UpperConstraint {
Expand All @@ -418,15 +418,15 @@ impl Arg1UpperConstraint {
.get_main_evaluation_element(0, cols::WORD_INSTR)
.clone();
let signed = step.get_main_evaluation_element(0, cols::SIGNED).clone();
let rv1_sign_bit = step
.get_main_evaluation_element(0, cols::RV1_SIGN_BIT)
let rv1_ext_bit = step
.get_main_evaluation_element(0, cols::RV1_EXT_BIT)
.clone();

let one = FieldElement::<F>::one();
let mask_32: FieldElement<F> = FieldElement::from((1u64 << 32) - 1); // 2^32 - 1

// Expected: rv1_upper * (1 - word_instr) + mask_32 * rv1_sign_bit * signed
let expected = rv1_upper * (one - &word_instr) + mask_32 * rv1_sign_bit * signed;
// Expected: rv1_upper * (1 - word_instr) + mask_32 * rv1_ext_bit * signed
let expected = rv1_upper * (one - &word_instr) + mask_32 * rv1_ext_bit * signed;

// Constraint: arg1_hi - expected = 0
arg1_hi - expected
Expand All @@ -435,7 +435,7 @@ impl Arg1UpperConstraint {

impl TransitionConstraint<GoldilocksField, GoldilocksExtension> for Arg1UpperConstraint {
fn degree(&self) -> usize {
// rv1_sign_bit * signed * word_instr has degree 3
// rv1_ext_bit * signed * word_instr has degree 3
3
}

Expand Down Expand Up @@ -569,47 +569,47 @@ pub fn create_slt_res_zero_constraints(
}

// =========================================================================
// Sign Bit Constraints
// Extension Bit Constraints (SIGN template from spec)
// =========================================================================

/// Constraint: sign bits are zero when word_instr = 0
/// Constraint: ext_bit must be zero when word_instr = 0
///
/// (rv1_sign_bit + arg2_sign_bit + res_sign_bit) * (1 - word_instr) = 0
pub struct SignBitZeroConstraint {
/// (1 - word_instr) * ext_bit = 0
///
/// One instance per extension bit (rv1_ext_bit, rv2_ext_bit, res_ext_bit).
pub struct ExtBitZeroConstraint {
constraint_idx: usize,
ext_bit_col: usize,
}

impl SignBitZeroConstraint {
pub fn new(constraint_idx: usize) -> Self {
Self { constraint_idx }
impl ExtBitZeroConstraint {
pub fn new(constraint_idx: usize, ext_bit_col: usize) -> Self {
Self {
constraint_idx,
ext_bit_col,
}
}

fn compute<F, E>(&self, step: &TableView<F, E>) -> FieldElement<F>
where
F: IsSubFieldOf<E>,
E: IsField,
{
let rv1_sign_bit = step
.get_main_evaluation_element(0, cols::RV1_SIGN_BIT)
.clone();
let arg2_sign_bit = step
.get_main_evaluation_element(0, cols::ARG2_SIGN_BIT)
.clone();
let res_sign_bit = step
.get_main_evaluation_element(0, cols::RES_SIGN_BIT)
let ext_bit = step
.get_main_evaluation_element(0, self.ext_bit_col)
.clone();
let word_instr = step
.get_main_evaluation_element(0, cols::WORD_INSTR)
.clone();

let one = FieldElement::<F>::one();

// (sum of sign bits) * (1 - word_instr) = 0
(rv1_sign_bit + arg2_sign_bit + res_sign_bit) * (one - word_instr)
// (1 - word_instr) * ext_bit = 0
(one - word_instr) * ext_bit
}
}

impl TransitionConstraint<GoldilocksField, GoldilocksExtension> for SignBitZeroConstraint {
impl TransitionConstraint<GoldilocksField, GoldilocksExtension> for ExtBitZeroConstraint {
fn degree(&self) -> usize {
2
}
Expand Down Expand Up @@ -874,7 +874,7 @@ impl TransitionConstraint<GoldilocksField, GoldilocksExtension> for Arg2LowerCon
}
}

/// Constraint: arg2[4:] = (1-LOAD)*((1-word_instr)*rv2[2] + signed*arg2_sign_bit*(2^32-1)) + (1-BEQ-BLT-STORE)*imm[1]
/// Constraint: arg2[4:] = (1-LOAD)*((1-word_instr)*rv2[2] + signed*rv2_ext_bit*(2^32-1)) + (1-BEQ-BLT-STORE)*imm[1]
///
/// arg2 upper 32 bits with sign extension logic.
pub struct Arg2UpperConstraint {
Expand Down Expand Up @@ -912,13 +912,13 @@ impl Arg2UpperConstraint {
let blt = step.get_main_evaluation_element(0, cols::BLT);
let word_instr = step.get_main_evaluation_element(0, cols::WORD_INSTR);
let signed = step.get_main_evaluation_element(0, cols::SIGNED);
let arg2_sign_bit = step.get_main_evaluation_element(0, cols::ARG2_SIGN_BIT);
let rv2_ext_bit = step.get_main_evaluation_element(0, cols::RV2_EXT_BIT);

let one = FieldElement::<F>::one();
let mask_32: FieldElement<F> = FieldElement::from((1u64 << 32) - 1);

// rv2_term = (1 - word_instr) * rv2[2] + signed * arg2_sign_bit * (2^32 - 1)
let rv2_term = (&one - word_instr) * rv2_upper + signed * arg2_sign_bit * &mask_32;
// rv2_term = (1 - word_instr) * rv2[2] + signed * rv2_ext_bit * (2^32 - 1)
let rv2_term = (&one - word_instr) * rv2_upper + signed * rv2_ext_bit * &mask_32;

// expected = (1-LOAD) * rv2_term + (1-BEQ-BLT-STORE) * imm[1]
// STORE now gets rv2_term (with sign extension), not imm
Expand All @@ -931,7 +931,7 @@ impl Arg2UpperConstraint {

impl TransitionConstraint<GoldilocksField, GoldilocksExtension> for Arg2UpperConstraint {
fn degree(&self) -> usize {
// (1-LOAD) * signed * arg2_sign_bit has degree 3
// (1-LOAD) * signed * rv2_ext_bit has degree 3
3
}

Expand Down Expand Up @@ -1029,7 +1029,7 @@ impl TransitionConstraint<GoldilocksField, GoldilocksExtension> for RvdLowerCons
}
}

/// Constraint: (1-LOAD) * (rvd[1] - ((1-word_instr)*res[4:] + res_sign_bit*(2^32-1))) = 0
/// Constraint: (1-LOAD) * (rvd[1] - ((1-word_instr)*res[4:] + res_ext_bit*(2^32-1))) = 0
///
/// When not LOAD, rvd upper 32 bits equals res upper with sign extension.
/// For LOAD: rvd is the loaded value, not res (which is the address).
Expand All @@ -1056,13 +1056,13 @@ impl RvdUpperConstraint {

let load = step.get_main_evaluation_element(0, cols::LOAD);
let word_instr = step.get_main_evaluation_element(0, cols::WORD_INSTR);
let res_sign_bit = step.get_main_evaluation_element(0, cols::RES_SIGN_BIT);
let res_ext_bit = step.get_main_evaluation_element(0, cols::RES_EXT_BIT);

let one = FieldElement::<F>::one();
let mask_32: FieldElement<F> = FieldElement::from((1u64 << 32) - 1);

// expected = (1 - word_instr) * res_hi + res_sign_bit * (2^32 - 1)
let expected = (&one - word_instr) * res_hi + res_sign_bit * mask_32;
// expected = (1 - word_instr) * res_hi + res_ext_bit * (2^32 - 1)
let expected = (&one - word_instr) * res_hi + res_ext_bit * mask_32;

// (1 - LOAD) * (rvd[1] - expected) = 0
(one - load) * (rvd_1 - expected)
Expand Down Expand Up @@ -1265,14 +1265,14 @@ pub fn create_jalr_constraints(constraint_idx_start: usize) -> (Vec<AddConstrain
/// - Rvd lower: 1
/// - Rvd upper: 1
/// - SLT res zero: 7 (bytes 1-7)
/// - Sign bit zero: 1
/// - Ext bit zero (SIGN template): 3 (rv1_ext_bit, rv2_ext_bit, res_ext_bit)
/// - rv1 zero-forcing (CM48): 3 (rv1[0..2] when read_register1 = 0)
/// - rv2 zero-forcing (CM50): 3 (rv2[0..2] when read_register2 = 0)
/// - Next PC (non-branching): 2
///
/// Total: 64 constraints (32 IS_BIT + 8 ADD + 24 other)
/// Total: 66 constraints (32 IS_BIT + 8 ADD + 26 other)
pub const NUM_CPU_CONSTRAINTS: usize =
32 + 2 + 2 + 2 + 2 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 7 + 1 + 3 + 3 + 2;
32 + 2 + 2 + 2 + 2 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 7 + 3 + 3 + 3 + 2;

/// Creates all CPU constraints.
///
Expand Down Expand Up @@ -1361,8 +1361,21 @@ pub fn create_all_cpu_constraints() -> (
other.push(Box::new(c));
}

// Sign bit zero constraint
other.push(Box::new(SignBitZeroConstraint::new(next_idx)));
// Extension bit zero constraints (SIGN template: !word_instr => ext_bit = 0)
other.push(Box::new(ExtBitZeroConstraint::new(
next_idx,
cols::RV1_EXT_BIT,
)));
next_idx += 1;
other.push(Box::new(ExtBitZeroConstraint::new(
next_idx,
cols::RV2_EXT_BIT,
)));
next_idx += 1;
other.push(Box::new(ExtBitZeroConstraint::new(
next_idx,
cols::RES_EXT_BIT,
)));
next_idx += 1;

// Next PC (non-branching) constraints
Expand Down
Loading
Loading