Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 60 additions & 5 deletions crypto/stark/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,34 @@ type AirTracePair<'a, Field, FieldExtension, PI> = (
&'a PI,
);

#[cfg(test)]
pub(crate) mod domain_cache_stats {
use std::cell::Cell;

thread_local! {
static COUNTS: Cell<(usize, usize)> = const { Cell::new((0, 0)) };
}

pub(crate) fn reset() {
COUNTS.with(|c| c.set((0, 0)));
}

pub(crate) fn get() -> (usize, usize) {
COUNTS.with(Cell::get)
}

pub(crate) fn record(was_hit: bool) {
COUNTS.with(|c| {
let (hits, misses) = c.get();
c.set(if was_hit {
(hits + 1, misses)
} else {
(hits, misses + 1)
});
});
}
}

/// A default STARK prover implementing `IsStarkProver`.
pub struct Prover<
Field: IsSubFieldOf<FieldExtension> + IsFFTField + Send + Sync,
Expand Down Expand Up @@ -645,8 +673,8 @@ pub trait IsStarkProver<
fn run_debug_checks(
air_trace_pairs: &[AirTracePair<'_, Field, FieldExtension, PI>],
commitments: &[Round1Commitments<Field, FieldExtension>],
domains: &[Domain<Field>],
twiddle_caches: &[LdeTwiddles<Field>],
domains: &[Arc<Domain<Field>>],
twiddle_caches: &[Arc<LdeTwiddles<Field>>],
) where
FieldElement<Field>: AsBytes,
FieldElement<FieldExtension>: AsBytes,
Expand Down Expand Up @@ -1523,17 +1551,44 @@ pub trait IsStarkProver<
#[cfg(feature = "instruments")]
let phase_start = Instant::now();

// Deduplicate Domain + LdeTwiddles by (trace_length, blowup_factor, coset_offset).
// Many tables share the same domain size (e.g., 7+ tables at 2^20).
// Without dedup, each creates its own Domain (~24 MB) and LdeTwiddles (~32 MB).
type DomainEntry<F> = (Arc<Domain<F>>, Arc<LdeTwiddles<F>>);
let mut domain_cache: std::collections::HashMap<(usize, usize, u64), DomainEntry<Field>> =
std::collections::HashMap::new();

let mut domains = Vec::with_capacity(num_airs);
let mut twiddle_caches: Vec<LdeTwiddles<Field>> = Vec::with_capacity(num_airs);
let mut twiddle_caches: Vec<Arc<LdeTwiddles<Field>>> = Vec::with_capacity(num_airs);

for (air, trace, _pub_inputs) in &*air_trace_pairs {
let trace_length = trace.num_rows();
let domain = new_domain(*air, trace_length);
let twiddles = LdeTwiddles::new(&domain);
let blowup = air.options().blowup_factor as usize;
let coset_offset = air.options().coset_offset;
let key = (trace_length, blowup, coset_offset);
Comment thread
MauroToscano marked this conversation as resolved.

#[cfg(test)]
let was_hit = domain_cache.contains_key(&key);
Comment thread
MauroToscano marked this conversation as resolved.

let (domain, twiddles) = domain_cache
.entry(key)
.or_insert_with(|| {
let d = new_domain(*air, trace_length);
let t = LdeTwiddles::new(&d);
(Arc::new(d), Arc::new(t))
})
.clone();
Comment thread
MauroToscano marked this conversation as resolved.

#[cfg(test)]
domain_cache_stats::record(was_hit);

domains.push(domain);
twiddle_caches.push(twiddles);
}
// Free the HashMap (which holds extra strong Arc references) before the
// long proving rounds begin. `domains` and `twiddle_caches` already hold
// the only surviving Arcs we care about.
drop(domain_cache);

let k = table_parallelism().min(num_airs).max(1);

Expand Down
155 changes: 153 additions & 2 deletions crypto/stark/src/tests/prover_tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use crypto::fiat_shamir::default_transcript::DefaultTranscript;

use crate::{
domain::Domain,
examples::{quadratic_air::QuadraticAIR, simple_fibonacci},
examples::{
quadratic_air::QuadraticAIR,
simple_fibonacci::{self, FibonacciAIR, FibonacciPublicInputs},
},
proof::options::ProofOptions,
prover::{IsStarkProver, Prover, evaluate_polynomial_on_lde_domain},
prover::{IsStarkProver, Prover, domain_cache_stats, evaluate_polynomial_on_lde_domain},
trace::{LDETraceTable, get_trace_evaluations, get_trace_evaluations_from_lde},
traits::AIR,
verifier::{IsStarkVerifier, Verifier},
};
use math::{
field::{element::FieldElement, goldilocks::GoldilocksField, traits::IsFFTField},
Expand Down Expand Up @@ -233,3 +239,148 @@ fn test_decompose_and_extend_d2_matches_original() {
assert_eq!(new_result[1][i], original[1][i], "H₁ mismatch at index {i}");
}
}

/// Test that the domain cache 3-tuple key `(trace_length, blowup, coset_offset)` correctly
/// distinguishes AIRs that share the same `(trace_length, blowup)` but differ in
/// `coset_offset`. Both AIRs must get their own `Domain` and the resulting proofs must
/// verify successfully.
#[test_log::test]
fn test_multi_prove_mixed_coset_offsets() {
let proof_options_3 = ProofOptions {
blowup_factor: 2,
fri_number_of_queries: 3,
coset_offset: 3,
grinding_factor: 1,
};
let proof_options_7 = ProofOptions {
blowup_factor: 2,
fri_number_of_queries: 3,
coset_offset: 7,
grinding_factor: 1,
};

// Both AIRs have the same trace length and blowup, but different coset offsets.
let mut trace_1 = simple_fibonacci::fibonacci_trace([Felt::from(1), Felt::from(1)], 8);
let mut trace_2 = simple_fibonacci::fibonacci_trace([Felt::from(1), Felt::from(1)], 8);

let pub_inputs = FibonacciPublicInputs {
a0: Felt::one(),
a1: Felt::one(),
};

let air_1 = FibonacciAIR::<GoldilocksField>::new(&proof_options_3);
let air_2 = FibonacciAIR::<GoldilocksField>::new(&proof_options_7);

let air_trace_pairs: Vec<(
&dyn AIR<
Field = GoldilocksField,
FieldExtension = GoldilocksField,
PublicInputs = FibonacciPublicInputs<GoldilocksField>,
>,
&mut _,
&_,
)> = vec![
(&air_1, &mut trace_1, &pub_inputs),
(&air_2, &mut trace_2, &pub_inputs),
];

let multi_proof = Prover::multi_prove(
air_trace_pairs,
&mut DefaultTranscript::<GoldilocksField>::new(&[]),
)
.expect("proving should succeed");

let airs: Vec<
&dyn AIR<
Field = GoldilocksField,
FieldExtension = GoldilocksField,
PublicInputs = FibonacciPublicInputs<GoldilocksField>,
>,
> = vec![&air_1, &air_2];

assert!(
Verifier::multi_verify(
&airs,
&multi_proof,
&mut DefaultTranscript::<GoldilocksField>::new(&[]),
&FieldElement::zero(),
),
"verification should succeed when AIRs share (trace_length, blowup) but differ in coset_offset"
);
}

/// Test that the domain cache deduplicates when multiple AIRs share all three key fields
/// `(trace_length, blowup, coset_offset)`. Asserts exactly one `Domain`/`LdeTwiddles`
/// construction for N identical AIRs and that the resulting proof still verifies.
#[test_log::test]
fn test_multi_prove_dedups_shared_domain_params() {
domain_cache_stats::reset();

let proof_options = ProofOptions {
blowup_factor: 2,
fri_number_of_queries: 3,
coset_offset: 3,
grinding_factor: 1,
};

let mut trace_1 = simple_fibonacci::fibonacci_trace([Felt::from(1), Felt::from(1)], 8);
let mut trace_2 = simple_fibonacci::fibonacci_trace([Felt::from(1), Felt::from(1)], 8);
let mut trace_3 = simple_fibonacci::fibonacci_trace([Felt::from(1), Felt::from(1)], 8);

let pub_inputs = FibonacciPublicInputs {
a0: Felt::one(),
a1: Felt::one(),
};

let air_1 = FibonacciAIR::<GoldilocksField>::new(&proof_options);
let air_2 = FibonacciAIR::<GoldilocksField>::new(&proof_options);
let air_3 = FibonacciAIR::<GoldilocksField>::new(&proof_options);

let air_trace_pairs: Vec<(
&dyn AIR<
Field = GoldilocksField,
FieldExtension = GoldilocksField,
PublicInputs = FibonacciPublicInputs<GoldilocksField>,
>,
&mut _,
&_,
)> = vec![
(&air_1, &mut trace_1, &pub_inputs),
(&air_2, &mut trace_2, &pub_inputs),
(&air_3, &mut trace_3, &pub_inputs),
];

let multi_proof = Prover::multi_prove(
air_trace_pairs,
&mut DefaultTranscript::<GoldilocksField>::new(&[]),
)
.expect("proving should succeed");

let (hits, misses) = domain_cache_stats::get();
assert_eq!(
misses, 1,
"only one Domain/LdeTwiddles must be constructed for 3 AIRs sharing domain params"
);
assert_eq!(
hits, 2,
"remaining 2 AIRs must hit the cache instead of reconstructing"
);

let airs: Vec<
&dyn AIR<
Field = GoldilocksField,
FieldExtension = GoldilocksField,
PublicInputs = FibonacciPublicInputs<GoldilocksField>,
>,
> = vec![&air_1, &air_2, &air_3];

assert!(
Verifier::multi_verify(
&airs,
&multi_proof,
&mut DefaultTranscript::<GoldilocksField>::new(&[]),
&FieldElement::zero(),
),
"verification should succeed when AIRs share all domain parameters"
);
}
Loading