Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions ml-kem/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ alloc = ["pkcs8?/alloc"]
getrandom = ["kem/getrandom"]
pem = ["pkcs8/pem"]
pkcs8 = ["dep:const-oid", "dep:pkcs8"]
zeroize = ["dep:zeroize"]
zeroize = ["module-lattice/zeroize", "dep:zeroize"]
hazmat = []

[dependencies]
array = { package = "hybrid-array", version = "0.4.4", features = ["extra-sizes", "subtle"] }
module-lattice = "0.1.0-pre.0"
module-lattice = { version = "0.1.0-pre.0", features = ["subtle"] }
kem = "0.3.0-rc.2"
rand_core = "0.10.0-rc-6"
sha3 = { version = "0.11.0-rc.3", default-features = false }
Expand Down
164 changes: 56 additions & 108 deletions ml-kem/src/algebra.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use array::{Array, typenum::U256};
use core::ops::{Add, Mul, Sub};
use module_lattice::util::Truncate;
use module_lattice::{
algebra::{Elem, Field},
util::Truncate,
};
use sha3::digest::XofReader;
use subtle::{Choice, ConstantTimeEq};

Expand All @@ -14,88 +17,33 @@ use zeroize::Zeroize;

pub type Integer = u16;

/// An element of GF(q). Although `q` is only 16 bits wide, we use a wider uint type to so that we
/// can defer modular reductions.
#[derive(Copy, Clone, Debug, Default, PartialEq)]
pub struct FieldElement(pub Integer);
module_lattice::define_field!(BaseField, Integer, u32, u64, 3329);

impl ConstantTimeEq for FieldElement {
fn ct_eq(&self, other: &Self) -> Choice {
self.0.ct_eq(&other.0)
}
}
/// An element of GF(q).
pub type FieldElement = Elem<BaseField>;

#[cfg(feature = "zeroize")]
impl Zeroize for FieldElement {
fn zeroize(&mut self) {
self.0.zeroize();
}
}

impl FieldElement {
pub const Q: Integer = 3329;
pub const Q32: u32 = Self::Q as u32;
pub const Q64: u64 = Self::Q as u64;
const BARRETT_SHIFT: usize = 24;
#[allow(clippy::integer_division_remainder_used)]
const BARRETT_MULTIPLIER: u64 = (1 << Self::BARRETT_SHIFT) / Self::Q64;

// A fast modular reduction for small numbers `x < 2*q`
fn small_reduce(x: u16) -> u16 {
if x < Self::Q { x } else { x - Self::Q }
}

fn barrett_reduce(x: u32) -> u16 {
let product = u64::from(x) * Self::BARRETT_MULTIPLIER;
let quotient: u32 = Truncate::truncate(product >> Self::BARRETT_SHIFT);
let remainder = x - quotient * Self::Q32;
Self::small_reduce(Truncate::truncate(remainder))
}

// Algorithm 11. BaseCaseMultiply
//
// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of
// modular reductions, since these are the expensive operation.
fn base_case_multiply(a0: Self, a1: Self, b0: Self, b1: Self, i: usize) -> (Self, Self) {
let a0 = u32::from(a0.0);
let a1 = u32::from(a1.0);
let b0 = u32::from(b0.0);
let b1 = u32::from(b1.0);
let g = u32::from(GAMMA[i].0);

let b1g = u32::from(Self::barrett_reduce(b1 * g));

let c0 = Self::barrett_reduce(a0 * b0 + a1 * b1g);
let c1 = Self::barrett_reduce(a0 * b1 + a1 * b0);
(Self(c0), Self(c1))
}
}

impl Add<FieldElement> for FieldElement {
type Output = Self;

fn add(self, rhs: Self) -> Self {
Self(Self::small_reduce(self.0 + rhs.0))
}
}

impl Sub<FieldElement> for FieldElement {
type Output = Self;

fn sub(self, rhs: Self) -> Self {
// Guard against underflow if `rhs` is too large
Self(Self::small_reduce(self.0 + Self::Q - rhs.0))
}
}

