diff --git a/ntt/Cargo.toml b/ntt/Cargo.toml index 79ad4d64a..1411eab31 100644 --- a/ntt/Cargo.toml +++ b/ntt/Cargo.toml @@ -8,6 +8,9 @@ license.workspace = true homepage.workspace = true repository.workspace = true +[features] +force_b51 = [] + [dependencies] ark-bn254.workspace = true ark-ff.workspace = true diff --git a/ntt/src/b51_interleaved.rs b/ntt/src/b51_interleaved.rs index d6fcef158..876ac649d 100644 --- a/ntt/src/b51_interleaved.rs +++ b/ntt/src/b51_interleaved.rs @@ -1,17 +1,22 @@ use { - crate::{define_ntt, extend_roots_table}, - ark_bn254::Fr, bn254_multiplier::{ - constants::{self, U64_P_MULTIPLES}, - rne, - utils::{self, addv, div_p_2b, subtraction_reduce}, + constants::U64_P_MULTIPLES, + utils::{self, div_p_2b, subtraction_reduce}, }, rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}, +}; +#[cfg(not(kani))] +use { + crate::{define_ntt, extend_roots_table}, + ark_bn254::Fr, + bn254_multiplier::{constants, rne, utils::addv}, std::mem, }; +#[cfg(not(kani))] define_ntt!(interleaved_ntt_nr, [u64; 4], b51_kernel); +#[cfg(not(kani))] pub fn ntt_nr_b51(values: &mut [Fr], codeword_size: usize, num_groups: usize) { let new_root = extend_roots_table(codeword_size); // SAFETY: `Fr` is `#[repr(transparent)]` over `BigInt<4>`, which is @@ -22,6 +27,7 @@ pub fn ntt_nr_b51(values: &mut [Fr], codeword_size: usize, num_groups: usize) { canonicalize_b51(raw); } +#[cfg(not(kani))] #[inline(always)] fn b51_kernel(even: &mut [u64; 4], odd: &mut [u64; 4], omega: &[u64; 4]) { // rne multiplier will takes any value times

