Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 76 additions & 139 deletions ml-kem/src/algebra.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use array::{Array, typenum::U256};
use core::ops::{Add, Mul, Sub};
use core::ops::{Add, Mul};
use module_lattice::{algebra::Field, util::Truncate};
use sha3::digest::XofReader;
use subtle::{Choice, ConstantTimeEq};
Expand All @@ -19,7 +19,13 @@ module_lattice::define_field!(BaseField, Integer, u32, u64, 3329);
/// An element of GF(q).
pub type FieldElement = module_lattice::algebra::Elem<BaseField>;

// Algorithm 11. BaseCaseMultiply
/// An element of the ring `R_q`, i.e., a polynomial over `Z_q` of degree 255
pub type Polynomial = module_lattice::algebra::Polynomial<BaseField>;

/// A vector of polynomials of length `K`.
pub type PolynomialVector<K> = module_lattice::algebra::Vector<BaseField, K>;

// Algorithm 12. BaseCaseMultiply
//
// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of
// modular reductions, since these are the expensive operation.
Expand All @@ -43,90 +49,29 @@ fn base_case_multiply(
(FieldElement::new(c0), FieldElement::new(c1))
}

/// An element of the ring `R_q`, i.e., a polynomial over `Z_q` of degree 255
#[derive(Clone, Copy, Default, Debug, PartialEq)]
pub struct Polynomial(pub Array<FieldElement, U256>);

impl Add<&Polynomial> for &Polynomial {
type Output = Polynomial;

fn add(self, rhs: &Polynomial) -> Polynomial {
Polynomial(
self.0
.iter()
.zip(rhs.0.iter())
.map(|(&x, &y)| x + y)
.collect(),
)
}
}

impl Sub<&Polynomial> for &Polynomial {
type Output = Polynomial;

fn sub(self, rhs: &Polynomial) -> Polynomial {
Polynomial(
self.0
.iter()
.zip(rhs.0.iter())
.map(|(&x, &y)| x - y)
.collect(),
)
}
}

impl Mul<&Polynomial> for FieldElement {
type Output = Polynomial;

fn mul(self, rhs: &Polynomial) -> Polynomial {
Polynomial(rhs.0.iter().map(|&x| self * x).collect())
}
}

impl Polynomial {
// Algorithm 7. SamplePolyCBD_eta(B)
//
// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in
// ByteDecode. We decode the PRF output into integers with eta bits, then use
// `count_ones` to perform the summation described in the algorithm.
pub fn sample_cbd<Eta>(B: &PrfOutput<Eta>) -> Self
where
Eta: CbdSamplingSize,
{
let vals: Polynomial = Encode::<Eta::SampleSize>::decode(B);
Self(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect())
}
}

/// A vector of polynomials of length `k`
#[derive(Clone, Default, Debug, PartialEq)]
pub struct PolynomialVector<K: ArraySize>(pub Array<Polynomial, K>);

