Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 6 additions & 103 deletions crypto/stark/src/constraints/boundary.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
use itertools::Itertools;
use math::{
field::{element::FieldElement, traits::IsField},
polynomial::Polynomial,
};
use math::field::{element::FieldElement, traits::IsField};

#[derive(Debug)]
/// Represents a boundary constraint that must hold in an execution
/// trace:
/// Represents a boundary constraint that must hold in an execution trace:
/// * col: The column of the trace where the constraint must hold
/// * step: The step (or row) of the trace where the constraint must hold
/// * value: The value the constraint must have in that column and step
#[derive(Debug)]
pub struct BoundaryConstraint<F: IsField> {
pub col: usize,
pub step: usize,
Expand All @@ -36,7 +31,7 @@ impl<F: IsField> BoundaryConstraint<F> {
}
}

/// Used for creating boundary constraints for a trace with only one column
/// Boundary constraint for a trace with a single main column.
pub fn new_simple_main(step: usize, value: FieldElement<F>) -> Self {
Self {
col: 0,
Expand All @@ -45,109 +40,17 @@ impl<F: IsField> BoundaryConstraint<F> {
is_aux: false,
}
}

/// Used for creating boundary constraints for a trace with only one column
pub fn new_simple_aux(step: usize, value: FieldElement<F>) -> Self {
Self {
col: 0,
step,
value,
is_aux: true,
}
}
}

/// Data structure that stores all the boundary constraints that must
/// hold for the execution trace
/// All the boundary constraints that must hold for an execution trace.
#[derive(Default, Debug)]
pub struct BoundaryConstraints<F: IsField> {
pub constraints: Vec<BoundaryConstraint<F>>,
}

impl<F: IsField> BoundaryConstraints<F> {
/// To instantiate from a vector of BoundaryConstraint elements
/// Instantiate from a vector of `BoundaryConstraint` elements.
pub fn from_constraints(constraints: Vec<BoundaryConstraint<F>>) -> Self {
Self { constraints }
}

/// Returns all the steps where boundary conditions exist for the given column
pub fn steps(&self, col: usize) -> Vec<usize> {
self.constraints
.iter()
.filter(|v| v.col == col)
.map(|c| c.step)
.collect()
}

pub fn steps_for_boundary(&self) -> Vec<usize> {
self.constraints
.iter()
.unique_by(|elem| elem.step)
.map(|v| v.step)
.collect()
}

pub fn cols_for_boundary(&self) -> Vec<usize> {
self.constraints
.iter()
.unique_by(|elem| elem.col)
.map(|v| v.col)
.collect()
}

/// Given the primitive root of some domain, returns the domain values corresponding
/// to the steps where the boundary conditions hold. This is useful when interpolating
/// the boundary conditions, since we must know the x values
pub fn generate_roots_of_unity(
&self,
primitive_root: &FieldElement<F>,
cols_trace: &[usize],
) -> Vec<Vec<FieldElement<F>>> {
cols_trace
.iter()
.map(|i| {
self.steps(*i)
.into_iter()
.map(|s| primitive_root.pow(s))
.collect::<Vec<FieldElement<F>>>()
})
.collect::<Vec<Vec<FieldElement<F>>>>()
}

/// For every trace column, give all the values the trace must be equal to in
/// the steps where the boundary constraints hold
pub fn values(&self, cols_trace: &[usize]) -> Vec<Vec<FieldElement<F>>> {
cols_trace
.iter()
.map(|i| {
self.constraints
.iter()
.filter(|c| c.col == *i)
.map(|c| c.value.clone())
.collect()
})
.collect()
}

/// Computes the zerofier of the boundary quotient. The result is the
/// multiplication of each binomial that evaluates to zero in the domain
/// values where the boundary constraints must hold.
///
/// Example: If there are boundary conditions in the third and fifth steps,
/// then the zerofier will be (x - w^3) * (x - w^5)
pub fn compute_zerofier(
&self,
primitive_root: &FieldElement<F>,
col: usize,
) -> Polynomial<FieldElement<F>> {
self.steps(col).into_iter().fold(
Polynomial::new_monomial(FieldElement::<F>::one(), 0),
|zerofier, step| {
let binomial =
Polynomial::new(&[-primitive_root.pow(step), FieldElement::<F>::one()]);
// TODO: Implement the MulAssign trait for Polynomials?
zerofier * binomial
},
)
}
}
171 changes: 71 additions & 100 deletions crypto/stark/src/constraints/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,69 @@ where
let num_aux_cols = lde_trace.num_aux_cols();
let num_offsets = offsets.len();

