Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion ml-kem/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ pub type Seed = Array<u8, U64>;
/// cipher with a 128-bit key.
pub mod ml_kem_512 {
use super::{Debug, ParameterSet, U2, U3, U4, U10, kem};
use crate::param;

/// `MlKem512` is the parameter set for security category 1, corresponding to key search on a
/// block cipher with a 128-bit key.
Expand All @@ -119,12 +120,21 @@ pub mod ml_kem_512 {
/// An ML-KEM-512 `EncapsulationKey` provides the ability to encapsulate a shared key so that it
/// can only be decapsulated by the holder of the corresponding decapsulation key.
pub type EncapsulationKey = kem::EncapsulationKey<MlKem512Params>;

/// Encoded ML-KEM-512 ciphertexts.
pub type EncodedCiphertext = param::EncodedCiphertext<MlKem512Params>;

/// Legacy expanded decapsulation keys. Prefer seeds instead.
#[doc(hidden)]
#[deprecated(since = "0.3.0", note = "use `Seed` instead")]
pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey<MlKem512Params>;
}

/// ML-KEM-768 is the parameter set for security category 3, corresponding to key search on a block
/// cipher with a 192-bit key.
pub mod ml_kem_768 {
use super::{Debug, ParameterSet, U2, U3, U4, U10, kem};
use crate::param;

/// `MlKem768` is the parameter set for security category 3, corresponding to key search on a
/// block cipher with a 192-bit key.
Expand All @@ -146,12 +156,20 @@ pub mod ml_kem_768 {
/// An ML-KEM-768 `EncapsulationKey` provides the ability to encapsulate a shared key so that it
/// can only be decapsulated by the holder of the corresponding decapsulation key.
pub type EncapsulationKey = kem::EncapsulationKey<MlKem768Params>;

/// Encoded ML-KEM-512 ciphertexts.
pub type EncodedCiphertext = param::EncodedCiphertext<MlKem768Params>;

/// Legacy expanded decapsulation keys. Prefer seeds instead.
#[doc(hidden)]
#[deprecated(since = "0.3.0", note = "use `Seed` instead")]
pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey<MlKem768Params>;
}

/// ML-KEM-1024 is the parameter set for security category 5, corresponding to key search on a block
/// cipher with a 256-bit key.
pub mod ml_kem_1024 {
use super::{Debug, ParameterSet, U2, U4, U5, U11, kem};
use super::{Debug, ParameterSet, U2, U4, U5, U11, kem, param};

/// `MlKem1024` is the parameter set for security category 5, corresponding to key search on a
/// block cipher with a 256-bit key.
Expand All @@ -173,6 +191,14 @@ pub mod ml_kem_1024 {
/// An ML-KEM-1024 `EncapsulationKey` provides the ability to encapsulate a shared key so that
/// it can only be decapsulated by the holder of the corresponding decapsulation key.
pub type EncapsulationKey = kem::EncapsulationKey<MlKem1024Params>;

/// Encoded ML-KEM-512 ciphertexts.
pub type EncodedCiphertext = param::EncodedCiphertext<MlKem1024Params>;

/// Legacy expanded decapsulation keys. Prefer seeds instead.
#[doc(hidden)]
#[deprecated(since = "0.3.0", note = "use `Seed` instead")]
pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey<MlKem1024Params>;
}

/// An ML-KEM-512 `DecapsulationKey` which provides the ability to generate a new key pair, and
Expand Down
114 changes: 93 additions & 21 deletions ml-kem/tests/wycheproof.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Test against the Wycheproof test vectors.

use array::{Array, ArraySize};
use ml_kem::{
EncodedSizeUser, KemCore, MlKem512, MlKem768, MlKem1024,
kem::{KeyExport, TryKeyInit},
kem::{Decapsulate, KeyExport, TryKeyInit},
};
use serde::Deserialize;
use std::fs::File;
Expand Down Expand Up @@ -43,7 +44,6 @@ struct Test {
dk: Option<String>,
#[cfg(feature = "hazmat")]
m: Option<String>,
#[cfg(feature = "hazmat")]
c: Option<String>,
#[cfg(feature = "hazmat")]
#[serde(default, rename(deserialize = "K"))]
Expand Down Expand Up @@ -73,13 +73,16 @@ macro_rules! load_json_file {
}};
}

fn decode_optional_hex(opt: &Option<String>, field: &str) -> Vec<u8> {
match opt {
Some(h) => {
hex::decode(h).unwrap_or_else(|e| panic!("invalid hex for field '{field}': {e}"))
}
None => panic!("missing field: {field}"),
}
fn decode_optional_hex<U: ArraySize>(opt: &Option<String>, field: &str) -> Option<Array<u8, U>> {
opt.as_ref().and_then(|h| {
let vec = hex::decode(h).unwrap_or_else(|e| panic!("invalid hex for field '{field}': {e}"));
vec.as_slice().try_into().ok()
})
}

fn decode_expected_hex<U: ArraySize>(opt: &Option<String>, field: &str) -> Array<u8, U> {
decode_optional_hex(opt, field)
.unwrap_or_else(|| panic!("missing or incorrect length field: {field}"))
}

macro_rules! mlkem_keygen_seed_test {
Expand All @@ -101,11 +104,11 @@ macro_rules! mlkem_keygen_seed_test {
test.comment.as_ref().unwrap(),
&test.result
);
let test_seed = decode_optional_hex(&test.seed, "seed");
let test_dk = decode_optional_hex(&test.dk, "dk");
let test_seed = decode_expected_hex(&test.seed, "seed");
let test_dk = decode_expected_hex(&test.dk, "dk");

let (dk, ek) = $kem::from_seed(test_seed.as_slice().try_into().unwrap());
assert_eq!(test_dk.as_slice(), dk.to_encoded_bytes().as_slice());
let (dk, ek) = $kem::from_seed(test_seed);
assert_eq!(test_dk, dk.to_encoded_bytes());
assert_eq!(test.ek.as_slice(), ek.to_bytes().as_slice());
}
}
Expand Down Expand Up @@ -143,21 +146,74 @@ macro_rules! mlkem_encaps_test {

#[cfg(feature = "hazmat")]
{
let test_m = decode_optional_hex(&test.m, "m");
let test_m = test_m.as_slice().try_into().unwrap();
let (c, k) = ek.encapsulate_deterministic(test_m);

let test_c = decode_optional_hex(&test.c, "c");
let test_k = decode_optional_hex(&test.k, "K");
assert_eq!(test_c.as_slice(), c.as_slice());
assert_eq!(test_k.as_slice(), k.as_slice());
let test_m = decode_expected_hex(&test.m, "m");
let (c, k) = ek.encapsulate_deterministic(&test_m);

let test_c = decode_expected_hex(&test.c, "c");
let test_k = decode_expected_hex(&test.k, "K");
assert_eq!(test_c, c);
assert_eq!(test_k, k);
}
}
}
}
};
}

macro_rules! mlkem_decaps_test {
($name:ident, $json_file:expr, $kem_module:ident) => {
#[test]
fn $name() {
let tests = load_json_file!($json_file);

for group in tests.groups {
println!(
"Parameter set: {} ({} v{})\n",
&group.parameter_set, &group.source.name, &group.source.version
);

for test in &group.tests {
println!("Test #{} ({:?})", test.id, &test.result);

#[allow(deprecated)]
use ml_kem::$kem_module::{
DecapsulationKey, EncodedCiphertext, ExpandedDecapsulationKey,
};

#[allow(deprecated)]
let test_dk: ExpandedDecapsulationKey =
match decode_optional_hex(&test.dk, "dk") {
Some(dk) => dk,
None => {
if test.result == ExpectedResult::Invalid {
continue;
} else {
panic!("failed to decode expanded decapsulation key!")
}
}
};

#[allow(deprecated)]
let dk = DecapsulationKey::from_expanded(&test_dk).expect("should be valid");

let test_c: EncodedCiphertext = match decode_optional_hex(&test.c, "c") {
Some(dk) => dk,
None => {
if test.result == ExpectedResult::Invalid {
continue;
} else {
panic!("failed to decode ciphertext!")
}
}
};

let _ss = dk.decapsulate(&test_c);
}
}
}
};
}

mlkem_keygen_seed_test!(
mlkem_512_keygen_seed_test,
"mlkem_512_keygen_seed_test.json",
Expand All @@ -174,6 +230,22 @@ mlkem_keygen_seed_test!(
MlKem1024
);

mlkem_decaps_test!(
mlkem_512_semi_expanded_decaps_test,
"mlkem_512_semi_expanded_decaps_test.json",
ml_kem_512
);
mlkem_decaps_test!(
mlkem_768_semi_expanded_decaps_test,
"mlkem_768_semi_expanded_decaps_test.json",
ml_kem_768
);
mlkem_decaps_test!(
mlkem_1024_semi_expanded_decaps_test,
"mlkem_1024_semi_expanded_decaps_test.json",
ml_kem_1024
);

mlkem_encaps_test!(
mlkem_512_encaps_test,
"mlkem_512_encaps_test.json",
Expand Down