From 81993861195af6fbb3921f06caa7daeefbccb996 Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Tue, 3 Feb 2026 07:41:35 -0700 Subject: [PATCH] module-lattice: remove `Flatten` and `Unflatten` They're now in `hybrid-array` as of v0.4.7 --- Cargo.lock | 4 +- ml-kem/Cargo.toml | 2 +- ml-kem/src/algebra.rs | 8 +- ml-kem/src/compress.rs | 2 +- module-lattice/Cargo.toml | 2 +- module-lattice/src/algebra.rs | 4 +- module-lattice/src/encoding.rs | 4 +- module-lattice/src/lib.rs | 2 +- module-lattice/src/truncate.rs | 31 ++++++ module-lattice/src/utils.rs | 177 --------------------------------- 10 files changed, 46 insertions(+), 190 deletions(-) create mode 100644 module-lattice/src/truncate.rs delete mode 100644 module-lattice/src/utils.rs diff --git a/Cargo.lock b/Cargo.lock index 45ad68e..53dc225 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -642,9 +642,9 @@ dependencies = [ [[package]] name = "hybrid-array" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b41fb3dc24fe72c2e3a4685eed55917c2fb228851257f4a8f2d985da9443c3e5" +checksum = "e1b229d73f5803b562cc26e4da0396c8610a4ee209f4fac8fa4f8d709166dc45" dependencies = [ "subtle", "typenum", diff --git a/ml-kem/Cargo.toml b/ml-kem/Cargo.toml index ee322de..cd80e7e 100644 --- a/ml-kem/Cargo.toml +++ b/ml-kem/Cargo.toml @@ -25,7 +25,7 @@ pkcs8 = ["dep:const-oid", "dep:pkcs8"] zeroize = ["module-lattice/zeroize", "dep:zeroize"] [dependencies] -array = { version = "0.4.4", package = "hybrid-array", features = ["extra-sizes", "subtle"] } +array = { version = "0.4.7", package = "hybrid-array", features = ["extra-sizes", "subtle"] } module-lattice = { version = "0.1.0-rc.0", features = ["subtle"] } kem = "0.3.0-rc.3" rand_core = "0.10.0-rc-6" diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index 3237d84..afa6a8c 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -2,7 +2,7 @@ use array::{Array, typenum::U256}; use module_lattice::{ algebra::{Field, MultiplyNtt}, encoding::Encode, - utils::Truncate, + truncate::Truncate, }; use sha3::digest::XofReader; @@ -319,8 +319,10 @@ mod test { Array, ArraySize, B32, BaseField, Elem, Field, Int, Ntt, NttInverse, NttMatrix, NttPolynomial, NttVector, PRF, Polynomial, U256, XOF, }; - use array::typenum::{U2, U3, U8}; - use module_lattice::utils::Flatten; + use array::{ + Flatten, + typenum::{U2, U3, U8}, + }; /// A polynomial with only a scalar component, to make simple test cases fn const_ntt(x: Int) -> NttPolynomial { diff --git a/ml-kem/src/compress.rs b/ml-kem/src/compress.rs index 5c79b13..97617c9 100644 --- a/ml-kem/src/compress.rs +++ b/ml-kem/src/compress.rs @@ -1,6 +1,6 @@ use crate::algebra::{BaseField, Elem, Int, Polynomial, Vector}; use crate::param::{ArraySize, EncodingSize}; -use module_lattice::{algebra::Field, utils::Truncate}; +use module_lattice::{algebra::Field, truncate::Truncate}; // A convenience trait to allow us to associate some constants with a typenum pub(crate) trait CompressionFactor: EncodingSize { diff --git a/module-lattice/Cargo.toml b/module-lattice/Cargo.toml index c802229..3429ef2 100644 --- a/module-lattice/Cargo.toml +++ b/module-lattice/Cargo.toml @@ -16,7 +16,7 @@ categories = ["cryptography", "no-std"] keywords = ["crypto", "kyber", "lattice", "post-quantum"] [dependencies] -array = { version = "0.4", package = "hybrid-array", features = ["extra-sizes"] } +array = { version = "0.4.7", package = "hybrid-array", features = ["extra-sizes"] } num-traits = { version = "0.2", default-features = false } # optional dependencies diff --git a/module-lattice/src/algebra.rs b/module-lattice/src/algebra.rs index 547502b..b762996 100644 --- a/module-lattice/src/algebra.rs +++ b/module-lattice/src/algebra.rs @@ -1,4 +1,4 @@ -use super::utils::Truncate; +use super::truncate::Truncate; use array::{Array, ArraySize, typenum::U256}; use core::fmt::Debug; @@ -80,7 +80,7 @@ macro_rules! define_field { let product = x * Self::BARRETT_MULTIPLIER; let quotient = product >> Self::BARRETT_SHIFT; let remainder = x - quotient * Self::QLL; - Self::small_reduce($crate::utils::Truncate::truncate(remainder)) + Self::small_reduce($crate::truncate::Truncate::truncate(remainder)) } } }; diff --git a/module-lattice/src/encoding.rs b/module-lattice/src/encoding.rs index d649a96..90f2f17 100644 --- a/module-lattice/src/encoding.rs +++ b/module-lattice/src/encoding.rs @@ -1,5 +1,5 @@ use array::{ - Array, + Array, Flatten, Unflatten, typenum::{Gcd, Gcf, Prod, Quot, U0, U8, U32, U256, Unsigned}, }; use core::fmt::Debug; @@ -7,7 +7,7 @@ use core::ops::{Div, Mul, Rem}; use num_traits::One; use super::algebra::{Elem, Field, NttPolynomial, NttVector, Polynomial, Vector}; -use super::utils::{Flatten, Truncate, Unflatten}; +use super::truncate::Truncate; /// An array length with other useful properties pub trait ArraySize: array::ArraySize + PartialEq + Debug {} diff --git a/module-lattice/src/lib.rs b/module-lattice/src/lib.rs index 4027f06..b6363fa 100644 --- a/module-lattice/src/lib.rs +++ b/module-lattice/src/lib.rs @@ -18,4 +18,4 @@ pub mod encoding; /// Utility functions such as truncating integers, flattening arrays of arrays, and unflattening /// arrays into arrays of arrays. -pub mod utils; +pub mod truncate; diff --git a/module-lattice/src/truncate.rs b/module-lattice/src/truncate.rs new file mode 100644 index 0000000..41a03e3 --- /dev/null +++ b/module-lattice/src/truncate.rs @@ -0,0 +1,31 @@ +/// Safely truncate an unsigned integer value to shorter representation +pub trait Truncate { + /// Truncate value to the width of `Self`. + fn truncate(x: T) -> Self; +} + +macro_rules! define_truncate { + ($from:ident, $to:ident) => { + impl Truncate<$from> for $to { + // Truncation should always function as intended here: + // - we ensure `$to` is small enough to infallibly convert to `$from` via the + // `$from::from($to::MAX)` conversion, which will fail if not widening. + // - we are deliberately masking to the smaller size, i.e. truncation is intentional + // (though that's not enough for `clippy` for some reason). Arguably the truncation + // of the `as` cast is sufficient, but this makes it explicit + #[allow(clippy::cast_possible_truncation)] + fn truncate(x: $from) -> $to { + (x & $from::from($to::MAX)) as $to + } + } + }; +} + +define_truncate!(u32, u16); +define_truncate!(u64, u16); +define_truncate!(u64, u32); +define_truncate!(u128, u8); +define_truncate!(u128, u16); +define_truncate!(u128, u32); +define_truncate!(usize, u8); +define_truncate!(usize, u16); diff --git a/module-lattice/src/utils.rs b/module-lattice/src/utils.rs deleted file mode 100644 index bfcccc1..0000000 --- a/module-lattice/src/utils.rs +++ /dev/null @@ -1,177 +0,0 @@ -use array::{ - Array, ArraySize, - typenum::{Prod, Quot, U0, Unsigned}, -}; -use core::{ - mem::ManuallyDrop, - ops::{Div, Mul, Rem}, - ptr, -}; - -/// Safely truncate an unsigned integer value to shorter representation -pub trait Truncate { - /// Truncate value to the width of `Self`. - fn truncate(x: T) -> Self; -} - -macro_rules! define_truncate { - ($from:ident, $to:ident) => { - impl Truncate<$from> for $to { - // Truncation should always function as intended here: - // - we ensure `$to` is small enough to infallibly convert to `$from` via the - // `$from::from($to::MAX)` conversion, which will fail if not widening. - // - we are deliberately masking to the smaller size, i.e. truncation is intentional - // (though that's not enough for `clippy` for some reason). Arguably the truncation - // of the `as` cast is sufficient, but this makes it explicit - #[allow(clippy::cast_possible_truncation)] - fn truncate(x: $from) -> $to { - (x & $from::from($to::MAX)) as $to - } - } - }; -} - -define_truncate!(u32, u16); -define_truncate!(u64, u16); -define_truncate!(u64, u32); -define_truncate!(u128, u8); -define_truncate!(u128, u16); -define_truncate!(u128, u32); -define_truncate!(usize, u8); -define_truncate!(usize, u16); - -/// Defines a sequence of sequences that can be merged into a bigger overall sequence. -pub trait Flatten { - /// Size of the output array. - type OutputSize: ArraySize; - - /// Flatten array. - fn flatten(self) -> Array; -} - -impl Flatten> for Array, N> -where - N: ArraySize, - M: ArraySize + Mul, - Prod: ArraySize, -{ - type OutputSize = Prod; - - fn flatten(self) -> Array { - let whole = ManuallyDrop::new(self); - - // SAFETY: this is the reverse transmute between [T; K*N] and [[T; K], M], which is guaranteed - // to be safe by the Rust memory layout of these types. - #[allow(unsafe_code)] - unsafe { - ptr::read(whole.as_ptr().cast()) - } - } -} - -/// Defines a sequence that can be split into a sequence of smaller sequences of uniform size. -pub trait Unflatten -where - M: ArraySize, -{ - /// Part of the array we're decomposing into. - type Part; - - /// Unflatten array into `Self::Part` chunks. - fn unflatten(self) -> Array; -} - -impl Unflatten for Array -where - N: ArraySize + Div + Rem, - M: ArraySize, - Quot: ArraySize, -{ - type Part = Array>; - - fn unflatten(self) -> Array { - let part_size = Quot::::USIZE; - let whole = ManuallyDrop::new(self); - - // SAFETY: this is doing the same thing as what is done in `Array::split`. - // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to - // be safe by the Rust memory layout of these types. - #[allow(unsafe_code)] - Array::from_fn(|i| unsafe { - let offset = i.checked_mul(part_size).expect("overflow"); - ptr::read(whole.as_ptr().add(offset).cast()) - }) - } -} - -impl<'a, T, N, M> Unflatten for &'a Array -where - N: ArraySize + Div + Rem, - M: ArraySize, - Quot: ArraySize, -{ - type Part = &'a Array>; - - fn unflatten(self) -> Array { - let part_size = Quot::::USIZE; - let mut ptr: *const T = self.as_ptr(); - - // SAFETY: this is doing the same thing as what is done in `Array::split`. - // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to - // be safe by the Rust memory layout of these types. - #[allow(unsafe_code)] - Array::from_fn(|_i| unsafe { - let part = &*(ptr.cast()); - ptr = ptr.add(part_size); - part - }) - } -} - -#[cfg(test)] -mod test { - use super::*; - use array::{ - Array, - sizes::{U2, U5}, - }; - - #[test] - fn flatten() { - let flat: Array = Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); - let unflat2: Array, _> = Array([ - Array([1, 2]), - Array([3, 4]), - Array([5, 6]), - Array([7, 8]), - Array([9, 10]), - ]); - let unflat5: Array, _> = - Array([Array([1, 2, 3, 4, 5]), Array([6, 7, 8, 9, 10])]); - - // Flatten - let actual = unflat2.flatten(); - assert_eq!(flat, actual); - - let actual = unflat5.flatten(); - assert_eq!(flat, actual); - - // Unflatten - let actual: Array, U5> = flat.unflatten(); - assert_eq!(unflat2, actual); - - let actual: Array, U2> = flat.unflatten(); - assert_eq!(unflat5, actual); - - // Unflatten on references - let actual: Array<&Array, U5> = (&flat).unflatten(); - for (i, part) in actual.iter().enumerate() { - assert_eq!(&unflat2[i], *part); - } - - let actual: Array<&Array, U2> = (&flat).unflatten(); - for (i, part) in actual.iter().enumerate() { - assert_eq!(&unflat5[i], *part); - } - } -}