From f1020f89d101d1f3e12ee62482d7e52c90bf5525 Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Wed, 28 Jan 2026 19:58:15 -0700 Subject: [PATCH] ml-kem: use `Polynomial(Vector)` from `module-lattice` Replaces the `Polynomial` and `PolynomialVector` types defined in `ml-kem` with shared implementations with `ml-dsa` from the `module-lattice` crate. --- ml-kem/src/algebra.rs | 215 ++++++++++++---------------------- ml-kem/src/encode.rs | 8 +- ml-kem/src/pke.rs | 23 ++-- module-lattice/src/algebra.rs | 8 +- 4 files changed, 100 insertions(+), 154 deletions(-) diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index 4576862..751c609 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -1,5 +1,5 @@ use array::{Array, typenum::U256}; -use core::ops::{Add, Mul, Sub}; +use core::ops::{Add, Mul}; use module_lattice::{algebra::Field, util::Truncate}; use sha3::digest::XofReader; use subtle::{Choice, ConstantTimeEq}; @@ -19,7 +19,13 @@ module_lattice::define_field!(BaseField, Integer, u32, u64, 3329); /// An element of GF(q). pub type FieldElement = module_lattice::algebra::Elem; -// Algorithm 11. BaseCaseMultiply +/// An element of the ring `R_q`, i.e., a polynomial over `Z_q` of degree 255 +pub type Polynomial = module_lattice::algebra::Polynomial; + +/// A vector of polynomials of length `K`. +pub type PolynomialVector = module_lattice::algebra::Vector; + +// Algorithm 12. BaseCaseMultiply // // This is a hot loop. We promote to u64 so that we can do the absolute minimum number of // modular reductions, since these are the expensive operation. @@ -43,90 +49,29 @@ fn base_case_multiply( (FieldElement::new(c0), FieldElement::new(c1)) } -/// An element of the ring `R_q`, i.e., a polynomial over `Z_q` of degree 255 -#[derive(Clone, Copy, Default, Debug, PartialEq)] -pub struct Polynomial(pub Array); - -impl Add<&Polynomial> for &Polynomial { - type Output = Polynomial; - - fn add(self, rhs: &Polynomial) -> Polynomial { - Polynomial( - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(&x, &y)| x + y) - .collect(), - ) - } -} - -impl Sub<&Polynomial> for &Polynomial { - type Output = Polynomial; - - fn sub(self, rhs: &Polynomial) -> Polynomial { - Polynomial( - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(&x, &y)| x - y) - .collect(), - ) - } -} - -impl Mul<&Polynomial> for FieldElement { - type Output = Polynomial; - - fn mul(self, rhs: &Polynomial) -> Polynomial { - Polynomial(rhs.0.iter().map(|&x| self * x).collect()) - } -} - -impl Polynomial { - // Algorithm 7. SamplePolyCBD_eta(B) - // - // To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in - // ByteDecode. We decode the PRF output into integers with eta bits, then use - // `count_ones` to perform the summation described in the algorithm. - pub fn sample_cbd(B: &PrfOutput) -> Self - where - Eta: CbdSamplingSize, - { - let vals: Polynomial = Encode::::decode(B); - Self(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect()) - } -} - -/// A vector of polynomials of length `k` -#[derive(Clone, Default, Debug, PartialEq)] -pub struct PolynomialVector(pub Array); - -impl Add> for PolynomialVector { - type Output = PolynomialVector; - - fn add(self, rhs: PolynomialVector) -> PolynomialVector { - PolynomialVector( - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(x, y)| x + y) - .collect(), - ) - } +// Algorithm 8. SamplePolyCBD_eta(B) +// +// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in +// ByteDecode. We decode the PRF output into integers with eta bits, then use +// `count_ones` to perform the summation described in the algorithm. +pub(crate) fn sample_poly_cbd(B: &PrfOutput) -> Polynomial +where + Eta: CbdSamplingSize, +{ + let vals: Polynomial = Encode::::decode(B); + Polynomial::new(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect()) } -impl PolynomialVector { - pub fn sample_cbd(sigma: &B32, start_n: u8) -> Self - where - Eta: CbdSamplingSize, - { - Self(Array::from_fn(|i| { - let N = start_n + u8::truncate(i); - let prf_output = PRF::(sigma, N); - Polynomial::sample_cbd::(&prf_output) - })) - } +pub(crate) fn sample_poly_vec_cbd(sigma: &B32, start_n: u8) -> PolynomialVector +where + Eta: CbdSamplingSize, + K: ArraySize, +{ + PolynomialVector::new(Array::from_fn(|i| { + let N = start_n + u8::truncate(i); + let prf_output = PRF::(sigma, N); + sample_poly_cbd::(&prf_output) + })) } /// An element of the ring `T_q`, i.e., a tuple of 128 elements of the direct sum components of `T_q`. @@ -162,7 +107,7 @@ impl Add<&NttPolynomial> for &NttPolynomial { } } -// Algorithm 6. SampleNTT (lines 4-13) +// Algorithm 7. SampleNTT (lines 4-13) struct FieldElementReader<'a> { xof: &'a mut dyn XofReader, data: [u8; 96], @@ -219,7 +164,7 @@ impl<'a> FieldElementReader<'a> { } impl NttPolynomial { - // Algorithm 6 SampleNTT(B) + // Algorithm 7 SampleNTT(B) pub fn sample_uniform(B: &mut impl XofReader) -> Self { let mut reader = FieldElementReader::new(B); Self(Array::from_fn(|_| reader.next())) @@ -288,7 +233,7 @@ const GAMMA: [FieldElement; 128] = { gamma }; -// Algorithm 10. MuliplyNTTs +// Algorithm 11. MuliplyNTTs impl Mul<&NttPolynomial> for &NttPolynomial { type Output = NttPolynomial; @@ -324,30 +269,32 @@ impl From for Array { } } -// Algorithm 8. NTT -impl Polynomial { - pub fn ntt(&self) -> NttPolynomial { - let mut k = 1; +// Algorithm 9. NTT +pub(crate) fn ntt(poly: &Polynomial) -> NttPolynomial { + let mut k = 1; - let mut f = self.0; - for len in [128, 64, 32, 16, 8, 4, 2] { - for start in (0..256).step_by(2 * len) { - let zeta = ZETA_POW_BITREV[k]; - k += 1; + let mut f = poly.0; + for len in [128, 64, 32, 16, 8, 4, 2] { + for start in (0..256).step_by(2 * len) { + let zeta = ZETA_POW_BITREV[k]; + k += 1; - for j in start..(start + len) { - let t = zeta * f[j + len]; - f[j + len] = f[j] - t; - f[j] = f[j] + t; - } + for j in start..(start + len) { + let t = zeta * f[j + len]; + f[j + len] = f[j] - t; + f[j] = f[j] + t; } } - - f.into() } + + f.into() +} + +pub(crate) fn ntt_vector(poly: &PolynomialVector) -> NttVector { + NttVector(poly.0.iter().map(ntt).collect()) } -// Algorithm 9. NTT^{-1} +// Algorithm 10. NTT^{-1} impl NttPolynomial { pub fn ntt_inverse(&self) -> Polynomial { let mut f: Array = self.0.clone(); @@ -366,7 +313,7 @@ impl NttPolynomial { } } - FieldElement::new(3303) * &Polynomial(f) + FieldElement::new(3303) * &Polynomial::new(f) } } @@ -436,15 +383,9 @@ impl Mul<&NttVector> for &NttVector { } } -impl PolynomialVector { - pub fn ntt(&self) -> NttVector { - NttVector(self.0.iter().map(Polynomial::ntt).collect()) - } -} - impl NttVector { pub fn ntt_inverse(&self) -> PolynomialVector { - PolynomialVector(self.0.iter().map(NttPolynomial::ntt_inverse).collect()) + PolynomialVector::new(self.0.iter().map(NttPolynomial::ntt_inverse).collect()) } } @@ -482,39 +423,35 @@ mod test { use module_lattice::util::Flatten; // Multiplication in R_q, modulo X^256 + 1 - impl Mul<&Polynomial> for &Polynomial { - type Output = Polynomial; - - fn mul(self, rhs: &Polynomial) -> Self::Output { - let mut out = Self::Output::default(); - for (i, x) in self.0.iter().enumerate() { - for (j, y) in rhs.0.iter().enumerate() { - let (sign, index) = if i + j < 256 { - (FieldElement::new(1), i + j) - } else { - (FieldElement::new(BaseField::Q - 1), i + j - 256) - }; - - out.0[index] = out.0[index] + (sign * *x * *y); - } + fn poly_mul(lhs: &Polynomial, rhs: &Polynomial) -> Polynomial { + let mut out = Polynomial::default(); + for (i, x) in lhs.0.iter().enumerate() { + for (j, y) in rhs.0.iter().enumerate() { + let (sign, index) = if i + j < 256 { + (FieldElement::new(1), i + j) + } else { + (FieldElement::new(BaseField::Q - 1), i + j - 256) + }; + + out.0[index] = out.0[index] + (sign * *x * *y); } - out } + out } // A polynomial with only a scalar component, to make simple test cases fn const_ntt(x: Integer) -> NttPolynomial { let mut p = Polynomial::default(); p.0[0] = FieldElement::new(x); - p.ntt() + super::ntt(&p) } #[test] #[allow(clippy::cast_possible_truncation)] fn polynomial_ops() { - let f = Polynomial(Array::from_fn(|i| FieldElement::new(i as Integer))); - let g = Polynomial(Array::from_fn(|i| FieldElement::new(2 * i as Integer))); - let sum = Polynomial(Array::from_fn(|i| FieldElement::new(3 * i as Integer))); + let f = Polynomial::new(Array::from_fn(|i| FieldElement::new(i as Integer))); + let g = Polynomial::new(Array::from_fn(|i| FieldElement::new(2 * i as Integer))); + let sum = Polynomial::new(Array::from_fn(|i| FieldElement::new(3 * i as Integer))); assert_eq!((&f + &g), sum); assert_eq!((&sum - &g), f); assert_eq!(FieldElement::new(3) * &f, sum); @@ -523,10 +460,10 @@ mod test { #[test] #[allow(clippy::cast_possible_truncation, clippy::similar_names)] fn ntt() { - let f = Polynomial(Array::from_fn(|i| FieldElement::new(i as Integer))); - let g = Polynomial(Array::from_fn(|i| FieldElement::new(2 * i as Integer))); - let f_hat = f.ntt(); - let g_hat = g.ntt(); + let f = Polynomial::new(Array::from_fn(|i| FieldElement::new(i as Integer))); + let g = Polynomial::new(Array::from_fn(|i| FieldElement::new(2 * i as Integer))); + let f_hat = super::ntt(&f); + let g_hat = super::ntt(&g); // Verify that NTT and NTT^-1 are actually inverses let f_unhat = f_hat.ntt_inverse(); @@ -539,7 +476,7 @@ mod test { assert_eq!(fg, fg_unhat); // Verify that NTT is a homomorphism with regard to multiplication - let fg = &f * &g; + let fg = poly_mul(&f, &g); let f_hat_g_hat = &f_hat * &g_hat; let fg_unhat = f_hat_g_hat.ntt_inverse(); assert_eq!(fg, fg_unhat); @@ -683,13 +620,13 @@ mod test { // Eta = 2 let sigma = B32::default(); let prf_output = PRF::(&sigma, 0); - let sample = Polynomial::sample_cbd::(&prf_output).0; + let sample = super::sample_poly_cbd::(&prf_output).0; test_sample(&sample, &CBD2); // Eta = 3 let sigma = B32::default(); let prf_output = PRF::(&sigma, 0); - let sample = Polynomial::sample_cbd::(&prf_output).0; + let sample = super::sample_poly_cbd::(&prf_output).0; test_sample(&sample, &CBD3); } } diff --git a/ml-kem/src/encode.rs b/ml-kem/src/encode.rs index f566751..2758ff8 100644 --- a/ml-kem/src/encode.rs +++ b/ml-kem/src/encode.rs @@ -275,7 +275,7 @@ pub(crate) mod test { #[test] fn vector_codec() { - let poly = Polynomial( + let poly = Polynomial::new( Array::<_, U8>([ FieldElement::new(0), FieldElement::new(1), @@ -290,17 +290,17 @@ pub(crate) mod test { ); // The required vector sizes are 2, 3, and 4. - let decoded: PolynomialVector = PolynomialVector(Array([poly, poly])); + let decoded: PolynomialVector = PolynomialVector::new(Array([poly, poly])); let encoded: EncodedPolynomialVector = Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); vector_codec_known_answer_test::>(&decoded, &encoded); - let decoded: PolynomialVector = PolynomialVector(Array([poly, poly, poly])); + let decoded: PolynomialVector = PolynomialVector::new(Array([poly, poly, poly])); let encoded: EncodedPolynomialVector = Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); vector_codec_known_answer_test::>(&decoded, &encoded); - let decoded: PolynomialVector = PolynomialVector(Array([poly, poly, poly, poly])); + let decoded: PolynomialVector = PolynomialVector::new(Array([poly, poly, poly, poly])); let encoded: EncodedPolynomialVector = Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); vector_codec_known_answer_test::>(&decoded, &encoded); diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index c53a01b..79b70a3 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -1,5 +1,8 @@ use crate::B32; -use crate::algebra::{NttMatrix, NttVector, Polynomial, PolynomialVector}; +use crate::algebra::{ + NttMatrix, NttVector, Polynomial, PolynomialVector, ntt_vector, sample_poly_cbd, + sample_poly_vec_cbd, +}; use crate::compress::Compress; use crate::crypto::{G, PRF}; use crate::encode::Encode; @@ -64,12 +67,12 @@ where // Sample pseudo-random matrix and vectors let A_hat: NttMatrix = NttMatrix::sample_uniform(&rho, false); - let s: PolynomialVector = PolynomialVector::sample_cbd::(&sigma, 0); - let e: PolynomialVector = PolynomialVector::sample_cbd::(&sigma, P::K::U8); + let s: PolynomialVector = sample_poly_vec_cbd::(&sigma, 0); + let e: PolynomialVector = sample_poly_vec_cbd::(&sigma, P::K::U8); // NTT the vectors - let s_hat = s.ntt(); - let e_hat = e.ntt(); + let s_hat = ntt_vector(&s); + let e_hat = ntt_vector(&e); // Compute the public value let t_hat = &(&A_hat * &s_hat) + &e_hat; @@ -91,7 +94,7 @@ where let mut v: Polynomial = Encode::::decode(c2); v.decompress::(); - let u_hat = u.ntt(); + let u_hat = ntt_vector(&u); let sTu = (&self.s_hat * &u_hat).ntt_inverse(); let mut w = &v - &sTu; Encode::::encode(w.compress::()) @@ -127,14 +130,14 @@ where /// Encrypt the specified message for the holder of the corresponding decryption key, using the /// provided randomness, according the `K-PKE.Encrypt` procedure. pub fn encrypt(&self, message: &B32, randomness: &B32) -> EncodedCiphertext

