Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 73 additions & 16 deletions ml-kem/src/kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::crypto::{G, H, J, rand};
use crate::param::{DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, KemParams};
use crate::pke::{DecryptionKey, EncryptionKey};
use crate::util::B32;
use crate::{Encoded, EncodedSizeUser};
use crate::{Encoded, EncodedSizeUser, Seed};

#[cfg(feature = "zeroize")]
use zeroize::{Zeroize, ZeroizeOnDrop};
Expand All @@ -21,16 +21,28 @@ pub(crate) type SharedKey = B32;

/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
/// encapsulated shared key.
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub struct DecapsulationKey<P>
where
P: KemParams,
{
dk_pke: DecryptionKey<P>,
ek: EncapsulationKey<P>,
d: Option<B32>,
z: B32,
}

// Handwritten to omit `d` in the comparisons, so keys initialized from seeds compare equally to
// keys initialized from the expanded form
impl<P> PartialEq for DecapsulationKey<P>
where
P: KemParams,
{
fn eq(&self, other: &Self) -> bool {
self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
}
}

#[cfg(feature = "zeroize")]
impl<P> Drop for DecapsulationKey<P>
where
Expand Down Expand Up @@ -65,6 +77,7 @@ where
ek_pke,
h: h.clone(),
},
d: None,
z: z.clone(),
}
}
Expand Down Expand Up @@ -102,24 +115,50 @@ impl<P> DecapsulationKey<P>
where
P: KemParams,
{
/// Create a [`DecapsulationKey`] instance from a 64-byte random seed value.
#[inline]
#[must_use]
pub fn from_seed(seed: Seed) -> Self {
let (d, z) = seed.split();
Self::generate_deterministic(d, z)
}

/// Serialize the [`Seed`] value: 64-bytes which can be used to reconstruct the
/// [`DecapsulationKey`].
///
/// # ⚠️Warning!
///
/// This value is key material. Please treat it with care.
///
/// # Returns
/// - `Some` if the [`DecapsulationKey`] was initialized using `from_seed`, `generate`, or
/// `generate_deterministic`
/// - `None` if the [`DecapsulationKey`] was initialized from the expanded form.
#[inline]
pub fn to_seed(&self) -> Option<Seed> {
self.d.map(|d| d.concat(self.z))
}

/// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`].
pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
&self.ek
}

#[inline]
pub(crate) fn generate<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
let d: B32 = rand(rng);
let z: B32 = rand(rng);
Self::generate_deterministic(&d, &z)
Self::generate_deterministic(d, z)
}

#[inline]
#[must_use]
#[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
pub(crate) fn generate_deterministic(d: &B32, z: &B32) -> Self {
let (dk_pke, ek_pke) = DecryptionKey::generate(d);
pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
let ek = EncapsulationKey::new(ek_pke);
let z = z.clone();
Self { dk_pke, ek, z }
let d = Some(d);
Self { dk_pke, ek, d, z }
}
}

Expand Down Expand Up @@ -224,10 +263,7 @@ where
}

#[cfg(feature = "deterministic")]
fn generate_deterministic(
d: &B32,
z: &B32,
) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
fn generate_deterministic(d: B32, z: B32) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
let dk = Self::DecapsulationKey::generate_deterministic(d, z);
let ek = dk.encapsulation_key().clone();
(dk, ek)
Expand All @@ -239,6 +275,7 @@ mod test {
use super::*;
use crate::{MlKem512Params, MlKem768Params, MlKem1024Params};
use ::kem::{Decapsulate, Encapsulate};
use rand_core::TryRngCore;

fn round_trip_test<P>()
where
Expand All @@ -261,7 +298,7 @@ mod test {
round_trip_test::<MlKem1024Params>();
}

fn codec_test<P>()
fn expanded_key_test<P>()
where
P: KemParams,
{
Expand All @@ -279,9 +316,29 @@ mod test {
}

#[test]
fn codec() {
codec_test::<MlKem512Params>();
codec_test::<MlKem768Params>();
codec_test::<MlKem1024Params>();
fn expanded_key() {
expanded_key_test::<MlKem512Params>();
expanded_key_test::<MlKem768Params>();
expanded_key_test::<MlKem1024Params>();
}

fn seed_test<P>()
where
P: KemParams,
{
let mut rng = rand::rng();
let mut seed = Seed::default();
rng.try_fill_bytes(&mut seed).unwrap();

let dk = DecapsulationKey::<P>::from_seed(seed.clone());
let seed_encoded = dk.to_seed().unwrap();
assert_eq!(seed, seed_encoded);
}

#[test]
fn seed() {
seed_test::<MlKem512Params>();
seed_test::<MlKem768Params>();
seed_test::<MlKem1024Params>();
}
}
9 changes: 6 additions & 3 deletions ml-kem/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ use ::kem::{Decapsulate, Encapsulate};
use core::fmt::Debug;
use hybrid_array::{
Array,
typenum::{U2, U3, U4, U5, U10, U11},
typenum::{U2, U3, U4, U5, U10, U11, U64},
};
use rand_core::CryptoRng;

Expand All @@ -80,6 +80,10 @@ pub use util::B32;

pub use param::{ArraySize, ParameterSet};

/// ML-KEM seeds are decapsulation (private) keys, which are consistently 64-bytes across all
/// security levels, and are the preferred serialization for representing such keys.
pub type Seed = Array<u8, U64>;

/// An object that knows what size it is
pub trait EncodedSizeUser {
/// The size of an encoded object
Expand Down Expand Up @@ -148,8 +152,7 @@ pub trait KemCore: Clone {

/// Generate a new (decapsulation, encapsulation) key pair deterministically
#[cfg(feature = "deterministic")]
fn generate_deterministic(d: &B32, z: &B32)
-> (Self::DecapsulationKey, Self::EncapsulationKey);
fn generate_deterministic(d: B32, z: B32) -> (Self::DecapsulationKey, Self::EncapsulationKey);
}

/// `MlKem512` is the parameter set for security category 1, corresponding to key search on a block
Expand Down
2 changes: 1 addition & 1 deletion ml-kem/tests/key-gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ fn verify<K: KemCore>(tc: &acvp::TestCase) {
let dk_bytes = Encoded::<K::DecapsulationKey>::try_from(tc.dk.as_slice()).unwrap();
let ek_bytes = Encoded::<K::EncapsulationKey>::try_from(tc.ek.as_slice()).unwrap();

let (dk, ek) = K::generate_deterministic(&d, &z);
let (dk, ek) = K::generate_deterministic(d, z);

// Verify correctness via serialization
assert_eq!(dk.as_bytes().as_slice(), tc.dk.as_slice());
Expand Down
2 changes: 1 addition & 1 deletion x-wing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ impl DecapsulationKey {

let d = read_from(&mut expanded).into();
let z = read_from(&mut expanded).into();
let (sk_m, pk_m) = MlKem768::generate_deterministic(&d, &z);
let (sk_m, pk_m) = MlKem768::generate_deterministic(d, z);

let sk_x = read_from(&mut expanded);
let sk_x = StaticSecret::from(sk_x);
Expand Down