From 64de82c5fb8023939aa27da1906d040c602f972c Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Thu, 29 Jan 2026 11:36:01 -0700 Subject: [PATCH] ml-kem: add Wycheproof `mlkem_*_encaps_test` Tests decoding of encapsulation keys, and that they are able to generate the correct ciphertext and shared secret via the `EncapsulationKey::encapsulate_deterministic` API (when the `hazmat` feature is enabled) --- ml-kem/Cargo.toml | 2 +- ml-kem/tests/wycheproof.rs | 103 +++++++++++++++++++++++++++++++++---- 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/ml-kem/Cargo.toml b/ml-kem/Cargo.toml index 99a5bd7..e56cccc 100644 --- a/ml-kem/Cargo.toml +++ b/ml-kem/Cargo.toml @@ -19,10 +19,10 @@ exclude = ["tests/key-gen.rs", "tests/key-gen.json", "tests/encap-decap.rs", "te alloc = ["pkcs8?/alloc"] getrandom = ["kem/getrandom"] +hazmat = [] pem = ["pkcs8/pem"] pkcs8 = ["dep:const-oid", "dep:pkcs8"] zeroize = ["module-lattice/zeroize", "dep:zeroize"] -hazmat = [] [dependencies] array = { package = "hybrid-array", version = "0.4.4", features = ["extra-sizes", "subtle"] } diff --git a/ml-kem/tests/wycheproof.rs b/ml-kem/tests/wycheproof.rs index 22ca823..5c69fdd 100644 --- a/ml-kem/tests/wycheproof.rs +++ b/ml-kem/tests/wycheproof.rs @@ -1,6 +1,9 @@ //! Test against the Wycheproof test vectors. -use ml_kem::{EncodedSizeUser, KemCore, MlKem512, MlKem768, MlKem1024, kem::KeyExport}; +use ml_kem::{ + EncodedSizeUser, KemCore, MlKem512, MlKem768, MlKem1024, + kem::{KeyExport, TryKeyInit}, +}; use serde::Deserialize; use std::fs::File; @@ -33,13 +36,18 @@ struct Source { struct Test { #[serde(rename(deserialize = "tcId"))] id: usize, - comment: String, - #[serde(with = "hex::serde")] - seed: Vec, + comment: Option, + seed: Option, #[serde(default, with = "hex::serde")] ek: Vec, - #[serde(with = "hex::serde")] - dk: Vec, + dk: Option, + #[cfg(feature = "hazmat")] + m: Option, + #[cfg(feature = "hazmat")] + c: Option, + #[cfg(feature = "hazmat")] + #[serde(default, rename(deserialize = "K"))] + k: Option, result: ExpectedResult, } @@ -65,6 +73,15 @@ macro_rules! load_json_file { }}; } +fn decode_optional_hex(opt: &Option, field: &str) -> Vec { + match opt { + Some(h) => { + hex::decode(h).unwrap_or_else(|e| panic!("invalid hex for field '{field}': {e}")) + } + None => panic!("missing field: {field}"), + } +} + macro_rules! mlkem_keygen_seed_test { ($name:ident, $json_file:expr, $kem:ident) => { #[test] @@ -78,10 +95,17 @@ macro_rules! mlkem_keygen_seed_test { ); for test in &group.tests { - println!("Test #{}: {} ({:?})", test.id, &test.comment, &test.result); + println!( + "Test #{}: {} ({:?})", + test.id, + test.comment.as_ref().unwrap(), + &test.result + ); + let test_seed = decode_optional_hex(&test.seed, "seed"); + let test_dk = decode_optional_hex(&test.dk, "dk"); - let (dk, ek) = $kem::from_seed(test.seed.as_slice().try_into().unwrap()); - assert_eq!(test.dk.as_slice(), dk.to_encoded_bytes().as_slice()); + let (dk, ek) = $kem::from_seed(test_seed.as_slice().try_into().unwrap()); + assert_eq!(test_dk.as_slice(), dk.to_encoded_bytes().as_slice()); assert_eq!(test.ek.as_slice(), ek.to_bytes().as_slice()); } } @@ -89,6 +113,51 @@ macro_rules! mlkem_keygen_seed_test { }; } +macro_rules! mlkem_encaps_test { + ($name:ident, $json_file:expr, $kem_module:ident) => { + #[test] + fn $name() { + let tests = load_json_file!($json_file); + + for group in tests.groups { + println!( + "Parameter set: {} ({} v{})\n", + &group.parameter_set, &group.source.name, &group.source.version + ); + + for test in &group.tests { + println!("Test #{} ({:?})", test.id, &test.result); + + use ml_kem::$kem_module::EncapsulationKey; + let ek_result = EncapsulationKey::new_from_slice(&test.ek); + + #[cfg_attr(not(feature = "hazmat"), allow(unused_variables))] + let ek = match test.result { + ExpectedResult::Valid => ek_result.expect("should be valid"), + ExpectedResult::Invalid => { + assert!(ek_result.is_err()); + continue; + } + other => todo!("{:?}", other), + }; + + #[cfg(feature = "hazmat")] + { + let test_m = decode_optional_hex(&test.m, "m"); + let test_m = test_m.as_slice().try_into().unwrap(); + let (c, k) = ek.encapsulate_deterministic(test_m); + + let test_c = decode_optional_hex(&test.c, "c"); + let test_k = decode_optional_hex(&test.k, "K"); + assert_eq!(test_c.as_slice(), c.as_slice()); + assert_eq!(test_k.as_slice(), k.as_slice()); + } + } + } + } + }; +} + mlkem_keygen_seed_test!( mlkem_512_keygen_seed_test, "mlkem_512_keygen_seed_test.json", @@ -104,3 +173,19 @@ mlkem_keygen_seed_test!( "mlkem_1024_keygen_seed_test.json", MlKem1024 ); + +mlkem_encaps_test!( + mlkem_512_encaps_test, + "mlkem_512_encaps_test.json", + ml_kem_512 +); +mlkem_encaps_test!( + mlkem_768_encaps_test, + "mlkem_768_encaps_test.json", + ml_kem_768 +); +mlkem_encaps_test!( + mlkem_1024_encaps_test, + "mlkem_1024_encaps_test.json", + ml_kem_1024 +);