diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index 751c609..51535a8 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -25,28 +25,68 @@ pub type Polynomial = module_lattice::algebra::Polynomial; /// A vector of polynomials of length `K`. pub type PolynomialVector = module_lattice::algebra::Vector; -// Algorithm 12. BaseCaseMultiply -// -// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of -// modular reductions, since these are the expensive operation. -fn base_case_multiply( - a0: FieldElement, - a1: FieldElement, - b0: FieldElement, - b1: FieldElement, - i: usize, -) -> (FieldElement, FieldElement) { - let a0 = u32::from(a0.0); - let a1 = u32::from(a1.0); - let b0 = u32::from(b0.0); - let b1 = u32::from(b1.0); - let g = u32::from(GAMMA[i].0); +/// An element of the ring `T_q` i.e. a tuple of 128 elements of the direct sum components of `T_q`. +pub type NttPolynomial = module_lattice::algebra::NttPolynomial; + +// Algorithm 7 SampleNTT(B) +pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial { + struct FieldElementReader<'a> { + xof: &'a mut dyn XofReader, + data: [u8; 96], + start: usize, + next: Option, + } + + impl<'a> FieldElementReader<'a> { + fn new(xof: &'a mut impl XofReader) -> Self { + let mut out = Self { + xof, + data: [0u8; 96], + start: 0, + next: None, + }; + + // Fill the buffer + out.xof.read(&mut out.data); + + out + } - let b1g = u32::from(BaseField::barrett_reduce(b1 * g)); + fn next(&mut self) -> FieldElement { + if let Some(val) = self.next { + self.next = None; + return FieldElement::new(val); + } - let c0 = BaseField::barrett_reduce(a0 * b0 + a1 * b1g); - let c1 = BaseField::barrett_reduce(a0 * b1 + a1 * b0); - (FieldElement::new(c0), FieldElement::new(c1)) + loop { + if self.start == self.data.len() { + self.xof.read(&mut self.data); + self.start = 0; + } + + let end = self.start + 3; + let b = &self.data[self.start..end]; + self.start = end; + + let d1 = Integer::from(b[0]) + ((Integer::from(b[1]) & 0xf) << 8); + let d2 = (Integer::from(b[1]) >> 4) + ((Integer::from(b[2]) as Integer) << 4); + + if d1 < BaseField::Q { + if d2 < BaseField::Q { + self.next = Some(d2); + } + return FieldElement::new(d1); + } + + if d2 < BaseField::Q { + return FieldElement::new(d2); + } + } + } + } + + let mut reader = FieldElementReader::new(B); + NttPolynomial::new(Array::from_fn(|_| reader.next())) } // Algorithm 8. SamplePolyCBD_eta(B) @@ -62,113 +102,107 @@ where Polynomial::new(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect()) } -pub(crate) fn sample_poly_vec_cbd(sigma: &B32, start_n: u8) -> PolynomialVector -where - Eta: CbdSamplingSize, - K: ArraySize, -{ - PolynomialVector::new(Array::from_fn(|i| { - let N = start_n + u8::truncate(i); - let prf_output = PRF::(sigma, N); - sample_poly_cbd::(&prf_output) - })) -} +// Algorithm 9. NTT +pub(crate) fn ntt(poly: &Polynomial) -> NttPolynomial { + let mut k = 1; -/// An element of the ring `T_q`, i.e., a tuple of 128 elements of the direct sum components of `T_q`. -#[derive(Clone, Default, Debug, PartialEq)] -pub struct NttPolynomial(pub Array); + let mut f = poly.0; + for len in [128, 64, 32, 16, 8, 4, 2] { + for start in (0..256).step_by(2 * len) { + let zeta = ZETA_POW_BITREV[k]; + k += 1; -impl ConstantTimeEq for NttPolynomial { - fn ct_eq(&self, other: &Self) -> Choice { - self.0.ct_eq(&other.0) + for j in start..(start + len) { + let t = zeta * f[j + len]; + f[j + len] = f[j] - t; + f[j] = f[j] + t; + } + } } + + f.into() } -#[cfg(feature = "zeroize")] -impl Zeroize for NttPolynomial { - fn zeroize(&mut self) { - for fe in &mut self.0 { - fe.zeroize(); - } - } +pub(crate) fn ntt_vector(poly: &PolynomialVector) -> NttVector { + NttVector(poly.0.iter().map(ntt).collect()) } -impl Add<&NttPolynomial> for &NttPolynomial { - type Output = NttPolynomial; +// Algorithm 10. NTT^{-1} +pub(crate) fn ntt_inverse(poly: &NttPolynomial) -> Polynomial { + let mut f: Array = poly.0.clone(); - fn add(self, rhs: &NttPolynomial) -> NttPolynomial { - NttPolynomial( - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(&x, &y)| x + y) - .collect(), - ) + let mut k = 127; + for len in [2, 4, 8, 16, 32, 64, 128] { + for start in (0..256).step_by(2 * len) { + let zeta = ZETA_POW_BITREV[k]; + k -= 1; + + for j in start..(start + len) { + let t = f[j]; + f[j] = t + f[j + len]; + f[j + len] = zeta * (f[j + len] - t); + } + } } -} -// Algorithm 7. SampleNTT (lines 4-13) -struct FieldElementReader<'a> { - xof: &'a mut dyn XofReader, - data: [u8; 96], - start: usize, - next: Option, + FieldElement::new(3303) * &Polynomial::new(f) } -impl<'a> FieldElementReader<'a> { - fn new(xof: &'a mut impl XofReader) -> Self { - let mut out = Self { - xof, - data: [0u8; 96], - start: 0, - next: None, - }; +// Algorithm 11. MultiplyNTTs +fn multiply_ntts(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial { + let mut out = NttPolynomial::new(Array::default()); - // Fill the buffer - out.xof.read(&mut out.data); + for i in 0..128 { + let (c0, c1) = base_case_multiply( + lhs.0[2 * i], + lhs.0[2 * i + 1], + rhs.0[2 * i], + rhs.0[2 * i + 1], + i, + ); - out + out.0[2 * i] = c0; + out.0[2 * i + 1] = c1; } - fn next(&mut self) -> FieldElement { - if let Some(val) = self.next { - self.next = None; - return FieldElement::new(val); - } - - loop { - if self.start == self.data.len() { - self.xof.read(&mut self.data); - self.start = 0; - } - - let end = self.start + 3; - let b = &self.data[self.start..end]; - self.start = end; + out +} - let d1 = Integer::from(b[0]) + ((Integer::from(b[1]) & 0xf) << 8); - let d2 = (Integer::from(b[1]) >> 4) + ((Integer::from(b[2]) as Integer) << 4); +// Algorithm 12. BaseCaseMultiply +// +// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of +// modular reductions, since these are the expensive operation. +#[inline] +fn base_case_multiply( + a0: FieldElement, + a1: FieldElement, + b0: FieldElement, + b1: FieldElement, + i: usize, +) -> (FieldElement, FieldElement) { + let a0 = u32::from(a0.0); + let a1 = u32::from(a1.0); + let b0 = u32::from(b0.0); + let b1 = u32::from(b1.0); + let g = u32::from(GAMMA[i].0); - if d1 < BaseField::Q { - if d2 < BaseField::Q { - self.next = Some(d2); - } - return FieldElement::new(d1); - } + let b1g = u32::from(BaseField::barrett_reduce(b1 * g)); - if d2 < BaseField::Q { - return FieldElement::new(d2); - } - } - } + let c0 = BaseField::barrett_reduce(a0 * b0 + a1 * b1g); + let c1 = BaseField::barrett_reduce(a0 * b1 + a1 * b0); + (FieldElement::new(c0), FieldElement::new(c1)) } -impl NttPolynomial { - // Algorithm 7 SampleNTT(B) - pub fn sample_uniform(B: &mut impl XofReader) -> Self { - let mut reader = FieldElementReader::new(B); - Self(Array::from_fn(|_| reader.next())) - } +pub(crate) fn sample_poly_vec_cbd(sigma: &B32, start_n: u8) -> PolynomialVector +where + Eta: CbdSamplingSize, + K: ArraySize, +{ + PolynomialVector::new(Array::from_fn(|i| { + let N = start_n + u8::truncate(i); + let prf_output = PRF::(sigma, N); + sample_poly_cbd::(&prf_output) + })) } // Since the powers of zeta used in the NTT and MultiplyNTTs are fixed, we use pre-computed tables @@ -233,90 +267,6 @@ const GAMMA: [FieldElement; 128] = { gamma }; -// Algorithm 11. MuliplyNTTs -impl Mul<&NttPolynomial> for &NttPolynomial { - type Output = NttPolynomial; - - fn mul(self, rhs: &NttPolynomial) -> NttPolynomial { - let mut out = NttPolynomial(Array::default()); - - for i in 0..128 { - let (c0, c1) = base_case_multiply( - self.0[2 * i], - self.0[2 * i + 1], - rhs.0[2 * i], - rhs.0[2 * i + 1], - i, - ); - - out.0[2 * i] = c0; - out.0[2 * i + 1] = c1; - } - - out - } -} - -impl From> for NttPolynomial { - fn from(f: Array) -> NttPolynomial { - NttPolynomial(f) - } -} - -impl From for Array { - fn from(f_hat: NttPolynomial) -> Array { - f_hat.0 - } -} - -// Algorithm 9. NTT -pub(crate) fn ntt(poly: &Polynomial) -> NttPolynomial { - let mut k = 1; - - let mut f = poly.0; - for len in [128, 64, 32, 16, 8, 4, 2] { - for start in (0..256).step_by(2 * len) { - let zeta = ZETA_POW_BITREV[k]; - k += 1; - - for j in start..(start + len) { - let t = zeta * f[j + len]; - f[j + len] = f[j] - t; - f[j] = f[j] + t; - } - } - } - - f.into() -} - -pub(crate) fn ntt_vector(poly: &PolynomialVector) -> NttVector { - NttVector(poly.0.iter().map(ntt).collect()) -} - -// Algorithm 10. NTT^{-1} -impl NttPolynomial { - pub fn ntt_inverse(&self) -> Polynomial { - let mut f: Array = self.0.clone(); - - let mut k = 127; - for len in [2, 4, 8, 16, 32, 64, 128] { - for start in (0..256).step_by(2 * len) { - let zeta = ZETA_POW_BITREV[k]; - k -= 1; - - for j in start..(start + len) { - let t = f[j]; - f[j] = t + f[j + len]; - f[j + len] = zeta * (f[j + len] - t); - } - } - } - - FieldElement::new(3303) * &Polynomial::new(f) - } -} - /// A vector of K NTT-domain polynomials #[derive(Clone, Default, Debug)] pub struct NttVector(pub Array); @@ -326,7 +276,7 @@ impl NttVector { Self(Array::from_fn(|j| { let (i, j) = if transpose { (j, i) } else { (i, j) }; let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i)); - NttPolynomial::sample_uniform(&mut xof) + sample_ntt(&mut xof) })) } } @@ -378,14 +328,14 @@ impl Mul<&NttVector> for &NttVector { self.0 .iter() .zip(rhs.0.iter()) - .map(|(x, y)| x * y) + .map(|(x, y)| multiply_ntts(x, y)) .fold(NttPolynomial::default(), |x, y| &x + &y) } } impl NttVector { pub fn ntt_inverse(&self) -> PolynomialVector { - PolynomialVector::new(self.0.iter().map(NttPolynomial::ntt_inverse).collect()) + PolynomialVector::new(self.0.iter().map(ntt_inverse).collect()) } } @@ -466,19 +416,19 @@ mod test { let g_hat = super::ntt(&g); // Verify that NTT and NTT^-1 are actually inverses - let f_unhat = f_hat.ntt_inverse(); + let f_unhat = ntt_inverse(&f_hat); assert_eq!(f, f_unhat); // Verify that NTT is a homomorphism with regard to addition let fg = &f + &g; let f_hat_g_hat = &f_hat + &g_hat; - let fg_unhat = f_hat_g_hat.ntt_inverse(); + let fg_unhat = ntt_inverse(&f_hat_g_hat); assert_eq!(fg, fg_unhat); // Verify that NTT is a homomorphism with regard to multiplication let fg = poly_mul(&f, &g); - let f_hat_g_hat = &f_hat * &g_hat; - let fg_unhat = f_hat_g_hat.ntt_inverse(); + let f_hat_g_hat = multiply_ntts(&f_hat, &g_hat); + let fg_unhat = ntt_inverse(&f_hat_g_hat); assert_eq!(fg, fg_unhat); } @@ -609,7 +559,7 @@ mod test { let rho = B32::default(); let sample: Array, U8> = Array::from_fn(|i| { let mut xof = XOF(&rho, 0, i as u8); - NttPolynomial::sample_uniform(&mut xof).into() + sample_ntt(&mut xof).into() }); test_sample(&sample.flatten(), &UNIFORM); diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index 79b70a3..0373fbc 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -1,6 +1,6 @@ use crate::B32; use crate::algebra::{ - NttMatrix, NttVector, Polynomial, PolynomialVector, ntt_vector, sample_poly_cbd, + NttMatrix, NttVector, Polynomial, PolynomialVector, ntt_inverse, ntt_vector, sample_poly_cbd, sample_poly_vec_cbd, }; use crate::compress::Compress; @@ -95,7 +95,7 @@ where v.decompress::(); let u_hat = ntt_vector(&u); - let sTu = (&self.s_hat * &u_hat).ntt_inverse(); + let sTu = ntt_inverse(&(&self.s_hat * &u_hat)); let mut w = &v - &sTu; Encode::::encode(w.compress::()) } @@ -144,7 +144,7 @@ where let mut mu: Polynomial = Encode::::decode(message); mu.decompress::(); - let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse(); + let tTr: Polynomial = ntt_inverse(&(&self.t_hat * &r_hat)); let mut v = &(&tTr + &e2) + μ let c1 = Encode::::encode(u.compress::()); diff --git a/module-lattice/src/algebra.rs b/module-lattice/src/algebra.rs index 83a61f9..c91d3d1 100644 --- a/module-lattice/src/algebra.rs +++ b/module-lattice/src/algebra.rs @@ -291,16 +291,6 @@ impl NttPolynomial { } } -#[cfg(feature = "zeroize")] -impl Zeroize for NttPolynomial -where - F::Int: Zeroize, -{ - fn zeroize(&mut self) { - self.0.zeroize(); - } -} - impl Add<&NttPolynomial> for &NttPolynomial { type Output = NttPolynomial; @@ -360,6 +350,38 @@ impl Neg for &NttPolynomial { } } +impl From, U256>> for NttPolynomial { + fn from(f: Array, U256>) -> NttPolynomial { + NttPolynomial(f) + } +} + +impl From> for Array, U256> { + fn from(f_hat: NttPolynomial) -> Array, U256> { + f_hat.0 + } +} + +#[cfg(feature = "subtle")] +impl ConstantTimeEq for NttPolynomial +where + F::Int: ConstantTimeEq, +{ + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + +#[cfg(feature = "zeroize")] +impl Zeroize for NttPolynomial +where + F::Int: Zeroize, +{ + fn zeroize(&mut self) { + self.0.zeroize(); + } +} + /// An `NttVector` is a vector of polynomials from `T_q` of length `K`. NTT vectors can be /// added and subtracted. If multiplication is defined for NTT polynomials, then NTT vectors /// can be multiplied by NTT polynomials, and "multiplied" with each other to produce a dot