impl<K: ArraySize> Add<PolynomialVector<K>> for PolynomialVector<K> {
type Output = PolynomialVector<K>;

fn add(self, rhs: PolynomialVector<K>) -> PolynomialVector<K> {
PolynomialVector(
self.0
.iter()
.zip(rhs.0.iter())
.map(|(x, y)| x + y)
.collect(),
)
}
// Algorithm 8. SamplePolyCBD_eta(B)
//
// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in
// ByteDecode. We decode the PRF output into integers with eta bits, then use
// `count_ones` to perform the summation described in the algorithm.
pub(crate) fn sample_poly_cbd<Eta>(B: &PrfOutput<Eta>) -> Polynomial
where
Eta: CbdSamplingSize,
{
let vals: Polynomial = Encode::<Eta::SampleSize>::decode(B);
Polynomial::new(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect())
}

impl<K: ArraySize> PolynomialVector<K> {
pub fn sample_cbd<Eta>(sigma: &B32, start_n: u8) -> Self
where
Eta: CbdSamplingSize,
{
Self(Array::from_fn(|i| {
let N = start_n + u8::truncate(i);
let prf_output = PRF::<Eta>(sigma, N);
Polynomial::sample_cbd::<Eta>(&prf_output)
}))
}
pub(crate) fn sample_poly_vec_cbd<Eta, K>(sigma: &B32, start_n: u8) -> PolynomialVector<K>
where
Eta: CbdSamplingSize,
K: ArraySize,
{
PolynomialVector::new(Array::from_fn(|i| {
let N = start_n + u8::truncate(i);
let prf_output = PRF::<Eta>(sigma, N);
sample_poly_cbd::<Eta>(&prf_output)
}))
}

/// An element of the ring `T_q`, i.e., a tuple of 128 elements of the direct sum components of `T_q`.
Expand Down Expand Up @@ -162,7 +107,7 @@ impl Add<&NttPolynomial> for &NttPolynomial {
}
}

// Algorithm 6. SampleNTT (lines 4-13)
// Algorithm 7. SampleNTT (lines 4-13)
struct FieldElementReader<'a> {
xof: &'a mut dyn XofReader,
data: [u8; 96],
Expand Down Expand Up @@ -219,7 +164,7 @@ impl<'a> FieldElementReader<'a> {
}

impl NttPolynomial {
// Algorithm 6 SampleNTT(B)
// Algorithm 7 SampleNTT(B)
pub fn sample_uniform(B: &mut impl XofReader) -> Self {
let mut reader = FieldElementReader::new(B);
Self(Array::from_fn(|_| reader.next()))
Expand Down Expand Up @@ -288,7 +233,7 @@ const GAMMA: [FieldElement; 128] = {
gamma
};

// Algorithm 10. MuliplyNTTs
// Algorithm 11. MuliplyNTTs
impl Mul<&NttPolynomial> for &NttPolynomial {
type Output = NttPolynomial;

Expand Down Expand Up @@ -324,30 +269,32 @@ impl From<NttPolynomial> for Array<FieldElement, U256> {
}
}

// Algorithm 8. NTT
impl Polynomial {
pub fn ntt(&self) -> NttPolynomial {
let mut k = 1;
// Algorithm 9. NTT
pub(crate) fn ntt(poly: &Polynomial) -> NttPolynomial {
let mut k = 1;

let mut f = self.0;
for len in [128, 64, 32, 16, 8, 4, 2] {
for start in (0..256).step_by(2 * len) {
let zeta = ZETA_POW_BITREV[k];
k += 1;
let mut f = poly.0;
for len in [128, 64, 32, 16, 8, 4, 2] {
for start in (0..256).step_by(2 * len) {
let zeta = ZETA_POW_BITREV[k];
k += 1;

for j in start..(start + len) {
let t = zeta * f[j + len];
f[j + len] = f[j] - t;
f[j] = f[j] + t;
}
for j in start..(start + len) {
let t = zeta * f[j + len];
f[j + len] = f[j] - t;
f[j] = f[j] + t;
}
}

f.into()
}

f.into()
}

pub(crate) fn ntt_vector<K: ArraySize>(poly: &PolynomialVector<K>) -> NttVector<K> {
NttVector(poly.0.iter().map(ntt).collect())
}

// Algorithm 9. NTT^{-1}
// Algorithm 10. NTT^{-1}
impl NttPolynomial {
pub fn ntt_inverse(&self) -> Polynomial {
let mut f: Array<FieldElement, U256> = self.0.clone();
Expand All @@ -366,7 +313,7 @@ impl NttPolynomial {
}
}

FieldElement::new(3303) * &Polynomial(f)
FieldElement::new(3303) * &Polynomial::new(f)
}
}

Expand Down Expand Up @@ -436,15 +383,9 @@ impl<K: ArraySize> Mul<&NttVector<K>> for &NttVector<K> {
}
}

impl<K: ArraySize> PolynomialVector<K> {
pub fn ntt(&self) -> NttVector<K> {
NttVector(self.0.iter().map(Polynomial::ntt).collect())
}
}

impl<K: ArraySize> NttVector<K> {
pub fn ntt_inverse(&self) -> PolynomialVector<K> {
PolynomialVector(self.0.iter().map(NttPolynomial::ntt_inverse).collect())
PolynomialVector::new(self.0.iter().map(NttPolynomial::ntt_inverse).collect())
}
}

Expand Down Expand Up @@ -482,39 +423,35 @@ mod test {
use module_lattice::util::Flatten;

// Multiplication in R_q, modulo X^256 + 1
impl Mul<&Polynomial> for &Polynomial {
type Output = Polynomial;

fn mul(self, rhs: &Polynomial) -> Self::Output {
let mut out = Self::Output::default();
for (i, x) in self.0.iter().enumerate() {
for (j, y) in rhs.0.iter().enumerate() {
let (sign, index) = if i + j < 256 {
(FieldElement::new(1), i + j)
} else {
(FieldElement::new(BaseField::Q - 1), i + j - 256)
};

out.0[index] = out.0[index] + (sign * *x * *y);
}
fn poly_mul(lhs: &Polynomial, rhs: &Polynomial) -> Polynomial {
let mut out = Polynomial::default();
for (i, x) in lhs.0.iter().enumerate() {
for (j, y) in rhs.0.iter().enumerate() {
let (sign, index) = if i + j < 256 {
(FieldElement::new(1), i + j)
} else {
(FieldElement::new(BaseField::Q - 1), i + j - 256)
};

out.0[index] = out.0[index] + (sign * *x * *y);
}
out
}
out
}

// A polynomial with only a scalar component, to make simple test cases
fn const_ntt(x: Integer) -> NttPolynomial {
let mut p = Polynomial::default();
p.0[0] = FieldElement::new(x);
p.ntt()
super::ntt(&p)
}

#[test]
#[allow(clippy::cast_possible_truncation)]
fn polynomial_ops() {
let f = Polynomial(Array::from_fn(|i| FieldElement::new(i as Integer)));
let g = Polynomial(Array::from_fn(|i| FieldElement::new(2 * i as Integer)));
let sum = Polynomial(Array::from_fn(|i| FieldElement::new(3 * i as Integer)));
let f = Polynomial::new(Array::from_fn(|i| FieldElement::new(i as Integer)));
let g = Polynomial::new(Array::from_fn(|i| FieldElement::new(2 * i as Integer)));
let sum = Polynomial::new(Array::from_fn(|i| FieldElement::new(3 * i as Integer)));
assert_eq!((&f + &g), sum);
assert_eq!((&sum - &g), f);
assert_eq!(FieldElement::new(3) * &f, sum);
Expand All @@ -523,10 +460,10 @@ mod test {
#[test]
#[allow(clippy::cast_possible_truncation, clippy::similar_names)]
fn ntt() {
let f = Polynomial(Array::from_fn(|i| FieldElement::new(i as Integer)));
let g = Polynomial(Array::from_fn(|i| FieldElement::new(2 * i as Integer)));
let f_hat = f.ntt();
let g_hat = g.ntt();
let f = Polynomial::new(Array::from_fn(|i| FieldElement::new(i as Integer)));
let g = Polynomial::new(Array::from_fn(|i| FieldElement::new(2 * i as Integer)));
let f_hat = super::ntt(&f);
let g_hat = super::ntt(&g);

// Verify that NTT and NTT^-1 are actually inverses
let f_unhat = f_hat.ntt_inverse();
Expand All @@ -539,7 +476,7 @@ mod test {
assert_eq!(fg, fg_unhat);

// Verify that NTT is a homomorphism with regard to multiplication
let fg = &f * &g;
let fg = poly_mul(&f, &g);
let f_hat_g_hat = &f_hat * &g_hat;
let fg_unhat = f_hat_g_hat.ntt_inverse();
assert_eq!(fg, fg_unhat);
Expand Down Expand Up @@ -683,13 +620,13 @@ mod test {
// Eta = 2
let sigma = B32::default();
let prf_output = PRF::<U2>(&sigma, 0);
let sample = Polynomial::sample_cbd::<U2>(&prf_output).0;
let sample = super::sample_poly_cbd::<U2>(&prf_output).0;
test_sample(&sample, &CBD2);

// Eta = 3
let sigma = B32::default();
let prf_output = PRF::<U3>(&sigma, 0);
let sample = Polynomial::sample_cbd::<U3>(&prf_output).0;
let sample = super::sample_poly_cbd::<U3>(&prf_output).0;
test_sample(&sample, &CBD3);
}
}
8 changes: 4 additions & 4 deletions ml-kem/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ pub(crate) mod test {

#[test]
fn vector_codec() {
let poly = Polynomial(
let poly = Polynomial::new(
Array::<_, U8>([
FieldElement::new(0),
FieldElement::new(1),
Expand All @@ -290,17 +290,17 @@ pub(crate) mod test {
);

// The required vector sizes are 2, 3, and 4.
let decoded: PolynomialVector<U2> = PolynomialVector(Array([poly, poly]));
let decoded: PolynomialVector<U2> = PolynomialVector::new(Array([poly, poly]));
let encoded: EncodedPolynomialVector<U5, U2> =
Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat();
vector_codec_known_answer_test::<U5, PolynomialVector<U2>>(&decoded, &encoded);

let decoded: PolynomialVector<U3> = PolynomialVector(Array([poly, poly, poly]));
let decoded: PolynomialVector<U3> = PolynomialVector::new(Array([poly, poly, poly]));
let encoded: EncodedPolynomialVector<U5, U3> =
Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat();
vector_codec_known_answer_test::<U5, PolynomialVector<U3>>(&decoded, &encoded);

let decoded: PolynomialVector<U4> = PolynomialVector(Array([poly, poly, poly, poly]));
let decoded: PolynomialVector<U4> = PolynomialVector::new(Array([poly, poly, poly, poly]));
let encoded: EncodedPolynomialVector<U5, U4> =
Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat();
vector_codec_known_answer_test::<U5, PolynomialVector<U4>>(&decoded, &encoded);
Expand Down
Loading