Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ntt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ license.workspace = true
homepage.workspace = true
repository.workspace = true

[features]
force_b51 = []

[dependencies]
ark-bn254.workspace = true
ark-ff.workspace = true
Expand Down
107 changes: 90 additions & 17 deletions ntt/src/b51_interleaved.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
use {
crate::{define_ntt, extend_roots_table},
ark_bn254::Fr,
bn254_multiplier::{
constants::{self, U64_P_MULTIPLES},
rne,
utils::{self, addv, div_p_2b, subtraction_reduce},
constants::U64_P_MULTIPLES,
utils::{self, div_p_2b, subtraction_reduce},
},
rayon::iter::{IntoParallelRefMutIterator, ParallelIterator},
};
#[cfg(not(kani))]
use {
crate::{define_ntt, extend_roots_table},
ark_bn254::Fr,
bn254_multiplier::{constants, rne, utils::addv},
std::mem,
};

#[cfg(not(kani))]
define_ntt!(interleaved_ntt_nr, [u64; 4], b51_kernel);

#[cfg(not(kani))]
pub fn ntt_nr_b51(values: &mut [Fr], codeword_size: usize, num_groups: usize) {
let new_root = extend_roots_table(codeword_size);
// SAFETY: `Fr` is `#[repr(transparent)]` over `BigInt<4>`, which is
Expand All @@ -22,6 +27,7 @@ pub fn ntt_nr_b51(values: &mut [Fr], codeword_size: usize, num_groups: usize) {
canonicalize_b51(raw);
}

#[cfg(not(kani))]
#[inline(always)]
fn b51_kernel(even: &mut [u64; 4], odd: &mut [u64; 4], omega: &[u64; 4]) {
// rne multiplier will takes any value times <p to a value less than 3p.
Expand All @@ -38,27 +44,31 @@ fn b51_kernel(even: &mut [u64; 4], odd: &mut [u64; 4], omega: &[u64; 4]) {
(*even, *odd) = (l, r);
}

#[inline(always)]
fn canonicalize_b51_element(elem: [u64; 4]) -> [u64; 4] {
let reduced = subtraction_reduce(div_p_2b, elem);
let tentative = utils::sub(reduced, U64_P_MULTIPLES[1]);
if tentative[3] >> 63 == 1 {
reduced
} else {
tentative
}
}

/// Fit values within [0,p). Necessary to be compatible with Ark
fn canonicalize_b51(values: &mut [[u64; 4]]) {
// After the kernel the values are within [0,2.3p]. We use subtraction reduce to
// get the value within 2p such that we can use a conditional subtract.
values.par_iter_mut().for_each(|elem| {
let reduced = subtraction_reduce(div_p_2b, *elem);
let tentative = utils::sub(reduced, U64_P_MULTIPLES[1]);
*elem = if tentative[3] >> 63 == 1 {
reduced
} else {
tentative
};
});
values
.par_iter_mut()
.for_each(|elem| *elem = canonicalize_b51_element(*elem));
}

#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use {
crate::{ark_interleaved::ntt_nr_ark, b51_interleaved::ntt_nr_b51},
ark_bn254::Fr,
ark_ff::BigInt,
ark_ff::{BigInt, PrimeField},
bn254_multiplier::constants::U64_P_MULTIPLES,
proptest::{collection, prelude::*},
};

Expand All @@ -83,5 +93,68 @@ mod tests {

prop_assert_eq!(b51_out, ark_out);
}

#[test]
fn b51_matches_ark_interleaved(
codeword_log2 in 3_usize..=12,
num_groups_log2 in 0_usize..=3,
) {
let codeword_size = 1 << codeword_log2;
let num_groups = 1 << num_groups_log2;
let total = codeword_size * num_groups;
let mut rng = ark_std::test_rng();
let values: Vec<Fr> = (0..total).map(|_| <Fr as ark_ff::UniformRand>::rand(&mut rng)).collect();

let mut b51_out = values.clone();
ntt_nr_b51(&mut b51_out, codeword_size, num_groups);
let mut ark_out = values;
ntt_nr_ark(&mut ark_out, codeword_size, num_groups);

prop_assert_eq!(b51_out, ark_out);
}

// Samples raw [u64;4] directly so inputs can cover the full [0, 3p)
// range the kernel invariant allows (Fr::rand only covers [0, p)).
#[test]
fn canonicalize_b51_is_canonical(
raw in proptest::array::uniform4(0u64..),
) {
use bn254_multiplier::utils;
let below_3p = utils::sub(raw, U64_P_MULTIPLES[3])[3] >> 63 == 1;
prop_assume!(below_3p);

let mut buf = vec![raw];
super::canonicalize_b51(&mut buf);

let bi = BigInt(buf[0]);
prop_assert!(Fr::from_bigint(bi).is_some(),
"canonicalize_b51 left value ≥ p: {:?}", buf[0]);
}
}
}

