Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 32 additions & 92 deletions ml-kem/src/algebra.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
use array::{Array, typenum::U256};
use core::ops::{Add, Mul};
use core::ops::Mul;
use module_lattice::{
algebra::{Field, MultiplyNtt},
util::Truncate,
};
use sha3::digest::XofReader;
use subtle::{Choice, ConstantTimeEq};

use crate::B32;
use crate::crypto::{PRF, PrfOutput, XOF};
use crate::encode::Encode;
use crate::param::{ArraySize, CbdSamplingSize};

#[cfg(feature = "zeroize")]
use zeroize::Zeroize;

module_lattice::define_field!(BaseField, u16, u32, u64, 3329);

pub type Int = <BaseField as Field>::Int;
Expand All @@ -31,6 +27,9 @@ pub type Vector<K> = module_lattice::algebra::Vector<BaseField, K>;
/// An element of the ring `T_q` i.e. a tuple of 128 elements of the direct sum components of `T_q`.
pub type NttPolynomial = module_lattice::algebra::NttPolynomial<BaseField>;

/// A vector of K NTT-domain polynomials.
pub type NttVector<K> = module_lattice::algebra::NttVector<BaseField, K>;

/// Algorithm 7: `SampleNTT(B)`
pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial {
struct FieldElementReader<'a> {
Expand Down Expand Up @@ -154,7 +153,7 @@ impl<K: ArraySize> Ntt for Vector<K> {
type Output = NttVector<K>;

fn ntt(&self) -> NttVector<K> {
NttVector(self.0.iter().map(Ntt::ntt).collect())
NttVector::new(self.0.iter().map(Ntt::ntt).collect())
}
}

Expand Down Expand Up @@ -191,6 +190,14 @@ impl NttInverse for NttPolynomial {
}
}

impl<K: ArraySize> NttInverse for NttVector<K> {
type Output = Vector<K>;

fn ntt_inverse(&self) -> Vector<K> {
Vector::new(self.0.iter().map(NttInverse::ntt_inverse).collect())
}
}

/// Algorithm 11: `MultiplyNTTs`
impl MultiplyNtt for BaseField {
fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial {
Expand Down Expand Up @@ -294,78 +301,6 @@ const GAMMA: [Elem; 128] = {
gamma
};

/// A vector of K NTT-domain polynomials
#[derive(Clone, Default, Debug)]
pub struct NttVector<K: ArraySize>(pub Array<NttPolynomial, K>);

impl<K: ArraySize> NttVector<K> {
pub fn sample_uniform(rho: &B32, i: usize, transpose: bool) -> Self {
Self(Array::from_fn(|j| {
let (i, j) = if transpose { (j, i) } else { (i, j) };
let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i));
sample_ntt(&mut xof)
}))
}
}

impl<K: ArraySize> ConstantTimeEq for NttVector<K> {
fn ct_eq(&self, other: &Self) -> Choice {
self.0.ct_eq(&other.0)
}
}

impl<K: ArraySize> Eq for NttVector<K> {}
impl<K: ArraySize> PartialEq for NttVector<K> {
fn eq(&self, other: &Self) -> bool {
// Impl `PartialEq` in constant-time, in case this value contains a secret
self.0.ct_eq(&other.0).into()
}
}

#[cfg(feature = "zeroize")]
impl<K> Zeroize for NttVector<K>
where
K: ArraySize,
{
fn zeroize(&mut self) {
for poly in &mut self.0 {
poly.zeroize();
}
}
}

impl<K: ArraySize> Add<&NttVector<K>> for &NttVector<K> {
type Output = NttVector<K>;

fn add(self, rhs: &NttVector<K>) -> NttVector<K> {
NttVector(
self.0
.iter()
.zip(rhs.0.iter())
.map(|(x, y)| x + y)
.collect(),
)
}
}

impl<K: ArraySize> Mul<&NttVector<K>> for &NttVector<K> {
type Output = NttPolynomial;

fn mul(self, rhs: &NttVector<K>) -> NttPolynomial {
self.0
.iter()
.zip(rhs.0.iter())
.map(|(x, y)| x * y)
.fold(NttPolynomial::default(), |x, y| &x + &y)
}
}

impl<K: ArraySize> NttVector<K> {
pub fn ntt_inverse(&self) -> Vector<K> {
Vector::new(self.0.iter().map(NttInverse::ntt_inverse).collect())
}
}

/// A K x K matrix of NTT-domain polynomials. Each vector represents a row of the matrix, so that
/// multiplying on the right just requires iteration.
#[derive(Clone, Default, Debug, PartialEq)]
Expand All @@ -375,20 +310,24 @@ impl<K: ArraySize> Mul<&NttVector<K>> for &NttMatrix<K> {
type Output = NttVector<K>;

fn mul(self, rhs: &NttVector<K>) -> NttVector<K> {
NttVector(self.0.iter().map(|x| x * rhs).collect())
NttVector::new(self.0.iter().map(|x| x * rhs).collect())
}
}

impl<K: ArraySize> NttMatrix<K> {
pub fn sample_uniform(rho: &B32, transpose: bool) -> Self {
Self(Array::from_fn(|i| {
NttVector::sample_uniform(rho, i, transpose)
NttVector::new(Array::from_fn(|j| {
let (i, j) = if transpose { (j, i) } else { (i, j) };
let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i));
sample_ntt(&mut xof)
}))
}))
}

