diff --git a/executor/programs/asm/test_keccak.s b/executor/programs/asm/test_keccak.s new file mode 100644 index 000000000..31cd93be6 --- /dev/null +++ b/executor/programs/asm/test_keccak.s @@ -0,0 +1,38 @@ + .attribute 5, "rv64i2p1_m2p0_zmmul1p0" +.Lfunc_end0: + .globl main +main: + # Allocate 200 bytes on the stack for the Keccak state (25 × u64) + addi sp, sp, -200 + + # Zero out the state (200 bytes = 25 doublewords) + mv t0, sp + li t1, 25 +.Lzero_loop: + sd zero, 0(t0) + addi t0, t0, 8 + addi t1, t1, -1 + bnez t1, .Lzero_loop + + # Call keccak-f[1600] permutation + # a0 = pointer to 200-byte state + # a7 = syscall number (0xFFFFFFFFFFFFFFFE = u64::MAX - 1) + mv a0, sp + li a7, -2 + ecall + + # Commit the post-permutation state so the test can verify the KAT. + # Commit syscall: a0=fd(1), a1=buf_addr, a2=count, a7=64 + li a0, 1 + mv a1, sp + li a2, 200 + li a7, 64 + ecall + + # Restore stack and halt + addi sp, sp, 200 + li a0, 0 + li a7, 93 + ecall +.Lfunc_end1: + .size main, .Lfunc_end1-main diff --git a/executor/programs/asm/test_keccak_multi.s b/executor/programs/asm/test_keccak_multi.s new file mode 100644 index 000000000..fcd192de7 --- /dev/null +++ b/executor/programs/asm/test_keccak_multi.s @@ -0,0 +1,48 @@ + .attribute 5, "rv64i2p1_m2p0_zmmul1p0" +.Lfunc_end0: + .globl main +main: + # Allocate 200 bytes on the stack for the Keccak state (25 × u64). + addi sp, sp, -200 + + # Initialize a non-zero, deterministic state: lane[i] = i + 1. + # Used by the host test as the initial state for tiny-keccak::keccakf + # cross-checking. + mv t0, sp + li t1, 1 + li t2, 26 +.Linit_loop: + sd t1, 0(t0) + addi t0, t0, 8 + addi t1, t1, 1 + bne t1, t2, .Linit_loop + + # First keccak-f[1600] call. + mv a0, sp + li a7, -2 + ecall + + # Second keccak-f[1600] call on the result. + mv a0, sp + li a7, -2 + ecall + + # Third keccak-f[1600] call on the result. + mv a0, sp + li a7, -2 + ecall + + # Commit the final 200-byte state. + li a0, 1 + mv a1, sp + li a2, 200 + li a7, 64 + ecall + + # Restore stack and halt. + addi sp, sp, 200 + li a0, 0 + li a7, 93 + ecall +.Lfunc_end1: + .size main, .Lfunc_end1-main diff --git a/executor/src/vm/instruction/execution.rs b/executor/src/vm/instruction/execution.rs index a5222557a..04502645b 100644 --- a/executor/src/vm/instruction/execution.rs +++ b/executor/src/vm/instruction/execution.rs @@ -8,12 +8,20 @@ use crate::vm::{ const REGULAR_PC_UPDATE: u64 = 4; pub enum SyscallNumbers { + // Placeholder discriminant. The actual syscall value is KECCAK_SYSCALL_NUMBER. + KeccakPermute = 0, Print = 1, Panic = 2, Commit = 64, Halt = 93, } +/// Syscall number for KeccakPermute (u64::MAX - 1 = 0xFFFF_FFFF_FFFF_FFFE). +/// +/// Cannot be an enum discriminant because it exceeds isize::MAX. +pub const KECCAK_SYSCALL_NUMBER: u64 = u64::MAX - 1; +const KECCAK_STATE_BYTES: u64 = 25 * 8; + impl TryFrom for SyscallNumbers { type Error = (); fn try_from(value: u64) -> Result { @@ -22,6 +30,7 @@ impl TryFrom for SyscallNumbers { 2 => Ok(SyscallNumbers::Panic), 64 => Ok(SyscallNumbers::Commit), 93 => Ok(SyscallNumbers::Halt), + v if v == KECCAK_SYSCALL_NUMBER => Ok(SyscallNumbers::KeccakPermute), _ => Err(()), } } @@ -324,6 +333,32 @@ impl Instruction { src2_val = buf_addr; dst_val = count; } + SyscallNumbers::KeccakPermute => { + // keccak-f[1600] permutation on 200 bytes (25 × u64) at address in x10 + let state_addr = registers.read(10)?; + if !state_addr.is_multiple_of(8) { + return Err(ExecutionError::UnalignedKeccakStateAddress(state_addr)); + } + state_addr + .checked_add(KECCAK_STATE_BYTES - 1) + .ok_or(ExecutionError::KeccakStateAddressOverflow(state_addr))?; + + let mut state = [0u64; 25]; + for (i, lane) in state.iter_mut().enumerate() { + let lane_addr = state_addr + .checked_add((i as u64) * 8) + .ok_or(ExecutionError::KeccakStateAddressOverflow(state_addr))?; + *lane = memory.load_doubleword(lane_addr)?; + } + keccak_f1600(&mut state); + for (i, &lane) in state.iter().enumerate() { + let lane_addr = state_addr + .checked_add((i as u64) * 8) + .ok_or(ExecutionError::KeccakStateAddressOverflow(state_addr))?; + memory.store_doubleword(lane_addr, lane)?; + } + src2_val = state_addr; + } SyscallNumbers::Halt => { // halt return Ok(Log { @@ -496,4 +531,177 @@ pub enum ExecutionError { InvalidWSuffixOperation(ArithOp), #[error("Invalid commit fd: expected 1 (stdout), got {0}")] InvalidCommitFd(u64), + #[error("Unaligned Keccak state address: {0:#018x}")] + UnalignedKeccakStateAddress(u64), + #[error("Keccak state address range overflows: {0:#018x}")] + KeccakStateAddressOverflow(u64), +} + +// ============================================================================= +// Keccak-f[1600] permutation +// ============================================================================= + +/// Round constants for Keccak-f[1600] (24 rounds). +pub const KECCAK_RC: [u64; 24] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808A, + 0x8000000080008000, + 0x000000000000808B, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008A, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000A, + 0x000000008000808B, + 0x800000000000008B, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800A, + 0x800000008000000A, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, +]; + +/// Rotation offsets R[x][y] for the rho step of Keccak-f[1600]. +pub const KECCAK_RHO: [[u32; 5]; 5] = [ + [0, 36, 3, 41, 18], + [1, 44, 10, 45, 2], + [62, 6, 43, 15, 61], + [28, 55, 25, 21, 56], + [27, 20, 39, 8, 14], +]; + +/// Apply the Keccak-f[1600] permutation (24 rounds) to a 25-word state. +/// +/// The state is indexed as `state[x + 5*y]` where `x, y ∈ {0..4}`. +pub fn keccak_f1600(state: &mut [u64; 25]) { + for &rc in &KECCAK_RC { + // θ (theta) + let mut c = [0u64; 5]; + for x in 0..5 { + c[x] = state[x] ^ state[x + 5] ^ state[x + 10] ^ state[x + 15] ^ state[x + 20]; + } + let mut d = [0u64; 5]; + for x in 0..5 { + d[x] = c[(x + 4) % 5] ^ c[(x + 1) % 5].rotate_left(1); + } + for x in 0..5 { + for y in 0..5 { + state[x + 5 * y] ^= d[x]; + } + } + + // ρ (rho) and π (pi) + let mut b = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + b[y + 5 * ((2 * x + 3 * y) % 5)] = state[x + 5 * y].rotate_left(KECCAK_RHO[x][y]); + } + } + + // χ (chi) + for x in 0..5 { + for y in 0..5 { + state[x + 5 * y] = + b[x + 5 * y] ^ (!b[(x + 1) % 5 + 5 * y] & b[(x + 2) % 5 + 5 * y]); + } + } + + // ι (iota) + state[0] ^= rc; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_keccak_f1600_zero_input() { + let mut state = [0u64; 25]; + keccak_f1600(&mut state); + + let expected: [u64; 25] = [ + 0xF1258F7940E1DDE7, + 0x84D5CCF933C0478A, + 0xD598261EA65AA9EE, + 0xBD1547306F80494D, + 0x8B284E056253D057, + 0xFF97A42D7F8E6FD4, + 0x90FEE5A0A44647C4, + 0x8C5BDA0CD6192E76, + 0xAD30A6F71B19059C, + 0x30935AB7D08FFC64, + 0xEB5AA93F2317D635, + 0xA9A6E6260D712103, + 0x81A57C16DBCF555F, + 0x43B831CD0347C826, + 0x01F22F1A11A5569F, + 0x05E5635A21D9AE61, + 0x64BEFEF28CC970F2, + 0x613670957BC46611, + 0xB87C5A554FD00ECB, + 0x8C3EE88A1CCF32C8, + 0x940C7922AE3A2614, + 0x1841F924A2C509E4, + 0x16F53526E70465C2, + 0x75F644E97F30A13B, + 0xEAF1FF7B5CECA249, + ]; + + assert_eq!(state, expected, "keccak-f[1600] on zero input mismatch"); + } + + #[test] + fn test_keccak_f1600_nonzero_input() { + let mut state = [0u64; 25]; + state[0] = 1; + let original = state; + keccak_f1600(&mut state); + assert_ne!(state, original); + assert!(state.iter().any(|&x| x != 0)); + } + + #[test] + fn test_keccak_syscall_rejects_unaligned_state_addr() { + let mut pc = 0; + let mut registers = Registers::default(); + let mut memory = Memory::default(); + + registers.write(17, KECCAK_SYSCALL_NUMBER).unwrap(); + registers.write(10, 0x1001).unwrap(); + + let err = Instruction::EcallEbreak + .run(&mut pc, &mut registers, &mut memory) + .unwrap_err(); + assert!(matches!( + err, + ExecutionError::UnalignedKeccakStateAddress(0x1001) + )); + } + + #[test] + fn test_keccak_syscall_rejects_overflowing_state_range() { + let mut pc = 0; + let mut registers = Registers::default(); + let mut memory = Memory::default(); + + registers.write(17, KECCAK_SYSCALL_NUMBER).unwrap(); + registers.write(10, u64::MAX - 191).unwrap(); + + let err = Instruction::EcallEbreak + .run(&mut pc, &mut registers, &mut memory) + .unwrap_err(); + assert!(matches!( + err, + ExecutionError::KeccakStateAddressOverflow(addr) if addr == u64::MAX - 191 + )); + } } diff --git a/executor/tests/asm.rs b/executor/tests/asm.rs index cbc1adec5..86722b82c 100644 --- a/executor/tests/asm.rs +++ b/executor/tests/asm.rs @@ -801,3 +801,49 @@ fn test_sub_64bit() { fn test_sub_underflow() { run_program("./program_artifacts/asm/sub_underflow.elf"); } + +// ==================== Keccak Precompile ==================== + +#[test] +fn test_keccak() { + // Runs keccak-f[1600] on a zeroed state and commits the 200-byte result. + // Expected output is the FIPS-202 zero-input KAT. + let elf_data = std::fs::read("./program_artifacts/asm/test_keccak.elf").unwrap(); + let program = Elf::load(&elf_data).unwrap(); + let executor = Executor::new(&program, vec![]).expect("Failed to create executor"); + let result = executor.run().expect("Failed to run program"); + + let expected_state: [u64; 25] = [ + 0xF1258F7940E1DDE7, + 0x84D5CCF933C0478A, + 0xD598261EA65AA9EE, + 0xBD1547306F80494D, + 0x8B284E056253D057, + 0xFF97A42D7F8E6FD4, + 0x90FEE5A0A44647C4, + 0x8C5BDA0CD6192E76, + 0xAD30A6F71B19059C, + 0x30935AB7D08FFC64, + 0xEB5AA93F2317D635, + 0xA9A6E6260D712103, + 0x81A57C16DBCF555F, + 0x43B831CD0347C826, + 0x01F22F1A11A5569F, + 0x05E5635A21D9AE61, + 0x64BEFEF28CC970F2, + 0x613670957BC46611, + 0xB87C5A554FD00ECB, + 0x8C3EE88A1CCF32C8, + 0x940C7922AE3A2614, + 0x1841F924A2C509E4, + 0x16F53526E70465C2, + 0x75F644E97F30A13B, + 0xEAF1FF7B5CECA249, + ]; + let mut expected_bytes = Vec::with_capacity(200); + for lane in expected_state { + expected_bytes.extend_from_slice(&lane.to_le_bytes()); + } + assert_eq!(result.return_values.memory_values, expected_bytes); + assert_eq!(result.return_values.register_values.0, 0); +} diff --git a/prover/Cargo.toml b/prover/Cargo.toml index dac711002..60ed39c0c 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -20,6 +20,7 @@ rayon = { version = "1.8.0", optional = true } [dev-dependencies] env_logger = "*" criterion = { version = "0.5", default-features = false } +tiny-keccak = { version = "2.0", features = ["keccak"] } [[bench]] name = "vm_prover_benchmark" diff --git a/prover/src/constraints/cpu.rs b/prover/src/constraints/cpu.rs index 64d6b7e3e..546f2f2a4 100644 --- a/prover/src/constraints/cpu.rs +++ b/prover/src/constraints/cpu.rs @@ -1033,7 +1033,7 @@ pub fn create_jalr_constraints(constraint_idx_start: usize) -> (Vec, pub halt: VmAir, pub commit: VmAir, + pub keccak: VmAir, + pub keccak_rnd: VmAir, + pub keccak_rc: VmAir, pub register: VmAir, pub pages: Vec, pub memw_registers: Vec, @@ -213,6 +217,9 @@ impl VmAirs { (&self.decode, &mut traces.decode, &()), (&self.halt, &mut traces.halt, &()), (&self.commit, &mut traces.commit, &()), + (&self.keccak, &mut traces.keccak, &()), + (&self.keccak_rnd, &mut traces.keccak_rnd, &()), + (&self.keccak_rc, &mut traces.keccak_rc, &()), (&self.register, &mut traces.register, &()), ]; @@ -268,6 +275,9 @@ impl VmAirs { &self.decode, &self.halt, &self.commit, + &self.keccak, + &self.keccak_rnd, + &self.keccak_rc, &self.register, ]; @@ -363,6 +373,12 @@ impl VmAirs { .collect(); let halt = create_halt_air(proof_options); let commit = create_commit_air(proof_options); + let keccak = create_keccak_air(proof_options); + let keccak_rnd = create_keccak_rnd_air(proof_options); + let keccak_rc = create_keccak_rc_air(proof_options).with_preprocessed( + tables::keccak_rc::preprocessed_commitment(proof_options), + tables::keccak_rc::NUM_PRECOMPUTED_COLS, + ); let register = create_register_air(proof_options).with_preprocessed( register::preprocessed_commitment(proof_options, elf.entry_point), register::NUM_PREPROCESSED_COLS, @@ -406,6 +422,9 @@ impl VmAirs { branches, halt, commit, + keccak, + keccak_rnd, + keccak_rc, register, pages, memw_registers, @@ -690,11 +709,11 @@ pub fn verify_with_options( ); // Cross-check: table_counts must match the number of sub-proofs. - // Fixed tables (bitwise, decode, halt, commit, register) = 5, plus page tables. - let expected_proof_count = vm_proof.table_counts.total() + 5 + page_configs.len(); + // Fixed tables (bitwise, decode, halt, commit, keccak, keccak_rnd, keccak_rc, register) = 8, plus page tables. + let expected_proof_count = vm_proof.table_counts.total() + 8 + page_configs.len(); if expected_proof_count != vm_proof.proof.proofs.len() { return Err(Error::InvalidTableCounts(format!( - "table_counts total ({}) + 5 fixed + {} pages = {}, but proof contains {} sub-proofs", + "table_counts total ({}) + 8 fixed + {} pages = {}, but proof contains {} sub-proofs", vm_proof.table_counts.total(), page_configs.len(), expected_proof_count, diff --git a/prover/src/tables/cpu.rs b/prover/src/tables/cpu.rs index 70ae8c501..57f207d4d 100644 --- a/prover/src/tables/cpu.rs +++ b/prover/src/tables/cpu.rs @@ -306,6 +306,12 @@ pub struct CpuOperation { /// For Commit ECALLs: byte count from x12 pub commit_count: u64, + + /// Whether this ECALL is a KeccakPermute syscall + pub ecall_keccak: bool, + + /// For KeccakPermute ECALLs: state address from x10 + pub keccak_state_addr: u64, } impl CpuOperation { @@ -641,6 +647,9 @@ impl CpuOperation { } else { (0, 0) }; + let ecall_keccak = decode.op_ecall + && log.src1_val == executor::vm::instruction::execution::KECCAK_SYSCALL_NUMBER; + let keccak_state_addr = if ecall_keccak { log.src2_val } else { 0 }; // CM50: (1 - read_register2) * rv2[i] = 0. When read_register2=0, rv2 must be 0. // For example, ECALL has read_register2=0 (rs2 defaults to 0). The commit buf_addr is // carried separately in commit_buf_addr and does not go through rv2. @@ -663,6 +672,8 @@ impl CpuOperation { ecall_commit, commit_buf_addr, commit_count, + ecall_keccak, + keccak_state_addr, }; // Compute runtime-specific values based on instruction type @@ -2035,12 +2046,9 @@ pub fn bus_interactions() -> Vec { } } - // ECALL interaction (single shared bus for HALT and COMMIT) + // ECALL interaction (shared bus for HALT, COMMIT, and KECCAK) // ------------------------------------------------------------------------- - // Sends to both HALT and COMMIT tables. Each receiver pattern-matches on - // the syscall number in the payload. - // multiplicity = ECALL - // rv1 = value of a7 register (syscall number). + // multiplicity = ECALL (all ECALLs, each receiver matches on syscall number) interactions.push(BusInteraction::sender( BusId::Ecall, Multiplicity::Column(cols::ECALL), diff --git a/prover/src/tables/keccak.rs b/prover/src/tables/keccak.rs new file mode 100644 index 000000000..87e8dc122 --- /dev/null +++ b/prover/src/tables/keccak.rs @@ -0,0 +1,567 @@ +//! KECCAK core chip — handles ECALL, memory I/O, and delegation to the round chip. +//! +//! One row per keccak permutation call. Reads/writes 25 u64 lanes from/to memory, +//! sends input state to the round chip via the Keccak bus, and receives the output +//! state after 24 rounds. +//! +//! ## Column layout (~511 columns) +//! +//! | Group | Size | Description | +//! |----------------|------|------------------------------------------------| +//! | timestamp | 2 | DWordWL | +//! | addr | 8 | State address as DWordBL (8 bytes) | +//! | input_state | 200 | Input state bytes [5][5][8] | +//! | output_state | 200 | Output state bytes [5][5][8] | +//! | state_ptr | 100 | Per-lane DWordHL addresses [25][4] | +//! | mu | 1 | Multiplicity flag | + +use executor::vm::instruction::execution::KECCAK_SYSCALL_NUMBER; +use math::field::element::FieldElement; +use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; +use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; +use stark::table::TableView; +use stark::trace::TraceTable; + +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; +use crate::constraints::templates::{AddConstraint, AddOperand, INV_SHIFT_32}; + +// ========================================================================= +// Column indices +// ========================================================================= + +pub mod cols { + pub const TIMESTAMP_0: usize = 0; + pub const TIMESTAMP_1: usize = 1; + + // addr[8] — state address as 8 bytes (DWordBL) + pub const ADDR: usize = 2; + + // input_state[5][5][8] = 200 bytes + pub const INPUT_STATE: usize = ADDR + 8; // 10 + + // output_state[5][5][8] = 200 bytes + pub const OUTPUT_STATE: usize = INPUT_STATE + 200; // 210 + + // state_ptr[25][4] = 100 halfwords (DWordHL per lane) + pub const STATE_PTR: usize = OUTPUT_STATE + 200; // 410 + + pub const MU: usize = STATE_PTR + 100; // 510 + + pub const NUM_COLUMNS: usize = MU + 1; // 511 + + // ------------------------------------------------------------------------- + // Index helpers + // ------------------------------------------------------------------------- + + #[inline] + pub const fn addr(byte: usize) -> usize { + ADDR + byte + } + + /// Index into input_state[x][y][byte] + #[inline] + pub const fn input_state(x: usize, y: usize, byte: usize) -> usize { + INPUT_STATE + (x + 5 * y) * 8 + byte + } + + /// Index into output_state[x][y][byte] + #[inline] + pub const fn output_state(x: usize, y: usize, byte: usize) -> usize { + OUTPUT_STATE + (x + 5 * y) * 8 + byte + } + + /// Index into state_ptr[lane_idx][halfword] (DWordHL = 4 halfwords) + #[inline] + pub const fn state_ptr(lane_idx: usize, hw: usize) -> usize { + STATE_PTR + lane_idx * 4 + hw + } +} + +// ========================================================================= +// Operation struct +// ========================================================================= + +#[derive(Debug, Clone)] +pub struct KeccakOperation { + pub timestamp: u64, + pub state_addr: u64, + pub input: [u64; 25], + pub output: [u64; 25], +} + +// ========================================================================= +// Trace generation +// ========================================================================= + +fn byte_of(val: u64, b: usize) -> u8 { + ((val >> (b * 8)) & 0xFF) as u8 +} + +pub fn generate_keccak_trace( + ops: &[KeccakOperation], +) -> TraceTable { + let n = ops.len(); + let num_rows = n.next_power_of_two().max(4); + let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; + + for (row_idx, op) in ops.iter().enumerate() { + let base = row_idx * cols::NUM_COLUMNS; + + // Timestamp + data[base + cols::TIMESTAMP_0] = FE::from(op.timestamp & 0xFFFF_FFFF); + data[base + cols::TIMESTAMP_1] = FE::from(op.timestamp >> 32); + + // Address as 8 bytes + for b in 0..8 { + data[base + cols::addr(b)] = FE::from(byte_of(op.state_addr, b) as u64); + } + + // Input state as bytes + for x in 0..5 { + for y in 0..5 { + let lane = op.input[x + 5 * y]; + for b in 0..8 { + data[base + cols::input_state(x, y, b)] = FE::from(byte_of(lane, b) as u64); + } + } + } + + // Output state as bytes + for x in 0..5 { + for y in 0..5 { + let lane = op.output[x + 5 * y]; + for b in 0..8 { + data[base + cols::output_state(x, y, b)] = FE::from(byte_of(lane, b) as u64); + } + } + } + + // State pointers: state_ptr[lane] = addr + 8 * lane_idx + for lane_idx in 0..25 { + let ptr = op + .state_addr + .checked_add(lane_idx as u64 * 8) + .expect("keccak state address range must be validated by the executor"); + data[base + cols::state_ptr(lane_idx, 0)] = FE::from(ptr & 0xFFFF); + data[base + cols::state_ptr(lane_idx, 1)] = FE::from((ptr >> 16) & 0xFFFF); + data[base + cols::state_ptr(lane_idx, 2)] = FE::from((ptr >> 32) & 0xFFFF); + data[base + cols::state_ptr(lane_idx, 3)] = FE::from((ptr >> 48) & 0xFFFF); + } + + // mu = 1 (real row) + data[base + cols::MU] = FE::one(); + } + + // Padding rows: state_ptr[lane][0] = 8 * lane_idx (per spec keccak.toml pad). + // Halfwords 1..3 stay zero since 8*24 = 192 fits in the low halfword. + // mu = 0 gates all bus interactions and the ADD constraint, so these values + // only need to satisfy the pad requirement, not reconstruct a real address. + for row_idx in n..num_rows { + let base = row_idx * cols::NUM_COLUMNS; + for lane_idx in 0..25 { + data[base + cols::state_ptr(lane_idx, 0)] = FE::from((lane_idx as u64) * 8); + } + } + + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +// ========================================================================= +// Bus interactions +// ========================================================================= + +pub fn bus_interactions() -> Vec { + let syscall_lo = KECCAK_SYSCALL_NUMBER & 0xFFFF_FFFF; + let syscall_hi = KECCAK_SYSCALL_NUMBER >> 32; + let mut interactions = Vec::with_capacity(160); + + // 1. ECALL receiver (shared bus, per spec keccak:c:output) + // Payload: [ts_lo, ts_hi, syscall_lo32, syscall_hi32] in DWordWL [lo, hi] + // ordering, matching the CPU ECALL sender shared with HALT/COMMIT. + interactions.push(BusInteraction::receiver( + BusId::Ecall, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::constant(syscall_lo), + BusValue::constant(syscall_hi), + ], + )); + + // 2. MEMW read_addr: read register x10 to bind addr (per spec keccak:c:read_addr) + // Format: [old[8], is_register=1, base_addr=[20,0], value[8], ts, ts_hi, write2=1, write4=0, write8=0] + // For register read: old = value = addr as WL + 6 zeros + { + // addr as DWordWL from DWordBL bytes: lo32 = sum(addr[0..4] * 256^i), hi32 = sum(addr[4..8] * 256^i) + let addr_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::addr(0), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::addr(1), + }, + LinearTerm::Column { + coefficient: 65536, + column: cols::addr(2), + }, + LinearTerm::Column { + coefficient: 16777216, + column: cols::addr(3), + }, + ]); + let addr_hi = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::addr(4), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::addr(5), + }, + LinearTerm::Column { + coefficient: 65536, + column: cols::addr(6), + }, + LinearTerm::Column { + coefficient: 16777216, + column: cols::addr(7), + }, + ]); + let mut values = Vec::with_capacity(24); + // old[0..7] = addr as WL + 6 zeros + values.push(addr_lo.clone()); + values.push(addr_hi.clone()); + for _ in 2..8 { + values.push(BusValue::constant(0)); + } + // is_register = 1 + values.push(BusValue::constant(1)); + // base_address = 2*10 = 20 (register x10) + values.push(BusValue::constant(20)); + values.push(BusValue::constant(0)); + // value[0..7] = same as old (read) + values.push(addr_lo); + values.push(addr_hi); + for _ in 2..8 { + values.push(BusValue::constant(0)); + } + // timestamp + values.push(BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }); + values.push(BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }); + // write2=1, write4=0, write8=0 (register access) + values.push(BusValue::constant(1)); + values.push(BusValue::constant(0)); + values.push(BusValue::constant(0)); + interactions.push(BusInteraction::sender( + BusId::Memw, + Multiplicity::Column(cols::MU), + values, + )); + } + + // 2. Keccak bus: send (timestamp, 0, input_state[200]) + // Per spec keccak.toml: input = ["timestamp", 0, "input_state"] where + // input_state is [[[Byte, 8], 5], 5] — 200 Byte elements, each its own + // bus element (no packing). + { + let mut values = vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::constant(0), // round = 0 + ]; + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::input_state(x, y, b), + packing: Packing::Direct, + }); + } + } + } + interactions.push(BusInteraction::sender( + BusId::Keccak, + Multiplicity::Column(cols::MU), + values, + )); + } + + // 3. Keccak bus: receive (timestamp, 24, output_state[200]) + { + let mut values = vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::constant(24), // round = 24 + ]; + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::output_state(x, y, b), + packing: Packing::Direct, + }); + } + } + } + interactions.push(BusInteraction::receiver( + BusId::Keccak, + Multiplicity::Column(cols::MU), + values, + )); + } + + // 4. IS_HALF range checks on state_ptr (100 interactions) + for lane_idx in 0..25 { + for hw in 0..4 { + interactions.push(BusInteraction::sender( + BusId::IsHalfword, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: cols::state_ptr(lane_idx, hw), + packing: Packing::Direct, + }], + )); + } + } + + // 5. Alignment: addr[0] & 7 = 0, which enforces addr % 8 == 0. + interactions.push(BusInteraction::sender( + BusId::AndByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::addr(0), + packing: Packing::Direct, + }, + BusValue::constant(7), + BusValue::constant(0), + ], + )); + + // 6. Range-check every addr byte. The addr columns are reconstructed as a + // linear combination (addr_lo = b0 + 256*b1 + 65536*b2 + 2^24*b3, etc.) + // for the MEMW lookup and the no-overflow / alignment constraints. Without + // an explicit byte range check on each cell, an attacker can keep the + // field-element value of that linear combination correct while encoding + // arbitrary non-byte values in the individual cells (e.g. addr[0]=0, + // addr[1]=V_lo * 256^{-1} mod p), bypassing the alignment check. + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::IsByte, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: cols::addr(b), + packing: Packing::Direct, + }], + )); + } + + // 7. MEMW interactions: 25 combined read+write per lane (per spec) + // Format: [old[8], is_register, addr_lo32, addr_hi32, value[8], ts[2], w2, w4, w8] = 24 + // old = input_state (read), value = output_state (write) + for lane_idx in 0..25 { + let x = lane_idx % 5; + let y = lane_idx / 5; + + // Address as DWordWL: lo32 = h0 + 2^16*h1, hi32 = h2 + 2^16*h3 + let addr_lo = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::state_ptr(lane_idx, 0), + }, + LinearTerm::Column { + coefficient: 65536, + column: cols::state_ptr(lane_idx, 1), + }, + ]); + let addr_hi = BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::state_ptr(lane_idx, 2), + }, + LinearTerm::Column { + coefficient: 65536, + column: cols::state_ptr(lane_idx, 3), + }, + ]); + + let mut values = Vec::with_capacity(24); + // old[0..8] = input_state bytes (the value being read) + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::input_state(x, y, b), + packing: Packing::Direct, + }); + } + // is_register = 0 + values.push(BusValue::constant(0)); + // address as DWordWL + values.push(addr_lo); + values.push(addr_hi); + // value[0..8] = output_state bytes (the value being written) + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::output_state(x, y, b), + packing: Packing::Direct, + }); + } + // timestamp + values.push(BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }); + values.push(BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }); + // write2=0, write4=0, write8=1 + values.push(BusValue::constant(0)); + values.push(BusValue::constant(0)); + values.push(BusValue::constant(1)); + + interactions.push(BusInteraction::sender( + BusId::Memw, + Multiplicity::Column(cols::MU), + values, + )); + } + + interactions +} + +// ========================================================================= +// Constraints +// ========================================================================= + +struct KeccakAddressNoOverflowConstraint { + constraint_idx: usize, +} + +impl KeccakAddressNoOverflowConstraint { + fn new(constraint_idx: usize) -> Self { + Self { constraint_idx } + } + + fn compute(&self, step: &TableView) -> FieldElement + where + F: IsSubFieldOf, + E: IsField, + { + let addr_lo = step.get_main_evaluation_element(0, cols::addr(0)).clone() + + step.get_main_evaluation_element(0, cols::addr(1)) * FieldElement::::from(256) + + step.get_main_evaluation_element(0, cols::addr(2)) * FieldElement::::from(65536) + + step.get_main_evaluation_element(0, cols::addr(3)) + * FieldElement::::from(16777216); + let addr_hi = step.get_main_evaluation_element(0, cols::addr(4)).clone() + + step.get_main_evaluation_element(0, cols::addr(5)) * FieldElement::::from(256) + + step.get_main_evaluation_element(0, cols::addr(6)) * FieldElement::::from(65536) + + step.get_main_evaluation_element(0, cols::addr(7)) + * FieldElement::::from(16777216); + + let ptr_lo = step + .get_main_evaluation_element(0, cols::state_ptr(24, 0)) + .clone() + + step.get_main_evaluation_element(0, cols::state_ptr(24, 1)) + * FieldElement::::from(65536); + let ptr_hi = step + .get_main_evaluation_element(0, cols::state_ptr(24, 2)) + .clone() + + step.get_main_evaluation_element(0, cols::state_ptr(24, 3)) + * FieldElement::::from(65536); + + let inv_2_32 = FieldElement::::from(INV_SHIFT_32); + let carry_0 = (addr_lo + FieldElement::::from(192) - ptr_lo) * inv_2_32.clone(); + let carry_1 = (addr_hi + carry_0 - ptr_hi) * inv_2_32; + step.get_main_evaluation_element(0, cols::MU).clone() * carry_1 + } +} + +impl TransitionConstraint + for KeccakAddressNoOverflowConstraint +{ + fn degree(&self) -> usize { + 2 + } + + fn constraint_idx(&self) -> usize { + self.constraint_idx + } + + fn evaluate(&self, step: &TableView) -> FieldElement + where + F: IsSubFieldOf, + E: IsField, + { + self.compute(step) + } +} + +/// Create constraints for the KECCAK core chip. +/// +/// Per spec (keccak:c:state_ptr): ADD template for each lane: +/// state_ptr[lane] = addr + 8 * lane_idx +/// +/// 25 lane pointers × 2 constraints per ADD + 1 top-lane no-overflow +/// constraint = 51 constraints total. +/// Conditional on mu (only real rows). +pub fn create_constraints( + constraint_idx_start: usize, +) -> ( + Vec>>, + usize, +) { + let mut constraints: Vec< + Box>, + > = Vec::with_capacity(51); + let mut idx = constraint_idx_start; + + // state_ptr[lane] = addr + 8*lane_idx + // addr is DWordBL (8 bytes), state_ptr is DWordHL (4 halfwords) + // ADD: lhs = addr (DWordBL→DWordWL), rhs = 8*lane_idx (constant), sum = state_ptr (DWordHL→DWordWL) + for lane_idx in 0..25 { + let offset = (lane_idx * 8) as i64; + let (c0, c1) = AddConstraint::new_pair( + vec![cols::MU], // conditional on mu + AddOperand::from_dword_bl(cols::ADDR), + AddOperand::constant(offset), + AddOperand::from_dword_hl(cols::state_ptr(lane_idx, 0)), + idx, + ); + constraints.push(c0.boxed()); + constraints.push(c1.boxed()); + idx += 2; + } + + constraints.push(KeccakAddressNoOverflowConstraint::new(idx).boxed()); + idx += 1; + + (constraints, idx) +} diff --git a/prover/src/tables/keccak_rc.rs b/prover/src/tables/keccak_rc.rs new file mode 100644 index 000000000..c2e14d643 --- /dev/null +++ b/prover/src/tables/keccak_rc.rs @@ -0,0 +1,190 @@ +//! KECCAK_RC: Precomputed round constant lookup table for Keccak-f[1600]. +//! +//! 24 rows (one per round), padded to 32. Each row maps a round index to its +//! 8-byte round constant. The round chip looks up `(round) → rc[8]` via the +//! `KeccakRc` bus. +//! +//! Follows the BITWISE preprocessed-table pattern: precomputed columns are +//! committed once and cached via `OnceLock`. + +use std::sync::OnceLock; + +use math::fft::cpu::bit_reversing::in_place_bit_reverse_permute; +use math::field::element::FieldElement; +use math::polynomial::Polynomial; +use stark::config::{BatchedMerkleTree, Commitment}; +use stark::lookup::{BusInteraction, BusValue, Multiplicity, Packing}; +use stark::proof::options::ProofOptions; +use stark::prover::evaluate_polynomial_on_lde_domain; +use stark::trace::{TraceTable, columns2rows}; + +use executor::vm::instruction::execution::KECCAK_RC; + +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; + +// ========================================================================= +// Column indices +// ========================================================================= + +pub mod cols { + /// Round index (0..23) + pub const ROUND: usize = 0; + /// RC bytes [0..7] — 8 bytes of the round constant (little-endian) + pub const RC: usize = 1; + pub const RC_END: usize = RC + 8; // = 9 + /// Multiplicity (how many times this row is looked up) + pub const MU: usize = 9; + + pub const NUM_COLUMNS: usize = 10; +} + +/// Number of precomputed columns (everything except MU). +pub const NUM_PRECOMPUTED_COLS: usize = 9; + +/// Number of real rows (one per keccak round). +pub const NUM_REAL_ROWS: usize = 24; + +/// Number of rows in the trace (padded to next power of 2). +pub const NUM_ROWS: usize = 32; + +/// Whether this table is preprocessed. +pub const fn is_preprocessed() -> bool { + true +} + +/// Generate one precomputed row: [round, rc_byte0, ..., rc_byte7]. +pub const fn generate_row(round: usize) -> [u64; NUM_PRECOMPUTED_COLS] { + let rc_val = if round < 24 { KECCAK_RC[round] } else { 0 }; + [ + round as u64, + rc_val & 0xFF, + (rc_val >> 8) & 0xFF, + (rc_val >> 16) & 0xFF, + (rc_val >> 24) & 0xFF, + (rc_val >> 32) & 0xFF, + (rc_val >> 40) & 0xFF, + (rc_val >> 48) & 0xFF, + (rc_val >> 56) & 0xFF, + ] +} + +// ========================================================================= +// Preprocessed commitment +// ========================================================================= + +static KECCAK_RC_COMMITMENT: OnceLock = OnceLock::new(); + +fn compute_preprocessed_commitment(options: &ProofOptions) -> Commitment { + // Generate precomputed columns + let mut columns: Vec> = (0..NUM_PRECOMPUTED_COLS) + .map(|_| Vec::with_capacity(NUM_ROWS)) + .collect(); + for idx in 0..NUM_ROWS { + let row = generate_row(idx); + for (col_idx, &value) in row.iter().enumerate() { + columns[col_idx].push(FE::from(value)); + } + } + + // Interpolate each column to a polynomial + let polys: Vec> = columns + .iter() + .map(|col| { + Polynomial::interpolate_fft::(col) + .expect("FFT interpolation failed for keccak_rc column") + }) + .collect(); + + // Evaluate on LDE domain + let blowup_factor = options.blowup_factor as usize; + let coset_offset = FE::from(options.coset_offset); + let mut lde_columns: Vec> = polys + .iter() + .map(|poly| { + evaluate_polynomial_on_lde_domain(poly, blowup_factor, NUM_ROWS, &coset_offset) + .expect("LDE evaluation failed for keccak_rc polynomial") + }) + .collect(); + + // Bit-reverse permute + for col in lde_columns.iter_mut() { + in_place_bit_reverse_permute(col); + } + + // Build Merkle tree + let lde_rows = columns2rows(lde_columns); + let tree = BatchedMerkleTree::::build(&lde_rows) + .expect("Failed to build Merkle tree for keccak_rc LDE"); + + tree.root +} + +#[inline] +pub fn preprocessed_commitment(options: &ProofOptions) -> Commitment { + *KECCAK_RC_COMMITMENT.get_or_init(|| compute_preprocessed_commitment(options)) +} + +// ========================================================================= +// Trace generation +// ========================================================================= + +/// Generate the KECCAK_RC trace table. +/// +/// All precomputed columns are filled; MU is initialized to zero and must be +/// updated via `update_multiplicities` after all round-chip lookups are known. +pub fn generate_keccak_rc_trace() -> TraceTable { + let mut data = vec![FE::zero(); NUM_ROWS * cols::NUM_COLUMNS]; + + for idx in 0..NUM_ROWS { + let base = idx * cols::NUM_COLUMNS; + let row = generate_row(idx); + for (col_idx, &value) in row.iter().enumerate() { + data[base + col_idx] = FE::from(value); + } + // MU = 0 (will be updated later) + } + + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +/// Increment MU for each round lookup. +/// +/// Called after the round chip's trace is generated. Each keccak permutation +/// call produces 24 round lookups (one per round), so each round row's MU +/// equals the number of keccak operations. +pub fn update_multiplicities( + trace: &mut TraceTable, + num_keccak_ops: usize, +) { + let mu = FieldElement::from(num_keccak_ops as u64); + for round in 0..NUM_REAL_ROWS { + let base = round * cols::NUM_COLUMNS; + trace.main_table.data[base + cols::MU] = mu; + } +} + +// ========================================================================= +// Bus interactions +// ========================================================================= + +/// Single receiver on the KeccakRc bus. +/// +/// Format: [round(Direct), rc[0](Direct), ..., rc[7](Direct)] +pub fn bus_interactions() -> Vec { + let mut values = vec![BusValue::Packed { + start_column: cols::ROUND, + packing: Packing::Direct, + }]; + for i in 0..8 { + values.push(BusValue::Packed { + start_column: cols::RC + i, + packing: Packing::Direct, + }); + } + + vec![BusInteraction::receiver( + BusId::KeccakRc, + Multiplicity::Column(cols::MU), + values, + )] +} diff --git a/prover/src/tables/keccak_rnd.rs b/prover/src/tables/keccak_rnd.rs new file mode 100644 index 000000000..277281583 --- /dev/null +++ b/prover/src/tables/keccak_rnd.rs @@ -0,0 +1,986 @@ +//! KECCAK_RND: Round chip for Keccak-f[1600] permutation. +//! +//! One row per round (24 rows per keccak call). All bitwise operations are +//! delegated to BITWISE lookup tables (XOR_BYTE, AND_BYTE, HWSL, IS_BYTE). +//! +//! ## Column layout (1,480 columns) +//! +//! | Group | Size | Description | +//! |----------------|------|---------------------------------------------------| +//! | timestamp | 2 | DWordWL | +//! | round | 1 | Round index (0..23) | +//! | start | 200 | Input state bytes [5][5][8] | +//! | Cxz | 160 | Column parity chain [5][4][8] | +//! | Cxz_left | 40 | Left component of rotated C [5][8] | +//! | Cxz_right | 20 | Carry bits of HWSL(C[x],1) [5][4] | +//! | Dxz | 40 | D values [5][8] | +//! | theta | 200 | State after θ [5][5][8] | +//! | rot_left | 200 | Left half of ρ rotation [5][5][8] | +//! | rot_right | 200 | Right half of ρ rotation [5][5][8] | +//! | chi_ands | 200 | AND results for χ [5][5][8] | +//! | chi | 200 | State after χ [5][5][8] | +//! | rc | 8 | Round constant bytes | +//! | iota | 8 | χ[0][0] ⊕ rc | +//! | mu | 1 | Multiplicity (1 for real, 0 for padding) | +//! +//! Note: spec [[variables.constant]] `rnc` and `rbc` are inlined as compile-time +//! constants derived from `KECCAK_RHO[x][y]`, not materialized as columns. +//! `Cxz_right` is typed `[Bit, 4]` per spec d75944ee — HWSL with shift=1 +//! produces a single-bit carry, range-checked via IS_BIT polynomial constraints. + +use executor::vm::instruction::execution::{KECCAK_RC, KECCAK_RHO}; +use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; +use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; +use stark::trace::TraceTable; + +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; + +// ========================================================================= +// Column indices +// ========================================================================= + +pub mod cols { + pub const TIMESTAMP_0: usize = 0; + pub const TIMESTAMP_1: usize = 1; + pub const ROUND: usize = 2; + + // start[5][5][8] = 200 bytes — input state for this round + pub const START: usize = 3; + + // Cxz[5][4][8] = 160 bytes — partial XOR chain for column parities + pub const CXZ: usize = START + 200; // 203 + + // Cxz_left[5][8] = 40 bytes — left shift component of rotated C + pub const CXZ_LEFT: usize = CXZ + 160; // 363 + + // Cxz_right[5][4] = 20 bits — carry bit of HWSL(C[x] halfword[hw], 1). + // For shift=1, HWSL emits a single-bit carry; one column per halfword. + pub const CXZ_RIGHT: usize = CXZ_LEFT + 40; // 403 + + // Dxz[5][8] = 40 bytes + pub const DXZ: usize = CXZ_RIGHT + 20; // 423 + + // theta[5][5][8] = 200 bytes — state after θ + pub const THETA: usize = DXZ + 40; // 463 + + // rot_left[5][5][8] = 200 bytes + pub const ROT_LEFT: usize = THETA + 200; // 663 + + // rot_right[5][5][8] = 200 bytes + pub const ROT_RIGHT: usize = ROT_LEFT + 200; // 863 + + // chi_ands[5][5][8] = 200 bytes + // (pi is a spec [[variables.virtual]] — inlined as rot_left + rot_right at + // compile-resolved offsets, not materialized as columns.) + pub const CHI_ANDS: usize = ROT_RIGHT + 200; // 1063 + + // chi[5][5][8] = 200 bytes — state after χ + pub const CHI: usize = CHI_ANDS + 200; // 1263 + + // rc[8] — round constant bytes + pub const RC: usize = CHI + 200; // 1463 + + // iota[8] — χ[0][0] ⊕ rc + pub const IOTA: usize = RC + 8; // 1471 + + // mu — multiplicity flag. + // rnc and rbc (spec [[variables.constant]]) are inlined as compile-time + // constants from KECCAK_RHO, not allocated as columns. + pub const MU: usize = IOTA + 8; // 1479 + + pub const NUM_COLUMNS: usize = MU + 1; // 1480 + + // ------------------------------------------------------------------------- + // Index helpers + // ------------------------------------------------------------------------- + + /// Index into start[x][y][byte] (200 bytes, row-major: y varies fastest) + #[inline] + pub const fn start(x: usize, y: usize, byte: usize) -> usize { + START + (x + 5 * y) * 8 + byte + } + + /// Index into Cxz[x][stage][byte] (160 bytes) + #[inline] + pub const fn cxz(x: usize, stage: usize, byte: usize) -> usize { + CXZ + (x * 4 + stage) * 8 + byte + } + + /// Index into Cxz_left[x][byte] + #[inline] + pub const fn cxz_left(x: usize, byte: usize) -> usize { + CXZ_LEFT + x * 8 + byte + } + + /// Index into Cxz_right[x][hw] — single-bit carry for halfword `hw` of x. + #[inline] + pub const fn cxz_right_bit(x: usize, hw: usize) -> usize { + CXZ_RIGHT + x * 4 + hw + } + + /// For byte `b` of the rotated_Cxz output, return Some(hw) if a Cxz_right + /// bit contributes (even b), else None (odd b → only Cxz_left contributes). + /// Spec d75944ee/9143370f: rotated_Cxz[z] = Cxz_left[z] + (1 - z%2) * + /// Cxz_right[(z/2 - 1) mod 4]. + #[inline] + pub const fn cxz_right_bit_for_byte(b: usize) -> Option { + if b.is_multiple_of(2) { + Some((b / 2 + 3) % 4) + } else { + None + } + } + + /// Index into Dxz[x][byte] + #[inline] + pub const fn dxz(x: usize, byte: usize) -> usize { + DXZ + x * 8 + byte + } + + /// Index into theta[x][y][byte] + #[inline] + pub const fn theta(x: usize, y: usize, byte: usize) -> usize { + THETA + (x + 5 * y) * 8 + byte + } + + /// Index into rot_left[x][y][byte] + #[inline] + pub const fn rot_left(x: usize, y: usize, byte: usize) -> usize { + ROT_LEFT + (x + 5 * y) * 8 + byte + } + + /// Index into rot_right[x][y][byte] + #[inline] + pub const fn rot_right(x: usize, y: usize, byte: usize) -> usize { + ROT_RIGHT + (x + 5 * y) * 8 + byte + } + + /// Resolve pi[x][y][z] (spec virtual) to the (rot_left_col, rot_right_col) + /// pair whose sum equals pi[x][y][z]. rbc is compile-time constant. + #[inline] + pub fn pi_src_cols(x: usize, y: usize, z: usize) -> (usize, usize) { + use executor::vm::instruction::execution::KECCAK_RHO; + let sx = (x + 3 * y) % 5; + let sy = x; + let rho_offset = KECCAK_RHO[sx][sy] as usize; + let rbc_val = rho_offset / 16; + let (l_byte, r_byte) = match rbc_val { + 0 => (z, (z + 6) % 8), + 1 => ((z + 6) % 8, (z + 4) % 8), + 2 => ((z + 4) % 8, (z + 2) % 8), + 3 => ((z + 2) % 8, z), + _ => unreachable!(), + }; + (rot_left(sx, sy, l_byte), rot_right(sx, sy, r_byte)) + } + + /// Index into chi_ands[x][y][byte] + #[inline] + pub const fn chi_ands(x: usize, y: usize, byte: usize) -> usize { + CHI_ANDS + (x + 5 * y) * 8 + byte + } + + /// Index into chi[x][y][byte] + #[inline] + pub const fn chi(x: usize, y: usize, byte: usize) -> usize { + CHI + (x + 5 * y) * 8 + byte + } + + /// Index into rc[byte] + #[inline] + pub const fn rc(byte: usize) -> usize { + RC + byte + } + + /// Index into iota[byte] + #[inline] + pub const fn iota(byte: usize) -> usize { + IOTA + byte + } +} + +// ========================================================================= +// Operation struct +// ========================================================================= + +/// One keccak permutation call's worth of data (produces 24 rows). +#[derive(Debug, Clone)] +pub struct KeccakRoundOperation { + pub timestamp: u64, + pub input: [u64; 25], + pub output: [u64; 25], +} + +// ========================================================================= +// Trace generation +// ========================================================================= + +/// Extract byte `b` (0..8) from a u64 value. +#[inline] +fn byte_of(val: u64, b: usize) -> u8 { + ((val >> (b * 8)) & 0xFF) as u8 +} + +/// Compute halfword shift left: (value << shift) mod 2^16 and value >> (16 - shift). +#[inline] +fn hwsl(halfword: u16, shift: u8) -> (u16, u16) { + if shift == 0 { + (halfword, 0) + } else { + ( + halfword << shift, // u16 naturally wraps at 16 bits + halfword >> (16 - shift), + ) + } +} + +#[allow(clippy::needless_range_loop)] +/// Generate the KECCAK_RND trace table. +/// +/// Each `KeccakRoundOperation` produces 24 rows (one per round). The trace +/// computes all intermediate values (θ, ρ, π, χ, ι) at byte granularity. +pub fn generate_keccak_rnd_trace( + ops: &[KeccakRoundOperation], +) -> TraceTable { + let n_rows = (ops.len() * 24).next_power_of_two().max(4); + let mut data = vec![FE::zero(); n_rows * cols::NUM_COLUMNS]; + + for (op_idx, op) in ops.iter().enumerate() { + // Execute round-by-round, tracking the state + let mut state = op.input; + + for round in 0..24 { + let row_idx = op_idx * 24 + round; + let base = row_idx * cols::NUM_COLUMNS; + + // Timestamp & round + data[base + cols::TIMESTAMP_0] = FE::from(op.timestamp & 0xFFFF_FFFF); + data[base + cols::TIMESTAMP_1] = FE::from(op.timestamp >> 32); + data[base + cols::ROUND] = FE::from(round as u64); + + // start = current state as bytes + for x in 0..5 { + for y in 0..5 { + let lane = state[x + 5 * y]; + for b in 0..8 { + data[base + cols::start(x, y, b)] = FE::from(byte_of(lane, b) as u64); + } + } + } + + // === θ (theta) === + // Column parities: C[x] = XOR of all 5 lanes in column x + // Computed as a chain: Cxz[x][0] = start[x,0] XOR start[x,1] + // Cxz[x][k] = Cxz[x][k-1] XOR start[x,k+1] + let mut c_bytes = [[0u8; 8]; 5]; // C[x][byte] = final parity + let mut cxz = [[[0u8; 8]; 4]; 5]; // Cxz[x][stage][byte] + for x in 0..5 { + // Stage 0: XOR(start[x,0], start[x,1]) + for b in 0..8 { + let v0 = byte_of(state[x], b); + let v1 = byte_of(state[x + 5], b); + cxz[x][0][b] = v0 ^ v1; + data[base + cols::cxz(x, 0, b)] = FE::from(cxz[x][0][b] as u64); + } + // Stages 1..3: XOR(Cxz[x][k-1], start[x, k+1]) + for stage in 1..4 { + let y = stage + 1; + for b in 0..8 { + let prev = cxz[x][stage - 1][b]; + let sv = byte_of(state[x + 5 * y], b); + cxz[x][stage][b] = prev ^ sv; + data[base + cols::cxz(x, stage, b)] = FE::from(cxz[x][stage][b] as u64); + } + } + c_bytes[x] = cxz[x][3]; + } + + // Rotate C left by 1 bit using HWSL decomposition. + // HWSL shifts each halfword (u16) independently. For shift=1, the + // carry is a single bit (top bit of the halfword); we store it in + // one column per halfword (Cxz_right[x][hw], spec d75944ee). + // rotated_Cxz[z] = Cxz_left[z] + (1 - z%2) * Cxz_right[(z/2 - 1) mod 4] + let mut cxz_left_bytes = [[0u8; 8]; 5]; + let mut cxz_right_bits = [[0u8; 4]; 5]; + let mut rotated_c = [[0u8; 8]; 5]; + for x in 0..5 { + for hw in 0..4 { + let lo = c_bytes[x][hw * 2] as u16; + let hi = c_bytes[x][hw * 2 + 1] as u16; + let halfword = lo | (hi << 8); + let (shifted, carry) = hwsl(halfword, 1); + cxz_left_bytes[x][hw * 2] = (shifted & 0xFF) as u8; + cxz_left_bytes[x][hw * 2 + 1] = (shifted >> 8) as u8; + // For shift=1, carry ∈ {0, 1}. + cxz_right_bits[x][hw] = carry as u8; + data[base + cols::cxz_left(x, hw * 2)] = + FE::from(cxz_left_bytes[x][hw * 2] as u64); + data[base + cols::cxz_left(x, hw * 2 + 1)] = + FE::from(cxz_left_bytes[x][hw * 2 + 1] as u64); + data[base + cols::cxz_right_bit(x, hw)] = + FE::from(cxz_right_bits[x][hw] as u64); + } + // Reconstruct: left[b] + (1 - b%2) * right[(b/2 + 3) mod 4] + for b in 0..8 { + let right_contribution = match cols::cxz_right_bit_for_byte(b) { + Some(hw) => cxz_right_bits[x][hw], + None => 0, + }; + rotated_c[x][b] = cxz_left_bytes[x][b].wrapping_add(right_contribution); + } + } + + // D[x] = C[(x-1)%5] XOR rotated_C[(x+1)%5] + let mut d_bytes = [[0u8; 8]; 5]; + for x in 0..5 { + for b in 0..8 { + let val = c_bytes[(x + 4) % 5][b] ^ rotated_c[(x + 1) % 5][b]; + d_bytes[x][b] = val; + data[base + cols::dxz(x, b)] = FE::from(val as u64); + } + } + + // theta[x][y] = start[x][y] XOR D[x] + let mut theta_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let lane = state[x + 5 * y]; + let mut d_lane = 0u64; + for b in 0..8 { + d_lane |= (d_bytes[x][b] as u64) << (b * 8); + } + theta_lanes[x + 5 * y] = lane ^ d_lane; + for b in 0..8 { + data[base + cols::theta(x, y, b)] = + FE::from(byte_of(theta_lanes[x + 5 * y], b) as u64); + } + } + } + + // === ρ (rho) === + // For each lane, rotate theta[x][y] by KECCAK_RHO[x][y] bits. + // Decompose rotation as: rnc (nibble, 0..15) + 16*rbc[0] + 32*rbc[1]. + // rnc and rbc are inlined as compile-time constants per spec + // [[variables.constant]]; only HWSL outputs are stored in the trace. + for x in 0..5 { + for y in 0..5 { + let rho_offset = KECCAK_RHO[x][y] as usize; + let rnc_val = (rho_offset % 16) as u8; + let theta_lane = theta_lanes[x + 5 * y]; + for hw in 0..4 { + let halfword = ((theta_lane >> (hw * 16)) & 0xFFFF) as u16; + let (shifted, carry) = hwsl(halfword, rnc_val); + data[base + cols::rot_left(x, y, hw * 2)] = + FE::from((shifted & 0xFF) as u64); + data[base + cols::rot_left(x, y, hw * 2 + 1)] = + FE::from((shifted >> 8) as u64); + data[base + cols::rot_right(x, y, hw * 2)] = + FE::from((carry & 0xFF) as u64); + data[base + cols::rot_right(x, y, hw * 2 + 1)] = + FE::from((carry >> 8) as u64); + } + } + } + + // === π (pi) === + // pi[x][y] = rho[(x+3y)%5][x] where rho is the rotated theta. + // pi is a spec [[variables.virtual]] — not stored as trace columns. + // It's reconstructed inline in chi bus interactions as + // pi[x][y][z] = rot_left[sx,sy,l_byte] + rot_right[sx,sy,r_byte] + // with (sx, sy) = ((x+3y)%5, x) and (l_byte, r_byte) resolved from + // the compile-time rbc constant. pi_lanes is still computed here + // for the chi step below. + let mut pi_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let rotated = theta_lanes[x + 5 * y].rotate_left(KECCAK_RHO[x][y]); + let dst_x = y; + let dst_y = (2 * x + 3 * y) % 5; + pi_lanes[dst_x + 5 * dst_y] = rotated; + } + } + + // === χ (chi) === + let mut chi_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let not_next = !pi_lanes[(x + 1) % 5 + 5 * y]; + let next2 = pi_lanes[(x + 2) % 5 + 5 * y]; + let and_val = not_next & next2; + chi_lanes[x + 5 * y] = pi_lanes[x + 5 * y] ^ and_val; + for b in 0..8 { + data[base + cols::chi_ands(x, y, b)] = FE::from(byte_of(and_val, b) as u64); + data[base + cols::chi(x, y, b)] = + FE::from(byte_of(chi_lanes[x + 5 * y], b) as u64); + } + } + } + + // === ι (iota) === + let rc_val = KECCAK_RC[round]; + for b in 0..8 { + data[base + cols::rc(b)] = FE::from(byte_of(rc_val, b) as u64); + let iota_byte = byte_of(chi_lanes[0], b) ^ byte_of(rc_val, b); + data[base + cols::iota(b)] = FE::from(iota_byte as u64); + } + + // Update state for next round + chi_lanes[0] ^= rc_val; + state = chi_lanes; + + // mu = 1 (real row) + data[base + cols::MU] = FE::one(); + } + } + + // Padding rows have mu=0 and all zeros (default) + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +// ========================================================================= +// Bus interactions (1,371 total) +// ========================================================================= + +#[allow(clippy::needless_range_loop)] +pub fn bus_interactions() -> Vec { + let mut interactions = Vec::with_capacity(1371); + + // --- IO group (3) --- + + // 1. KECCAK bus: receive (timestamp, round, start[200]) + // Per spec keccak_round.toml: input = ["timestamp", "round", "start"] where + // start is [[[Byte, 8], 5], 5] — 200 Byte elements, each its own bus element. + { + let mut values = vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::ROUND, + packing: Packing::Direct, + }, + ]; + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::start(x, y, b), + packing: Packing::Direct, + }); + } + } + } + interactions.push(BusInteraction::receiver( + BusId::Keccak, + Multiplicity::Column(cols::MU), + values, + )); + } + + // 2. KECCAK bus: send (timestamp, round+1, out[200]) + // out[0][0] = iota, out[x][y] = chi for (x,y) != (0,0) + { + let mut values = vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::ROUND, + }, + LinearTerm::Constant(1), + ]), + ]; + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + let col = if x == 0 && y == 0 { + cols::IOTA + b + } else { + cols::chi(x, y, b) + }; + values.push(BusValue::Packed { + start_column: col, + packing: Packing::Direct, + }); + } + } + } + interactions.push(BusInteraction::sender( + BusId::Keccak, + Multiplicity::Column(cols::MU), + values, + )); + } + + // 3. KECCAK_RC: lookup (round) → rc[8] + { + let mut values = vec![BusValue::Packed { + start_column: cols::ROUND, + packing: Packing::Direct, + }]; + for b in 0..8 { + values.push(BusValue::Packed { + start_column: cols::rc(b), + packing: Packing::Direct, + }); + } + interactions.push(BusInteraction::sender( + BusId::KeccakRc, + Multiplicity::Column(cols::MU), + values, + )); + } + + // --- Theta: Cxz chain XOR_BYTE (160) --- + // Stage 0: XOR(start[x,0,z], start[x,1,z]) → Cxz[x,0,z] + for x in 0..5 { + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::start(x, 0, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::start(x, 1, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::cxz(x, 0, b), + packing: Packing::Direct, + }, + ], + )); + } + } + // Stages 1..3: XOR(Cxz[x,stage-1,z], start[x,stage+1,z]) → Cxz[x,stage,z] + for x in 0..5 { + for stage in 1..4usize { + let y = stage + 1; + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::cxz(x, stage - 1, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::start(x, y, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::cxz(x, stage, b), + packing: Packing::Direct, + }, + ], + )); + } + } + } + + // --- Theta: HWSL for rotated C (20) --- + // HWSL(C[x] halfword[hw], 1) → (Cxz_left, Cxz_right) + // Cxz_right is a single carry bit zero-extended to a halfword (spec d75944ee). + for x in 0..5 { + for hw in 0..4 { + interactions.push(BusInteraction::sender( + BusId::Hwsl, + Multiplicity::Column(cols::MU), + vec![ + // Input halfword: Cxz[x][3][hw*2] + 256 * Cxz[x][3][hw*2+1] + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::cxz(x, 3, hw * 2), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::cxz(x, 3, hw * 2 + 1), + }, + ]), + // Shift amount = 1 + BusValue::constant(1), + // Output: shifted + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::cxz_left(x, hw * 2), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::cxz_left(x, hw * 2 + 1), + }, + ]), + // Output: carry (single bit cast to Half — high byte = 0). + BusValue::Packed { + start_column: cols::cxz_right_bit(x, hw), + packing: Packing::Direct, + }, + ], + )); + } + } + + // --- Theta: IS_BYTE range checks on Cxz_left (40) --- + // Cxz_right uses IS_BIT polynomial constraints (see create_constraints). + for x in 0..5 { + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::IsByte, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: cols::cxz_left(x, b), + packing: Packing::Direct, + }], + )); + } + } + + // --- Theta: Dxz XOR_BYTE (40) --- + // D[x][b] = C[(x-1)%5][b] XOR rotated_C[(x+1)%5][b] + // rotated_C[x'][b] = Cxz_left[x'][b] + (1 - b%2) * Cxz_right[x'][(b/2 - 1)%4] + // (spec d75944ee/9143370f). For odd b only Cxz_left contributes. + for x in 0..5 { + for b in 0..8 { + let mut rotated_c_terms = vec![LinearTerm::Column { + coefficient: 1, + column: cols::cxz_left((x + 1) % 5, b), + }]; + if let Some(hw) = cols::cxz_right_bit_for_byte(b) { + rotated_c_terms.push(LinearTerm::Column { + coefficient: 1, + column: cols::cxz_right_bit((x + 1) % 5, hw), + }); + } + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::cxz((x + 4) % 5, 3, b), + packing: Packing::Direct, + }, + BusValue::linear(rotated_c_terms), + BusValue::Packed { + start_column: cols::dxz(x, b), + packing: Packing::Direct, + }, + ], + )); + } + } + + // --- Theta final: XOR_BYTE (200) --- + // theta[x][y][b] = start[x][y][b] XOR D[x][b] + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::start(x, y, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::dxz(x, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::theta(x, y, b), + packing: Packing::Direct, + }, + ], + )); + } + } + } + + // --- Rho: HWSL (100) --- + // HWSL(theta[x][y] halfword[hw], rnc[x][y]) → (rot_left, rot_right) + // rnc is inlined as a constant: KECCAK_RHO[x][y] % 16. + for x in 0..5 { + for y in 0..5 { + let rnc_val = (KECCAK_RHO[x][y] % 16) as u64; + for hw in 0..4 { + interactions.push(BusInteraction::sender( + BusId::Hwsl, + Multiplicity::Column(cols::MU), + vec![ + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::theta(x, y, hw * 2), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::theta(x, y, hw * 2 + 1), + }, + ]), + BusValue::constant(rnc_val), + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::rot_left(x, y, hw * 2), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::rot_left(x, y, hw * 2 + 1), + }, + ]), + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::rot_right(x, y, hw * 2), + }, + LinearTerm::Column { + coefficient: 256, + column: cols::rot_right(x, y, hw * 2 + 1), + }, + ]), + ], + )); + } + } + } + + // --- Rho: IS_BYTE range checks on rot_left + rot_right (400) --- + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::IsByte, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: cols::rot_left(x, y, b), + packing: Packing::Direct, + }], + )); + interactions.push(BusInteraction::sender( + BusId::IsByte, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: cols::rot_right(x, y, b), + packing: Packing::Direct, + }], + )); + } + } + } + + // --- Chi: AND_BYTE (200) --- + // chi_ands[x][y][b] = (255 - pi[(x+1)%5][y][b]) AND pi[(x+2)%5][y][b] + // pi is virtual: pi[x][y][z] = rot_left[sx,sy,l_byte] + rot_right[sx,sy,r_byte] + // with src lane (sx,sy) = ((x+3y)%5, x) and byte offsets from KECCAK_RHO. + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + let (p1_l, p1_r) = cols::pi_src_cols((x + 1) % 5, y, b); + let (p2_l, p2_r) = cols::pi_src_cols((x + 2) % 5, y, b); + interactions.push(BusInteraction::sender( + BusId::AndByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::linear(vec![ + LinearTerm::Constant(255), + LinearTerm::Column { + coefficient: -1, + column: p1_l, + }, + LinearTerm::Column { + coefficient: -1, + column: p1_r, + }, + ]), + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: p2_l, + }, + LinearTerm::Column { + coefficient: 1, + column: p2_r, + }, + ]), + BusValue::Packed { + start_column: cols::chi_ands(x, y, b), + packing: Packing::Direct, + }, + ], + )); + } + } + } + + // --- Chi: XOR_BYTE (200) --- + // chi[x][y][b] = pi[x][y][b] XOR chi_ands[x][y][b] (pi virtual). + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + let (p_l, p_r) = cols::pi_src_cols(x, y, b); + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: p_l, + }, + LinearTerm::Column { + coefficient: 1, + column: p_r, + }, + ]), + BusValue::Packed { + start_column: cols::chi_ands(x, y, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::chi(x, y, b), + packing: Packing::Direct, + }, + ], + )); + } + } + } + + // --- Iota: XOR_BYTE (8) --- + // iota[b] = chi[0][0][b] XOR rc[b] + for b in 0..8 { + interactions.push(BusInteraction::sender( + BusId::XorByte, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::chi(0, 0, b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::rc(b), + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::iota(b), + packing: Packing::Direct, + }, + ], + )); + } + + interactions +} + +// ========================================================================= +// Constraints +// ========================================================================= + +/// KECCAK_RND polynomial constraints: 20 IS_BIT(μ; Cxz_right) constraints. +/// +/// Per spec d75944ee, `Cxz_right` is typed `[Bit, 4], 5` and range-checked via +/// IS_BIT polynomial constraints (kind="template", cond="μ"), not lookups: +/// μ * Cxz_right[x][hw] * (1 - Cxz_right[x][hw]) = 0 +/// +/// - pi is a spec [[variables.virtual]] inlined in chi bus interactions. +/// - rnc/rbc are spec [[variables.constant]] inlined as compile-time constants. +/// +/// All other checks (XOR, AND, HWSL, IS_BYTE, IS_HALF, KECCAK, KECCAK_RC) are +/// enforced via bus interactions against the BITWISE/KECCAK_RC chips. +pub fn create_constraints( + constraint_idx_start: usize, +) -> ( + Vec>>, + usize, +) { + use crate::constraints::templates::IsBitConstraint; + + let mut constraints: Vec< + Box>, + > = Vec::with_capacity(20); + let mut idx = constraint_idx_start; + for x in 0..5 { + for hw in 0..4 { + constraints + .push(IsBitConstraint::new(cols::MU, cols::cxz_right_bit(x, hw), idx).boxed()); + idx += 1; + } + } + (constraints, idx) +} + +#[cfg(test)] +mod tests { + use super::*; + use executor::vm::instruction::execution::keccak_f1600; + + /// pi is a spec virtual variable. Verify the inlined expression + /// (rot_left[sx,sy,l_byte] + rot_right[sx,sy,r_byte]) matches the byte of + /// rho(theta) for a non-trivial state. Uses mu=0 padding rows as a trivial + /// sanity check (all zeros), then a non-zero-input round as the real test. + #[test] + fn test_pi_virtual_matches_rotate() { + // Use a non-zero input so theta_lanes are non-trivial. + let input = [0x0102030405060708u64; 25]; + let mut output = input; + keccak_f1600(&mut output); + let op = KeccakRoundOperation { + timestamp: 42, + input, + output, + }; + let trace = generate_keccak_rnd_trace(&[op]); + let base = 0; + + // Recompute theta for round 0 in u64 to compare against virtual pi. + let mut c = [0u64; 5]; + for x in 0..5 { + c[x] = input[x] ^ input[x + 5] ^ input[x + 10] ^ input[x + 15] ^ input[x + 20]; + } + let mut d = [0u64; 5]; + for x in 0..5 { + d[x] = c[(x + 4) % 5] ^ c[(x + 1) % 5].rotate_left(1); + } + let mut theta_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + theta_lanes[x + 5 * y] = input[x + 5 * y] ^ d[x]; + } + } + + for x in 0..5 { + for y in 0..5 { + let sx = (x + 3 * y) % 5; + let sy = x; + let rotated = theta_lanes[sx + 5 * sy].rotate_left(KECCAK_RHO[sx][sy]); + for z in 0..8 { + let (l_col, r_col) = cols::pi_src_cols(x, y, z); + let virtual_pi = + &trace.main_table.data[base + l_col] + &trace.main_table.data[base + r_col]; + let expected = FE::from((rotated >> (z * 8)) & 0xFF); + assert_eq!( + virtual_pi, expected, + "virtual pi mismatch at ({x},{y},{z}): sx={sx}, sy={sy}" + ); + } + } + } + } +} diff --git a/prover/src/tables/mod.rs b/prover/src/tables/mod.rs index 19d14411d..4a6032ef2 100644 --- a/prover/src/tables/mod.rs +++ b/prover/src/tables/mod.rs @@ -28,6 +28,9 @@ pub mod cpu; pub mod decode; pub mod dvrm; pub mod halt; +pub mod keccak; +pub mod keccak_rc; +pub mod keccak_rnd; pub mod load; pub mod lt; pub mod memw; diff --git a/prover/src/tables/trace_builder.rs b/prover/src/tables/trace_builder.rs index d2743a1e5..e7a662502 100644 --- a/prover/src/tables/trace_builder.rs +++ b/prover/src/tables/trace_builder.rs @@ -10,10 +10,10 @@ //! ```text //! PHASE 0: ELF → DECODE, MEMORY_INIT (preprocessed tables) //! PHASE 1: Logs → CPU ops -//! PHASE 2: CPU ops → MEMW, MEMW_A, MEMW_R, LOAD, LT, Bitwise (with state tracking for MEMW/LOAD) +//! PHASE 2: CPU ops → MEMW, MEMW_A, MEMW_R, LOAD, LT, Bitwise, KECCAK (with state tracking for MEMW/LOAD/ECALL) //! PHASE 3: MEMW/MEMW_A → LT ops (timestamp ordering); MEMW_R uses IS_HALFWORD instead -//! PHASE 4: LT, MEMW_A, MEMW_R → Bitwise lookups -//! PHASE 5: Generate all traces +//! PHASE 4: LT, MEMW_A, MEMW_R, KECCAK → Bitwise lookups +//! PHASE 5: Generate all traces (including KECCAK core, KECCAK_RND, KECCAK_RC) //! ``` //! //! ## Usage @@ -40,6 +40,9 @@ use super::cpu::{self, CpuOperation}; use super::decode; use super::dvrm::{self, DvrmOperation}; use super::halt; +use super::keccak::{self, KeccakOperation}; +use super::keccak_rc; +use super::keccak_rnd::{self, KeccakRoundOperation}; use super::load::{self, LoadOperation}; use super::lt::{self, LtOperation}; use super::memw::{self, MemwOperation}; @@ -335,7 +338,7 @@ fn collect_cpu_ops( /// /// MEMW and LOAD collection requires sequential processing with state tracking. /// -/// Returns: (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops) +/// Returns: (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops) #[allow(clippy::type_complexity)] fn collect_ops_from_cpu( cpu_ops: &[CpuOperation], @@ -348,6 +351,7 @@ fn collect_ops_from_cpu( Vec, Vec, Vec, + Vec, ) { let mut memw_ops = Vec::with_capacity(cpu_ops.len() * 3); let mut load_ops = Vec::with_capacity(cpu_ops.len() / 8 + 1); @@ -355,6 +359,7 @@ fn collect_ops_from_cpu( let mut shift_ops = Vec::with_capacity(cpu_ops.len() / 10 + 1); let mut bitwise_ops = Vec::with_capacity(cpu_ops.len() * 4); let mut commit_ops = Vec::new(); + let mut keccak_ops = Vec::new(); let mut current_commit_index = 0u32; let mut commit_ecall_count = 0u32; @@ -397,6 +402,38 @@ fn collect_ops_from_cpu( commit_ecall_count += 1; } + // Collect KeccakPermute ECALL operations + if op.ecall_keccak { + let state_addr = op.keccak_state_addr; + let mut input = [0u64; 25]; + for (i, lane) in input.iter_mut().enumerate() { + let addr = state_addr + .checked_add(i as u64 * 8) + .expect("keccak state address range must be validated by the executor"); + let mut val = 0u64; + for b in 0..8 { + let byte_addr = addr + .checked_add(b as u64) + .expect("keccak state address range must be validated by the executor"); + let (byte_val, _ts) = memory_state.read_byte(byte_addr); + val |= (byte_val as u64) << (b * 8); + } + *lane = val; + } + let mut output = input; + executor::vm::instruction::execution::keccak_f1600(&mut output); + // collect_keccak_memw_ops handles memory_state + register_state updates + let keccak_memw_ops = + collect_keccak_memw_ops(op, &input, &output, memory_state, register_state); + memw_ops.extend(keccak_memw_ops); + keccak_ops.push(KeccakOperation { + timestamp: op.timestamp, + state_addr, + input, + output, + }); + } + // --- LT, SHIFT, and Bitwise (no state tracking needed) --- // Collect LT operations from SLT/BLT instructions @@ -440,6 +477,7 @@ fn collect_ops_from_cpu( shift_ops, bitwise_ops, commit_ops, + keccak_ops, ) } @@ -781,6 +819,73 @@ fn collect_halt_ops(register_state: &mut RegisterState) -> Vec { // ============================================================================= /// Collects LT operations from MEMW for timestamp ordering. +/// Collect MEMW operations for a KeccakPermute ECALL. +/// +/// Generates 25 read operations (input lanes at timestamp) and 25 write +/// operations (output lanes at timestamp+1). Each operation is 8 bytes wide. +fn collect_keccak_memw_ops( + op: &CpuOperation, + input: &[u64; 25], + output: &[u64; 25], + memory_state: &mut MemoryState, + register_state: &mut RegisterState, +) -> Vec { + let ts = op.timestamp; + let state_addr = op.keccak_state_addr; + let mut memw_ops = Vec::with_capacity(26); // 1 register read + 25 lane ops + + // Per spec (keccak:c:read_addr): read register x10 to get state_addr + { + let reg_value = pack_register_value(state_addr); + let reg_addr = 2 * 10u64; // x10 → address 20 + let (_old_val, old_ts) = register_state.read(10); + let old_timestamps = [old_ts, old_ts, 0, 0, 0, 0, 0, 0]; + let memw_op = MemwOperation::new(true, reg_addr, reg_value, ts, 2, true) + .with_old(reg_value, old_timestamps); + memw_ops.push(memw_op); + register_state.write(10, state_addr, ts); + } + + // Per spec (keccak:c:load_store_state): single combined read+write MEMW per lane. + // input = [0, state_ptr, output_state, timestamp, 0, 0, 1], output = input_state + // The MEMW table sees: old=input_state, value=output_state, is_read=true. + for (lane_idx, (&in_lane, &out_lane)) in input.iter().zip(output.iter()).enumerate() { + let lane_addr = state_addr + .checked_add(lane_idx as u64 * 8) + .expect("keccak state address range must be validated by the executor"); + + let mut old_bytes = [0u64; 8]; + let mut old_timestamps = [0u64; 8]; + for b in 0..8 { + old_bytes[b] = (in_lane >> (b * 8)) & 0xFF; + let byte_addr = lane_addr + .checked_add(b as u64) + .expect("keccak state address range must be validated by the executor"); + let (_old_val, old_ts) = memory_state.read_byte(byte_addr); + old_timestamps[b] = old_ts; + } + + let mut value_bytes = [0u64; 8]; + for (b, byte) in value_bytes.iter_mut().enumerate() { + *byte = (out_lane >> (b * 8)) & 0xFF; + } + + let memw_op = MemwOperation::new(false, lane_addr, value_bytes, ts, 8, true) + .with_old(old_bytes, old_timestamps); + memw_ops.push(memw_op); + + // Update memory state + for (b, &val) in value_bytes.iter().enumerate() { + let byte_addr = lane_addr + .checked_add(b as u64) + .expect("keccak state address range must be validated by the executor"); + memory_state.write_byte(byte_addr, val as u8, ts); + } + } + + memw_ops +} + /// /// From spec memw.md: /// - MEMW-C4 through MEMW-C7: old_timestamp[i] < timestamp (based on width) @@ -1544,6 +1649,264 @@ fn collect_bitwise_from_commit(commit_ops: &[CommitOperation]) -> Vec Vec { + use executor::vm::instruction::execution::{KECCAK_RC, KECCAK_RHO}; + + let mut ops = Vec::new(); + + for kop in keccak_ops { + let state_addr = kop.state_addr; + + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::AndByte, + (state_addr & 0xFF) as u8, + 7, + )); + + // Range-check addr bytes (paired with the IS_BYTE sends in + // keccak::bus_interactions): without this the field-element value of + // the addr_lo / addr_hi linear combinations is unconstrained per byte. + for b in 0..8 { + let byte = ((state_addr >> (b * 8)) & 0xFF) as u8; + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + byte, + )); + } + + // IS_HALF for state_ptr halfwords (100 per call) + for lane_idx in 0..25 { + let ptr = state_addr + .checked_add(lane_idx as u64 * 8) + .expect("keccak state address range must be validated by the executor"); + for shift in [0, 16, 32, 48] { + let half = ((ptr >> shift) & 0xFFFF) as u16; + ops.push(BitwiseOperation::halfword( + BitwiseOperationType::IsHalf, + (half & 0xFF) as u8, + ((half >> 8) & 0xFF) as u8, + )); + } + } + + // Replay keccak round computation to extract bitwise lookups + let mut state = kop.input; + for round in 0..24 { + // --- theta: Cxz chain XOR_BYTE (160) --- + let mut cxz = [[[0u8; 8]; 4]; 5]; + for x in 0..5 { + for b in 0..8 { + let v0 = ((state[x] >> (b * 8)) & 0xFF) as u8; + let v1 = ((state[x + 5] >> (b * 8)) & 0xFF) as u8; + cxz[x][0][b] = v0 ^ v1; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + v0, + v1, + )); + } + for stage in 1..4usize { + let y = stage + 1; + for b in 0..8 { + let prev = cxz[x][stage - 1][b]; + let sv = ((state[x + 5 * y] >> (b * 8)) & 0xFF) as u8; + cxz[x][stage][b] = prev ^ sv; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + prev, + sv, + )); + } + } + } + + // theta: HWSL for rotated C (20) + IS_BYTE on Cxz_left (40). + // Cxz_right is range-checked via IS_BIT polynomial constraints + // on the keccak_rnd chip, not via lookups (spec d75944ee). + let mut rotated_c = [[0u8; 8]; 5]; + for x in 0..5 { + let c = cxz[x][3]; + for hw in 0..4 { + let halfword = (c[hw * 2] as u16) | ((c[hw * 2 + 1] as u16) << 8); + let shifted = halfword << 1; // u16 wraps + ops.push(BitwiseOperation::new( + BitwiseOperationType::Hwsl, + (halfword & 0xFF) as u8, + ((halfword >> 8) & 0xFF) as u8, + 1, + )); + // IS_BYTE for cxz_left bytes + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + (shifted & 0xFF) as u8, + )); + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + ((shifted >> 8) & 0xFF) as u8, + )); + } + // Reconstruct rotated_c using the bit-typed Cxz_right. + let mut left_bytes = [0u8; 8]; + let mut right_bits = [0u8; 4]; + for hw in 0..4 { + let halfword = (c[hw * 2] as u16) | ((c[hw * 2 + 1] as u16) << 8); + let shifted = halfword << 1; + left_bytes[hw * 2] = (shifted & 0xFF) as u8; + left_bytes[hw * 2 + 1] = ((shifted >> 8) & 0xFF) as u8; + right_bits[hw] = (halfword >> 15) as u8; + } + for b in 0usize..8 { + let right_contribution = if b.is_multiple_of(2) { + right_bits[(b / 2 + 3) % 4] + } else { + 0 + }; + rotated_c[x][b] = left_bytes[b].wrapping_add(right_contribution); + } + } + + // theta: Dxz XOR_BYTE (40) + let mut d_bytes = [[0u8; 8]; 5]; + for x in 0..5 { + for b in 0..8 { + let a = cxz[(x + 4) % 5][3][b]; + let rb = rotated_c[(x + 1) % 5][b]; + d_bytes[x][b] = a ^ rb; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + a, + rb, + )); + } + } + + // theta final: XOR_BYTE (200) + let mut theta_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let lane = state[x + 5 * y]; + let mut d_lane = 0u64; + for b in 0..8 { + d_lane |= (d_bytes[x][b] as u64) << (b * 8); + } + theta_lanes[x + 5 * y] = lane ^ d_lane; + for b in 0..8 { + let s = ((lane >> (b * 8)) & 0xFF) as u8; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + s, + d_bytes[x][b], + )); + } + } + } + + // rho: HWSL (100) + IS_BYTE (400) + for x in 0..5 { + for y in 0..5 { + let rho_offset = KECCAK_RHO[x][y] as usize; + let rnc_val = (rho_offset % 16) as u8; + let theta_lane = theta_lanes[x + 5 * y]; + for hw in 0..4 { + let halfword = ((theta_lane >> (hw * 16)) & 0xFFFF) as u16; + let (shifted, carry) = if rnc_val == 0 { + (halfword, 0u16) + } else { + (halfword << rnc_val, halfword >> (16 - rnc_val)) + }; + ops.push(BitwiseOperation::new( + BitwiseOperationType::Hwsl, + (halfword & 0xFF) as u8, + ((halfword >> 8) & 0xFF) as u8, + rnc_val, + )); + // IS_BYTE for rot_left + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + (shifted & 0xFF) as u8, + )); + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + ((shifted >> 8) & 0xFF) as u8, + )); + // IS_BYTE for rot_right + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + (carry & 0xFF) as u8, + )); + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::IsByte, + ((carry >> 8) & 0xFF) as u8, + )); + } + } + } + + // pi: compute pi_lanes + let mut pi_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let rotated = theta_lanes[x + 5 * y].rotate_left(KECCAK_RHO[x][y]); + let dst_x = y; + let dst_y = (2 * x + 3 * y) % 5; + pi_lanes[dst_x + 5 * dst_y] = rotated; + } + } + + // chi: AND_BYTE (200) + XOR_BYTE (200) + let mut chi_lanes = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + let not_next = !pi_lanes[(x + 1) % 5 + 5 * y]; + let next2 = pi_lanes[(x + 2) % 5 + 5 * y]; + let and_val = not_next & next2; + chi_lanes[x + 5 * y] = pi_lanes[x + 5 * y] ^ and_val; + for b in 0..8 { + let not_byte = ((not_next >> (b * 8)) & 0xFF) as u8; + let n2_byte = ((next2 >> (b * 8)) & 0xFF) as u8; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::AndByte, + not_byte, + n2_byte, + )); + let pi_byte = ((pi_lanes[x + 5 * y] >> (b * 8)) & 0xFF) as u8; + let and_byte = ((and_val >> (b * 8)) & 0xFF) as u8; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + pi_byte, + and_byte, + )); + } + } + } + + // iota: XOR_BYTE (8) + let rc_val = KECCAK_RC[round]; + for b in 0..8 { + let chi_byte = ((chi_lanes[0] >> (b * 8)) & 0xFF) as u8; + let rc_byte = ((rc_val >> (b * 8)) & 0xFF) as u8; + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::XorByte, + chi_byte, + rc_byte, + )); + } + + // Update state + chi_lanes[0] ^= rc_val; + state = chi_lanes; + } + } + + ops +} + /// every address accessed during execution (ELF init + runtime stores/loads). /// ELF pages get their init data from the binary; all others are zero-init. fn generate_page_tables( @@ -1664,6 +2027,15 @@ pub struct Traces { /// COMMIT table for write syscall (byte-by-byte commit with recursive bus) pub commit: TraceTable, + /// KECCAK core table (one row per keccak permutation call) + pub keccak: TraceTable, + + /// KECCAK_RND round table (24 rows per keccak call) + pub keccak_rnd: TraceTable, + + /// KECCAK_RC precomputed round constant table (32 rows) + pub keccak_rc: TraceTable, + /// MEMW_R register-only fast-path traces (split into chunks of max_rows::MEMW_R) pub memw_registers: Vec>, } @@ -1683,6 +2055,7 @@ struct CollectedOps { mul_ops: Vec<(MulOperation, bool)>, dvrm_ops: Vec<(DvrmOperation, bool)>, commit_ops: Vec, + keccak_ops: Vec, } /// Chunk raw ops and generate one trace table per chunk. @@ -1711,6 +2084,7 @@ fn collect_all_ops( shift_ops: Vec, bitwise_ops: Vec, commit_ops: Vec, + keccak_ops: Vec, register_state: &mut RegisterState, ) -> CollectedOps { // HALT finalization: 33 register MEMW operations at timestamp u64::MAX. @@ -1800,6 +2174,7 @@ fn collect_all_ops( mul_ops, dvrm_ops, commit_ops, + keccak_ops, } } @@ -1832,6 +2207,7 @@ fn build_traces( mul_ops, dvrm_ops, commit_ops, + keccak_ops, } = ops; // ===================================================================== @@ -1863,6 +2239,8 @@ fn build_traces( .collect(); // COMMIT table sends IsByte and IsHalfword lookups bitwise_ops.extend(collect_bitwise_from_commit(&commit_ops)); + // KECCAK_RND sends XOR/AND/IS_BYTE/HWSL; KECCAK core sends IS_HALF + bitwise_ops.extend(collect_bitwise_from_keccak(&keccak_ops)); // CPU padding rows send IS_BYTE with all-zero values. // Add corresponding ops so the bitwise table multiplicities balance. @@ -1921,6 +2299,21 @@ fn build_traces( // Generate remaining traces in parallel (page, register, halt, commit). // chunk_and_generate already handled cpu, lt, memw, load, mul, dvrm, branch above. let commit_trace = commit::generate_commit_trace(&commit_ops); + + // Generate keccak traces (core table + per-round table + preprocessed RC) + let keccak_rnd_ops: Vec = keccak_ops + .iter() + .map(|op| KeccakRoundOperation { + timestamp: op.timestamp, + input: op.input, + output: op.output, + }) + .collect(); + let keccak_trace = keccak::generate_keccak_trace(&keccak_ops); + let keccak_rnd_trace = keccak_rnd::generate_keccak_rnd_trace(&keccak_rnd_ops); + let mut keccak_rc_trace = keccak_rc::generate_keccak_rc_trace(); + keccak_rc::update_multiplicities(&mut keccak_rc_trace, keccak_ops.len()); + let (pages, page_configs, register_trace, halt_trace); #[cfg(feature = "parallel")] { @@ -1977,6 +2370,9 @@ fn build_traces( branches, halt: halt_trace, commit: commit_trace, + keccak: keccak_trace, + keccak_rnd: keccak_rnd_trace, + keccak_rc: keccak_rc_trace, memw_registers, }) } @@ -1999,6 +2395,10 @@ impl Traces { use super::decode::cols::NUM_COLUMNS as DECODE_COLS; use super::dvrm::cols::NUM_COLUMNS as DVRM_COLS; use super::halt::cols::NUM_COLUMNS as HALT_COLS; + use super::keccak::cols::NUM_COLUMNS as KECCAK_COLS; + use super::keccak_rc::NUM_PRECOMPUTED_COLS as KECCAK_RC_PRECOMPUTED; + use super::keccak_rc::cols::NUM_COLUMNS as KECCAK_RC_COLS; + use super::keccak_rnd::cols::NUM_COLUMNS as KECCAK_RND_COLS; use super::load::cols::NUM_COLUMNS as LOAD_COLS; use super::lt::cols::NUM_COLUMNS as LT_COLS; use super::memw::cols::NUM_COLUMNS as MEMW_COLS; @@ -2027,6 +2427,9 @@ impl Traces { branches, halt, commit, + keccak, + keccak_rnd, + keccak_rc, memw_registers, page_configs: _, public_output_bytes: _, @@ -2071,6 +2474,9 @@ impl Traces { for t in memw_registers { total += (t.num_rows() * MEMW_R_COLS) as u64; } + total += (keccak.num_rows() * KECCAK_COLS) as u64; + total += (keccak_rnd.num_rows() * KECCAK_RND_COLS) as u64; + total += (keccak_rc.num_rows() * (KECCAK_RC_COLS - KECCAK_RC_PRECOMPUTED)) as u64; total } @@ -2103,6 +2509,9 @@ impl Traces { // page::bus_interactions count is constant regardless of page_base. let n_page = aux_cols(super::page::bus_interactions(0).len()); let n_memw_r = aux_cols(super::memw_register::bus_interactions().len()); + let n_keccak = aux_cols(super::keccak::bus_interactions().len()); + let n_keccak_rnd = aux_cols(super::keccak_rnd::bus_interactions().len()); + let n_keccak_rc = aux_cols(super::keccak_rc::bus_interactions().len()); let Traces { cpus, @@ -2120,6 +2529,9 @@ impl Traces { branches, halt, commit, + keccak, + keccak_rnd, + keccak_rc, memw_registers, page_configs: _, public_output_bytes: _, @@ -2164,6 +2576,9 @@ impl Traces { for t in memw_registers { total += (t.num_rows() * n_memw_r) as u64; } + total += (keccak.num_rows() * n_keccak) as u64; + total += (keccak_rnd.num_rows() * n_keccak_rnd) as u64; + total += (keccak_rc.num_rows() * n_keccak_rc) as u64; total } @@ -2322,7 +2737,7 @@ impl Traces { let mut memory_state = MemoryState::from_elf(elf); memory_state.add_private_input(private_input); let mut register_state = RegisterState::new(elf.entry_point); - let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops) = + let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops) = collect_ops_from_cpu(&cpu_ops, &mut memory_state, &mut register_state); let ops = collect_all_ops( @@ -2333,6 +2748,7 @@ impl Traces { shift_ops, bitwise_ops, commit_ops, + keccak_ops, &mut register_state, ); @@ -2368,7 +2784,7 @@ impl Traces { let mut memory_state = MemoryState::new(); let entry_point = cpu_ops.first().map_or(0, |op| op.decode.pc); let mut register_state = RegisterState::new(entry_point); - let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops) = + let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops) = collect_ops_from_cpu(&cpu_ops, &mut memory_state, &mut register_state); let ops = collect_all_ops( @@ -2379,6 +2795,7 @@ impl Traces { shift_ops, bitwise_ops, commit_ops, + keccak_ops, &mut register_state, ); @@ -2448,6 +2865,210 @@ impl Traces { } } +#[cfg(test)] +mod keccak_tests { + use super::*; + use crate::tables::keccak::cols as core_cols; + use crate::tables::keccak_rnd::cols as rnd_cols; + use crate::tables::types::FE; + use executor::vm::instruction::execution::keccak_f1600; + + fn make_keccak_ops() -> (KeccakOperation, KeccakRoundOperation) { + let input = [0u64; 25]; + let mut output = input; + keccak_f1600(&mut output); + let kop = KeccakOperation { + timestamp: 42, + state_addr: 0x1000, + input, + output, + }; + let rop = KeccakRoundOperation { + timestamp: 42, + input, + output, + }; + (kop, rop) + } + + #[test] + fn test_keccak_bitwise_ops_count() { + let (kop, _) = make_keccak_ops(); + let ops = collect_bitwise_from_keccak(&[kop]); + + let xor = ops + .iter() + .filter(|o| o.lookup_type == BitwiseOperationType::XorByte) + .count(); + let and = ops + .iter() + .filter(|o| o.lookup_type == BitwiseOperationType::AndByte) + .count(); + let is_byte = ops + .iter() + .filter(|o| o.lookup_type == BitwiseOperationType::IsByte) + .count(); + let hwsl = ops + .iter() + .filter(|o| o.lookup_type == BitwiseOperationType::Hwsl) + .count(); + let is_half = ops + .iter() + .filter(|o| o.lookup_type == BitwiseOperationType::IsHalf) + .count(); + + assert_eq!(xor, 24 * 608, "XorByte count"); + assert_eq!(and, 24 * 200 + 1, "AndByte count"); + // Cxz_right Byte→Bit (spec d75944ee): drops 40 IS_BYTE per round. + // +8 per call to range-check the addr bytes used in alignment / no-overflow. + assert_eq!(is_byte, 24 * 440 + 8, "IsByte count"); + assert_eq!(hwsl, 24 * 120, "Hwsl count"); + assert_eq!(is_half, 100, "IsHalf count"); + assert_eq!(ops.len(), 109 + 24 * 1368, "Total bitwise ops"); + } + + #[test] + fn test_keccak_round_trace_matches_f1600() { + let (_, rop) = make_keccak_ops(); + let rnd_trace = keccak_rnd::generate_keccak_rnd_trace(&[rop]); + + let mut ref_state = [0u64; 25]; + for round in 0..24 { + let rc = executor::vm::instruction::execution::KECCAK_RC[round]; + let mut c = [0u64; 5]; + for x in 0..5 { + c[x] = ref_state[x] + ^ ref_state[x + 5] + ^ ref_state[x + 10] + ^ ref_state[x + 15] + ^ ref_state[x + 20]; + } + let mut d = [0u64; 5]; + for x in 0..5 { + d[x] = c[(x + 4) % 5] ^ c[(x + 1) % 5].rotate_left(1); + } + for i in 0..25 { + ref_state[i] ^= d[i % 5]; + } + let mut b = [0u64; 25]; + for x in 0..5 { + for y in 0..5 { + b[y + 5 * ((2 * x + 3 * y) % 5)] = ref_state[x + 5 * y] + .rotate_left(executor::vm::instruction::execution::KECCAK_RHO[x][y]); + } + } + for x in 0..5 { + for y in 0..5 { + ref_state[x + 5 * y] = + b[x + 5 * y] ^ (!b[(x + 1) % 5 + 5 * y] & b[(x + 2) % 5 + 5 * y]); + } + } + ref_state[0] ^= rc; + + let base = round * rnd_cols::NUM_COLUMNS; + for (lane, &lane_val) in ref_state.iter().enumerate() { + let x = lane % 5; + let y = lane / 5; + for byte_idx in 0..8 { + let expected = FE::from((lane_val >> (byte_idx * 8)) & 0xFF); + let col = if x == 0 && y == 0 { + rnd_cols::iota(byte_idx) + } else { + rnd_cols::chi(x, y, byte_idx) + }; + let trace_val = &rnd_trace.main_table.data[base + col]; + assert_eq!( + &expected, trace_val, + "Round {round} lane ({x},{y}) byte {byte_idx}" + ); + } + } + } + } + + #[test] + fn test_keccak_core_round_state_consistency() { + let (kop, rop) = make_keccak_ops(); + let core_trace = keccak::generate_keccak_trace(&[kop]); + let rnd_trace = keccak_rnd::generate_keccak_rnd_trace(&[rop]); + + // Round 0 start == core input_state + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + let core_val = &core_trace.main_table.data[core_cols::input_state(x, y, b)]; + let rnd_val = &rnd_trace.main_table.data[rnd_cols::start(x, y, b)]; + assert_eq!(core_val, rnd_val, "Round 0 start mismatch at ({x},{y},{b})"); + } + } + } + + // Round 23 out == core output_state + let rnd_base_23 = 23 * rnd_cols::NUM_COLUMNS; + for x in 0..5 { + for y in 0..5 { + for b in 0..8 { + let core_val = &core_trace.main_table.data[core_cols::output_state(x, y, b)]; + let rnd_val = if x == 0 && y == 0 { + &rnd_trace.main_table.data[rnd_base_23 + rnd_cols::iota(b)] + } else { + &rnd_trace.main_table.data[rnd_base_23 + rnd_cols::chi(x, y, b)] + }; + assert_eq!(core_val, rnd_val, "Round 23 out mismatch at ({x},{y},{b})"); + } + } + } + } + + #[test] + fn test_keccak_bus_interaction_counts() { + assert_eq!( + keccak::bus_interactions().len(), + 138, + "KECCAK core: 1 ECALL + 1 MEMW read_addr + 25 MEMW lanes + 100 IS_HALF + 1 AND_BYTE alignment + 8 IS_BYTE addr + 1 Keccak send + 1 Keccak recv" + ); + assert_eq!( + keccak_rnd::bus_interactions().len(), + 1371, + "KECCAK_RND: 3 IO + 460 theta + 500 rho + 400 chi + 8 iota \ + (Cxz_right Byte→Bit drops 40 IS_BYTE per spec d75944ee)" + ); + assert_eq!( + keccak_rc::bus_interactions().len(), + 1, + "KECCAK_RC: 1 receiver" + ); + } + + #[test] + fn test_keccak_column_counts() { + assert_eq!(core_cols::NUM_COLUMNS, 511, "KECCAK core columns"); + assert_eq!( + rnd_cols::NUM_COLUMNS, + 1480, + "KECCAK_RND columns (rnc/rbc inlined; pi virtual; Cxz_right Bit-typed)" + ); + assert_eq!(keccak_rc::cols::NUM_COLUMNS, 10, "KECCAK_RC columns"); + } + + #[test] + fn test_keccak_constraint_counts() { + let (core_constraints, _) = keccak::create_constraints(0); + assert_eq!( + core_constraints.len(), + 51, + "KECCAK core: 25 ADD pairs + no-overflow" + ); + + let (rnd_constraints, _) = keccak_rnd::create_constraints(0); + assert_eq!( + rnd_constraints.len(), + 20, + "KECCAK_RND: 20 IS_BIT(μ; Cxz_right_bit) per spec d75944ee" + ); + } +} + #[cfg(test)] mod routing_tests { use super::*; diff --git a/prover/src/tables/types.rs b/prover/src/tables/types.rs index a1dcd043a..70aa6813d 100644 --- a/prover/src/tables/types.rs +++ b/prover/src/tables/types.rs @@ -110,6 +110,10 @@ pub enum BusId { /// COMMIT output bus: verifier computes the receiver contribution externally /// from `VmProof.public_output` using the shared LogUp challenges Commit, + /// Keccak core ↔ round chip: (timestamp, round, state[200 bytes]) + Keccak, + /// Keccak round ↔ RC lookup: (round, rc[8 bytes]) + KeccakRc, } impl BusId { @@ -138,6 +142,8 @@ impl BusId { BusId::Dvrm => "Dvrm", BusId::CommitNextByte => "CommitNextByte", BusId::Commit => "Commit", + BusId::Keccak => "Keccak", + BusId::KeccakRc => "KeccakRc", } } } @@ -169,6 +175,8 @@ impl TryFrom for BusId { 19 => Ok(BusId::Ecall), 20 => Ok(BusId::CommitNextByte), 21 => Ok(BusId::Commit), + 22 => Ok(BusId::Keccak), + 23 => Ok(BusId::KeccakRc), other => Err(other), } } diff --git a/prover/src/test_utils.rs b/prover/src/test_utils.rs index b47554857..1dcb768b2 100644 --- a/prover/src/test_utils.rs +++ b/prover/src/test_utils.rs @@ -43,6 +43,13 @@ use crate::tables::dvrm::{ bus_interactions as dvrm_bus_interactions, cols as dvrm_cols, dvrm_constraints, }; use crate::tables::halt::{bus_interactions as halt_bus_interactions, cols as halt_cols}; +use crate::tables::keccak::{bus_interactions as keccak_bus_interactions, cols as keccak_cols}; +use crate::tables::keccak_rc::{ + bus_interactions as keccak_rc_bus_interactions, cols as keccak_rc_cols, +}; +use crate::tables::keccak_rnd::{ + bus_interactions as keccak_rnd_bus_interactions, cols as keccak_rnd_cols, +}; use crate::tables::load::{ bus_interactions as load_bus_interactions, cols as load_cols, constraints as load_constraints, }; @@ -791,3 +798,59 @@ pub fn create_register_air(proof_options: &ProofOptions) -> VmAir { ) .with_name("REGISTER") } + +/// Create KECCAK core AIR with ADD constraints and bus interactions. +pub fn create_keccak_air(proof_options: &ProofOptions) -> VmAir { + let (constraints, _) = crate::tables::keccak::create_constraints(0); + let transition_constraints: Vec>> = constraints; + + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: keccak_bus_interactions(), + }; + + AirWithBuses::new( + keccak_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("KECCAK") +} + +/// Create KECCAK_RND AIR with pi constraints and bus interactions. +pub fn create_keccak_rnd_air(proof_options: &ProofOptions) -> VmAir { + let (constraints, _) = crate::tables::keccak_rnd::create_constraints(0); + let transition_constraints: Vec>> = constraints; + + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: keccak_rnd_bus_interactions(), + }; + + AirWithBuses::new( + keccak_rnd_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("KECCAK_RND") +} + +/// Create KECCAK_RC AIR with bus interactions (preprocessed table). +pub fn create_keccak_rc_air(proof_options: &ProofOptions) -> VmAir { + let transition_constraints: Vec>> = vec![]; + + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: keccak_rc_bus_interactions(), + }; + + AirWithBuses::new( + keccak_rc_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("KECCAK_RC") +} diff --git a/prover/src/tests/cpu_tests.rs b/prover/src/tests/cpu_tests.rs index 6b3239f43..9004d24c0 100644 --- a/prover/src/tests/cpu_tests.rs +++ b/prover/src/tests/cpu_tests.rs @@ -328,7 +328,7 @@ fn test_bus_interactions_count() { // - 1 DVRM (division/remainder) // - 1 SHIFT (shift operations) // - 1 BRANCH (branch/jump target calculation) - // - 1 ECALL (single shared bus for HALT and COMMIT, mult = ECALL) + // - 1 ECALL (shared bus for HALT, COMMIT, and KECCAK, mult = ECALL) // - 1 IS_BYTE for (RS1, RS2) paired // - 1 IS_BYTE for (RD, 0) // - 12 IS_BYTE (ARG1/ARG2/RES byte pairs: 4 pairs × 3 arrays) diff --git a/prover/src/tests/prove_elfs_tests.rs b/prover/src/tests/prove_elfs_tests.rs index 7e0fbc181..adbf02143 100644 --- a/prover/src/tests/prove_elfs_tests.rs +++ b/prover/src/tests/prove_elfs_tests.rs @@ -716,6 +716,100 @@ fn test_prove_elfs_all_instructions_64() { ); } +#[test] +fn test_prove_elfs_keccak() { + let _ = env_logger::builder().is_test(true).try_init(); + + let (elf, logs, _instructions) = run_asm_elf("test_keccak"); + // Must use from_elf_and_logs (not from_logs_minimal) because keccak accesses + // RAM (stack memory), which requires PAGE tables for Memory bus balance. + let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + + assert!( + prove_and_verify_vm_minimal(&elf, &mut traces), + "keccak prove/verify failed" + ); +} + +#[test] +fn test_prove_elfs_keccak_multi_call() { + let _ = env_logger::builder().is_test(true).try_init(); + + let elf_bytes = crate::test_utils::asm_elf_bytes("test_keccak_multi"); + let elf = Elf::load(&elf_bytes).expect("Failed to load ELF"); + let executor = + executor::vm::execution::Executor::new(&elf, vec![]).expect("Failed to create executor"); + let result = executor.run().expect("Failed to run program"); + + // The guest initializes lane[i] = i + 1 and applies keccak-f[1600] three times. + // Cross-check the committed output against tiny-keccak's independent + // implementation of the permutation. + let mut expected_state: [u64; 25] = core::array::from_fn(|i| (i + 1) as u64); + for _ in 0..3 { + tiny_keccak::keccakf(&mut expected_state); + } + let mut expected_bytes = Vec::with_capacity(200); + for lane in expected_state { + expected_bytes.extend_from_slice(&lane.to_le_bytes()); + } + + assert_eq!( + result.return_values.memory_values, expected_bytes, + "committed state must match tiny-keccak after 3 keccak-f[1600] calls" + ); + + let mut traces = + Traces::from_elf_and_logs(&elf, &result.logs, &Default::default(), &[]).unwrap(); + assert_eq!( + traces.public_output_bytes, + result.return_values.memory_values + ); + + assert!( + prove_and_verify_vm_minimal(&elf, &mut traces), + "keccak multi-call prove/verify failed" + ); +} + +/// Verifier REJECTS a forged trace where an addr byte cell is set to a +/// non-byte field element. +/// +/// Without the IS_BYTE range checks on addr(0..7), an attacker could keep +/// `addr_lo = b0 + 256·b1 + 65536·b2 + 2^24·b3` equal to an unaligned target +/// address as a field element while setting addr(0)=0 (passing the AndByte +/// alignment check) and folding the carry into addr(1) as a non-byte +/// FE-element. This test asserts that mutating addr(1) to a non-byte value +/// unbalances the verifier's bus checks and the proof is rejected. +#[test] +fn test_prove_elfs_keccak_unaligned_state_addr() { + use crate::tables::keccak::cols as keccak_cols; + + let _ = env_logger::builder().is_test(true).try_init(); + + let elf_bytes = crate::test_utils::asm_elf_bytes("test_keccak_multi"); + let elf = Elf::load(&elf_bytes).expect("Failed to load ELF"); + let executor = + executor::vm::execution::Executor::new(&elf, vec![]).expect("Failed to create executor"); + let result = executor.run().expect("Failed to run program"); + let mut traces = + Traces::from_elf_and_logs(&elf, &result.logs, &Default::default(), &[]).unwrap(); + + // Tamper the first real keccak row: replace addr(1) (a byte cell) with a + // value outside [0, 256). The new IS_BYTE bus sender will emit this + // value with multiplicity MU=1; the IS_BYTE preprocessed table only + // contains 0..256, so the bus cannot balance. + traces.keccak.main_table.set( + 0, + keccak_cols::addr(1), + FieldElement::::from(257u64), + ); + + assert!( + !prove_and_verify_vm_minimal(&elf, &mut traces), + "Verifier must reject a keccak proof whose addr cells are not bytes" + ); +} + #[test] fn test_prove_elfs_test_commit_4() { let elf_bytes = crate::test_utils::asm_elf_bytes("test_commit_4"); @@ -1796,7 +1890,7 @@ fn test_crafted_zero_count_proof_must_not_verify() { let airs = VmAirs::new(&elf, &proof_options, true, &[], &zero_counts); let verifier_air_refs = airs.air_refs(); - assert_eq!(verifier_air_refs.len(), 5); + assert_eq!(verifier_air_refs.len(), 8); let mut bitwise_trace = crate::tables::bitwise::generate_bitwise_trace(); diff --git a/syscalls/src/syscalls.rs b/syscalls/src/syscalls.rs index ae0315ff5..14d5b2e6f 100644 --- a/syscalls/src/syscalls.rs +++ b/syscalls/src/syscalls.rs @@ -16,6 +16,10 @@ enum SyscallNumbers { Halt = 93, } +/// Syscall number for KeccakPermute (u64::MAX - 1). +#[cfg(target_arch = "riscv64")] +const KECCAK_SYSCALL_NUMBER: usize = usize::MAX - 1; + #[cfg(target_arch = "riscv64")] /// This is a template for printing in the vm pub fn print_string(s: &str) { @@ -120,6 +124,24 @@ pub fn sys_halt() -> ! { unimplemented!("syscalls are only implemented for riscv64 targets"); } +#[cfg(target_arch = "riscv64")] +/// Apply the Keccak-f[1600] permutation to a 25-element u64 state in-place. +pub fn keccak_permute(state: &mut [u64; 25]) { + unsafe { + asm!( + "ecall", + in("a0") state.as_mut_ptr(), + in("a7") KECCAK_SYSCALL_NUMBER, + ) + } +} + +#[cfg(not(target_arch = "riscv64"))] +/// Apply the Keccak-f[1600] permutation to a 25-element u64 state in-place. +pub fn keccak_permute(_state: &mut [u64; 25]) { + unimplemented!("syscalls are only implemented for riscv64 targets"); +} + // ============================================================================= // Stub implementations for unsupported std functions // These functions are required by Rust's std zkvm module but are not supported