diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 02fe99a17..2b7cca179 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -22,6 +22,7 @@ pub use { r1cs::R1CS, verifier::Verifier, whir_r1cs::{IOPattern, WhirConfig, WhirR1CSProof, WhirR1CSScheme}, + witness::PublicInputs, }; #[cfg(test)] diff --git a/provekit/common/src/noir_proof_scheme.rs b/provekit/common/src/noir_proof_scheme.rs index b8c8ff937..7552ab268 100644 --- a/provekit/common/src/noir_proof_scheme.rs +++ b/provekit/common/src/noir_proof_scheme.rs @@ -2,7 +2,7 @@ use { crate::{ whir_r1cs::{WhirR1CSProof, WhirR1CSScheme}, witness::{NoirWitnessGenerator, SplitWitnessBuilders}, - NoirElement, R1CS, + NoirElement, PublicInputs, R1CS, }, acir::circuit::Program, serde::{Deserialize, Serialize}, @@ -20,6 +20,7 @@ pub struct NoirProofScheme { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct NoirProof { + pub public_inputs: PublicInputs, pub whir_r1cs_proof: WhirR1CSProof, } diff --git a/provekit/common/src/skyscraper/mod.rs b/provekit/common/src/skyscraper/mod.rs index 3b6da92a4..2caecdc28 100644 --- a/provekit/common/src/skyscraper/mod.rs +++ b/provekit/common/src/skyscraper/mod.rs @@ -2,4 +2,8 @@ mod pow; mod sponge; mod whir; -pub use self::{pow::SkyscraperPoW, sponge::SkyscraperSponge, whir::SkyscraperMerkleConfig}; +pub use self::{ + pow::SkyscraperPoW, + sponge::SkyscraperSponge, + whir::{SkyscraperCRH, SkyscraperMerkleConfig}, +}; diff --git a/provekit/common/src/utils/mod.rs b/provekit/common/src/utils/mod.rs index a5f6aa5b7..43943c566 100644 --- a/provekit/common/src/utils/mod.rs +++ b/provekit/common/src/utils/mod.rs @@ -2,6 +2,7 @@ mod print_abi; pub mod serde_ark; pub mod serde_ark_option; +pub mod serde_ark_vec; pub mod serde_hex; pub mod serde_jsonify; pub mod sumcheck; diff --git a/provekit/common/src/utils/serde_ark_vec.rs b/provekit/common/src/utils/serde_ark_vec.rs new file mode 100644 index 000000000..b26c309ed --- /dev/null +++ b/provekit/common/src/utils/serde_ark_vec.rs @@ -0,0 +1,87 @@ +use { + crate::FieldElement, + ark_serialize::{CanonicalDeserialize, CanonicalSerialize}, + serde::{ + de::{Error as _, SeqAccess, Visitor}, + ser::{Error as _, SerializeSeq}, + Deserializer, Serializer, + }, + std::fmt, +}; + +pub fn serialize(vec: &Vec, serializer: S) -> Result +where + S: Serializer, +{ + let is_human_readable = serializer.is_human_readable(); + let mut seq = serializer.serialize_seq(Some(vec.len()))?; + for element in vec { + let mut buf = Vec::with_capacity(element.compressed_size()); + element + .serialize_compressed(&mut buf) + .map_err(|e| S::Error::custom(format!("Failed to serialize: {e}")))?; + + // Write bytes + if is_human_readable { + // ark_serialize doesn't have human-readable serialization. And Serde + // doesn't have good defaults for [u8]. So we implement hexadecimal + // serialization. + let hex = hex::encode(buf); + seq.serialize_element(&hex)?; + } else { + seq.serialize_element(&buf)?; + } + } + seq.end() +} + +pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + struct VecVisitor { + is_human_readable: bool, + } + + impl<'de> Visitor<'de> for VecVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of field elements") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut vec = Vec::new(); + if self.is_human_readable { + while let Some(hex) = seq.next_element::()? { + let bytes = hex::decode(hex) + .map_err(|e| A::Error::custom(format!("invalid hex: {e}")))?; + let mut reader = &*bytes; + let element = FieldElement::deserialize_compressed(&mut reader) + .map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?; + if !reader.is_empty() { + return Err(A::Error::custom("while deserializing: trailing bytes")); + } + vec.push(element); + } + } else { + while let Some(bytes) = seq.next_element::>()? { + let mut reader = &*bytes; + let element = FieldElement::deserialize_compressed(&mut reader) + .map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?; + if !reader.is_empty() { + return Err(A::Error::custom("while deserializing: trailing bytes")); + } + vec.push(element); + } + } + Ok(vec) + } + } + + let is_human_readable = deserializer.is_human_readable(); + deserializer.deserialize_seq(VecVisitor { is_human_readable }) +} diff --git a/provekit/common/src/utils/sumcheck.rs b/provekit/common/src/utils/sumcheck.rs index 7e1c5a245..6baef51d7 100644 --- a/provekit/common/src/utils/sumcheck.rs +++ b/provekit/common/src/utils/sumcheck.rs @@ -114,6 +114,10 @@ pub trait SumcheckIOPattern { fn add_rand(self, num_rand: usize) -> Self; fn add_zk_sumcheck_polynomials(self, num_vars: usize) -> Self; + + /// Prover sends the hash of the public inputs + /// Verifier sends randomness to construct weights + fn add_public_inputs(self) -> Self; } impl SumcheckIOPattern for IOPattern @@ -136,6 +140,12 @@ where self } + fn add_public_inputs(mut self) -> Self { + self = self.add_scalars(1, "Public Inputs Hash"); + self = self.challenge_scalars(1, "Public Weights Vector Random"); + self + } + fn add_rand(self, num_rand: usize) -> Self { self.challenge_scalars(num_rand, "rand") } diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index 302bfd79d..dc6ed3615 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -22,6 +22,7 @@ pub struct WhirR1CSScheme { pub m_0: usize, pub a_num_terms: usize, pub num_challenges: usize, + pub has_public_inputs: bool, pub whir_witness: WhirConfig, pub whir_for_hiding_spartan: WhirConfig, } @@ -34,10 +35,11 @@ impl WhirR1CSScheme { if self.num_challenges > 0 { // Compute total constraints: OOD + statement // OOD: 2 witnesses × committment_ood_samples each - // Statement: 2 statements × 3 constraints each = 6 + // Statement: statement_1 has 3 constraints + 1 public weights constraint = 4, + // statement_2 has 3 constraints = 3, total = 7 let num_witnesses = 2; let num_ood_constraints = num_witnesses * self.whir_witness.committment_ood_samples; - let num_statement_constraints = 6; // 2 statements × 3 constraints + let num_statement_constraints = if self.has_public_inputs { 7 } else { 6 }; let num_constraints_total = num_ood_constraints + num_statement_constraints; io = io @@ -48,8 +50,10 @@ impl WhirR1CSScheme { .commit_statement(&self.whir_for_hiding_spartan) .add_zk_sumcheck_polynomials(self.m_0) .add_whir_proof(&self.whir_for_hiding_spartan) + .add_public_inputs() .hint("claimed_evaluations_1") .hint("claimed_evaluations_2") + .hint("public_weights_evaluations") .add_whir_batch_proof(&self.whir_witness, num_witnesses, num_constraints_total); } else { io = io @@ -58,7 +62,9 @@ impl WhirR1CSScheme { .commit_statement(&self.whir_for_hiding_spartan) .add_zk_sumcheck_polynomials(self.m_0) .add_whir_proof(&self.whir_for_hiding_spartan) + .add_public_inputs() .hint("claimed_evaluations") + .hint("public_weights_evaluations") .add_whir_proof(&self.whir_witness); } diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 4361a5dce..183c116ff 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -7,7 +7,12 @@ mod witness_generator; mod witness_io_pattern; use { - crate::{utils::serde_ark, FieldElement}, + crate::{ + skyscraper::SkyscraperCRH, + utils::{serde_ark, serde_ark_vec}, + FieldElement, + }, + ark_crypto_primitives::crh::CRHScheme, ark_ff::One, serde::{Deserialize, Serialize}, }; @@ -15,7 +20,7 @@ pub use { binops::{BINOP_ATOMIC_BITS, BINOP_BITS, NUM_DIGITS}, digits::{decompose_into_digits, DigitalDecompositionWitnesses}, ram::{SpiceMemoryOperation, SpiceWitnesses}, - scheduling::{Layer, LayerType, LayeredWitnessBuilders, SplitWitnessBuilders}, + scheduling::{Layer, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders}, witness_builder::{ ConstantTerm, ProductLinearTerm, SumTerm, WitnessBuilder, WitnessCoefficient, }, @@ -40,3 +45,43 @@ impl ConstantOrR1CSWitness { } } } + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PublicInputs(#[serde(with = "serde_ark_vec")] pub Vec); + +impl PublicInputs { + pub fn new() -> Self { + Self(Vec::new()) + } + + pub fn from_vec(vec: Vec) -> Self { + Self(vec) + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn hash(&self) -> FieldElement { + match self.0.len() { + 0 => FieldElement::from(0u64), + 1 => { + // For single element, hash it with zero to ensure it gets properly hashed + let padded = vec![self.0[0], FieldElement::from(0u64)]; + SkyscraperCRH::evaluate(&(), &padded[..]).expect("hash should succeed") + } + _ => SkyscraperCRH::evaluate(&(), &self.0[..]) + .expect("hash should succeed for multiple inputs"), + } + } +} + +impl Default for PublicInputs { + fn default() -> Self { + Self::new() + } +} diff --git a/provekit/common/src/witness/scheduling/mod.rs b/provekit/common/src/witness/scheduling/mod.rs index f2529253e..28d80c11b 100644 --- a/provekit/common/src/witness/scheduling/mod.rs +++ b/provekit/common/src/witness/scheduling/mod.rs @@ -9,8 +9,10 @@ mod scheduler; mod splitter; pub use { - dependency::DependencyInfo, remapper::WitnessIndexRemapper, scheduler::LayerScheduler, - splitter::WitnessSplitter, + dependency::DependencyInfo, + remapper::WitnessIndexRemapper, + scheduler::LayerScheduler, + splitter::{SplitError, WitnessSplitter}, }; /// Type of operations contained in a layer. diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index 57a44367b..dc5107a6d 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -1,8 +1,23 @@ use { crate::witness::{scheduling::DependencyInfo, WitnessBuilder}, - std::collections::{HashSet, VecDeque}, + std::{ + collections::{HashSet, VecDeque}, + fmt, + }, }; +/// Error returned when witness splitting validation fails. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SplitError; + +impl fmt::Display for SplitError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "error in splitting witnesses into w1 and w2") + } +} + +impl std::error::Error for SplitError {} + /// Analyzes witness builder dependencies and splits them into w1/w2 groups. /// /// Uses backward reachability from challenge consumers (lookup builders) to @@ -26,7 +41,10 @@ impl<'a> WitnessSplitter<'a> { /// (post-challenge). /// /// Returns (w1_builder_indices, w2_builder_indices) - pub fn split_builders(&self) -> (Vec, Vec) { + pub fn split_builders( + &self, + acir_public_inputs_indices_set: HashSet, + ) -> Result<(Vec, Vec), SplitError> { let builder_count = self.witness_builders.len(); // Step 1: Find all Challenge builders @@ -40,7 +58,11 @@ impl<'a> WitnessSplitter<'a> { .collect(); if challenge_builders.is_empty() { - return ((0..builder_count).collect(), Vec::new()); + let w1_indices = self.rearrange_w1( + (0..builder_count).collect(), + &acir_public_inputs_indices_set, + )?; + return Ok((w1_indices, Vec::new())); } // Step 2: Forward DFS from challenges to find mandatory_w2 @@ -135,6 +157,7 @@ impl<'a> WitnessSplitter<'a> { // Step 7: Assign free builders greedily while respecting dependencies // Rule: if any dependency is in w2, the builder must also be in w2 // (because w1 is solved before w2) + // A free builder for public input witnesses goes in w1. let mut w1_set = mandatory_w1; let mut w2_set = mandatory_w2; @@ -149,6 +172,15 @@ impl<'a> WitnessSplitter<'a> { let witness_count = DependencyInfo::extract_writes(&self.witness_builders[idx]).len(); + // If free builder writes a public witness, add it to w1_set. + if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] { + if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + w1_set.insert(idx); + w1_witness_count += witness_count; + continue; + } + } + if must_be_w2 { w2_set.insert(idx); w2_witness_count += witness_count; @@ -165,9 +197,53 @@ impl<'a> WitnessSplitter<'a> { let mut w1_indices: Vec = w1_set.into_iter().collect(); let mut w2_indices: Vec = w2_set.into_iter().collect(); - w1_indices.sort_unstable(); + w1_indices = self.rearrange_w1(w1_indices, &acir_public_inputs_indices_set)?; w2_indices.sort_unstable(); - (w1_indices, w2_indices) + Ok((w1_indices, w2_indices)) + } + + /// Rearranges w1 builder indices into a canonical order: + /// 1. Constant builder (index 0) first, to preserve R1CS index 0 invariant + /// 2. Public input builders next, grouped together + /// 3. All other w1 builders last, sorted by index + fn rearrange_w1( + &self, + w1_indices: Vec, + acir_public_inputs_indices_set: &HashSet, + ) -> Result, SplitError> { + let mut public_input_builder_indices = Vec::new(); + let mut rest_indices = Vec::new(); + + // Sanity Check: Make sure all public inputs and WITNESS_ONE_IDX are in + // w1_indices. + // Convert to HashSet for O(1) lookups since we're checking many times + let w1_indices_set = w1_indices.iter().copied().collect::>(); + for &idx in acir_public_inputs_indices_set.iter() { + if !w1_indices_set.contains(&(idx as usize)) { + return Err(SplitError); + } + } + + // Separate into: 0, public inputs, and rest + for builder_idx in w1_indices { + if builder_idx == 0 { + continue; // Will add 0 first + } else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] { + if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + public_input_builder_indices.push(builder_idx); + continue; + } + } + rest_indices.push(builder_idx); + } + + rest_indices.sort_unstable(); + + // Reorder: 0 first, then public inputs, then rest + let mut new_w1_indices = vec![0]; + new_w1_indices.extend(public_input_builder_indices); + new_w1_indices.extend(rest_indices); + Ok(new_w1_indices) } } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index d212c9bc4..601321a70 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -6,15 +6,15 @@ use { digits::DigitalDecompositionWitnesses, ram::SpiceWitnesses, scheduling::{ - LayerScheduler, LayeredWitnessBuilders, SplitWitnessBuilders, WitnessIndexRemapper, - WitnessSplitter, + LayerScheduler, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, + WitnessIndexRemapper, WitnessSplitter, }, ConstantOrR1CSWitness, }, FieldElement, R1CS, }, serde::{Deserialize, Serialize}, - std::num::NonZeroU32, + std::{collections::HashSet, num::NonZeroU32}, }; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -174,9 +174,10 @@ impl WitnessBuilder { witness_builders: &[WitnessBuilder], r1cs: R1CS, witness_map: Vec>, - ) -> (SplitWitnessBuilders, R1CS, Vec>, usize) { + acir_public_inputs_indices_set: HashSet, + ) -> Result<(SplitWitnessBuilders, R1CS, Vec>, usize), SplitError> { if witness_builders.is_empty() { - return ( + return Ok(( SplitWitnessBuilders { w1_layers: LayeredWitnessBuilders { layers: Vec::new() }, w2_layers: LayeredWitnessBuilders { layers: Vec::new() }, @@ -185,12 +186,12 @@ impl WitnessBuilder { r1cs, witness_map, 0, - ); + )); } // Step 1: Analyze dependencies and split into w1/w2 let splitter = WitnessSplitter::new(witness_builders); - let (w1_indices, w2_indices) = splitter.split_builders(); + let (w1_indices, w2_indices) = splitter.split_builders(acir_public_inputs_indices_set)?; // Step 2: Extract w1 and w2 builders in order let w1_builders: Vec = w1_indices @@ -244,7 +245,7 @@ impl WitnessBuilder { .filter(|b| matches!(b, WitnessBuilder::Challenge(_))) .count(); - ( + Ok(( SplitWitnessBuilders { w1_layers, w2_layers, @@ -253,6 +254,6 @@ impl WitnessBuilder { remapped_r1cs, remapped_witness_map, num_challenges, - ) + )) } } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 031a29dde..bb89b7905 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -6,7 +6,7 @@ use { nargo::foreign_calls::DefaultForeignCallBuilder, noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, - provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover}, + provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover, PublicInputs}, std::path::Path, tracing::instrument, }; @@ -56,6 +56,7 @@ impl Prove for Prover { read_inputs_from_file(prover_toml.as_ref(), self.witness_generator.abi())?; let acir_witness_idx_to_value_map = self.generate_witness(input_map)?; + let acir_public_inputs = self.program.functions[0].public_inputs().indices(); // Set up transcript let io: IOPattern = self.whir_for_witness.create_io_pattern(); @@ -112,14 +113,30 @@ impl Prove for Prover { self.r1cs .test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::>()) .context("While verifying R1CS instance")?; + + // Gather public inputs from witness + let num_public_inputs = acir_public_inputs.len(); + let public_inputs = if num_public_inputs == 0 { + PublicInputs::new() + } else { + PublicInputs::from_vec( + witness[1..=num_public_inputs] + .iter() + .map(|w| w.ok_or_else(|| anyhow::anyhow!("Missing public input witness"))) + .collect::>>()?, + ) + }; drop(witness); let whir_r1cs_proof = self .whir_for_witness - .prove(merlin, self.r1cs, commitments) + .prove(merlin, self.r1cs, commitments, &public_inputs) .context("While proving R1CS instance")?; - Ok(NoirProof { whir_r1cs_proof }) + Ok(NoirProof { + public_inputs, + whir_r1cs_proof, + }) } } diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index bb2b8a64b..06fe0fdd6 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -14,7 +14,7 @@ use { zk_utils::{create_masked_polynomial, generate_random_multilinear_polynomial}, HALF, }, - FieldElement, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, + FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitSerialize, UnitToField}, @@ -54,6 +54,7 @@ pub trait WhirR1CSProver { merlin: ProverState, r1cs: R1CS, commitments: Vec, + public_inputs: &PublicInputs, ) -> Result; } @@ -121,6 +122,7 @@ impl WhirR1CSProver for WhirR1CSScheme { mut merlin: ProverState, r1cs: R1CS, mut commitments: Vec, + public_inputs: &PublicInputs, ) -> Result { ensure!(!commitments.is_empty(), "Need at least one commitment"); @@ -153,22 +155,41 @@ impl WhirR1CSProver for WhirR1CSScheme { // Compute weights from R1CS matrices let alphas = calculate_external_row_of_r1cs_matrices(alpha, r1cs); + let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); if is_single { // Single commitment path let commitment = commitments.into_iter().next().unwrap(); let alphas: [Vec; 3] = alphas.try_into().unwrap(); - let (statement, f_sums, g_sums) = create_combined_statement_over_two_polynomials::<3>( + let (mut statement, f_sums, g_sums) = create_combined_statement_over_two_polynomials::<3>( self.m, &commitment.commitment_to_witness, - commitment.masked_polynomial, - commitment.random_polynomial, + &commitment.masked_polynomial, + &commitment.random_polynomial, &alphas, ); merlin.hint::<(Vec, Vec)>(&(f_sums, g_sums))?; + let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { + // If there are no public inputs, the hint is unused by the verifier and can be + // assigned an arbitrary value. + let public_f_sum = FieldElement::zero(); + let public_g_sum = FieldElement::zero(); + (public_f_sum, public_g_sum) + } else { + update_statement_with_public_weights( + &mut statement, + &commitment.commitment_to_witness, + &commitment.masked_polynomial, + &commitment.random_polynomial, + public_weight, + ) + }; + + merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum))?; + run_zk_whir_pcs_prover( commitment.commitment_to_witness, statement, @@ -193,12 +214,12 @@ impl WhirR1CSProver for WhirR1CSScheme { let alphas_1: [Vec; 3] = alphas_1.try_into().unwrap(); let alphas_2: [Vec; 3] = alphas_2.try_into().unwrap(); - let (statement_1, f_sums_1, g_sums_1) = + let (mut statement_1, f_sums_1, g_sums_1) = create_combined_statement_over_two_polynomials::<3>( self.m, &c1.commitment_to_witness, - c1.masked_polynomial, - c1.random_polynomial, + &c1.masked_polynomial, + &c1.random_polynomial, &alphas_1, ); drop(alphas_1); @@ -207,8 +228,8 @@ impl WhirR1CSProver for WhirR1CSScheme { create_combined_statement_over_two_polynomials::<3>( self.m, &c2.commitment_to_witness, - c2.masked_polynomial, - c2.random_polynomial, + &c2.masked_polynomial, + &c2.random_polynomial, &alphas_2, ); drop(alphas_2); @@ -216,6 +237,22 @@ impl WhirR1CSProver for WhirR1CSScheme { merlin.hint::<(Vec, Vec)>(&(f_sums_1, g_sums_1))?; merlin.hint::<(Vec, Vec)>(&(f_sums_2, g_sums_2))?; + let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { + let public_f_sum = FieldElement::zero(); + let public_g_sum = FieldElement::zero(); + (public_f_sum, public_g_sum) + } else { + update_statement_with_public_weights( + &mut statement_1, + &c1.commitment_to_witness, + &c1.masked_polynomial, + &c1.random_polynomial, + public_weight, + ) + }; + + merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum))?; + run_zk_whir_pcs_batch_prover( &[c1.commitment_to_witness, c2.commitment_to_witness], &[statement_1, statement_2], @@ -511,8 +548,8 @@ pub fn run_zk_sumcheck_prover( create_combined_statement_over_two_polynomials::<1>( blinding_polynomial_variables + 1, &commitment_to_blinding_polynomial, - blindings_mask_polynomial, - blindings_blind_polynomial, + &blindings_mask_polynomial, + &blindings_blind_polynomial, &[expand_powers(alpha.as_slice())], ); @@ -545,8 +582,8 @@ fn expand_powers(values: &[FieldElement]) -> Vec { fn create_combined_statement_over_two_polynomials( cfg_nv: usize, witness: &Witness, - f_polynomial: EvaluationsList, - g_polynomial: EvaluationsList, + f_polynomial: &EvaluationsList, + g_polynomial: &EvaluationsList, alphas: &[Vec; N], ) -> ( Statement, @@ -576,8 +613,8 @@ fn create_combined_statement_over_two_polynomials( w_full.resize(final_len, FieldElement::zero()); let weight = Weights::linear(EvaluationsList::new(w_full)); - let f = weight.weighted_sum(&f_polynomial); - let g = weight.weighted_sum(&g_polynomial); + let f = weight.weighted_sum(f_polynomial); + let g = weight.weighted_sum(g_polynomial); statement.add_constraint(weight, f + witness.batching_randomness * g); f_sums.push(f); @@ -628,3 +665,49 @@ pub fn run_zk_whir_pcs_batch_prover( (randomness, deferred) } + +fn update_statement_with_public_weights( + statement: &mut Statement, + witness: &Witness, + f_polynomial: &EvaluationsList, + g_polynomial: &EvaluationsList, + public_weights: Weights, +) -> (FieldElement, FieldElement) { + let f = public_weights.weighted_sum(f_polynomial); + let g = public_weights.weighted_sum(g_polynomial); + statement.add_constraint_in_front(public_weights, f + witness.batching_randomness * g); + (f, g) +} + +fn get_public_weights( + public_inputs: &PublicInputs, + merlin: &mut ProverState, + m: usize, +) -> Weights { + // Add hash to transcript + let public_inputs_hash = public_inputs.hash(); + let _ = merlin.add_scalars(&[public_inputs_hash]); + + // Get random point x + let mut x_buf = [FieldElement::zero()]; + merlin + .fill_challenge_scalars(&mut x_buf) + .expect("Failed to get challenge from Merlin"); + let x = x_buf[0]; + + let domain_size = 1 << m; + let mut public_weights = vec![FieldElement::zero(); domain_size]; + + // Set public weights for public inputs [1,x,x^2,x^3...x^n-1,0,0,0...0] + let mut current_pow = FieldElement::one(); + for (idx, _) in public_inputs.0.iter().enumerate() { + public_weights[idx] = current_pow; + current_pow = current_pow * x; + } + + Weights::geometric( + x, + public_inputs.0.len(), + EvaluationsList::new(public_weights), + ) +} diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 58e9346a7..b89a87847 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -10,7 +10,7 @@ use { witness::{NoirWitnessGenerator, WitnessBuilder}, NoirProofScheme, WhirR1CSScheme, }, - std::{fs::File, path::Path}, + std::{collections::HashSet, fs::File, path::Path}, tracing::{info, instrument}, }; @@ -61,9 +61,19 @@ impl NoirProofSchemeBuilder for NoirProofScheme { r1cs.c.num_entries() ); + // Extract ACIR public input indices set + let acir_public_inputs_indices_set: HashSet = + main.public_inputs().indices().iter().cloned().collect(); + + let has_public_inputs = !acir_public_inputs_indices_set.is_empty(); // Split witness builders and remap indices for sound challenge generation let (split_witness_builders, remapped_r1cs, remapped_witness_map, num_challenges) = - WitnessBuilder::split_and_prepare_layers(&witness_builders, r1cs, witness_map); + WitnessBuilder::split_and_prepare_layers( + &witness_builders, + r1cs, + witness_map, + acir_public_inputs_indices_set, + )?; info!( "Witness split: w1 size = {}, w2 size = {}", split_witness_builders.w1_size, @@ -82,6 +92,7 @@ impl NoirProofSchemeBuilder for NoirProofScheme { &remapped_r1cs, split_witness_builders.w1_size, num_challenges, + has_public_inputs, ); Ok(Self { diff --git a/provekit/r1cs-compiler/src/whir_r1cs.rs b/provekit/r1cs-compiler/src/whir_r1cs.rs index 4604be4a7..385b11c98 100644 --- a/provekit/r1cs-compiler/src/whir_r1cs.rs +++ b/provekit/r1cs-compiler/src/whir_r1cs.rs @@ -17,13 +17,23 @@ const MIN_WHIR_NUM_VARIABLES: usize = 12; const MIN_SUMCHECK_NUM_VARIABLES: usize = 1; pub trait WhirR1CSSchemeBuilder { - fn new_for_r1cs(r1cs: &R1CS, w1_size: usize, num_challenges: usize) -> Self; + fn new_for_r1cs( + r1cs: &R1CS, + w1_size: usize, + num_challenges: usize, + has_public_inputs: bool, + ) -> Self; fn new_whir_config_for_size(num_variables: usize, batch_size: usize) -> WhirConfig; } impl WhirR1CSSchemeBuilder for WhirR1CSScheme { - fn new_for_r1cs(r1cs: &R1CS, w1_size: usize, num_challenges: usize) -> Self { + fn new_for_r1cs( + r1cs: &R1CS, + w1_size: usize, + num_challenges: usize, + has_public_inputs: bool, + ) -> Self { let total_witnesses = r1cs.num_witnesses(); assert!( w1_size <= total_witnesses, @@ -49,6 +59,7 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { next_power_of_two(4 * m_0) + 1, 2, ), + has_public_inputs, } } diff --git a/provekit/verifier/src/lib.rs b/provekit/verifier/src/lib.rs index abdb369fc..d1ec351a4 100644 --- a/provekit/verifier/src/lib.rs +++ b/provekit/verifier/src/lib.rs @@ -17,7 +17,7 @@ impl Verify for Verifier { self.whir_for_witness .take() .unwrap() - .verify(&proof.whir_r1cs_proof)?; + .verify(&proof.whir_r1cs_proof, &proof.public_inputs)?; Ok(()) } diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 8a668fa03..068fd1d0a 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -4,7 +4,7 @@ use { provekit_common::{ skyscraper::SkyscraperSponge, utils::sumcheck::{calculate_eq, eval_cubic_poly}, - FieldElement, WhirConfig, WhirR1CSProof, WhirR1CSScheme, + FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitDeserialize, UnitToField}, @@ -29,13 +29,13 @@ pub struct DataFromSumcheckVerifier { } pub trait WhirR1CSVerifier { - fn verify(&self, proof: &WhirR1CSProof) -> Result<()>; + fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs) -> Result<()>; } impl WhirR1CSVerifier for WhirR1CSScheme { #[instrument(skip_all)] #[allow(unused)] - fn verify(&self, proof: &WhirR1CSProof) -> Result<()> { + fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs) -> Result<()> { let io = self.create_io_pattern(); let mut arthur = io.to_verifier_state(&proof.transcript); @@ -56,6 +56,19 @@ impl WhirR1CSVerifier for WhirR1CSScheme { run_sumcheck_verifier(&mut arthur, self.m_0, &self.whir_for_hiding_spartan) .context("while verifying sumcheck")?; + // Verify public inputs hash + let mut public_inputs_hash_buf = [FieldElement::zero()]; + arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; + let expected_public_inputs_hash = public_inputs.hash(); + ensure!( + public_inputs_hash_buf[0] == expected_public_inputs_hash, + "Public inputs hash mismatch: expected {:?}, got {:?}", + expected_public_inputs_hash, + public_inputs_hash_buf[0] + ); + let mut public_weights_vector_random_buf = [FieldElement::zero()]; + arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; + // Read hints and verify WHIR proof let (az_at_alpha, bz_at_alpha, cz_at_alpha) = if let Some(parsed_commitment_2) = parsed_commitment_2 { @@ -68,7 +81,7 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let whir_sums_2: ([FieldElement; 3], [FieldElement; 3]) = (sums_2.0.try_into().unwrap(), sums_2.1.try_into().unwrap()); - let statement_1 = prepare_statement_for_witness_verifier::<3>( + let mut statement_1 = prepare_statement_for_witness_verifier::<3>( self.m, &parsed_commitment_1, &whir_sums_1, @@ -79,6 +92,19 @@ impl WhirR1CSVerifier for WhirR1CSScheme { &whir_sums_2, ); + let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur + .hint() + .context("failed to read WHIR public weights query answer")?; + + if !public_inputs.is_empty() { + update_statement_for_witness_verifier( + self.m, + &mut statement_1, + &parsed_commitment_1, + whir_public_weights_query_answer, + ); + } + run_whir_pcs_batch_verifier( &mut arthur, &self.whir_witness, @@ -98,12 +124,24 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let whir_sums: ([FieldElement; 3], [FieldElement; 3]) = (sums.0.try_into().unwrap(), sums.1.try_into().unwrap()); - let statement = prepare_statement_for_witness_verifier::<3>( + let mut statement = prepare_statement_for_witness_verifier::<3>( self.m, &parsed_commitment_1, &whir_sums, ); + let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur + .hint() + .context("failed to read WHIR public weights query answer")?; + if !public_inputs.is_empty() { + update_statement_for_witness_verifier( + self.m, + &mut statement, + &parsed_commitment_1, + whir_public_weights_query_answer, + ); + } + run_whir_pcs_verifier( &mut arthur, &parsed_commitment_1, @@ -147,6 +185,20 @@ fn prepare_statement_for_witness_verifier( statement_verifier } +fn update_statement_for_witness_verifier( + m: usize, + statement_verifier: &mut Statement, + parsed_commitment: &ParsedCommitment, + whir_public_weights_query_answer: (FieldElement, FieldElement), +) { + let (public_f_sum, public_g_sum) = whir_public_weights_query_answer; + let public_weight = Weights::linear(EvaluationsList::new(vec![FieldElement::zero(); 1 << m])); + statement_verifier.add_constraint_in_front( + public_weight, + public_f_sum + public_g_sum * parsed_commitment.batching_randomness, + ); +} + #[instrument(skip_all)] pub fn run_sumcheck_verifier( arthur: &mut VerifierState, diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index 2f95d5e6e..c543d20ed 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -39,6 +39,9 @@ type Circuit struct { WitnessClaimedEvaluations [][]frontend.Variable // [commitment_idx][eval_idx] WitnessBlindingEvaluations [][]frontend.Variable + // For public_f_sum and public_g_sum + PubWitnessEvaluations []frontend.Variable + // Batch mode only: batched polynomial for rounds 1+ WitnessMerkle Merkle @@ -46,9 +49,9 @@ type Circuit struct { MatrixB []MatrixCell MatrixC []MatrixCell - // Public Input - IO []byte - Transcript []uints.U8 `gnark:",public"` + IO []byte + Transcript []uints.U8 `gnark:",public"` + PublicInputs PublicInputs } func (circuit *Circuit) Define(api frontend.API) error { @@ -95,12 +98,56 @@ func (circuit *Circuit) Define(api frontend.API) error { return err } + // Read public inputs hash from transcript + publicInputsHashBuf := make([]frontend.Variable, 1) + if err := arthur.FillNextScalars(publicInputsHashBuf); err != nil { + return fmt.Errorf("failed to read public inputs hash: %w", err) + } + + expectedHash, err := hashPublicInputs(sc, circuit.PublicInputs) + if err != nil { + return fmt.Errorf("failed to compute public inputs hash: %w", err) + } + + api.AssertIsEqual(publicInputsHashBuf[0], expectedHash) + + // Squeeze rand for public weights + publicWeightsChallenge := make([]frontend.Variable, 1) + if err := arthur.FillChallengeScalars(publicWeightsChallenge); err != nil { + return fmt.Errorf("failed to read public weights challenge: %w", err) + } + // WHIR verification var whirFoldingRandomness []frontend.Variable var az, bz, cz frontend.Variable if circuit.NumChallenges > 0 { - // Dual commitment mode: batch WHIR verification + // Only statement_1 (first commitment) gets extended with public weights, statement_2 remains unchanged + extendedLinearStatementEvalsBatch := make([][][]frontend.Variable, 2) + + if !circuit.PublicInputs.IsEmpty() { + extendedLinearStatementEvalsBatch[0] = extendLinearStatement( + circuit, + [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, + circuit.PubWitnessEvaluations, + ) + + extendedLinearStatementEvalsBatch[1] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[1], + circuit.WitnessBlindingEvaluations[1], + } + } else { + // Use original arrays as before, no public inputs + extendedLinearStatementEvalsBatch[0] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[0], + circuit.WitnessBlindingEvaluations[0], + } + extendedLinearStatementEvalsBatch[1] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[1], + circuit.WitnessBlindingEvaluations[1], + } + } + whirFoldingRandomness, err = RunZKWhirBatch( api, arthur, uapi, sc, circuit.WitnessFirstRounds, // firstRounds []Merkle @@ -109,12 +156,10 @@ func (circuit *Circuit) Define(api frontend.API) error { [][][]frontend.Variable{initialOODAnswers1, initialOODAnswers2}, // initialOODAnswers []frontend.Variable{rootHash1, rootHash2}, // rootHashes circuit.WitnessMerkle, // batchedMerkle - [][][]frontend.Variable{ // linearStatementEvals - {circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, - {circuit.WitnessClaimedEvaluations[1], circuit.WitnessBlindingEvaluations[1]}, - }, - circuit.WHIRParamsWitness, // whirParams - circuit.WitnessLinearStatementEvaluations, // linearStatementValuesAtPoints + extendedLinearStatementEvalsBatch, // linearStatementEvals (extended for first commitment) + circuit.WHIRParamsWitness, // whirParams + circuit.WitnessLinearStatementEvaluations, // linearStatementValuesAtPoints + circuit.PublicInputs, // publicInputs ) if err != nil { return err @@ -125,12 +170,14 @@ func (circuit *Circuit) Define(api frontend.API) error { bz = api.Add(circuit.WitnessClaimedEvaluations[0][1], circuit.WitnessClaimedEvaluations[1][1]) cz = api.Add(circuit.WitnessClaimedEvaluations[0][2], circuit.WitnessClaimedEvaluations[1][2]) } else { + extendedLinearStatementEvals := extendLinearStatement(circuit, [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, circuit.PubWitnessEvaluations) + // Single commitment mode whirFoldingRandomness, err = RunZKWhir( api, arthur, uapi, sc, circuit.WitnessMerkle, circuit.WitnessFirstRounds[0], circuit.WHIRParamsWitness, - [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, + extendedLinearStatementEvals, circuit.WitnessLinearStatementEvaluations, batchingRandomness1, initialOODQueries1, @@ -150,23 +197,74 @@ func (circuit *Circuit) Define(api frontend.API) error { x := api.Mul(api.Sub(api.Mul(az, bz), cz), calculateEQ(api, spartanSumcheckRand, tRand)) api.AssertIsEqual(spartanSumcheckLastValue, x) + offset := 0 + if !circuit.PublicInputs.IsEmpty() { + // can be generalized later on if we have more different kinds of statements + offset = 1 + } + if circuit.NumChallenges > 0 { // Batch mode - check 6 deferred values matrixExtensionEvals := evaluateR1CSMatrixExtensionBatch(api, circuit, spartanSumcheckRand, whirFoldingRandomness, circuit.W1Size) for i := 0; i < 6; i++ { - api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[i]) + api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[offset+i]) } } else { + // Single mode - existing logic matrixExtensionEvals := evaluateR1CSMatrixExtension(api, circuit, spartanSumcheckRand, whirFoldingRandomness) for i := 0; i < 3; i++ { - api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[i]) + api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[offset+i]) } } + // Geometric weights for public inputs + if !circuit.PublicInputs.IsEmpty() { + publicWeightEval := computePublicWeightEvaluation( + api, circuit.PublicInputs, whirFoldingRandomness, + circuit.WHIRParamsWitness.MVParamsNumberOfVariables, publicWeightsChallenge[0], + ) + + api.AssertIsEqual(publicWeightEval, circuit.WitnessLinearStatementEvaluations[0]) + } + return nil } +func computePublicWeightEvaluation( + api frontend.API, + publicInputs PublicInputs, + foldingRandomness []frontend.Variable, + m int, // domain size = 2^m + x frontend.Variable, +) frontend.Variable { + // Build public weight vector: [1, x, x^2, ..., x^(n-1), 0, 0, ..., 0] where n = len(publicInputs.Values) and total length = 2^m + domainSize := 1 << m + publicWeights := make([]frontend.Variable, domainSize) + + for i := 0; i < domainSize; i++ { + publicWeights[i] = 0 + } + + // Set public weights: [1, x, x^2, ..., x^(n-1), 0, 0, ..., 0] + currentPower := frontend.Variable(1) + for i := 0; i < len(publicInputs.Values); i++ { + publicWeights[i] = currentPower + currentPower = api.Mul(currentPower, x) + } + + // TODO : Replace it with geometric_till algo + // Evaluate the multilinear extension of publicWeights at foldingRandomness + // Formula: f(r) = Σ_{i=0}^{2^m-1} f[i] * eq_i(r) + // where eq_i(r) is the i-th Lagrange basis polynomial + eqPolys := calculateEQOverBooleanHypercube(api, foldingRandomness) + result := frontend.Variable(0) + for i := 0; i < len(publicWeights); i++ { + result = api.Add(result, api.Mul(publicWeights[i], eqPolys[i])) + } + return result +} + func verifyCircuit( deferred []Fp256, cfg Config, @@ -175,9 +273,11 @@ func verifyCircuit( vk *groth16.VerifyingKey, claimedEvaluations ClaimedEvaluations, claimedEvaluations2 ClaimedEvaluations, + publicWeightsClaimedEvaluation [2]Fp256, internedR1CS R1CS, interner Interner, buildOps common.BuildOps, + publicInputs PublicInputs, ) error { transcriptT := make([]uints.U8, cfg.TranscriptLen) contTranscript := make([]uints.U8, cfg.TranscriptLen) @@ -189,9 +289,18 @@ func verifyCircuit( // Determine witness linear statement evals size based on mode var witnessLinearStatementEvalsSize int if cfg.NumChallenges > 0 { - witnessLinearStatementEvalsSize = 6 // 3 per commitment in batch mode + if !cfg.PublicInputs.IsEmpty() { + // 3 per commitment in batch mode + 1 public_input (geometric statement as a subset of linear statement) + witnessLinearStatementEvalsSize = 7 + } else { + witnessLinearStatementEvalsSize = 6 + } } else { - witnessLinearStatementEvalsSize = 3 + if !cfg.PublicInputs.IsEmpty() { + witnessLinearStatementEvalsSize = 4 + } else { + witnessLinearStatementEvalsSize = 3 + } } witnessLinearStatementEvaluations := make([]frontend.Variable, witnessLinearStatementEvalsSize) @@ -199,6 +308,9 @@ func verifyCircuit( contWitnessLinearStatementEvaluations := make([]frontend.Variable, witnessLinearStatementEvalsSize) contHidingSpartanLinearStatementEvaluations := make([]frontend.Variable, 1) + if len(deferred) < 1+witnessLinearStatementEvalsSize { + return fmt.Errorf("deferred array too short: expected at least %d elements, got %d", 1+witnessLinearStatementEvalsSize, len(deferred)) + } hidingSpartanLinearStatementEvaluations[0] = typeConverters.LimbsToBigIntMod(deferred[0].Limbs) for i := 0; i < witnessLinearStatementEvalsSize; i++ { witnessLinearStatementEvaluations[i] = typeConverters.LimbsToBigIntMod(deferred[1+i].Limbs) @@ -258,6 +370,10 @@ func verifyCircuit( fSums2, gSums2 = parseClaimedEvaluations(claimedEvaluations2, true) } + // Parse public weights claimed evaluation + fSumPublicWeights, gSumPublicWeights := parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation, true) + pubWitnessEvaluations := []frontend.Variable{fSumPublicWeights, gSumPublicWeights} + // Build witness slices conditionally var witnessClaimedEvals, witnessBlindingEvals [][]frontend.Variable if cfg.NumChallenges > 0 { @@ -268,6 +384,11 @@ func verifyCircuit( witnessBlindingEvals = [][]frontend.Variable{gSums} } + // Empty container while circuit creation + publicInputsContainer := PublicInputs{ + Values: make([]frontend.Variable, len(publicInputs.Values)), + } + circuit := Circuit{ IO: []byte(cfg.IOPattern), Transcript: contTranscript, @@ -276,6 +397,7 @@ func verifyCircuit( LogANumTerms: cfg.LogANumTerms, WitnessClaimedEvaluations: witnessClaimedEvals, WitnessBlindingEvaluations: witnessBlindingEvals, + PubWitnessEvaluations: pubWitnessEvaluations, WitnessLinearStatementEvaluations: contWitnessLinearStatementEvaluations, HidingSpartanLinearStatementEvaluations: contHidingSpartanLinearStatementEvaluations, HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, true), @@ -289,6 +411,7 @@ func verifyCircuit( MatrixA: matrixA, MatrixB: matrixB, MatrixC: matrixC, + PublicInputs: publicInputsContainer, } ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) @@ -377,13 +500,19 @@ func verifyCircuit( witnessBlindingEvals = [][]frontend.Variable{gSums} } + fSumPublicWeights, gSumPublicWeights = parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation, false) + pubWitnessEvaluations = []frontend.Variable{fSumPublicWeights, gSumPublicWeights} + assignment := Circuit{ IO: []byte(cfg.IOPattern), Transcript: transcriptT, LogNumConstraints: cfg.LogNumConstraints, + LogNumVariables: cfg.LogNumVariables, + LogANumTerms: cfg.LogANumTerms, WitnessClaimedEvaluations: witnessClaimedEvals, WitnessBlindingEvaluations: witnessBlindingEvals, WitnessLinearStatementEvaluations: witnessLinearStatementEvaluations, + PubWitnessEvaluations: pubWitnessEvaluations, HidingSpartanLinearStatementEvaluations: hidingSpartanLinearStatementEvaluations, HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, false), HidingSpartanMerkle: newMerkle(hints.spartanHidingHint.roundHints, false), @@ -396,10 +525,15 @@ func verifyCircuit( MatrixA: matrixA, MatrixB: matrixB, MatrixC: matrixC, + PublicInputs: publicInputs, } witness, _ := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) - publicWitness, _ := witness.Public() + publicWitness, err := witness.Public() + if err != nil { + log.Printf("Failed witness, Public(): %v", err) + return err + } opts := []backend.ProverOption{ backend.WithSolverOptions(solver.WithHints(utilities.IndexOf)), @@ -436,3 +570,42 @@ func witnessFirstRounds(hints Hints, isContainer bool) []Merkle { } return result } + +func parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation [2]Fp256, isContainer bool) (frontend.Variable, frontend.Variable) { + var fSumPublicWeights, gSumPublicWeights frontend.Variable + + if !isContainer { + fSumPublicWeights = typeConverters.LimbsToBigIntMod(publicWeightsClaimedEvaluation[0].Limbs) + gSumPublicWeights = typeConverters.LimbsToBigIntMod(publicWeightsClaimedEvaluation[1].Limbs) + } + + return fSumPublicWeights, gSumPublicWeights +} + +func extendLinearStatement( + circuit *Circuit, + linearStatementEvaluations [][]frontend.Variable, + pubWitnessEvaluations []frontend.Variable, +) [][]frontend.Variable { + var extendedLinearStatementEvals [][]frontend.Variable + + if !circuit.PublicInputs.IsEmpty() { + // Extend the statement equivalent array by prepending the public constraint (public constraint is added in starting at prover side) + extendedLinearStatementEvals = make([][]frontend.Variable, 2) + + // f_sums: [public_f_sum, f_sums[0], f_sums[1]... ] + extendedLinearStatementEvals[0] = make([]frontend.Variable, len(linearStatementEvaluations[0])+1) + extendedLinearStatementEvals[0][0] = pubWitnessEvaluations[0] + copy(extendedLinearStatementEvals[0][1:], linearStatementEvaluations[0]) + + // g_sums: [public_g_sum, g_sums[0], g_sums[1]... ] + extendedLinearStatementEvals[1] = make([]frontend.Variable, len(linearStatementEvaluations[1])+1) + extendedLinearStatementEvals[1][0] = pubWitnessEvaluations[1] + copy(extendedLinearStatementEvals[1][1:], linearStatementEvaluations[1]) + } else { + // No public inputs, use original arrays + extendedLinearStatementEvals = linearStatementEvaluations + } + + return extendedLinearStatementEvals +} diff --git a/recursive-verifier/app/circuit/common.go b/recursive-verifier/app/circuit/common.go index 16eea7c1f..707286738 100644 --- a/recursive-verifier/app/circuit/common.go +++ b/recursive-verifier/app/circuit/common.go @@ -30,6 +30,7 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v var deferred []Fp256 var claimedEvaluations ClaimedEvaluations var claimedEvaluations2 ClaimedEvaluations + var publicWeightsEvaluations [2]Fp256 for _, op := range io.Ops { switch op.Kind { @@ -128,6 +129,16 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v if err != nil { return fmt.Errorf("failed to deserialize claimed_evaluations_2: %w", err) } + + case "public_weights_evaluations": + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(config.Transcript[start:end]), + &publicWeightsEvaluations, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize public_weights_evaluations: %w", err) + } } if err != nil { @@ -204,7 +215,7 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v WitnessRoundHints: witnessRoundHints, } - err = verifyCircuit(deferred, config, hints, pk, vk, claimedEvaluations, claimedEvaluations2, r1cs, interner, buildOps) + err = verifyCircuit(deferred, config, hints, pk, vk, claimedEvaluations, claimedEvaluations2, publicWeightsEvaluations, r1cs, interner, buildOps, config.PublicInputs) if err != nil { return fmt.Errorf("verification failed: %w", err) } diff --git a/recursive-verifier/app/circuit/mtUtilities.go b/recursive-verifier/app/circuit/mtUtilities.go index 264727b38..4dc74eee2 100644 --- a/recursive-verifier/app/circuit/mtUtilities.go +++ b/recursive-verifier/app/circuit/mtUtilities.go @@ -112,3 +112,24 @@ func rlcBatchedLeaves(api frontend.API, leaves [][]frontend.Variable, foldSize i } return collapsed } + +// hashPublicInputs computes the hash of public inputs as field elements sequentially +func hashPublicInputs(sc *skyscraper.Skyscraper, publicInputs PublicInputs) (frontend.Variable, error) { + + if len(publicInputs.Values) == 0 { + return frontend.Variable(0), nil + } + + // For single element, we hash it with a zero + if len(publicInputs.Values) == 1 { + return sc.CompressV2(publicInputs.Values[0], frontend.Variable(0)), nil + } + + // For 2+ elements, use standard approach + hash := sc.CompressV2(publicInputs.Values[0], publicInputs.Values[1]) + for i := 2; i < len(publicInputs.Values); i++ { + hash = sc.CompressV2(hash, publicInputs.Values[i]) + } + + return hash, nil +} diff --git a/recursive-verifier/app/circuit/types.go b/recursive-verifier/app/circuit/types.go index 420b43e0e..f6db49d00 100644 --- a/recursive-verifier/app/circuit/types.go +++ b/recursive-verifier/app/circuit/types.go @@ -1,6 +1,8 @@ package circuit import ( + "reilabs/whir-verifier-circuit/app/utilities" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" ) @@ -89,18 +91,19 @@ type ProofObject struct { } type Config struct { - WHIRConfigWitness WHIRConfig `json:"whir_config_witness"` - WHIRConfigHidingSpartan WHIRConfig `json:"whir_config_hiding_spartan"` - LogNumConstraints int `json:"log_num_constraints"` - LogNumVariables int `json:"log_num_variables"` - LogANumTerms int `json:"log_a_num_terms"` - IOPattern string `json:"io_pattern"` - Transcript []byte `json:"transcript"` - TranscriptLen int `json:"transcript_len"` - WitnessStatementEvaluations []string `json:"witness_statement_evaluations"` - BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` - NumChallenges int `json:"num_challenges"` - W1Size int `json:"w1_size"` + WHIRConfigWitness WHIRConfig `json:"whir_config_witness"` + WHIRConfigHidingSpartan WHIRConfig `json:"whir_config_hiding_spartan"` + LogNumConstraints int `json:"log_num_constraints"` + LogNumVariables int `json:"log_num_variables"` + LogANumTerms int `json:"log_a_num_terms"` + IOPattern string `json:"io_pattern"` + Transcript []byte `json:"transcript"` + TranscriptLen int `json:"transcript_len"` + WitnessStatementEvaluations []string `json:"witness_statement_evaluations"` + BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` + NumChallenges int `json:"num_challenges"` + W1Size int `json:"w1_size"` + PublicInputs PublicInputs `json:"public_inputs"` } // Update Hints to support batch mode @@ -139,3 +142,20 @@ type DualClaimedEvaluations struct { First ClaimedEvaluations Second ClaimedEvaluations } + +type PublicInputs struct { + Values []frontend.Variable +} + +func (p *PublicInputs) UnmarshalJSON(data []byte) error { + values, err := utilities.UnmarshalPublicInputs(data) + if err != nil { + return err + } + p.Values = values + return nil +} + +func (p *PublicInputs) IsEmpty() bool { + return len(p.Values) == 0 +} diff --git a/recursive-verifier/app/circuit/whir.go b/recursive-verifier/app/circuit/whir.go index 0e153280e..3bdea8f19 100644 --- a/recursive-verifier/app/circuit/whir.go +++ b/recursive-verifier/app/circuit/whir.go @@ -238,6 +238,7 @@ func RunZKWhirBatch( // Common parameters whirParams WHIRParams, linearStatementValuesAtPoints []frontend.Variable, + publicInputs PublicInputs, ) (totalFoldingRandomness []frontend.Variable, err error) { numPolynomials := len(firstRounds) if numPolynomials == 0 { @@ -255,7 +256,15 @@ func RunZKWhirBatch( for i := 0; i < numPolynomials; i++ { numOOD += len(initialOODQueries[i]) } - numStatementConstraints := numPolynomials * 3 // 3 per commitment (Az, Bz, Cz) + + numStatementConstraints := 0 + + // w1 has 4 (pub, Az, Bz, Cz) constraints, w2 and remaining have 3 (Az, Bz, Cz) constraints + if !publicInputs.IsEmpty() { + numStatementConstraints = 4 + 3*(numPolynomials-1) + } else { + numStatementConstraints = numPolynomials * 3 + } numConstraints := numOOD + numStatementConstraints // Step 3: Read N×M evaluation matrix from transcript diff --git a/recursive-verifier/app/circuit/whir_utilities.go b/recursive-verifier/app/circuit/whir_utilities.go index 666dde44f..ac75a9c34 100644 --- a/recursive-verifier/app/circuit/whir_utilities.go +++ b/recursive-verifier/app/circuit/whir_utilities.go @@ -140,6 +140,7 @@ func computeWPoly( value = api.Add(value, api.Mul(initialData.InitialCombinationRandomness[j], utilities.EqPolyOutside(api, utilities.ExpandFromUnivariate(api, initialData.InitialOODQueries[j], numberVars), totalFoldingRandomness))) } + // Values are directly used as all linearStatements are deffered and hints were given. Checking of hints will be done later on. for j, linearStatementValueAtPoint := range linearStatementValuesAtPoints { value = api.Add(value, api.Mul(initialData.InitialCombinationRandomness[len(initialData.InitialOODQueries)+j], linearStatementValueAtPoint)) } diff --git a/recursive-verifier/app/utilities/utilities.go b/recursive-verifier/app/utilities/utilities.go index 04a8e9d2e..691af27dd 100644 --- a/recursive-verifier/app/utilities/utilities.go +++ b/recursive-verifier/app/utilities/utilities.go @@ -1,6 +1,8 @@ package utilities import ( + "encoding/hex" + "encoding/json" "fmt" "math/big" "reilabs/whir-verifier-circuit/app/typeConverters" @@ -210,3 +212,49 @@ func DotProduct(api frontend.API, a []frontend.Variable, b []frontend.Variable) } return acc } + +// ParseHexFieldElement parses a hex string representing a FieldElement (little-endian) +// and converts it to a big.Int. The hex string should be 64 characters (32 bytes). +func ParseHexFieldElement(hexStr string) (*big.Int, error) { + if len(hexStr) >= 2 && hexStr[0:2] == "0x" { + hexStr = hexStr[2:] + } + + bytes, err := hex.DecodeString(hexStr) + if err != nil { + return nil, fmt.Errorf("invalid hex string: %w", err) + } + + reversed := make([]byte, len(bytes)) + for i, b := range bytes { + reversed[len(bytes)-1-i] = b + } + + result := new(big.Int) + result.SetBytes(reversed) + + modulus := new(big.Int) + modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + result.Mod(result, modulus) + + return result, nil +} + +// UnmarshalPublicInputs parses a JSON array of hex-encoded FieldElement strings +// and returns them as frontend.Variable slice. +func UnmarshalPublicInputs(data []byte) ([]frontend.Variable, error) { + var arr []string + if err := json.Unmarshal(data, &arr); err != nil { + return nil, err + } + + values := make([]frontend.Variable, len(arr)) + for i, hexStr := range arr { + value, err := ParseHexFieldElement(hexStr) + if err != nil { + return nil, fmt.Errorf("failed to parse public input at index %d: %w", i, err) + } + values[i] = value + } + return values, nil +} diff --git a/tooling/cli/src/cmd/generate_gnark_inputs.rs b/tooling/cli/src/cmd/generate_gnark_inputs.rs index 883f52b5e..0e3cc6d0a 100644 --- a/tooling/cli/src/cmd/generate_gnark_inputs.rs +++ b/tooling/cli/src/cmd/generate_gnark_inputs.rs @@ -62,6 +62,7 @@ impl Command for Args { prover.whir_for_witness.a_num_terms, prover.whir_for_witness.num_challenges, prover.whir_for_witness.w1_size, + &proof.public_inputs, &self.params_for_recursive_verifier, ); diff --git a/tooling/provekit-gnark/src/gnark_config.rs b/tooling/provekit-gnark/src/gnark_config.rs index 015cc68f8..41439a75c 100644 --- a/tooling/provekit-gnark/src/gnark_config.rs +++ b/tooling/provekit-gnark/src/gnark_config.rs @@ -1,6 +1,6 @@ use { ark_poly::EvaluationDomain, - provekit_common::{IOPattern, WhirConfig}, + provekit_common::{IOPattern, PublicInputs, WhirConfig}, serde::{Deserialize, Serialize}, std::{fs::File, io::Write}, tracing::instrument, @@ -29,6 +29,8 @@ pub struct GnarkConfig { pub num_challenges: usize, /// size of w1 pub w1_size: usize, + /// public inputs + pub public_inputs: PublicInputs, } #[derive(Debug, Serialize, Deserialize)] @@ -114,6 +116,7 @@ pub fn gnark_parameters( a_num_terms: usize, num_challenges: usize, w1_size: usize, + public_inputs: &PublicInputs, ) -> GnarkConfig { GnarkConfig { whir_config_witness: WHIRConfigGnark::new(whir_params_witness), @@ -126,6 +129,7 @@ pub fn gnark_parameters( transcript_len: transcript.to_vec().len(), num_challenges, w1_size, + public_inputs: public_inputs.clone(), } } @@ -141,6 +145,7 @@ pub fn write_gnark_parameters_to_file( a_num_terms: usize, num_challenges: usize, w1_size: usize, + public_inputs: &PublicInputs, file_path: &str, ) { let gnark_config = gnark_parameters( @@ -153,6 +158,7 @@ pub fn write_gnark_parameters_to_file( a_num_terms, num_challenges, w1_size, + public_inputs, ); let mut file_params = File::create(file_path).unwrap(); file_params diff --git a/tooling/verifier-server/src/services/verification.rs b/tooling/verifier-server/src/services/verification.rs index f8abd1385..398b282c4 100644 --- a/tooling/verifier-server/src/services/verification.rs +++ b/tooling/verifier-server/src/services/verification.rs @@ -97,6 +97,7 @@ impl VerificationService { whir_scheme.a_num_terms, whir_scheme.num_challenges, whir_scheme.w1_size, + &proof.public_inputs, gnark_params_path, );