diff --git a/ml-kem/src/kem.rs b/ml-kem/src/kem.rs index 711059e..4ffc41a 100644 --- a/ml-kem/src/kem.rs +++ b/ml-kem/src/kem.rs @@ -271,9 +271,8 @@ where (dk, ek) } - #[cfg(feature = "deterministic")] - fn generate_deterministic(d: B32, z: B32) -> (Self::DecapsulationKey, Self::EncapsulationKey) { - let dk = Self::DecapsulationKey::generate_deterministic(d, z); + fn from_seed(seed: Seed) -> (Self::DecapsulationKey, Self::EncapsulationKey) { + let dk = Self::DecapsulationKey::from_seed(seed); let ek = dk.encapsulation_key().clone(); (dk, ek) } diff --git a/ml-kem/src/traits.rs b/ml-kem/src/traits.rs index 0e8c6fb..6131c13 100644 --- a/ml-kem/src/traits.rs +++ b/ml-kem/src/traits.rs @@ -1,11 +1,13 @@ -use crate::{ArraySize, Ciphertext, SharedKey}; +//! Trait definitions + +use crate::{ArraySize, Ciphertext, Seed, SharedKey}; use core::fmt::Debug; use hybrid_array::Array; use kem::{Decapsulate, Encapsulate}; use rand_core::CryptoRng; #[cfg(feature = "deterministic")] -use crate::util::B32; +use crate::B32; /// An object that knows what size it is pub trait EncodedSizeUser { @@ -52,28 +54,18 @@ pub trait KemCore: Clone { + PartialEq; /// An encapsulation key for this KEM - #[cfg(not(feature = "deterministic"))] - type EncapsulationKey: Encapsulate, SharedKey> - + EncodedSizeUser - + Clone - + Debug - + PartialEq; - - /// An encapsulation key for this KEM - #[cfg(feature = "deterministic")] type EncapsulationKey: Encapsulate, SharedKey> - + EncapsulateDeterministic, SharedKey> + EncodedSizeUser + Clone + Debug + PartialEq; - /// Generate a new (decapsulation, encapsulation) key pair + /// Generate a new (decapsulation, encapsulation) key pair. fn generate( rng: &mut R, ) -> (Self::DecapsulationKey, Self::EncapsulationKey); - /// Generate a new (decapsulation, encapsulation) key pair deterministically - #[cfg(feature = "deterministic")] - fn generate_deterministic(d: B32, z: B32) -> (Self::DecapsulationKey, Self::EncapsulationKey); + /// Generate a new (decapsulation, encapsulation) key pair deterministically from the given + /// uniformly random seed value. + fn from_seed(seed: Seed) -> (Self::DecapsulationKey, Self::EncapsulationKey); } diff --git a/ml-kem/tests/encap-decap.rs b/ml-kem/tests/encap-decap.rs index 64d0247..2a9e98c 100644 --- a/ml-kem/tests/encap-decap.rs +++ b/ml-kem/tests/encap-decap.rs @@ -35,7 +35,11 @@ fn verify_encap_group(tg: &acvp::EncapTestGroup) { } } -fn verify_encap(tc: &acvp::EncapTestCase) { +fn verify_encap(tc: &acvp::EncapTestCase) +where + K: KemCore, + K::EncapsulationKey: EncapsulateDeterministic, SharedKey>, +{ let m = Array::try_from(tc.m.as_slice()).unwrap(); let ek_bytes = Encoded::::try_from(tc.ek.as_slice()).unwrap(); let ek = K::EncapsulationKey::from_bytes(&ek_bytes); diff --git a/ml-kem/tests/key-gen.rs b/ml-kem/tests/key-gen.rs index 37fa44a..d60801f 100644 --- a/ml-kem/tests/key-gen.rs +++ b/ml-kem/tests/key-gen.rs @@ -29,12 +29,12 @@ fn acvp_key_gen() { fn verify(tc: &acvp::TestCase) { // Import test data into the relevant array structures - let d = Array::try_from(tc.d.as_slice()).unwrap(); - let z = Array::try_from(tc.z.as_slice()).unwrap(); + let d: B32 = Array::try_from(tc.d.as_slice()).unwrap(); + let z: B32 = Array::try_from(tc.z.as_slice()).unwrap(); let dk_bytes = Encoded::::try_from(tc.dk.as_slice()).unwrap(); let ek_bytes = Encoded::::try_from(tc.ek.as_slice()).unwrap(); - let (dk, ek) = K::generate_deterministic(d, z); + let (dk, ek) = K::from_seed(d.concat(z)); // Verify correctness via serialization assert_eq!(dk.as_bytes().as_slice(), tc.dk.as_slice()); diff --git a/x-wing/src/lib.rs b/x-wing/src/lib.rs index 201e2df..6ced38c 100644 --- a/x-wing/src/lib.rs +++ b/x-wing/src/lib.rs @@ -178,9 +178,8 @@ impl DecapsulationKey { shaker.update(&self.sk); let mut expanded: Shake256Reader = shaker.finalize_xof(); - let d = read_from(&mut expanded).into(); - let z = read_from(&mut expanded).into(); - let (sk_m, pk_m) = MlKem768::generate_deterministic(d, z); + let seed = read_from(&mut expanded).into(); + let (sk_m, pk_m) = MlKem768::from_seed(seed); let sk_x = read_from(&mut expanded); let sk_x = StaticSecret::from(sk_x);