From 8403ccded92f5fcbecbc5cde64d3d0720b83432e Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Thu, 29 Jan 2026 14:18:27 -0700 Subject: [PATCH] ml-kem: use `NttVector` from `module-lattice` Continues replacing the types in `algebra.rs` with generic versions from the `module-lattice` crate. --- ml-kem/src/algebra.rs | 124 +++++++++------------------------- module-lattice/src/algebra.rs | 18 +++-- 2 files changed, 46 insertions(+), 96 deletions(-) diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index 79b568d..b19a5bd 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -1,20 +1,16 @@ use array::{Array, typenum::U256}; -use core::ops::{Add, Mul}; +use core::ops::Mul; use module_lattice::{ algebra::{Field, MultiplyNtt}, util::Truncate, }; use sha3::digest::XofReader; -use subtle::{Choice, ConstantTimeEq}; use crate::B32; use crate::crypto::{PRF, PrfOutput, XOF}; use crate::encode::Encode; use crate::param::{ArraySize, CbdSamplingSize}; -#[cfg(feature = "zeroize")] -use zeroize::Zeroize; - module_lattice::define_field!(BaseField, u16, u32, u64, 3329); pub type Int = ::Int; @@ -31,6 +27,9 @@ pub type Vector = module_lattice::algebra::Vector; /// An element of the ring `T_q` i.e. a tuple of 128 elements of the direct sum components of `T_q`. pub type NttPolynomial = module_lattice::algebra::NttPolynomial; +/// A vector of K NTT-domain polynomials. +pub type NttVector = module_lattice::algebra::NttVector; + /// Algorithm 7: `SampleNTT(B)` pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial { struct FieldElementReader<'a> { @@ -154,7 +153,7 @@ impl Ntt for Vector { type Output = NttVector; fn ntt(&self) -> NttVector { - NttVector(self.0.iter().map(Ntt::ntt).collect()) + NttVector::new(self.0.iter().map(Ntt::ntt).collect()) } } @@ -191,6 +190,14 @@ impl NttInverse for NttPolynomial { } } +impl NttInverse for NttVector { + type Output = Vector; + + fn ntt_inverse(&self) -> Vector { + Vector::new(self.0.iter().map(NttInverse::ntt_inverse).collect()) + } +} + /// Algorithm 11: `MultiplyNTTs` impl MultiplyNtt for BaseField { fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial { @@ -294,78 +301,6 @@ const GAMMA: [Elem; 128] = { gamma }; -/// A vector of K NTT-domain polynomials -#[derive(Clone, Default, Debug)] -pub struct NttVector(pub Array); - -impl NttVector { - pub fn sample_uniform(rho: &B32, i: usize, transpose: bool) -> Self { - Self(Array::from_fn(|j| { - let (i, j) = if transpose { (j, i) } else { (i, j) }; - let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i)); - sample_ntt(&mut xof) - })) - } -} - -impl ConstantTimeEq for NttVector { - fn ct_eq(&self, other: &Self) -> Choice { - self.0.ct_eq(&other.0) - } -} - -impl Eq for NttVector {} -impl PartialEq for NttVector { - fn eq(&self, other: &Self) -> bool { - // Impl `PartialEq` in constant-time, in case this value contains a secret - self.0.ct_eq(&other.0).into() - } -} - -#[cfg(feature = "zeroize")] -impl Zeroize for NttVector -where - K: ArraySize, -{ - fn zeroize(&mut self) { - for poly in &mut self.0 { - poly.zeroize(); - } - } -} - -impl Add<&NttVector> for &NttVector { - type Output = NttVector; - - fn add(self, rhs: &NttVector) -> NttVector { - NttVector( - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(x, y)| x + y) - .collect(), - ) - } -} - -impl Mul<&NttVector> for &NttVector { - type Output = NttPolynomial; - - fn mul(self, rhs: &NttVector) -> NttPolynomial { - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(x, y)| x * y) - .fold(NttPolynomial::default(), |x, y| &x + &y) - } -} - -impl NttVector { - pub fn ntt_inverse(&self) -> Vector { - Vector::new(self.0.iter().map(NttInverse::ntt_inverse).collect()) - } -} - /// A K x K matrix of NTT-domain polynomials. Each vector represents a row of the matrix, so that /// multiplying on the right just requires iteration. #[derive(Clone, Default, Debug, PartialEq)] @@ -375,20 +310,24 @@ impl Mul<&NttVector> for &NttMatrix { type Output = NttVector; fn mul(self, rhs: &NttVector) -> NttVector { - NttVector(self.0.iter().map(|x| x * rhs).collect()) + NttVector::new(self.0.iter().map(|x| x * rhs).collect()) } } impl NttMatrix { pub fn sample_uniform(rho: &B32, transpose: bool) -> Self { Self(Array::from_fn(|i| { - NttVector::sample_uniform(rho, i, transpose) + NttVector::new(Array::from_fn(|j| { + let (i, j) = if transpose { (j, i) } else { (i, j) }; + let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i)); + sample_ntt(&mut xof) + })) })) } pub fn transpose(&self) -> Self { Self(Array::from_fn(|i| { - NttVector(Array::from_fn(|j| self.0[j].0[i].clone())) + NttVector::new(Array::from_fn(|j| self.0[j].0[i].clone())) })) } } @@ -462,9 +401,9 @@ mod test { #[test] fn ntt_vector() { // Verify vector addition - let v1: NttVector = NttVector(Array([const_ntt(1), const_ntt(1), const_ntt(1)])); - let v2: NttVector = NttVector(Array([const_ntt(2), const_ntt(2), const_ntt(2)])); - let v3: NttVector = NttVector(Array([const_ntt(3), const_ntt(3), const_ntt(3)])); + let v1: NttVector = NttVector::new(Array([const_ntt(1), const_ntt(1), const_ntt(1)])); + let v2: NttVector = NttVector::new(Array([const_ntt(2), const_ntt(2), const_ntt(2)])); + let v3: NttVector = NttVector::new(Array([const_ntt(3), const_ntt(3), const_ntt(3)])); assert_eq!((&v1 + &v2), v3); // Verify dot product @@ -477,19 +416,20 @@ mod test { fn ntt_matrix() { // Verify matrix multiplication by a vector let a: NttMatrix = NttMatrix(Array([ - NttVector(Array([const_ntt(1), const_ntt(2), const_ntt(3)])), - NttVector(Array([const_ntt(4), const_ntt(5), const_ntt(6)])), - NttVector(Array([const_ntt(7), const_ntt(8), const_ntt(9)])), + NttVector::new(Array([const_ntt(1), const_ntt(2), const_ntt(3)])), + NttVector::new(Array([const_ntt(4), const_ntt(5), const_ntt(6)])), + NttVector::new(Array([const_ntt(7), const_ntt(8), const_ntt(9)])), ])); - let v_in: NttVector = NttVector(Array([const_ntt(1), const_ntt(2), const_ntt(3)])); - let v_out: NttVector = NttVector(Array([const_ntt(14), const_ntt(32), const_ntt(50)])); + let v_in: NttVector = NttVector::new(Array([const_ntt(1), const_ntt(2), const_ntt(3)])); + let v_out: NttVector = + NttVector::new(Array([const_ntt(14), const_ntt(32), const_ntt(50)])); assert_eq!(&a * &v_in, v_out); // Verify transpose let aT = NttMatrix(Array([ - NttVector(Array([const_ntt(1), const_ntt(4), const_ntt(7)])), - NttVector(Array([const_ntt(2), const_ntt(5), const_ntt(8)])), - NttVector(Array([const_ntt(3), const_ntt(6), const_ntt(9)])), + NttVector::new(Array([const_ntt(1), const_ntt(4), const_ntt(7)])), + NttVector::new(Array([const_ntt(2), const_ntt(5), const_ntt(8)])), + NttVector::new(Array([const_ntt(3), const_ntt(6), const_ntt(9)])), ])); assert_eq!(a.transpose(), aT); } diff --git a/module-lattice/src/algebra.rs b/module-lattice/src/algebra.rs index 74e4602..f54b8bd 100644 --- a/module-lattice/src/algebra.rs +++ b/module-lattice/src/algebra.rs @@ -39,7 +39,7 @@ pub trait Field: Copy + Default + Debug + PartialEq { #[macro_export] macro_rules! define_field { ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => { - #[derive(Copy, Clone, Default, Debug, PartialEq)] + #[derive(Copy, Clone, Default, Debug, Eq, PartialEq)] pub struct $field; impl $crate::algebra::Field for $field { @@ -76,7 +76,7 @@ macro_rules! define_field { /// integer values remain in the field, and that the reductions are done efficiently. For /// addition and subtraction, a simple conditional subtraction is used; for multiplication, /// Barrett reduction. -#[derive(Copy, Clone, Default, Debug, PartialEq)] +#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)] pub struct Elem(pub F::Int); impl Elem { @@ -282,7 +282,7 @@ impl Neg for &Vector { /// subtracted, negated, and multiplied by scalars. /// We do not define multiplication of NTT polynomials here. We also do not define the /// mappings between normal polynomials and NTT polynomials (i.e., between `R_q` and `T_q`). -#[derive(Clone, Default, Debug, PartialEq)] +#[derive(Clone, Default, Debug, Eq, PartialEq)] pub struct NttPolynomial(pub Array, U256>); impl NttPolynomial { @@ -387,7 +387,7 @@ where /// added and subtracted. If multiplication is defined for NTT polynomials, then NTT vectors /// can be multiplied by NTT polynomials, and "multiplied" with each other to produce a dot /// product. -#[derive(Clone, Default, Debug, PartialEq)] +#[derive(Clone, Default, Debug, Eq, PartialEq)] pub struct NttVector(pub Array, K>); impl NttVector { @@ -396,6 +396,16 @@ impl NttVector { } } +#[cfg(feature = "subtle")] +impl ConstantTimeEq for NttVector +where + F::Int: ConstantTimeEq, +{ + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + #[cfg(feature = "zeroize")] impl Zeroize for NttVector where