diff --git a/Cargo.lock b/Cargo.lock index 70b4071e8..8fff60dcf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1961,6 +1961,7 @@ dependencies = [ "math", "rayon", "serde", + "sha3", "stark", "sysinfo", "tikv-jemalloc-ctl", diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 9e03da9b3..03f9d8e75 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -20,6 +20,7 @@ serde = { version = "1.0", features = ["derive"] } rayon = { version = "1.8.0", optional = true } sysinfo = { version = "0.31", default-features = false, features = ["system"] } log = "0.4" +sha3 = { version = "0.10.8", default-features = false } [dev-dependencies] env_logger = "*" diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 7c6ca838c..14f35cdf8 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -17,6 +17,7 @@ pub mod constraints; mod debug_report; #[cfg(feature = "instruments")] pub mod instruments; +mod statement; pub mod tables; pub mod test_utils; #[cfg(test)] @@ -35,6 +36,7 @@ use stark::storage_mode::StorageMode; use stark::traits::AIR; use stark::verifier::{IsStarkVerifier, Verifier}; +use crate::statement::absorb_statement; pub use crate::tables::MaxRowsConfig; use crate::tables::bitwise; use crate::tables::decode; @@ -445,8 +447,8 @@ impl VmAirs { pub(crate) fn replay_transcript_phase_a( airs: &[&dyn AIR], multi_proof: &MultiProof, + transcript: &mut DefaultTranscript, ) -> (FieldElement, FieldElement) { - let mut transcript = DefaultTranscript::::new(&[]); for (air, proof) in airs.iter().zip(&multi_proof.proofs) { if air.is_preprocessed() { transcript.append_bytes(&air.precomputed_commitment()); @@ -512,8 +514,9 @@ pub(crate) fn compute_expected_commit_bus_balance( airs: &[&dyn AIR], proof: &MultiProof, public_output_bytes: &[u8], + transcript: &mut DefaultTranscript, ) -> Option> { - let (z, alpha) = replay_transcript_phase_a(airs, proof); + let (z, alpha) = replay_transcript_phase_a(airs, proof, transcript); compute_commit_bus_offset(public_output_bytes, &z, &alpha) } @@ -652,10 +655,28 @@ pub fn prove_with_options_and_inputs( let runtime_page_ranges = traces.runtime_page_ranges(); + let num_private_input_pages = traces + .page_configs + .iter() + .filter(|c| c.is_private_input) + .count(); + + // Bind the full statement (program, public output, table layout) into the + // Fiat-Shamir transcript so every challenge depends on it. + let mut transcript = DefaultTranscript::::new(&[]); + absorb_statement( + &mut transcript, + elf_bytes, + &traces.public_output_bytes, + &table_counts, + num_private_input_pages, + &runtime_page_ranges, + ); + // Phase 4: Prove (multi_prove) let proof = Prover::multi_prove( airs.air_trace_pairs(&mut traces), - &mut DefaultTranscript::::new(&[]), + &mut transcript, #[cfg(feature = "disk-spill")] storage_mode, ) @@ -677,12 +698,6 @@ pub fn prove_with_options_and_inputs( ); } - let num_private_input_pages = traces - .page_configs - .iter() - .filter(|c| c.is_private_input) - .count(); - Ok(VmProof { proof, runtime_page_ranges, @@ -765,10 +780,29 @@ pub fn verify_with_options( // If public_output was tampered, the recomputed offset won't match the // actual bus total in the proof, and multi_verify will reject. let air_refs = airs.air_refs(); + + // Bind the statement into the verifier's transcript. A tampered statement + // field makes this diverge from the prover's transcript state, so every + // derived challenge differs and verification rejects. + let mut transcript = DefaultTranscript::::new(&[]); + absorb_statement( + &mut transcript, + elf_bytes, + &vm_proof.public_output, + &vm_proof.table_counts, + vm_proof.num_private_input_pages, + &vm_proof.runtime_page_ranges, + ); + + // Fork the post-absorb state: the replay helper advances through Phase A + // independently of the multi_verify transcript, but both must start from + // the same statement-bound state. + let mut transcript_for_replay = transcript.clone(); let expected_bus_balance = match compute_expected_commit_bus_balance( &air_refs, &vm_proof.proof, &vm_proof.public_output, + &mut transcript_for_replay, ) { Some(balance) => balance, None => return Ok(false), @@ -777,7 +811,7 @@ pub fn verify_with_options( Ok(Verifier::multi_verify( &air_refs, &vm_proof.proof, - &mut DefaultTranscript::::new(&[]), + &mut transcript, &expected_bus_balance, )) } diff --git a/prover/src/statement.rs b/prover/src/statement.rs new file mode 100644 index 000000000..82c41861c --- /dev/null +++ b/prover/src/statement.rs @@ -0,0 +1,85 @@ +//! Statement absorbed into the Fiat-Shamir transcript before Phase A. +//! +//! Streams a canonical, domain-separated, length-prefixed encoding directly +//! into the transcript. The transcript is itself a Keccak256 absorber +//! (`DefaultTranscript`), so a single hash suffices — no external digest +//! needed beyond the ELF. +//! +//! All three call sites (prove, verify, bus-balance replay) must absorb +//! identical bytes; any divergence makes every derived challenge differ and +//! verification reject. + +use crypto::fiat_shamir::is_transcript::IsTranscript; +use sha3::{Digest, Keccak256}; + +use crate::test_utils::E; +use crate::{RuntimePageRange, TableCounts}; + +/// Domain-separation tag. Bump the suffix (`_V2`, ...) on any encoding change. +const DOMAIN_TAG: &[u8] = b"LAMBDAVM_STARK_STATEMENT_V1"; + +fn elf_digest(elf: &[u8]) -> [u8; 32] { + let mut h = Keccak256::new(); + h.update(elf); + h.finalize().into() +} + +pub(crate) fn absorb_statement( + t: &mut impl IsTranscript, + elf_bytes: &[u8], + public_output: &[u8], + table_counts: &TableCounts, + num_private_input_pages: usize, + runtime_page_ranges: &[RuntimePageRange], +) { + t.append_bytes(DOMAIN_TAG); + + // ELF: fixed 32-byte digest — no length prefix needed. + t.append_bytes(&elf_digest(elf_bytes)); + + // public_output: variable length → length-prefix to prevent boundary collisions. + t.append_bytes(&(public_output.len() as u64).to_le_bytes()); + t.append_bytes(public_output); + + // table_counts: fixed-width u64s in declared order. The exhaustive + // destructure makes any field added to TableCounts a compile error here — + // that's the signal to extend the loop below and bump DOMAIN_TAG. + let &TableCounts { + cpu, + lt, + memw, + memw_aligned, + load, + mul, + dvrm, + shift, + branch, + memw_register, + } = table_counts; + for count in [ + cpu, + lt, + memw, + memw_aligned, + load, + mul, + dvrm, + shift, + branch, + memw_register, + ] { + t.append_bytes(&(count as u64).to_le_bytes()); + } + + t.append_bytes(&(num_private_input_pages as u64).to_le_bytes()); + + // runtime_page_ranges: count-prefixed; each entry fixed width. + t.append_bytes(&(runtime_page_ranges.len() as u64).to_le_bytes()); + for r in runtime_page_ranges { + // Exhaustive destructure: any field added to RuntimePageRange becomes + // a compile error here. + let &RuntimePageRange { base, count } = r; + t.append_bytes(&base.to_le_bytes()); + t.append_bytes(&count.to_le_bytes()); + } +} diff --git a/prover/src/tests/decode_tests.rs b/prover/src/tests/decode_tests.rs index f1f60e5ba..8f61a1d74 100644 --- a/prover/src/tests/decode_tests.rs +++ b/prover/src/tests/decode_tests.rs @@ -1064,10 +1064,12 @@ fn test_decode_soundness_same_elf_accepted() { &table_counts, ); let verifier_air_refs = verifier_airs.air_refs(); + let mut replay_transcript = DefaultTranscript::::new(&[]); let expected_bus_balance = crate::compute_expected_commit_bus_balance( &verifier_air_refs, &proof, &traces.public_output_bytes, + &mut replay_transcript, ) .expect("fingerprint collision in test"); diff --git a/prover/src/tests/mod.rs b/prover/src/tests/mod.rs index 89bab730b..91a92ad46 100644 --- a/prover/src/tests/mod.rs +++ b/prover/src/tests/mod.rs @@ -31,4 +31,6 @@ pub mod mul_tests; #[cfg(test)] pub mod prove_elfs_tests; #[cfg(test)] +pub mod statement_tests; +#[cfg(test)] pub mod trace_builder_tests; diff --git a/prover/src/tests/prove_elfs_tests.rs b/prover/src/tests/prove_elfs_tests.rs index 6219e766f..fe97911b9 100644 --- a/prover/src/tests/prove_elfs_tests.rs +++ b/prover/src/tests/prove_elfs_tests.rs @@ -67,10 +67,12 @@ fn prove_and_verify_vm_minimal(elf: &Elf, traces: &mut Traces) -> bool { }; // Compute the verifier-side expected COMMIT bus balance from public output bytes + let mut replay_transcript = DefaultTranscript::::new(&[]); let expected_bus_balance = crate::compute_expected_commit_bus_balance( &airs.air_refs(), &multi_proof, &traces.public_output_bytes, + &mut replay_transcript, ) .expect("fingerprint collision in test"); @@ -881,10 +883,12 @@ fn test_prove_elfs_test_commit_4_wrong_pages_rejected() { let verifier_airs = crate::VmAirs::new(&elf, &proof_options, true, &wrong_configs, &table_counts); let verifier_air_refs = verifier_airs.air_refs(); + let mut replay_transcript = DefaultTranscript::::new(&[]); let expected_bus_balance = crate::compute_expected_commit_bus_balance( &verifier_air_refs, &proof, &traces.public_output_bytes, + &mut replay_transcript, ) .expect("fingerprint collision in test"); @@ -1617,10 +1621,12 @@ fn test_deep_stack_runtime_pages_roundtrip() { let verifier_airs = crate::VmAirs::new(&elf, &proof_options, true, &verifier_configs, &table_counts); let verifier_air_refs = verifier_airs.air_refs(); + let mut replay_transcript = DefaultTranscript::::new(&[]); let expected_bus_balance = crate::compute_expected_commit_bus_balance( &verifier_air_refs, &proof, &traces.public_output_bytes, + &mut replay_transcript, ) .expect("fingerprint collision in test"); @@ -1672,10 +1678,12 @@ fn test_deep_stack_missing_pages_rejected() { let verifier_airs = crate::VmAirs::new(&elf, &proof_options, true, &wrong_configs, &table_counts); let verifier_air_refs = verifier_airs.air_refs(); + let mut replay_transcript = DefaultTranscript::::new(&[]); let expected_bus_balance = crate::compute_expected_commit_bus_balance( &verifier_air_refs, &proof, &traces.public_output_bytes, + &mut replay_transcript, ) .expect("fingerprint collision in test"); @@ -1762,10 +1770,12 @@ fn test_heap_alloc_runtime_pages_roundtrip() { let verifier_airs = crate::VmAirs::new(&elf, &proof_options, true, &verifier_configs, &table_counts); let verifier_air_refs = verifier_airs.air_refs(); + let mut replay_transcript = DefaultTranscript::::new(&[]); let expected_bus_balance = crate::compute_expected_commit_bus_balance( &verifier_air_refs, &proof, &traces.public_output_bytes, + &mut replay_transcript, ) .expect("fingerprint collision in test"); diff --git a/prover/src/tests/statement_tests.rs b/prover/src/tests/statement_tests.rs new file mode 100644 index 000000000..659113be6 --- /dev/null +++ b/prover/src/tests/statement_tests.rs @@ -0,0 +1,118 @@ +//! Tests for statement absorption into the Fiat-Shamir transcript. + +use crypto::fiat_shamir::default_transcript::DefaultTranscript; +use crypto::fiat_shamir::is_transcript::IsTranscript; + +use crate::statement::absorb_statement; +use crate::test_utils::E; +use crate::{RuntimePageRange, TableCounts}; + +fn sample_counts() -> TableCounts { + TableCounts { + cpu: 3, + lt: 1, + memw: 2, + memw_aligned: 1, + load: 1, + mul: 1, + dvrm: 1, + shift: 1, + branch: 2, + memw_register: 1, + } +} + +fn sample_ranges() -> Vec { + vec![ + RuntimePageRange { + base: 0x1000, + count: 4, + }, + RuntimePageRange { + base: 0x8000, + count: 2, + }, + ] +} + +fn state_after_absorb( + elf: &[u8], + out: &[u8], + counts: &TableCounts, + priv_pages: usize, + ranges: &[RuntimePageRange], +) -> [u8; 32] { + let mut t = DefaultTranscript::::new(&[]); + absorb_statement(&mut t, elf, out, counts, priv_pages, ranges); + t.state() +} + +#[test] +fn state_is_deterministic() { + let a = state_after_absorb(b"elf", b"out", &sample_counts(), 3, &sample_ranges()); + let b = state_after_absorb(b"elf", b"out", &sample_counts(), 3, &sample_ranges()); + assert_eq!(a, b); +} + +#[test] +fn state_depends_on_every_field() { + let baseline = state_after_absorb(b"elf", b"out", &sample_counts(), 1, &sample_ranges()); + + assert_ne!( + baseline, + state_after_absorb( + b"different-elf", + b"out", + &sample_counts(), + 1, + &sample_ranges() + ), + "state must depend on elf", + ); + assert_ne!( + baseline, + state_after_absorb( + b"elf", + b"different-output", + &sample_counts(), + 1, + &sample_ranges() + ), + "state must depend on public_output", + ); + + let mut counts2 = sample_counts(); + counts2.branch += 1; + assert_ne!( + baseline, + state_after_absorb(b"elf", b"out", &counts2, 1, &sample_ranges()), + "state must depend on table_counts", + ); + + assert_ne!( + baseline, + state_after_absorb(b"elf", b"out", &sample_counts(), 2, &sample_ranges()), + "state must depend on num_private_input_pages", + ); + + assert_ne!( + baseline, + state_after_absorb(b"elf", b"out", &sample_counts(), 1, &[]), + "state must depend on runtime_page_ranges", + ); +} + +#[test] +fn public_output_length_prefix_prevents_collision() { + // Without the length prefix on public_output, "empty output + cpu count + // 0x41" and "output [0x41] + cpu count 0" would absorb identical bytes. + // The prefix keeps the two statements distinct. + let mut counts_a = sample_counts(); + counts_a.cpu = 0x41; + let mut counts_b = sample_counts(); + counts_b.cpu = 0; + assert_ne!( + state_after_absorb(b"elf", b"", &counts_a, 0, &[]), + state_after_absorb(b"elf", b"\x41", &counts_b, 0, &[]), + ); +}