// Per-row evaluation, shared by the parallel and sequential paths below:
// fill the frame, evaluate transition constraints, accumulate with zerofiers.
let eval_row = |i: usize,
boundary: FieldElement<FieldExtension>,
transition_buf: &mut [FieldElement<FieldExtension>],
base_buf: &mut [FieldElement<Field>],
periodic_buf: &mut [FieldElement<Field>],
frame: &mut Frame<Field, FieldExtension>|
-> FieldElement<FieldExtension> {
frame.fill_from_lde(lde_trace, i, offsets);

for (j, col) in lde_periodic_columns.iter().enumerate() {
periodic_buf[j] = col[i].clone();
}

let ctx = TransitionEvaluationContext::new_prover(
frame,
periodic_buf,
rap_challenges,
&logup_alpha_powers,
logup_table_offset,
&packing_shifts,
);
air.compute_transition_prover(&ctx, base_buf, transition_buf);

let acc_transition = if is_uniform {
// All constraints share one zerofier: factor it out of the sum.
let z = zerofier_data.get_uniform(i);
// F×E inner product for base constraints (3 muls per term)
let mut sum = base_buf
.iter()
.zip(&transition_coefficients[..num_base])
.fold(FieldElement::zero(), |acc, (eval, beta)| acc + eval * beta);
// E×E for extension constraints (9 muls per term)
sum = transition_buf[num_base..]
.iter()
.zip(&transition_coefficients[num_base..])
.fold(sum, |acc, (eval, beta)| acc + eval * beta);
z * &sum
} else {
let mut sum = base_buf
.iter()
.enumerate()
.zip(&transition_coefficients[..num_base])
.fold(FieldElement::zero(), |acc, ((c_idx, eval), beta)| {
acc + zerofier_data.get(c_idx, i) * eval * beta
});
sum = transition_buf[num_base..]
.iter()
.enumerate()
.zip(&transition_coefficients[num_base..])
.fold(sum, |acc, ((j, eval), beta)| {
acc + zerofier_data.get(num_base + j, i) * eval * beta
});
sum
};

acc_transition + boundary
};

#[cfg(feature = "parallel")]
{
let evaluations_t: Vec<_> = boundary_evaluation
boundary_evaluation
.into_par_iter()
.enumerate()
.map_init(
Expand All @@ -94,59 +154,10 @@ where
)
},
|(transition_buf, base_buf, periodic_buf, frame), (i, boundary)| {
frame.fill_from_lde(lde_trace, i, offsets);

for (j, col) in lde_periodic_columns.iter().enumerate() {
periodic_buf[j] = col[i].clone();
}

let ctx = TransitionEvaluationContext::new_prover(
frame,
periodic_buf,
rap_challenges,
&logup_alpha_powers,
logup_table_offset,
&packing_shifts,
);
air.compute_transition_prover(&ctx, base_buf, transition_buf);

let acc_transition = if is_uniform {
// All constraints share one zerofier: factor it out of the sum.
let z = zerofier_data.get_uniform(i);
// F×E inner product for base constraints (3 muls per term)
let mut sum = base_buf
.iter()
.zip(&transition_coefficients[..num_base])
.fold(FieldElement::zero(), |acc, (eval, beta)| acc + eval * beta);
// E×E for extension constraints (9 muls per term)
sum = transition_buf[num_base..]
.iter()
.zip(&transition_coefficients[num_base..])
.fold(sum, |acc, (eval, beta)| acc + eval * beta);
z * &sum
} else {
let mut sum = base_buf
.iter()
.enumerate()
.zip(&transition_coefficients[..num_base])
.fold(FieldElement::zero(), |acc, ((c_idx, eval), beta)| {
acc + zerofier_data.get(c_idx, i) * eval * beta
});
sum = transition_buf[num_base..]
.iter()
.enumerate()
.zip(&transition_coefficients[num_base..])
.fold(sum, |acc, ((j, eval), beta)| {
acc + zerofier_data.get(num_base + j, i) * eval * beta
});
sum
};

acc_transition + boundary
eval_row(i, boundary, transition_buf, base_buf, periodic_buf, frame)
},
)
.collect();
evaluations_t
.collect()
}

#[cfg(not(feature = "parallel"))]
Expand All @@ -161,54 +172,14 @@ where
.into_iter()
.enumerate()
.map(|(i, boundary)| {
frame.fill_from_lde(lde_trace, i, offsets);

for (j, col) in lde_periodic_columns.iter().enumerate() {
periodic_buf[j] = col[i].clone();
}

let ctx = TransitionEvaluationContext::new_prover(
&frame,
&periodic_buf,
rap_challenges,
&logup_alpha_powers,
logup_table_offset,
&packing_shifts,
);
air.compute_transition_prover(&ctx, &mut base_buf, &mut transition_buf);

let acc_transition = if is_uniform {
let z = zerofier_data.get_uniform(i);
// F×E inner product for base constraints (3 muls per term)
let mut sum = base_buf
.iter()
.zip(&transition_coefficients[..num_base])
.fold(FieldElement::zero(), |acc, (eval, beta)| acc + eval * beta);
// E×E for extension constraints (9 muls per term)
sum = transition_buf[num_base..]
.iter()
.zip(&transition_coefficients[num_base..])
.fold(sum, |acc, (eval, beta)| acc + eval * beta);
z * &sum
} else {
let mut sum = base_buf
.iter()
.enumerate()
.zip(&transition_coefficients[..num_base])
.fold(FieldElement::zero(), |acc, ((c_idx, eval), beta)| {
acc + zerofier_data.get(c_idx, i) * eval * beta
});
sum = transition_buf[num_base..]
.iter()
.enumerate()
.zip(&transition_coefficients[num_base..])
.fold(sum, |acc, ((j, eval), beta)| {
acc + zerofier_data.get(num_base + j, i) * eval * beta
});
sum
};

acc_transition + boundary
eval_row(
i,
boundary,
&mut transition_buf,
&mut base_buf,
&mut periodic_buf,
&mut frame,
)
})
.collect()
}
Expand Down
12 changes: 7 additions & 5 deletions crypto/stark/src/constraints/transition.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::Div;
use core::ops::Div;

use crate::domain::Domain;
use crate::prover::evaluate_polynomial_on_lde_domain;
Expand Down Expand Up @@ -205,7 +205,7 @@ where
.cycle()
.take(end_exemption_evaluations.len());

std::iter::zip(cycled_evaluations, end_exemption_evaluations)
core::iter::zip(cycled_evaluations, end_exemption_evaluations)
.map(|(eval, exemption_eval)| eval * exemption_eval)
.collect()

Expand Down Expand Up @@ -246,7 +246,7 @@ where
.cycle()
.take(end_exemption_evaluations.len());

std::iter::zip(cycled_evaluations, end_exemption_evaluations)
core::iter::zip(cycled_evaluations, end_exemption_evaluations)
.map(|(eval, exemption_eval)| eval * exemption_eval)
.collect()
}
Expand Down Expand Up @@ -276,8 +276,10 @@ where
let denominator = -trace_primitive_root
.pow(self.offset() * trace_length / self.period())
+ z.pow(trace_length / self.period());
// The denominator isn't zero because z is sampled outside the set of primitive roots.
return unsafe { numerator.div(denominator).unwrap_unchecked() }
// The denominator is non-zero: z is sampled outside the set of primitive roots.
return numerator
.div(denominator)
.expect("zerofier denominator is non-zero: z is sampled out-of-domain")
* end_exemptions_poly.evaluate(z);
}

Expand Down
Loading
Loading