From fd282d057fb9606a438972adc1db1f7072ad4270 Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Thu, 29 Jan 2026 12:44:29 -0700 Subject: [PATCH] ml-kem: add Wycheproof `mlkem_*_decaps_test` These tests mostly cover length handling which isn't particularly helpful because it's something the caller (or a trait impl) has to do. That said, it includes some nice changes to the test machinery, and we're now set up for further future decapsulation tests. --- ml-kem/src/lib.rs | 28 ++++++++- ml-kem/tests/wycheproof.rs | 114 ++++++++++++++++++++++++++++++------- 2 files changed, 120 insertions(+), 22 deletions(-) diff --git a/ml-kem/src/lib.rs b/ml-kem/src/lib.rs index 9d92b51..15298dd 100644 --- a/ml-kem/src/lib.rs +++ b/ml-kem/src/lib.rs @@ -98,6 +98,7 @@ pub type Seed = Array; /// cipher with a 128-bit key. pub mod ml_kem_512 { use super::{Debug, ParameterSet, U2, U3, U4, U10, kem}; + use crate::param; /// `MlKem512` is the parameter set for security category 1, corresponding to key search on a /// block cipher with a 128-bit key. @@ -119,12 +120,21 @@ pub mod ml_kem_512 { /// An ML-KEM-512 `EncapsulationKey` provides the ability to encapsulate a shared key so that it /// can only be decapsulated by the holder of the corresponding decapsulation key. pub type EncapsulationKey = kem::EncapsulationKey; + + /// Encoded ML-KEM-512 ciphertexts. + pub type EncodedCiphertext = param::EncodedCiphertext; + + /// Legacy expanded decapsulation keys. Prefer seeds instead. + #[doc(hidden)] + #[deprecated(since = "0.3.0", note = "use `Seed` instead")] + pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey; } /// ML-KEM-768 is the parameter set for security category 3, corresponding to key search on a block /// cipher with a 192-bit key. pub mod ml_kem_768 { use super::{Debug, ParameterSet, U2, U3, U4, U10, kem}; + use crate::param; /// `MlKem768` is the parameter set for security category 3, corresponding to key search on a /// block cipher with a 192-bit key. @@ -146,12 +156,20 @@ pub mod ml_kem_768 { /// An ML-KEM-768 `EncapsulationKey` provides the ability to encapsulate a shared key so that it /// can only be decapsulated by the holder of the corresponding decapsulation key. pub type EncapsulationKey = kem::EncapsulationKey; + + /// Encoded ML-KEM-512 ciphertexts. + pub type EncodedCiphertext = param::EncodedCiphertext; + + /// Legacy expanded decapsulation keys. Prefer seeds instead. + #[doc(hidden)] + #[deprecated(since = "0.3.0", note = "use `Seed` instead")] + pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey; } /// ML-KEM-1024 is the parameter set for security category 5, corresponding to key search on a block /// cipher with a 256-bit key. pub mod ml_kem_1024 { - use super::{Debug, ParameterSet, U2, U4, U5, U11, kem}; + use super::{Debug, ParameterSet, U2, U4, U5, U11, kem, param}; /// `MlKem1024` is the parameter set for security category 5, corresponding to key search on a /// block cipher with a 256-bit key. @@ -173,6 +191,14 @@ pub mod ml_kem_1024 { /// An ML-KEM-1024 `EncapsulationKey` provides the ability to encapsulate a shared key so that /// it can only be decapsulated by the holder of the corresponding decapsulation key. pub type EncapsulationKey = kem::EncapsulationKey; + + /// Encoded ML-KEM-512 ciphertexts. + pub type EncodedCiphertext = param::EncodedCiphertext; + + /// Legacy expanded decapsulation keys. Prefer seeds instead. + #[doc(hidden)] + #[deprecated(since = "0.3.0", note = "use `Seed` instead")] + pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey; } /// An ML-KEM-512 `DecapsulationKey` which provides the ability to generate a new key pair, and diff --git a/ml-kem/tests/wycheproof.rs b/ml-kem/tests/wycheproof.rs index 5c69fdd..46c4cc9 100644 --- a/ml-kem/tests/wycheproof.rs +++ b/ml-kem/tests/wycheproof.rs @@ -1,8 +1,9 @@ //! Test against the Wycheproof test vectors. +use array::{Array, ArraySize}; use ml_kem::{ EncodedSizeUser, KemCore, MlKem512, MlKem768, MlKem1024, - kem::{KeyExport, TryKeyInit}, + kem::{Decapsulate, KeyExport, TryKeyInit}, }; use serde::Deserialize; use std::fs::File; @@ -43,7 +44,6 @@ struct Test { dk: Option, #[cfg(feature = "hazmat")] m: Option, - #[cfg(feature = "hazmat")] c: Option, #[cfg(feature = "hazmat")] #[serde(default, rename(deserialize = "K"))] @@ -73,13 +73,16 @@ macro_rules! load_json_file { }}; } -fn decode_optional_hex(opt: &Option, field: &str) -> Vec { - match opt { - Some(h) => { - hex::decode(h).unwrap_or_else(|e| panic!("invalid hex for field '{field}': {e}")) - } - None => panic!("missing field: {field}"), - } +fn decode_optional_hex(opt: &Option, field: &str) -> Option> { + opt.as_ref().and_then(|h| { + let vec = hex::decode(h).unwrap_or_else(|e| panic!("invalid hex for field '{field}': {e}")); + vec.as_slice().try_into().ok() + }) +} + +fn decode_expected_hex(opt: &Option, field: &str) -> Array { + decode_optional_hex(opt, field) + .unwrap_or_else(|| panic!("missing or incorrect length field: {field}")) } macro_rules! mlkem_keygen_seed_test { @@ -101,11 +104,11 @@ macro_rules! mlkem_keygen_seed_test { test.comment.as_ref().unwrap(), &test.result ); - let test_seed = decode_optional_hex(&test.seed, "seed"); - let test_dk = decode_optional_hex(&test.dk, "dk"); + let test_seed = decode_expected_hex(&test.seed, "seed"); + let test_dk = decode_expected_hex(&test.dk, "dk"); - let (dk, ek) = $kem::from_seed(test_seed.as_slice().try_into().unwrap()); - assert_eq!(test_dk.as_slice(), dk.to_encoded_bytes().as_slice()); + let (dk, ek) = $kem::from_seed(test_seed); + assert_eq!(test_dk, dk.to_encoded_bytes()); assert_eq!(test.ek.as_slice(), ek.to_bytes().as_slice()); } } @@ -143,14 +146,13 @@ macro_rules! mlkem_encaps_test { #[cfg(feature = "hazmat")] { - let test_m = decode_optional_hex(&test.m, "m"); - let test_m = test_m.as_slice().try_into().unwrap(); - let (c, k) = ek.encapsulate_deterministic(test_m); - - let test_c = decode_optional_hex(&test.c, "c"); - let test_k = decode_optional_hex(&test.k, "K"); - assert_eq!(test_c.as_slice(), c.as_slice()); - assert_eq!(test_k.as_slice(), k.as_slice()); + let test_m = decode_expected_hex(&test.m, "m"); + let (c, k) = ek.encapsulate_deterministic(&test_m); + + let test_c = decode_expected_hex(&test.c, "c"); + let test_k = decode_expected_hex(&test.k, "K"); + assert_eq!(test_c, c); + assert_eq!(test_k, k); } } } @@ -158,6 +160,60 @@ macro_rules! mlkem_encaps_test { }; } +macro_rules! mlkem_decaps_test { + ($name:ident, $json_file:expr, $kem_module:ident) => { + #[test] + fn $name() { + let tests = load_json_file!($json_file); + + for group in tests.groups { + println!( + "Parameter set: {} ({} v{})\n", + &group.parameter_set, &group.source.name, &group.source.version + ); + + for test in &group.tests { + println!("Test #{} ({:?})", test.id, &test.result); + + #[allow(deprecated)] + use ml_kem::$kem_module::{ + DecapsulationKey, EncodedCiphertext, ExpandedDecapsulationKey, + }; + + #[allow(deprecated)] + let test_dk: ExpandedDecapsulationKey = + match decode_optional_hex(&test.dk, "dk") { + Some(dk) => dk, + None => { + if test.result == ExpectedResult::Invalid { + continue; + } else { + panic!("failed to decode expanded decapsulation key!") + } + } + }; + + #[allow(deprecated)] + let dk = DecapsulationKey::from_expanded(&test_dk).expect("should be valid"); + + let test_c: EncodedCiphertext = match decode_optional_hex(&test.c, "c") { + Some(dk) => dk, + None => { + if test.result == ExpectedResult::Invalid { + continue; + } else { + panic!("failed to decode ciphertext!") + } + } + }; + + let _ss = dk.decapsulate(&test_c); + } + } + } + }; +} + mlkem_keygen_seed_test!( mlkem_512_keygen_seed_test, "mlkem_512_keygen_seed_test.json", @@ -174,6 +230,22 @@ mlkem_keygen_seed_test!( MlKem1024 ); +mlkem_decaps_test!( + mlkem_512_semi_expanded_decaps_test, + "mlkem_512_semi_expanded_decaps_test.json", + ml_kem_512 +); +mlkem_decaps_test!( + mlkem_768_semi_expanded_decaps_test, + "mlkem_768_semi_expanded_decaps_test.json", + ml_kem_768 +); +mlkem_decaps_test!( + mlkem_1024_semi_expanded_decaps_test, + "mlkem_1024_semi_expanded_decaps_test.json", + ml_kem_1024 +); + mlkem_encaps_test!( mlkem_512_encaps_test, "mlkem_512_encaps_test.json",