From 2896c6ba075bf682cd10b6d17dc7f784bbf25777 Mon Sep 17 00:00:00 2001 From: jotabulacios Date: Fri, 24 Apr 2026 16:46:10 -0300 Subject: [PATCH 1/2] Parallelize inplace_batch_inverse --- crypto/math/src/field/element.rs | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/crypto/math/src/field/element.rs b/crypto/math/src/field/element.rs index 9c2ac3258..e34ec0fb7 100644 --- a/crypto/math/src/field/element.rs +++ b/crypto/math/src/field/element.rs @@ -51,7 +51,29 @@ impl FieldElement { /// Computes the multiplicative inverses of a slice of field elements /// The algorithm just performs one inversion and several multiplications and should be used /// when wanting to invert several elements together - pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> { + pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> + where + Self: Send + Sync, + { + #[cfg(feature = "parallel")] + { + // Montgomery batch inverse has a serial prefix-product dependency, but + // chunks are independent — each chunk inverts its own elements without + // needing values from other chunks. Trade K-1 extra field inversions + // (negligible vs ~2N mults per chunk) for K-way parallelism. + const PARALLEL_BATCH_INV_THRESHOLD: usize = 1 << 16; + if numbers.len() >= PARALLEL_BATCH_INV_THRESHOLD { + use rayon::prelude::*; + let chunk_size = numbers.len().div_ceil(rayon::current_num_threads().max(1)); + return numbers + .par_chunks_mut(chunk_size) + .try_for_each(Self::inplace_batch_inverse_sequential); + } + } + Self::inplace_batch_inverse_sequential(numbers) + } + + fn inplace_batch_inverse_sequential(numbers: &mut [Self]) -> Result<(), FieldError> { if numbers.is_empty() { return Ok(()); } From 396dbdeb7b1d13de10f0290638404c2047fb8f94 Mon Sep 17 00:00:00 2001 From: jotabulacios Date: Tue, 28 Apr 2026 09:23:26 -0300 Subject: [PATCH 2/2] addres comments --- crypto/math/benches/goldilocks_benchmark.rs | 37 ++++++++++++++- crypto/math/src/field/element.rs | 18 ++++++-- crypto/math/src/tests/field_element_tests.rs | 48 ++++++++++++++++++++ 3 files changed, 97 insertions(+), 6 deletions(-) diff --git a/crypto/math/benches/goldilocks_benchmark.rs b/crypto/math/benches/goldilocks_benchmark.rs index a7518205b..7ef2e240d 100644 --- a/crypto/math/benches/goldilocks_benchmark.rs +++ b/crypto/math/benches/goldilocks_benchmark.rs @@ -114,12 +114,47 @@ fn bench_inv(c: &mut Criterion) { group.finish(); } +fn bench_inplace_batch_inverse(c: &mut Criterion) { + let mut group = c.benchmark_group("goldilocks_inplace_batch_inverse"); + + // 2^14: below the parallel threshold (sequential path). + // 2^16: at the parallel threshold. + // 2^20: well above the threshold (parallel path with full chunking). + const BATCH_SIZES: [usize; 3] = [1 << 14, 1 << 16, 1 << 20]; + + for size in BATCH_SIZES { + let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(9001); + let data: Vec = (0..size) + .map(|_| { + let mut v = rng.next_u64(); + if v == 0 { + v = 1; + } + NativeFE::from(v) + }) + .collect(); + + group.bench_with_input(BenchmarkId::new("native", size), &data, |b, data| { + b.iter_batched( + || data.clone(), + |mut buf| { + NativeFE::inplace_batch_inverse(black_box(&mut buf)).unwrap(); + black_box(buf); + }, + criterion::BatchSize::LargeInput, + ) + }); + } + group.finish(); +} + criterion_group!( benches, bench_add, bench_sub, bench_mul, bench_square, - bench_inv + bench_inv, + bench_inplace_batch_inverse ); criterion_main!(benches); diff --git a/crypto/math/src/field/element.rs b/crypto/math/src/field/element.rs index e34ec0fb7..a94886763 100644 --- a/crypto/math/src/field/element.rs +++ b/crypto/math/src/field/element.rs @@ -50,11 +50,12 @@ impl FieldElement { // Source: https://en.wikipedia.org/wiki/Modular_multiplicative_inverse#Multiple_inverses /// Computes the multiplicative inverses of a slice of field elements /// The algorithm just performs one inversion and several multiplications and should be used - /// when wanting to invert several elements together - pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> - where - Self: Send + Sync, - { + /// when wanting to invert several elements together. + /// + /// On `Err(InvZeroError)` the input slice is left unchanged (all-or-nothing). + /// The parallel path enforces this with a zero pre-scan; the sequential + /// path checks before any mutation. + pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> { #[cfg(feature = "parallel")] { // Montgomery batch inverse has a serial prefix-product dependency, but @@ -64,6 +65,13 @@ impl FieldElement { const PARALLEL_BATCH_INV_THRESHOLD: usize = 1 << 16; if numbers.len() >= PARALLEL_BATCH_INV_THRESHOLD { use rayon::prelude::*; + // Pre-scan for zeros so the mutation step is all-or-nothing. + // Without this, a chunk containing zero would return Err while + // sibling chunks may have already overwritten their elements. + let zero = Self::zero(); + if numbers.par_iter().any(|x| x == &zero) { + return Err(FieldError::InvZeroError); + } let chunk_size = numbers.len().div_ceil(rayon::current_num_threads().max(1)); return numbers .par_chunks_mut(chunk_size) diff --git a/crypto/math/src/tests/field_element_tests.rs b/crypto/math/src/tests/field_element_tests.rs index 5054e9fff..cbe7c26df 100644 --- a/crypto/math/src/tests/field_element_tests.rs +++ b/crypto/math/src/tests/field_element_tests.rs @@ -111,6 +111,54 @@ mod tests { } } + #[cfg(all(feature = "alloc", feature = "parallel"))] + #[test] + fn test_inplace_batch_inverse_parallel_path() { + // Slice size must exceed the parallel threshold (1 << 16) so the + // parallel branch is actually exercised. Test under several thread + // counts to catch chunking bugs. + use rayon::ThreadPoolBuilder; + + let n = (1 << 16) + 17; + let input: Vec = (1..=n as u64).map(Gfe::from).collect(); + + for num_threads in [1, 2, 4, 8] { + let pool = ThreadPoolBuilder::new() + .num_threads(num_threads) + .build() + .unwrap(); + + pool.install(|| { + let mut inverses = input.clone(); + FieldElement::inplace_batch_inverse(&mut inverses).unwrap(); + for (i, x) in inverses.into_iter().enumerate() { + assert_eq!( + x * input[i], + Gfe::one(), + "x * inv(x) != 1 with {} threads at index {}", + num_threads, + i + ); + } + }); + } + } + + #[cfg(all(feature = "alloc", feature = "parallel"))] + #[test] + fn test_inplace_batch_inverse_parallel_zero_returns_err_without_mutation() { + // A zero in the slice must produce InvZeroError and leave the input + // unchanged (all-or-nothing semantics, matching the sequential path). + let n = (1 << 16) + 1; + let mut input: Vec = (1..=n as u64).map(Gfe::from).collect(); + input[n / 2] = Gfe::zero(); + let snapshot = input.clone(); + + let result = FieldElement::inplace_batch_inverse(&mut input); + assert!(result.is_err()); + assert_eq!(input, snapshot, "input was partially mutated on Err"); + } + // Tests for BigUint conversion using Goldilocks field. #[test] fn test_reduced_biguint_conversion_goldilocks() {