{ - let r = PolynomialVector::::sample_cbd::(randomness, 0); - let e1 = PolynomialVector::::sample_cbd::(randomness, P::K::U8); + let r = sample_poly_vec_cbd::(randomness, 0); + let e1 = sample_poly_vec_cbd::(randomness, P::K::U8); let prf_output = PRF::(randomness, 2 * P::K::U8); - let e2: Polynomial = Polynomial::sample_cbd::(&prf_output); + let e2: Polynomial = sample_poly_cbd::(&prf_output); let A_hat_t = NttMatrix::::sample_uniform(&self.rho, true); - let r_hat: NttVector = r.ntt(); + let r_hat: NttVector = ntt_vector(&r); let ATr: PolynomialVector = (&A_hat_t * &r_hat).ntt_inverse(); let mut u = ATr + e1; diff --git a/module-lattice/src/algebra.rs b/module-lattice/src/algebra.rs index fb7c02d..83a61f9 100644 --- a/module-lattice/src/algebra.rs +++ b/module-lattice/src/algebra.rs @@ -143,7 +143,7 @@ impl Mul> for Elem { /// A `Polynomial` is a member of the ring `R_q = Z_q[X] / (X^256)` of degree-256 polynomials /// over the finite field with prime order `q`. Polynomials can be added, subtracted, negated, /// and multiplied by field elements. We do not define multiplication of polynomials here. -#[derive(Clone, Default, Debug, PartialEq)] +#[derive(Clone, Copy, Default, Debug, PartialEq)] pub struct Polynomial(pub Array, U256>); impl Polynomial { @@ -227,6 +227,12 @@ where } } +impl Add> for Vector { + type Output = Vector; + fn add(self, rhs: Vector) -> Vector { + Add::add(&self, &rhs) + } +} impl Add<&Vector> for &Vector { type Output = Vector;