Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 39 additions & 45 deletions ml-kem/src/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ use crate::param::{ArraySize, CbdSamplingSize};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;

pub type Integer = u16;
module_lattice::define_field!(BaseField, u16, u32, u64, 3329);

module_lattice::define_field!(BaseField, Integer, u32, u64, 3329);
pub type Int = <BaseField as Field>::Int;

/// An element of GF(q).
pub type FieldElement = module_lattice::algebra::Elem<BaseField>;
pub type Elem = module_lattice::algebra::Elem<BaseField>;

/// An element of the ring `R_q`, i.e., a polynomial over `Z_q` of degree 255
pub type Polynomial = module_lattice::algebra::Polynomial<BaseField>;

/// A vector of polynomials of length `K`.
pub type PolynomialVector<K> = module_lattice::algebra::Vector<BaseField, K>;
pub type Vector<K> = module_lattice::algebra::Vector<BaseField, K>;

/// An element of the ring `T_q` i.e. a tuple of 128 elements of the direct sum components of `T_q`.
pub type NttPolynomial = module_lattice::algebra::NttPolynomial<BaseField>;
Expand All @@ -37,7 +37,7 @@ pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial {
xof: &'a mut dyn XofReader,
data: [u8; 96],
start: usize,
next: Option<Integer>,
next: Option<Int>,
}