impl Mul<FieldElement> for FieldElement {
type Output = FieldElement;

fn mul(self, rhs: FieldElement) -> FieldElement {
let x = u32::from(self.0);
let y = u32::from(rhs.0);
Self(Self::barrett_reduce(x * y))
}
// Algorithm 11. BaseCaseMultiply
//
// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of
// modular reductions, since these are the expensive operation.
fn base_case_multiply(
a0: FieldElement,
a1: FieldElement,
b0: FieldElement,
b1: FieldElement,
i: usize,
) -> (FieldElement, FieldElement) {
let a0 = u32::from(a0.0);
let a1 = u32::from(a1.0);
let b0 = u32::from(b0.0);
let b1 = u32::from(b1.0);
let g = u32::from(GAMMA[i].0);

let b1g = u32::from(BaseField::barrett_reduce(b1 * g));

let c0 = BaseField::barrett_reduce(a0 * b0 + a1 * b1g);
let c1 = BaseField::barrett_reduce(a0 * b1 + a1 * b0);
(Elem(c0), Elem(c1))
}

/// An element of the ring `R_q`, i.e., a polynomial over `Z_q` of degree 255
Expand Down Expand Up @@ -243,7 +191,7 @@ impl<'a> FieldElementReader<'a> {
fn next(&mut self) -> FieldElement {
if let Some(val) = self.next {
self.next = None;
return FieldElement(val);
return Elem(val);
}

loop {
Expand All @@ -259,15 +207,15 @@ impl<'a> FieldElementReader<'a> {
let d1 = Integer::from(b[0]) + ((Integer::from(b[1]) & 0xf) << 8);
let d2 = (Integer::from(b[1]) >> 4) + ((Integer::from(b[2]) as Integer) << 4);

if d1 < FieldElement::Q {
if d2 < FieldElement::Q {
if d1 < BaseField::Q {
if d2 < BaseField::Q {
self.next = Some(d2);
}
return FieldElement(d1);
return Elem(d1);
}

if d2 < FieldElement::Q {
return FieldElement(d2);
if d2 < BaseField::Q {
return Elem(d2);
}
}
}
Expand Down Expand Up @@ -308,18 +256,18 @@ const ZETA_POW_BITREV: [FieldElement; 128] = {
}

// Compute the powers of zeta
let mut pow = [FieldElement(0); 128];
let mut pow = [Elem(0); 128];
let mut i = 0;
let mut curr = 1u64;
#[allow(clippy::integer_division_remainder_used)]
while i < 128 {
pow[i] = FieldElement(curr as u16);
pow[i] = Elem(curr as u16);
i += 1;
curr = (curr * ZETA) % FieldElement::Q64;
curr = (curr * ZETA) % BaseField::QLL;
}

// Reorder the powers according to bitrev7
let mut pow_bitrev = [FieldElement(0); 128];
let mut pow_bitrev = [Elem(0); 128];
let mut i = 0;
while i < 128 {
pow_bitrev[i] = pow[bitrev7(i)];
Expand All @@ -331,13 +279,13 @@ const ZETA_POW_BITREV: [FieldElement; 128] = {
#[allow(clippy::cast_possible_truncation)]
const GAMMA: [FieldElement; 128] = {
const ZETA: u64 = 17;
let mut gamma = [FieldElement(0); 128];
let mut gamma = [Elem(0); 128];
let mut i = 0;
while i < 128 {
let zpr = ZETA_POW_BITREV[i].0 as u64;
#[allow(clippy::integer_division_remainder_used)]
let g = (zpr * zpr * ZETA) % FieldElement::Q64;
gamma[i] = FieldElement(g as u16);
let g = (zpr * zpr * ZETA) % BaseField::QLL;
gamma[i] = Elem(g as u16);
i += 1;
}
gamma
Expand All @@ -351,7 +299,7 @@ impl Mul<&NttPolynomial> for &NttPolynomial {
let mut out = NttPolynomial(Array::default());

for i in 0..128 {
let (c0, c1) = FieldElement::base_case_multiply(
let (c0, c1) = base_case_multiply(
self.0[2 * i],
self.0[2 * i + 1],
rhs.0[2 * i],
Expand Down Expand Up @@ -421,7 +369,7 @@ impl NttPolynomial {
}
}

FieldElement(3303) * &Polynomial(f)
Elem(3303) * &Polynomial(f)
}
}

Expand Down Expand Up @@ -545,9 +493,9 @@ mod test {
for (i, x) in self.0.iter().enumerate() {
for (j, y) in rhs.0.iter().enumerate() {
let (sign, index) = if i + j < 256 {
(FieldElement(1), i + j)
(Elem(1), i + j)
} else {
(FieldElement(FieldElement::Q - 1), i + j - 256)
(Elem(BaseField::Q - 1), i + j - 256)
};

out.0[index] = out.0[index] + (sign * *x * *y);
Expand All @@ -560,26 +508,26 @@ mod test {
// A polynomial with only a scalar component, to make simple test cases
fn const_ntt(x: Integer) -> NttPolynomial {
let mut p = Polynomial::default();
p.0[0] = FieldElement(x);
p.0[0] = Elem(x);
p.ntt()
}

#[test]
#[allow(clippy::cast_possible_truncation)]
fn polynomial_ops() {
let f = Polynomial(Array::from_fn(|i| FieldElement(i as Integer)));
let g = Polynomial(Array::from_fn(|i| FieldElement(2 * i as Integer)));
let sum = Polynomial(Array::from_fn(|i| FieldElement(3 * i as Integer)));
let f = Polynomial(Array::from_fn(|i| Elem(i as Integer)));
let g = Polynomial(Array::from_fn(|i| Elem(2 * i as Integer)));
let sum = Polynomial(Array::from_fn(|i| Elem(3 * i as Integer)));
assert_eq!((&f + &g), sum);
assert_eq!((&sum - &g), f);
assert_eq!(FieldElement(3) * &f, sum);
assert_eq!(Elem(3) * &f, sum);
}

#[test]
#[allow(clippy::cast_possible_truncation, clippy::similar_names)]
fn ntt() {
let f = Polynomial(Array::from_fn(|i| FieldElement(i as Integer)));
let g = Polynomial(Array::from_fn(|i| FieldElement(2 * i as Integer)));
let f = Polynomial(Array::from_fn(|i| Elem(i as Integer)));
let g = Polynomial(Array::from_fn(|i| Elem(2 * i as Integer)));
let f_hat = f.ntt();
let g_hat = g.ntt();

Expand Down Expand Up @@ -668,7 +616,7 @@ mod test {
//
// for k in $-\eta, \ldots, \eta$. The cases of interest here are \eta = 2, 3.
type Distribution = [f64; Q_SIZE];
const Q_SIZE: usize = FieldElement::Q as usize;
const Q_SIZE: usize = BaseField::Q as usize;
static CBD2: Distribution = {
let mut dist = [0.0; Q_SIZE];
dist[Q_SIZE - 2] = 1.0 / 16.0;
Expand All @@ -689,7 +637,7 @@ mod test {
dist[3] = 1.0 / 64.0;
dist
};
static UNIFORM: Distribution = [1.0 / (FieldElement::Q as f64); Q_SIZE];
static UNIFORM: Distribution = [1.0 / (BaseField::Q as f64); Q_SIZE];

fn kl_divergence(p: &Distribution, q: &Distribution) -> f64 {
p.iter()
Expand All @@ -704,7 +652,7 @@ mod test {
let mut sample_dist: Distribution = [0.0; Q_SIZE];
let bump: f64 = 1.0 / (sample.len() as f64);
for x in sample {
assert!(x.0 < FieldElement::Q);
assert!(x.0 < BaseField::Q);
assert!(ref_dist[x.0 as usize] > 0.0);

sample_dist[x.0 as usize] += bump;
Expand Down
Loading