diff --git a/.github/workflows/end-to-end.yml b/.github/workflows/end-to-end.yml index a560f7341..a3569682d 100644 --- a/.github/workflows/end-to-end.yml +++ b/.github/workflows/end-to-end.yml @@ -57,48 +57,64 @@ jobs: - name: Clean stale benchmark artifacts working-directory: noir-examples/noir-passport/merkle_age_check run: | - rm -f ./benchmark-inputs/*.pkp ./benchmark-inputs/*.pkv ./benchmark-inputs/*.np + rm -f ./benchmark-inputs/*.pkp ./benchmark-inputs/*.pkv ./benchmark-inputs/*.np ./benchmark-inputs/*.sp echo "Cleaned stale benchmark artifacts" - - name: Prepare circuits + - name: Prepare circuits working-directory: noir-examples/noir-passport/merkle_age_check run: | for circuit in t_add_dsc_720 t_add_id_data_720 t_add_integrity_commit t_attest; do echo "Preparing $circuit" - cargo run --release --bin provekit-cli prepare ./target/$circuit.json \ - --pkp ./benchmark-inputs/$circuit-prover.pkp \ - --pkv ./benchmark-inputs/$circuit-verifier.pkv + ../../../target/release/provekit-cli prepare \ + ./target/$circuit.json \ + --pkp ./benchmark-inputs/$circuit.pkp \ + --pkv ./benchmark-inputs/$circuit.pkv \ + --spark \ + --spc ./benchmark-inputs/$circuit.spc echo "Prepared $circuit" done - - name: Generate proofs for all circuits + - name: Prove all circuits working-directory: noir-examples/noir-passport/merkle_age_check run: | for circuit in t_add_dsc_720 t_add_id_data_720 t_add_integrity_commit t_attest; do echo "Proving $circuit" - cargo run --release --bin provekit-cli prove \ - ./benchmark-inputs/$circuit-prover.pkp \ + ../../../target/release/provekit-cli prove \ + ./benchmark-inputs/$circuit.pkp \ ./benchmark-inputs/tbs_720/$circuit.toml \ - -o ./benchmark-inputs/$circuit-proof.np + -o ./benchmark-inputs/$circuit-proof.np \ + --spark-queries-dir ./benchmark-inputs/$circuit-spark echo "Proved $circuit" done - - name: Verify proofs for all circuits + - name: Generate SPARK proofs for all circuits working-directory: noir-examples/noir-passport/merkle_age_check run: | for circuit in t_add_dsc_720 t_add_id_data_720 t_add_integrity_commit t_attest; do - echo "Verifying $circuit" - cargo run --release --bin provekit-cli verify \ - ./benchmark-inputs/$circuit-verifier.pkv \ - ./benchmark-inputs/$circuit-proof.np - echo "Verified $circuit" + echo "SPARK proving $circuit" + ../../../target/release/provekit-cli prove-spark \ + ./benchmark-inputs/$circuit.pkp \ + --spark-dir ./benchmark-inputs/$circuit-spark + echo "SPARK proved $circuit" + done + + - name: Verify SPARK proofs for all circuits + working-directory: noir-examples/noir-passport/merkle_age_check + run: | + for circuit in t_add_dsc_720 t_add_id_data_720 t_add_integrity_commit t_attest; do + echo "SPARK verifying $circuit" + ../../../target/release/provekit-cli verify-spark \ + ./benchmark-inputs/$circuit-spark/spark_proof_0.sp \ + ./benchmark-inputs/$circuit.spc \ + ./benchmark-inputs/$circuit-spark/spark_query_0.json + echo "SPARK verified $circuit" done - name: Generate Gnark inputs working-directory: noir-examples/noir-passport/merkle_age_check run: | - cargo run --release --bin provekit-cli generate-gnark-inputs \ - ./benchmark-inputs/t_attest-verifier.pkv \ + ../../../target/release/provekit-cli generate-gnark-inputs \ + ./benchmark-inputs/t_attest.pkv \ ./benchmark-inputs/t_attest-proof.np diff --git a/.gitignore b/.gitignore index b3a38e13b..1a1ec8e41 100644 --- a/.gitignore +++ b/.gitignore @@ -14,11 +14,15 @@ *.pkp *.pkv *.np +*.sp +*.spc +spark_proofs/ params_for_recursive_verifier params artifacts/ spartan_vm_debug/ mavros_debug/ +mavros/ # Don't ignore benchmarking artifacts !tooling/provekit-bench/benches/* diff --git a/Cargo.lock b/Cargo.lock index 92ae8c401..1c957ccaa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4597,15 +4597,19 @@ dependencies = [ "argh", "ark-ff 0.5.0", "base64", + "bincode", "hex", + "mavros-artifacts", "noirc_abi", "postcard", "provekit-common", "provekit-gnark", "provekit-prover", "provekit-r1cs-compiler", + "provekit-spark", "provekit-verifier", "rayon", + "serde", "serde_json", "tikv-jemallocator", "tracing", @@ -4725,6 +4729,20 @@ dependencies = [ "whir", ] +[[package]] +name = "provekit-spark" +version = "0.1.0" +dependencies = [ + "anyhow", + "ark-ff 0.5.0", + "ark-std 0.5.0", + "provekit-common", + "rayon", + "serde", + "tracing", + "whir", +] + [[package]] name = "provekit-verifier" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 73d5ac541..de06d003b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "tooling/provekit-wasm", "tooling/verifier-server", "ntt", + "provekit/spark", "poseidon2", "playground/passport-input-gen", ] @@ -102,6 +103,7 @@ provekit-ffi = { path = "tooling/provekit-ffi" } provekit-gnark = { path = "tooling/provekit-gnark" } provekit-prover = { path = "provekit/prover", default-features = false } provekit-r1cs-compiler = { path = "provekit/r1cs-compiler" } +provekit-spark = { path = "provekit/spark" } provekit-verifier = { path = "provekit/verifier" } provekit-verifier-server = { path = "tooling/verifier-server" } provekit-wasm = { path = "tooling/provekit-wasm" } diff --git a/noir-examples/power/Nargo.toml b/noir-examples/power/Nargo.toml index 839ecb852..44082b657 100644 --- a/noir-examples/power/Nargo.toml +++ b/noir-examples/power/Nargo.toml @@ -1,5 +1,5 @@ [package] -name = "basic" +name = "power" type = "bin" authors = [""] compiler_version = ">=0.22.0" diff --git a/noir-examples/power/benchmark-inputs/power.spc b/noir-examples/power/benchmark-inputs/power.spc new file mode 100644 index 000000000..21d3c3a16 Binary files /dev/null and b/noir-examples/power/benchmark-inputs/power.spc differ diff --git a/noir-examples/power/src/main.nr b/noir-examples/power/src/main.nr index 67f8f4ead..eee84870f 100644 --- a/noir-examples/power/src/main.nr +++ b/noir-examples/power/src/main.nr @@ -1,6 +1,6 @@ fn main(mut x: Field, y: pub Field) { let mut r = 1; - for i in 0..10 { + for _ in 0..1000 { r *= x; } assert(r == y); diff --git a/playground/passport-input-gen/src/bin/passport_cli/main.rs b/playground/passport-input-gen/src/bin/passport_cli/main.rs index 0a21bb4fa..6820b0380 100644 --- a/playground/passport-input-gen/src/bin/passport_cli/main.rs +++ b/playground/passport-input-gen/src/bin/passport_cli/main.rs @@ -267,7 +267,7 @@ fn prove_circuit( .map_err(|e| anyhow::anyhow!("ABI parse error for {circuit_name}: {e}"))?; tee_println!(" [{circuit_name}] Generating proof..."); - let proof = prover + let (proof, _) = prover .prove(input_map) .with_context(|| format!("Proving {circuit_name}"))?; diff --git a/provekit/common/src/file/binary_format.rs b/provekit/common/src/file/binary_format.rs index 44ff55717..2aa22d7be 100644 --- a/provekit/common/src/file/binary_format.rs +++ b/provekit/common/src/file/binary_format.rs @@ -24,4 +24,10 @@ pub const NOIR_PROOF_SCHEME_FORMAT: [u8; 8] = *b"NrProScm"; pub const NOIR_PROOF_SCHEME_VERSION: (u16, u16) = (1, 2); pub const NOIR_PROOF_FORMAT: [u8; 8] = *b"NPSProof"; -pub const NOIR_PROOF_VERSION: (u16, u16) = (1, 1); +pub const NOIR_PROOF_VERSION: (u16, u16) = (1, 2); + +pub const SPARK_PROOF_FORMAT: [u8; 8] = *b"SparkPrf"; +pub const SPARK_PROOF_VERSION: (u16, u16) = (1, 0); + +pub const SPARK_SETUP_FORMAT: [u8; 8] = *b"SparkStp"; +pub const SPARK_SETUP_VERSION: (u16, u16) = (1, 0); diff --git a/provekit/common/src/file/io/mod.rs b/provekit/common/src/file/io/mod.rs index 049c984a7..964a8f17c 100644 --- a/provekit/common/src/file/io/mod.rs +++ b/provekit/common/src/file/io/mod.rs @@ -3,11 +3,12 @@ mod buf_ext; mod counting_writer; mod json; +pub use self::bin::Compression; use { self::{ bin::{ deserialize_from_bytes, read_bin, read_hash_config as read_hash_config_bin, - serialize_to_bytes, write_bin, Compression, + serialize_to_bytes, write_bin, }, buf_ext::BufExt, counting_writer::CountingWriter, @@ -29,7 +30,7 @@ pub trait FileFormat: Serialize + for<'a> Deserialize<'a> { } /// Helper trait to optionally extract hash config. -pub(crate) trait MaybeHashAware { +pub trait MaybeHashAware { fn maybe_hash_config(&self) -> Option; } diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 3953207d8..add0fba84 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -1,7 +1,7 @@ pub mod file; pub use file::binary_format; pub mod hash_config; -mod interner; +pub mod interner; mod mavros; mod noir_proof_scheme; pub mod ntt; @@ -11,6 +11,7 @@ pub mod prefix_covector; mod prover; mod r1cs; pub mod skyscraper; +pub mod spark; pub mod sparse_matrix; mod transcript_sponge; pub mod u256_arith; diff --git a/provekit/common/src/spark.rs b/provekit/common/src/spark.rs new file mode 100644 index 000000000..0f25e714e --- /dev/null +++ b/provekit/common/src/spark.rs @@ -0,0 +1,29 @@ +use { + crate::{utils::serde_ark, FieldElement}, + serde::{Deserialize, Serialize}, + sha3::{Digest, Sha3_256}, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Point { + #[serde(with = "serde_ark")] + pub row: Vec, + #[serde(with = "serde_ark")] + pub col: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct R1CSSparkQuery { + pub point_to_evaluate: Point, + #[serde(with = "serde_ark")] + pub matrix_batching_randomness: FieldElement, + #[serde(with = "serde_ark")] + pub claimed_value: FieldElement, +} + +impl R1CSSparkQuery { + pub fn hash_bytes(&self) -> [u8; 32] { + let bytes = postcard::to_allocvec(self).expect("serializing R1CSSparkQuery"); + Sha3_256::digest(&bytes).into() + } +} diff --git a/provekit/common/src/utils/sumcheck.rs b/provekit/common/src/utils/sumcheck.rs index d703d8cd4..859d1d54e 100644 --- a/provekit/common/src/utils/sumcheck.rs +++ b/provekit/common/src/utils/sumcheck.rs @@ -152,6 +152,11 @@ fn eval_eq( } } +/// Evaluates a quadratic polynomial on a value +pub fn eval_quadratic_poly(poly: [FieldElement; 3], point: FieldElement) -> FieldElement { + poly[0] + point * (poly[1] + point * poly[2]) +} + /// Evaluates a cubic polynomial on a value pub fn eval_cubic_poly(poly: [FieldElement; 4], point: FieldElement) -> FieldElement { poly[0] + point * (poly[1] + point * (poly[2] + point * poly[3])) diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 28af714d9..87eb3a91f 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -8,8 +8,8 @@ use { acir::native_types::{Witness, WitnessMap}, anyhow::{Context, Result}, provekit_common::{ - utils::noir_to_native, FieldElement, NoirElement, NoirProof, NoirProver, Prover, - PublicInputs, TranscriptSponge, + spark::R1CSSparkQuery, utils::noir_to_native, FieldElement, NoirElement, NoirProof, + NoirProver, Prover, PublicInputs, TranscriptSponge, }, std::mem::size_of, tracing::{debug, info_span, instrument}, @@ -40,14 +40,24 @@ pub use {ec_arith::ec_scalar_mul, r1cs::solve_witness_vec}; /// `prove` and `prove_with_toml` are native-only (cfg-gated out on wasm32). /// `prove_with_witness` is available on all targets. `MavrosProver` does not /// support `prove_with_witness` (errors at runtime). +/// +/// All methods return the `NoirProof` plus a `Vec` of SPARK +/// queries produced as a side output. Callers that don't need the queries +/// can discard with `let (proof, _) = ...`. pub trait Prove { #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] - fn prove(self, input_map: InputMap) -> Result; + fn prove(self, input_map: InputMap) -> Result<(NoirProof, Vec)>; #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] - fn prove_with_toml(self, prover_toml: impl AsRef) -> Result; + fn prove_with_toml( + self, + prover_toml: impl AsRef, + ) -> Result<(NoirProof, Vec)>; - fn prove_with_witness(self, witness: WitnessMap) -> Result; + fn prove_with_witness( + self, + witness: WitnessMap, + ) -> Result<(NoirProof, Vec)>; } #[instrument(skip_all)] @@ -85,14 +95,17 @@ fn generate_noir_witness( impl Prove for NoirProver { #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] #[instrument(skip_all)] - fn prove(mut self, input_map: InputMap) -> Result { + fn prove(mut self, input_map: InputMap) -> Result<(NoirProof, Vec)> { let witness = generate_noir_witness(&mut self, input_map)?; self.prove_with_witness(witness) } #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] #[instrument(skip_all)] - fn prove_with_toml(self, prover_toml: impl AsRef) -> Result { + fn prove_with_toml( + self, + prover_toml: impl AsRef, + ) -> Result<(NoirProof, Vec)> { let (input_map, _return_value) = read_inputs_from_file(prover_toml.as_ref(), self.witness_generator.abi())?; self.prove(input_map) @@ -102,7 +115,7 @@ impl Prove for NoirProver { fn prove_with_witness( self, acir_witness_idx_to_value_map: WitnessMap, - ) -> Result { + ) -> Result<(NoirProof, Vec)> { provekit_common::register_ntt(); let mut public_input_indices = self.program.functions[0].public_inputs().indices(); @@ -255,22 +268,25 @@ impl Prove for NoirProver { .map(|(i, w)| w.ok_or_else(|| anyhow::anyhow!("Witness {i} unsolved after solving"))) .collect::>>()?; - let whir_r1cs_proof = self + let (whir_r1cs_proof, r1cs_spark_queries) = self .whir_for_witness .prove_noir(merlin, r1cs, commitments, full_witness, &public_inputs) .context("While proving R1CS instance")?; - Ok(NoirProof { - public_inputs, - whir_r1cs_proof, - }) + Ok(( + NoirProof { + public_inputs, + whir_r1cs_proof, + }, + r1cs_spark_queries, + )) } } #[cfg(not(target_arch = "wasm32"))] impl Prove for MavrosProver { #[cfg(feature = "witness-generation")] - fn prove(mut self, input_map: InputMap) -> Result { + fn prove(mut self, input_map: InputMap) -> Result<(NoirProof, Vec)> { provekit_common::register_ntt(); let params = crate::input_utils::ordered_params_from_btreemap(&self.abi, &input_map)?; @@ -341,7 +357,7 @@ impl Prove for MavrosProver { vec![commitment_1] }; - let whir_r1cs_proof = self + let (whir_r1cs_proof, r1cs_spark_queries) = self .whir_for_witness .prove_mavros( merlin, @@ -354,15 +370,21 @@ impl Prove for MavrosProver { ) .context("While proving R1CS instance")?; - Ok(NoirProof { - public_inputs, - whir_r1cs_proof, - }) + Ok(( + NoirProof { + public_inputs, + whir_r1cs_proof, + }, + r1cs_spark_queries, + )) } #[cfg(feature = "witness-generation")] #[instrument(skip_all)] - fn prove_with_toml(self, prover_toml: impl AsRef) -> Result { + fn prove_with_toml( + self, + prover_toml: impl AsRef, + ) -> Result<(NoirProof, Vec)> { let project_path = prover_toml .as_ref() .parent() @@ -373,7 +395,10 @@ impl Prove for MavrosProver { self.prove(input_map) } - fn prove_with_witness(self, _witness: WitnessMap) -> Result { + fn prove_with_witness( + self, + _witness: WitnessMap, + ) -> Result<(NoirProof, Vec)> { Err(anyhow::anyhow!( "prove_with_witness is not supported for Mavros prover" )) @@ -382,7 +407,7 @@ impl Prove for MavrosProver { impl Prove for Prover { #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] - fn prove(self, input_map: InputMap) -> Result { + fn prove(self, input_map: InputMap) -> Result<(NoirProof, Vec)> { match self { Prover::Noir(p) => p.prove(input_map), Prover::Mavros(p) => p.prove(input_map), @@ -390,14 +415,20 @@ impl Prove for Prover { } #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] - fn prove_with_toml(self, prover_toml: impl AsRef) -> Result { + fn prove_with_toml( + self, + prover_toml: impl AsRef, + ) -> Result<(NoirProof, Vec)> { match self { Prover::Noir(p) => p.prove_with_toml(prover_toml), Prover::Mavros(p) => p.prove_with_toml(prover_toml), } } - fn prove_with_witness(self, witness: WitnessMap) -> Result { + fn prove_with_witness( + self, + witness: WitnessMap, + ) -> Result<(NoirProof, Vec)> { match self { Prover::Noir(p) => p.prove_with_witness(witness), #[cfg(not(target_arch = "wasm32"))] diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 199177e16..f1873c2b7 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -1,6 +1,6 @@ use { anyhow::{ensure, Result}, - ark_ff::UniformRand, + ark_ff::{AdditiveGroup, UniformRand}, ark_std::{One, Zero}, provekit_common::{ prefix_covector::{ @@ -8,12 +8,13 @@ use { compute_public_eval, expand_powers, make_challenge_weight, make_public_weight, OffsetCovector, }, + spark::{Point, R1CSSparkQuery}, utils::{ pad_to_power_of_two, sumcheck::{ calculate_evaluations_over_boolean_hypercube_for_eq, calculate_witness_bounds, - eval_cubic_poly, multiply_transposed_by_eq_alpha, sumcheck_fold_map_reduce, - transpose_r1cs_matrices, + eval_cubic_poly, eval_quadratic_poly, multiply_transposed_by_eq_alpha, + sumcheck_fold_map_reduce, transpose_r1cs_matrices, }, HALF, }, @@ -24,8 +25,9 @@ use { tracing::instrument, whir::{ algebra::{dot, linear_form::LinearForm}, - protocols::whir_zk::Witness as WhirZkWitness, + protocols::{whir::FinalClaim, whir_zk::Witness as WhirZkWitness}, transcript::{ProverState, VerifierMessage}, + utils::zip_strict, }, }; #[cfg(not(target_arch = "wasm32"))] @@ -62,7 +64,7 @@ pub trait WhirR1CSProver { commitments: Vec, full_witness: Vec, public_inputs: &PublicInputs, - ) -> Result; + ) -> Result<(WhirR1CSProof, Vec)>; #[cfg(not(target_arch = "wasm32"))] fn prove_mavros( @@ -74,7 +76,7 @@ pub trait WhirR1CSProver { witness_layout: WitnessLayout, constraints_layout: ConstraintsLayout, ad_binary: &[u64], - ) -> Result; + ) -> Result<(WhirR1CSProof, Vec)>; } impl WhirR1CSProver for WhirR1CSScheme { @@ -147,7 +149,7 @@ impl WhirR1CSProver for WhirR1CSScheme { commitments: Vec, full_witness: Vec, public_inputs: &PublicInputs, - ) -> Result { + ) -> Result<(WhirR1CSProof, Vec)> { ensure!(!commitments.is_empty(), "Need at least one commitment"); let (a, b, c) = calculate_witness_bounds(&r1cs, &full_witness); @@ -177,6 +179,7 @@ impl WhirR1CSProver for WhirR1CSScheme { prove_from_alphas( self, merlin, + alpha, alphas, blinding_eval, blinding_offset, @@ -197,7 +200,7 @@ impl WhirR1CSProver for WhirR1CSScheme { witness_layout: WitnessLayout, constraints_layout: ConstraintsLayout, ad_binary: &[u64], - ) -> Result { + ) -> Result<(WhirR1CSProof, Vec)> { ensure!(!commitments.is_empty(), "Need at least one commitment"); let blinding = commitments[0] @@ -233,6 +236,7 @@ impl WhirR1CSProver for WhirR1CSScheme { prove_from_alphas( self, merlin, + alpha, alphas, blinding_eval, blinding_offset, @@ -247,13 +251,14 @@ impl WhirR1CSProver for WhirR1CSScheme { fn prove_from_alphas( scheme: &WhirR1CSScheme, mut merlin: ProverState, + alpha: Vec, alphas: [Vec; 3], blinding_eval: FieldElement, blinding_offset: usize, blinding_weights: Vec, commitments: Vec, public_inputs: &PublicInputs, -) -> Result { +) -> Result<(WhirR1CSProof, Vec)> { let public_inputs_hash = public_inputs.hash(scheme.hash_config); let public_inputs_len = public_inputs.len(); @@ -263,7 +268,7 @@ fn prove_from_alphas( let domain_size = 1usize << scheme.m; - if is_single { + let spark_queries: Vec = if is_single { // Single commitment path let commitment = commitments .into_iter() @@ -290,19 +295,52 @@ fn prove_from_alphas( let blinding_covector = OffsetCovector::new(blinding_weights, blinding_offset, domain_size); + let alpha_weight_data: Vec<_> = weights + .iter() + .map(|w| (w.vector().to_vec(), w.size())) + .collect(); + let mut boxed_weights: Vec>> = weights .into_iter() .map(|w| Box::new(w) as Box>) .collect(); boxed_weights.push(Box::new(blinding_covector)); - let _ = scheme.whir_witness.prove( + let public_offset = if public_inputs.is_empty() { 0 } else { 1 }; + + let final_claim = scheme.whir_witness.prove( &mut merlin, vec![Cow::Borrowed(commitment.polynomial.as_slice())], commitment.witness, boxed_weights, Cow::Borrowed(&evaluations), ); + + let rlc = zip_strict( + final_claim.rlc_coefficients[public_offset..(public_offset + 3)].iter(), + alpha_weight_data[public_offset..(public_offset + 3)].iter(), + ) + .map(|(&c, (vec, ds))| { + let w = PrefixCovector::new(vec.clone(), *ds); + c * w.mle_evaluate(&final_claim.evaluation_point) + }) + .sum::(); + + let claimed_batched_spark_value = if !public_inputs.is_empty() { + rlc / final_claim.rlc_coefficients[1] + } else { + rlc + }; + + let query = R1CSSparkQuery { + point_to_evaluate: Point { + row: alpha, + col: final_claim.evaluation_point, + }, + matrix_batching_randomness: final_claim.rlc_coefficients[1], + claimed_value: claimed_batched_spark_value, + }; + vec![query] } else { // Dual commitment path let mut commitments = commitments.into_iter(); @@ -314,6 +352,7 @@ fn prove_from_alphas( .expect("dual-commitment path requires second commitment"); let (alphas_1, alphas_2): (Vec<_>, Vec<_>) = alphas + .clone() .into_iter() .map(|mut v| { let v2 = v.split_off(scheme.w1_size); @@ -355,13 +394,16 @@ fn prove_from_alphas( None }; + let has_public = public_1.is_some(); + let public_offset_1 = if has_public { 1 } else { 0 }; + let WhirR1CSCommitment { witness: w1, polynomial: p1, .. } = c1; - { - let mut weights = build_prefix_covectors(scheme.m, alphas_1); + let (final_claim1, rlc1) = { + let mut weights = build_prefix_covectors(scheme.m, alphas_1.clone()); let mut evaluations: Vec = Vec::new(); if let Some(pe) = public_1 { weights.insert(0, make_public_weight(x, public_inputs_len, scheme.m)); @@ -373,20 +415,43 @@ fn prove_from_alphas( let blinding_covector = OffsetCovector::new(blinding_weights, blinding_offset, domain_size); + let alpha_weight_data_1: Vec<_> = weights[public_offset_1..public_offset_1 + 3] + .iter() + .map(|w| (w.vector().to_vec(), w.size())) + .collect(); + let mut boxed_weights: Vec>> = weights .into_iter() .map(|w| Box::new(w) as Box>) .collect(); boxed_weights.push(Box::new(blinding_covector)); - let _ = scheme.whir_witness.prove( + let final_claim1 = scheme.whir_witness.prove( &mut merlin, vec![Cow::Borrowed(p1.as_slice())], w1, boxed_weights, Cow::Borrowed(&evaluations), ); - } + + let rlc1_sum = zip_strict( + final_claim1.rlc_coefficients[public_offset_1..(public_offset_1 + 3)].iter(), + alpha_weight_data_1.iter(), + ) + .map(|(&c, (vec, ds))| { + let w = PrefixCovector::new(vec.clone(), *ds); + c * w.mle_evaluate(&final_claim1.evaluation_point) + }) + .sum::(); + + let claimed1 = if has_public { + rlc1_sum / final_claim1.rlc_coefficients[1] + } else { + rlc1_sum + }; + + (final_claim1, claimed1) + }; drop(p1); let WhirR1CSCommitment { @@ -394,10 +459,15 @@ fn prove_from_alphas( polynomial: p2, .. } = c2; - { - let weights = build_prefix_covectors(scheme.m, alphas_2); + let (final_claim2, rlc2) = { + let weights = build_prefix_covectors(scheme.m, alphas_2.clone()); let mut evaluations: Vec = evals_2; + let alpha_weight_data_2: Vec<_> = weights[0..3] + .iter() + .map(|w| (w.vector().to_vec(), w.size())) + .collect(); + let mut boxed_weights: Vec>> = weights .into_iter() .map(|w| Box::new(w) as Box>) @@ -409,23 +479,62 @@ fn prove_from_alphas( boxed_weights.push(Box::new(cw)); } - let _ = scheme.whir_witness.prove( + let final_claim2 = scheme.whir_witness.prove( &mut merlin, vec![Cow::Borrowed(p2.as_slice())], w2, boxed_weights, Cow::Borrowed(&evaluations), ); - } - } + + let rlc2_sum = zip_strict( + final_claim2.rlc_coefficients[0..3].iter(), + alpha_weight_data_2.iter(), + ) + .map(|(&c, (vec, ds))| { + let w = PrefixCovector::new(vec.clone(), *ds); + c * w.mle_evaluate(&final_claim2.evaluation_point) + }) + .sum::(); + + (final_claim2, rlc2_sum) + }; + + let mut col1 = final_claim1.evaluation_point.clone(); + col1.insert(0, FieldElement::zero()); + let query1 = R1CSSparkQuery { + point_to_evaluate: Point { + row: alpha.clone(), + col: col1, + }, + matrix_batching_randomness: final_claim1.rlc_coefficients[1], + claimed_value: rlc1, + }; + + let mut col2 = final_claim2.evaluation_point.clone(); + col2.insert(0, FieldElement::one()); + let query2 = R1CSSparkQuery { + point_to_evaluate: Point { + row: alpha, + col: col2, + }, + matrix_batching_randomness: final_claim2.rlc_coefficients[1], + claimed_value: rlc2, + }; + + vec![query1, query2] + }; let proof = merlin.proof(); - Ok(WhirR1CSProof { - narg_string: proof.narg_string, - hints: proof.hints, - #[cfg(debug_assertions)] - pattern: proof.pattern, - }) + Ok(( + WhirR1CSProof { + narg_string: proof.narg_string, + hints: proof.hints, + #[cfg(debug_assertions)] + pattern: proof.pattern, + }, + spark_queries, + )) } pub fn compute_blinding_coefficients_for_round( diff --git a/provekit/spark/Cargo.toml b/provekit/spark/Cargo.toml new file mode 100644 index 000000000..1f6b62f96 --- /dev/null +++ b/provekit/spark/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "provekit-spark" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +provekit-common.workspace = true +ark-ff.workspace = true +ark-std.workspace = true +anyhow.workspace = true +serde.workspace = true +whir.workspace = true +tracing.workspace = true +rayon.workspace = true + +[lints] +workspace = true diff --git a/provekit/spark/SPARK.md b/provekit/spark/SPARK.md new file mode 100644 index 000000000..da8f17b26 --- /dev/null +++ b/provekit/spark/SPARK.md @@ -0,0 +1,90 @@ +# SPARK + +Reference for this implementation +- SPARK: https://eprint.iacr.org/2019/550 +- Stronger security analysis of SPARK: https://people.cs.georgetown.edu/jthaler/Lasso-paper.pdf + +## Proposed prototype workflow +1. Provekit prepare step + - Compiles the circuit and writes the prover/verifier artifacts (`.pkp`, `.pkv`). + - With `--spark`, also runs SPARK preprocessing once and writes the SPARK + setup transcript (`.spc`). + +2. Provekit prove step + - Runs the provekit prover and obtains the Noir proof plus the deferred + matrix evaluations (SPARK queries). + - Writes each query as `spark_query_.json` to `--spark-queries-dir`. + +3. Provekit prove-spark step + - Reads queries from `--spark-dir` and produces a SPARK proof per query + (`spark_proof_.sp` written back to the same directory). + +4. Provekit and SPARK verify step + - Verifies Provekit and SPARK proofs + +## Design decisions + +### Pack $A$, $B$, $C$ into one block matrix Z: +This is a result from Marcin (https://gist.github.com/kustosz/14b62de666f721ab855536e575891bd1) + +**The trick:** + +$$Z = \begin{bmatrix} A & B \\ 0 & C \end{bmatrix}$$ + +Same total non-zeros, double the dimensions. Then for any $\beta$, $p$, and $q$: + +$$A(p,q) + \beta B(p,q) + \beta^2 C(p,q) = (1+\beta)^2 \cdot Z\!\left(\tfrac{\beta}{1+\beta}, p,\ \tfrac{\beta}{1+\beta}, q\right)$$ + +One matrix, one commitment, one opening. + +### Batching GPA and WHIR proofs + +- Combining GPA + - Products of hashes corresponding to read sets and write sets of row-wise and column-wise memory check are combined into one GPA + - Products of hashes corresponding to init and final vectors are combined into one GPA (separate for row-wise and col-wise memory). Possible optimization - if number of rows and columns for the matrix are ensured to be equal, we can combine them into one GPA. + +- WHIR Batching +| `num_terms_2batched` e-values are committed and opened together. Opened once in sumcheck and once in rs_ws GPA +| `num_terms_4batched` | Address/timestamp values for row-wise and col-wise memory checks are committed and opened together + +### Split witness: two SPARK queries +The current ZK WHIR doesn't support batching which would enable easier handling of split witness commitment. + +The current implementation emits **two SPARK +queries** for the dual-commitment path — one per split half. + + +## Full workflow for a Noir passport circuit: + +```bash +cargo build --release --bin provekit-cli + +cd noir-examples/power +nargo compile + +# 1. Prepare the circuit (compiles and writes prover/verifier artifacts plus +# the SPARK setup transcript). +cargo run --release --bin provekit-cli -- prepare ./target/power.json \ + --pkp ./benchmark-inputs/power.pkp \ + --pkv ./benchmark-inputs/power.pkv \ + --spark \ + --spc ./benchmark-inputs/power.spc + +# 2. Prove (generates Noir proof + writes SPARK queries to disk). +cargo run --release --bin provekit-cli -- prove ./benchmark-inputs/power.pkp ./Prover.toml \ + -o ./benchmark-inputs/power-proof.np \ + --spark-queries-dir ./spark_proofs + +# 3. Generate SPARK proofs for the queries written in step 2. +cargo run --release --bin provekit-cli -- prove-spark ./benchmark-inputs/power.pkp \ + --spark-dir ./spark_proofs + +# 4. Natively verify the Noir proof. Native verification evaluates MLE directly. Spark proofs are useful only in the recursive verifier. +cargo run --release --bin provekit-cli -- verify ./benchmark-inputs/power.pkv ./benchmark-inputs/power-proof.np + +# 5. Verify a standalone SPARK proof. Needs the per-proof artifacts (.sp, .json) +# plus the SPARK setup transcript (.spc) emitted by `prepare --spark`. +cargo run --release --bin provekit-cli -- verify-spark ./spark_proofs/spark_proof_0.sp ./benchmark-inputs/power.spc ./spark_proofs/spark_query_0.json + +# TODO: 6. Recursively verify the Noir proof and SPARK. +``` \ No newline at end of file diff --git a/provekit/spark/src/gpa.rs b/provekit/spark/src/gpa.rs new file mode 100644 index 000000000..2715c6d66 --- /dev/null +++ b/provekit/spark/src/gpa.rs @@ -0,0 +1,433 @@ +use { + anyhow::{ensure, Context}, + provekit_common::{ + utils::{ + next_power_of_two, + sumcheck::{ + calculate_eq, calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, + sumcheck_fold_map_reduce, + }, + HALF, + }, + FieldElement, TranscriptSponge, + }, + tracing::instrument, + whir::transcript::{ProverState, VerifierMessage, VerifierState}, +}; + +#[instrument(skip_all)] +pub fn run_gpa2( + merlin: &mut ProverState, + left: &[FieldElement], + right: &[FieldElement], +) -> anyhow::Result> { + let mut concatenated = left.to_vec(); + concatenated.extend_from_slice(right); + let mut layers = calculate_binary_multiplication_tree(concatenated)?; + + let mut drain = layers.drain(1..); + + let first_layer = drain.next().context("GPA tree has fewer than 2 layers")?; + let (accumulated_randomness, mut sumcheck_claim) = add_line_to_transcript(merlin, first_layer); + let mut accumulated_randomness = accumulated_randomness.to_vec(); + + for layer in drain { + (sumcheck_claim, accumulated_randomness) = + run_gpa_sumcheck(merlin, layer, sumcheck_claim, accumulated_randomness)?; + } + + Ok(accumulated_randomness) +} + +#[instrument(skip_all)] +pub fn run_gpa4( + merlin: &mut ProverState, + leaves: Vec, +) -> anyhow::Result> { + let mut layers = calculate_binary_multiplication_tree(leaves)?; + + let mut drain = layers.drain(2..); + + let coeffs = drain.next().context("GPA tree has fewer than 3 layers")?; + let coeffs = [ + coeffs[0], + coeffs[1] - coeffs[0], + coeffs[2] - coeffs[0], + coeffs[3] - coeffs[2] - coeffs[1] + coeffs[0], + ]; + + for c in &coeffs { + merlin.prover_message(c); + } + + let r0: FieldElement = merlin.verifier_message(); + let r1: FieldElement = merlin.verifier_message(); + let mut accumulated_randomness = vec![r0, r1]; + + let mut sumcheck_claim = coeffs[0] + coeffs[1] * r1 + coeffs[2] * r0 + coeffs[3] * r0 * r1; + + for layer in drain { + (sumcheck_claim, accumulated_randomness) = + run_gpa_sumcheck(merlin, layer, sumcheck_claim, accumulated_randomness)?; + } + + Ok(accumulated_randomness) +} + +fn calculate_binary_multiplication_tree( + array_to_prove: Vec, +) -> anyhow::Result>> { + use rayon::prelude::*; + + ensure!( + array_to_prove.len() == (1 << next_power_of_two(array_to_prove.len())), + "Input length must be power of two" + ); + + let mut layers = vec![]; + let mut current_layer = array_to_prove; + + while current_layer.len() > 1 { + let next_layer: Vec = current_layer + .par_chunks_exact(2) + .map(|pair| pair[0] * pair[1]) + .collect(); + + layers.push(current_layer); + current_layer = next_layer; + } + + layers.push(current_layer); + layers.reverse(); + Ok(layers) +} + +fn add_line_to_transcript( + merlin: &mut ProverState, + arr: Vec, +) -> ([FieldElement; 1], FieldElement) { + let line_poly = [arr[0], arr[1] - arr[0]]; + + for c in line_poly.iter() { + merlin.prover_message(c); + } + + let challenge: FieldElement = merlin.verifier_message(); + + let next_claim = line_poly[0] + line_poly[1] * challenge; + + ([challenge], next_claim) +} + +fn run_gpa_sumcheck( + merlin: &mut ProverState, + layer: Vec, + mut sumcheck_claim: FieldElement, + accumulated_randomness: Vec, +) -> anyhow::Result<(FieldElement, Vec)> { + let (mut even_layer, mut odd_layer) = split_even_odd(layer); + + let mut eq_evaluations = calculate_evaluations_over_boolean_hypercube_for_eq( + &accumulated_randomness, + 1 << accumulated_randomness.len(), + ); + let mut challenge; + let mut round_randomness = Vec::::new(); + let mut fold = None; + + loop { + let [eval_at_0, eval_at_neg1, eval_at_inf_over_x3] = sumcheck_fold_map_reduce( + [&mut eq_evaluations, &mut even_layer, &mut odd_layer], + fold, + |[eq, v0, v1]| { + [ + eq.0 * v0.0 * v1.0, + (eq.0 + eq.0 - eq.1) * (v0.0 + v0.0 - v0.1) * (v1.0 + v1.0 - v1.1), + (eq.1 - eq.0) * (v0.1 - v0.0) * (v1.1 - v1.0), + ] + }, + ); + + if fold.is_some() { + eq_evaluations.truncate(eq_evaluations.len() / 2); + even_layer.truncate(even_layer.len() / 2); + odd_layer.truncate(odd_layer.len() / 2); + } + + let poly_coeffs = reconstruct_cubic_from_evaluations( + sumcheck_claim, + eval_at_0, + eval_at_neg1, + eval_at_inf_over_x3, + ); + + ensure!( + sumcheck_claim + == poly_coeffs[0] + + poly_coeffs[0] + + poly_coeffs[1] + + poly_coeffs[2] + + poly_coeffs[3], + "Sumcheck binding check failed" + ); + + for coeff in &poly_coeffs { + merlin.prover_message(coeff); + } + challenge = merlin.verifier_message(); + + fold = Some(challenge); + sumcheck_claim = eval_cubic_poly(poly_coeffs, challenge); + round_randomness.push(challenge); + + if eq_evaluations.len() <= 2 { + break; + } + } + + let final_v0 = even_layer[0] + (even_layer[1] - even_layer[0]) * challenge; + let final_v1 = odd_layer[0] + (odd_layer[1] - odd_layer[0]) * challenge; + let final_v2 = eq_evaluations[0] + (eq_evaluations[1] - eq_evaluations[0]) * challenge; + + ensure!( + sumcheck_claim == final_v0 * final_v1 * final_v2, + "GPA sumcheck claim mismatch" + ); + + let line_coeffs = [final_v0, final_v1 - final_v0]; + + for c in &line_coeffs { + merlin.prover_message(c); + } + + let line_challenge: FieldElement = merlin.verifier_message(); + let next_claim = line_coeffs[0] + line_coeffs[1] * line_challenge; + round_randomness.push(line_challenge); + + Ok((next_claim, round_randomness)) +} + +fn reconstruct_cubic_from_evaluations( + binding_value: FieldElement, + at_0: FieldElement, + at_neg1: FieldElement, + at_inf_over_x3: FieldElement, +) -> [FieldElement; 4] { + let mut coeffs = [FieldElement::from(0u64); 4]; + + coeffs[0] = at_0; + coeffs[2] = HALF * (binding_value + at_neg1 - at_0 - at_0 - at_0); + coeffs[3] = at_inf_over_x3; + coeffs[1] = binding_value - coeffs[0] - coeffs[0] - coeffs[3] - coeffs[2]; + + coeffs +} + +fn split_even_odd(input: Vec) -> (Vec, Vec) { + input + .chunks_exact(2) + .map(|chunk| (chunk[0], chunk[1])) + .unzip() +} + +pub struct GPASumcheckResult { + pub claimed_values: Vec, + pub last_sumcheck_value: FieldElement, + pub randomness: Vec, +} + +#[instrument(skip_all)] +pub fn gpa_sumcheck_verifier2( + arthur: &mut VerifierState<'_, TranscriptSponge>, + height_of_binary_tree: usize, +) -> anyhow::Result { + let mut prev_randomness; + let mut current_randomness = Vec::::new(); + + let claimed_0: FieldElement = arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let claimed_1: FieldElement = arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let claimed_values = [claimed_0, claimed_1]; + + let line_challenge: FieldElement = arthur.verifier_message(); + + let mut sumcheck_value = eval_line(&claimed_values, &line_challenge); + current_randomness.push(line_challenge); + prev_randomness = current_randomness; + current_randomness = Vec::new(); + + for layer_idx in 1..height_of_binary_tree - 1 { + for _ in 0..layer_idx { + let cubic_coeffs: [FieldElement; 4] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let sumcheck_challenge: FieldElement = arthur.verifier_message(); + + ensure!( + eval_cubic_poly(cubic_coeffs, FieldElement::from(0u64)) + + eval_cubic_poly(cubic_coeffs, FieldElement::from(1u64)) + == sumcheck_value, + "Sumcheck verification failed at layer {layer_idx}" + ); + + current_randomness.push(sumcheck_challenge); + sumcheck_value = eval_cubic_poly(cubic_coeffs, sumcheck_challenge); + } + + let line_coeffs: [FieldElement; 2] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let line_challenge: FieldElement = arthur.verifier_message(); + + let expected_line_value = calculate_eq(&prev_randomness, ¤t_randomness) + * eval_line(&line_coeffs, &FieldElement::from(0u64)) + * eval_line(&line_coeffs, &FieldElement::from(1u64)); + ensure!( + expected_line_value == sumcheck_value, + "Line evaluation mismatch" + ); + + current_randomness.push(line_challenge); + prev_randomness = current_randomness; + current_randomness = Vec::new(); + sumcheck_value = eval_line(&line_coeffs, &line_challenge); + } + + let claimed_values = [claimed_values[0], claimed_values[0] + claimed_values[1]].to_vec(); + + Ok(GPASumcheckResult { + claimed_values, + last_sumcheck_value: sumcheck_value, + randomness: prev_randomness, + }) +} + +#[instrument(skip_all)] +pub fn gpa_sumcheck_verifier4( + arthur: &mut VerifierState<'_, TranscriptSponge>, + height_of_binary_tree: usize, +) -> anyhow::Result { + let claimed_values: [FieldElement; 4] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let r0: FieldElement = arthur.verifier_message(); + let r1: FieldElement = arthur.verifier_message(); + let mut prev_randomness = vec![r0, r1]; + let mut current_randomness = Vec::::new(); + + let mut sumcheck_value = claimed_values[0] + + claimed_values[1] * prev_randomness[1] + + claimed_values[2] * prev_randomness[0] + + claimed_values[3] * prev_randomness[0] * prev_randomness[1]; + + for layer_idx in 2..height_of_binary_tree - 1 { + for _ in 0..layer_idx { + let cubic_coeffs: [FieldElement; 4] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let sumcheck_challenge: FieldElement = arthur.verifier_message(); + + ensure!( + eval_cubic_poly(cubic_coeffs, FieldElement::from(0u64)) + + eval_cubic_poly(cubic_coeffs, FieldElement::from(1u64)) + == sumcheck_value, + "Sumcheck verification failed at layer {layer_idx}" + ); + + current_randomness.push(sumcheck_challenge); + sumcheck_value = eval_cubic_poly(cubic_coeffs, sumcheck_challenge); + } + + let line_coeffs: [FieldElement; 2] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let line_challenge: FieldElement = arthur.verifier_message(); + + let expected_line_value = calculate_eq(&prev_randomness, ¤t_randomness) + * eval_line(&line_coeffs, &FieldElement::from(0u64)) + * eval_line(&line_coeffs, &FieldElement::from(1u64)); + ensure!( + expected_line_value == sumcheck_value, + "Line evaluation mismatch" + ); + + current_randomness.push(line_challenge); + prev_randomness = current_randomness; + current_randomness = Vec::new(); + sumcheck_value = eval_line(&line_coeffs, &line_challenge); + } + + let claimed_values = [ + claimed_values[0], + claimed_values[0] + claimed_values[1], + claimed_values[0] + claimed_values[2], + claimed_values[0] + claimed_values[1] + claimed_values[2] + claimed_values[3], + ] + .to_vec(); + + Ok(GPASumcheckResult { + claimed_values, + last_sumcheck_value: sumcheck_value, + randomness: prev_randomness, + }) +} + +pub fn eval_line(poly: &[FieldElement], point: &FieldElement) -> FieldElement { + poly[0] + *point * poly[1] +} + +pub fn calculate_adr(randomness: &[FieldElement]) -> FieldElement { + randomness + .iter() + .rev() + .enumerate() + .fold(FieldElement::from(0u64), |acc, (i, &r)| { + acc + r * FieldElement::from(1u64 << i) + }) +} diff --git a/provekit/spark/src/lib.rs b/provekit/spark/src/lib.rs new file mode 100644 index 000000000..3f1af34ba --- /dev/null +++ b/provekit/spark/src/lib.rs @@ -0,0 +1,19 @@ +pub mod gpa; +pub mod memory; +pub mod prover; +pub mod setup; +pub mod sumcheck; +pub mod types; +pub mod utils; +pub mod verifier; + +pub use { + prover::{SPARKProver, SPARKScheme as SPARKProverScheme}, + setup::preprocess_spark, + types::{ + MatrixDimensions, SPARKProof, SPARKSetup, SPARKWHIRConfigs, SparkProverContext, + SparkWitnesses, + }, + utils::calculate_memory, + verifier::{SPARKScheme as SPARKVerifierScheme, SPARKVerifier}, +}; diff --git a/provekit/spark/src/memory.rs b/provekit/spark/src/memory.rs new file mode 100644 index 000000000..1a0bfb34e --- /dev/null +++ b/provekit/spark/src/memory.rs @@ -0,0 +1,159 @@ +use { + crate::{ + gpa::{calculate_adr, gpa_sumcheck_verifier2, run_gpa2}, + types::{Challenges, WhirWitness}, + }, + anyhow::{ensure, Result}, + ark_std::One, + provekit_common::{FieldElement, TranscriptSponge, WhirConfig}, + rayon::prelude::*, + std::borrow::Cow, + tracing::instrument, + whir::{ + algebra::{linear_form::MultilinearExtension, multilinear_extend}, + protocols::irs_commit::Commitment, + transcript::{ProverState, VerifierState}, + }, +}; + +pub struct AxisConfig<'a> { + pub eq_memory: &'a [FieldElement], + pub final_timestamp: &'a [FieldElement], + pub whir_config: &'a WhirConfig, +} + +#[instrument(skip_all)] +pub fn prove_axis_init_final_product( + merlin: &mut ProverState, + config: AxisConfig<'_>, + final_ts_witness: &WhirWitness, + challenges: &Challenges, +) -> Result<()> { + let gamma = &challenges.gamma; + let tau = &challenges.tau; + let gamma_sq = *gamma * *gamma; + + let (init_vec, final_vec) = tracing::info_span!("build_init_final_vecs").in_scope(|| { + rayon::join( + || { + config + .eq_memory + .par_iter() + .enumerate() + .map(|(i, &v)| { + let a = FieldElement::from(i as u64); + a * gamma_sq + v * gamma - tau + }) + .collect::>() + }, + || { + config + .eq_memory + .par_iter() + .zip(config.final_timestamp.par_iter()) + .enumerate() + .map(|(i, (&v, &t))| { + let a = FieldElement::from(i as u64); + a * gamma_sq + v * gamma + t - tau + }) + .collect::>() + }, + ) + }); + + let gpa_randomness = run_gpa2(merlin, &init_vec, &final_vec)?; + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + let final_ts_eval = multilinear_extend(config.final_timestamp, evaluation_randomness); + merlin.prover_hint_ark(&final_ts_eval); + + produce_whir_proof( + merlin, + evaluation_randomness, + &[config.final_timestamp], + config.whir_config, + final_ts_witness, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +pub fn verify_axis( + arthur: &mut VerifierState<'_, TranscriptSponge>, + num_axis_items: usize, + whir_config: &WhirConfig, + finalts_commitment: Commitment, + init_mem_fn: impl Fn(&[FieldElement]) -> FieldElement, + tau: &FieldElement, + gamma: &FieldElement, + claimed_rs: &FieldElement, + claimed_ws: &FieldElement, +) -> Result<()> { + let gpa_result = gpa_sumcheck_verifier2( + arthur, + provekit_common::utils::next_power_of_two(num_axis_items) + 2, + )?; + + let claimed_init = gpa_result.claimed_values[0]; + let claimed_final = gpa_result.claimed_values[1]; + let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + + let gamma_sq = *gamma * *gamma; + + let init_adr = calculate_adr(evaluation_randomness); + let init_mem = init_mem_fn(evaluation_randomness); + let init_opening = init_adr * gamma_sq + init_mem * gamma - tau; + + let final_cntr: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + + let eval_weight = MultilinearExtension::new(evaluation_randomness.to_vec()); + let finalts_claim = whir_config + .verify(arthur, &[&finalts_commitment], &[final_cntr]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed: {e}"))?; + finalts_claim + .verify([&eval_weight as &dyn whir::algebra::linear_form::LinearForm]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for final timestamps: {e}"))?; + + let final_opening = init_adr * gamma_sq + init_mem * gamma + final_cntr - tau; + + let evaluated_value = init_opening * (FieldElement::one() - last_randomness[0]) + + final_opening * last_randomness[0]; + + ensure!(evaluated_value == gpa_result.last_sumcheck_value); + + ensure!(claimed_init * claimed_ws == claimed_final * claimed_rs); + + Ok(()) +} + +#[instrument(skip_all)] +pub fn produce_whir_proof( + merlin: &mut ProverState, + evaluation_point: &[FieldElement], + vectors: &[&[FieldElement]], + config: &WhirConfig, + witness: &WhirWitness, +) -> Result<()> { + let lf = MultilinearExtension::new(evaluation_point.to_vec()); + + let evaluations: Vec = vectors + .iter() + .map(|v| multilinear_extend(v, evaluation_point)) + .collect(); + + _ = config.prove( + merlin, + vectors.iter().map(|v| Cow::Borrowed(*v)).collect(), + vec![Cow::Owned(witness.clone())], + vec![Box::new(lf) + as Box< + dyn whir::algebra::linear_form::LinearForm, + >], + Cow::Borrowed(&evaluations), + ); + + Ok(()) +} diff --git a/provekit/spark/src/prover.rs b/provekit/spark/src/prover.rs new file mode 100644 index 000000000..d132761b3 --- /dev/null +++ b/provekit/spark/src/prover.rs @@ -0,0 +1,430 @@ +use { + crate::{ + gpa::run_gpa4, + memory::{produce_whir_proof, prove_axis_init_final_product, AxisConfig}, + sumcheck::run_spark_sumcheck, + types::{ + Challenges, EValuesForMatrix, MatrixDimensions, Memory, SPARKProof, SPARKWHIRConfigs, + SparkMatrix, SparkProverContext, WhirWitness, + }, + utils::calculate_memory, + }, + anyhow::{ensure, Result}, + ark_ff::{Field, Zero}, + provekit_common::{ + spark::R1CSSparkQuery, utils::next_power_of_two, FieldElement, TranscriptSponge, + WhirConfig, WhirR1CSProof, + }, + rayon::{join, prelude::*}, + tracing::instrument, + whir::{ + algebra::multilinear_extend, + parameters::ProtocolParameters, + transcript::{DomainSeparator, ProverState, VerifierMessage}, + }, +}; + +pub trait SPARKProver { + fn prove( + &self, + spark_data: &SparkProverContext, + request: &R1CSSparkQuery, + ) -> Result; +} + +pub struct SPARKScheme { + pub whir_configs: SPARKWHIRConfigs, + pub matrix_dimensions: MatrixDimensions, +} + +pub fn new_whir_config_for_size(log_size: usize, batch_size: usize) -> WhirConfig { + let nv = log_size.max(4); + + let whir_params = ProtocolParameters { + unique_decoding: false, + initial_folding_factor: 3, + security_level: 128, + pow_bits: 10, + folding_factor: 3, + starting_log_inv_rate: 2, + batch_size, + hash_id: whir::hash::SHA2, + }; + + WhirConfig::new(1 << nv, &whir_params) +} + +impl SPARKScheme { + pub fn new_for_r1cs(r1cs: &provekit_common::R1CS) -> Self { + let num_rows = 2 * r1cs.num_constraints(); + let num_cols = 2 * r1cs.num_witnesses(); + let nonzero_terms = + r1cs.a().iter().count() + r1cs.b().iter().count() + r1cs.c().iter().count(); + + Self::new(num_rows, num_cols, nonzero_terms) + } + + pub fn new(num_rows: usize, num_cols: usize, nonzero_terms: usize) -> Self { + let padded_num_entries = 1 << next_power_of_two(nonzero_terms); + + let row_config = new_whir_config_for_size(next_power_of_two(num_rows), 1); + let col_config = new_whir_config_for_size(next_power_of_two(num_cols), 1); + let num_terms_1batched_config = + new_whir_config_for_size(next_power_of_two(padded_num_entries), 1); + let num_terms_2batched_config = + new_whir_config_for_size(next_power_of_two(padded_num_entries), 2); + let num_terms_4batched_config = + new_whir_config_for_size(next_power_of_two(padded_num_entries), 4); + + Self { + whir_configs: SPARKWHIRConfigs { + row: row_config, + col: col_config, + num_terms_1batched: num_terms_1batched_config, + num_terms_2batched: num_terms_2batched_config, + num_terms_4batched: num_terms_4batched_config, + }, + matrix_dimensions: MatrixDimensions { + num_rows, + num_cols, + nonzero_terms, + }, + } + } +} + +impl SPARKProver for SPARKScheme { + #[instrument(skip_all)] + fn prove( + &self, + spark_data: &SparkProverContext, + request: &R1CSSparkQuery, + ) -> Result { + ensure!( + !(FieldElement::ONE + request.matrix_batching_randomness).is_zero(), + "matrix_batching_randomness must not equal -1 (would zero the SPARK denominator)" + ); + + let padded_num_entries = spark_data.matrix.coo.val.len(); + + let mut merlin = ProverState::new( + &DomainSeparator::protocol(&self.whir_configs) + .session(&spark_data.setup.transcript.narg_string) + .instance(&request.hash_bytes()), + TranscriptSponge::default(), + ); + + let (memory, e_values) = compute_spark_data(request, spark_data, padded_num_entries); + + let claimed_value = (request.claimed_value + / (FieldElement::ONE + request.matrix_batching_randomness)) + / (FieldElement::ONE + request.matrix_batching_randomness); + + prove_spark( + &mut merlin, + spark_data, + &e_values, + claimed_value, + &memory, + &self.whir_configs, + )?; + + let proof = merlin.proof(); + Ok(SPARKProof(WhirR1CSProof { + narg_string: proof.narg_string, + hints: proof.hints, + #[cfg(debug_assertions)] + pattern: proof.pattern, + })) + } +} + +#[instrument(skip_all)] +fn compute_spark_data( + request: &R1CSSparkQuery, + spark_data: &SparkProverContext, + padded_num_entries: usize, +) -> (Memory, EValuesForMatrix) { + let memory = compute_memory(request); + let e_values = compute_e_values(spark_data, &memory, padded_num_entries); + (memory, e_values) +} + +#[instrument(skip_all)] +fn compute_memory(request: &R1CSSparkQuery) -> Memory { + calculate_memory( + request.matrix_batching_randomness + / (FieldElement::ONE + request.matrix_batching_randomness), + &request.point_to_evaluate.row, + &request.point_to_evaluate.col, + ) +} + +#[instrument(skip_all)] +fn compute_e_values( + spark_data: &SparkProverContext, + memory: &Memory, + padded_num_entries: usize, +) -> EValuesForMatrix { + let (e_rx, e_ry) = rayon::join( + || { + spark_data.matrix.coo.row[..padded_num_entries] + .par_iter() + .map(|&r| memory.eq_rx[r]) + .collect() + }, + || { + spark_data.matrix.coo.col[..padded_num_entries] + .par_iter() + .map(|&c| memory.eq_ry[c]) + .collect() + }, + ); + EValuesForMatrix { e_rx, e_ry } +} + +#[instrument(skip_all)] +fn prove_spark( + merlin: &mut ProverState, + data: &SparkProverContext, + e_values: &EValuesForMatrix, + claimed_value: FieldElement, + memory: &Memory, + whir_configs: &SPARKWHIRConfigs, +) -> Result<()> { + let e_values_witness = commit_e_values(merlin, whir_configs, e_values); + + sumcheck_and_its_proofs( + merlin, + &data.matrix, + e_values, + claimed_value, + &e_values_witness, + &data.witnesses.vals_witness, + whir_configs, + )?; + + memory_checking( + merlin, + data, + e_values, + &e_values_witness, + memory, + whir_configs, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +fn memory_checking( + merlin: &mut ProverState, + data: &SparkProverContext, + e_values: &EValuesForMatrix, + e_values_witness: &WhirWitness, + memory: &Memory, + whir_configs: &SPARKWHIRConfigs, +) -> Result<()> { + let tau: FieldElement = merlin.verifier_message(); + let gamma: FieldElement = merlin.verifier_message(); + let challenges = Challenges { tau, gamma }; + + prove_combined_rs_ws_product( + merlin, + &data.matrix, + e_values, + e_values_witness, + &data.witnesses.rs_ws_witness, + whir_configs, + &challenges, + )?; + + prove_axis_init_final_product( + merlin, + AxisConfig { + eq_memory: &memory.eq_rx, + final_timestamp: &data.matrix.timestamps.final_row, + whir_config: &whir_configs.row, + }, + &data.witnesses.final_row_ts_witness, + &challenges, + )?; + + prove_axis_init_final_product( + merlin, + AxisConfig { + eq_memory: &memory.eq_ry, + final_timestamp: &data.matrix.timestamps.final_col, + whir_config: &whir_configs.col, + }, + &data.witnesses.final_col_ts_witness, + &challenges, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +fn sumcheck_and_its_proofs( + merlin: &mut ProverState, + matrix: &SparkMatrix, + e_values: &EValuesForMatrix, + claimed_value: FieldElement, + e_values_witness: &WhirWitness, + vals_witness: &WhirWitness, + whir_configs: &SPARKWHIRConfigs, +) -> Result<()> { + let mles: [&[FieldElement]; 3] = [&matrix.coo.val, &e_values.e_rx, &e_values.e_ry]; + let (sumcheck_final_folds, folding_randomness) = + run_spark_sumcheck(merlin, mles, claimed_value)?; + + merlin.prover_hint_ark(&[ + sumcheck_final_folds[0], + sumcheck_final_folds[1], + sumcheck_final_folds[2], + ]); + + produce_whir_proof( + merlin, + &folding_randomness, + &[&e_values.e_rx, &e_values.e_ry], + &whir_configs.num_terms_2batched, + e_values_witness, + )?; + + produce_whir_proof( + merlin, + &folding_randomness, + &[&matrix.coo.val], + &whir_configs.num_terms_1batched, + vals_witness, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +fn prove_combined_rs_ws_product( + merlin: &mut ProverState, + matrix: &SparkMatrix, + e_values: &EValuesForMatrix, + e_values_witness: &WhirWitness, + rs_ws_witness: &WhirWitness, + whir_configs: &SPARKWHIRConfigs, + challenges: &Challenges, +) -> Result<()> { + let gamma_sq = challenges.gamma * challenges.gamma; + let one = FieldElement::from(1u64); + + let row_field = &matrix.coo.row_field; + let col_field = &matrix.coo.col_field; + let n = row_field.len(); + let m = col_field.len(); + + let (row_pairs, col_pairs) = tracing::info_span!("build_rs_ws_pairs").in_scope(|| { + join( + || { + (0..n) + .into_par_iter() + .map(|i| { + let a = row_field[i]; + let v = e_values.e_rx[i]; + let t = matrix.timestamps.read_row[i]; + let base = a * gamma_sq + v * challenges.gamma + t - challenges.tau; + (base, base + one) + }) + .collect::>() + }, + || { + (0..m) + .into_par_iter() + .map(|i| { + let a = col_field[i]; + let v = e_values.e_ry[i]; + let t = matrix.timestamps.read_col[i]; + let base = a * gamma_sq + v * challenges.gamma + t - challenges.tau; + (base, base + one) + }) + .collect::>() + }, + ) + }); + let (row_rs_vec, row_ws_vec): (Vec<_>, Vec<_>) = row_pairs.into_iter().unzip(); + let (col_rs_vec, col_ws_vec): (Vec<_>, Vec<_>) = col_pairs.into_iter().unzip(); + + let mut gpa_leaves_flat = Vec::with_capacity(4 * row_rs_vec.len()); + let gpa_leaves = [row_rs_vec, row_ws_vec, col_rs_vec, col_ws_vec]; + gpa_leaves_flat.extend(gpa_leaves.into_iter().flatten()); + let gpa_randomness = run_gpa4(merlin, gpa_leaves_flat)?; + + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(2); + + let ((row_address_eval, row_timestamp_eval), (col_address_eval, col_timestamp_eval)) = + tracing::info_span!("multilinear_extend_rs_ws").in_scope(|| { + join( + || { + join( + || multilinear_extend(row_field, evaluation_randomness), + || multilinear_extend(&matrix.timestamps.read_row, evaluation_randomness), + ) + }, + || { + join( + || multilinear_extend(col_field, evaluation_randomness), + || multilinear_extend(&matrix.timestamps.read_col, evaluation_randomness), + ) + }, + ) + }); + + merlin.prover_hint_ark(&row_address_eval); + merlin.prover_hint_ark(&row_timestamp_eval); + merlin.prover_hint_ark(&col_address_eval); + merlin.prover_hint_ark(&col_timestamp_eval); + + let rs_ws_vecs: [&[FieldElement]; 4] = [ + &matrix.coo.row_field, + &matrix.timestamps.read_row, + &matrix.coo.col_field, + &matrix.timestamps.read_col, + ]; + + produce_whir_proof( + merlin, + evaluation_randomness, + &rs_ws_vecs, + &whir_configs.num_terms_4batched, + rs_ws_witness, + )?; + + let (row_value_eval, col_value_eval) = tracing::info_span!("multilinear_extend_e_values") + .in_scope(|| { + join( + || multilinear_extend(&e_values.e_rx, evaluation_randomness), + || multilinear_extend(&e_values.e_ry, evaluation_randomness), + ) + }); + merlin.prover_hint_ark(&row_value_eval); + merlin.prover_hint_ark(&col_value_eval); + + produce_whir_proof( + merlin, + evaluation_randomness, + &[&e_values.e_rx, &e_values.e_ry], + &whir_configs.num_terms_2batched, + e_values_witness, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +fn commit_e_values( + merlin: &mut ProverState, + whir_configs: &SPARKWHIRConfigs, + e_values: &EValuesForMatrix, +) -> WhirWitness { + whir_configs + .num_terms_2batched + .commit(merlin, &[&e_values.e_rx, &e_values.e_ry]) +} diff --git a/provekit/spark/src/setup.rs b/provekit/spark/src/setup.rs new file mode 100644 index 000000000..6ea49744d --- /dev/null +++ b/provekit/spark/src/setup.rs @@ -0,0 +1,113 @@ +use { + crate::{ + prover::SPARKScheme as SPARKProverScheme, + types::{SPARKSetup, SparkMatrix, SparkWitnesses}, + }, + anyhow::Result, + provekit_common::{FieldElement, TranscriptSponge, WhirR1CSProof}, + tracing::instrument, + whir::{ + protocols::irs_commit::Commitment, + transcript::{codecs::Empty, DomainSeparator, Proof, ProverState, VerifierState}, + }, +}; + +pub(crate) struct PrecomputedCommitments { + pub val: Commitment, + pub rsws: Commitment, + pub a_row_finalts: Commitment, + pub a_col_finalts: Commitment, +} + +#[instrument(skip_all)] +pub fn preprocess_spark(matrix: &SparkMatrix) -> (SPARKSetup, SparkWitnesses) { + let num_rows = matrix.timestamps.final_row.len(); + let num_cols = matrix.timestamps.final_col.len(); + let nonzero_terms = matrix.coo.val.len(); + let scheme = SPARKProverScheme::new(num_rows, num_cols, nonzero_terms); + + let ds = DomainSeparator::protocol(&scheme.whir_configs).instance(&Empty); + let mut merlin = ProverState::new(&ds, TranscriptSponge::default()); + + let vals_witness = scheme + .whir_configs + .num_terms_1batched + .commit(&mut merlin, &[&matrix.coo.val]); + let rs_ws_witness = scheme + .whir_configs + .num_terms_4batched + .commit(&mut merlin, &[ + &matrix.coo.row_field, + &matrix.timestamps.read_row, + &matrix.coo.col_field, + &matrix.timestamps.read_col, + ]); + let final_row_ts_witness = scheme + .whir_configs + .row + .commit(&mut merlin, &[&matrix.timestamps.final_row]); + let final_col_ts_witness = scheme + .whir_configs + .col + .commit(&mut merlin, &[&matrix.timestamps.final_col]); + + let proof = merlin.proof(); + let setup = SPARKSetup { + whir_params: scheme.whir_configs, + matrix_dimensions: scheme.matrix_dimensions, + transcript: WhirR1CSProof { + narg_string: proof.narg_string, + hints: proof.hints, + #[cfg(debug_assertions)] + pattern: proof.pattern, + }, + }; + let witnesses = SparkWitnesses { + vals_witness, + rs_ws_witness, + final_row_ts_witness, + final_col_ts_witness, + }; + (setup, witnesses) +} + +impl SPARKSetup { + pub(crate) fn extract_commitments(&self) -> Result { + let setup_ds = DomainSeparator::protocol(&self.whir_params).instance(&Empty); + let setup_proof = Proof { + narg_string: self.transcript.narg_string.clone(), + hints: self.transcript.hints.clone(), + #[cfg(debug_assertions)] + pattern: self.transcript.pattern.clone(), + }; + let mut side = VerifierState::new(&setup_ds, &setup_proof, TranscriptSponge::default()); + + let val = self + .whir_params + .num_terms_1batched + .receive_commitment(&mut side) + .map_err(|e| anyhow::anyhow!("Failed to reconstruct val commitment: {e}"))?; + let rsws = self + .whir_params + .num_terms_4batched + .receive_commitment(&mut side) + .map_err(|e| anyhow::anyhow!("Failed to reconstruct rsws commitment: {e}"))?; + let a_row_finalts = self + .whir_params + .row + .receive_commitment(&mut side) + .map_err(|e| anyhow::anyhow!("Failed to reconstruct row finalts commitment: {e}"))?; + let a_col_finalts = self + .whir_params + .col + .receive_commitment(&mut side) + .map_err(|e| anyhow::anyhow!("Failed to reconstruct col finalts commitment: {e}"))?; + + Ok(PrecomputedCommitments { + val, + rsws, + a_row_finalts, + a_col_finalts, + }) + } +} diff --git a/provekit/spark/src/sumcheck.rs b/provekit/spark/src/sumcheck.rs new file mode 100644 index 000000000..293ea8309 --- /dev/null +++ b/provekit/spark/src/sumcheck.rs @@ -0,0 +1,127 @@ +use { + anyhow::{ensure, Result}, + ark_std::{One, Zero}, + provekit_common::{ + utils::{ + sumcheck::{eval_cubic_poly, sumcheck_fold_map_reduce}, + HALF, + }, + FieldElement, TranscriptSponge, + }, + tracing::instrument, + whir::transcript::{ProverState, VerifierMessage, VerifierState}, +}; + +#[instrument(skip_all)] +pub fn run_spark_sumcheck( + merlin: &mut ProverState, + mles: [&[FieldElement]; 3], + mut claimed_value: FieldElement, +) -> Result<([FieldElement; 3], Vec)> { + let mut sumcheck_randomness; + let mut sumcheck_randomness_accumulator = Vec::::new(); + let mut fold = None; + + let mut m0 = mles[0].to_vec(); + let mut m1 = mles[1].to_vec(); + let mut m2 = mles[2].to_vec(); + + loop { + let [hhat_i_at_0, hhat_i_at_em1, hhat_i_at_inf_over_x_cube] = + sumcheck_fold_map_reduce([&mut m0, &mut m1, &mut m2], fold, |[m0, m1, m2]| { + [ + m0.0 * m1.0 * m2.0, + (m0.0 + m0.0 - m0.1) * (m1.0 + m1.0 - m1.1) * (m2.0 + m2.0 - m2.1), + (m0.1 - m0.0) * (m1.1 - m1.0) * (m2.1 - m2.0), + ] + }); + + if fold.is_some() { + m0.truncate(m0.len() / 2); + m1.truncate(m1.len() / 2); + m2.truncate(m2.len() / 2); + } + + let mut hhat_i_coeffs = [FieldElement::zero(); 4]; + + hhat_i_coeffs[0] = hhat_i_at_0; + hhat_i_coeffs[2] = + HALF * (claimed_value + hhat_i_at_em1 - hhat_i_at_0 - hhat_i_at_0 - hhat_i_at_0); + hhat_i_coeffs[3] = hhat_i_at_inf_over_x_cube; + hhat_i_coeffs[1] = claimed_value + - hhat_i_coeffs[0] + - hhat_i_coeffs[0] + - hhat_i_coeffs[3] + - hhat_i_coeffs[2]; + + ensure!( + claimed_value + == hhat_i_coeffs[0] + + hhat_i_coeffs[0] + + hhat_i_coeffs[1] + + hhat_i_coeffs[2] + + hhat_i_coeffs[3], + "Sumcheck binding check failed" + ); + + for coeff in &hhat_i_coeffs { + merlin.prover_message(coeff); + } + sumcheck_randomness = merlin.verifier_message(); + fold = Some(sumcheck_randomness); + claimed_value = eval_cubic_poly(hhat_i_coeffs, sumcheck_randomness); + sumcheck_randomness_accumulator.push(sumcheck_randomness); + if m0.len() <= 2 { + break; + } + } + + let folded_v0 = m0[0] + (m0[1] - m0[0]) * sumcheck_randomness; + let folded_v1 = m1[0] + (m1[1] - m1[0]) * sumcheck_randomness; + let folded_v2 = m2[0] + (m2[1] - m2[0]) * sumcheck_randomness; + + Ok(( + [folded_v0, folded_v1, folded_v2], + sumcheck_randomness_accumulator, + )) +} + +#[instrument(skip_all)] +pub fn run_sumcheck_verifier_spark( + arthur: &mut VerifierState<'_, TranscriptSponge>, + variable_count: usize, + initial_sumcheck_val: FieldElement, +) -> Result<(Vec, FieldElement)> { + let mut saved_val_for_sumcheck_equality_assertion = initial_sumcheck_val; + + let mut alpha = vec![FieldElement::zero(); variable_count]; + + for i in 0..variable_count { + let hhat_i: [FieldElement; 4] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let alpha_i: FieldElement = arthur.verifier_message(); + alpha[i] = alpha_i; + + let hhat_i_at_zero = eval_cubic_poly(hhat_i, FieldElement::zero()); + let hhat_i_at_one = eval_cubic_poly(hhat_i, FieldElement::one()); + ensure!( + saved_val_for_sumcheck_equality_assertion == hhat_i_at_zero + hhat_i_at_one, + "Sumcheck equality check failed" + ); + saved_val_for_sumcheck_equality_assertion = eval_cubic_poly(hhat_i, alpha_i); + } + + Ok((alpha, saved_val_for_sumcheck_equality_assertion)) +} diff --git a/provekit/spark/src/types.rs b/provekit/spark/src/types.rs new file mode 100644 index 000000000..7a4bdee8b --- /dev/null +++ b/provekit/spark/src/types.rs @@ -0,0 +1,125 @@ +use { + provekit_common::{ + file::{ + binary_format::{ + SPARK_PROOF_FORMAT, SPARK_PROOF_VERSION, SPARK_SETUP_FORMAT, SPARK_SETUP_VERSION, + }, + Compression, FileFormat, MaybeHashAware, + }, + FieldElement, HashConfig, WhirConfig, WhirR1CSProof, + }, + serde::{Deserialize, Serialize}, + whir::protocols::irs_commit, +}; + +pub type WhirWitness = irs_commit::Witness; + +#[derive(Clone, Serialize, Deserialize)] +pub struct SPARKSetup { + pub whir_params: SPARKWHIRConfigs, + pub matrix_dimensions: MatrixDimensions, + pub transcript: WhirR1CSProof, +} + +impl FileFormat for SPARKSetup { + const FORMAT: [u8; 8] = SPARK_SETUP_FORMAT; + const EXTENSION: &'static str = "spc"; + const VERSION: (u16, u16) = SPARK_SETUP_VERSION; + const COMPRESSION: Compression = Compression::Zstd; +} + +impl MaybeHashAware for SPARKSetup { + fn maybe_hash_config(&self) -> Option { + None + } +} + +#[derive(Serialize, Deserialize)] +#[serde(transparent)] +pub struct SPARKProof(pub WhirR1CSProof); + +impl FileFormat for SPARKProof { + const FORMAT: [u8; 8] = SPARK_PROOF_FORMAT; + const EXTENSION: &'static str = "sp"; + const VERSION: (u16, u16) = SPARK_PROOF_VERSION; + const COMPRESSION: Compression = Compression::Zstd; +} + +impl MaybeHashAware for SPARKProof { + fn maybe_hash_config(&self) -> Option { + None + } +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct MatrixDimensions { + pub num_rows: usize, + pub num_cols: usize, + pub nonzero_terms: usize, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct SPARKWHIRConfigs { + pub row: WhirConfig, + pub col: WhirConfig, + pub num_terms_1batched: WhirConfig, + pub num_terms_2batched: WhirConfig, + pub num_terms_4batched: WhirConfig, +} + +#[derive(Debug, Clone)] +pub struct SparkMatrix { + pub coo: COOMatrix, + pub timestamps: TimeStamps, +} + +#[derive(Debug, Clone)] +pub struct COOMatrix { + pub row: Vec, + pub col: Vec, + pub row_field: Vec, + pub col_field: Vec, + pub val: Vec, +} + +#[derive(Debug, Clone)] +pub struct TimeStamps { + pub read_row: Vec, + pub read_col: Vec, + pub final_row: Vec, + pub final_col: Vec, +} + +#[derive(Clone)] +pub struct SparkWitnesses { + pub vals_witness: WhirWitness, + pub rs_ws_witness: WhirWitness, + pub final_row_ts_witness: WhirWitness, + pub final_col_ts_witness: WhirWitness, +} + +#[derive(Clone)] +pub struct SparkProverContext { + pub matrix: SparkMatrix, + pub witnesses: SparkWitnesses, + pub setup: SPARKSetup, +} + +#[derive(Debug, Clone)] +pub struct Memory { + pub eq_rx: Vec, + pub eq_ry: Vec, +} + +#[derive(Debug, Clone)] +pub struct EValuesForMatrix { + pub e_rx: Vec, + pub e_ry: Vec, +} + +/// Challenges drawn from the Fiat-Shamir transcript during proving. +#[derive(Debug, Clone)] +pub struct Challenges { + pub gamma: FieldElement, + pub tau: FieldElement, +} diff --git a/provekit/spark/src/utils.rs b/provekit/spark/src/utils.rs new file mode 100644 index 000000000..ff122e2fa --- /dev/null +++ b/provekit/spark/src/utils.rs @@ -0,0 +1,28 @@ +pub use crate::types::Memory; +use provekit_common::{ + utils::sumcheck::calculate_evaluations_over_boolean_hypercube_for_eq, FieldElement, +}; + +#[tracing::instrument(skip_all)] +pub fn calculate_memory( + b: FieldElement, + point_row: &[FieldElement], + point_col: &[FieldElement], +) -> Memory { + let row_point: Vec<_> = std::iter::once(b) + .chain(point_row.iter().copied()) + .collect(); + let col_point: Vec<_> = std::iter::once(b) + .chain(point_col.iter().copied()) + .collect(); + Memory { + eq_rx: calculate_evaluations_over_boolean_hypercube_for_eq( + &row_point, + 1 << row_point.len(), + ), + eq_ry: calculate_evaluations_over_boolean_hypercube_for_eq( + &col_point, + 1 << col_point.len(), + ), + } +} diff --git a/provekit/spark/src/verifier.rs b/provekit/spark/src/verifier.rs new file mode 100644 index 000000000..1f26e7e16 --- /dev/null +++ b/provekit/spark/src/verifier.rs @@ -0,0 +1,241 @@ +use { + crate::{ + gpa::gpa_sumcheck_verifier4, + memory::verify_axis, + setup::PrecomputedCommitments, + sumcheck::run_sumcheck_verifier_spark, + types::{MatrixDimensions, SPARKProof, SPARKSetup, SPARKWHIRConfigs}, + }, + anyhow::{ensure, Context, Result}, + ark_ff::{Field, Zero}, + provekit_common::{ + spark::R1CSSparkQuery, + utils::{next_power_of_two, sumcheck::calculate_eq}, + FieldElement, TranscriptSponge, + }, + tracing::instrument, + whir::{ + algebra::linear_form::MultilinearExtension, + transcript::{DomainSeparator, Proof, VerifierMessage, VerifierState}, + }, +}; + +pub trait SPARKVerifier { + fn verify(&self, proof: SPARKProof, setup: &SPARKSetup, request: &R1CSSparkQuery) + -> Result<()>; +} + +pub struct SPARKScheme; + +impl SPARKVerifier for SPARKScheme { + #[instrument(skip_all)] + fn verify( + &self, + proof: SPARKProof, + setup: &SPARKSetup, + request: &R1CSSparkQuery, + ) -> Result<()> { + ensure!( + !(FieldElement::ONE + request.matrix_batching_randomness).is_zero(), + "matrix_batching_randomness must not equal -1 (would zero the SPARK denominator)" + ); + + let precomputed_commitments = setup.extract_commitments()?; + + let whir_proof = Proof { + narg_string: proof.0.narg_string, + hints: proof.0.hints, + #[cfg(debug_assertions)] + pattern: proof.0.pattern, + }; + let mut arthur = VerifierState::new( + &DomainSeparator::protocol(&setup.whir_params) + .session(&setup.transcript.narg_string) + .instance(&request.hash_bytes()), + &whir_proof, + TranscriptSponge::default(), + ); + + let claimed_value = (request.claimed_value + / (FieldElement::ONE + request.matrix_batching_randomness)) + / (FieldElement::ONE + request.matrix_batching_randomness); + + let mut new_request = request.clone(); + let b1 = request.matrix_batching_randomness + / (FieldElement::ONE + request.matrix_batching_randomness); + new_request.point_to_evaluate.row = std::iter::once(b1) + .chain(new_request.point_to_evaluate.row.clone()) + .collect(); + new_request.point_to_evaluate.col = std::iter::once(b1) + .chain(new_request.point_to_evaluate.col.clone()) + .collect(); + + verify_spark_single_matrix( + &setup.whir_params, + setup.matrix_dimensions.clone(), + &mut arthur, + &precomputed_commitments, + &new_request, + &claimed_value, + ) + } +} + +#[instrument(skip_all)] +pub(crate) fn verify_spark_single_matrix( + whir_params: &SPARKWHIRConfigs, + matrix_dimensions: MatrixDimensions, + arthur: &mut VerifierState<'_, TranscriptSponge>, + precomputed_commitments: &PrecomputedCommitments, + request: &R1CSSparkQuery, + claimed_value: &FieldElement, +) -> Result<()> { + let e_values_commitment = whir_params + .num_terms_2batched + .receive_commitment(arthur) + .map_err(|e| anyhow::anyhow!("Failed to receive e_values commitment: {e}"))?; + + let (randomness, last_sumcheck_value) = run_sumcheck_verifier_spark( + arthur, + next_power_of_two(matrix_dimensions.nonzero_terms), + *claimed_value, + ) + .context("While verifying SPARK sumcheck")?; + let eval_weight = MultilinearExtension::new(randomness); + + let sumcheck_hints: [FieldElement; 3] = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + + ensure!(last_sumcheck_value == sumcheck_hints[0] * sumcheck_hints[1] * sumcheck_hints[2]); + + let e_values_claim = whir_params + .num_terms_2batched + .verify(arthur, &[&e_values_commitment], &[ + sumcheck_hints[1], + sumcheck_hints[2], + ]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed for e_values (sumcheck): {e}"))?; + e_values_claim + .verify([&eval_weight as &dyn whir::algebra::linear_form::LinearForm]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for e_values: {e}"))?; + + let val_claim = whir_params + .num_terms_1batched + .verify( + arthur, + &[&precomputed_commitments.val], + &[sumcheck_hints[0]], + ) + .map_err(|e| anyhow::anyhow!("WHIR verify failed for val: {e}"))?; + val_claim + .verify([&eval_weight as &dyn whir::algebra::linear_form::LinearForm]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for val: {e}"))?; + + let tau: FieldElement = arthur.verifier_message(); + let gamma: FieldElement = arthur.verifier_message(); + + let gpa_result = gpa_sumcheck_verifier4( + arthur, + provekit_common::utils::next_power_of_two(matrix_dimensions.nonzero_terms) + 3, + )?; + + let (combination_randomness, evaluation_randomness) = gpa_result.randomness.split_at(2); + + let claimed_row_rs = gpa_result.claimed_values[0]; + let claimed_row_ws = gpa_result.claimed_values[1]; + let claimed_col_rs = gpa_result.claimed_values[2]; + let claimed_col_ws = gpa_result.claimed_values[3]; + + let row_adr: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let row_timestamp: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let col_adr: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let col_timestamp: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + + let gpa_eval_weight = MultilinearExtension::new(evaluation_randomness.to_vec()); + let gpa_eval_lf: &dyn whir::algebra::linear_form::LinearForm = &gpa_eval_weight; + + let rsws_claim = whir_params + .num_terms_4batched + .verify(arthur, &[&precomputed_commitments.rsws], &[ + row_adr, + row_timestamp, + col_adr, + col_timestamp, + ]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed for rsws: {e}"))?; + rsws_claim + .verify([gpa_eval_lf]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for rsws: {e}"))?; + + let row_mem: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let col_mem: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + + let e_values_gpa_claim = whir_params + .num_terms_2batched + .verify(arthur, &[&e_values_commitment], &[row_mem, col_mem]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed for e_values (GPA): {e}"))?; + e_values_gpa_claim + .verify([gpa_eval_lf]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for e_values (GPA): {e}"))?; + + let gamma_sq = gamma * gamma; + + let row_rs_opening = row_adr * gamma_sq + row_mem * gamma + row_timestamp - tau; + let row_ws_opening = + row_adr * gamma_sq + row_mem * gamma + row_timestamp + FieldElement::from(1) - tau; + let col_rs_opening = col_adr * gamma_sq + col_mem * gamma + col_timestamp - tau; + let col_ws_opening = + col_adr * gamma_sq + col_mem * gamma + col_timestamp + FieldElement::from(1) - tau; + + let evaluated_value = row_rs_opening + * (FieldElement::from(1) - combination_randomness[0]) + * (FieldElement::from(1) - combination_randomness[1]) + + row_ws_opening + * (FieldElement::from(1) - combination_randomness[0]) + * combination_randomness[1] + + col_rs_opening + * combination_randomness[0] + * (FieldElement::from(1) - combination_randomness[1]) + + col_ws_opening * combination_randomness[0] * combination_randomness[1]; + + ensure!(evaluated_value == gpa_result.last_sumcheck_value); + + verify_axis( + arthur, + matrix_dimensions.num_rows, + &whir_params.row, + precomputed_commitments.a_row_finalts.clone(), + |eval_rand| calculate_eq(&request.point_to_evaluate.row, eval_rand), + &tau, + &gamma, + &claimed_row_rs, + &claimed_row_ws, + )?; + + verify_axis( + arthur, + matrix_dimensions.num_cols, + &whir_params.col, + precomputed_commitments.a_col_finalts.clone(), + |eval_rand| calculate_eq(&request.point_to_evaluate.col, eval_rand), + &tau, + &gamma, + &claimed_col_rs, + &claimed_col_ws, + )?; + + Ok(()) +} diff --git a/tooling/cli/Cargo.toml b/tooling/cli/Cargo.toml index 9c9deb208..e24e1e133 100644 --- a/tooling/cli/Cargo.toml +++ b/tooling/cli/Cargo.toml @@ -10,10 +10,12 @@ repository.workspace = true [dependencies] # Workspace crates +mavros-artifacts.workspace = true provekit-common.workspace = true provekit-gnark.workspace = true provekit-prover = { workspace = true, features = ["witness-generation", "parallel"] } provekit-r1cs-compiler.workspace = true +provekit-spark.workspace = true provekit-verifier.workspace = true # Noir language @@ -25,11 +27,13 @@ ark-ff.workspace = true # 3rd party anyhow.workspace = true +bincode.workspace = true argh.workspace = true base64.workspace = true hex.workspace = true postcard.workspace = true rayon.workspace = true +serde.workspace = true serde_json.workspace = true tikv-jemallocator = { workspace = true, optional = true } tracing.workspace = true diff --git a/tooling/cli/src/cmd/mod.rs b/tooling/cli/src/cmd/mod.rs index faf7297cb..69749c2bf 100644 --- a/tooling/cli/src/cmd/mod.rs +++ b/tooling/cli/src/cmd/mod.rs @@ -1,10 +1,12 @@ mod analyze_pkp; mod circuit_stats; mod generate_gnark_inputs; -mod prepare; +pub mod prepare; mod prove; +mod prove_spark; mod show_inputs; mod verify; +mod verify_spark; use {anyhow::Result, argh::FromArgs}; @@ -41,8 +43,10 @@ enum Commands { AnalyzePkp(analyze_pkp::Args), Prepare(prepare::Args), Prove(prove::Args), + ProveSpark(prove_spark::Args), CircuitStats(circuit_stats::Args), Verify(verify::Args), + VerifySpark(verify_spark::Args), GenerateGnarkInputs(generate_gnark_inputs::Args), ShowInputs(show_inputs::Args), } @@ -59,8 +63,10 @@ impl Command for Commands { Self::AnalyzePkp(args) => args.run(), Self::Prepare(args) => args.run(), Self::Prove(args) => args.run(), + Self::ProveSpark(args) => args.run(), Self::CircuitStats(args) => args.run(), Self::Verify(args) => args.run(), + Self::VerifySpark(args) => args.run(), Self::GenerateGnarkInputs(args) => args.run(), Self::ShowInputs(args) => args.run(), } diff --git a/tooling/cli/src/cmd/prepare.rs b/tooling/cli/src/cmd/prepare.rs index a069dab12..8f3c4b88b 100644 --- a/tooling/cli/src/cmd/prepare.rs +++ b/tooling/cli/src/cmd/prepare.rs @@ -2,14 +2,22 @@ use { super::Command, anyhow::{Context, Result}, argh::FromArgs, - provekit_common::{file::write, HashConfig, Prover, Verifier}, + mavros_artifacts::R1CS as MavrosR1CS, + provekit_common::{ + file::write, utils::next_power_of_two, FieldElement, HashConfig, NoirProofScheme, Prover, + Verifier, R1CS, + }, provekit_r1cs_compiler::{MavrosCompiler, NoirCompiler}, - std::{path::PathBuf, str::FromStr}, - tracing::instrument, + provekit_spark::types::{COOMatrix, SparkMatrix, TimeStamps}, + std::{ + path::{Path, PathBuf}, + str::FromStr, + }, + tracing::{info, instrument}, }; #[derive(PartialEq, Eq, Debug)] -enum Compiler { +pub enum Compiler { Noir, Mavros, } @@ -65,6 +73,18 @@ pub struct Args { /// blake3, poseidon2) #[argh(option, long = "hash", default = "String::from(\"skyscraper\")")] hash: String, + + /// also run SPARK preprocessing and write the SPARK setup transcript + #[argh(switch, long = "spark")] + spark: bool, + + /// output path for the SPARK setup transcript (used with --spark) + #[argh( + option, + long = "spc", + default = "PathBuf::from(\"noir_proof_scheme.spc\")" + )] + spc_path: PathBuf, } impl Command for Args { @@ -84,6 +104,15 @@ impl Command for Args { } }; + if self.spark { + provekit_common::register_ntt(); + let matrix = build_spark_matrix_for_scheme(&scheme, self.r1cs_path.as_deref())?; + let (setup, _witnesses) = provekit_spark::preprocess_spark(&matrix); + write(&setup, &self.spc_path) + .with_context(|| format!("writing SPARK setup to {:?}", self.spc_path))?; + info!("Wrote SPARK setup to {:?}", self.spc_path); + } + let prover = Prover::from_noir_proof_scheme(scheme.clone()); let verifier = Verifier::from_noir_proof_scheme(scheme); @@ -92,3 +121,208 @@ impl Command for Args { Ok(()) } } + +pub fn build_spark_matrix_for_scheme( + scheme: &NoirProofScheme, + r1cs_path: Option<&Path>, +) -> Result { + let whir = match scheme { + NoirProofScheme::Noir(s) => s.whir_for_witness.clone(), + NoirProofScheme::Mavros(s) => s.whir_for_witness.clone(), + }; + match scheme { + NoirProofScheme::Noir(noir) => build_spark_r1cs_noir( + &noir.r1cs, + whir.m_0, + whir.m, + whir.w1_size, + whir.num_challenges, + ), + NoirProofScheme::Mavros(_) => { + let r1cs_path = + r1cs_path.context("--r1cs is required for SPARK with the mavros compiler")?; + build_spark_r1cs_mavros( + r1cs_path, + whir.m_0, + whir.m, + whir.w1_size, + whir.num_challenges, + ) + } + } +} + +pub fn build_spark_r1cs_noir( + r1cs: &R1CS, + log_row: usize, + log_col: usize, + w1_size: usize, + num_challenges: usize, +) -> Result { + let is_single_commitment = num_challenges == 0; + + let original_num_entries = + r1cs.a().iter().count() + r1cs.b().iter().count() + r1cs.c().iter().count(); + + let padded_num_entries = 1 << next_power_of_two(original_num_entries); + let to_fill = padded_num_entries - original_num_entries; + + let row_cnt = 1 << log_row; + let col_cnt = if is_single_commitment { + 1 << log_col + } else { + 1 << (1 + log_col) + }; + + let col_witness_split_offset = |c: usize| -> usize { + if !is_single_commitment && (c >= w1_size) { + (1 << log_col) - w1_size + } else { + 0 + } + }; + + let (mut row, mut col, mut val) = ( + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + ); + + for (matrix, row_offset, col_offset) in [ + (r1cs.a(), 0, 0), + (r1cs.b(), 0, col_cnt), + (r1cs.c(), row_cnt, col_cnt), + ] { + for ((r, c), v) in matrix.iter() { + row.push(r + row_offset); + col.push(c + col_offset + col_witness_split_offset(c)); + val.push(v); + } + } + for _ in 0..to_fill { + row.push(0); + col.push(0); + val.push(FieldElement::from(0u64)); + } + + Ok(build_spark_matrix(row, col, val, 2 * row_cnt, 2 * col_cnt)) +} + +pub fn build_spark_r1cs_mavros( + r1cs_path: &Path, + log_row: usize, + log_col: usize, + w1_size: usize, + num_challenges: usize, +) -> Result { + let is_single_commitment = num_challenges == 0; + + let r1cs_bytes = std::fs::read(r1cs_path).context("while reading R1CS file")?; + let r1cs: MavrosR1CS = + bincode::deserialize(&r1cs_bytes).context("while deserializing R1CS from bincode")?; + + let row_cnt = 1 << log_row; + let col_cnt = if is_single_commitment { + 1 << log_col + } else { + 1 << (1 + log_col) + }; + + let col_witness_split_offset = |c: usize| -> usize { + if !is_single_commitment && (c >= w1_size) { + (1 << log_col) - w1_size + } else { + 0 + } + }; + + let original_num_entries: usize = r1cs + .constraints + .iter() + .map(|r1c| r1c.a.len() + r1c.b.len() + r1c.c.len()) + .sum(); + + let padded_num_entries = 1 << next_power_of_two(original_num_entries); + let to_fill = padded_num_entries - original_num_entries; + + let (mut row, mut col, mut val) = ( + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + ); + + for (i, r1c) in r1cs.constraints.iter().enumerate() { + for &(c, v) in &r1c.a { + row.push(i); + col.push(c + col_witness_split_offset(c)); + val.push(v); + } + for &(c, v) in &r1c.b { + row.push(i); + col.push(c + col_cnt + col_witness_split_offset(c)); + val.push(v); + } + for &(c, v) in &r1c.c { + row.push(i + row_cnt); + col.push(c + col_cnt + col_witness_split_offset(c)); + val.push(v); + } + } + + for _ in 0..to_fill { + row.push(0); + col.push(0); + val.push(FieldElement::from(0u64)); + } + + Ok(build_spark_matrix(row, col, val, 2 * row_cnt, 2 * col_cnt)) +} + +pub fn build_spark_matrix( + row: Vec, + col: Vec, + val: Vec, + num_rows: usize, + num_cols: usize, +) -> SparkMatrix { + let len = row.len(); + let mut read_row_counters = vec![0usize; num_rows]; + let mut read_col_counters = vec![0usize; num_cols]; + let mut read_row = Vec::with_capacity(len); + let mut read_col = Vec::with_capacity(len); + + for i in 0..len { + read_row.push(FieldElement::from(read_row_counters[row[i]] as u64)); + read_row_counters[row[i]] += 1; + read_col.push(FieldElement::from(read_col_counters[col[i]] as u64)); + read_col_counters[col[i]] += 1; + } + + let final_row = read_row_counters + .iter() + .map(|&x| FieldElement::from(x as u64)) + .collect(); + let final_col = read_col_counters + .iter() + .map(|&x| FieldElement::from(x as u64)) + .collect(); + + let row_field = row.iter().map(|&r| FieldElement::from(r as u64)).collect(); + let col_field = col.iter().map(|&c| FieldElement::from(c as u64)).collect(); + + SparkMatrix { + coo: COOMatrix { + row, + col, + row_field, + col_field, + val, + }, + timestamps: TimeStamps { + read_row, + read_col, + final_row, + final_col, + }, + } +} diff --git a/tooling/cli/src/cmd/prove.rs b/tooling/cli/src/cmd/prove.rs index d494f82c8..49966a1e3 100644 --- a/tooling/cli/src/cmd/prove.rs +++ b/tooling/cli/src/cmd/prove.rs @@ -38,6 +38,14 @@ pub struct Args { default = "PathBuf::from(\"./proof.np\")" )] proof_path: PathBuf, + + /// directory in which to write SPARK queries (default: ./spark_proofs) + #[argh( + option, + long = "spark-queries-dir", + default = "PathBuf::from(\"./spark_proofs\")" + )] + spark_queries_dir: PathBuf, } impl Command for Args { @@ -49,7 +57,7 @@ impl Command for Args { info!(constraints, witnesses, "Read Noir proof scheme"); // Generate the proof - let proof = prover + let (proof, spark_queries) = prover .prove_with_toml(&self.input_path) .context("While proving Noir program statement")?; @@ -66,6 +74,20 @@ impl Command for Args { .context("While verifying Noir proof")?; } + if !spark_queries.is_empty() { + std::fs::create_dir_all(&self.spark_queries_dir) + .with_context(|| format!("creating {:?}", self.spark_queries_dir))?; + for (index, query) in spark_queries.iter().enumerate() { + let query_path = self + .spark_queries_dir + .join(format!("spark_query_{index}.json")); + let query_file = std::fs::File::create(&query_path) + .with_context(|| format!("creating {query_path:?}"))?; + serde_json::to_writer_pretty(query_file, query).context("writing spark query")?; + info!("Wrote SPARK query to {query_path:?}"); + } + } + Ok(()) } } diff --git a/tooling/cli/src/cmd/prove_spark.rs b/tooling/cli/src/cmd/prove_spark.rs new file mode 100644 index 000000000..1331079d2 --- /dev/null +++ b/tooling/cli/src/cmd/prove_spark.rs @@ -0,0 +1,119 @@ +use { + super::{ + prepare::{build_spark_r1cs_mavros, build_spark_r1cs_noir}, + Command, + }, + anyhow::{Context, Result}, + argh::FromArgs, + provekit_common::{ + file::{read, write}, + spark::R1CSSparkQuery, + Prover, + }, + provekit_spark::{types::SparkMatrix, SPARKProver as _, SPARKProverScheme, SparkProverContext}, + std::{ + fs::File, + io::BufReader, + path::{Path, PathBuf}, + }, + tracing::{info, instrument}, +}; + +/// Generate SPARK proofs for the queries emitted by `prove`. +#[derive(FromArgs, PartialEq, Eq, Debug)] +#[argh(subcommand, name = "prove-spark")] +pub struct Args { + /// path to the prepared proof scheme + #[argh(positional)] + prover_path: PathBuf, + + /// directory containing `spark_query_.json` files; SPARK proofs are + /// written here as `spark_proof_.sp` (default: ./spark_proofs) + #[argh( + option, + long = "spark-dir", + default = "PathBuf::from(\"./spark_proofs\")" + )] + spark_dir: PathBuf, + + /// path to the R1CS file (required for the mavros compiler) + #[argh(option, long = "r1cs")] + r1cs_path: Option, +} + +impl Command for Args { + #[instrument(skip_all)] + fn run(&self) -> Result<()> { + provekit_common::register_ntt(); + + let prover: Prover = read(&self.prover_path).context("while reading Provekit Prover")?; + + let queries = collect_queries(&self.spark_dir)?; + if queries.is_empty() { + info!("No SPARK queries found in {:?}", self.spark_dir); + return Ok(()); + } + + // TODO: cache from `prepare --spark` instead of recomputing; blocked on + // serde for `WhirWitness` over `ark_bn254::Fr`. + let spark_matrix = build_spark_matrix(&prover, self.r1cs_path.as_deref())?; + let (setup, witnesses) = provekit_spark::preprocess_spark(&spark_matrix); + let context = SparkProverContext { + matrix: spark_matrix, + witnesses, + setup, + }; + + let num_constraints = context.matrix.timestamps.final_row.len(); + let num_witnesses = context.matrix.timestamps.final_col.len(); + let num_nonzero = context.matrix.coo.val.len(); + + for (index, query) in queries.iter().enumerate() { + let scheme = SPARKProverScheme::new(num_constraints, num_witnesses, num_nonzero); + let spark_proof = scheme + .prove(&context, query) + .context("generating SPARK proof")?; + let proof_path = self.spark_dir.join(format!("spark_proof_{index}.sp")); + write(&spark_proof, &proof_path) + .with_context(|| format!("writing SPARK proof to {proof_path:?}"))?; + info!("Wrote SPARK proof to {proof_path:?}"); + } + + Ok(()) + } +} + +fn collect_queries(dir: &Path) -> Result> { + let mut out = Vec::new(); + for index in 0usize.. { + let path = dir.join(format!("spark_query_{index}.json")); + if !path.exists() { + break; + } + let file = File::open(&path).with_context(|| format!("opening {path:?}"))?; + let query: R1CSSparkQuery = serde_json::from_reader(BufReader::new(file)) + .with_context(|| format!("parsing {path:?}"))?; + out.push(query); + } + Ok(out) +} + +fn build_spark_matrix(prover: &Prover, r1cs_path: Option<&Path>) -> Result { + let whir = prover.whir_for_witness().clone(); + match prover { + Prover::Noir(p) => { + build_spark_r1cs_noir(&p.r1cs, whir.m_0, whir.m, whir.w1_size, whir.num_challenges) + } + Prover::Mavros(_) => { + let r1cs_path = r1cs_path + .context("--r1cs is required for SPARK proving with the mavros compiler")?; + build_spark_r1cs_mavros( + r1cs_path, + whir.m_0, + whir.m, + whir.w1_size, + whir.num_challenges, + ) + } + } +} diff --git a/tooling/cli/src/cmd/verify.rs b/tooling/cli/src/cmd/verify.rs index 213cc53df..f709a9cce 100644 --- a/tooling/cli/src/cmd/verify.rs +++ b/tooling/cli/src/cmd/verify.rs @@ -32,7 +32,6 @@ impl Command for Args { let mut verifier = verifier?; let proof = proof?; - // Verify the proof verifier .verify(&proof) .context("While verifying Noir proof")?; diff --git a/tooling/cli/src/cmd/verify_spark.rs b/tooling/cli/src/cmd/verify_spark.rs new file mode 100644 index 000000000..c9a8d4196 --- /dev/null +++ b/tooling/cli/src/cmd/verify_spark.rs @@ -0,0 +1,57 @@ +use { + super::Command, + anyhow::{Context, Result}, + argh::FromArgs, + provekit_common::{file::read, spark::R1CSSparkQuery}, + provekit_spark::{SPARKProof, SPARKSetup, SPARKVerifier, SPARKVerifierScheme}, + std::{fs::File, io::BufReader, path::PathBuf}, + tracing::instrument, +}; + +/// Verify a standalone SPARK proof against a saved R1CSSparkQuery. +#[derive(FromArgs, PartialEq, Eq, Debug)] +#[argh(subcommand, name = "verify-spark")] +pub struct Args { + /// path to the SPARK proof file (.sp or .json) + #[argh(positional)] + proof_path: PathBuf, + + /// path to the SPARK setup transcript (.spc) produced by `serve` + #[argh(positional)] + setup_path: PathBuf, + + /// path to the R1CSSparkQuery JSON file + #[argh(positional)] + query_path: PathBuf, +} + +impl Command for Args { + #[instrument(skip_all)] + fn run(&self) -> Result<()> { + provekit_common::register_ntt(); + + let (proof, (setup, query)) = rayon::join( + || read::(&self.proof_path).context("while reading SPARK proof"), + || { + rayon::join( + || read::(&self.setup_path).context("while reading SPARK setup"), + || read_query(&self.query_path).context("while reading SPARK query"), + ) + }, + ); + let proof = proof?; + let setup = setup?; + let query = query?; + + SPARKVerifierScheme + .verify(proof, &setup, &query) + .context("while verifying SPARK proof")?; + + Ok(()) + } +} + +fn read_query(path: &PathBuf) -> Result { + let file = File::open(path).with_context(|| format!("opening {path:?}"))?; + serde_json::from_reader(BufReader::new(file)).context("parsing query JSON") +} diff --git a/tooling/provekit-bench/tests/compiler.rs b/tooling/provekit-bench/tests/compiler.rs index d4481875b..3c09ca9f5 100644 --- a/tooling/provekit-bench/tests/compiler.rs +++ b/tooling/provekit-bench/tests/compiler.rs @@ -87,7 +87,7 @@ fn test_noir_compiler_with_hash_config( let prover = Prover::from_noir_proof_scheme(schema.clone()); let mut verifier = Verifier::from_noir_proof_scheme(schema.clone()); - let proof = prover + let (proof, _) = prover .prove_with_toml(&witness_file_path) .expect("While proving Noir program statement"); @@ -269,7 +269,7 @@ fn test_public_input_binding_exploit() { let mut verifier = Verifier::from_noir_proof_scheme(schema.clone()); // Prove honestly (a=5, b=3 → result = (5+3)*(5-3) = 16) - let mut proof = prover + let (mut proof, _) = prover .prove_with_toml(&witness_file_path) .expect("While proving Noir program statement"); @@ -322,7 +322,7 @@ fn test_verifier_rejects_mismatched_hash_config() { let mut matching_verifier = Verifier::from_noir_proof_scheme(prover_schema); let mut mismatched_verifier = Verifier::from_noir_proof_scheme(verifier_schema); - let proof = prover + let (proof, _spark_queries) = prover .prove_with_toml(&witness_file_path) .expect("While proving Noir program statement"); diff --git a/tooling/provekit-ffi/src/ffi.rs b/tooling/provekit-ffi/src/ffi.rs index e156f5d09..3dca9b539 100644 --- a/tooling/provekit-ffi/src/ffi.rs +++ b/tooling/provekit-ffi/src/ffi.rs @@ -555,7 +555,7 @@ pub unsafe extern "C" fn pk_prove_toml( // Clone is required: Prove::prove consumes self. // SAFETY: prover is guaranteed non-null and valid by caller contract. let fresh_prover = (*prover).prover.clone(); - let proof = fresh_prover + let (proof, _) = fresh_prover .prove_with_toml(Path::new(&toml_path)) .map_err(|e| { set_last_error(format!("{e:#}")); @@ -624,7 +624,7 @@ pub unsafe extern "C" fn pk_prove_json( // Clone is required: Prove::prove consumes self. let fresh_prover = (*prover).prover.clone(); - let proof = fresh_prover.prove(input_map).map_err(|e| { + let (proof, _) = fresh_prover.prove(input_map).map_err(|e| { set_last_error(format!("{e:#}")); PKStatus::ProofError })?; diff --git a/tooling/provekit-wasm/src/prover.rs b/tooling/provekit-wasm/src/prover.rs index e14d70a52..634d6bf19 100644 --- a/tooling/provekit-wasm/src/prover.rs +++ b/tooling/provekit-wasm/src/prover.rs @@ -117,6 +117,7 @@ impl Prover { .ok_or_else(|| JsError::new("Prover has been consumed by a previous prove() call"))?; inner .prove_with_witness(witness) + .map(|(proof, _)| proof) .context("Failed to generate proof") .map_err(|err| JsError::new(&format!("{err:#}"))) }