[u64; 4] { + let reduced = subtraction_reduce(div_p_2b, elem); + let tentative = utils::sub(reduced, U64_P_MULTIPLES[1]); + if tentative[3] >> 63 == 1 { + reduced + } else { + tentative + } +} + /// Fit values within [0,p). Necessary to be compatible with Ark fn canonicalize_b51(values: &mut [[u64; 4]]) { - // After the kernel the values are within [0,2.3p]. We use subtraction reduce to - // get the value within 2p such that we can use a conditional subtract. - values.par_iter_mut().for_each(|elem| { - let reduced = subtraction_reduce(div_p_2b, *elem); - let tentative = utils::sub(reduced, U64_P_MULTIPLES[1]); - *elem = if tentative[3] >> 63 == 1 { - reduced - } else { - tentative - }; - }); + values + .par_iter_mut() + .for_each(|elem| *elem = canonicalize_b51_element(*elem)); } #[cfg(all(test, not(target_arch = "wasm32")))] @@ -58,7 +67,8 @@ mod tests { use { crate::{ark_interleaved::ntt_nr_ark, b51_interleaved::ntt_nr_b51}, ark_bn254::Fr, - ark_ff::BigInt, + ark_ff::{BigInt, PrimeField}, + bn254_multiplier::constants::U64_P_MULTIPLES, proptest::{collection, prelude::*}, }; @@ -83,5 +93,68 @@ mod tests { prop_assert_eq!(b51_out, ark_out); } + + #[test] + fn b51_matches_ark_interleaved( + codeword_log2 in 3_usize..=12, + num_groups_log2 in 0_usize..=3, + ) { + let codeword_size = 1 << codeword_log2; + let num_groups = 1 << num_groups_log2; + let total = codeword_size * num_groups; + let mut rng = ark_std::test_rng(); + let values: Vec = (0..total).map(|_| ::rand(&mut rng)).collect(); + + let mut b51_out = values.clone(); + ntt_nr_b51(&mut b51_out, codeword_size, num_groups); + let mut ark_out = values; + ntt_nr_ark(&mut ark_out, codeword_size, num_groups); + + prop_assert_eq!(b51_out, ark_out); + } + + // Samples raw [u64;4] directly so inputs can cover the full [0, 3p) + // range the kernel invariant allows (Fr::rand only covers [0, p)). + #[test] + fn canonicalize_b51_is_canonical( + raw in proptest::array::uniform4(0u64..), + ) { + use bn254_multiplier::utils; + let below_3p = utils::sub(raw, U64_P_MULTIPLES[3])[3] >> 63 == 1; + prop_assume!(below_3p); + + let mut buf = vec![raw]; + super::canonicalize_b51(&mut buf); + + let bi = BigInt(buf[0]); + prop_assert!(Fr::from_bigint(bi).is_some(), + "canonicalize_b51 left value ≥ p: {:?}", buf[0]); + } + } +} + +#[cfg(kani)] +mod kani_proofs { + use { + super::canonicalize_b51_element, + bn254_multiplier::{constants::U64_P_MULTIPLES, utils::sub}, + }; + + fn le256(a: [u64; 4], b: [u64; 4]) -> bool { + for i in (0..4).rev() { + if a[i] != b[i] { + return a[i] < b[i]; + } + } + true + } + + #[kani::proof] + fn canonicalize_b51_produces_canonical() { + let elem: [u64; 4] = [kani::any(), kani::any(), kani::any(), kani::any()]; + kani::assume(le256(elem, U64_P_MULTIPLES[3])); + let result = canonicalize_b51_element(elem); + let diff = sub(result, U64_P_MULTIPLES[1]); + assert!(diff[3] >> 63 == 1, "result must be < p"); } } diff --git a/ntt/src/ntt.rs b/ntt/src/ntt.rs index 5ede429cd..ad203f7aa 100644 --- a/ntt/src/ntt.rs +++ b/ntt/src/ntt.rs @@ -111,8 +111,10 @@ static ENGINE: LazyLock> = LazyLock::new(|| RwLock::new(NTTEng /// * `values` - A mutable reference to an NTT container holding the /// coefficients to be transformed. pub fn ntt_nr(values: &mut [Fr], codeword_size: usize, num_groups: usize) { - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(not(target_arch = "wasm32"), not(feature = "force_b51")))] ntt_nr_ark(values, codeword_size, num_groups); + #[cfg(all(not(target_arch = "wasm32"), feature = "force_b51"))] + crate::b51_interleaved::ntt_nr_b51(values, codeword_size, num_groups); #[cfg(target_arch = "wasm32")] ntt_nr_b51(values, codeword_size, num_groups); } diff --git a/ntt/tests/cross_kernel_roundtrip.rs b/ntt/tests/cross_kernel_roundtrip.rs new file mode 100644 index 000000000..ad88b5f33 --- /dev/null +++ b/ntt/tests/cross_kernel_roundtrip.rs @@ -0,0 +1,82 @@ +//! Cross-kernel correctness: simulate "prove with one kernel, verify with the +//! other" at the NTT level. The production dispatch in `ntt_nr()` is +//! target-gated (`cfg(target_arch = "wasm32")`), so on native we can only run +//! one kernel. These tests explicitly drive both `ntt_nr_ark` and +//! `ntt_nr_b51` on the same input and assert byte-identical output — the +//! property on which "prover on wasm, verifier on native" compatibility +//! ultimately rests. + +#![cfg(not(target_arch = "wasm32"))] + +use { + ark_bn254::Fr, + ark_ff::UniformRand, + ntt::{ark_interleaved::ntt_nr_ark, b51_interleaved::ntt_nr_b51}, +}; + +fn make_values(codeword_log2: u32, num_groups_log2: u32) -> Vec { + let total = 1usize << (codeword_log2 + num_groups_log2); + let mut rng = ark_std::test_rng(); + (0..total).map(|_| Fr::rand(&mut rng)).collect() +} + +#[test] +fn ark_and_b51_agree_across_interleaving_strides() { + for codeword_log2 in [6u32, 10, 14, 16] { + for num_groups_log2 in 0u32..=5 { + let codeword_size = 1usize << codeword_log2; + let num_groups = 1usize << num_groups_log2; + let values = make_values(codeword_log2, num_groups_log2); + + let mut ark_out = values.clone(); + ntt_nr_ark(&mut ark_out, codeword_size, num_groups); + let mut b51_out = values; + ntt_nr_b51(&mut b51_out, codeword_size, num_groups); + + assert_eq!( + ark_out, b51_out, + "mismatch at codeword=2^{codeword_log2}, num_groups=2^{num_groups_log2}" + ); + } + } +} + +#[test] +fn b51_forward_ark_inverse_roundtrips() { + for codeword_log2 in [8u32, 12, 14] { + let codeword_size = 1usize << codeword_log2; + let values = make_values(codeword_log2, 0); + + let mut working = values.clone(); + ntt_nr_b51(&mut working, codeword_size, 1); + + // intt_rn expects reverse-bit-ordered evaluations → normal-order + // coefficients. ntt_nr produces reverse-bit-ordered evaluations, so we + // can feed directly. + ntt::intt_rn(&mut working); + + assert_eq!( + working, values, + "b51→ark roundtrip diverged at size 2^{codeword_log2}" + ); + } +} + + +#[test] +fn ark_forward_matches_b51_forward_then_canonical() { + for codeword_log2 in [8u32, 12, 14] { + let codeword_size = 1usize << codeword_log2; + let values = make_values(codeword_log2, 0); + + let mut ark_out = values.clone(); + ntt_nr_ark(&mut ark_out, codeword_size, 1); + + let mut b51_out = values; + ntt_nr_b51(&mut b51_out, codeword_size, 1); + + // b51 already canonicalizes internally via canonicalize_b51, so raw + // limbs should match without post-processing. + assert_eq!(ark_out, b51_out); + } +} diff --git a/provekit/common/Cargo.toml b/provekit/common/Cargo.toml index d39bc0535..ad1fca981 100644 --- a/provekit/common/Cargo.toml +++ b/provekit/common/Cargo.toml @@ -12,6 +12,7 @@ repository.workspace = true default = ["parallel"] parallel = [] provekit_ntt = [] +force_b51_ntt = ["ntt/force_b51"] [dependencies] # Workspace crates diff --git a/provekit/prover/Cargo.toml b/provekit/prover/Cargo.toml index 82f848326..8abf8b0ee 100644 --- a/provekit/prover/Cargo.toml +++ b/provekit/prover/Cargo.toml @@ -12,6 +12,7 @@ repository.workspace = true default = ["witness-generation", "parallel"] witness-generation = ["nargo", "bn254_blackbox_solver", "noir_artifact_cli"] parallel = ["provekit-common/parallel"] +force_b51_ntt = ["provekit-common/force_b51_ntt"] [dependencies] # Workspace crates diff --git a/skyscraper/bn254-multiplier/src/lib.rs b/skyscraper/bn254-multiplier/src/lib.rs index 3b4d99bfa..e7a2ce410 100644 --- a/skyscraper/bn254-multiplier/src/lib.rs +++ b/skyscraper/bn254-multiplier/src/lib.rs @@ -1,29 +1,32 @@ -#![feature(portable_simd)] +#![cfg_attr(not(kani), feature(portable_simd))] //#![no_std] This crate can technically be no_std. However this requires // replacing StdFloat.mul_add with intrinsics. -#[cfg(target_arch = "aarch64")] +#[cfg(all(target_arch = "aarch64", not(kani)))] mod aarch64; // These can be made to work on x86, // but for now it uses an ARM NEON intrinsic. -#[cfg(target_arch = "aarch64")] +#[cfg(all(target_arch = "aarch64", not(kani)))] pub mod rtz; pub mod constants; +#[cfg(not(kani))] pub mod rne; +#[cfg(not(kani))] mod scalar; pub mod utils; -#[cfg(not(target_arch = "wasm32"))] // Proptest not supported on WASI +#[cfg(all(not(target_arch = "wasm32"), not(kani)))] mod test_utils; -#[cfg(target_arch = "aarch64")] +#[cfg(all(target_arch = "aarch64", not(kani)))] pub use crate::aarch64::{ montgomery_interleaved_3, montgomery_interleaved_4, montgomery_square_interleaved_3, montgomery_square_interleaved_4, montgomery_square_log_interleaved_3, montgomery_square_log_interleaved_4, }; +#[cfg(not(kani))] pub use crate::scalar::{scalar_mul, scalar_sqr}; const fn pow_2(n: u32) -> f64 { diff --git a/skyscraper/bn254-multiplier/src/rne/mod.rs b/skyscraper/bn254-multiplier/src/rne/mod.rs index a22af985a..d6507f2fa 100644 --- a/skyscraper/bn254-multiplier/src/rne/mod.rs +++ b/skyscraper/bn254-multiplier/src/rne/mod.rs @@ -28,3 +28,35 @@ pub mod mono; pub mod simd_utils; pub use {batched::*, constants::*, mono::*}; + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod output_bound_tests { + use { + super::mono, + crate::constants::U64_P_MULTIPLES, + proptest::{prelude::*, test_runner::Config}, + }; + + fn le256(a: [u64; 4], b: [u64; 4]) -> bool { + for i in (0..4).rev() { + if a[i] != b[i] { + return a[i] < b[i]; + } + } + true + } + + fn below_2_255() -> impl Strategy { + (0u64.., 0u64.., 0u64.., 0u64..(1u64 << 63)).prop_map(|(a, b, c, d)| [a, b, c, d]) + } + + proptest! { + #![proptest_config(Config { cases: 4096, .. Config::default() })] + #[test] + fn mono_mul_output_under_3p(a in below_2_255(), b in below_2_255()) { + let out = mono::mul(a, b); + prop_assert!(le256(out, U64_P_MULTIPLES[3]), + "mul({:?}, {:?}) = {:?} ≥ 3p", a, b, out); + } + } +} diff --git a/skyscraper/bn254-multiplier/src/utils.rs b/skyscraper/bn254-multiplier/src/utils.rs index 693cfcc03..3897a5ad1 100644 --- a/skyscraper/bn254-multiplier/src/utils.rs +++ b/skyscraper/bn254-multiplier/src/utils.rs @@ -234,4 +234,28 @@ mod proofs { assert!(x >= r); assert!(le256([0, 0, 0, x - r], U64_2P)); } + + /// `div_p_2b` output bound is `< 3p` weaker than `div_p_6b`/`div_p_32b`'s`< 2p`. + #[kani::proof] + fn div_p_2b_underapprox() { + use super::div_p_2b; + let x: u64 = kani::any(); + let q = div_p_2b(x); + let r = U64_P_MULTIPLES[q as usize][3]; + assert!(x >= r); + assert!(le256([0, 0, 0, x - r], U64_P_MULTIPLES[3])); + } + + /// Input `≤ 3p` tightens `subtraction_reduce(div_p_2b, ·)` output to `< 2p`. + /// Used by `canonicalize_b51`'s conditional-subtract step. + #[kani::proof] + fn subtraction_reduce_div_p_2b_tight_under_2_3p() { + use super::{div_p_2b, subtraction_reduce}; + use crate::constants::U64_P_MULTIPLES; + + let elem: [u64; 4] = [kani::any(), kani::any(), kani::any(), kani::any()]; + kani::assume(le256(elem, U64_P_MULTIPLES[3])); + let reduced = subtraction_reduce(div_p_2b, elem); + assert!(le256(reduced, U64_P_MULTIPLES[2])); + } } diff --git a/tooling/cli/Cargo.toml b/tooling/cli/Cargo.toml index 59369b84d..4a864d3c0 100644 --- a/tooling/cli/Cargo.toml +++ b/tooling/cli/Cargo.toml @@ -48,3 +48,4 @@ default = ["profiling-allocator"] profiling-allocator = [] jemalloc = ["profiling-allocator", "dep:tikv-jemallocator"] tracy = ["dep:tracing-tracy"] +force_b51_ntt = ["provekit-prover/force_b51_ntt"]