impl<'a> FieldElementReader<'a> {
Expand All @@ -55,10 +55,10 @@ pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial {
out
}

fn next(&mut self) -> FieldElement {
fn next(&mut self) -> Elem {
if let Some(val) = self.next {
self.next = None;
return FieldElement::new(val);
return Elem::new(val);
}

loop {
Expand All @@ -71,18 +71,18 @@ pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial {
let b = &self.data[self.start..end];
self.start = end;

let d1 = Integer::from(b[0]) + ((Integer::from(b[1]) & 0xf) << 8);
let d2 = (Integer::from(b[1]) >> 4) + ((Integer::from(b[2]) as Integer) << 4);
let d1 = Int::from(b[0]) + ((Int::from(b[1]) & 0xf) << 8);
let d2 = (Int::from(b[1]) >> 4) + ((Int::from(b[2]) as Int) << 4);

if d1 < BaseField::Q {
if d2 < BaseField::Q {
self.next = Some(d2);
}
return FieldElement::new(d1);
return Elem::new(d1);
}

if d2 < BaseField::Q {
return FieldElement::new(d2);
return Elem::new(d2);
}
}
}
Expand All @@ -105,12 +105,12 @@ where
Polynomial::new(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect())
}

pub(crate) fn sample_poly_vec_cbd<Eta, K>(sigma: &B32, start_n: u8) -> PolynomialVector<K>
pub(crate) fn sample_poly_vec_cbd<Eta, K>(sigma: &B32, start_n: u8) -> Vector<K>
where
Eta: CbdSamplingSize,
K: ArraySize,
{
PolynomialVector::new(Array::from_fn(|i| {
Vector::new(Array::from_fn(|i| {
let N = start_n + u8::truncate(i);
let prf_output = PRF::<Eta>(sigma, N);
sample_poly_cbd::<Eta>(&prf_output)
Expand Down Expand Up @@ -150,7 +150,7 @@ impl Ntt for Polynomial {
}
}

impl<K: ArraySize> Ntt for PolynomialVector<K> {
impl<K: ArraySize> Ntt for Vector<K> {
type Output = NttVector<K>;

fn ntt(&self) -> NttVector<K> {
Expand All @@ -171,7 +171,7 @@ impl NttInverse for NttPolynomial {
type Output = Polynomial;

fn ntt_inverse(&self) -> Polynomial {
let mut f: Array<FieldElement, U256> = self.0.clone();
let mut f: Array<Elem, U256> = self.0.clone();

let mut k = 127;
for len in [2, 4, 8, 16, 32, 64, 128] {
Expand All @@ -187,7 +187,7 @@ impl NttInverse for NttPolynomial {
}
}

FieldElement::new(3303) * &Polynomial::new(f)
Elem::new(3303) * &Polynomial::new(f)
}
}

Expand Down Expand Up @@ -218,13 +218,7 @@ impl MultiplyNtt for BaseField {
/// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of
/// modular reductions, since these are the expensive operation.
#[inline]
fn base_case_multiply(
a0: FieldElement,
a1: FieldElement,
b0: FieldElement,
b1: FieldElement,
i: usize,
) -> (FieldElement, FieldElement) {
fn base_case_multiply(a0: Elem, a1: Elem, b0: Elem, b1: Elem, i: usize) -> (Elem, Elem) {
let a0 = u32::from(a0.0);
let a1 = u32::from(a1.0);
let b0 = u32::from(b0.0);
Expand All @@ -235,7 +229,7 @@ fn base_case_multiply(

let c0 = BaseField::barrett_reduce(a0 * b0 + a1 * b1g);
let c1 = BaseField::barrett_reduce(a0 * b1 + a1 * b0);
(FieldElement::new(c0), FieldElement::new(c1))
(Elem::new(c0), Elem::new(c1))
}

/// Since the powers of zeta used in the `NTT` and `MultiplyNTTs` are fixed, we use pre-computed
Expand All @@ -251,7 +245,7 @@ fn base_case_multiply(
/// The values computed here match those provided in Appendix A of FIPS 203.
/// `ZETA_POW_BITREV` corresponds to the first table, and `GAMMA` to the second table.
#[allow(clippy::cast_possible_truncation)]
const ZETA_POW_BITREV: [FieldElement; 128] = {
const ZETA_POW_BITREV: [Elem; 128] = {
const ZETA: u64 = 17;
#[allow(clippy::integer_division_remainder_used)]
const fn bitrev7(x: usize) -> usize {
Expand All @@ -265,18 +259,18 @@ const ZETA_POW_BITREV: [FieldElement; 128] = {
}

// Compute the powers of zeta
let mut pow = [FieldElement::new(0); 128];
let mut pow = [Elem::new(0); 128];
let mut i = 0;
let mut curr = 1u64;
#[allow(clippy::integer_division_remainder_used)]
while i < 128 {
pow[i] = FieldElement::new(curr as u16);
pow[i] = Elem::new(curr as u16);
i += 1;
curr = (curr * ZETA) % BaseField::QLL;
}

// Reorder the powers according to bitrev7
let mut pow_bitrev = [FieldElement::new(0); 128];
let mut pow_bitrev = [Elem::new(0); 128];
let mut i = 0;
while i < 128 {
pow_bitrev[i] = pow[bitrev7(i)];
Expand All @@ -286,15 +280,15 @@ const ZETA_POW_BITREV: [FieldElement; 128] = {
};

#[allow(clippy::cast_possible_truncation)]
const GAMMA: [FieldElement; 128] = {
const GAMMA: [Elem; 128] = {
const ZETA: u64 = 17;
let mut gamma = [FieldElement::new(0); 128];
let mut gamma = [Elem::new(0); 128];
let mut i = 0;
while i < 128 {
let zpr = ZETA_POW_BITREV[i].0 as u64;
#[allow(clippy::integer_division_remainder_used)]
let g = (zpr * zpr * ZETA) % BaseField::QLL;
gamma[i] = FieldElement::new(g as u16);
gamma[i] = Elem::new(g as u16);
i += 1;
}
gamma
Expand Down Expand Up @@ -367,8 +361,8 @@ impl<K: ArraySize> Mul<&NttVector<K>> for &NttVector<K> {
}

impl<K: ArraySize> NttVector<K> {
pub fn ntt_inverse(&self) -> PolynomialVector<K> {
PolynomialVector::new(self.0.iter().map(NttInverse::ntt_inverse).collect())
pub fn ntt_inverse(&self) -> Vector<K> {
Vector::new(self.0.iter().map(NttInverse::ntt_inverse).collect())
}
}

Expand Down Expand Up @@ -411,9 +405,9 @@ mod test {
for (i, x) in lhs.0.iter().enumerate() {
for (j, y) in rhs.0.iter().enumerate() {
let (sign, index) = if i + j < 256 {
(FieldElement::new(1), i + j)
(Elem::new(1), i + j)
} else {
(FieldElement::new(BaseField::Q - 1), i + j - 256)
(Elem::new(BaseField::Q - 1), i + j - 256)
};

out.0[index] = out.0[index] + (sign * *x * *y);
Expand All @@ -423,28 +417,28 @@ mod test {
}

// A polynomial with only a scalar component, to make simple test cases
fn const_ntt(x: Integer) -> NttPolynomial {
fn const_ntt(x: Int) -> NttPolynomial {
let mut p = Polynomial::default();
p.0[0] = FieldElement::new(x);
p.0[0] = Elem::new(x);
p.ntt()
}

#[test]
#[allow(clippy::cast_possible_truncation)]
fn polynomial_ops() {
let f = Polynomial::new(Array::from_fn(|i| FieldElement::new(i as Integer)));
let g = Polynomial::new(Array::from_fn(|i| FieldElement::new(2 * i as Integer)));
let sum = Polynomial::new(Array::from_fn(|i| FieldElement::new(3 * i as Integer)));
let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int)));
let g = Polynomial::new(Array::from_fn(|i| Elem::new(2 * i as Int)));
let sum = Polynomial::new(Array::from_fn(|i| Elem::new(3 * i as Int)));
assert_eq!((&f + &g), sum);
assert_eq!((&sum - &g), f);
assert_eq!(FieldElement::new(3) * &f, sum);
assert_eq!(Elem::new(3) * &f, sum);
}

#[test]
#[allow(clippy::cast_possible_truncation, clippy::similar_names)]
fn ntt() {
let f = Polynomial::new(Array::from_fn(|i| FieldElement::new(i as Integer)));
let g = Polynomial::new(Array::from_fn(|i| FieldElement::new(2 * i as Integer)));
let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int)));
let g = Polynomial::new(Array::from_fn(|i| Elem::new(2 * i as Int)));
let f_hat = f.ntt();
let g_hat = g.ntt();

Expand Down Expand Up @@ -564,7 +558,7 @@ mod test {
}

#[allow(clippy::cast_precision_loss, clippy::large_stack_arrays)]
fn test_sample(sample: &[FieldElement], ref_dist: &Distribution) {
fn test_sample(sample: &[Elem], ref_dist: &Distribution) {
// Verify data and compute the empirical distribution
let mut sample_dist: Distribution = [0.0; Q_SIZE];
let bump: f64 = 1.0 / (sample.len() as f64);
Expand All @@ -590,7 +584,7 @@ mod test {
// Since Q ~= 2^11 and 256 == 2^8, we need 2^3 == 8 runs of 256 to get out of the bad
// regime and get a meaningful measurement.
let rho = B32::default();
let sample: Array<Array<FieldElement, U256>, U8> = Array::from_fn(|i| {
let sample: Array<Array<Elem, U256>, U8> = Array::from_fn(|i| {
let mut xof = XOF(&rho, 0, i as u8);
sample_ntt(&mut xof).into()
});
Expand Down
18 changes: 9 additions & 9 deletions ml-kem/src/compress.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::algebra::{BaseField, FieldElement, Integer, Polynomial, PolynomialVector};
use crate::algebra::{BaseField, Elem, Int, Polynomial, Vector};
use crate::param::{ArraySize, EncodingSize};
use module_lattice::{algebra::Field, util::Truncate};

// A convenience trait to allow us to associate some constants with a typenum
pub trait CompressionFactor: EncodingSize {
const POW2_HALF: u32;
const MASK: Integer;
const MASK: Int;
const DIV_SHIFT: usize;
const DIV_MUL: u64;
}
Expand All @@ -15,7 +15,7 @@ where
T: EncodingSize,
{
const POW2_HALF: u32 = 1 << (T::USIZE - 1);
const MASK: Integer = ((1 as Integer) << T::USIZE) - 1;
const MASK: Int = ((1 as Int) << T::USIZE) - 1;
const DIV_SHIFT: usize = 34;
#[allow(clippy::integer_division_remainder_used)]
const DIV_MUL: u64 = (1 << T::DIV_SHIFT) / BaseField::QLL;
Expand All @@ -27,7 +27,7 @@ pub trait Compress {
fn decompress<D: CompressionFactor>(&mut self) -> &Self;
}

impl Compress for FieldElement {
impl Compress for Elem {
// Equation 4.5: Compress_d(x) = round((2^d / q) x)
//
// Here and in decompression, we leverage the following facts:
Expand Down Expand Up @@ -68,7 +68,7 @@ impl Compress for Polynomial {
}
}

impl<K: ArraySize> Compress for PolynomialVector<K> {
impl<K: ArraySize> Compress for Vector<K> {
fn compress<D: CompressionFactor>(&mut self) -> &Self {
for x in &mut self.0 {
x.compress::<D>();
Expand Down Expand Up @@ -111,7 +111,7 @@ pub(crate) mod test {
let error_threshold = i32::from(Ratio::new(BaseField::Q, 1 << D::USIZE).to_integer());

for x in 0..BaseField::Q {
let mut y = FieldElement::new(x);
let mut y = Elem::new(x);
y.compress::<D>();
y.decompress::<D>();

Expand All @@ -131,7 +131,7 @@ pub(crate) mod test {

fn decompression_compression_equality<D: CompressionFactor>() {
for x in 0..(1 << D::USIZE) {
let mut y = FieldElement::new(x);
let mut y = Elem::new(x);
y.decompress::<D>();
y.compress::<D>();

Expand All @@ -142,7 +142,7 @@ pub(crate) mod test {
fn decompress_KAT<D: CompressionFactor>() {
for y in 0..(1 << D::USIZE) {
let x_expected = rational_decompress::<D>(y);
let mut x_actual = FieldElement::new(y);
let mut x_actual = Elem::new(y);
x_actual.decompress::<D>();

assert_eq!(x_expected, x_actual.0);
Expand All @@ -152,7 +152,7 @@ pub(crate) mod test {
fn compress_KAT<D: CompressionFactor>() {
for x in 0..BaseField::Q {
let y_expected = rational_compress::<D>(x);
let mut y_actual = FieldElement::new(x);
let mut y_actual = Elem::new(x);
y_actual.compress::<D>();

assert_eq!(y_expected, y_actual.0, "for x: {}, D: {}", x, D::USIZE);
Expand Down
Loading