Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion crypto/math/benches/goldilocks_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,47 @@ fn bench_inv(c: &mut Criterion) {
group.finish();
}

fn bench_inplace_batch_inverse(c: &mut Criterion) {
let mut group = c.benchmark_group("goldilocks_inplace_batch_inverse");

// 2^14: below the parallel threshold (sequential path).
// 2^16: at the parallel threshold.
// 2^20: well above the threshold (parallel path with full chunking).
const BATCH_SIZES: [usize; 3] = [1 << 14, 1 << 16, 1 << 20];

for size in BATCH_SIZES {
let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(9001);
let data: Vec<NativeFE> = (0..size)
.map(|_| {
let mut v = rng.next_u64();
if v == 0 {
v = 1;
}
NativeFE::from(v)
})
.collect();

group.bench_with_input(BenchmarkId::new("native", size), &data, |b, data| {
b.iter_batched(
|| data.clone(),
|mut buf| {
NativeFE::inplace_batch_inverse(black_box(&mut buf)).unwrap();
black_box(buf);
},
criterion::BatchSize::LargeInput,
)
});
}
group.finish();
}

criterion_group!(
benches,
bench_add,
bench_sub,
bench_mul,
bench_square,
bench_inv
bench_inv,
bench_inplace_batch_inverse
);
criterion_main!(benches);
32 changes: 31 additions & 1 deletion crypto/math/src/field/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,38 @@ impl<F: IsField> FieldElement<F> {
// Source: https://en.wikipedia.org/wiki/Modular_multiplicative_inverse#Multiple_inverses
/// Computes the multiplicative inverses of a slice of field elements
/// The algorithm just performs one inversion and several multiplications and should be used
/// when wanting to invert several elements together
/// when wanting to invert several elements together.
///
/// On `Err(InvZeroError)` the input slice is left unchanged (all-or-nothing).
/// The parallel path enforces this with a zero pre-scan; the sequential
/// path checks before any mutation.
pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> {
#[cfg(feature = "parallel")]
{
// Montgomery batch inverse has a serial prefix-product dependency, but
// chunks are independent — each chunk inverts its own elements without
// needing values from other chunks. Trade K-1 extra field inversions
// (negligible vs ~2N mults per chunk) for K-way parallelism.
const PARALLEL_BATCH_INV_THRESHOLD: usize = 1 << 16;
Comment thread
jotabulacios marked this conversation as resolved.
if numbers.len() >= PARALLEL_BATCH_INV_THRESHOLD {
use rayon::prelude::*;
// Pre-scan for zeros so the mutation step is all-or-nothing.
// Without this, a chunk containing zero would return Err while
// sibling chunks may have already overwritten their elements.
let zero = Self::zero();
if numbers.par_iter().any(|x| x == &zero) {
Comment thread
jotabulacios marked this conversation as resolved.
return Err(FieldError::InvZeroError);
}
let chunk_size = numbers.len().div_ceil(rayon::current_num_threads().max(1));
return numbers
.par_chunks_mut(chunk_size)
.try_for_each(Self::inplace_batch_inverse_sequential);
Comment thread
jotabulacios marked this conversation as resolved.
Comment thread
jotabulacios marked this conversation as resolved.
}
}
Self::inplace_batch_inverse_sequential(numbers)
}

fn inplace_batch_inverse_sequential(numbers: &mut [Self]) -> Result<(), FieldError> {
if numbers.is_empty() {
return Ok(());
}
Expand Down
48 changes: 48 additions & 0 deletions crypto/math/src/tests/field_element_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,54 @@ mod tests {
}
}

#[cfg(all(feature = "alloc", feature = "parallel"))]
#[test]
fn test_inplace_batch_inverse_parallel_path() {
// Slice size must exceed the parallel threshold (1 << 16) so the
// parallel branch is actually exercised. Test under several thread
// counts to catch chunking bugs.
use rayon::ThreadPoolBuilder;

let n = (1 << 16) + 17;
let input: Vec<Gfe> = (1..=n as u64).map(Gfe::from).collect();

for num_threads in [1, 2, 4, 8] {
let pool = ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.unwrap();

pool.install(|| {
let mut inverses = input.clone();
FieldElement::inplace_batch_inverse(&mut inverses).unwrap();
for (i, x) in inverses.into_iter().enumerate() {
assert_eq!(
x * input[i],
Gfe::one(),
"x * inv(x) != 1 with {} threads at index {}",
num_threads,
i
);
}
});
}
}

#[cfg(all(feature = "alloc", feature = "parallel"))]
#[test]
fn test_inplace_batch_inverse_parallel_zero_returns_err_without_mutation() {
// A zero in the slice must produce InvZeroError and leave the input
// unchanged (all-or-nothing semantics, matching the sequential path).
let n = (1 << 16) + 1;
let mut input: Vec<Gfe> = (1..=n as u64).map(Gfe::from).collect();
input[n / 2] = Gfe::zero();
let snapshot = input.clone();

let result = FieldElement::inplace_batch_inverse(&mut input);
assert!(result.is_err());
assert_eq!(input, snapshot, "input was partially mutated on Err");
}

// Tests for BigUint conversion using Goldilocks field.
#[test]
fn test_reduced_biguint_conversion_goldilocks() {
Expand Down
Loading