diff --git a/ml-kem/src/kem.rs b/ml-kem/src/kem.rs
index d361977..8f105d1 100644
--- a/ml-kem/src/kem.rs
+++ b/ml-kem/src/kem.rs
@@ -8,7 +8,7 @@ use crate::crypto::{G, H, J, rand};
use crate::param::{DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, KemParams};
use crate::pke::{DecryptionKey, EncryptionKey};
use crate::util::B32;
-use crate::{Encoded, EncodedSizeUser};
+use crate::{Encoded, EncodedSizeUser, Seed};
#[cfg(feature = "zeroize")]
use zeroize::{Zeroize, ZeroizeOnDrop};
@@ -21,16 +21,28 @@ pub(crate) type SharedKey = B32;
/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
/// encapsulated shared key.
-#[derive(Clone, Debug, PartialEq)]
+#[derive(Clone, Debug)]
pub struct DecapsulationKey
where
P: KemParams,
{
dk_pke: DecryptionKey
,
ek: EncapsulationKey
,
+ d: Option,
z: B32,
}
+// Handwritten to omit `d` in the comparisons, so keys initialized from seeds compare equally to
+// keys initialized from the expanded form
+impl PartialEq for DecapsulationKey
+where
+ P: KemParams,
+{
+ fn eq(&self, other: &Self) -> bool {
+ self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
+ }
+}
+
#[cfg(feature = "zeroize")]
impl
Drop for DecapsulationKey
where
@@ -65,6 +77,7 @@ where
ek_pke,
h: h.clone(),
},
+ d: None,
z: z.clone(),
}
}
@@ -102,24 +115,50 @@ impl
DecapsulationKey
where
P: KemParams,
{
+ /// Create a [`DecapsulationKey`] instance from a 64-byte random seed value.
+ #[inline]
+ #[must_use]
+ pub fn from_seed(seed: Seed) -> Self {
+ let (d, z) = seed.split();
+ Self::generate_deterministic(d, z)
+ }
+
+ /// Serialize the [`Seed`] value: 64-bytes which can be used to reconstruct the
+ /// [`DecapsulationKey`].
+ ///
+ /// # ⚠️Warning!
+ ///
+ /// This value is key material. Please treat it with care.
+ ///
+ /// # Returns
+ /// - `Some` if the [`DecapsulationKey`] was initialized using `from_seed`, `generate`, or
+ /// `generate_deterministic`
+ /// - `None` if the [`DecapsulationKey`] was initialized from the expanded form.
+ #[inline]
+ pub fn to_seed(&self) -> Option {
+ self.d.map(|d| d.concat(self.z))
+ }
+
/// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`].
pub fn encapsulation_key(&self) -> &EncapsulationKey {
&self.ek
}
+ #[inline]
pub(crate) fn generate(rng: &mut R) -> Self {
let d: B32 = rand(rng);
let z: B32 = rand(rng);
- Self::generate_deterministic(&d, &z)
+ Self::generate_deterministic(d, z)
}
+ #[inline]
#[must_use]
#[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
- pub(crate) fn generate_deterministic(d: &B32, z: &B32) -> Self {
- let (dk_pke, ek_pke) = DecryptionKey::generate(d);
+ pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
+ let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
let ek = EncapsulationKey::new(ek_pke);
- let z = z.clone();
- Self { dk_pke, ek, z }
+ let d = Some(d);
+ Self { dk_pke, ek, d, z }
}
}
@@ -224,10 +263,7 @@ where
}
#[cfg(feature = "deterministic")]
- fn generate_deterministic(
- d: &B32,
- z: &B32,
- ) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
+ fn generate_deterministic(d: B32, z: B32) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
let dk = Self::DecapsulationKey::generate_deterministic(d, z);
let ek = dk.encapsulation_key().clone();
(dk, ek)
@@ -239,6 +275,7 @@ mod test {
use super::*;
use crate::{MlKem512Params, MlKem768Params, MlKem1024Params};
use ::kem::{Decapsulate, Encapsulate};
+ use rand_core::TryRngCore;
fn round_trip_test()
where
@@ -261,7 +298,7 @@ mod test {
round_trip_test::();
}
- fn codec_test()
+ fn expanded_key_test
()
where
P: KemParams,
{
@@ -279,9 +316,29 @@ mod test {
}
#[test]
- fn codec() {
- codec_test::();
- codec_test::();
- codec_test::();
+ fn expanded_key() {
+ expanded_key_test::();
+ expanded_key_test::();
+ expanded_key_test::();
+ }
+
+ fn seed_test()
+ where
+ P: KemParams,
+ {
+ let mut rng = rand::rng();
+ let mut seed = Seed::default();
+ rng.try_fill_bytes(&mut seed).unwrap();
+
+ let dk = DecapsulationKey::
::from_seed(seed.clone());
+ let seed_encoded = dk.to_seed().unwrap();
+ assert_eq!(seed, seed_encoded);
+ }
+
+ #[test]
+ fn seed() {
+ seed_test::();
+ seed_test::();
+ seed_test::();
}
}
diff --git a/ml-kem/src/lib.rs b/ml-kem/src/lib.rs
index 24e04e8..e77ac68 100644
--- a/ml-kem/src/lib.rs
+++ b/ml-kem/src/lib.rs
@@ -69,7 +69,7 @@ use ::kem::{Decapsulate, Encapsulate};
use core::fmt::Debug;
use hybrid_array::{
Array,
- typenum::{U2, U3, U4, U5, U10, U11},
+ typenum::{U2, U3, U4, U5, U10, U11, U64},
};
use rand_core::CryptoRng;
@@ -80,6 +80,10 @@ pub use util::B32;
pub use param::{ArraySize, ParameterSet};
+/// ML-KEM seeds are decapsulation (private) keys, which are consistently 64-bytes across all
+/// security levels, and are the preferred serialization for representing such keys.
+pub type Seed = Array;
+
/// An object that knows what size it is
pub trait EncodedSizeUser {
/// The size of an encoded object
@@ -148,8 +152,7 @@ pub trait KemCore: Clone {
/// Generate a new (decapsulation, encapsulation) key pair deterministically
#[cfg(feature = "deterministic")]
- fn generate_deterministic(d: &B32, z: &B32)
- -> (Self::DecapsulationKey, Self::EncapsulationKey);
+ fn generate_deterministic(d: B32, z: B32) -> (Self::DecapsulationKey, Self::EncapsulationKey);
}
/// `MlKem512` is the parameter set for security category 1, corresponding to key search on a block
diff --git a/ml-kem/tests/key-gen.rs b/ml-kem/tests/key-gen.rs
index 3c855fd..37fa44a 100644
--- a/ml-kem/tests/key-gen.rs
+++ b/ml-kem/tests/key-gen.rs
@@ -34,7 +34,7 @@ fn verify(tc: &acvp::TestCase) {
let dk_bytes = Encoded::::try_from(tc.dk.as_slice()).unwrap();
let ek_bytes = Encoded::::try_from(tc.ek.as_slice()).unwrap();
- let (dk, ek) = K::generate_deterministic(&d, &z);
+ let (dk, ek) = K::generate_deterministic(d, z);
// Verify correctness via serialization
assert_eq!(dk.as_bytes().as_slice(), tc.dk.as_slice());
diff --git a/x-wing/src/lib.rs b/x-wing/src/lib.rs
index 597fb0d..201e2df 100644
--- a/x-wing/src/lib.rs
+++ b/x-wing/src/lib.rs
@@ -180,7 +180,7 @@ impl DecapsulationKey {
let d = read_from(&mut expanded).into();
let z = read_from(&mut expanded).into();
- let (sk_m, pk_m) = MlKem768::generate_deterministic(&d, &z);
+ let (sk_m, pk_m) = MlKem768::generate_deterministic(d, z);
let sk_x = read_from(&mut expanded);
let sk_x = StaticSecret::from(sk_x);