diff --git a/Cargo.lock b/Cargo.lock index ec5ec04..d372aa4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -776,6 +776,7 @@ dependencies = [ "hex-literal", "hybrid-array", "kem", + "module-lattice", "num-rational", "pkcs8", "rand_core", diff --git a/Cargo.toml b/Cargo.toml index 4d326c7..254da3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,4 @@ debug = true [patch.crates-io] ml-kem = { path = "./ml-kem" } +module-lattice = { path = "./module-lattice" } diff --git a/ml-kem/Cargo.toml b/ml-kem/Cargo.toml index b4e3117..de4a1fd 100644 --- a/ml-kem/Cargo.toml +++ b/ml-kem/Cargo.toml @@ -26,6 +26,7 @@ hazmat = [] [dependencies] array = { package = "hybrid-array", version = "0.4.4", features = ["extra-sizes", "subtle"] } +module-lattice = "0.1.0-pre.0" kem = "0.3.0-rc.2" rand_core = "0.10.0-rc-6" sha3 = { version = "0.11.0-rc.3", default-features = false } diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index 7080a3b..cb657af 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -1,12 +1,13 @@ use array::{Array, typenum::U256}; use core::ops::{Add, Mul, Sub}; +use module_lattice::util::Truncate; use sha3::digest::XofReader; use subtle::{Choice, ConstantTimeEq}; use crate::crypto::{PRF, PrfOutput, XOF}; use crate::encode::Encode; use crate::param::{ArraySize, CbdSamplingSize}; -use crate::util::{B32, Truncate}; +use crate::util::B32; #[cfg(feature = "zeroize")] use zeroize::Zeroize; @@ -46,9 +47,9 @@ impl FieldElement { fn barrett_reduce(x: u32) -> u16 { let product = u64::from(x) * Self::BARRETT_MULTIPLIER; - let quotient = (product >> Self::BARRETT_SHIFT).truncate(); + let quotient: u32 = Truncate::truncate(product >> Self::BARRETT_SHIFT); let remainder = x - quotient * Self::Q32; - Self::small_reduce(remainder.truncate()) + Self::small_reduce(Truncate::truncate(remainder)) } // Algorithm 11. BaseCaseMultiply @@ -176,7 +177,7 @@ impl PolynomialVector { Eta: CbdSamplingSize, { Self(Array::from_fn(|i| { - let N = start_n + i.truncate(); + let N = start_n + u8::truncate(i); let prf_output = PRF::(sigma, N); Polynomial::sample_cbd::(&prf_output) })) @@ -432,7 +433,7 @@ impl NttVector { pub fn sample_uniform(rho: &B32, i: usize, transpose: bool) -> Self { Self(Array::from_fn(|j| { let (i, j) = if transpose { (j, i) } else { (i, j) }; - let mut xof = XOF(rho, j.truncate(), i.truncate()); + let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i)); NttPolynomial::sample_uniform(&mut xof) })) } diff --git a/ml-kem/src/compress.rs b/ml-kem/src/compress.rs index 21a1ae3..4716bbe 100644 --- a/ml-kem/src/compress.rs +++ b/ml-kem/src/compress.rs @@ -1,6 +1,6 @@ use crate::algebra::{FieldElement, Integer, Polynomial, PolynomialVector}; use crate::param::{ArraySize, EncodingSize}; -use crate::util::Truncate; +use module_lattice::util::Truncate; // A convenience trait to allow us to associate some constants with a typenum pub trait CompressionFactor: EncodingSize { @@ -37,8 +37,8 @@ impl Compress for FieldElement { fn compress(&mut self) -> &Self { const Q_HALF: u64 = (FieldElement::Q64 + 1) >> 1; let x = u64::from(self.0); - let y = ((((x << D::USIZE) + Q_HALF) * D::DIV_MUL) >> D::DIV_SHIFT).truncate(); - self.0 = y.truncate() & D::MASK; + let y = (((x << D::USIZE) + Q_HALF) * D::DIV_MUL) >> D::DIV_SHIFT; + self.0 = u16::truncate(y) & D::MASK; self } @@ -46,7 +46,7 @@ impl Compress for FieldElement { fn decompress(&mut self) -> &Self { let x = u32::from(self.0); let y = ((x * FieldElement::Q32) + D::POW2_HALF) >> D::USIZE; - self.0 = y.truncate(); + self.0 = Truncate::truncate(y); self } } diff --git a/ml-kem/src/encode.rs b/ml-kem/src/encode.rs index 663912d..26769ed 100644 --- a/ml-kem/src/encode.rs +++ b/ml-kem/src/encode.rs @@ -2,12 +2,12 @@ use array::{ Array, typenum::{U256, Unsigned}, }; +use module_lattice::util::Truncate; use crate::algebra::{ FieldElement, Integer, NttPolynomial, NttVector, Polynomial, PolynomialVector, }; use crate::param::{ArraySize, EncodedPolynomial, EncodingSize, VectorEncodingSize}; -use crate::util::Truncate; type DecodedValue = Array; @@ -53,7 +53,7 @@ fn byte_decode(bytes: &EncodedPolynomial) -> DecodedValue { let x = u128::from_le_bytes(xb); for (j, vj) in v.iter_mut().enumerate() { - let val: Integer = (x >> (D::USIZE * j)).truncate(); + let val: Integer = Truncate::truncate(x >> (D::USIZE * j)); vj.0 = val & mask; if D::USIZE == 12 { diff --git a/ml-kem/src/util.rs b/ml-kem/src/util.rs index 26f22a7..0b325d3 100644 --- a/ml-kem/src/util.rs +++ b/ml-kem/src/util.rs @@ -12,31 +12,6 @@ use core::ptr; /// A 32-byte array, defined here for brevity because it is used several times pub type B32 = Array; -/// Safely truncate an unsigned integer value to shorter representation -pub trait Truncate { - fn truncate(self) -> T; -} - -macro_rules! define_truncate { - ($from:ident, $to:ident) => { - impl Truncate<$to> for $from { - fn truncate(self) -> $to { - // This line is marked unsafe because the `unwrap_unchecked` call is UB when its - // `self` argument is `Err`. It never will be, because we explicitly zeroize the - // high-order bits before converting. We could have used `unwrap()`, but chose to - // avoid the possibility of panic. - unsafe { (self & $from::from($to::MAX)).try_into().unwrap_unchecked() } - } - } - }; -} - -define_truncate!(u32, u16); -define_truncate!(u64, u32); -define_truncate!(usize, u8); -define_truncate!(u128, u16); -define_truncate!(u128, u8); - /// Defines a sequence of sequences that can be merged into a bigger overall seequence pub trait Flatten { type OutputSize: ArraySize; diff --git a/module-lattice/src/util.rs b/module-lattice/src/util.rs index 6488c15..65a36bb 100644 --- a/module-lattice/src/util.rs +++ b/module-lattice/src/util.rs @@ -29,6 +29,7 @@ macro_rules! define_truncate { } define_truncate!(u32, u16); +define_truncate!(u64, u16); define_truncate!(u64, u32); define_truncate!(u128, u8); define_truncate!(u128, u16);