#[cfg(kani)]
mod kani_proofs {
use {
super::canonicalize_b51_element,
bn254_multiplier::{constants::U64_P_MULTIPLES, utils::sub},
};

fn le256(a: [u64; 4], b: [u64; 4]) -> bool {
for i in (0..4).rev() {
if a[i] != b[i] {
return a[i] < b[i];
}
}
true
}

#[kani::proof]
fn canonicalize_b51_produces_canonical() {
let elem: [u64; 4] = [kani::any(), kani::any(), kani::any(), kani::any()];
kani::assume(le256(elem, U64_P_MULTIPLES[3]));
let result = canonicalize_b51_element(elem);
let diff = sub(result, U64_P_MULTIPLES[1]);
assert!(diff[3] >> 63 == 1, "result must be < p");
}
}
4 changes: 3 additions & 1 deletion ntt/src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ static ENGINE: LazyLock<RwLock<NTTEngine>> = LazyLock::new(|| RwLock::new(NTTEng
/// * `values` - A mutable reference to an NTT container holding the
/// coefficients to be transformed.
pub fn ntt_nr(values: &mut [Fr], codeword_size: usize, num_groups: usize) {
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(not(target_arch = "wasm32"), not(feature = "force_b51")))]
ntt_nr_ark(values, codeword_size, num_groups);
#[cfg(all(not(target_arch = "wasm32"), feature = "force_b51"))]
crate::b51_interleaved::ntt_nr_b51(values, codeword_size, num_groups);
#[cfg(target_arch = "wasm32")]
ntt_nr_b51(values, codeword_size, num_groups);
}
Expand Down
82 changes: 82 additions & 0 deletions ntt/tests/cross_kernel_roundtrip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//! Cross-kernel correctness: simulate "prove with one kernel, verify with the
//! other" at the NTT level. The production dispatch in `ntt_nr()` is
//! target-gated (`cfg(target_arch = "wasm32")`), so on native we can only run
//! one kernel. These tests explicitly drive both `ntt_nr_ark` and
//! `ntt_nr_b51` on the same input and assert byte-identical output — the
//! property on which "prover on wasm, verifier on native" compatibility
//! ultimately rests.

#![cfg(not(target_arch = "wasm32"))]

use {
ark_bn254::Fr,
ark_ff::UniformRand,
ntt::{ark_interleaved::ntt_nr_ark, b51_interleaved::ntt_nr_b51},
};

fn make_values(codeword_log2: u32, num_groups_log2: u32) -> Vec<Fr> {
let total = 1usize << (codeword_log2 + num_groups_log2);
let mut rng = ark_std::test_rng();
(0..total).map(|_| Fr::rand(&mut rng)).collect()
}

#[test]
fn ark_and_b51_agree_across_interleaving_strides() {
for codeword_log2 in [6u32, 10, 14, 16] {
for num_groups_log2 in 0u32..=5 {
let codeword_size = 1usize << codeword_log2;
let num_groups = 1usize << num_groups_log2;
let values = make_values(codeword_log2, num_groups_log2);

let mut ark_out = values.clone();
ntt_nr_ark(&mut ark_out, codeword_size, num_groups);
let mut b51_out = values;
ntt_nr_b51(&mut b51_out, codeword_size, num_groups);

assert_eq!(
ark_out, b51_out,
"mismatch at codeword=2^{codeword_log2}, num_groups=2^{num_groups_log2}"
);
}
}
}

#[test]
fn b51_forward_ark_inverse_roundtrips() {
for codeword_log2 in [8u32, 12, 14] {
let codeword_size = 1usize << codeword_log2;
let values = make_values(codeword_log2, 0);

let mut working = values.clone();
ntt_nr_b51(&mut working, codeword_size, 1);

// intt_rn expects reverse-bit-ordered evaluations → normal-order
// coefficients. ntt_nr produces reverse-bit-ordered evaluations, so we
// can feed directly.
ntt::intt_rn(&mut working);

assert_eq!(
working, values,
"b51→ark roundtrip diverged at size 2^{codeword_log2}"
);
}
}


#[test]
fn ark_forward_matches_b51_forward_then_canonical() {
for codeword_log2 in [8u32, 12, 14] {
let codeword_size = 1usize << codeword_log2;
let values = make_values(codeword_log2, 0);

let mut ark_out = values.clone();
ntt_nr_ark(&mut ark_out, codeword_size, 1);

let mut b51_out = values;
ntt_nr_b51(&mut b51_out, codeword_size, 1);

// b51 already canonicalizes internally via canonicalize_b51, so raw
// limbs should match without post-processing.
assert_eq!(ark_out, b51_out);
}
}
1 change: 1 addition & 0 deletions provekit/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repository.workspace = true
default = ["parallel"]
parallel = []
provekit_ntt = []
force_b51_ntt = ["ntt/force_b51"]

