From af8ad76688efaee5b8fab94d8d45522cf697e44d Mon Sep 17 00:00:00 2001 From: ashpect Date: Mon, 8 Dec 2025 12:50:14 +0530 Subject: [PATCH 01/19] feat: pubwit struct --- provekit/common/Cargo.toml | 1 + provekit/common/src/lib.rs | 1 + provekit/common/src/noir_proof_scheme.rs | 3 +- provekit/common/src/utils/mod.rs | 1 + provekit/common/src/utils/serde_ark_vec.rs | 87 ++++++++++++++++++++++ provekit/common/src/witness/mod.rs | 52 ++++++++++++- provekit/prover/src/lib.rs | 6 +- 7 files changed, 146 insertions(+), 5 deletions(-) create mode 100644 provekit/common/src/utils/serde_ark_vec.rs diff --git a/provekit/common/Cargo.toml b/provekit/common/Cargo.toml index 92faae9c6..34b43012e 100644 --- a/provekit/common/Cargo.toml +++ b/provekit/common/Cargo.toml @@ -38,6 +38,7 @@ ruint.workspace = true serde.workspace = true serde_json.workspace = true tracing.workspace = true +sha2.workspace = true zerocopy.workspace = true zeroize.workspace = true zstd.workspace = true diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 02fe99a17..2b7cca179 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -22,6 +22,7 @@ pub use { r1cs::R1CS, verifier::Verifier, whir_r1cs::{IOPattern, WhirConfig, WhirR1CSProof, WhirR1CSScheme}, + witness::PublicInputs, }; #[cfg(test)] diff --git a/provekit/common/src/noir_proof_scheme.rs b/provekit/common/src/noir_proof_scheme.rs index b8c8ff937..f7e40fd22 100644 --- a/provekit/common/src/noir_proof_scheme.rs +++ b/provekit/common/src/noir_proof_scheme.rs @@ -2,7 +2,7 @@ use { crate::{ whir_r1cs::{WhirR1CSProof, WhirR1CSScheme}, witness::{NoirWitnessGenerator, SplitWitnessBuilders}, - NoirElement, R1CS, + NoirElement, R1CS, PublicInputs, }, acir::circuit::Program, serde::{Deserialize, Serialize}, @@ -20,6 +20,7 @@ pub struct NoirProofScheme { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct NoirProof { + pub public_inputs: PublicInputs, pub whir_r1cs_proof: WhirR1CSProof, } diff --git a/provekit/common/src/utils/mod.rs b/provekit/common/src/utils/mod.rs index a5f6aa5b7..43943c566 100644 --- a/provekit/common/src/utils/mod.rs +++ b/provekit/common/src/utils/mod.rs @@ -2,6 +2,7 @@ mod print_abi; pub mod serde_ark; pub mod serde_ark_option; +pub mod serde_ark_vec; pub mod serde_hex; pub mod serde_jsonify; pub mod sumcheck; diff --git a/provekit/common/src/utils/serde_ark_vec.rs b/provekit/common/src/utils/serde_ark_vec.rs new file mode 100644 index 000000000..b26c309ed --- /dev/null +++ b/provekit/common/src/utils/serde_ark_vec.rs @@ -0,0 +1,87 @@ +use { + crate::FieldElement, + ark_serialize::{CanonicalDeserialize, CanonicalSerialize}, + serde::{ + de::{Error as _, SeqAccess, Visitor}, + ser::{Error as _, SerializeSeq}, + Deserializer, Serializer, + }, + std::fmt, +}; + +pub fn serialize(vec: &Vec, serializer: S) -> Result +where + S: Serializer, +{ + let is_human_readable = serializer.is_human_readable(); + let mut seq = serializer.serialize_seq(Some(vec.len()))?; + for element in vec { + let mut buf = Vec::with_capacity(element.compressed_size()); + element + .serialize_compressed(&mut buf) + .map_err(|e| S::Error::custom(format!("Failed to serialize: {e}")))?; + + // Write bytes + if is_human_readable { + // ark_serialize doesn't have human-readable serialization. And Serde + // doesn't have good defaults for [u8]. So we implement hexadecimal + // serialization. + let hex = hex::encode(buf); + seq.serialize_element(&hex)?; + } else { + seq.serialize_element(&buf)?; + } + } + seq.end() +} + +pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + struct VecVisitor { + is_human_readable: bool, + } + + impl<'de> Visitor<'de> for VecVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of field elements") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut vec = Vec::new(); + if self.is_human_readable { + while let Some(hex) = seq.next_element::()? { + let bytes = hex::decode(hex) + .map_err(|e| A::Error::custom(format!("invalid hex: {e}")))?; + let mut reader = &*bytes; + let element = FieldElement::deserialize_compressed(&mut reader) + .map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?; + if !reader.is_empty() { + return Err(A::Error::custom("while deserializing: trailing bytes")); + } + vec.push(element); + } + } else { + while let Some(bytes) = seq.next_element::>()? { + let mut reader = &*bytes; + let element = FieldElement::deserialize_compressed(&mut reader) + .map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?; + if !reader.is_empty() { + return Err(A::Error::custom("while deserializing: trailing bytes")); + } + vec.push(element); + } + } + Ok(vec) + } + } + + let is_human_readable = deserializer.is_human_readable(); + deserializer.deserialize_seq(VecVisitor { is_human_readable }) +} diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 4361a5dce..549ff34ba 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -7,9 +7,10 @@ mod witness_generator; mod witness_io_pattern; use { - crate::{utils::serde_ark, FieldElement}, - ark_ff::One, + crate::{utils::{serde_ark, serde_ark_vec}, FieldElement}, + ark_ff::{BigInt, One, PrimeField}, serde::{Deserialize, Serialize}, + sha2::{Digest, Sha256}, }; pub use { binops::{BINOP_ATOMIC_BITS, BINOP_BITS, NUM_DIGITS}, @@ -40,3 +41,50 @@ impl ConstantOrR1CSWitness { } } } + + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PublicInputs(#[serde(with = "serde_ark_vec")] pub Vec); + +impl PublicInputs { + /// Creates a new `PublicInputs` with a constant 1 field element at the + /// start. + pub fn new() -> Self { + Self(vec![FieldElement::one()]) + } + + /// Creates a new `PublicInputs` from a vector, adding a constant 1 field + /// element at the start. To emulate the constant 1 witness in the R1CS + /// instance. + pub fn from_vec(mut vec: Vec) -> Self { + vec.insert(0, FieldElement::one()); + Self(vec) + } + + pub fn len(&self) -> usize { + self.0.len() + } + + /// Hashes the public input values using SHA-256 and converts the result to + /// a FieldElement. + pub fn hash(&self) -> FieldElement { + let mut hasher = Sha256::new(); + + // Hash all public values from witness + for value in self.0.iter() { + let bigint = value.into_bigint(); + for limb in bigint.0.iter() { + hasher.update(&limb.to_le_bytes()); + } + } + + let result = hasher.finalize(); + + let limbs = result + .chunks_exact(8) + .map(|s| u64::from_le_bytes(s.try_into().unwrap())) + .collect::>(); + + FieldElement::new(BigInt::new(limbs.try_into().unwrap())) + } +} diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 031a29dde..0fdefd4e2 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -6,7 +6,7 @@ use { nargo::foreign_calls::DefaultForeignCallBuilder, noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, - provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover}, + provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover, PublicInputs}, std::path::Path, tracing::instrument, }; @@ -119,7 +119,9 @@ impl Prove for Prover { .prove(merlin, self.r1cs, commitments) .context("While proving R1CS instance")?; - Ok(NoirProof { whir_r1cs_proof }) + let public_inputs = PublicInputs::new(); + + Ok(NoirProof { public_inputs, whir_r1cs_proof }) } } From 3d5754a88fb2d602fbcb0c7880e7911071bfcec7 Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 10 Dec 2025 01:08:43 +0530 Subject: [PATCH 02/19] feat: put pub wit at starting --- .../common/src/witness/scheduling/splitter.rs | 59 ++++++++++++++++++- .../common/src/witness/witness_builder.rs | 4 +- .../r1cs-compiler/src/noir_proof_scheme.rs | 7 ++- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index 57a44367b..eb990dd2e 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -26,7 +26,7 @@ impl<'a> WitnessSplitter<'a> { /// (post-challenge). /// /// Returns (w1_builder_indices, w2_builder_indices) - pub fn split_builders(&self) -> (Vec, Vec) { + pub fn split_builders(&self, acir_public_inputs_indices_set: HashSet) -> (Vec, Vec) { let builder_count = self.witness_builders.len(); // Step 1: Find all Challenge builders @@ -40,7 +40,11 @@ impl<'a> WitnessSplitter<'a> { .collect(); if challenge_builders.is_empty() { - return ((0..builder_count).collect(), Vec::new()); + let w1_indices = self.rearrange_w1( + (0..builder_count).collect(), + &acir_public_inputs_indices_set, + ); + return (w1_indices, Vec::new()); } // Step 2: Forward DFS from challenges to find mandatory_w2 @@ -135,6 +139,7 @@ impl<'a> WitnessSplitter<'a> { // Step 7: Assign free builders greedily while respecting dependencies // Rule: if any dependency is in w2, the builder must also be in w2 // (because w1 is solved before w2) + // with the exception of public builders writing public witnesses) let mut w1_set = mandatory_w1; let mut w2_set = mandatory_w2; @@ -149,6 +154,15 @@ impl<'a> WitnessSplitter<'a> { let witness_count = DependencyInfo::extract_writes(&self.witness_builders[idx]).len(); + // If free builder writes a public witness, add it to w1_set. + if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] { + if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + w1_set.insert(idx); + w1_witness_count += witness_count; + continue; + } + } + if must_be_w2 { w2_set.insert(idx); w2_witness_count += witness_count; @@ -170,4 +184,45 @@ impl<'a> WitnessSplitter<'a> { (w1_indices, w2_indices) } + + /// Rearranges w1 indices: constant builder (0) first, then public inputs, + /// then rest. + fn rearrange_w1( + &self, + w1_indices: Vec, + acir_public_inputs_indices_set: &HashSet, + ) -> Vec { + let mut public_input_builder_indices = Vec::new(); + let mut rest_indices = Vec::new(); + + // Sanity Check: Make sure all public inputs and WITNESS_ONE_IDX are in + // w1_indices. + for &idx in acir_public_inputs_indices_set.iter() { + if !w1_indices.contains(&(idx as usize)) { + panic!("Public input {} is not in w1_indices", idx); + } + } + + // Separate into: 0, public inputs, and rest + for builder_idx in w1_indices { + if builder_idx == 0 { + continue; // Will add 0 first + } else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] { + if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + public_input_builder_indices.push(builder_idx); + continue; + } + } + rest_indices.push(builder_idx); + } + + public_input_builder_indices.sort_unstable(); + rest_indices.sort_unstable(); + + // Reorder: 0 first, then public inputs, then rest + let mut new_w1_indices = vec![0]; + new_w1_indices.extend(public_input_builder_indices); + new_w1_indices.extend(rest_indices); + new_w1_indices + } } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index d212c9bc4..ee248153c 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -14,6 +14,7 @@ use { FieldElement, R1CS, }, serde::{Deserialize, Serialize}, + std::collections::HashSet, std::num::NonZeroU32, }; @@ -174,6 +175,7 @@ impl WitnessBuilder { witness_builders: &[WitnessBuilder], r1cs: R1CS, witness_map: Vec>, + acir_public_inputs_indices_set: HashSet, ) -> (SplitWitnessBuilders, R1CS, Vec>, usize) { if witness_builders.is_empty() { return ( @@ -190,7 +192,7 @@ impl WitnessBuilder { // Step 1: Analyze dependencies and split into w1/w2 let splitter = WitnessSplitter::new(witness_builders); - let (w1_indices, w2_indices) = splitter.split_builders(); + let (w1_indices, w2_indices) = splitter.split_builders(acir_public_inputs_indices_set); // Step 2: Extract w1 and w2 builders in order let w1_builders: Vec = w1_indices diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 58e9346a7..6c9ca3223 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -12,6 +12,7 @@ use { }, std::{fs::File, path::Path}, tracing::{info, instrument}, + std::collections::HashSet, }; pub trait NoirProofSchemeBuilder { @@ -61,9 +62,13 @@ impl NoirProofSchemeBuilder for NoirProofScheme { r1cs.c.num_entries() ); + // Extract ACIR public input indices set + let acir_public_inputs_indices_set: HashSet = + main.public_inputs().indices().iter().cloned().collect(); + // Split witness builders and remap indices for sound challenge generation let (split_witness_builders, remapped_r1cs, remapped_witness_map, num_challenges) = - WitnessBuilder::split_and_prepare_layers(&witness_builders, r1cs, witness_map); + WitnessBuilder::split_and_prepare_layers(&witness_builders, r1cs, witness_map, acir_public_inputs_indices_set); info!( "Witness split: w1 size = {}, w2 size = {}", split_witness_builders.w1_size, From 8e3a8639bf300b385ac449a41c0fe5dcbc47a3fc Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 10 Dec 2025 10:34:59 +0530 Subject: [PATCH 03/19] feat: prover addon --- provekit/common/src/utils/sumcheck.rs | 10 +++ provekit/common/src/whir_r1cs.rs | 4 ++ provekit/prover/src/lib.rs | 43 +++++++++++-- provekit/prover/src/r1cs.rs | 13 ++++ provekit/prover/src/whir_r1cs.rs | 90 +++++++++++++++++++++++---- 5 files changed, 143 insertions(+), 17 deletions(-) diff --git a/provekit/common/src/utils/sumcheck.rs b/provekit/common/src/utils/sumcheck.rs index 7e1c5a245..6baef51d7 100644 --- a/provekit/common/src/utils/sumcheck.rs +++ b/provekit/common/src/utils/sumcheck.rs @@ -114,6 +114,10 @@ pub trait SumcheckIOPattern { fn add_rand(self, num_rand: usize) -> Self; fn add_zk_sumcheck_polynomials(self, num_vars: usize) -> Self; + + /// Prover sends the hash of the public inputs + /// Verifier sends randomness to construct weights + fn add_public_inputs(self) -> Self; } impl SumcheckIOPattern for IOPattern @@ -136,6 +140,12 @@ where self } + fn add_public_inputs(mut self) -> Self { + self = self.add_scalars(1, "Public Inputs Hash"); + self = self.challenge_scalars(1, "Public Weights Vector Random"); + self + } + fn add_rand(self, num_rand: usize) -> Self { self.challenge_scalars(num_rand, "rand") } diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index 302bfd79d..9163c967c 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -50,6 +50,8 @@ impl WhirR1CSScheme { .add_whir_proof(&self.whir_for_hiding_spartan) .hint("claimed_evaluations_1") .hint("claimed_evaluations_2") + .add_public_inputs() + .hint("public_weights_evaluations") .add_whir_batch_proof(&self.whir_witness, num_witnesses, num_constraints_total); } else { io = io @@ -59,6 +61,8 @@ impl WhirR1CSScheme { .add_zk_sumcheck_polynomials(self.m_0) .add_whir_proof(&self.whir_for_hiding_spartan) .hint("claimed_evaluations") + .add_public_inputs() + .hint("public_weights_evaluations") .add_whir_proof(&self.whir_witness); } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 0fdefd4e2..7d69475b1 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -7,6 +7,7 @@ use { noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover, PublicInputs}, + std::collections::{HashMap, HashSet}, std::path::Path, tracing::instrument, }; @@ -57,6 +58,15 @@ impl Prove for Prover { let acir_witness_idx_to_value_map = self.generate_witness(input_map)?; + let acir_public_inputs = self.program.functions[0].public_inputs().indices(); + let acir_public_inputs_set: HashSet = acir_public_inputs.iter().cloned().collect(); + let mut acir_to_r1cs_public_map = HashMap::new(); + + println!("DEBUG_ASH: acir_witness_idx_to_value_map: {:?}", acir_witness_idx_to_value_map); + println!("DEBUG_ASH: acir_public_inputs: {:?}", acir_public_inputs); + println!("DEBUG_ASH: acir_public_inputs_set: {:?}", acir_public_inputs_set); + println!("DEBUG_ASH: acir_to_r1cs_public_map: {:?}", acir_to_r1cs_public_map); + // Set up transcript let io: IOPattern = self.whir_for_witness.create_io_pattern(); let mut merlin = io.to_prover_state(); @@ -70,13 +80,19 @@ impl Prove for Prover { self.split_witness_builders.w1_layers, &acir_witness_idx_to_value_map, &mut merlin, + &acir_public_inputs_set, + &mut acir_to_r1cs_public_map, ); + println!("DEBUG_ASH: acir_to_r1cs_public_map after w1: {:?}", acir_to_r1cs_public_map); + + let w1 = witness[..self.whir_for_witness.w1_size] .iter() .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w1 are missing"))) .collect::>>()?; + println!("DEBUG_ASH: w1: {:?}", w1); let commitment_1 = self .whir_for_witness .commit(&mut merlin, &self.r1cs, w1, true) @@ -90,7 +106,11 @@ impl Prove for Prover { self.split_witness_builders.w2_layers, &acir_witness_idx_to_value_map, &mut merlin, - ); + &acir_public_inputs_set, + &mut acir_to_r1cs_public_map, + ); // DEBUG_ASH : if w2 didn't have pub witness, no need honestly for this + + println!("DEBUG_ASH: acir_to_r1cs_public_map after w2: {:?}", acir_to_r1cs_public_map); let w2 = witness[self.whir_for_witness.w1_size..] .iter() @@ -112,15 +132,30 @@ impl Prove for Prover { self.r1cs .test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::>()) .context("While verifying R1CS instance")?; + + // Gather public inputs from witness + let public_indices = acir_to_r1cs_public_map + .values() + .map(|&x| x) + .collect::>(); + + let public_inputs = PublicInputs::from_vec( + public_indices + .iter() + .map(|&i| { + witness[i].ok_or_else(|| anyhow::anyhow!("Missing public input witness at index {i}")) + }) + .collect::>>()?, + ); + drop(witness); let whir_r1cs_proof = self .whir_for_witness - .prove(merlin, self.r1cs, commitments) + .prove(merlin, self.r1cs, commitments, &public_inputs) .context("While proving R1CS instance")?; - let public_inputs = PublicInputs::new(); - + println!("DEBUG_ASH: public_inputs: {:?}", public_inputs); Ok(NoirProof { public_inputs, whir_r1cs_proof }) } } diff --git a/provekit/prover/src/r1cs.rs b/provekit/prover/src/r1cs.rs index 18b3ffca0..cb752c26c 100644 --- a/provekit/prover/src/r1cs.rs +++ b/provekit/prover/src/r1cs.rs @@ -10,6 +10,7 @@ use { FieldElement, NoirElement, R1CS, }, spongefish::ProverState, + std::collections::{HashMap, HashSet}, tracing::instrument, }; @@ -20,6 +21,8 @@ pub trait R1CSSolver { plan: LayeredWitnessBuilders, acir_map: &WitnessMap, transcript: &mut ProverState, + acir_public_inputs_set: &HashSet, + acir_to_r1cs_public_map: &mut HashMap, ); #[cfg(test)] @@ -54,12 +57,22 @@ impl R1CSSolver for R1CS { plan: LayeredWitnessBuilders, acir_map: &WitnessMap, transcript: &mut ProverState, + acir_public_inputs_set: &HashSet, + acir_to_r1cs_public_map: &mut HashMap, ) { for layer in &plan.layers { match layer.typ { LayerType::Other => { // Execute regular operations for builder in &layer.witness_builders { + + if let WitnessBuilder::Acir(r1cs_witness_idx, acir_witness_idx) = builder { + if acir_public_inputs_set.contains(&(*acir_witness_idx as u32)) { + acir_to_r1cs_public_map + .insert(*acir_witness_idx as u32, *r1cs_witness_idx); + } + } + builder.solve(&acir_map, witness, transcript); } } diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index bb2b8a64b..baf4238f0 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -3,6 +3,7 @@ use { ark_ff::UniformRand, ark_std::{One, Zero}, provekit_common::{ + PublicInputs, skyscraper::{SkyscraperMerkleConfig, SkyscraperSponge}, utils::{ pad_to_power_of_two, @@ -54,6 +55,7 @@ pub trait WhirR1CSProver { merlin: ProverState, r1cs: R1CS, commitments: Vec, + public_inputs: &PublicInputs, ) -> Result; } @@ -121,6 +123,7 @@ impl WhirR1CSProver for WhirR1CSScheme { mut merlin: ProverState, r1cs: R1CS, mut commitments: Vec, + public_inputs: &PublicInputs, ) -> Result { ensure!(!commitments.is_empty(), "Need at least one commitment"); @@ -141,6 +144,8 @@ impl WhirR1CSProver for WhirR1CSScheme { w }; + println!("DEBUG_ASH: full_witness: {:?}", full_witness); + // First round: ZK sumcheck to reduce R1CS to weighted evaluation let alpha = run_zk_sumcheck_prover( &r1cs, @@ -159,16 +164,28 @@ impl WhirR1CSProver for WhirR1CSScheme { let commitment = commitments.into_iter().next().unwrap(); let alphas: [Vec; 3] = alphas.try_into().unwrap(); - let (statement, f_sums, g_sums) = create_combined_statement_over_two_polynomials::<3>( + let (mut statement, f_sums, g_sums) = create_combined_statement_over_two_polynomials::<3>( self.m, &commitment.commitment_to_witness, - commitment.masked_polynomial, - commitment.random_polynomial, + &commitment.masked_polynomial, + &commitment.random_polynomial, &alphas, ); merlin.hint::<(Vec, Vec)>(&(f_sums, g_sums))?; + // VERIFY the size given by self.m + let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); + let (public_f_sum, public_g_sum) = update_statement_with_public_weights( + &mut statement, + &commitment.commitment_to_witness, + &commitment.masked_polynomial, + &commitment.random_polynomial, + public_weight, + ); + + let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); + run_zk_whir_pcs_prover( commitment.commitment_to_witness, statement, @@ -197,8 +214,8 @@ impl WhirR1CSProver for WhirR1CSScheme { create_combined_statement_over_two_polynomials::<3>( self.m, &c1.commitment_to_witness, - c1.masked_polynomial, - c1.random_polynomial, + &c1.masked_polynomial, + &c1.random_polynomial, &alphas_1, ); drop(alphas_1); @@ -207,8 +224,8 @@ impl WhirR1CSProver for WhirR1CSScheme { create_combined_statement_over_two_polynomials::<3>( self.m, &c2.commitment_to_witness, - c2.masked_polynomial, - c2.random_polynomial, + &c2.masked_polynomial, + &c2.random_polynomial, &alphas_2, ); drop(alphas_2); @@ -511,8 +528,8 @@ pub fn run_zk_sumcheck_prover( create_combined_statement_over_two_polynomials::<1>( blinding_polynomial_variables + 1, &commitment_to_blinding_polynomial, - blindings_mask_polynomial, - blindings_blind_polynomial, + &blindings_mask_polynomial, + &blindings_blind_polynomial, &[expand_powers(alpha.as_slice())], ); @@ -545,8 +562,8 @@ fn expand_powers(values: &[FieldElement]) -> Vec { fn create_combined_statement_over_two_polynomials( cfg_nv: usize, witness: &Witness, - f_polynomial: EvaluationsList, - g_polynomial: EvaluationsList, + f_polynomial: &EvaluationsList, + g_polynomial: &EvaluationsList, alphas: &[Vec; N], ) -> ( Statement, @@ -576,8 +593,8 @@ fn create_combined_statement_over_two_polynomials( w_full.resize(final_len, FieldElement::zero()); let weight = Weights::linear(EvaluationsList::new(w_full)); - let f = weight.weighted_sum(&f_polynomial); - let g = weight.weighted_sum(&g_polynomial); + let f = weight.weighted_sum(f_polynomial); + let g = weight.weighted_sum(g_polynomial); statement.add_constraint(weight, f + witness.batching_randomness * g); f_sums.push(f); @@ -628,3 +645,50 @@ pub fn run_zk_whir_pcs_batch_prover( (randomness, deferred) } + + +fn update_statement_with_public_weights( + statement: &mut Statement, + witness: &Witness, + f_polynomial: &EvaluationsList, + g_polynomial: &EvaluationsList, + public_weights: Weights, +) -> (FieldElement, FieldElement) { + let f = public_weights.weighted_sum(f_polynomial); + let g = public_weights.weighted_sum(g_polynomial); + statement.add_constraint_in_front(public_weights, f + witness.batching_randomness * g); + (f, g) +} + +fn get_public_weights( + public_inputs: &PublicInputs, + merlin: &mut ProverState, + m: usize, +) -> Weights { + // Add hash to transcript + let public_inputs_hash = public_inputs.hash(); + let _ = merlin.add_scalars(&[public_inputs_hash]); + + // Get random point x + let mut x_buf = [FieldElement::zero()]; + merlin + .fill_challenge_scalars(&mut x_buf) + .expect("Failed to get challenge from Merlin"); + let x = x_buf[0]; + + let domain_size = 1 << m; + let mut public_weights = vec![FieldElement::zero(); domain_size]; + + // Set public weights for public inputs [1,x,x^2,x^3...x^n-1,0,0,0...0] + let mut current_pow = FieldElement::one(); + for (idx, _) in public_inputs.0.iter().enumerate() { + public_weights[idx] = current_pow; + current_pow = current_pow * x; + } + + Weights::geometric( + x, + public_inputs.0.len(), + EvaluationsList::new(public_weights), + ) +} From 35214168f9969593dc38fc185a17be8d5084050a Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 10 Dec 2025 10:35:35 +0530 Subject: [PATCH 04/19] feat: verifier updates --- provekit/verifier/src/lib.rs | 2 +- provekit/verifier/src/whir_r1cs.rs | 65 ++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/provekit/verifier/src/lib.rs b/provekit/verifier/src/lib.rs index abdb369fc..d1ec351a4 100644 --- a/provekit/verifier/src/lib.rs +++ b/provekit/verifier/src/lib.rs @@ -17,7 +17,7 @@ impl Verify for Verifier { self.whir_for_witness .take() .unwrap() - .verify(&proof.whir_r1cs_proof)?; + .verify(&proof.whir_r1cs_proof, &proof.public_inputs)?; Ok(()) } diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 8a668fa03..3ae9dc93f 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -2,6 +2,7 @@ use { anyhow::{ensure, Context, Result}, ark_std::{One, Zero}, provekit_common::{ + PublicInputs, skyscraper::SkyscraperSponge, utils::sumcheck::{calculate_eq, eval_cubic_poly}, FieldElement, WhirConfig, WhirR1CSProof, WhirR1CSScheme, @@ -29,13 +30,13 @@ pub struct DataFromSumcheckVerifier { } pub trait WhirR1CSVerifier { - fn verify(&self, proof: &WhirR1CSProof) -> Result<()>; + fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs) -> Result<()>; } impl WhirR1CSVerifier for WhirR1CSScheme { #[instrument(skip_all)] #[allow(unused)] - fn verify(&self, proof: &WhirR1CSProof) -> Result<()> { + fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs) -> Result<()> { let io = self.create_io_pattern(); let mut arthur = io.to_verifier_state(&proof.transcript); @@ -68,7 +69,7 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let whir_sums_2: ([FieldElement; 3], [FieldElement; 3]) = (sums_2.0.try_into().unwrap(), sums_2.1.try_into().unwrap()); - let statement_1 = prepare_statement_for_witness_verifier::<3>( + let mut statement_1 = prepare_statement_for_witness_verifier::<3>( self.m, &parsed_commitment_1, &whir_sums_1, @@ -79,6 +80,27 @@ impl WhirR1CSVerifier for WhirR1CSScheme { &whir_sums_2, ); + let mut public_inputs_hash_buf = [FieldElement::zero()]; + arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; + let expected_public_inputs_hash = public_inputs.hash(); + ensure!( + public_inputs_hash_buf[0] == expected_public_inputs_hash, + "Public inputs hash mismatch: expected {:?}, got {:?}", + expected_public_inputs_hash, + public_inputs_hash_buf[0] + ); + + let mut public_weights_vector_random_buf = [FieldElement::zero()]; + arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; + + let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); + update_statement_for_witness_verifier( + self.m, + &mut statement_1, + &parsed_commitment_1, + whir_pub_weights_query_answer, + ); + run_whir_pcs_batch_verifier( &mut arthur, &self.whir_witness, @@ -98,12 +120,33 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let whir_sums: ([FieldElement; 3], [FieldElement; 3]) = (sums.0.try_into().unwrap(), sums.1.try_into().unwrap()); - let statement = prepare_statement_for_witness_verifier::<3>( + let mut statement = prepare_statement_for_witness_verifier::<3>( self.m, &parsed_commitment_1, &whir_sums, ); + let mut public_inputs_hash_buf = [FieldElement::zero()]; + arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; + let expected_public_inputs_hash = public_inputs.hash(); + ensure!( + public_inputs_hash_buf[0] == expected_public_inputs_hash, + "Public inputs hash mismatch: expected {:?}, got {:?}", + expected_public_inputs_hash, + public_inputs_hash_buf[0] + ); + + let mut public_weights_vector_random_buf = [FieldElement::zero()]; + arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; + + let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); + update_statement_for_witness_verifier( + self.m, + &mut statement, + &parsed_commitment_1, + whir_pub_weights_query_answer, + ); + run_whir_pcs_verifier( &mut arthur, &parsed_commitment_1, @@ -147,6 +190,20 @@ fn prepare_statement_for_witness_verifier( statement_verifier } +fn update_statement_for_witness_verifier( + m: usize, + statement_verifier: &mut Statement, + parsed_commitment: &ParsedCommitment, + whir_public_weights_query_answer: (FieldElement, FieldElement), +) { + let (public_f_sum, public_g_sum) = whir_public_weights_query_answer; + let public_weight = Weights::linear(EvaluationsList::new(vec![FieldElement::zero(); 1 << m])); + statement_verifier.add_constraint_in_front( + public_weight, + public_f_sum + public_g_sum * parsed_commitment.batching_randomness, + ); +} + #[instrument(skip_all)] pub fn run_sumcheck_verifier( arthur: &mut VerifierState, From fba715e0d198a72cc70f3d310b5112f8ff767026 Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 10 Dec 2025 12:28:55 +0530 Subject: [PATCH 05/19] feat: patch batch_prove --- provekit/prover/src/whir_r1cs.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index baf4238f0..122314f62 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -210,7 +210,7 @@ impl WhirR1CSProver for WhirR1CSScheme { let alphas_1: [Vec; 3] = alphas_1.try_into().unwrap(); let alphas_2: [Vec; 3] = alphas_2.try_into().unwrap(); - let (statement_1, f_sums_1, g_sums_1) = + let (mut statement_1, f_sums_1, g_sums_1) = create_combined_statement_over_two_polynomials::<3>( self.m, &c1.commitment_to_witness, @@ -233,6 +233,18 @@ impl WhirR1CSProver for WhirR1CSScheme { merlin.hint::<(Vec, Vec)>(&(f_sums_1, g_sums_1))?; merlin.hint::<(Vec, Vec)>(&(f_sums_2, g_sums_2))?; + // VERIFY the size given by self.m + let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); + let (public_f_sum, public_g_sum) = update_statement_with_public_weights( + &mut statement_1, + &c1.commitment_to_witness, + &c1.masked_polynomial, + &c1.random_polynomial, + public_weight, + ); + + let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); + run_zk_whir_pcs_batch_prover( &[c1.commitment_to_witness, c2.commitment_to_witness], &[statement_1, statement_2], From 09b4a3f6b0a7dc0a20e13cd647a1d24001f3c99d Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 11 Dec 2025 16:20:51 +0530 Subject: [PATCH 06/19] fix: rearrange for w2 --- provekit/common/src/whir_r1cs.rs | 4 ++-- provekit/common/src/witness/scheduling/splitter.rs | 4 +++- provekit/common/src/witness/witness_builder.rs | 3 +++ provekit/prover/src/whir_r1cs.rs | 4 ++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index 9163c967c..572fa5620 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -34,10 +34,10 @@ impl WhirR1CSScheme { if self.num_challenges > 0 { // Compute total constraints: OOD + statement // OOD: 2 witnesses × committment_ood_samples each - // Statement: 2 statements × 3 constraints each = 6 + // Statement: statement_1 has 3 constraints + 1 public weights constraint = 4, statement_2 has 3 = 3, total = 7 let num_witnesses = 2; let num_ood_constraints = num_witnesses * self.whir_witness.committment_ood_samples; - let num_statement_constraints = 6; // 2 statements × 3 constraints + let num_statement_constraints = 7; // (3+1) + (3) let num_constraints_total = num_ood_constraints + num_statement_constraints; io = io diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index eb990dd2e..b7223dc2f 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -157,6 +157,7 @@ impl<'a> WitnessSplitter<'a> { // If free builder writes a public witness, add it to w1_set. if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] { if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + println!("DEBUG_ASH w2 exists: acir_idx: {:?}, idx: {:?}", acir_idx, idx); w1_set.insert(idx); w1_witness_count += witness_count; continue; @@ -179,7 +180,7 @@ impl<'a> WitnessSplitter<'a> { let mut w1_indices: Vec = w1_set.into_iter().collect(); let mut w2_indices: Vec = w2_set.into_iter().collect(); - w1_indices.sort_unstable(); + w1_indices = self.rearrange_w1(w1_indices, &acir_public_inputs_indices_set); w2_indices.sort_unstable(); (w1_indices, w2_indices) @@ -209,6 +210,7 @@ impl<'a> WitnessSplitter<'a> { continue; // Will add 0 first } else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] { if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + println!("DEBUG_ASH: acir_idx: {:?}, builder_idx: {:?}", acir_idx, builder_idx); public_input_builder_indices.push(builder_idx); continue; } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index ee248153c..f42851f17 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -194,6 +194,9 @@ impl WitnessBuilder { let splitter = WitnessSplitter::new(witness_builders); let (w1_indices, w2_indices) = splitter.split_builders(acir_public_inputs_indices_set); + println!("Dx {:?}", w1_indices); + println!("DEBUG_ASH: w2_indices: {:?}", w2_indices); + // Step 2: Extract w1 and w2 builders in order let w1_builders: Vec = w1_indices .iter() diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 122314f62..a01a011d7 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -144,8 +144,6 @@ impl WhirR1CSProver for WhirR1CSScheme { w }; - println!("DEBUG_ASH: full_witness: {:?}", full_witness); - // First round: ZK sumcheck to reduce R1CS to weighted evaluation let alpha = run_zk_sumcheck_prover( &r1cs, @@ -210,6 +208,8 @@ impl WhirR1CSProver for WhirR1CSScheme { let alphas_1: [Vec; 3] = alphas_1.try_into().unwrap(); let alphas_2: [Vec; 3] = alphas_2.try_into().unwrap(); + println!("DEBUG_ASH: &c1.parsed_commitment: {:?}", &c1.padded_witness); + let (mut statement_1, f_sums_1, g_sums_1) = create_combined_statement_over_two_polynomials::<3>( self.m, From 5cb683ab22b5bd3ed2d6f74c85eefbcc2058167a Mon Sep 17 00:00:00 2001 From: ashpect Date: Fri, 12 Dec 2025 12:20:14 +0530 Subject: [PATCH 07/19] fix: reduce reduendancy in prove --- provekit/common/src/witness/mod.rs | 5 +++++ provekit/prover/src/lib.rs | 36 +++++------------------------- provekit/prover/src/r1cs.rs | 12 ---------- 3 files changed, 11 insertions(+), 42 deletions(-) diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 549ff34ba..0b2ec1dea 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -61,6 +61,11 @@ impl PublicInputs { Self(vec) } + /// Assuming the given vector already has a constant 1 field element at the start. + pub fn from_vec_with_constant_one(vec: Vec) -> Self { + Self(vec) + } + pub fn len(&self) -> usize { self.0.len() } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 7d69475b1..2bd0639f9 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -57,15 +57,7 @@ impl Prove for Prover { read_inputs_from_file(prover_toml.as_ref(), self.witness_generator.abi())?; let acir_witness_idx_to_value_map = self.generate_witness(input_map)?; - let acir_public_inputs = self.program.functions[0].public_inputs().indices(); - let acir_public_inputs_set: HashSet = acir_public_inputs.iter().cloned().collect(); - let mut acir_to_r1cs_public_map = HashMap::new(); - - println!("DEBUG_ASH: acir_witness_idx_to_value_map: {:?}", acir_witness_idx_to_value_map); - println!("DEBUG_ASH: acir_public_inputs: {:?}", acir_public_inputs); - println!("DEBUG_ASH: acir_public_inputs_set: {:?}", acir_public_inputs_set); - println!("DEBUG_ASH: acir_to_r1cs_public_map: {:?}", acir_to_r1cs_public_map); // Set up transcript let io: IOPattern = self.whir_for_witness.create_io_pattern(); @@ -80,13 +72,8 @@ impl Prove for Prover { self.split_witness_builders.w1_layers, &acir_witness_idx_to_value_map, &mut merlin, - &acir_public_inputs_set, - &mut acir_to_r1cs_public_map, ); - println!("DEBUG_ASH: acir_to_r1cs_public_map after w1: {:?}", acir_to_r1cs_public_map); - - let w1 = witness[..self.whir_for_witness.w1_size] .iter() .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w1 are missing"))) @@ -106,11 +93,7 @@ impl Prove for Prover { self.split_witness_builders.w2_layers, &acir_witness_idx_to_value_map, &mut merlin, - &acir_public_inputs_set, - &mut acir_to_r1cs_public_map, - ); // DEBUG_ASH : if w2 didn't have pub witness, no need honestly for this - - println!("DEBUG_ASH: acir_to_r1cs_public_map after w2: {:?}", acir_to_r1cs_public_map); + ); let w2 = witness[self.whir_for_witness.w1_size..] .iter() @@ -133,21 +116,14 @@ impl Prove for Prover { .test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::>()) .context("While verifying R1CS instance")?; - // Gather public inputs from witness - let public_indices = acir_to_r1cs_public_map - .values() - .map(|&x| x) - .collect::>(); - - let public_inputs = PublicInputs::from_vec( - public_indices + // Gather public inputs from witness + let num_public_inputs = acir_public_inputs.len(); + let public_inputs = PublicInputs::from_vec_with_constant_one( + witness[0..=num_public_inputs] .iter() - .map(|&i| { - witness[i].ok_or_else(|| anyhow::anyhow!("Missing public input witness at index {i}")) - }) + .map(|w| w.ok_or_else(|| anyhow::anyhow!("Missing public input witness"))) .collect::>>()?, ); - drop(witness); let whir_r1cs_proof = self diff --git a/provekit/prover/src/r1cs.rs b/provekit/prover/src/r1cs.rs index cb752c26c..967c258d6 100644 --- a/provekit/prover/src/r1cs.rs +++ b/provekit/prover/src/r1cs.rs @@ -21,8 +21,6 @@ pub trait R1CSSolver { plan: LayeredWitnessBuilders, acir_map: &WitnessMap, transcript: &mut ProverState, - acir_public_inputs_set: &HashSet, - acir_to_r1cs_public_map: &mut HashMap, ); #[cfg(test)] @@ -57,22 +55,12 @@ impl R1CSSolver for R1CS { plan: LayeredWitnessBuilders, acir_map: &WitnessMap, transcript: &mut ProverState, - acir_public_inputs_set: &HashSet, - acir_to_r1cs_public_map: &mut HashMap, ) { for layer in &plan.layers { match layer.typ { LayerType::Other => { // Execute regular operations for builder in &layer.witness_builders { - - if let WitnessBuilder::Acir(r1cs_witness_idx, acir_witness_idx) = builder { - if acir_public_inputs_set.contains(&(*acir_witness_idx as u32)) { - acir_to_r1cs_public_map - .insert(*acir_witness_idx as u32, *r1cs_witness_idx); - } - } - builder.solve(&acir_map, witness, transcript); } } From 91859f88928e9ab3cfe3a320e623c9b824317f89 Mon Sep 17 00:00:00 2001 From: ashpect Date: Fri, 12 Dec 2025 13:12:55 +0530 Subject: [PATCH 08/19] chore: cleanup, remove logging and fmt --- provekit/common/src/noir_proof_scheme.rs | 2 +- provekit/common/src/whir_r1cs.rs | 3 ++- provekit/common/src/witness/mod.rs | 9 ++++++--- .../common/src/witness/scheduling/splitter.rs | 11 ++++++----- provekit/common/src/witness/witness_builder.rs | 6 +----- provekit/prover/src/lib.rs | 17 ++++++++++------- provekit/prover/src/r1cs.rs | 1 - provekit/prover/src/whir_r1cs.rs | 12 ++++-------- provekit/r1cs-compiler/src/noir_proof_scheme.rs | 12 ++++++++---- provekit/verifier/src/whir_r1cs.rs | 17 +++++++++-------- 10 files changed, 47 insertions(+), 43 deletions(-) diff --git a/provekit/common/src/noir_proof_scheme.rs b/provekit/common/src/noir_proof_scheme.rs index f7e40fd22..7552ab268 100644 --- a/provekit/common/src/noir_proof_scheme.rs +++ b/provekit/common/src/noir_proof_scheme.rs @@ -2,7 +2,7 @@ use { crate::{ whir_r1cs::{WhirR1CSProof, WhirR1CSScheme}, witness::{NoirWitnessGenerator, SplitWitnessBuilders}, - NoirElement, R1CS, PublicInputs, + NoirElement, PublicInputs, R1CS, }, acir::circuit::Program, serde::{Deserialize, Serialize}, diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index 572fa5620..43aed2c29 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -34,7 +34,8 @@ impl WhirR1CSScheme { if self.num_challenges > 0 { // Compute total constraints: OOD + statement // OOD: 2 witnesses × committment_ood_samples each - // Statement: statement_1 has 3 constraints + 1 public weights constraint = 4, statement_2 has 3 = 3, total = 7 + // Statement: statement_1 has 3 constraints + 1 public weights constraint = 4, + // statement_2 has 3 = 3, total = 7 let num_witnesses = 2; let num_ood_constraints = num_witnesses * self.whir_witness.committment_ood_samples; let num_statement_constraints = 7; // (3+1) + (3) diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 0b2ec1dea..694817206 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -7,7 +7,10 @@ mod witness_generator; mod witness_io_pattern; use { - crate::{utils::{serde_ark, serde_ark_vec}, FieldElement}, + crate::{ + utils::{serde_ark, serde_ark_vec}, + FieldElement, + }, ark_ff::{BigInt, One, PrimeField}, serde::{Deserialize, Serialize}, sha2::{Digest, Sha256}, @@ -42,7 +45,6 @@ impl ConstantOrR1CSWitness { } } - #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PublicInputs(#[serde(with = "serde_ark_vec")] pub Vec); @@ -61,7 +63,8 @@ impl PublicInputs { Self(vec) } - /// Assuming the given vector already has a constant 1 field element at the start. + /// Assuming the given vector already has a constant 1 field element at the + /// start. pub fn from_vec_with_constant_one(vec: Vec) -> Self { Self(vec) } diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index b7223dc2f..2a30e3513 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -26,7 +26,10 @@ impl<'a> WitnessSplitter<'a> { /// (post-challenge). /// /// Returns (w1_builder_indices, w2_builder_indices) - pub fn split_builders(&self, acir_public_inputs_indices_set: HashSet) -> (Vec, Vec) { + pub fn split_builders( + &self, + acir_public_inputs_indices_set: HashSet, + ) -> (Vec, Vec) { let builder_count = self.witness_builders.len(); // Step 1: Find all Challenge builders @@ -154,10 +157,9 @@ impl<'a> WitnessSplitter<'a> { let witness_count = DependencyInfo::extract_writes(&self.witness_builders[idx]).len(); - // If free builder writes a public witness, add it to w1_set. + // If free builder writes a public witness, add it to w1_set. if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] { if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { - println!("DEBUG_ASH w2 exists: acir_idx: {:?}, idx: {:?}", acir_idx, idx); w1_set.insert(idx); w1_witness_count += witness_count; continue; @@ -186,7 +188,7 @@ impl<'a> WitnessSplitter<'a> { (w1_indices, w2_indices) } - /// Rearranges w1 indices: constant builder (0) first, then public inputs, + /// Rearranges w1 indices: constant builder (0) first, then public inputs, /// then rest. fn rearrange_w1( &self, @@ -210,7 +212,6 @@ impl<'a> WitnessSplitter<'a> { continue; // Will add 0 first } else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] { if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { - println!("DEBUG_ASH: acir_idx: {:?}, builder_idx: {:?}", acir_idx, builder_idx); public_input_builder_indices.push(builder_idx); continue; } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index f42851f17..9567eac7a 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -14,8 +14,7 @@ use { FieldElement, R1CS, }, serde::{Deserialize, Serialize}, - std::collections::HashSet, - std::num::NonZeroU32, + std::{collections::HashSet, num::NonZeroU32}, }; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -194,9 +193,6 @@ impl WitnessBuilder { let splitter = WitnessSplitter::new(witness_builders); let (w1_indices, w2_indices) = splitter.split_builders(acir_public_inputs_indices_set); - println!("Dx {:?}", w1_indices); - println!("DEBUG_ASH: w2_indices: {:?}", w2_indices); - // Step 2: Extract w1 and w2 builders in order let w1_builders: Vec = w1_indices .iter() diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 2bd0639f9..d7e8f846a 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -7,8 +7,10 @@ use { noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover, PublicInputs}, - std::collections::{HashMap, HashSet}, - std::path::Path, + std::{ + collections::{HashMap, HashSet}, + path::Path, + }, tracing::instrument, }; @@ -79,7 +81,6 @@ impl Prove for Prover { .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w1 are missing"))) .collect::>>()?; - println!("DEBUG_ASH: w1: {:?}", w1); let commitment_1 = self .whir_for_witness .commit(&mut merlin, &self.r1cs, w1, true) @@ -93,7 +94,7 @@ impl Prove for Prover { self.split_witness_builders.w2_layers, &acir_witness_idx_to_value_map, &mut merlin, - ); + ); let w2 = witness[self.whir_for_witness.w1_size..] .iter() @@ -116,7 +117,7 @@ impl Prove for Prover { .test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::>()) .context("While verifying R1CS instance")?; - // Gather public inputs from witness + // Gather public inputs from witness let num_public_inputs = acir_public_inputs.len(); let public_inputs = PublicInputs::from_vec_with_constant_one( witness[0..=num_public_inputs] @@ -131,8 +132,10 @@ impl Prove for Prover { .prove(merlin, self.r1cs, commitments, &public_inputs) .context("While proving R1CS instance")?; - println!("DEBUG_ASH: public_inputs: {:?}", public_inputs); - Ok(NoirProof { public_inputs, whir_r1cs_proof }) + Ok(NoirProof { + public_inputs, + whir_r1cs_proof, + }) } } diff --git a/provekit/prover/src/r1cs.rs b/provekit/prover/src/r1cs.rs index 967c258d6..18b3ffca0 100644 --- a/provekit/prover/src/r1cs.rs +++ b/provekit/prover/src/r1cs.rs @@ -10,7 +10,6 @@ use { FieldElement, NoirElement, R1CS, }, spongefish::ProverState, - std::collections::{HashMap, HashSet}, tracing::instrument, }; diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index a01a011d7..69bbed0e6 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -3,7 +3,6 @@ use { ark_ff::UniformRand, ark_std::{One, Zero}, provekit_common::{ - PublicInputs, skyscraper::{SkyscraperMerkleConfig, SkyscraperSponge}, utils::{ pad_to_power_of_two, @@ -15,7 +14,7 @@ use { zk_utils::{create_masked_polynomial, generate_random_multilinear_polynomial}, HALF, }, - FieldElement, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, + FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitSerialize, UnitToField}, @@ -181,9 +180,9 @@ impl WhirR1CSProver for WhirR1CSScheme { &commitment.random_polynomial, public_weight, ); - + let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); - + run_zk_whir_pcs_prover( commitment.commitment_to_witness, statement, @@ -208,8 +207,6 @@ impl WhirR1CSProver for WhirR1CSScheme { let alphas_1: [Vec; 3] = alphas_1.try_into().unwrap(); let alphas_2: [Vec; 3] = alphas_2.try_into().unwrap(); - println!("DEBUG_ASH: &c1.parsed_commitment: {:?}", &c1.padded_witness); - let (mut statement_1, f_sums_1, g_sums_1) = create_combined_statement_over_two_polynomials::<3>( self.m, @@ -242,7 +239,7 @@ impl WhirR1CSProver for WhirR1CSScheme { &c1.random_polynomial, public_weight, ); - + let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); run_zk_whir_pcs_batch_prover( @@ -658,7 +655,6 @@ pub fn run_zk_whir_pcs_batch_prover( (randomness, deferred) } - fn update_statement_with_public_weights( statement: &mut Statement, witness: &Witness, diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 6c9ca3223..0cde2f1bb 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -10,9 +10,8 @@ use { witness::{NoirWitnessGenerator, WitnessBuilder}, NoirProofScheme, WhirR1CSScheme, }, - std::{fs::File, path::Path}, + std::{collections::HashSet, fs::File, path::Path}, tracing::{info, instrument}, - std::collections::HashSet, }; pub trait NoirProofSchemeBuilder { @@ -64,11 +63,16 @@ impl NoirProofSchemeBuilder for NoirProofScheme { // Extract ACIR public input indices set let acir_public_inputs_indices_set: HashSet = - main.public_inputs().indices().iter().cloned().collect(); + main.public_inputs().indices().iter().cloned().collect(); // Split witness builders and remap indices for sound challenge generation let (split_witness_builders, remapped_r1cs, remapped_witness_map, num_challenges) = - WitnessBuilder::split_and_prepare_layers(&witness_builders, r1cs, witness_map, acir_public_inputs_indices_set); + WitnessBuilder::split_and_prepare_layers( + &witness_builders, + r1cs, + witness_map, + acir_public_inputs_indices_set, + ); info!( "Witness split: w1 size = {}, w2 size = {}", split_witness_builders.w1_size, diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 3ae9dc93f..22296aae5 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -2,10 +2,9 @@ use { anyhow::{ensure, Context, Result}, ark_std::{One, Zero}, provekit_common::{ - PublicInputs, skyscraper::SkyscraperSponge, utils::sumcheck::{calculate_eq, eval_cubic_poly}, - FieldElement, WhirConfig, WhirR1CSProof, WhirR1CSScheme, + FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitDeserialize, UnitToField}, @@ -89,11 +88,12 @@ impl WhirR1CSVerifier for WhirR1CSScheme { expected_public_inputs_hash, public_inputs_hash_buf[0] ); - + let mut public_weights_vector_random_buf = [FieldElement::zero()]; arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - - let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); + + let whir_pub_weights_query_answer: (FieldElement, FieldElement) = + arthur.hint().unwrap(); update_statement_for_witness_verifier( self.m, &mut statement_1, @@ -135,11 +135,12 @@ impl WhirR1CSVerifier for WhirR1CSScheme { expected_public_inputs_hash, public_inputs_hash_buf[0] ); - + let mut public_weights_vector_random_buf = [FieldElement::zero()]; arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - - let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); + + let whir_pub_weights_query_answer: (FieldElement, FieldElement) = + arthur.hint().unwrap(); update_statement_for_witness_verifier( self.m, &mut statement, From 6545a5985d124c4824e4945ee34deff0f9b4d13f Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 24 Dec 2025 19:20:01 +0530 Subject: [PATCH 09/19] feat: fix ordering and hashing --- provekit/common/src/witness/mod.rs | 18 ++------ .../common/src/witness/scheduling/splitter.rs | 1 - provekit/prover/src/lib.rs | 21 ++++----- provekit/prover/src/whir_r1cs.rs | 45 ++++++++++++------- provekit/verifier/src/whir_r1cs.rs | 29 +++++++----- 5 files changed, 62 insertions(+), 52 deletions(-) diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 694817206..ae1390326 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -49,23 +49,13 @@ impl ConstantOrR1CSWitness { pub struct PublicInputs(#[serde(with = "serde_ark_vec")] pub Vec); impl PublicInputs { - /// Creates a new `PublicInputs` with a constant 1 field element at the - /// start. + /// Creates a new `PublicInputs` with an empty vector. pub fn new() -> Self { - Self(vec![FieldElement::one()]) + Self(Vec::new()) } - /// Creates a new `PublicInputs` from a vector, adding a constant 1 field - /// element at the start. To emulate the constant 1 witness in the R1CS - /// instance. - pub fn from_vec(mut vec: Vec) -> Self { - vec.insert(0, FieldElement::one()); - Self(vec) - } - - /// Assuming the given vector already has a constant 1 field element at the - /// start. - pub fn from_vec_with_constant_one(vec: Vec) -> Self { + /// Creates a new `PublicInputs` from a vector. + pub fn from_vec(vec: Vec) -> Self { Self(vec) } diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index 2a30e3513..2dbd2ba9d 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -219,7 +219,6 @@ impl<'a> WitnessSplitter<'a> { rest_indices.push(builder_idx); } - public_input_builder_indices.sort_unstable(); rest_indices.sort_unstable(); // Reorder: 0 first, then public inputs, then rest diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index d7e8f846a..bb89b7905 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -7,10 +7,7 @@ use { noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover, PublicInputs}, - std::{ - collections::{HashMap, HashSet}, - path::Path, - }, + std::path::Path, tracing::instrument, }; @@ -119,12 +116,16 @@ impl Prove for Prover { // Gather public inputs from witness let num_public_inputs = acir_public_inputs.len(); - let public_inputs = PublicInputs::from_vec_with_constant_one( - witness[0..=num_public_inputs] - .iter() - .map(|w| w.ok_or_else(|| anyhow::anyhow!("Missing public input witness"))) - .collect::>>()?, - ); + let public_inputs = if num_public_inputs == 0 { + PublicInputs::new() + } else { + PublicInputs::from_vec( + witness[1..=num_public_inputs] + .iter() + .map(|w| w.ok_or_else(|| anyhow::anyhow!("Missing public input witness"))) + .collect::>>()?, + ) + }; drop(witness); let whir_r1cs_proof = self diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 69bbed0e6..00dfbd478 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -173,13 +173,21 @@ impl WhirR1CSProver for WhirR1CSScheme { // VERIFY the size given by self.m let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); - let (public_f_sum, public_g_sum) = update_statement_with_public_weights( - &mut statement, - &commitment.commitment_to_witness, - &commitment.masked_polynomial, - &commitment.random_polynomial, - public_weight, - ); + let (public_f_sum, public_g_sum) = if public_inputs.len() == 0 { + // If there are no public inputs, the hint is unused by the verifier and can be + // assigned an arbitrary value. + let public_f_sum = FieldElement::zero(); + let public_g_sum = FieldElement::zero(); + (public_f_sum, public_g_sum) + } else { + update_statement_with_public_weights( + &mut statement, + &commitment.commitment_to_witness, + &commitment.masked_polynomial, + &commitment.random_polynomial, + public_weight, + ) + }; let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); @@ -230,15 +238,20 @@ impl WhirR1CSProver for WhirR1CSScheme { merlin.hint::<(Vec, Vec)>(&(f_sums_1, g_sums_1))?; merlin.hint::<(Vec, Vec)>(&(f_sums_2, g_sums_2))?; - // VERIFY the size given by self.m let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); - let (public_f_sum, public_g_sum) = update_statement_with_public_weights( - &mut statement_1, - &c1.commitment_to_witness, - &c1.masked_polynomial, - &c1.random_polynomial, - public_weight, - ); + let (public_f_sum, public_g_sum) = if public_inputs.len() == 0 { + let public_f_sum = FieldElement::zero(); + let public_g_sum = FieldElement::zero(); + (public_f_sum, public_g_sum) + } else { + update_statement_with_public_weights( + &mut statement_1, + &c1.commitment_to_witness, + &c1.masked_polynomial, + &c1.random_polynomial, + public_weight, + ) + }; let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); @@ -674,7 +687,9 @@ fn get_public_weights( m: usize, ) -> Weights { // Add hash to transcript + info!("ASH_TEST : Public inputs: {:?}", public_inputs.0); let public_inputs_hash = public_inputs.hash(); + info!("ASH_TEST : Public inputs hash: {:?}", public_inputs_hash); let _ = merlin.add_scalars(&[public_inputs_hash]); // Get random point x diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 22296aae5..54cabbcb8 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -94,12 +94,15 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); - update_statement_for_witness_verifier( - self.m, - &mut statement_1, - &parsed_commitment_1, - whir_pub_weights_query_answer, - ); + + if public_inputs.len() > 0 { + update_statement_for_witness_verifier( + self.m, + &mut statement_1, + &parsed_commitment_1, + whir_pub_weights_query_answer, + ); + } run_whir_pcs_batch_verifier( &mut arthur, @@ -141,12 +144,14 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); - update_statement_for_witness_verifier( - self.m, - &mut statement, - &parsed_commitment_1, - whir_pub_weights_query_answer, - ); + if public_inputs.len() > 0 { + update_statement_for_witness_verifier( + self.m, + &mut statement, + &parsed_commitment_1, + whir_pub_weights_query_answer, + ); + } run_whir_pcs_verifier( &mut arthur, From efc2f07bf510ec4af809ba0a980d0a7d92dd477c Mon Sep 17 00:00:00 2001 From: ashpect Date: Sat, 3 Jan 2026 03:54:29 +0530 Subject: [PATCH 10/19] chore: cleanup, namin, etc --- provekit/common/src/whir_r1cs.rs | 4 ++-- provekit/common/src/witness/mod.rs | 10 ++++++++++ .../common/src/witness/scheduling/splitter.rs | 12 ++++++++---- provekit/prover/src/whir_r1cs.rs | 10 ++++------ provekit/verifier/src/whir_r1cs.rs | 18 ++++++++++-------- 5 files changed, 34 insertions(+), 20 deletions(-) diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index 43aed2c29..702a25c74 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -35,10 +35,10 @@ impl WhirR1CSScheme { // Compute total constraints: OOD + statement // OOD: 2 witnesses × committment_ood_samples each // Statement: statement_1 has 3 constraints + 1 public weights constraint = 4, - // statement_2 has 3 = 3, total = 7 + // statement_2 has 3 constraints = 3, total = 7 let num_witnesses = 2; let num_ood_constraints = num_witnesses * self.whir_witness.committment_ood_samples; - let num_statement_constraints = 7; // (3+1) + (3) + let num_statement_constraints = 7; let num_constraints_total = num_ood_constraints + num_statement_constraints; io = io diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index ae1390326..d0fa8ea62 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -63,6 +63,10 @@ impl PublicInputs { self.0.len() } + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + /// Hashes the public input values using SHA-256 and converts the result to /// a FieldElement. pub fn hash(&self) -> FieldElement { @@ -86,3 +90,9 @@ impl PublicInputs { FieldElement::new(BigInt::new(limbs.try_into().unwrap())) } } + +impl Default for PublicInputs { + fn default() -> Self { + Self::new() + } +} diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index 2dbd2ba9d..94ffffe97 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -142,7 +142,7 @@ impl<'a> WitnessSplitter<'a> { // Step 7: Assign free builders greedily while respecting dependencies // Rule: if any dependency is in w2, the builder must also be in w2 // (because w1 is solved before w2) - // with the exception of public builders writing public witnesses) + // A free builder for public input witnesses goes in w1. let mut w1_set = mandatory_w1; let mut w2_set = mandatory_w2; @@ -188,8 +188,10 @@ impl<'a> WitnessSplitter<'a> { (w1_indices, w2_indices) } - /// Rearranges w1 indices: constant builder (0) first, then public inputs, - /// then rest. + /// Rearranges w1 builder indices into a canonical order: + /// 1. Constant builder (index 0) first, to preserve R1CS index 0 invariant + /// 2. Public input builders next, grouped together + /// 3. All other w1 builders last, sorted by index fn rearrange_w1( &self, w1_indices: Vec, @@ -200,8 +202,10 @@ impl<'a> WitnessSplitter<'a> { // Sanity Check: Make sure all public inputs and WITNESS_ONE_IDX are in // w1_indices. + // Convert to HashSet for O(1) lookups since we're checking many times + let w1_indices_set = w1_indices.iter().copied().collect::>(); for &idx in acir_public_inputs_indices_set.iter() { - if !w1_indices.contains(&(idx as usize)) { + if !w1_indices_set.contains(&(idx as usize)) { panic!("Public input {} is not in w1_indices", idx); } } diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 00dfbd478..fadfa0bda 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -173,7 +173,7 @@ impl WhirR1CSProver for WhirR1CSScheme { // VERIFY the size given by self.m let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); - let (public_f_sum, public_g_sum) = if public_inputs.len() == 0 { + let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { // If there are no public inputs, the hint is unused by the verifier and can be // assigned an arbitrary value. let public_f_sum = FieldElement::zero(); @@ -189,7 +189,7 @@ impl WhirR1CSProver for WhirR1CSScheme { ) }; - let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); + merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum))?; run_zk_whir_pcs_prover( commitment.commitment_to_witness, @@ -239,7 +239,7 @@ impl WhirR1CSProver for WhirR1CSScheme { merlin.hint::<(Vec, Vec)>(&(f_sums_2, g_sums_2))?; let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); - let (public_f_sum, public_g_sum) = if public_inputs.len() == 0 { + let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { let public_f_sum = FieldElement::zero(); let public_g_sum = FieldElement::zero(); (public_f_sum, public_g_sum) @@ -253,7 +253,7 @@ impl WhirR1CSProver for WhirR1CSScheme { ) }; - let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); + merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum))?; run_zk_whir_pcs_batch_prover( &[c1.commitment_to_witness, c2.commitment_to_witness], @@ -687,9 +687,7 @@ fn get_public_weights( m: usize, ) -> Weights { // Add hash to transcript - info!("ASH_TEST : Public inputs: {:?}", public_inputs.0); let public_inputs_hash = public_inputs.hash(); - info!("ASH_TEST : Public inputs hash: {:?}", public_inputs_hash); let _ = merlin.add_scalars(&[public_inputs_hash]); // Get random point x diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 54cabbcb8..17e102fb9 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -92,15 +92,16 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let mut public_weights_vector_random_buf = [FieldElement::zero()]; arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - let whir_pub_weights_query_answer: (FieldElement, FieldElement) = - arthur.hint().unwrap(); + let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur + .hint() + .context("failed to read WHIR public weights query answer")?; - if public_inputs.len() > 0 { + if !public_inputs.is_empty() { update_statement_for_witness_verifier( self.m, &mut statement_1, &parsed_commitment_1, - whir_pub_weights_query_answer, + whir_public_weights_query_answer, ); } @@ -142,14 +143,15 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let mut public_weights_vector_random_buf = [FieldElement::zero()]; arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - let whir_pub_weights_query_answer: (FieldElement, FieldElement) = - arthur.hint().unwrap(); - if public_inputs.len() > 0 { + let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur + .hint() + .context("failed to read WHIR public weights query answer")?; + if !public_inputs.is_empty() { update_statement_for_witness_verifier( self.m, &mut statement, &parsed_commitment_1, - whir_pub_weights_query_answer, + whir_public_weights_query_answer, ); } From 5bd5739e814146297cc17d4d959ad0069ea3ea59 Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 06:14:52 +0530 Subject: [PATCH 11/19] feat: add gnark support for public witness --- recursive-verifier/app/circuit/circuit.go | 171 +++++++++++++++++- .../app/circuit/circuit_test.go | 4 +- recursive-verifier/app/circuit/common.go | 13 +- recursive-verifier/app/circuit/mtUtilities.go | 83 +++++++++ recursive-verifier/app/circuit/types.go | 20 ++ recursive-verifier/app/circuit/whir.go | 3 +- .../app/circuit/whir_utilities.go | 1 + recursive-verifier/app/utilities/utilities.go | 121 +++++++++++++ 8 files changed, 402 insertions(+), 14 deletions(-) diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index 2f95d5e6e..dd7c2543d 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -39,6 +39,9 @@ type Circuit struct { WitnessClaimedEvaluations [][]frontend.Variable // [commitment_idx][eval_idx] WitnessBlindingEvaluations [][]frontend.Variable + // For public_f_sum and public_g_sum + PubWitnessEvaluations []frontend.Variable + // Batch mode only: batched polynomial for rounds 1+ WitnessMerkle Merkle @@ -46,9 +49,9 @@ type Circuit struct { MatrixB []MatrixCell MatrixC []MatrixCell - // Public Input - IO []byte - Transcript []uints.U8 `gnark:",public"` + IO []byte + Transcript []uints.U8 `gnark:",public"` + PublicInputs PublicInputs } func (circuit *Circuit) Define(api frontend.API) error { @@ -95,6 +98,26 @@ func (circuit *Circuit) Define(api frontend.API) error { return err } + // Read public inputs hash from transcript + publicInputsHashBuf := make([]frontend.Variable, 1) + if err := arthur.FillNextScalars(publicInputsHashBuf); err != nil { + return fmt.Errorf("failed to read public inputs hash: %w", err) + } + + // TODO : Compute expected public inputs hash and verify + expectedHash, err := hashPublicInputs(api, sc, circuit.PublicInputs) + if err != nil { + return fmt.Errorf("failed to compute public inputs hash: %w", err) + } + + api.AssertIsEqual(publicInputsHashBuf[0], expectedHash) + + // Squeeze rand for public weights + publicWeightsChallenge := make([]frontend.Variable, 1) + if err := arthur.FillChallengeScalars(publicWeightsChallenge); err != nil { + return fmt.Errorf("failed to read public weights challenge: %w", err) + } + // WHIR verification var whirFoldingRandomness []frontend.Variable var az, bz, cz frontend.Variable @@ -115,6 +138,7 @@ func (circuit *Circuit) Define(api frontend.API) error { }, circuit.WHIRParamsWitness, // whirParams circuit.WitnessLinearStatementEvaluations, // linearStatementValuesAtPoints + circuit.PublicInputs, // publicInputs ) if err != nil { return err @@ -125,12 +149,15 @@ func (circuit *Circuit) Define(api frontend.API) error { bz = api.Add(circuit.WitnessClaimedEvaluations[0][1], circuit.WitnessClaimedEvaluations[1][1]) cz = api.Add(circuit.WitnessClaimedEvaluations[0][2], circuit.WitnessClaimedEvaluations[1][2]) } else { + log.Println("Single Mode") + extendedLinearStatementEvals := extendLinearStatement(circuit, [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, circuit.PubWitnessEvaluations) + // Single commitment mode whirFoldingRandomness, err = RunZKWhir( api, arthur, uapi, sc, circuit.WitnessMerkle, circuit.WitnessFirstRounds[0], circuit.WHIRParamsWitness, - [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, + extendedLinearStatementEvals, circuit.WitnessLinearStatementEvaluations, batchingRandomness1, initialOODQueries1, @@ -150,23 +177,72 @@ func (circuit *Circuit) Define(api frontend.API) error { x := api.Mul(api.Sub(api.Mul(az, bz), cz), calculateEQ(api, spartanSumcheckRand, tRand)) api.AssertIsEqual(spartanSumcheckLastValue, x) + // TODO : generalize it later on if we have more different kinds of statements + // for handling geometric weights statement added at starting + offset := 1 + if circuit.NumChallenges > 0 { // Batch mode - check 6 deferred values matrixExtensionEvals := evaluateR1CSMatrixExtensionBatch(api, circuit, spartanSumcheckRand, whirFoldingRandomness, circuit.W1Size) for i := 0; i < 6; i++ { - api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[i]) + api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[offset+i]) } } else { + // Single mode - existing logic matrixExtensionEvals := evaluateR1CSMatrixExtension(api, circuit, spartanSumcheckRand, whirFoldingRandomness) for i := 0; i < 3; i++ { - api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[i]) + api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[offset+i]) } } + // Geomteric weights for public inputs + if !circuit.PublicInputs.IsEmpty() { + publicWeightEval := computePublicWeightEvaluation( + api, circuit.PublicInputs, whirFoldingRandomness, + circuit.WHIRParamsWitness.MVParamsNumberOfVariables, publicWeightsChallenge[0], + ) + + api.AssertIsEqual(publicWeightEval, circuit.WitnessLinearStatementEvaluations[0]) + } + return nil } +func computePublicWeightEvaluation( + api frontend.API, + publicInputs PublicInputs, + foldingRandomness []frontend.Variable, + m int, // domain size = 2^m + x frontend.Variable, +) frontend.Variable { + // Build public weight vector: [1, x, x^2, ..., x^(n-1), 0, 0, ..., 0] where n = len(publicInputs.Values) and total length = 2^m + domainSize := 1 << m + publicWeights := make([]frontend.Variable, domainSize) + + for i := 0; i < domainSize; i++ { + publicWeights[i] = 0 + } + + // Set public weights: [1, x, x^2, ..., x^(n-1), 0, 0, ..., 0] + currentPower := frontend.Variable(1) + for i := 0; i < len(publicInputs.Values); i++ { + publicWeights[i] = currentPower + currentPower = api.Mul(currentPower, x) + } + + // TODO : Replace it with geometric_till algo + // Evaluate the multilinear extension of publicWeights at foldingRandomness + // Formula: f(r) = Σ_{i=0}^{2^m-1} f[i] * eq_i(r) + // where eq_i(r) is the i-th Lagrange basis polynomial + eqPolys := calculateEQOverBooleanHypercube(api, foldingRandomness) + result := frontend.Variable(0) + for i := 0; i < len(publicWeights); i++ { + result = api.Add(result, api.Mul(publicWeights[i], eqPolys[i])) + } + return result +} + func verifyCircuit( deferred []Fp256, cfg Config, @@ -175,9 +251,11 @@ func verifyCircuit( vk *groth16.VerifyingKey, claimedEvaluations ClaimedEvaluations, claimedEvaluations2 ClaimedEvaluations, + publicWeightsClaimedEvaluation [2]Fp256, internedR1CS R1CS, interner Interner, buildOps common.BuildOps, + publicInputs PublicInputs, ) error { transcriptT := make([]uints.U8, cfg.TranscriptLen) contTranscript := make([]uints.U8, cfg.TranscriptLen) @@ -189,9 +267,18 @@ func verifyCircuit( // Determine witness linear statement evals size based on mode var witnessLinearStatementEvalsSize int if cfg.NumChallenges > 0 { - witnessLinearStatementEvalsSize = 6 // 3 per commitment in batch mode + if !cfg.PublicInputs.IsEmpty() { + // 3 per commitment in batch mode + 1 public_input (geometric statement as a subset of linear statement) + witnessLinearStatementEvalsSize = 7 + } else { + witnessLinearStatementEvalsSize = 6 + } } else { - witnessLinearStatementEvalsSize = 3 + if !cfg.PublicInputs.IsEmpty() { + witnessLinearStatementEvalsSize = 4 + } else { + witnessLinearStatementEvalsSize = 3 + } } witnessLinearStatementEvaluations := make([]frontend.Variable, witnessLinearStatementEvalsSize) @@ -199,6 +286,9 @@ func verifyCircuit( contWitnessLinearStatementEvaluations := make([]frontend.Variable, witnessLinearStatementEvalsSize) contHidingSpartanLinearStatementEvaluations := make([]frontend.Variable, 1) + if len(deferred) < 1+witnessLinearStatementEvalsSize { + return fmt.Errorf("deferred array too short: expected at least %d elements, got %d", 1+witnessLinearStatementEvalsSize, len(deferred)) + } hidingSpartanLinearStatementEvaluations[0] = typeConverters.LimbsToBigIntMod(deferred[0].Limbs) for i := 0; i < witnessLinearStatementEvalsSize; i++ { witnessLinearStatementEvaluations[i] = typeConverters.LimbsToBigIntMod(deferred[1+i].Limbs) @@ -258,6 +348,10 @@ func verifyCircuit( fSums2, gSums2 = parseClaimedEvaluations(claimedEvaluations2, true) } + // Parse public weights claimed evaluation + fSumPublicWeights, gSumPublicWeights := parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation, true) + pubWitnessEvaluations := []frontend.Variable{fSumPublicWeights, gSumPublicWeights} + // Build witness slices conditionally var witnessClaimedEvals, witnessBlindingEvals [][]frontend.Variable if cfg.NumChallenges > 0 { @@ -268,6 +362,11 @@ func verifyCircuit( witnessBlindingEvals = [][]frontend.Variable{gSums} } + // Empty container while circuit creation + publicInputsContainer := PublicInputs{ + Values: make([]frontend.Variable, len(publicInputs.Values)), + } + circuit := Circuit{ IO: []byte(cfg.IOPattern), Transcript: contTranscript, @@ -276,6 +375,7 @@ func verifyCircuit( LogANumTerms: cfg.LogANumTerms, WitnessClaimedEvaluations: witnessClaimedEvals, WitnessBlindingEvaluations: witnessBlindingEvals, + PubWitnessEvaluations: pubWitnessEvaluations, WitnessLinearStatementEvaluations: contWitnessLinearStatementEvaluations, HidingSpartanLinearStatementEvaluations: contHidingSpartanLinearStatementEvaluations, HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, true), @@ -289,6 +389,7 @@ func verifyCircuit( MatrixA: matrixA, MatrixB: matrixB, MatrixC: matrixC, + PublicInputs: publicInputsContainer, } ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) @@ -377,13 +478,19 @@ func verifyCircuit( witnessBlindingEvals = [][]frontend.Variable{gSums} } + fSumPublicWeights, gSumPublicWeights = parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation, false) + pubWitnessEvaluations = []frontend.Variable{fSumPublicWeights, gSumPublicWeights} + assignment := Circuit{ IO: []byte(cfg.IOPattern), Transcript: transcriptT, LogNumConstraints: cfg.LogNumConstraints, + LogNumVariables: cfg.LogNumVariables, + LogANumTerms: cfg.LogANumTerms, WitnessClaimedEvaluations: witnessClaimedEvals, WitnessBlindingEvaluations: witnessBlindingEvals, WitnessLinearStatementEvaluations: witnessLinearStatementEvaluations, + PubWitnessEvaluations: pubWitnessEvaluations, HidingSpartanLinearStatementEvaluations: hidingSpartanLinearStatementEvaluations, HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, false), HidingSpartanMerkle: newMerkle(hints.spartanHidingHint.roundHints, false), @@ -396,13 +503,18 @@ func verifyCircuit( MatrixA: matrixA, MatrixB: matrixB, MatrixC: matrixC, + PublicInputs: publicInputs, } witness, _ := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) - publicWitness, _ := witness.Public() + publicWitness, err := witness.Public() + if err != nil { + log.Printf("Failed witess,Public(): %v", err) + return err + } opts := []backend.ProverOption{ - backend.WithSolverOptions(solver.WithHints(utilities.IndexOf)), + backend.WithSolverOptions(solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)), backend.WithIcicleAcceleration(), } @@ -436,3 +548,42 @@ func witnessFirstRounds(hints Hints, isContainer bool) []Merkle { } return result } + +func parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation [2]Fp256, isContainer bool) (frontend.Variable, frontend.Variable) { + var fSumPublicWeights, gSumPublicWeights frontend.Variable + + if !isContainer { + fSumPublicWeights = typeConverters.LimbsToBigIntMod(publicWeightsClaimedEvaluation[0].Limbs) + gSumPublicWeights = typeConverters.LimbsToBigIntMod(publicWeightsClaimedEvaluation[1].Limbs) + } + + return fSumPublicWeights, gSumPublicWeights +} + +func extendLinearStatement( + circuit *Circuit, + linearStatementEvaluations [][]frontend.Variable, + pubWitnessEvaluations []frontend.Variable, +) [][]frontend.Variable { + var extendedLinearStatementEvals [][]frontend.Variable + + if !circuit.PublicInputs.IsEmpty() { + // Extend the statement equivalent array by prepending the public constraint (public constraint is added in starting at prover side) + extendedLinearStatementEvals = make([][]frontend.Variable, 2) + + // f_sums: [public_f_sum, f_sums[0], f_sums[1]... ] + extendedLinearStatementEvals[0] = make([]frontend.Variable, len(linearStatementEvaluations[0])+1) + extendedLinearStatementEvals[0][0] = pubWitnessEvaluations[0] + copy(extendedLinearStatementEvals[0][1:], linearStatementEvaluations[0]) + + // g_sums: [public_g_sum, g_sums[0], g_sums[1]... ] + extendedLinearStatementEvals[1] = make([]frontend.Variable, len(linearStatementEvaluations[1])+1) + extendedLinearStatementEvals[1][0] = pubWitnessEvaluations[1] + copy(extendedLinearStatementEvals[1][1:], linearStatementEvaluations[1]) + } else { + // No public inputs, use original arrays + extendedLinearStatementEvals = linearStatementEvaluations + } + + return extendedLinearStatementEvals +} diff --git a/recursive-verifier/app/circuit/circuit_test.go b/recursive-verifier/app/circuit/circuit_test.go index 19e2c4a04..69c1fa46e 100644 --- a/recursive-verifier/app/circuit/circuit_test.go +++ b/recursive-verifier/app/circuit/circuit_test.go @@ -70,7 +70,7 @@ func TestCircuitConstraints(t *testing.T) { test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), - test.WithSolverOpts(solver.WithHints(utilities.IndexOf)), + test.WithSolverOpts(solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)), ) } @@ -124,7 +124,7 @@ func TestCircuitConstraintsSolverOnly(t *testing.T) { } // Solve the constraint system - _, err = ccs.Solve(witness, solver.WithHints(utilities.IndexOf)) + _, err = ccs.Solve(witness, solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)) if err != nil { t.Fatalf("Constraint system not satisfied: %v", err) } diff --git a/recursive-verifier/app/circuit/common.go b/recursive-verifier/app/circuit/common.go index 16eea7c1f..707286738 100644 --- a/recursive-verifier/app/circuit/common.go +++ b/recursive-verifier/app/circuit/common.go @@ -30,6 +30,7 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v var deferred []Fp256 var claimedEvaluations ClaimedEvaluations var claimedEvaluations2 ClaimedEvaluations + var publicWeightsEvaluations [2]Fp256 for _, op := range io.Ops { switch op.Kind { @@ -128,6 +129,16 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v if err != nil { return fmt.Errorf("failed to deserialize claimed_evaluations_2: %w", err) } + + case "public_weights_evaluations": + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(config.Transcript[start:end]), + &publicWeightsEvaluations, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize public_weights_evaluations: %w", err) + } } if err != nil { @@ -204,7 +215,7 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v WitnessRoundHints: witnessRoundHints, } - err = verifyCircuit(deferred, config, hints, pk, vk, claimedEvaluations, claimedEvaluations2, r1cs, interner, buildOps) + err = verifyCircuit(deferred, config, hints, pk, vk, claimedEvaluations, claimedEvaluations2, publicWeightsEvaluations, r1cs, interner, buildOps, config.PublicInputs) if err != nil { return fmt.Errorf("verification failed: %w", err) } diff --git a/recursive-verifier/app/circuit/mtUtilities.go b/recursive-verifier/app/circuit/mtUtilities.go index 264727b38..47afcd222 100644 --- a/recursive-verifier/app/circuit/mtUtilities.go +++ b/recursive-verifier/app/circuit/mtUtilities.go @@ -1,6 +1,7 @@ package circuit import ( + "fmt" "reilabs/whir-verifier-circuit/app/utilities" "github.com/consensys/gnark/frontend" @@ -112,3 +113,85 @@ func rlcBatchedLeaves(api frontend.API, leaves [][]frontend.Variable, foldSize i } return collapsed } + +// hashPublicInputs computes the hash of public inputs by treating them as field elements +// This mimics the Rust PublicInputs::hash() function using SHA-256, TODO : Shift to skyscraper hash function later +func hashPublicInputs(api frontend.API, sc *skyscraper.Skyscraper, publicInputs PublicInputs) (frontend.Variable, error) { + if len(publicInputs.Values) == 0 { + // Return zero if no public inputs + return frontend.Variable(0), nil + } + + // Use hint to compute SHA-256 hash outside the circuit + // The hint function will be called during witness generation + hashResult, err := api.Compiler().NewHint(utilities.HashPublicInputsHint, 1, publicInputs.Values...) + if err != nil { + return nil, fmt.Errorf("failed to create hash hint: %w", err) + } + + return hashResult[0], nil +} + +// verifyPublicInputsAndReadWeights reads and verifies the public inputs hash from the transcript, +// then reads the public weights challenge and query answer. +// Returns (publicWeightsChallenge, publicWeightsQueryAnswer, error) +func verifyPublicInputsAndReadWeights( + api frontend.API, + sc *skyscraper.Skyscraper, + arthur gnarkNimue.Arthur, + publicInputs PublicInputs, +) (frontend.Variable, []frontend.Variable, error) { + // Read public inputs hash from transcript + publicInputsHashBuf := make([]frontend.Variable, 1) + if err := arthur.FillNextScalars(publicInputsHashBuf); err != nil { + return nil, nil, fmt.Errorf("failed to read public inputs hash: %w", err) + } + + // Compute expected public inputs hash + expectedHash, err := hashPublicInputs(api, sc, publicInputs) + if err != nil { + return nil, nil, fmt.Errorf("failed to compute public inputs hash: %w", err) + } + + // Verify hash matches + api.AssertIsEqual(publicInputsHashBuf[0], expectedHash) + + // Read public weights vector random challenge + publicWeightsChallenge := make([]frontend.Variable, 1) + if err := arthur.FillChallengeScalars(publicWeightsChallenge); err != nil { + return nil, nil, fmt.Errorf("failed to read public weights challenge: %w", err) + } + + // Read WHIR public weights query answer (2 field elements: f_sum, g_sum) + publicWeightsQueryAnswer := make([]frontend.Variable, 2) + if err := arthur.FillNextScalars(publicWeightsQueryAnswer); err != nil { + return nil, nil, fmt.Errorf("failed to read public weights query answer: %w", err) + } + + return publicWeightsChallenge[0], publicWeightsQueryAnswer, nil +} + +// readPublicWeightsQueryAnswer reads only the public weights query answer from the transcript. +// The challenge has already been read at the circuit level to match transcript order. +// Returns (publicWeightsQueryAnswer, error) +func readPublicWeightsQueryAnswer(arthur gnarkNimue.Arthur) ([]frontend.Variable, error) { + // Read WHIR public weights query answer (2 field elements: f_sum, g_sum) + publicWeightsQueryAnswer := make([]frontend.Variable, 2) + if err := arthur.FillNextScalars(publicWeightsQueryAnswer); err != nil { + return nil, fmt.Errorf("failed to read public weights query answer: %w", err) + } + + return publicWeightsQueryAnswer, nil +} + +// computePublicWeightsClaimedSum computes the claimed sum for the public weights constraint +// This is: public_f_sum + public_g_sum * batching_randomness +func computePublicWeightsClaimedSum( + api frontend.API, + publicWeightsQueryAnswer []frontend.Variable, + batchingRandomness frontend.Variable, +) frontend.Variable { + publicFSum := publicWeightsQueryAnswer[0] + publicGSum := publicWeightsQueryAnswer[1] + return api.Add(publicFSum, api.Mul(publicGSum, batchingRandomness)) +} diff --git a/recursive-verifier/app/circuit/types.go b/recursive-verifier/app/circuit/types.go index 420b43e0e..b456231a6 100644 --- a/recursive-verifier/app/circuit/types.go +++ b/recursive-verifier/app/circuit/types.go @@ -1,6 +1,8 @@ package circuit import ( + "reilabs/whir-verifier-circuit/app/utilities" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" ) @@ -101,6 +103,7 @@ type Config struct { BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` NumChallenges int `json:"num_challenges"` W1Size int `json:"w1_size"` + PublicInputs PublicInputs `json:"public_inputs"` } // Update Hints to support batch mode @@ -139,3 +142,20 @@ type DualClaimedEvaluations struct { First ClaimedEvaluations Second ClaimedEvaluations } + +type PublicInputs struct { + Values []frontend.Variable +} + +func (p *PublicInputs) UnmarshalJSON(data []byte) error { + values, err := utilities.UnmarshalPublicInputs(data) + if err != nil { + return err + } + p.Values = values + return nil +} + +func (p *PublicInputs) IsEmpty() bool { + return len(p.Values) == 0 +} \ No newline at end of file diff --git a/recursive-verifier/app/circuit/whir.go b/recursive-verifier/app/circuit/whir.go index 0e153280e..4e367cb34 100644 --- a/recursive-verifier/app/circuit/whir.go +++ b/recursive-verifier/app/circuit/whir.go @@ -59,7 +59,7 @@ func RunZKWhir( firstRound Merkle, whirParams WHIRParams, linearStatementEvaluations [][]frontend.Variable, - linearStatementValuesAtPoints []frontend.Variable, + linearStatementValuesAtPoints []frontend.Variable, // weights.evaluate(random_point) - this is what needs to be done batchingRandomness frontend.Variable, initialOODQueries []frontend.Variable, initialOODAnswers [][]frontend.Variable, @@ -238,6 +238,7 @@ func RunZKWhirBatch( // Common parameters whirParams WHIRParams, linearStatementValuesAtPoints []frontend.Variable, + publicInputs PublicInputs, ) (totalFoldingRandomness []frontend.Variable, err error) { numPolynomials := len(firstRounds) if numPolynomials == 0 { diff --git a/recursive-verifier/app/circuit/whir_utilities.go b/recursive-verifier/app/circuit/whir_utilities.go index 666dde44f..ac75a9c34 100644 --- a/recursive-verifier/app/circuit/whir_utilities.go +++ b/recursive-verifier/app/circuit/whir_utilities.go @@ -140,6 +140,7 @@ func computeWPoly( value = api.Add(value, api.Mul(initialData.InitialCombinationRandomness[j], utilities.EqPolyOutside(api, utilities.ExpandFromUnivariate(api, initialData.InitialOODQueries[j], numberVars), totalFoldingRandomness))) } + // Values are directly used as all linearStatements are deffered and hints were given. Checking of hints will be done later on. for j, linearStatementValueAtPoint := range linearStatementValuesAtPoints { value = api.Add(value, api.Mul(initialData.InitialCombinationRandomness[len(initialData.InitialOODQueries)+j], linearStatementValueAtPoint)) } diff --git a/recursive-verifier/app/utilities/utilities.go b/recursive-verifier/app/utilities/utilities.go index 04a8e9d2e..7222133bb 100644 --- a/recursive-verifier/app/utilities/utilities.go +++ b/recursive-verifier/app/utilities/utilities.go @@ -1,6 +1,10 @@ package utilities import ( + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "encoding/json" "fmt" "math/big" "reilabs/whir-verifier-circuit/app/typeConverters" @@ -58,6 +62,77 @@ func IndexOf(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { return nil } +// HashPublicInputsHint is a hint function that computes SHA-256 hash of public inputs +// matching the Rust PublicInputs::hash() implementation. +// It takes public input values, converts them to BigInt, extracts limbs, hashes them, +// and returns the hash as a field element. +func HashPublicInputsHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(outputs) != 1 { + return fmt.Errorf("expecting one output") + } + + if len(inputs) == 0 { + outputs[0] = big.NewInt(0) + return nil + } + + hasher := sha256.New() + + // Process each public input value + for _, input := range inputs { + // Convert field element to BigInt (it's already a BigInt, but ensure it's in range) + value := new(big.Int).Set(input) + + // Extract limbs (u64 values) from BigInt + // Field elements are represented as 4 u64 limbs in little-endian + limbs := make([]uint64, 4) + temp := new(big.Int).Set(value) + limbs[0] = temp.Uint64() // Least significant limb + temp.Rsh(temp, 64) + limbs[1] = temp.Uint64() + temp.Rsh(temp, 64) + limbs[2] = temp.Uint64() + temp.Rsh(temp, 64) + limbs[3] = temp.Uint64() // Most significant limb + + // Hash each limb as little-endian bytes (8 bytes per limb) + for _, limb := range limbs { + limbBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(limbBytes, limb) + hasher.Write(limbBytes) + } + } + + // Get the hash result (32 bytes) + hashResult := hasher.Sum(nil) + + // Convert hash result to field element by splitting into 4 u64 limbs + // Each chunk of 8 bytes becomes a u64 (little-endian) + limbs := make([]uint64, 4) + for i := 0; i < 4; i++ { + start := i * 8 + end := start + 8 + limbs[i] = binary.LittleEndian.Uint64(hashResult[start:end]) + } + + // Reconstruct field element from limbs + result := new(big.Int).SetUint64(limbs[0]) + temp := new(big.Int).SetUint64(limbs[1]) + result.Add(result, temp.Lsh(temp, 64)) + temp.SetUint64(limbs[2]) + result.Add(result, temp.Lsh(temp, 128)) + temp.SetUint64(limbs[3]) + result.Add(result, temp.Lsh(temp, 192)) + + // Apply modulus to ensure result is in field range + modulus := new(big.Int) + modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + result.Mod(result, modulus) + + outputs[0] = result + return nil +} + func Reverse[T any](s []T) []T { res := make([]T, len(s)) copy(res, s) @@ -210,3 +285,49 @@ func DotProduct(api frontend.API, a []frontend.Variable, b []frontend.Variable) } return acc } + +// ParseHexFieldElement parses a hex string representing a FieldElement (little-endian) +// and converts it to a big.Int. The hex string should be 64 characters (32 bytes). +func ParseHexFieldElement(hexStr string) (*big.Int, error) { + if len(hexStr) >= 2 && hexStr[0:2] == "0x" { + hexStr = hexStr[2:] + } + + bytes, err := hex.DecodeString(hexStr) + if err != nil { + return nil, fmt.Errorf("invalid hex string: %w", err) + } + + reversed := make([]byte, len(bytes)) + for i, b := range bytes { + reversed[len(bytes)-1-i] = b + } + + result := new(big.Int) + result.SetBytes(reversed) + + modulus := new(big.Int) + modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + result.Mod(result, modulus) + + return result, nil +} + +// UnmarshalPublicInputs parses a JSON array of hex-encoded FieldElement strings +// and returns them as frontend.Variable slice. +func UnmarshalPublicInputs(data []byte) ([]frontend.Variable, error) { + var arr []string + if err := json.Unmarshal(data, &arr); err != nil { + return nil, err + } + + values := make([]frontend.Variable, len(arr)) + for i, hexStr := range arr { + value, err := ParseHexFieldElement(hexStr) + if err != nil { + return nil, fmt.Errorf("failed to parse public input at index %d: %w", i, err) + } + values[i] = value + } + return values, nil +} From ff7174a7b3c0a8fed87bc8a0f33b21c2583bc289 Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 07:24:55 +0530 Subject: [PATCH 12/19] feat: tooling support --- tooling/cli/src/cmd/generate_gnark_inputs.rs | 1 + tooling/provekit-gnark/src/gnark_config.rs | 8 +++++++- tooling/verifier-server/src/services/verification.rs | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tooling/cli/src/cmd/generate_gnark_inputs.rs b/tooling/cli/src/cmd/generate_gnark_inputs.rs index 883f52b5e..0e3cc6d0a 100644 --- a/tooling/cli/src/cmd/generate_gnark_inputs.rs +++ b/tooling/cli/src/cmd/generate_gnark_inputs.rs @@ -62,6 +62,7 @@ impl Command for Args { prover.whir_for_witness.a_num_terms, prover.whir_for_witness.num_challenges, prover.whir_for_witness.w1_size, + &proof.public_inputs, &self.params_for_recursive_verifier, ); diff --git a/tooling/provekit-gnark/src/gnark_config.rs b/tooling/provekit-gnark/src/gnark_config.rs index 015cc68f8..41439a75c 100644 --- a/tooling/provekit-gnark/src/gnark_config.rs +++ b/tooling/provekit-gnark/src/gnark_config.rs @@ -1,6 +1,6 @@ use { ark_poly::EvaluationDomain, - provekit_common::{IOPattern, WhirConfig}, + provekit_common::{IOPattern, PublicInputs, WhirConfig}, serde::{Deserialize, Serialize}, std::{fs::File, io::Write}, tracing::instrument, @@ -29,6 +29,8 @@ pub struct GnarkConfig { pub num_challenges: usize, /// size of w1 pub w1_size: usize, + /// public inputs + pub public_inputs: PublicInputs, } #[derive(Debug, Serialize, Deserialize)] @@ -114,6 +116,7 @@ pub fn gnark_parameters( a_num_terms: usize, num_challenges: usize, w1_size: usize, + public_inputs: &PublicInputs, ) -> GnarkConfig { GnarkConfig { whir_config_witness: WHIRConfigGnark::new(whir_params_witness), @@ -126,6 +129,7 @@ pub fn gnark_parameters( transcript_len: transcript.to_vec().len(), num_challenges, w1_size, + public_inputs: public_inputs.clone(), } } @@ -141,6 +145,7 @@ pub fn write_gnark_parameters_to_file( a_num_terms: usize, num_challenges: usize, w1_size: usize, + public_inputs: &PublicInputs, file_path: &str, ) { let gnark_config = gnark_parameters( @@ -153,6 +158,7 @@ pub fn write_gnark_parameters_to_file( a_num_terms, num_challenges, w1_size, + public_inputs, ); let mut file_params = File::create(file_path).unwrap(); file_params diff --git a/tooling/verifier-server/src/services/verification.rs b/tooling/verifier-server/src/services/verification.rs index f8abd1385..398b282c4 100644 --- a/tooling/verifier-server/src/services/verification.rs +++ b/tooling/verifier-server/src/services/verification.rs @@ -97,6 +97,7 @@ impl VerificationService { whir_scheme.a_num_terms, whir_scheme.num_challenges, whir_scheme.w1_size, + &proof.public_inputs, gnark_params_path, ); From 3534a9521d68baa7331e69b5eaec2ecd536df437 Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 07:37:09 +0530 Subject: [PATCH 13/19] feat: switch to skyscraper for hashing --- provekit/common/src/skyscraper/mod.rs | 6 +- provekit/common/src/witness/mod.rs | 36 ++++---- recursive-verifier/app/circuit/circuit.go | 7 +- .../app/circuit/circuit_test.go | 4 +- recursive-verifier/app/circuit/mtUtilities.go | 84 +++---------------- 5 files changed, 36 insertions(+), 101 deletions(-) diff --git a/provekit/common/src/skyscraper/mod.rs b/provekit/common/src/skyscraper/mod.rs index 3b6da92a4..2caecdc28 100644 --- a/provekit/common/src/skyscraper/mod.rs +++ b/provekit/common/src/skyscraper/mod.rs @@ -2,4 +2,8 @@ mod pow; mod sponge; mod whir; -pub use self::{pow::SkyscraperPoW, sponge::SkyscraperSponge, whir::SkyscraperMerkleConfig}; +pub use self::{ + pow::SkyscraperPoW, + sponge::SkyscraperSponge, + whir::{SkyscraperCRH, SkyscraperMerkleConfig}, +}; diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index d0fa8ea62..a0ff25d1e 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -8,12 +8,13 @@ mod witness_io_pattern; use { crate::{ + skyscraper::SkyscraperCRH, utils::{serde_ark, serde_ark_vec}, FieldElement, }, - ark_ff::{BigInt, One, PrimeField}, + ark_crypto_primitives::crh::CRHScheme, + ark_ff::One, serde::{Deserialize, Serialize}, - sha2::{Digest, Sha256}, }; pub use { binops::{BINOP_ATOMIC_BITS, BINOP_BITS, NUM_DIGITS}, @@ -49,12 +50,10 @@ impl ConstantOrR1CSWitness { pub struct PublicInputs(#[serde(with = "serde_ark_vec")] pub Vec); impl PublicInputs { - /// Creates a new `PublicInputs` with an empty vector. pub fn new() -> Self { Self(Vec::new()) } - /// Creates a new `PublicInputs` from a vector. pub fn from_vec(vec: Vec) -> Self { Self(vec) } @@ -67,27 +66,20 @@ impl PublicInputs { self.0.is_empty() } - /// Hashes the public input values using SHA-256 and converts the result to - /// a FieldElement. pub fn hash(&self) -> FieldElement { - let mut hasher = Sha256::new(); - - // Hash all public values from witness - for value in self.0.iter() { - let bigint = value.into_bigint(); - for limb in bigint.0.iter() { - hasher.update(&limb.to_le_bytes()); + match self.0.len() { + 0 => FieldElement::from(0u64), + 1 => { + // For single element, hash it with zero to ensure it gets properly hashed + let padded = vec![self.0[0], FieldElement::from(0u64)]; + SkyscraperCRH::evaluate(&(), &padded[..]) + .expect("hash should succeed") + } + _ => { + SkyscraperCRH::evaluate(&(), &self.0[..]) + .expect("hash should succeed for multiple inputs") } } - - let result = hasher.finalize(); - - let limbs = result - .chunks_exact(8) - .map(|s| u64::from_le_bytes(s.try_into().unwrap())) - .collect::>(); - - FieldElement::new(BigInt::new(limbs.try_into().unwrap())) } } diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index dd7c2543d..c5a7fbb49 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -104,8 +104,7 @@ func (circuit *Circuit) Define(api frontend.API) error { return fmt.Errorf("failed to read public inputs hash: %w", err) } - // TODO : Compute expected public inputs hash and verify - expectedHash, err := hashPublicInputs(api, sc, circuit.PublicInputs) + expectedHash, err := hashPublicInputs(sc, circuit.PublicInputs) if err != nil { return fmt.Errorf("failed to compute public inputs hash: %w", err) } @@ -367,6 +366,8 @@ func verifyCircuit( Values: make([]frontend.Variable, len(publicInputs.Values)), } + log.Println("publicInputs", publicInputs) + circuit := Circuit{ IO: []byte(cfg.IOPattern), Transcript: contTranscript, @@ -514,7 +515,7 @@ func verifyCircuit( } opts := []backend.ProverOption{ - backend.WithSolverOptions(solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)), + backend.WithSolverOptions(solver.WithHints(utilities.IndexOf)), backend.WithIcicleAcceleration(), } diff --git a/recursive-verifier/app/circuit/circuit_test.go b/recursive-verifier/app/circuit/circuit_test.go index 69c1fa46e..19e2c4a04 100644 --- a/recursive-verifier/app/circuit/circuit_test.go +++ b/recursive-verifier/app/circuit/circuit_test.go @@ -70,7 +70,7 @@ func TestCircuitConstraints(t *testing.T) { test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), - test.WithSolverOpts(solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)), + test.WithSolverOpts(solver.WithHints(utilities.IndexOf)), ) } @@ -124,7 +124,7 @@ func TestCircuitConstraintsSolverOnly(t *testing.T) { } // Solve the constraint system - _, err = ccs.Solve(witness, solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)) + _, err = ccs.Solve(witness, solver.WithHints(utilities.IndexOf)) if err != nil { t.Fatalf("Constraint system not satisfied: %v", err) } diff --git a/recursive-verifier/app/circuit/mtUtilities.go b/recursive-verifier/app/circuit/mtUtilities.go index 47afcd222..4dc74eee2 100644 --- a/recursive-verifier/app/circuit/mtUtilities.go +++ b/recursive-verifier/app/circuit/mtUtilities.go @@ -1,7 +1,6 @@ package circuit import ( - "fmt" "reilabs/whir-verifier-circuit/app/utilities" "github.com/consensys/gnark/frontend" @@ -114,84 +113,23 @@ func rlcBatchedLeaves(api frontend.API, leaves [][]frontend.Variable, foldSize i return collapsed } -// hashPublicInputs computes the hash of public inputs by treating them as field elements -// This mimics the Rust PublicInputs::hash() function using SHA-256, TODO : Shift to skyscraper hash function later -func hashPublicInputs(api frontend.API, sc *skyscraper.Skyscraper, publicInputs PublicInputs) (frontend.Variable, error) { +// hashPublicInputs computes the hash of public inputs as field elements sequentially +func hashPublicInputs(sc *skyscraper.Skyscraper, publicInputs PublicInputs) (frontend.Variable, error) { + if len(publicInputs.Values) == 0 { - // Return zero if no public inputs return frontend.Variable(0), nil } - // Use hint to compute SHA-256 hash outside the circuit - // The hint function will be called during witness generation - hashResult, err := api.Compiler().NewHint(utilities.HashPublicInputsHint, 1, publicInputs.Values...) - if err != nil { - return nil, fmt.Errorf("failed to create hash hint: %w", err) - } - - return hashResult[0], nil -} - -// verifyPublicInputsAndReadWeights reads and verifies the public inputs hash from the transcript, -// then reads the public weights challenge and query answer. -// Returns (publicWeightsChallenge, publicWeightsQueryAnswer, error) -func verifyPublicInputsAndReadWeights( - api frontend.API, - sc *skyscraper.Skyscraper, - arthur gnarkNimue.Arthur, - publicInputs PublicInputs, -) (frontend.Variable, []frontend.Variable, error) { - // Read public inputs hash from transcript - publicInputsHashBuf := make([]frontend.Variable, 1) - if err := arthur.FillNextScalars(publicInputsHashBuf); err != nil { - return nil, nil, fmt.Errorf("failed to read public inputs hash: %w", err) - } - - // Compute expected public inputs hash - expectedHash, err := hashPublicInputs(api, sc, publicInputs) - if err != nil { - return nil, nil, fmt.Errorf("failed to compute public inputs hash: %w", err) - } - - // Verify hash matches - api.AssertIsEqual(publicInputsHashBuf[0], expectedHash) - - // Read public weights vector random challenge - publicWeightsChallenge := make([]frontend.Variable, 1) - if err := arthur.FillChallengeScalars(publicWeightsChallenge); err != nil { - return nil, nil, fmt.Errorf("failed to read public weights challenge: %w", err) + // For single element, we hash it with a zero + if len(publicInputs.Values) == 1 { + return sc.CompressV2(publicInputs.Values[0], frontend.Variable(0)), nil } - // Read WHIR public weights query answer (2 field elements: f_sum, g_sum) - publicWeightsQueryAnswer := make([]frontend.Variable, 2) - if err := arthur.FillNextScalars(publicWeightsQueryAnswer); err != nil { - return nil, nil, fmt.Errorf("failed to read public weights query answer: %w", err) + // For 2+ elements, use standard approach + hash := sc.CompressV2(publicInputs.Values[0], publicInputs.Values[1]) + for i := 2; i < len(publicInputs.Values); i++ { + hash = sc.CompressV2(hash, publicInputs.Values[i]) } - return publicWeightsChallenge[0], publicWeightsQueryAnswer, nil -} - -// readPublicWeightsQueryAnswer reads only the public weights query answer from the transcript. -// The challenge has already been read at the circuit level to match transcript order. -// Returns (publicWeightsQueryAnswer, error) -func readPublicWeightsQueryAnswer(arthur gnarkNimue.Arthur) ([]frontend.Variable, error) { - // Read WHIR public weights query answer (2 field elements: f_sum, g_sum) - publicWeightsQueryAnswer := make([]frontend.Variable, 2) - if err := arthur.FillNextScalars(publicWeightsQueryAnswer); err != nil { - return nil, fmt.Errorf("failed to read public weights query answer: %w", err) - } - - return publicWeightsQueryAnswer, nil -} - -// computePublicWeightsClaimedSum computes the claimed sum for the public weights constraint -// This is: public_f_sum + public_g_sum * batching_randomness -func computePublicWeightsClaimedSum( - api frontend.API, - publicWeightsQueryAnswer []frontend.Variable, - batchingRandomness frontend.Variable, -) frontend.Variable { - publicFSum := publicWeightsQueryAnswer[0] - publicGSum := publicWeightsQueryAnswer[1] - return api.Add(publicFSum, api.Mul(publicGSum, batchingRandomness)) + return hash, nil } From 5122790a1bf75447a4fa064960e50afd71e0d28f Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 08:06:42 +0530 Subject: [PATCH 14/19] feat: fix dualmode --- recursive-verifier/app/circuit/circuit.go | 38 ++++++++++++++++++----- recursive-verifier/app/circuit/whir.go | 12 +++++-- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index c5a7fbb49..1262d76af 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -122,7 +122,32 @@ func (circuit *Circuit) Define(api frontend.API) error { var az, bz, cz frontend.Variable if circuit.NumChallenges > 0 { - // Dual commitment mode: batch WHIR verification + // Only statement_1 (first commitment) gets extended with public weights, statement_2 remains unchanged + extendedLinearStatementEvalsBatch := make([][][]frontend.Variable, 2) + + if !circuit.PublicInputs.IsEmpty() { + extendedLinearStatementEvalsBatch[0] = extendLinearStatement( + circuit, + [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, + circuit.PubWitnessEvaluations, + ) + + extendedLinearStatementEvalsBatch[1] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[1], + circuit.WitnessBlindingEvaluations[1], + } + } else { + // Use original arrays as before, no public inputs + extendedLinearStatementEvalsBatch[0] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[0], + circuit.WitnessBlindingEvaluations[0], + } + extendedLinearStatementEvalsBatch[1] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[1], + circuit.WitnessBlindingEvaluations[1], + } + } + whirFoldingRandomness, err = RunZKWhirBatch( api, arthur, uapi, sc, circuit.WitnessFirstRounds, // firstRounds []Merkle @@ -131,13 +156,10 @@ func (circuit *Circuit) Define(api frontend.API) error { [][][]frontend.Variable{initialOODAnswers1, initialOODAnswers2}, // initialOODAnswers []frontend.Variable{rootHash1, rootHash2}, // rootHashes circuit.WitnessMerkle, // batchedMerkle - [][][]frontend.Variable{ // linearStatementEvals - {circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, - {circuit.WitnessClaimedEvaluations[1], circuit.WitnessBlindingEvaluations[1]}, - }, - circuit.WHIRParamsWitness, // whirParams - circuit.WitnessLinearStatementEvaluations, // linearStatementValuesAtPoints - circuit.PublicInputs, // publicInputs + extendedLinearStatementEvalsBatch, // linearStatementEvals (extended for first commitment) + circuit.WHIRParamsWitness, // whirParams + circuit.WitnessLinearStatementEvaluations, // linearStatementValuesAtPoints + circuit.PublicInputs, // publicInputs ) if err != nil { return err diff --git a/recursive-verifier/app/circuit/whir.go b/recursive-verifier/app/circuit/whir.go index 4e367cb34..3d780dd7d 100644 --- a/recursive-verifier/app/circuit/whir.go +++ b/recursive-verifier/app/circuit/whir.go @@ -59,7 +59,7 @@ func RunZKWhir( firstRound Merkle, whirParams WHIRParams, linearStatementEvaluations [][]frontend.Variable, - linearStatementValuesAtPoints []frontend.Variable, // weights.evaluate(random_point) - this is what needs to be done + linearStatementValuesAtPoints []frontend.Variable, batchingRandomness frontend.Variable, initialOODQueries []frontend.Variable, initialOODAnswers [][]frontend.Variable, @@ -256,7 +256,15 @@ func RunZKWhirBatch( for i := 0; i < numPolynomials; i++ { numOOD += len(initialOODQueries[i]) } - numStatementConstraints := numPolynomials * 3 // 3 per commitment (Az, Bz, Cz) + + numStatementConstraints := 0 + + // w1 has 4 (pub, Az, Bz, Cz) constraints, w2 and remaining have 3 (Az, Bz, Cz) constraints + if !publicInputs.IsEmpty() { + numStatementConstraints = 4 + 3*(numPolynomials-1) + } else { + numStatementConstraints = numPolynomials * 3 + } numConstraints := numOOD + numStatementConstraints // Step 3: Read N×M evaluation matrix from transcript From 0c5e3be392597b46722f4abb76b41dbcc5733d8d Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 08:38:36 +0530 Subject: [PATCH 15/19] chore: cleanup --- recursive-verifier/app/circuit/circuit.go | 8 ++++--- recursive-verifier/app/circuit/types.go | 26 +++++++++++------------ recursive-verifier/app/circuit/whir.go | 2 +- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index 1262d76af..fcc1e2165 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -198,9 +198,11 @@ func (circuit *Circuit) Define(api frontend.API) error { x := api.Mul(api.Sub(api.Mul(az, bz), cz), calculateEQ(api, spartanSumcheckRand, tRand)) api.AssertIsEqual(spartanSumcheckLastValue, x) - // TODO : generalize it later on if we have more different kinds of statements - // for handling geometric weights statement added at starting - offset := 1 + offset := 0 + if !circuit.PublicInputs.IsEmpty() { + // can be generalized later on if we have more different kinds of statements + offset = 1 + } if circuit.NumChallenges > 0 { // Batch mode - check 6 deferred values diff --git a/recursive-verifier/app/circuit/types.go b/recursive-verifier/app/circuit/types.go index b456231a6..f6db49d00 100644 --- a/recursive-verifier/app/circuit/types.go +++ b/recursive-verifier/app/circuit/types.go @@ -91,18 +91,18 @@ type ProofObject struct { } type Config struct { - WHIRConfigWitness WHIRConfig `json:"whir_config_witness"` - WHIRConfigHidingSpartan WHIRConfig `json:"whir_config_hiding_spartan"` - LogNumConstraints int `json:"log_num_constraints"` - LogNumVariables int `json:"log_num_variables"` - LogANumTerms int `json:"log_a_num_terms"` - IOPattern string `json:"io_pattern"` - Transcript []byte `json:"transcript"` - TranscriptLen int `json:"transcript_len"` - WitnessStatementEvaluations []string `json:"witness_statement_evaluations"` - BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` - NumChallenges int `json:"num_challenges"` - W1Size int `json:"w1_size"` + WHIRConfigWitness WHIRConfig `json:"whir_config_witness"` + WHIRConfigHidingSpartan WHIRConfig `json:"whir_config_hiding_spartan"` + LogNumConstraints int `json:"log_num_constraints"` + LogNumVariables int `json:"log_num_variables"` + LogANumTerms int `json:"log_a_num_terms"` + IOPattern string `json:"io_pattern"` + Transcript []byte `json:"transcript"` + TranscriptLen int `json:"transcript_len"` + WitnessStatementEvaluations []string `json:"witness_statement_evaluations"` + BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` + NumChallenges int `json:"num_challenges"` + W1Size int `json:"w1_size"` PublicInputs PublicInputs `json:"public_inputs"` } @@ -158,4 +158,4 @@ func (p *PublicInputs) UnmarshalJSON(data []byte) error { func (p *PublicInputs) IsEmpty() bool { return len(p.Values) == 0 -} \ No newline at end of file +} diff --git a/recursive-verifier/app/circuit/whir.go b/recursive-verifier/app/circuit/whir.go index 3d780dd7d..3bdea8f19 100644 --- a/recursive-verifier/app/circuit/whir.go +++ b/recursive-verifier/app/circuit/whir.go @@ -256,7 +256,7 @@ func RunZKWhirBatch( for i := 0; i < numPolynomials; i++ { numOOD += len(initialOODQueries[i]) } - + numStatementConstraints := 0 // w1 has 4 (pub, Az, Bz, Cz) constraints, w2 and remaining have 3 (Az, Bz, Cz) constraints From f9776b55306cf9722178d2433eb9ae974941e180 Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 08:58:05 +0530 Subject: [PATCH 16/19] chore: cargofmt --- provekit/common/src/witness/mod.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index a0ff25d1e..a5e9a8c4f 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -72,13 +72,10 @@ impl PublicInputs { 1 => { // For single element, hash it with zero to ensure it gets properly hashed let padded = vec![self.0[0], FieldElement::from(0u64)]; - SkyscraperCRH::evaluate(&(), &padded[..]) - .expect("hash should succeed") - } - _ => { - SkyscraperCRH::evaluate(&(), &self.0[..]) - .expect("hash should succeed for multiple inputs") + SkyscraperCRH::evaluate(&(), &padded[..]).expect("hash should succeed") } + _ => SkyscraperCRH::evaluate(&(), &self.0[..]) + .expect("hash should succeed for multiple inputs"), } } } From 3c961eb417f4b8b3890c605b744ae7e2a6fd0040 Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 22 Jan 2026 20:54:15 +0530 Subject: [PATCH 17/19] fix: cleanup reviews --- provekit/common/Cargo.toml | 1 - provekit/common/src/witness/mod.rs | 2 +- provekit/common/src/witness/scheduling/mod.rs | 6 +- .../common/src/witness/scheduling/splitter.rs | 33 ++++++--- .../common/src/witness/witness_builder.rs | 16 ++-- provekit/prover/src/whir_r1cs.rs | 4 +- .../r1cs-compiler/src/noir_proof_scheme.rs | 2 +- recursive-verifier/app/circuit/circuit.go | 6 +- recursive-verifier/app/utilities/utilities.go | 73 ------------------- 9 files changed, 41 insertions(+), 102 deletions(-) diff --git a/provekit/common/Cargo.toml b/provekit/common/Cargo.toml index 34b43012e..92faae9c6 100644 --- a/provekit/common/Cargo.toml +++ b/provekit/common/Cargo.toml @@ -38,7 +38,6 @@ ruint.workspace = true serde.workspace = true serde_json.workspace = true tracing.workspace = true -sha2.workspace = true zerocopy.workspace = true zeroize.workspace = true zstd.workspace = true diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index a5e9a8c4f..183c116ff 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -20,7 +20,7 @@ pub use { binops::{BINOP_ATOMIC_BITS, BINOP_BITS, NUM_DIGITS}, digits::{decompose_into_digits, DigitalDecompositionWitnesses}, ram::{SpiceMemoryOperation, SpiceWitnesses}, - scheduling::{Layer, LayerType, LayeredWitnessBuilders, SplitWitnessBuilders}, + scheduling::{Layer, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders}, witness_builder::{ ConstantTerm, ProductLinearTerm, SumTerm, WitnessBuilder, WitnessCoefficient, }, diff --git a/provekit/common/src/witness/scheduling/mod.rs b/provekit/common/src/witness/scheduling/mod.rs index f2529253e..28d80c11b 100644 --- a/provekit/common/src/witness/scheduling/mod.rs +++ b/provekit/common/src/witness/scheduling/mod.rs @@ -9,8 +9,10 @@ mod scheduler; mod splitter; pub use { - dependency::DependencyInfo, remapper::WitnessIndexRemapper, scheduler::LayerScheduler, - splitter::WitnessSplitter, + dependency::DependencyInfo, + remapper::WitnessIndexRemapper, + scheduler::LayerScheduler, + splitter::{SplitError, WitnessSplitter}, }; /// Type of operations contained in a layer. diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index 94ffffe97..dc5107a6d 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -1,8 +1,23 @@ use { crate::witness::{scheduling::DependencyInfo, WitnessBuilder}, - std::collections::{HashSet, VecDeque}, + std::{ + collections::{HashSet, VecDeque}, + fmt, + }, }; +/// Error returned when witness splitting validation fails. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SplitError; + +impl fmt::Display for SplitError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "error in splitting witnesses into w1 and w2") + } +} + +impl std::error::Error for SplitError {} + /// Analyzes witness builder dependencies and splits them into w1/w2 groups. /// /// Uses backward reachability from challenge consumers (lookup builders) to @@ -29,7 +44,7 @@ impl<'a> WitnessSplitter<'a> { pub fn split_builders( &self, acir_public_inputs_indices_set: HashSet, - ) -> (Vec, Vec) { + ) -> Result<(Vec, Vec), SplitError> { let builder_count = self.witness_builders.len(); // Step 1: Find all Challenge builders @@ -46,8 +61,8 @@ impl<'a> WitnessSplitter<'a> { let w1_indices = self.rearrange_w1( (0..builder_count).collect(), &acir_public_inputs_indices_set, - ); - return (w1_indices, Vec::new()); + )?; + return Ok((w1_indices, Vec::new())); } // Step 2: Forward DFS from challenges to find mandatory_w2 @@ -182,10 +197,10 @@ impl<'a> WitnessSplitter<'a> { let mut w1_indices: Vec = w1_set.into_iter().collect(); let mut w2_indices: Vec = w2_set.into_iter().collect(); - w1_indices = self.rearrange_w1(w1_indices, &acir_public_inputs_indices_set); + w1_indices = self.rearrange_w1(w1_indices, &acir_public_inputs_indices_set)?; w2_indices.sort_unstable(); - (w1_indices, w2_indices) + Ok((w1_indices, w2_indices)) } /// Rearranges w1 builder indices into a canonical order: @@ -196,7 +211,7 @@ impl<'a> WitnessSplitter<'a> { &self, w1_indices: Vec, acir_public_inputs_indices_set: &HashSet, - ) -> Vec { + ) -> Result, SplitError> { let mut public_input_builder_indices = Vec::new(); let mut rest_indices = Vec::new(); @@ -206,7 +221,7 @@ impl<'a> WitnessSplitter<'a> { let w1_indices_set = w1_indices.iter().copied().collect::>(); for &idx in acir_public_inputs_indices_set.iter() { if !w1_indices_set.contains(&(idx as usize)) { - panic!("Public input {} is not in w1_indices", idx); + return Err(SplitError); } } @@ -229,6 +244,6 @@ impl<'a> WitnessSplitter<'a> { let mut new_w1_indices = vec![0]; new_w1_indices.extend(public_input_builder_indices); new_w1_indices.extend(rest_indices); - new_w1_indices + Ok(new_w1_indices) } } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 9567eac7a..601321a70 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -6,8 +6,8 @@ use { digits::DigitalDecompositionWitnesses, ram::SpiceWitnesses, scheduling::{ - LayerScheduler, LayeredWitnessBuilders, SplitWitnessBuilders, WitnessIndexRemapper, - WitnessSplitter, + LayerScheduler, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, + WitnessIndexRemapper, WitnessSplitter, }, ConstantOrR1CSWitness, }, @@ -175,9 +175,9 @@ impl WitnessBuilder { r1cs: R1CS, witness_map: Vec>, acir_public_inputs_indices_set: HashSet, - ) -> (SplitWitnessBuilders, R1CS, Vec>, usize) { + ) -> Result<(SplitWitnessBuilders, R1CS, Vec>, usize), SplitError> { if witness_builders.is_empty() { - return ( + return Ok(( SplitWitnessBuilders { w1_layers: LayeredWitnessBuilders { layers: Vec::new() }, w2_layers: LayeredWitnessBuilders { layers: Vec::new() }, @@ -186,12 +186,12 @@ impl WitnessBuilder { r1cs, witness_map, 0, - ); + )); } // Step 1: Analyze dependencies and split into w1/w2 let splitter = WitnessSplitter::new(witness_builders); - let (w1_indices, w2_indices) = splitter.split_builders(acir_public_inputs_indices_set); + let (w1_indices, w2_indices) = splitter.split_builders(acir_public_inputs_indices_set)?; // Step 2: Extract w1 and w2 builders in order let w1_builders: Vec = w1_indices @@ -245,7 +245,7 @@ impl WitnessBuilder { .filter(|b| matches!(b, WitnessBuilder::Challenge(_))) .count(); - ( + Ok(( SplitWitnessBuilders { w1_layers, w2_layers, @@ -254,6 +254,6 @@ impl WitnessBuilder { remapped_r1cs, remapped_witness_map, num_challenges, - ) + )) } } diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index fadfa0bda..06fe0fdd6 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -155,6 +155,7 @@ impl WhirR1CSProver for WhirR1CSScheme { // Compute weights from R1CS matrices let alphas = calculate_external_row_of_r1cs_matrices(alpha, r1cs); + let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); if is_single { // Single commitment path @@ -171,8 +172,6 @@ impl WhirR1CSProver for WhirR1CSScheme { merlin.hint::<(Vec, Vec)>(&(f_sums, g_sums))?; - // VERIFY the size given by self.m - let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { // If there are no public inputs, the hint is unused by the verifier and can be // assigned an arbitrary value. @@ -238,7 +237,6 @@ impl WhirR1CSProver for WhirR1CSScheme { merlin.hint::<(Vec, Vec)>(&(f_sums_1, g_sums_1))?; merlin.hint::<(Vec, Vec)>(&(f_sums_2, g_sums_2))?; - let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { let public_f_sum = FieldElement::zero(); let public_g_sum = FieldElement::zero(); diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 0cde2f1bb..40df4bfe3 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -72,7 +72,7 @@ impl NoirProofSchemeBuilder for NoirProofScheme { r1cs, witness_map, acir_public_inputs_indices_set, - ); + )?; info!( "Witness split: w1 size = {}, w2 size = {}", split_witness_builders.w1_size, diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index fcc1e2165..dc34ece57 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -170,7 +170,6 @@ func (circuit *Circuit) Define(api frontend.API) error { bz = api.Add(circuit.WitnessClaimedEvaluations[0][1], circuit.WitnessClaimedEvaluations[1][1]) cz = api.Add(circuit.WitnessClaimedEvaluations[0][2], circuit.WitnessClaimedEvaluations[1][2]) } else { - log.Println("Single Mode") extendedLinearStatementEvals := extendLinearStatement(circuit, [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, circuit.PubWitnessEvaluations) // Single commitment mode @@ -219,7 +218,7 @@ func (circuit *Circuit) Define(api frontend.API) error { } } - // Geomteric weights for public inputs + // Geometric weights for public inputs if !circuit.PublicInputs.IsEmpty() { publicWeightEval := computePublicWeightEvaluation( api, circuit.PublicInputs, whirFoldingRandomness, @@ -390,7 +389,6 @@ func verifyCircuit( Values: make([]frontend.Variable, len(publicInputs.Values)), } - log.Println("publicInputs", publicInputs) circuit := Circuit{ IO: []byte(cfg.IOPattern), @@ -534,7 +532,7 @@ func verifyCircuit( witness, _ := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) publicWitness, err := witness.Public() if err != nil { - log.Printf("Failed witess,Public(): %v", err) + log.Printf("Failed witness, Public(): %v", err) return err } diff --git a/recursive-verifier/app/utilities/utilities.go b/recursive-verifier/app/utilities/utilities.go index 7222133bb..691af27dd 100644 --- a/recursive-verifier/app/utilities/utilities.go +++ b/recursive-verifier/app/utilities/utilities.go @@ -1,8 +1,6 @@ package utilities import ( - "crypto/sha256" - "encoding/binary" "encoding/hex" "encoding/json" "fmt" @@ -62,77 +60,6 @@ func IndexOf(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { return nil } -// HashPublicInputsHint is a hint function that computes SHA-256 hash of public inputs -// matching the Rust PublicInputs::hash() implementation. -// It takes public input values, converts them to BigInt, extracts limbs, hashes them, -// and returns the hash as a field element. -func HashPublicInputsHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { - if len(outputs) != 1 { - return fmt.Errorf("expecting one output") - } - - if len(inputs) == 0 { - outputs[0] = big.NewInt(0) - return nil - } - - hasher := sha256.New() - - // Process each public input value - for _, input := range inputs { - // Convert field element to BigInt (it's already a BigInt, but ensure it's in range) - value := new(big.Int).Set(input) - - // Extract limbs (u64 values) from BigInt - // Field elements are represented as 4 u64 limbs in little-endian - limbs := make([]uint64, 4) - temp := new(big.Int).Set(value) - limbs[0] = temp.Uint64() // Least significant limb - temp.Rsh(temp, 64) - limbs[1] = temp.Uint64() - temp.Rsh(temp, 64) - limbs[2] = temp.Uint64() - temp.Rsh(temp, 64) - limbs[3] = temp.Uint64() // Most significant limb - - // Hash each limb as little-endian bytes (8 bytes per limb) - for _, limb := range limbs { - limbBytes := make([]byte, 8) - binary.LittleEndian.PutUint64(limbBytes, limb) - hasher.Write(limbBytes) - } - } - - // Get the hash result (32 bytes) - hashResult := hasher.Sum(nil) - - // Convert hash result to field element by splitting into 4 u64 limbs - // Each chunk of 8 bytes becomes a u64 (little-endian) - limbs := make([]uint64, 4) - for i := 0; i < 4; i++ { - start := i * 8 - end := start + 8 - limbs[i] = binary.LittleEndian.Uint64(hashResult[start:end]) - } - - // Reconstruct field element from limbs - result := new(big.Int).SetUint64(limbs[0]) - temp := new(big.Int).SetUint64(limbs[1]) - result.Add(result, temp.Lsh(temp, 64)) - temp.SetUint64(limbs[2]) - result.Add(result, temp.Lsh(temp, 128)) - temp.SetUint64(limbs[3]) - result.Add(result, temp.Lsh(temp, 192)) - - // Apply modulus to ensure result is in field range - modulus := new(big.Int) - modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) - result.Mod(result, modulus) - - outputs[0] = result - return nil -} - func Reverse[T any](s []T) []T { res := make([]T, len(s)) copy(res, s) From 65cd065645213f178b4615a8eac8ae07ae74679a Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 22 Jan 2026 21:50:00 +0530 Subject: [PATCH 18/19] fix: iotranscript --- provekit/common/src/whir_r1cs.rs | 7 ++-- .../r1cs-compiler/src/noir_proof_scheme.rs | 2 + provekit/r1cs-compiler/src/whir_r1cs.rs | 5 ++- provekit/verifier/src/whir_r1cs.rs | 39 +++++++------------ 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index 702a25c74..dc6ed3615 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -22,6 +22,7 @@ pub struct WhirR1CSScheme { pub m_0: usize, pub a_num_terms: usize, pub num_challenges: usize, + pub has_public_inputs: bool, pub whir_witness: WhirConfig, pub whir_for_hiding_spartan: WhirConfig, } @@ -38,7 +39,7 @@ impl WhirR1CSScheme { // statement_2 has 3 constraints = 3, total = 7 let num_witnesses = 2; let num_ood_constraints = num_witnesses * self.whir_witness.committment_ood_samples; - let num_statement_constraints = 7; + let num_statement_constraints = if self.has_public_inputs { 7 } else { 6 }; let num_constraints_total = num_ood_constraints + num_statement_constraints; io = io @@ -49,9 +50,9 @@ impl WhirR1CSScheme { .commit_statement(&self.whir_for_hiding_spartan) .add_zk_sumcheck_polynomials(self.m_0) .add_whir_proof(&self.whir_for_hiding_spartan) + .add_public_inputs() .hint("claimed_evaluations_1") .hint("claimed_evaluations_2") - .add_public_inputs() .hint("public_weights_evaluations") .add_whir_batch_proof(&self.whir_witness, num_witnesses, num_constraints_total); } else { @@ -61,8 +62,8 @@ impl WhirR1CSScheme { .commit_statement(&self.whir_for_hiding_spartan) .add_zk_sumcheck_polynomials(self.m_0) .add_whir_proof(&self.whir_for_hiding_spartan) - .hint("claimed_evaluations") .add_public_inputs() + .hint("claimed_evaluations") .hint("public_weights_evaluations") .add_whir_proof(&self.whir_witness); } diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 40df4bfe3..b89a87847 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -65,6 +65,7 @@ impl NoirProofSchemeBuilder for NoirProofScheme { let acir_public_inputs_indices_set: HashSet = main.public_inputs().indices().iter().cloned().collect(); + let has_public_inputs = !acir_public_inputs_indices_set.is_empty(); // Split witness builders and remap indices for sound challenge generation let (split_witness_builders, remapped_r1cs, remapped_witness_map, num_challenges) = WitnessBuilder::split_and_prepare_layers( @@ -91,6 +92,7 @@ impl NoirProofSchemeBuilder for NoirProofScheme { &remapped_r1cs, split_witness_builders.w1_size, num_challenges, + has_public_inputs, ); Ok(Self { diff --git a/provekit/r1cs-compiler/src/whir_r1cs.rs b/provekit/r1cs-compiler/src/whir_r1cs.rs index 4604be4a7..a1f1b3b91 100644 --- a/provekit/r1cs-compiler/src/whir_r1cs.rs +++ b/provekit/r1cs-compiler/src/whir_r1cs.rs @@ -17,13 +17,13 @@ const MIN_WHIR_NUM_VARIABLES: usize = 12; const MIN_SUMCHECK_NUM_VARIABLES: usize = 1; pub trait WhirR1CSSchemeBuilder { - fn new_for_r1cs(r1cs: &R1CS, w1_size: usize, num_challenges: usize) -> Self; + fn new_for_r1cs(r1cs: &R1CS, w1_size: usize, num_challenges: usize, has_public_inputs: bool) -> Self; fn new_whir_config_for_size(num_variables: usize, batch_size: usize) -> WhirConfig; } impl WhirR1CSSchemeBuilder for WhirR1CSScheme { - fn new_for_r1cs(r1cs: &R1CS, w1_size: usize, num_challenges: usize) -> Self { + fn new_for_r1cs(r1cs: &R1CS, w1_size: usize, num_challenges: usize, has_public_inputs: bool) -> Self { let total_witnesses = r1cs.num_witnesses(); assert!( w1_size <= total_witnesses, @@ -49,6 +49,7 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { next_power_of_two(4 * m_0) + 1, 2, ), + has_public_inputs, } } diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 17e102fb9..068fd1d0a 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -56,6 +56,19 @@ impl WhirR1CSVerifier for WhirR1CSScheme { run_sumcheck_verifier(&mut arthur, self.m_0, &self.whir_for_hiding_spartan) .context("while verifying sumcheck")?; + // Verify public inputs hash + let mut public_inputs_hash_buf = [FieldElement::zero()]; + arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; + let expected_public_inputs_hash = public_inputs.hash(); + ensure!( + public_inputs_hash_buf[0] == expected_public_inputs_hash, + "Public inputs hash mismatch: expected {:?}, got {:?}", + expected_public_inputs_hash, + public_inputs_hash_buf[0] + ); + let mut public_weights_vector_random_buf = [FieldElement::zero()]; + arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; + // Read hints and verify WHIR proof let (az_at_alpha, bz_at_alpha, cz_at_alpha) = if let Some(parsed_commitment_2) = parsed_commitment_2 { @@ -79,19 +92,6 @@ impl WhirR1CSVerifier for WhirR1CSScheme { &whir_sums_2, ); - let mut public_inputs_hash_buf = [FieldElement::zero()]; - arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; - let expected_public_inputs_hash = public_inputs.hash(); - ensure!( - public_inputs_hash_buf[0] == expected_public_inputs_hash, - "Public inputs hash mismatch: expected {:?}, got {:?}", - expected_public_inputs_hash, - public_inputs_hash_buf[0] - ); - - let mut public_weights_vector_random_buf = [FieldElement::zero()]; - arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur .hint() .context("failed to read WHIR public weights query answer")?; @@ -130,19 +130,6 @@ impl WhirR1CSVerifier for WhirR1CSScheme { &whir_sums, ); - let mut public_inputs_hash_buf = [FieldElement::zero()]; - arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; - let expected_public_inputs_hash = public_inputs.hash(); - ensure!( - public_inputs_hash_buf[0] == expected_public_inputs_hash, - "Public inputs hash mismatch: expected {:?}, got {:?}", - expected_public_inputs_hash, - public_inputs_hash_buf[0] - ); - - let mut public_weights_vector_random_buf = [FieldElement::zero()]; - arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur .hint() .context("failed to read WHIR public weights query answer")?; From 278a7ff7e717b0f555a4a825cd1da2b7f392310c Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 22 Jan 2026 22:33:21 +0530 Subject: [PATCH 19/19] chore: fmt --- provekit/r1cs-compiler/src/whir_r1cs.rs | 14 ++++++++++++-- recursive-verifier/app/circuit/circuit.go | 1 - 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/provekit/r1cs-compiler/src/whir_r1cs.rs b/provekit/r1cs-compiler/src/whir_r1cs.rs index a1f1b3b91..385b11c98 100644 --- a/provekit/r1cs-compiler/src/whir_r1cs.rs +++ b/provekit/r1cs-compiler/src/whir_r1cs.rs @@ -17,13 +17,23 @@ const MIN_WHIR_NUM_VARIABLES: usize = 12; const MIN_SUMCHECK_NUM_VARIABLES: usize = 1; pub trait WhirR1CSSchemeBuilder { - fn new_for_r1cs(r1cs: &R1CS, w1_size: usize, num_challenges: usize, has_public_inputs: bool) -> Self; + fn new_for_r1cs( + r1cs: &R1CS, + w1_size: usize, + num_challenges: usize, + has_public_inputs: bool, + ) -> Self; fn new_whir_config_for_size(num_variables: usize, batch_size: usize) -> WhirConfig; } impl WhirR1CSSchemeBuilder for WhirR1CSScheme { - fn new_for_r1cs(r1cs: &R1CS, w1_size: usize, num_challenges: usize, has_public_inputs: bool) -> Self { + fn new_for_r1cs( + r1cs: &R1CS, + w1_size: usize, + num_challenges: usize, + has_public_inputs: bool, + ) -> Self { let total_witnesses = r1cs.num_witnesses(); assert!( w1_size <= total_witnesses, diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index dc34ece57..c543d20ed 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -389,7 +389,6 @@ func verifyCircuit( Values: make([]frontend.Variable, len(publicInputs.Values)), } - circuit := Circuit{ IO: []byte(cfg.IOPattern), Transcript: contTranscript,