pub fn transpose(&self) -> Self {
Self(Array::from_fn(|i| {
NttVector(Array::from_fn(|j| self.0[j].0[i].clone()))
NttVector::new(Array::from_fn(|j| self.0[j].0[i].clone()))
}))
}
}
Expand Down Expand Up @@ -462,9 +401,9 @@ mod test {
#[test]
fn ntt_vector() {
// Verify vector addition
let v1: NttVector<U3> = NttVector(Array([const_ntt(1), const_ntt(1), const_ntt(1)]));
let v2: NttVector<U3> = NttVector(Array([const_ntt(2), const_ntt(2), const_ntt(2)]));
let v3: NttVector<U3> = NttVector(Array([const_ntt(3), const_ntt(3), const_ntt(3)]));
let v1: NttVector<U3> = NttVector::new(Array([const_ntt(1), const_ntt(1), const_ntt(1)]));
let v2: NttVector<U3> = NttVector::new(Array([const_ntt(2), const_ntt(2), const_ntt(2)]));
let v3: NttVector<U3> = NttVector::new(Array([const_ntt(3), const_ntt(3), const_ntt(3)]));
assert_eq!((&v1 + &v2), v3);

// Verify dot product
Expand All @@ -477,19 +416,20 @@ mod test {
fn ntt_matrix() {
// Verify matrix multiplication by a vector
let a: NttMatrix<U3> = NttMatrix(Array([
NttVector(Array([const_ntt(1), const_ntt(2), const_ntt(3)])),
NttVector(Array([const_ntt(4), const_ntt(5), const_ntt(6)])),
NttVector(Array([const_ntt(7), const_ntt(8), const_ntt(9)])),
NttVector::new(Array([const_ntt(1), const_ntt(2), const_ntt(3)])),
NttVector::new(Array([const_ntt(4), const_ntt(5), const_ntt(6)])),
NttVector::new(Array([const_ntt(7), const_ntt(8), const_ntt(9)])),
]));
let v_in: NttVector<U3> = NttVector(Array([const_ntt(1), const_ntt(2), const_ntt(3)]));
let v_out: NttVector<U3> = NttVector(Array([const_ntt(14), const_ntt(32), const_ntt(50)]));
let v_in: NttVector<U3> = NttVector::new(Array([const_ntt(1), const_ntt(2), const_ntt(3)]));
let v_out: NttVector<U3> =
NttVector::new(Array([const_ntt(14), const_ntt(32), const_ntt(50)]));
assert_eq!(&a * &v_in, v_out);

// Verify transpose
let aT = NttMatrix(Array([
NttVector(Array([const_ntt(1), const_ntt(4), const_ntt(7)])),
NttVector(Array([const_ntt(2), const_ntt(5), const_ntt(8)])),
NttVector(Array([const_ntt(3), const_ntt(6), const_ntt(9)])),
NttVector::new(Array([const_ntt(1), const_ntt(4), const_ntt(7)])),
NttVector::new(Array([const_ntt(2), const_ntt(5), const_ntt(8)])),
NttVector::new(Array([const_ntt(3), const_ntt(6), const_ntt(9)])),
]));
assert_eq!(a.transpose(), aT);
}
Expand Down
18 changes: 14 additions & 4 deletions module-lattice/src/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub trait Field: Copy + Default + Debug + PartialEq {
#[macro_export]
macro_rules! define_field {
($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => {
#[derive(Copy, Clone, Default, Debug, PartialEq)]
#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
pub struct $field;

impl $crate::algebra::Field for $field {
Expand Down Expand Up @@ -76,7 +76,7 @@ macro_rules! define_field {
/// integer values remain in the field, and that the reductions are done efficiently. For
/// addition and subtraction, a simple conditional subtraction is used; for multiplication,
/// Barrett reduction.
#[derive(Copy, Clone, Default, Debug, PartialEq)]
#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
pub struct Elem<F: Field>(pub F::Int);

impl<F: Field> Elem<F> {
Expand Down Expand Up @@ -282,7 +282,7 @@ impl<F: Field, K: ArraySize> Neg for &Vector<F, K> {
/// subtracted, negated, and multiplied by scalars.
/// We do not define multiplication of NTT polynomials here. We also do not define the
/// mappings between normal polynomials and NTT polynomials (i.e., between `R_q` and `T_q`).
#[derive(Clone, Default, Debug, PartialEq)]
#[derive(Clone, Default, Debug, Eq, PartialEq)]
pub struct NttPolynomial<F: Field>(pub Array<Elem<F>, U256>);

impl<F: Field> NttPolynomial<F> {
Expand Down Expand Up @@ -387,7 +387,7 @@ where
/// added and subtracted. If multiplication is defined for NTT polynomials, then NTT vectors
/// can be multiplied by NTT polynomials, and "multiplied" with each other to produce a dot
/// product.
#[derive(Clone, Default, Debug, PartialEq)]
#[derive(Clone, Default, Debug, Eq, PartialEq)]
pub struct NttVector<F: Field, K: ArraySize>(pub Array<NttPolynomial<F>, K>);

impl<F: Field, K: ArraySize> NttVector<F, K> {
Expand All @@ -396,6 +396,16 @@ impl<F: Field, K: ArraySize> NttVector<F, K> {
}
}

#[cfg(feature = "subtle")]
impl<F: Field, K: ArraySize> ConstantTimeEq for NttVector<F, K>
where
F::Int: ConstantTimeEq,
{
fn ct_eq(&self, other: &Self) -> Choice {
self.0.ct_eq(&other.0)
}
}

#[cfg(feature = "zeroize")]
impl<F: Field, K: ArraySize> Zeroize for NttVector<F, K>
where
Expand Down