diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index b9c55fe..0090a69 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -1,5 +1,4 @@ use array::{Array, typenum::U256}; -use core::ops::Mul; use module_lattice::{ algebra::{Field, MultiplyNtt}, util::Truncate, @@ -30,6 +29,10 @@ pub type NttPolynomial = module_lattice::algebra::NttPolynomial; /// A vector of K NTT-domain polynomials. pub type NttVector = module_lattice::algebra::NttVector; +/// A K x K matrix of NTT-domain polynomials. Each vector represents a row of the matrix, so that +/// multiplying on the right just requires iteration. +pub type NttMatrix = module_lattice::algebra::NttMatrix; + /// Algorithm 7: `SampleNTT(B)` pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial { struct FieldElementReader<'a> { @@ -91,6 +94,16 @@ pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial { NttPolynomial::new(Array::from_fn(|_| reader.next())) } +pub(crate) fn matrix_sample_ntt(rho: &B32, transpose: bool) -> NttMatrix { + NttMatrix::new(Array::from_fn(|i| { + NttVector::new(Array::from_fn(|j| { + let (i, j) = if transpose { (j, i) } else { (i, j) }; + let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i)); + sample_ntt(&mut xof) + })) + })) +} + /// Algorithm 8: `SamplePolyCBD_eta(B)` /// /// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in @@ -301,43 +314,19 @@ const GAMMA: [Elem; 128] = { gamma }; -/// A K x K matrix of NTT-domain polynomials. Each vector represents a row of the matrix, so that -/// multiplying on the right just requires iteration. -#[derive(Clone, Default, Debug, PartialEq)] -pub struct NttMatrix(Array, K>); - -impl Mul<&NttVector> for &NttMatrix { - type Output = NttVector; - - fn mul(self, rhs: &NttVector) -> NttVector { - NttVector::new(self.0.iter().map(|x| x * rhs).collect()) - } -} - -impl NttMatrix { - pub fn sample_uniform(rho: &B32, transpose: bool) -> Self { - Self(Array::from_fn(|i| { - NttVector::new(Array::from_fn(|j| { - let (i, j) = if transpose { (j, i) } else { (i, j) }; - let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i)); - sample_ntt(&mut xof) - })) - })) - } - - pub fn transpose(&self) -> Self { - Self(Array::from_fn(|i| { - NttVector::new(Array::from_fn(|j| self.0[j].0[i].clone())) - })) - } -} - #[cfg(test)] mod test { use super::*; use array::typenum::{U2, U3, U8}; use module_lattice::util::Flatten; + /// A polynomial with only a scalar component, to make simple test cases + fn const_ntt(x: Int) -> NttPolynomial { + let mut p = Polynomial::default(); + p.0[0] = Elem::new(x); + p.ntt() + } + /// Multiplication in `R_q`, modulo X^256 + 1 fn poly_mul(lhs: &Polynomial, rhs: &Polynomial) -> Polynomial { let mut out = Polynomial::default(); @@ -355,11 +344,11 @@ mod test { out } - // A polynomial with only a scalar component, to make simple test cases - fn const_ntt(x: Int) -> NttPolynomial { - let mut p = Polynomial::default(); - p.0[0] = Elem::new(x); - p.ntt() + /// Transpose `NttMatrix` + fn matrix_transpose(matrix: &NttMatrix) -> NttMatrix { + NttMatrix::new(Array::from_fn(|i| { + NttVector::new(Array::from_fn(|j| matrix.0[j].0[i].clone())) + })) } #[test] @@ -420,7 +409,7 @@ mod test { #[test] fn ntt_matrix() { // Verify matrix multiplication by a vector - let a: NttMatrix = NttMatrix(Array([ + let a: NttMatrix = NttMatrix::new(Array([ NttVector::new(Array([const_ntt(1), const_ntt(2), const_ntt(3)])), NttVector::new(Array([const_ntt(4), const_ntt(5), const_ntt(6)])), NttVector::new(Array([const_ntt(7), const_ntt(8), const_ntt(9)])), @@ -431,12 +420,12 @@ mod test { assert_eq!(&a * &v_in, v_out); // Verify transpose - let aT = NttMatrix(Array([ + let aT = NttMatrix::new(Array([ NttVector::new(Array([const_ntt(1), const_ntt(4), const_ntt(7)])), NttVector::new(Array([const_ntt(2), const_ntt(5), const_ntt(8)])), NttVector::new(Array([const_ntt(3), const_ntt(6), const_ntt(9)])), ])); - assert_eq!(a.transpose(), aT); + assert_eq!(matrix_transpose(&a), aT); } // To verify the accuracy of sampling, we use a theorem related to the law of large numbers, diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index a849b55..cb0202e 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -1,6 +1,7 @@ use crate::B32; use crate::algebra::{ - Ntt, NttInverse, NttMatrix, NttVector, Polynomial, Vector, sample_poly_cbd, sample_poly_vec_cbd, + Ntt, NttInverse, NttMatrix, NttVector, Polynomial, Vector, matrix_sample_ntt, sample_poly_cbd, + sample_poly_vec_cbd, }; use crate::compress::Compress; use crate::crypto::{G, PRF}; @@ -65,7 +66,7 @@ where let (rho, sigma) = G(&[&d[..], &[k]]); // Sample pseudo-random matrix and vectors - let A_hat: NttMatrix = NttMatrix::sample_uniform(&rho, false); + let A_hat: NttMatrix = matrix_sample_ntt(&rho, false); let s: Vector = sample_poly_vec_cbd::(&sigma, 0); let e: Vector = sample_poly_vec_cbd::(&sigma, P::K::U8); @@ -135,7 +136,7 @@ where let prf_output = PRF::(randomness, 2 * P::K::U8); let e2: Polynomial = sample_poly_cbd::(&prf_output); - let A_hat_t = NttMatrix::::sample_uniform(&self.rho, true); + let A_hat_t: NttMatrix = matrix_sample_ntt(&self.rho, true); let r_hat: NttVector = r.ntt(); let ATr: Vector = (&A_hat_t * &r_hat).ntt_inverse(); let mut u = ATr + e1;