[dependencies]
# Workspace crates
Expand Down
1 change: 1 addition & 0 deletions provekit/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repository.workspace = true
default = ["witness-generation", "parallel"]
witness-generation = ["nargo", "bn254_blackbox_solver", "noir_artifact_cli"]
parallel = ["provekit-common/parallel"]
force_b51_ntt = ["provekit-common/force_b51_ntt"]

[dependencies]
# Workspace crates
Expand Down
13 changes: 8 additions & 5 deletions skyscraper/bn254-multiplier/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
#![feature(portable_simd)]
#![cfg_attr(not(kani), feature(portable_simd))]
//#![no_std] This crate can technically be no_std. However this requires
// replacing StdFloat.mul_add with intrinsics.

#[cfg(target_arch = "aarch64")]
#[cfg(all(target_arch = "aarch64", not(kani)))]
mod aarch64;

// These can be made to work on x86,
// but for now it uses an ARM NEON intrinsic.
#[cfg(target_arch = "aarch64")]
#[cfg(all(target_arch = "aarch64", not(kani)))]
pub mod rtz;

pub mod constants;
#[cfg(not(kani))]
pub mod rne;
#[cfg(not(kani))]
mod scalar;
pub mod utils;

#[cfg(not(target_arch = "wasm32"))] // Proptest not supported on WASI
#[cfg(all(not(target_arch = "wasm32"), not(kani)))]
mod test_utils;

#[cfg(target_arch = "aarch64")]
#[cfg(all(target_arch = "aarch64", not(kani)))]
pub use crate::aarch64::{
montgomery_interleaved_3, montgomery_interleaved_4, montgomery_square_interleaved_3,
montgomery_square_interleaved_4, montgomery_square_log_interleaved_3,
montgomery_square_log_interleaved_4,
};
#[cfg(not(kani))]
pub use crate::scalar::{scalar_mul, scalar_sqr};

const fn pow_2(n: u32) -> f64 {
Expand Down
32 changes: 32 additions & 0 deletions skyscraper/bn254-multiplier/src/rne/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,35 @@ pub mod mono;
pub mod simd_utils;

pub use {batched::*, constants::*, mono::*};

#[cfg(all(test, not(target_arch = "wasm32")))]
mod output_bound_tests {
use {
super::mono,
crate::constants::U64_P_MULTIPLES,
proptest::{prelude::*, test_runner::Config},
};

fn le256(a: [u64; 4], b: [u64; 4]) -> bool {
for i in (0..4).rev() {
if a[i] != b[i] {
return a[i] < b[i];
}
}
true
}

fn below_2_255() -> impl Strategy<Value = [u64; 4]> {
(0u64.., 0u64.., 0u64.., 0u64..(1u64 << 63)).prop_map(|(a, b, c, d)| [a, b, c, d])
}

proptest! {
#![proptest_config(Config { cases: 4096, .. Config::default() })]
#[test]
fn mono_mul_output_under_3p(a in below_2_255(), b in below_2_255()) {
let out = mono::mul(a, b);
prop_assert!(le256(out, U64_P_MULTIPLES[3]),
"mul({:?}, {:?}) = {:?} ≥ 3p", a, b, out);
}
}
}
24 changes: 24 additions & 0 deletions skyscraper/bn254-multiplier/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,28 @@ mod proofs {
assert!(x >= r);
assert!(le256([0, 0, 0, x - r], U64_2P));
}

/// `div_p_2b` output bound is `< 3p` weaker than `div_p_6b`/`div_p_32b`'s`< 2p`.
#[kani::proof]
fn div_p_2b_underapprox() {
use super::div_p_2b;
let x: u64 = kani::any();
let q = div_p_2b(x);
let r = U64_P_MULTIPLES[q as usize][3];
assert!(x >= r);
assert!(le256([0, 0, 0, x - r], U64_P_MULTIPLES[3]));
}

/// Input `≤ 3p` tightens `subtraction_reduce(div_p_2b, ·)` output to `< 2p`.
/// Used by `canonicalize_b51`'s conditional-subtract step.
#[kani::proof]
fn subtraction_reduce_div_p_2b_tight_under_2_3p() {
use super::{div_p_2b, subtraction_reduce};
use crate::constants::U64_P_MULTIPLES;

let elem: [u64; 4] = [kani::any(), kani::any(), kani::any(), kani::any()];
kani::assume(le256(elem, U64_P_MULTIPLES[3]));
let reduced = subtraction_reduce(div_p_2b, elem);
assert!(le256(reduced, U64_P_MULTIPLES[2]));
}
}
1 change: 1 addition & 0 deletions tooling/cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ default = ["profiling-allocator"]
profiling-allocator = []
jemalloc = ["profiling-allocator", "dep:tikv-jemallocator"]
tracy = ["dep:tracing-tracy"]
force_b51_ntt = ["provekit-prover/force_b51_ntt"]
Loading