From e3593c7fa47851e51a0f0dba68bab361ff449eb9 Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Thu, 29 Jan 2026 15:00:38 -0700 Subject: [PATCH] ml-kem: use `module_lattice::encode` Uses the shared implementation of encoding functionality provided by the `module-lattice` crate. This completely replaces the previous `ml_kem::encode` module, which has been removed, and its tests transferred to `module-lattice`. --- Cargo.lock | 1 + ml-kem/Cargo.toml | 2 +- ml-kem/src/algebra.rs | 2 +- ml-kem/src/encode.rs | 305 -------------------------------- ml-kem/src/lib.rs | 6 +- ml-kem/src/param.rs | 91 ++-------- ml-kem/src/pke.rs | 2 +- module-lattice/Cargo.toml | 9 +- module-lattice/src/algebra.rs | 2 +- module-lattice/src/encode.rs | 10 +- module-lattice/src/lib.rs | 6 - module-lattice/src/util.rs | 14 +- module-lattice/tests/algebra.rs | 12 +- module-lattice/tests/encode.rs | 179 ++++++++++++++++++- 14 files changed, 217 insertions(+), 424 deletions(-) delete mode 100644 ml-kem/src/encode.rs diff --git a/Cargo.lock b/Cargo.lock index 14d4374..116d85d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -791,6 +791,7 @@ dependencies = [ name = "module-lattice" version = "0.1.0-pre.0" dependencies = [ + "getrandom", "hybrid-array", "num-traits", "subtle", diff --git a/ml-kem/Cargo.toml b/ml-kem/Cargo.toml index e56cccc..9dbcdfc 100644 --- a/ml-kem/Cargo.toml +++ b/ml-kem/Cargo.toml @@ -25,7 +25,7 @@ pkcs8 = ["dep:const-oid", "dep:pkcs8"] zeroize = ["module-lattice/zeroize", "dep:zeroize"] [dependencies] -array = { package = "hybrid-array", version = "0.4.4", features = ["extra-sizes", "subtle"] } +array = { version = "0.4.4", package = "hybrid-array", features = ["extra-sizes", "subtle"] } module-lattice = { version = "0.1.0-pre.0", features = ["subtle"] } kem = "0.3.0-rc.2" rand_core = "0.10.0-rc-6" diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index 0090a69..9754347 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -1,13 +1,13 @@ use array::{Array, typenum::U256}; use module_lattice::{ algebra::{Field, MultiplyNtt}, + encode::Encode, util::Truncate, }; use sha3::digest::XofReader; use crate::B32; use crate::crypto::{PRF, PrfOutput, XOF}; -use crate::encode::Encode; use crate::param::{ArraySize, CbdSamplingSize}; module_lattice::define_field!(BaseField, u16, u32, u64, 3329); diff --git a/ml-kem/src/encode.rs b/ml-kem/src/encode.rs deleted file mode 100644 index 7664d2e..0000000 --- a/ml-kem/src/encode.rs +++ /dev/null @@ -1,305 +0,0 @@ -use crate::{ - algebra::{BaseField, Elem, Int, NttPolynomial, NttVector, Polynomial, Vector}, - param::{ArraySize, EncodedPolynomial, EncodingSize, VectorEncodingSize}, -}; -use array::{ - Array, - typenum::{U256, Unsigned}, -}; -use module_lattice::{algebra::Field, util::Truncate}; - -type DecodedValue = Array; - -// Algorithm 4 ByteEncode_d(F) -// -// Note: This algorithm performs compression as well as encoding. -fn byte_encode(vals: &DecodedValue) -> EncodedPolynomial { - let val_step = D::ValueStep::USIZE; - let byte_step = D::ByteStep::USIZE; - - let mut bytes = EncodedPolynomial::::default(); - - let vc = vals.chunks(val_step); - let bc = bytes.chunks_mut(byte_step); - for (v, b) in vc.zip(bc) { - let mut x = 0u128; - for (j, vj) in v.iter().enumerate() { - x |= u128::from(vj.0) << (D::USIZE * j); - } - - let xb = x.to_le_bytes(); - b.copy_from_slice(&xb[..byte_step]); - } - - bytes -} - -// Algorithm 5 ByteDecode_d(F) -// -// Note: This function performs decompression as well as decoding. -fn byte_decode(bytes: &EncodedPolynomial) -> DecodedValue { - let val_step = D::ValueStep::USIZE; - let byte_step = D::ByteStep::USIZE; - let mask = (1 << D::USIZE) - 1; - - let mut vals = DecodedValue::default(); - - let vc = vals.chunks_mut(val_step); - let bc = bytes.chunks(byte_step); - for (v, b) in vc.zip(bc) { - let mut xb = [0u8; 16]; - xb[..byte_step].copy_from_slice(b); - - let x = u128::from_le_bytes(xb); - for (j, vj) in v.iter_mut().enumerate() { - let val: Int = Truncate::truncate(x >> (D::USIZE * j)); - vj.0 = val & mask; - - if D::USIZE == 12 { - vj.0 %= BaseField::Q; - } - } - } - - vals -} - -pub trait Encode { - type EncodedSize: ArraySize; - fn encode(&self) -> Array; - fn decode(enc: &Array) -> Self; -} - -impl Encode for Polynomial { - type EncodedSize = D::EncodedPolynomialSize; - - fn encode(&self) -> Array { - byte_encode::(&self.0) - } - - fn decode(enc: &Array) -> Self { - Self(byte_decode::(enc)) - } -} - -impl Encode for Vector -where - K: ArraySize, - D: VectorEncodingSize, -{ - type EncodedSize = D::EncodedPolynomialVectorSize; - - fn encode(&self) -> Array { - let polys = self.0.iter().map(|x| Encode::::encode(x)).collect(); - >::flatten(polys) - } - - fn decode(enc: &Array) -> Self { - let unfold = >::unflatten(enc); - Self( - unfold - .iter() - .map(|&x| >::decode(x)) - .collect(), - ) - } -} - -impl Encode for NttPolynomial { - type EncodedSize = D::EncodedPolynomialSize; - - fn encode(&self) -> Array { - byte_encode::(&self.0) - } - - fn decode(enc: &Array) -> Self { - Self(byte_decode::(enc)) - } -} - -impl Encode for NttVector -where - D: VectorEncodingSize, - K: ArraySize, -{ - type EncodedSize = D::EncodedPolynomialVectorSize; - - fn encode(&self) -> Array { - let polys = self.0.iter().map(|x| Encode::::encode(x)).collect(); - >::flatten(polys) - } - - fn decode(enc: &Array) -> Self { - let unfold = >::unflatten(enc); - Self( - unfold - .iter() - .map(|&x| >::decode(x)) - .collect(), - ) - } -} - -#[cfg(test)] -pub(crate) mod test { - use super::*; - use crate::param::EncodedPolynomialVector; - use array::typenum::{ - U1, U2, U3, U4, U5, U6, U8, U10, U11, U12, marker_traits::Zero, operator_aliases::Mod, - }; - use core::{fmt::Debug, ops::Rem}; - use getrandom::SysRng; - use module_lattice::algebra::Field; - use rand_core::{Rng, UnwrapErr}; - - // A helper trait to construct larger arrays by repeating smaller ones - trait Repeat { - fn repeat(&self) -> Array; - } - - impl Repeat for Array - where - N: ArraySize, - T: Clone, - D: ArraySize + Rem, - Mod: Zero, - { - #[allow(clippy::integer_division_remainder_used)] - fn repeat(&self) -> Array { - Array::from_fn(|i| self[i % N::USIZE].clone()) - } - } - - #[allow(clippy::integer_division_remainder_used)] - fn byte_codec_test(decoded: &DecodedValue, encoded: &EncodedPolynomial) - where - D: EncodingSize, - { - // Test known answer - let actual_encoded = byte_encode::(decoded); - assert_eq!(&actual_encoded, encoded); - - let actual_decoded = byte_decode::(encoded); - assert_eq!(&actual_decoded, decoded); - - // Test random decode/encode and encode/decode round trips - let mut rng = UnwrapErr(SysRng); - let decoded = Array::::from_fn(|_| (rng.next_u32() & 0xFFFF) as Int); - let m = match D::USIZE { - 12 => BaseField::Q, - d => (1 as Int) << d, - }; - let decoded = decoded.iter().map(|x| Elem::new(x % m)).collect(); - - let actual_encoded = byte_encode::(&decoded); - let actual_decoded = byte_decode::(&actual_encoded); - assert_eq!(actual_decoded, decoded); - - let actual_reencoded = byte_encode::(&decoded); - assert_eq!(actual_reencoded, actual_encoded); - } - - #[test] - fn byte_codec() { - // The 1-bit can only represent decoded values equal to 0 or 1. - let decoded: DecodedValue = Array::<_, U2>([Elem::new(0), Elem::new(1)]).repeat(); - let encoded: EncodedPolynomial = Array([0xaa; 32]); - byte_codec_test::(&decoded, &encoded); - - // For other codec widths, we use a standard sequence - let decoded: DecodedValue = Array::<_, U8>([ - Elem::new(0), - Elem::new(1), - Elem::new(2), - Elem::new(3), - Elem::new(4), - Elem::new(5), - Elem::new(6), - Elem::new(7), - ]) - .repeat(); - - let encoded: EncodedPolynomial = Array::<_, U4>([0x10, 0x32, 0x54, 0x76]).repeat(); - byte_codec_test::(&decoded, &encoded); - - let encoded: EncodedPolynomial = - Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); - byte_codec_test::(&decoded, &encoded); - - let encoded: EncodedPolynomial = - Array::<_, U6>([0x40, 0x20, 0x0c, 0x44, 0x61, 0x1c]).repeat(); - byte_codec_test::(&decoded, &encoded); - - let encoded: EncodedPolynomial = - Array::<_, U10>([0x00, 0x04, 0x20, 0xc0, 0x00, 0x04, 0x14, 0x60, 0xc0, 0x01]).repeat(); - byte_codec_test::(&decoded, &encoded); - - let encoded: EncodedPolynomial = Array::<_, U11>([ - 0x00, 0x08, 0x80, 0x00, 0x06, 0x40, 0x80, 0x02, 0x18, 0xe0, 0x00, - ]) - .repeat(); - byte_codec_test::(&decoded, &encoded); - - let encoded: EncodedPolynomial = Array::<_, U12>([ - 0x00, 0x10, 0x00, 0x02, 0x30, 0x00, 0x04, 0x50, 0x00, 0x06, 0x70, 0x00, - ]) - .repeat(); - byte_codec_test::(&decoded, &encoded); - } - - #[allow(clippy::integer_division_remainder_used)] - #[test] - fn byte_codec_12_mod() { - // DecodeBytes_12 is required to reduce mod q - let encoded: EncodedPolynomial = Array([0xff; 384]); - let decoded: DecodedValue = Array([Elem::new(0xfff % BaseField::Q); 256]); - - let actual_decoded = byte_decode::(&encoded); - assert_eq!(actual_decoded, decoded); - } - - fn vector_codec_known_answer_test(decoded: &T, encoded: &Array) - where - D: EncodingSize, - T: Encode + PartialEq + Debug, - { - let actual_encoded = decoded.encode(); - assert_eq!(&actual_encoded, encoded); - - let actual_decoded: T = Encode::decode(encoded); - assert_eq!(&actual_decoded, decoded); - } - - #[test] - fn vector_codec() { - let poly = Polynomial::new( - Array::<_, U8>([ - Elem::new(0), - Elem::new(1), - Elem::new(2), - Elem::new(3), - Elem::new(4), - Elem::new(5), - Elem::new(6), - Elem::new(7), - ]) - .repeat(), - ); - - // The required vector sizes are 2, 3, and 4. - let decoded: Vector = Vector::new(Array([poly, poly])); - let encoded: EncodedPolynomialVector = - Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); - vector_codec_known_answer_test::>(&decoded, &encoded); - - let decoded: Vector = Vector::new(Array([poly, poly, poly])); - let encoded: EncodedPolynomialVector = - Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); - vector_codec_known_answer_test::>(&decoded, &encoded); - - let decoded: Vector = Vector::new(Array([poly, poly, poly, poly])); - let encoded: EncodedPolynomialVector = - Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); - vector_codec_known_answer_test::>(&decoded, &encoded); - } -} diff --git a/ml-kem/src/lib.rs b/ml-kem/src/lib.rs index 15298dd..d79956f 100644 --- a/ml-kem/src/lib.rs +++ b/ml-kem/src/lib.rs @@ -57,9 +57,6 @@ mod crypto; /// Section 4.2.1. Conversion and Compression Algorithms, Compression and decompression mod compress; -/// Section 4.2.1. Conversion and Compression Algorithms, Encoding and decoding -mod encode; - /// Section 5. The K-PKE Component Scheme mod pke; @@ -78,7 +75,8 @@ pub use array; pub use ml_kem_512::MlKem512Params; pub use ml_kem_768::MlKem768Params; pub use ml_kem_1024::MlKem1024Params; -pub use param::{ArraySize, ExpandedDecapsulationKey, ParameterSet}; +pub use module_lattice::encode::ArraySize; +pub use param::{ExpandedDecapsulationKey, ParameterSet}; pub use traits::*; use array::{ diff --git a/ml-kem/src/param.rs b/ml-kem/src/param.rs index d22a3c7..f475f8d 100644 --- a/ml-kem/src/param.rs +++ b/ml-kem/src/param.rs @@ -10,96 +10,31 @@ //! know any details about object sizes. For example, `VectorEncodingSize::flatten` needs to know //! that the size of an encoded vector is `K` times the size of an encoded polynomial. +pub(crate) use module_lattice::encode::{ + ArraySize, Encode, EncodedPolynomial, EncodedPolynomialSize, EncodedVectorSize, EncodingSize, + VectorEncodingSize, +}; + use crate::{ B32, algebra::{BaseField, Elem, NttVector}, - encode::Encode, }; use array::{ Array, typenum::{ - Const, ToUInt, U0, U2, U3, U4, U6, U8, U12, U16, U32, U64, U384, - operator_aliases::{Gcf, Prod, Quot, Sum}, - type_operators::Gcd, + Const, ToUInt, U0, U2, U3, U4, U6, U12, U16, U32, U64, U384, + operator_aliases::{Prod, Sum}, }, }; use core::{ fmt::Debug, ops::{Add, Div, Mul, Rem, Sub}, }; -use module_lattice::{ - algebra::Field, - util::{Flatten, Unflatten}, -}; +use module_lattice::algebra::Field; #[cfg(doc)] use crate::Seed; -/// An array length with other useful properties -pub trait ArraySize: array::ArraySize + PartialEq + Debug {} - -impl ArraySize for T where T: array::ArraySize + PartialEq + Debug {} - -/// An integer that can be used as a length for encoded values. -pub trait EncodingSize: ArraySize { - type EncodedPolynomialSize: ArraySize; - type ValueStep: ArraySize; - type ByteStep: ArraySize; -} - -type EncodingUnit = Quot, Gcf>; - -pub type EncodedPolynomialSize = ::EncodedPolynomialSize; -pub type EncodedPolynomial = Array>; - -impl EncodingSize for D -where - D: ArraySize + Mul + Gcd + Mul, - Prod: ArraySize, - Prod: Div>, - EncodingUnit: Div + Div, - Quot, D>: ArraySize, - Quot, U8>: ArraySize, -{ - type EncodedPolynomialSize = Prod; - type ValueStep = Quot, D>; - type ByteStep = Quot, U8>; -} - -/// An integer that can describe encoded vectors. -pub trait VectorEncodingSize: EncodingSize -where - K: ArraySize, -{ - type EncodedPolynomialVectorSize: ArraySize; - - fn flatten(polys: Array, K>) -> EncodedPolynomialVector; - fn unflatten(vec: &EncodedPolynomialVector) -> Array<&EncodedPolynomial, K>; -} - -pub type EncodedPolynomialVectorSize = - >::EncodedPolynomialVectorSize; -pub type EncodedPolynomialVector = Array>; - -impl VectorEncodingSize for D -where - D: EncodingSize, - K: ArraySize, - D::EncodedPolynomialSize: Mul, - Prod: - ArraySize + Div + Rem, -{ - type EncodedPolynomialVectorSize = Prod; - - fn flatten(polys: Array, K>) -> EncodedPolynomialVector { - polys.flatten() - } - - fn unflatten(vec: &EncodedPolynomialVector) -> Array<&EncodedPolynomial, K> { - vec.unflatten() - } -} - /// An integer that describes a bit length to be used in CBD sampling pub trait CbdSamplingSize: ArraySize { type SampleSize: EncodingSize; @@ -173,7 +108,7 @@ pub trait ParameterSet: Default + Clone + Debug + PartialEq { type Dv: EncodingSize; } -type EncodedUSize

= EncodedPolynomialVectorSize<

::Du,

::K>; +type EncodedUSize

= EncodedVectorSize<

::Du,

::K>; type EncodedVSize

= EncodedPolynomialSize<

::Dv>; type EncodedU

= Array>; @@ -208,11 +143,11 @@ where EncodedUSize

: Add>, Sum, EncodedVSize

>: ArraySize + Sub, Output = EncodedVSize

>, - EncodedPolynomialVectorSize: Add, - Sum, U32>: - ArraySize + Sub, Output = U32>, + EncodedVectorSize: Add, + Sum, U32>: + ArraySize + Sub, Output = U32>, { - type NttVectorSize = EncodedPolynomialVectorSize; + type NttVectorSize = EncodedVectorSize; type EncryptionKeySize = Sum; type CiphertextSize = Sum, EncodedVSize

>; diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index cb0202e..0278d2f 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -5,10 +5,10 @@ use crate::algebra::{ }; use crate::compress::Compress; use crate::crypto::{G, PRF}; -use crate::encode::Encode; use crate::param::{EncodedCiphertext, EncodedDecryptionKey, EncodedEncryptionKey, PkeParams}; use array::typenum::{U1, Unsigned}; use kem::InvalidKey; +use module_lattice::encode::Encode; use subtle::{Choice, ConstantTimeEq}; #[cfg(feature = "zeroize")] diff --git a/module-lattice/Cargo.toml b/module-lattice/Cargo.toml index d224a4c..d923ef5 100644 --- a/module-lattice/Cargo.toml +++ b/module-lattice/Cargo.toml @@ -16,13 +16,16 @@ categories = ["cryptography", "no-std"] keywords = ["crypto", "kyber", "lattice", "post-quantum"] [dependencies] -hybrid-array = { version = "0.4", features = ["extra-sizes"] } +array = { version = "0.4", package = "hybrid-array", features = ["extra-sizes"] } num-traits = { version = "0.2", default-features = false } # optional dependencies subtle = { version = "2", optional = true, default-features = false } zeroize = { version = "1.8.1", optional = true, default-features = false } +[dev-dependencies] +getrandom = { version = "0.4.0-rc.1", features = ["sys_rng"] } + [features] -subtle = ["dep:subtle"] -zeroize = ["hybrid-array/zeroize", "dep:zeroize"] +subtle = ["dep:subtle", "array/subtle"] +zeroize = ["array/zeroize", "dep:zeroize"] diff --git a/module-lattice/src/algebra.rs b/module-lattice/src/algebra.rs index f54b8bd..a649da2 100644 --- a/module-lattice/src/algebra.rs +++ b/module-lattice/src/algebra.rs @@ -1,8 +1,8 @@ use super::util::Truncate; +use array::{Array, ArraySize, typenum::U256}; use core::fmt::Debug; use core::ops::{Add, Mul, Neg, Sub}; -use hybrid_array::{Array, ArraySize, typenum::U256}; use num_traits::PrimInt; #[cfg(feature = "subtle")] diff --git a/module-lattice/src/encode.rs b/module-lattice/src/encode.rs index 2de525b..b6a2048 100644 --- a/module-lattice/src/encode.rs +++ b/module-lattice/src/encode.rs @@ -1,18 +1,18 @@ -use core::fmt::Debug; -use core::ops::{Div, Mul, Rem}; -use hybrid_array::{ +use array::{ Array, typenum::{Gcd, Gcf, Prod, Quot, U0, U8, U32, U256, Unsigned}, }; +use core::fmt::Debug; +use core::ops::{Div, Mul, Rem}; use num_traits::One; use super::algebra::{Elem, Field, NttPolynomial, NttVector, Polynomial, Vector}; use super::util::{Flatten, Truncate, Unflatten}; /// An array length with other useful properties -pub trait ArraySize: hybrid_array::ArraySize + PartialEq + Debug {} +pub trait ArraySize: array::ArraySize + PartialEq + Debug {} -impl ArraySize for T where T: hybrid_array::ArraySize + PartialEq + Debug {} +impl ArraySize for T where T: array::ArraySize + PartialEq + Debug {} /// An integer that can describe encoded polynomials. pub trait EncodingSize: ArraySize { diff --git a/module-lattice/src/lib.rs b/module-lattice/src/lib.rs index 874d1e1..4c8c30f 100644 --- a/module-lattice/src/lib.rs +++ b/module-lattice/src/lib.rs @@ -9,12 +9,6 @@ #![warn(clippy::pedantic)] // Be pedantic by default #![warn(clippy::integer_division_remainder_used)] // Be judicious about using `/` and `%` -// XXX(RLB) There are no unit tests in this crate right now, because the algebra and encode/decode -// routines all require a field, and the concrete field definitions are down in the dependent -// modules. Maybe we should pull the field definitions up into this module so that we can verify -// that everything works. That might also let us make private some of the tools used to build -// things up. - /// Linear algebra with degree-256 polynomials over a prime-order field, vectors of such /// polynomials, and NTT polynomials / vectors pub mod algebra; diff --git a/module-lattice/src/util.rs b/module-lattice/src/util.rs index e31312b..7afe59f 100644 --- a/module-lattice/src/util.rs +++ b/module-lattice/src/util.rs @@ -1,10 +1,12 @@ -use core::mem::ManuallyDrop; -use core::ops::{Div, Mul, Rem}; -use core::ptr; -use hybrid_array::{ +use array::{ Array, ArraySize, typenum::{Prod, Quot, U0, Unsigned}, }; +use core::{ + mem::ManuallyDrop, + ops::{Div, Mul, Rem}, + ptr, +}; /// Safely truncate an unsigned integer value to shorter representation pub trait Truncate { @@ -115,9 +117,9 @@ where #[cfg(test)] mod test { use super::*; - use hybrid_array::{ + use array::{ Array, - typenum::{U2, U5}, + sizes::{U2, U5}, }; #[test] diff --git a/module-lattice/tests/algebra.rs b/module-lattice/tests/algebra.rs index 2c0ce3f..1897447 100644 --- a/module-lattice/tests/algebra.rs +++ b/module-lattice/tests/algebra.rs @@ -1,6 +1,6 @@ //! Tests for the `algebra` module. -use hybrid_array::typenum::U2; +use array::typenum::U2; use module_lattice::algebra::{ Elem, Field, NttMatrix, NttPolynomial, NttVector, Polynomial, Vector, }; @@ -318,9 +318,9 @@ fn ntt_polynomial_scalar_multiplication() { #[test] fn ntt_polynomial_from_array() { - use hybrid_array::Array; + use array::Array; - let coeffs: Array, hybrid_array::typenum::U256> = + let coeffs: Array, array::typenum::U256> = core::array::from_fn(|i| Elem::new((i % 3329) as u16)).into(); let p: NttPolynomial = coeffs.into(); @@ -328,7 +328,7 @@ fn ntt_polynomial_from_array() { assert_eq!(p.0[1].0, 1); // Convert back - let arr: Array, hybrid_array::typenum::U256> = p.into(); + let arr: Array, array::typenum::U256> = p.into(); assert_eq!(arr[0].0, coeffs[0].0); } @@ -435,8 +435,8 @@ fn ntt_matrix_equality() { #[test] fn ntt_polynomial_into_array() { - use hybrid_array::Array; - use hybrid_array::typenum::U256; + use array::Array; + use array::typenum::U256; let p = make_test_ntt_polynomial::(100); diff --git a/module-lattice/tests/encode.rs b/module-lattice/tests/encode.rs index 31c30a4..0a5028d 100644 --- a/module-lattice/tests/encode.rs +++ b/module-lattice/tests/encode.rs @@ -1,16 +1,139 @@ //! Tests for the `encode` module. -use hybrid_array::typenum::{U1, U4, U10, U12}; -use module_lattice::algebra::{Elem, NttPolynomial, NttVector, Polynomial, Vector}; -use module_lattice::encode::{Encode, byte_decode, byte_encode}; +#![allow(clippy::integer_division_remainder_used)] + +use array::sizes::U3; +use array::typenum::{Mod, Zero}; +use array::{ + Array, + sizes::{U1, U2, U4, U5, U6, U8, U10, U11, U12, U256}, +}; +use getrandom::{ + SysRng, + rand_core::{Rng, UnwrapErr}, +}; +use module_lattice::encode::EncodedVector; +use module_lattice::{ + algebra::{Elem, Field, NttPolynomial, NttVector, Polynomial, Vector}, + encode::{ArraySize, Encode, EncodedPolynomial, EncodingSize, byte_decode, byte_encode}, +}; +use std::fmt::Debug; +use std::ops::Rem; // Field used by ML-KEM. module_lattice::define_field!(KyberField, u16, u32, u64, 3329); +type Int = u16; +type DecodedValue = module_lattice::encode::DecodedValue; + +/// A helper trait to construct larger arrays by repeating smaller ones +trait Repeat { + fn repeat(&self) -> Array; +} + +impl Repeat for Array +where + N: ArraySize, + T: Clone, + D: ArraySize + Rem, + Mod: Zero, +{ + #[allow(clippy::integer_division_remainder_used)] + fn repeat(&self) -> Array { + Array::from_fn(|i| self[i % N::USIZE].clone()) + } +} + // ======================================== -// byte_encode / byte_decode round-trip tests +// byte_encode / byte_decode tests // ======================================== +#[allow(clippy::integer_division_remainder_used)] +fn byte_codec_test(decoded: &DecodedValue, encoded: &EncodedPolynomial) +where + D: EncodingSize, +{ + // Test known answer + let actual_encoded = byte_encode::(decoded); + assert_eq!(&actual_encoded, encoded); + + let actual_decoded = byte_decode::(encoded); + assert_eq!(&actual_decoded, decoded); + + // Test random decode/encode and encode/decode round trips + let mut rng = UnwrapErr(SysRng); + let decoded = Array::::from_fn(|_| (rng.next_u32() & 0xFFFF) as Int); + let m = match D::USIZE { + 12 => KyberField::Q, + d => (1 as Int) << d, + }; + let decoded = decoded.iter().map(|x| Elem::new(x % m)).collect(); + + let actual_encoded = byte_encode::(&decoded); + let actual_decoded = byte_decode::(&actual_encoded); + assert_eq!(actual_decoded, decoded); + + let actual_reencoded = byte_encode::(&decoded); + assert_eq!(actual_reencoded, actual_encoded); +} + +#[test] +fn byte_codec() { + // The 1-bit can only represent decoded values equal to 0 or 1. + let decoded: DecodedValue = Array::<_, U2>([Elem::new(0), Elem::new(1)]).repeat(); + let encoded: EncodedPolynomial = Array([0xaa; 32]); + byte_codec_test::(&decoded, &encoded); + + // For other codec widths, we use a standard sequence + let decoded: DecodedValue = Array::<_, U8>([ + Elem::new(0), + Elem::new(1), + Elem::new(2), + Elem::new(3), + Elem::new(4), + Elem::new(5), + Elem::new(6), + Elem::new(7), + ]) + .repeat(); + + let encoded: EncodedPolynomial = Array::<_, U4>([0x10, 0x32, 0x54, 0x76]).repeat(); + byte_codec_test::(&decoded, &encoded); + + let encoded: EncodedPolynomial = Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); + byte_codec_test::(&decoded, &encoded); + + let encoded: EncodedPolynomial = + Array::<_, U6>([0x40, 0x20, 0x0c, 0x44, 0x61, 0x1c]).repeat(); + byte_codec_test::(&decoded, &encoded); + + let encoded: EncodedPolynomial = + Array::<_, U10>([0x00, 0x04, 0x20, 0xc0, 0x00, 0x04, 0x14, 0x60, 0xc0, 0x01]).repeat(); + byte_codec_test::(&decoded, &encoded); + + let encoded: EncodedPolynomial = Array::<_, U11>([ + 0x00, 0x08, 0x80, 0x00, 0x06, 0x40, 0x80, 0x02, 0x18, 0xe0, 0x00, + ]) + .repeat(); + byte_codec_test::(&decoded, &encoded); + + let encoded: EncodedPolynomial = Array::<_, U12>([ + 0x00, 0x10, 0x00, 0x02, 0x30, 0x00, 0x04, 0x50, 0x00, 0x06, 0x70, 0x00, + ]) + .repeat(); + byte_codec_test::(&decoded, &encoded); +} + +#[test] +fn byte_codec_12_mod() { + // DecodeBytes_12 is required to reduce mod q + let encoded: EncodedPolynomial = Array([0xff; 384]); + let decoded: DecodedValue = Array([Elem::new(0xfff % KyberField::Q); 256]); + + let actual_decoded = byte_decode::(&encoded); + assert_eq!(actual_decoded, decoded); +} + #[test] fn byte_encode_decode_d1_roundtrip() { // D=1: Single bit encoding @@ -136,9 +259,51 @@ fn polynomial_encode_decode_d12() { // Vector encoding tests // ======================================== +fn vector_codec_known_answer_test(decoded: &T, encoded: &Array) +where + D: EncodingSize, + T: Encode + PartialEq + Debug, +{ + let actual_encoded = decoded.encode(); + assert_eq!(&actual_encoded, encoded); + + let actual_decoded: T = Encode::decode(encoded); + assert_eq!(&actual_decoded, decoded); +} + +#[test] +fn vector_codec() { + let poly = Polynomial::new( + Array::<_, U8>([ + Elem::new(0), + Elem::new(1), + Elem::new(2), + Elem::new(3), + Elem::new(4), + Elem::new(5), + Elem::new(6), + Elem::new(7), + ]) + .repeat(), + ); + + // The required vector sizes are 2, 3, and 4. + let decoded: Vector = Vector::new(Array([poly, poly])); + let encoded: EncodedVector = Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); + vector_codec_known_answer_test::>(&decoded, &encoded); + + let decoded: Vector = Vector::new(Array([poly, poly, poly])); + let encoded: EncodedVector = Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); + vector_codec_known_answer_test::>(&decoded, &encoded); + + let decoded: Vector = Vector::new(Array([poly, poly, poly, poly])); + let encoded: EncodedVector = Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat(); + vector_codec_known_answer_test::>(&decoded, &encoded); +} + #[test] fn vector_encode_decode_roundtrip() { - use hybrid_array::typenum::U2; + use array::typenum::U2; let coeffs1: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 3) as u16 % 16)); let coeffs2: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 5) as u16 % 16)); @@ -186,7 +351,7 @@ fn ntt_polynomial_encode_decode_d12() { #[test] fn ntt_vector_encode_decode_roundtrip() { - use hybrid_array::typenum::U2; + use array::typenum::U2; let coeffs1: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 3) as u16 % 16)); let coeffs2: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 5) as u16 % 16)); @@ -227,7 +392,7 @@ fn encoded_polynomial_size_d12() { #[test] fn encoded_vector_size() { - use hybrid_array::typenum::U3; + use array::typenum::U3; // D=4, K=3: 128 bytes per polynomial * 3 = 384 bytes let coeffs = [Elem::::new(0); 256];