-
Notifications
You must be signed in to change notification settings - Fork 0
Misc optimizations #545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Misc optimizations #545
Changes from all commits
aa95fd2
6a003a2
d43ae4e
901a716
bccd4d2
c1dd556
3de2694
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,7 +1,42 @@ | ||||||
| /// In-place bit-reverse permutation algorithm. Requires input length to be a power of two. | ||||||
| pub fn in_place_bit_reverse_permute<E>(input: &mut [E]) { | ||||||
| for i in 0..input.len() { | ||||||
| let bit_reversed_index = reverse_index(i, input.len() as u64); | ||||||
| pub fn in_place_bit_reverse_permute<E: Send>(input: &mut [E]) { | ||||||
| let n = input.len(); | ||||||
| #[cfg(feature = "parallel")] | ||||||
| { | ||||||
| // Pair-parallel swap: each pair (i, br(i)) with i < br(i) is independent of all | ||||||
| // other pairs (disjoint indices), so threads can swap concurrently provided they | ||||||
| // never touch the same memory location. `if br > i` selects exactly one owner | ||||||
| // per pair, so no two threads ever write the same slot. | ||||||
| const PARALLEL_BITREV_THRESHOLD: usize = 1 << 14; | ||||||
| if n >= PARALLEL_BITREV_THRESHOLD { | ||||||
| use rayon::prelude::*; | ||||||
| struct SendPtr<E>(*mut E); | ||||||
| impl<E> Copy for SendPtr<E> {} | ||||||
| impl<E> Clone for SendPtr<E> { | ||||||
| fn clone(&self) -> Self { | ||||||
| *self | ||||||
| } | ||||||
| } | ||||||
| unsafe impl<E> Send for SendPtr<E> {} | ||||||
| unsafe impl<E> Sync for SendPtr<E> {} | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Low —
Suggested change
|
||||||
| let ptr = SendPtr(input.as_mut_ptr()); | ||||||
| (0..n).into_par_iter().for_each(|i| { | ||||||
| let br = reverse_index(i, n as u64); | ||||||
| if br > i { | ||||||
| // SAFETY: (i, br) uniquely identifies this pair (smaller index is owner), | ||||||
| // so no two threads race on the same `ptr.0.add(k)` slot. Both indices | ||||||
| // are in-bounds since i < n and br < n. | ||||||
| let p = ptr; | ||||||
| unsafe { | ||||||
| core::ptr::swap(p.0.add(i), p.0.add(br)); | ||||||
| } | ||||||
| } | ||||||
| }); | ||||||
| return; | ||||||
| } | ||||||
| } | ||||||
| for i in 0..n { | ||||||
| let bit_reversed_index = reverse_index(i, n as u64); | ||||||
| if bit_reversed_index > i { | ||||||
| input.swap(i, bit_reversed_index); | ||||||
| } | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,6 +80,37 @@ impl<E: IsField> Polynomial<FieldElement<E>> { | |
| evaluate_fft_cpu::<F, E>(&coeffs) | ||
| } | ||
|
|
||
| /// Same as `evaluate_fft` but returns the evaluations in bit-reversed order, | ||
| /// skipping the final natural-order permutation. Use when the consumer expects | ||
| /// bit-reversed input (e.g. FRI commit phase, which pairs consecutive values as | ||
| /// {f(x), f(-x)}). | ||
| pub fn evaluate_fft_bit_reversed<F: IsFFTField + IsSubFieldOf<E>>( | ||
| poly: &Polynomial<FieldElement<E>>, | ||
| blowup_factor: usize, | ||
| domain_size: Option<usize>, | ||
| ) -> Result<Vec<FieldElement<E>>, FFTError> | ||
| where | ||
| E: Send + Sync, | ||
| { | ||
| let domain_size = domain_size.unwrap_or(0); | ||
| let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor; | ||
| if len.trailing_zeros() as u64 > F::TWO_ADICITY { | ||
| return Err(FFTError::DomainSizeError(len.trailing_zeros() as usize)); | ||
| } | ||
| if poly.coefficients().is_empty() { | ||
| return Ok(vec![FieldElement::zero(); len]); | ||
| } | ||
|
|
||
| let mut coeffs = poly.coefficients().to_vec(); | ||
| coeffs.resize(len, FieldElement::zero()); | ||
|
|
||
| let order = len.trailing_zeros() as u64; | ||
| let layer_twiddles = | ||
| LayerTwiddles::<F>::new(order).ok_or(FFTError::DomainSizeError(order as usize))?; | ||
| dispatch_fft(&mut coeffs, &layer_twiddles)?; | ||
| Ok(coeffs) | ||
| } | ||
|
Comment on lines
+87
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Low — No test for the new public API
#[test]
fn evaluate_fft_bit_reversed_matches_evaluate_fft_permuted() {
use crate::fft::cpu::bit_reversing::in_place_bit_reverse_permute;
let coeffs: Vec<FE> = (0u64..8).map(FE::from).collect();
let poly = Polynomial::new(&coeffs);
let mut expected = Polynomial::evaluate_fft::<F>(&poly, 2, None).unwrap();
in_place_bit_reverse_permute(&mut expected);
let got = Polynomial::evaluate_fft_bit_reversed::<F>(&poly, 2, None).unwrap();
assert_eq!(got, expected);
}
Comment on lines
+87
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Low — Duplicate setup code with
Consider a private helper: fn evaluate_fft_raw<F: IsFFTField + IsSubFieldOf<E>, E: IsField + Send + Sync>(
poly: &Polynomial<FieldElement<E>>,
blowup_factor: usize,
domain_size: Option<usize>,
) -> Result<Vec<FieldElement<E>>, FFTError> {
// shared setup + dispatch_fft, no permutation
}Then |
||
|
|
||
| /// Returns `N` evaluations with an offset of this polynomial using FFT over a domain in a subfield F of E | ||
| /// (so the results are P(w^i), with w being a primitive root of unity). | ||
| /// `N = max(self.coeff_len(), domain_size).next_power_of_two() * blowup_factor`. | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -51,7 +51,29 @@ impl<F: IsField> FieldElement<F> { | |||||
| /// Computes the multiplicative inverses of a slice of field elements | ||||||
| /// The algorithm just performs one inversion and several multiplications and should be used | ||||||
| /// when wanting to invert several elements together | ||||||
| pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> { | ||||||
| pub fn inplace_batch_inverse(numbers: &mut [Self]) -> Result<(), FieldError> | ||||||
| where | ||||||
| Self: Send + Sync, | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Medium — Unconditional API-breaking bound The Consider splitting:
Suggested change
…and adding the |
||||||
| { | ||||||
| #[cfg(feature = "parallel")] | ||||||
| { | ||||||
| // Montgomery batch inverse has a serial prefix-product dependency, but | ||||||
| // chunks are independent — each chunk inverts its own elements without | ||||||
| // needing values from other chunks. Trade K-1 extra field inversions | ||||||
| // (negligible vs ~2N mults per chunk) for K-way parallelism. | ||||||
| const PARALLEL_BATCH_INV_THRESHOLD: usize = 1 << 16; | ||||||
| if numbers.len() >= PARALLEL_BATCH_INV_THRESHOLD { | ||||||
| use rayon::prelude::*; | ||||||
| let chunk_size = numbers.len().div_ceil(rayon::current_num_threads().max(1)); | ||||||
| return numbers | ||||||
| .par_chunks_mut(chunk_size) | ||||||
| .try_for_each(Self::inplace_batch_inverse_sequential); | ||||||
| } | ||||||
| } | ||||||
| Self::inplace_batch_inverse_sequential(numbers) | ||||||
| } | ||||||
|
|
||||||
| fn inplace_batch_inverse_sequential(numbers: &mut [Self]) -> Result<(), FieldError> { | ||||||
| if numbers.is_empty() { | ||||||
| return Ok(()); | ||||||
| } | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Medium — Missing power-of-two guard before unsafe parallel swap
The SAFETY argument for
core::ptr::swaprelies onreverse_indexbeing a bijection on[0, n), which only holds whennis a power of two. The sequential path just produces wrong output if the contract is broken; the parallel path invokes undefined behaviour (a data race) because two threads could swap the same element concurrently.A
debug_assert!catches violations in debug builds at zero release cost: