diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index 1e409b9..79b568d 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -15,18 +15,18 @@ use crate::param::{ArraySize, CbdSamplingSize}; #[cfg(feature = "zeroize")] use zeroize::Zeroize; -pub type Integer = u16; +module_lattice::define_field!(BaseField, u16, u32, u64, 3329); -module_lattice::define_field!(BaseField, Integer, u32, u64, 3329); +pub type Int = ::Int; /// An element of GF(q). -pub type FieldElement = module_lattice::algebra::Elem; +pub type Elem = module_lattice::algebra::Elem; /// An element of the ring `R_q`, i.e., a polynomial over `Z_q` of degree 255 pub type Polynomial = module_lattice::algebra::Polynomial; /// A vector of polynomials of length `K`. -pub type PolynomialVector = module_lattice::algebra::Vector; +pub type Vector = module_lattice::algebra::Vector; /// An element of the ring `T_q` i.e. a tuple of 128 elements of the direct sum components of `T_q`. pub type NttPolynomial = module_lattice::algebra::NttPolynomial; @@ -37,7 +37,7 @@ pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial { xof: &'a mut dyn XofReader, data: [u8; 96], start: usize, - next: Option, + next: Option, } impl<'a> FieldElementReader<'a> { @@ -55,10 +55,10 @@ pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial { out } - fn next(&mut self) -> FieldElement { + fn next(&mut self) -> Elem { if let Some(val) = self.next { self.next = None; - return FieldElement::new(val); + return Elem::new(val); } loop { @@ -71,18 +71,18 @@ pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial { let b = &self.data[self.start..end]; self.start = end; - let d1 = Integer::from(b[0]) + ((Integer::from(b[1]) & 0xf) << 8); - let d2 = (Integer::from(b[1]) >> 4) + ((Integer::from(b[2]) as Integer) << 4); + let d1 = Int::from(b[0]) + ((Int::from(b[1]) & 0xf) << 8); + let d2 = (Int::from(b[1]) >> 4) + ((Int::from(b[2]) as Int) << 4); if d1 < BaseField::Q { if d2 < BaseField::Q { self.next = Some(d2); } - return FieldElement::new(d1); + return Elem::new(d1); } if d2 < BaseField::Q { - return FieldElement::new(d2); + return Elem::new(d2); } } } @@ -105,12 +105,12 @@ where Polynomial::new(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect()) } -pub(crate) fn sample_poly_vec_cbd(sigma: &B32, start_n: u8) -> PolynomialVector +pub(crate) fn sample_poly_vec_cbd(sigma: &B32, start_n: u8) -> Vector where Eta: CbdSamplingSize, K: ArraySize, { - PolynomialVector::new(Array::from_fn(|i| { + Vector::new(Array::from_fn(|i| { let N = start_n + u8::truncate(i); let prf_output = PRF::(sigma, N); sample_poly_cbd::(&prf_output) @@ -150,7 +150,7 @@ impl Ntt for Polynomial { } } -impl Ntt for PolynomialVector { +impl Ntt for Vector { type Output = NttVector; fn ntt(&self) -> NttVector { @@ -171,7 +171,7 @@ impl NttInverse for NttPolynomial { type Output = Polynomial; fn ntt_inverse(&self) -> Polynomial { - let mut f: Array = self.0.clone(); + let mut f: Array = self.0.clone(); let mut k = 127; for len in [2, 4, 8, 16, 32, 64, 128] { @@ -187,7 +187,7 @@ impl NttInverse for NttPolynomial { } } - FieldElement::new(3303) * &Polynomial::new(f) + Elem::new(3303) * &Polynomial::new(f) } } @@ -218,13 +218,7 @@ impl MultiplyNtt for BaseField { /// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of /// modular reductions, since these are the expensive operation. #[inline] -fn base_case_multiply( - a0: FieldElement, - a1: FieldElement, - b0: FieldElement, - b1: FieldElement, - i: usize, -) -> (FieldElement, FieldElement) { +fn base_case_multiply(a0: Elem, a1: Elem, b0: Elem, b1: Elem, i: usize) -> (Elem, Elem) { let a0 = u32::from(a0.0); let a1 = u32::from(a1.0); let b0 = u32::from(b0.0); @@ -235,7 +229,7 @@ fn base_case_multiply( let c0 = BaseField::barrett_reduce(a0 * b0 + a1 * b1g); let c1 = BaseField::barrett_reduce(a0 * b1 + a1 * b0); - (FieldElement::new(c0), FieldElement::new(c1)) + (Elem::new(c0), Elem::new(c1)) } /// Since the powers of zeta used in the `NTT` and `MultiplyNTTs` are fixed, we use pre-computed @@ -251,7 +245,7 @@ fn base_case_multiply( /// The values computed here match those provided in Appendix A of FIPS 203. /// `ZETA_POW_BITREV` corresponds to the first table, and `GAMMA` to the second table. #[allow(clippy::cast_possible_truncation)] -const ZETA_POW_BITREV: [FieldElement; 128] = { +const ZETA_POW_BITREV: [Elem; 128] = { const ZETA: u64 = 17; #[allow(clippy::integer_division_remainder_used)] const fn bitrev7(x: usize) -> usize { @@ -265,18 +259,18 @@ const ZETA_POW_BITREV: [FieldElement; 128] = { } // Compute the powers of zeta - let mut pow = [FieldElement::new(0); 128]; + let mut pow = [Elem::new(0); 128]; let mut i = 0; let mut curr = 1u64; #[allow(clippy::integer_division_remainder_used)] while i < 128 { - pow[i] = FieldElement::new(curr as u16); + pow[i] = Elem::new(curr as u16); i += 1; curr = (curr * ZETA) % BaseField::QLL; } // Reorder the powers according to bitrev7 - let mut pow_bitrev = [FieldElement::new(0); 128]; + let mut pow_bitrev = [Elem::new(0); 128]; let mut i = 0; while i < 128 { pow_bitrev[i] = pow[bitrev7(i)]; @@ -286,15 +280,15 @@ const ZETA_POW_BITREV: [FieldElement; 128] = { }; #[allow(clippy::cast_possible_truncation)] -const GAMMA: [FieldElement; 128] = { +const GAMMA: [Elem; 128] = { const ZETA: u64 = 17; - let mut gamma = [FieldElement::new(0); 128]; + let mut gamma = [Elem::new(0); 128]; let mut i = 0; while i < 128 { let zpr = ZETA_POW_BITREV[i].0 as u64; #[allow(clippy::integer_division_remainder_used)] let g = (zpr * zpr * ZETA) % BaseField::QLL; - gamma[i] = FieldElement::new(g as u16); + gamma[i] = Elem::new(g as u16); i += 1; } gamma @@ -367,8 +361,8 @@ impl Mul<&NttVector> for &NttVector { } impl NttVector { - pub fn ntt_inverse(&self) -> PolynomialVector { - PolynomialVector::new(self.0.iter().map(NttInverse::ntt_inverse).collect()) + pub fn ntt_inverse(&self) -> Vector { + Vector::new(self.0.iter().map(NttInverse::ntt_inverse).collect()) } } @@ -411,9 +405,9 @@ mod test { for (i, x) in lhs.0.iter().enumerate() { for (j, y) in rhs.0.iter().enumerate() { let (sign, index) = if i + j < 256 { - (FieldElement::new(1), i + j) + (Elem::new(1), i + j) } else { - (FieldElement::new(BaseField::Q - 1), i + j - 256) + (Elem::new(BaseField::Q - 1), i + j - 256) }; out.0[index] = out.0[index] + (sign * *x * *y); @@ -423,28 +417,28 @@ mod test { } // A polynomial with only a scalar component, to make simple test cases - fn const_ntt(x: Integer) -> NttPolynomial { + fn const_ntt(x: Int) -> NttPolynomial { let mut p = Polynomial::default(); - p.0[0] = FieldElement::new(x); + p.0[0] = Elem::new(x); p.ntt() } #[test] #[allow(clippy::cast_possible_truncation)] fn polynomial_ops() { - let f = Polynomial::new(Array::from_fn(|i| FieldElement::new(i as Integer))); - let g = Polynomial::new(Array::from_fn(|i| FieldElement::new(2 * i as Integer))); - let sum = Polynomial::new(Array::from_fn(|i| FieldElement::new(3 * i as Integer))); + let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int))); + let g = Polynomial::new(Array::from_fn(|i| Elem::new(2 * i as Int))); + let sum = Polynomial::new(Array::from_fn(|i| Elem::new(3 * i as Int))); assert_eq!((&f + &g), sum); assert_eq!((&sum - &g), f); - assert_eq!(FieldElement::new(3) * &f, sum); + assert_eq!(Elem::new(3) * &f, sum); } #[test] #[allow(clippy::cast_possible_truncation, clippy::similar_names)] fn ntt() { - let f = Polynomial::new(Array::from_fn(|i| FieldElement::new(i as Integer))); - let g = Polynomial::new(Array::from_fn(|i| FieldElement::new(2 * i as Integer))); + let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int))); + let g = Polynomial::new(Array::from_fn(|i| Elem::new(2 * i as Int))); let f_hat = f.ntt(); let g_hat = g.ntt(); @@ -564,7 +558,7 @@ mod test { } #[allow(clippy::cast_precision_loss, clippy::large_stack_arrays)] - fn test_sample(sample: &[FieldElement], ref_dist: &Distribution) { + fn test_sample(sample: &[Elem], ref_dist: &Distribution) { // Verify data and compute the empirical distribution let mut sample_dist: Distribution = [0.0; Q_SIZE]; let bump: f64 = 1.0 / (sample.len() as f64); @@ -590,7 +584,7 @@ mod test { // Since Q ~= 2^11 and 256 == 2^8, we need 2^3 == 8 runs of 256 to get out of the bad // regime and get a meaningful measurement. let rho = B32::default(); - let sample: Array, U8> = Array::from_fn(|i| { + let sample: Array, U8> = Array::from_fn(|i| { let mut xof = XOF(&rho, 0, i as u8); sample_ntt(&mut xof).into() }); diff --git a/ml-kem/src/compress.rs b/ml-kem/src/compress.rs index e7859bc..41b5b6b 100644 --- a/ml-kem/src/compress.rs +++ b/ml-kem/src/compress.rs @@ -1,11 +1,11 @@ -use crate::algebra::{BaseField, FieldElement, Integer, Polynomial, PolynomialVector}; +use crate::algebra::{BaseField, Elem, Int, Polynomial, Vector}; use crate::param::{ArraySize, EncodingSize}; use module_lattice::{algebra::Field, util::Truncate}; // A convenience trait to allow us to associate some constants with a typenum pub trait CompressionFactor: EncodingSize { const POW2_HALF: u32; - const MASK: Integer; + const MASK: Int; const DIV_SHIFT: usize; const DIV_MUL: u64; } @@ -15,7 +15,7 @@ where T: EncodingSize, { const POW2_HALF: u32 = 1 << (T::USIZE - 1); - const MASK: Integer = ((1 as Integer) << T::USIZE) - 1; + const MASK: Int = ((1 as Int) << T::USIZE) - 1; const DIV_SHIFT: usize = 34; #[allow(clippy::integer_division_remainder_used)] const DIV_MUL: u64 = (1 << T::DIV_SHIFT) / BaseField::QLL; @@ -27,7 +27,7 @@ pub trait Compress { fn decompress(&mut self) -> &Self; } -impl Compress for FieldElement { +impl Compress for Elem { // Equation 4.5: Compress_d(x) = round((2^d / q) x) // // Here and in decompression, we leverage the following facts: @@ -68,7 +68,7 @@ impl Compress for Polynomial { } } -impl Compress for PolynomialVector { +impl Compress for Vector { fn compress(&mut self) -> &Self { for x in &mut self.0 { x.compress::(); @@ -111,7 +111,7 @@ pub(crate) mod test { let error_threshold = i32::from(Ratio::new(BaseField::Q, 1 << D::USIZE).to_integer()); for x in 0..BaseField::Q { - let mut y = FieldElement::new(x); + let mut y = Elem::new(x); y.compress::(); y.decompress::(); @@ -131,7 +131,7 @@ pub(crate) mod test { fn decompression_compression_equality() { for x in 0..(1 << D::USIZE) { - let mut y = FieldElement::new(x); + let mut y = Elem::new(x); y.decompress::(); y.compress::(); @@ -142,7 +142,7 @@ pub(crate) mod test { fn decompress_KAT() { for y in 0..(1 << D::USIZE) { let x_expected = rational_decompress::(y); - let mut x_actual = FieldElement::new(y); + let mut x_actual = Elem::new(y); x_actual.decompress::(); assert_eq!(x_expected, x_actual.0); @@ -152,7 +152,7 @@ pub(crate) mod test { fn compress_KAT() { for x in 0..BaseField::Q { let y_expected = rational_compress::(x); - let mut y_actual = FieldElement::new(x); + let mut y_actual = Elem::new(x); y_actual.compress::(); assert_eq!(y_expected, y_actual.0, "for x: {}, D: {}", x, D::USIZE); diff --git a/ml-kem/src/encode.rs b/ml-kem/src/encode.rs index 2758ff8..7664d2e 100644 --- a/ml-kem/src/encode.rs +++ b/ml-kem/src/encode.rs @@ -1,7 +1,5 @@ use crate::{ - algebra::{ - BaseField, FieldElement, Integer, NttPolynomial, NttVector, Polynomial, PolynomialVector, - }, + algebra::{BaseField, Elem, Int, NttPolynomial, NttVector, Polynomial, Vector}, param::{ArraySize, EncodedPolynomial, EncodingSize, VectorEncodingSize}, }; use array::{ @@ -10,7 +8,7 @@ use array::{ }; use module_lattice::{algebra::Field, util::Truncate}; -type DecodedValue = Array; +type DecodedValue = Array; // Algorithm 4 ByteEncode_d(F) // @@ -54,7 +52,7 @@ fn byte_decode(bytes: &EncodedPolynomial) -> DecodedValue { let x = u128::from_le_bytes(xb); for (j, vj) in v.iter_mut().enumerate() { - let val: Integer = Truncate::truncate(x >> (D::USIZE * j)); + let val: Int = Truncate::truncate(x >> (D::USIZE * j)); vj.0 = val & mask; if D::USIZE == 12 { @@ -84,7 +82,7 @@ impl Encode for Polynomial { } } -impl Encode for PolynomialVector +impl Encode for Vector where K: ArraySize, D: VectorEncodingSize, @@ -186,12 +184,12 @@ pub(crate) mod test { // Test random decode/encode and encode/decode round trips let mut rng = UnwrapErr(SysRng); - let decoded = Array::::from_fn(|_| (rng.next_u32() & 0xFFFF) as Integer); + let decoded = Array::::from_fn(|_| (rng.next_u32() & 0xFFFF) as Int); let m = match D::USIZE { 12 => BaseField::Q, - d => (1 as Integer) << d, + d => (1 as Int) << d, }; - let decoded = decoded.iter().map(|x| FieldElement::new(x % m)).collect(); + let decoded = decoded.iter().map(|x| Elem::new(x % m)).collect(); let actual_encoded = byte_encode::(&decoded); let actual_decoded = byte_decode::(&actual_encoded); @@ -204,21 +202,20 @@ pub(crate) mod test { #[test] fn byte_codec() { // The 1-bit can only represent decoded values equal to 0 or 1. - let decoded: DecodedValue = - Array::<_, U2>([FieldElement::new(0), FieldElement::new(1)]).repeat(); + let decoded: DecodedValue = Array::<_, U2>([Elem::new(0), Elem::new(1)]).repeat(); let encoded: EncodedPolynomial = Array([0xaa; 32]); byte_codec_test::(&decoded, &encoded); // For other codec widths, we use a standard sequence let decoded: DecodedValue = Array::<_, U8>([ - FieldElement::new(0), - FieldElement::new(1), - FieldElement::new(2), - FieldElement::new(3), - FieldElement::new(4), - FieldElement::new(5), - FieldElement::new(6), - FieldElement::new(7), + Elem::new(0), + Elem::new(1), + Elem::new(2), + Elem::new(3), + Elem::new(4), + Elem::new(5), + Elem::new(6), + Elem::new(7), ]) .repeat(); @@ -255,7 +252,7 @@ pub(crate) mod test { fn byte_codec_12_mod() { // DecodeBytes_12 is required to reduce mod q let encoded: EncodedPolynomial = Array([0xff; 384]); - let decoded: DecodedValue = Array([FieldElement::new(0xfff % BaseField::Q); 256]); + let decoded: DecodedValue = Array([Elem::new(0xfff % BaseField::Q); 256]); let actual_decoded = byte_decode::(&encoded); assert_eq!(actual_decoded, decoded); @@ -277,32 +274,32 @@ pub(crate) mod test { fn vector_codec() { let poly = Polynomial::new( Array::<_, U8>([ - FieldElement::new(0), - FieldElement::new(1), - FieldElement::new(2), - FieldElement::new(3), - FieldElement::new(4), - FieldElement::new(5), - FieldElement::new(6), - FieldElement::new(7), + Elem::new(0), + Elem::new(1), + Elem::new(2), + Elem::new(3), + Elem::new(4), + Elem::new(5), + Elem::new(6), + Elem::new(7), ]) .repeat(), ); // The required vector sizes are 2, 3, and 4. - let decoded: PolynomialVector = PolynomialVector::new(Array([poly, poly])); + let decoded: Vector = Vector::new(Array([poly, poly])); let encoded: EncodedPolynomialVector = Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); - vector_codec_known_answer_test::>(&decoded, &encoded); + vector_codec_known_answer_test::>(&decoded, &encoded); - let decoded: PolynomialVector = PolynomialVector::new(Array([poly, poly, poly])); + let decoded: Vector = Vector::new(Array([poly, poly, poly])); let encoded: EncodedPolynomialVector = Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); - vector_codec_known_answer_test::>(&decoded, &encoded); + vector_codec_known_answer_test::>(&decoded, &encoded); - let decoded: PolynomialVector = PolynomialVector::new(Array([poly, poly, poly, poly])); + let decoded: Vector = Vector::new(Array([poly, poly, poly, poly])); let encoded: EncodedPolynomialVector = Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); - vector_codec_known_answer_test::>(&decoded, &encoded); + vector_codec_known_answer_test::>(&decoded, &encoded); } } diff --git a/ml-kem/src/param.rs b/ml-kem/src/param.rs index ff55d4b..d22a3c7 100644 --- a/ml-kem/src/param.rs +++ b/ml-kem/src/param.rs @@ -12,7 +12,7 @@ use crate::{ B32, - algebra::{BaseField, FieldElement, NttVector}, + algebra::{BaseField, Elem, NttVector}, encode::Encode, }; use array::{ @@ -104,7 +104,7 @@ where pub trait CbdSamplingSize: ArraySize { type SampleSize: EncodingSize; type OnesSize: ArraySize; - const ONES: Array; + const ONES: Array; } // To speed up CBD sampling, we pre-compute all the bit-manipulations: @@ -116,13 +116,13 @@ pub trait CbdSamplingSize: ArraySize { // We have to allow the use of `as` here because we can't use our nice Truncate trait, because // const functions don't support traits. #[allow(clippy::cast_possible_truncation)] -const fn ones_array() -> Array +const fn ones_array() -> Array where - U: ArraySize = [FieldElement; N]>, + U: ArraySize = [Elem; N]>, Const: ToUInt, { let max = 1 << B; - let mut out = [FieldElement::new(0); N]; + let mut out = [Elem::new(0); N]; let mut x = 0usize; while x < max { let mut y = 0usize; @@ -131,7 +131,7 @@ where let x_ones = x.count_ones() as u16; let y_ones = y.count_ones() as u16; let i = x + (y << B); - out[i] = FieldElement::new((x_ones + BaseField::Q - y_ones) % BaseField::Q); + out[i] = Elem::new((x_ones + BaseField::Q - y_ones) % BaseField::Q); y += 1; } @@ -143,13 +143,13 @@ where impl CbdSamplingSize for U2 { type SampleSize = U4; type OnesSize = U16; - const ONES: Array = ones_array::<2, 16, U16>(); + const ONES: Array = ones_array::<2, 16, U16>(); } impl CbdSamplingSize for U3 { type SampleSize = U6; type OnesSize = U64; - const ONES: Array = ones_array::<3, 64, U64>(); + const ONES: Array = ones_array::<3, 64, U64>(); } /// A `ParameterSet` captures the parameters that describe a particular instance of ML-KEM. There diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index f0ee1da..88007f2 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -1,7 +1,6 @@ use crate::B32; use crate::algebra::{ - Ntt, NttInverse, NttMatrix, NttVector, Polynomial, PolynomialVector, sample_poly_cbd, - sample_poly_vec_cbd, + Ntt, NttInverse, NttMatrix, NttVector, Polynomial, Vector, sample_poly_cbd, sample_poly_vec_cbd, }; use crate::compress::Compress; use crate::crypto::{G, PRF}; @@ -67,8 +66,8 @@ where // Sample pseudo-random matrix and vectors let A_hat: NttMatrix = NttMatrix::sample_uniform(&rho, false); - let s: PolynomialVector = sample_poly_vec_cbd::(&sigma, 0); - let e: PolynomialVector = sample_poly_vec_cbd::(&sigma, P::K::U8); + let s: Vector = sample_poly_vec_cbd::(&sigma, 0); + let e: Vector = sample_poly_vec_cbd::(&sigma, P::K::U8); // NTT the vectors let s_hat = s.ntt(); @@ -88,7 +87,7 @@ where pub fn decrypt(&self, ciphertext: &EncodedCiphertext

) -> B32 { let (c1, c2) = P::split_ct(ciphertext); - let mut u: PolynomialVector = Encode::::decode(c1); + let mut u: Vector = Encode::::decode(c1); u.decompress::(); let mut v: Polynomial = Encode::::decode(c2); @@ -138,7 +137,7 @@ where let A_hat_t = NttMatrix::::sample_uniform(&self.rho, true); let r_hat: NttVector = r.ntt(); - let ATr: PolynomialVector = (&A_hat_t * &r_hat).ntt_inverse(); + let ATr: Vector = (&A_hat_t * &r_hat).ntt_inverse(); let mut u = ATr + e1; let mut mu: Polynomial = Encode::::decode(message);