Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 29 additions & 40 deletions ml-kem/src/algebra.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use array::{Array, typenum::U256};
use core::ops::Mul;
use module_lattice::{
algebra::{Field, MultiplyNtt},
util::Truncate,
Expand Down Expand Up @@ -30,6 +29,10 @@ pub type NttPolynomial = module_lattice::algebra::NttPolynomial<BaseField>;
/// A vector of K NTT-domain polynomials.
pub type NttVector<K> = module_lattice::algebra::NttVector<BaseField, K>;

/// A K x K matrix of NTT-domain polynomials. Each vector represents a row of the matrix, so that
/// multiplying on the right just requires iteration.
pub type NttMatrix<K> = module_lattice::algebra::NttMatrix<BaseField, K, K>;

/// Algorithm 7: `SampleNTT(B)`
pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial {
struct FieldElementReader<'a> {
Expand Down Expand Up @@ -91,6 +94,16 @@ pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial {
NttPolynomial::new(Array::from_fn(|_| reader.next()))
}

pub(crate) fn matrix_sample_ntt<K: ArraySize>(rho: &B32, transpose: bool) -> NttMatrix<K> {
NttMatrix::new(Array::from_fn(|i| {
NttVector::new(Array::from_fn(|j| {
let (i, j) = if transpose { (j, i) } else { (i, j) };
let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i));
sample_ntt(&mut xof)
}))
}))
}

/// Algorithm 8: `SamplePolyCBD_eta(B)`
///
/// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in
Expand Down Expand Up @@ -301,43 +314,19 @@ const GAMMA: [Elem; 128] = {
gamma
};

/// A K x K matrix of NTT-domain polynomials. Each vector represents a row of the matrix, so that
/// multiplying on the right just requires iteration.
#[derive(Clone, Default, Debug, PartialEq)]
pub struct NttMatrix<K: ArraySize>(Array<NttVector<K>, K>);

impl<K: ArraySize> Mul<&NttVector<K>> for &NttMatrix<K> {
type Output = NttVector<K>;

fn mul(self, rhs: &NttVector<K>) -> NttVector<K> {
NttVector::new(self.0.iter().map(|x| x * rhs).collect())
}
}

impl<K: ArraySize> NttMatrix<K> {
pub fn sample_uniform(rho: &B32, transpose: bool) -> Self {
Self(Array::from_fn(|i| {
NttVector::new(Array::from_fn(|j| {
let (i, j) = if transpose { (j, i) } else { (i, j) };
let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i));
sample_ntt(&mut xof)
}))
}))
}

pub fn transpose(&self) -> Self {
Self(Array::from_fn(|i| {
NttVector::new(Array::from_fn(|j| self.0[j].0[i].clone()))
}))
}
}

#[cfg(test)]
mod test {
use super::*;
use array::typenum::{U2, U3, U8};
use module_lattice::util::Flatten;

/// A polynomial with only a scalar component, to make simple test cases
fn const_ntt(x: Int) -> NttPolynomial {
let mut p = Polynomial::default();
p.0[0] = Elem::new(x);
p.ntt()
}

/// Multiplication in `R_q`, modulo X^256 + 1
fn poly_mul(lhs: &Polynomial, rhs: &Polynomial) -> Polynomial {
let mut out = Polynomial::default();
Expand All @@ -355,11 +344,11 @@ mod test {
out
}

// A polynomial with only a scalar component, to make simple test cases
fn const_ntt(x: Int) -> NttPolynomial {
let mut p = Polynomial::default();
p.0[0] = Elem::new(x);
p.ntt()
/// Transpose `NttMatrix`
fn matrix_transpose<K: ArraySize>(matrix: &NttMatrix<K>) -> NttMatrix<K> {
NttMatrix::new(Array::from_fn(|i| {
NttVector::new(Array::from_fn(|j| matrix.0[j].0[i].clone()))
}))
}

#[test]
Expand Down Expand Up @@ -420,7 +409,7 @@ mod test {
#[test]
fn ntt_matrix() {
// Verify matrix multiplication by a vector
let a: NttMatrix<U3> = NttMatrix(Array([
let a: NttMatrix<U3> = NttMatrix::new(Array([
NttVector::new(Array([const_ntt(1), const_ntt(2), const_ntt(3)])),
NttVector::new(Array([const_ntt(4), const_ntt(5), const_ntt(6)])),
NttVector::new(Array([const_ntt(7), const_ntt(8), const_ntt(9)])),
Expand All @@ -431,12 +420,12 @@ mod test {
assert_eq!(&a * &v_in, v_out);

// Verify transpose
let aT = NttMatrix(Array([
let aT = NttMatrix::new(Array([
NttVector::new(Array([const_ntt(1), const_ntt(4), const_ntt(7)])),
NttVector::new(Array([const_ntt(2), const_ntt(5), const_ntt(8)])),
NttVector::new(Array([const_ntt(3), const_ntt(6), const_ntt(9)])),
]));
assert_eq!(a.transpose(), aT);
assert_eq!(matrix_transpose(&a), aT);
}

// To verify the accuracy of sampling, we use a theorem related to the law of large numbers,
Expand Down
7 changes: 4 additions & 3 deletions ml-kem/src/pke.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::B32;
use crate::algebra::{
Ntt, NttInverse, NttMatrix, NttVector, Polynomial, Vector, sample_poly_cbd, sample_poly_vec_cbd,
Ntt, NttInverse, NttMatrix, NttVector, Polynomial, Vector, matrix_sample_ntt, sample_poly_cbd,
sample_poly_vec_cbd,
};
use crate::compress::Compress;
use crate::crypto::{G, PRF};
Expand Down Expand Up @@ -65,7 +66,7 @@ where
let (rho, sigma) = G(&[&d[..], &[k]]);

// Sample pseudo-random matrix and vectors
let A_hat: NttMatrix<P::K> = NttMatrix::sample_uniform(&rho, false);
let A_hat: NttMatrix<P::K> = matrix_sample_ntt(&rho, false);
let s: Vector<P::K> = sample_poly_vec_cbd::<P::Eta1, P::K>(&sigma, 0);
let e: Vector<P::K> = sample_poly_vec_cbd::<P::Eta1, P::K>(&sigma, P::K::U8);

Expand Down Expand Up @@ -135,7 +136,7 @@ where
let prf_output = PRF::<P::Eta2>(randomness, 2 * P::K::U8);
let e2: Polynomial = sample_poly_cbd::<P::Eta2>(&prf_output);

let A_hat_t = NttMatrix::<P::K>::sample_uniform(&self.rho, true);
let A_hat_t: NttMatrix<P::K> = matrix_sample_ntt(&self.rho, true);
let r_hat: NttVector<P::K> = r.ntt();
let ATr: Vector<P::K> = (&A_hat_t * &r_hat).ntt_inverse();
let mut u = ATr + e1;
Expand Down