diff --git a/ml-kem/tests/wycheproof.rs b/ml-kem/tests/wycheproof.rs index 46c4cc9..1fe95bd 100644 --- a/ml-kem/tests/wycheproof.rs +++ b/ml-kem/tests/wycheproof.rs @@ -45,7 +45,6 @@ struct Test { #[cfg(feature = "hazmat")] m: Option, c: Option, - #[cfg(feature = "hazmat")] #[serde(default, rename(deserialize = "K"))] k: Option, result: ExpectedResult, @@ -85,6 +84,48 @@ fn decode_expected_hex(opt: &Option, field: &str) -> Array .unwrap_or_else(|| panic!("missing or incorrect length field: {field}")) } +macro_rules! mlkem_test { + ($name:ident, $json_file:expr, $kem:ident, $kem_module:ident) => { + #[test] + fn $name() { + let tests = load_json_file!($json_file); + + for group in tests.groups { + println!( + "Parameter set: {} ({} v{})\n", + &group.parameter_set, &group.source.name, &group.source.version + ); + + for test in &group.tests { + println!("Test #{}: {:?}", test.id, &test.result); + let test_seed = match decode_optional_hex(&test.seed, "seed") { + Some(seed) => seed, + None => { + assert_eq!(test.result, ExpectedResult::Invalid); + continue; + } + }; + + let (dk, ek) = $kem::from_seed(test_seed); + assert_eq!(test.ek.as_slice(), ek.to_bytes().as_slice()); + + use ml_kem::$kem_module::EncodedCiphertext; + let test_c: EncodedCiphertext = match decode_optional_hex(&test.c, "c") { + Some(dk) => dk, + None => { + assert_eq!(test.result, ExpectedResult::Invalid); + continue; + } + }; + let test_k = decode_expected_hex(&test.k, "K"); + let decrypted_k = dk.decapsulate(&test_c); + assert_eq!(test_k, decrypted_k); + } + } + } + }; +} + macro_rules! mlkem_keygen_seed_test { ($name:ident, $json_file:expr, $kem:ident) => { #[test] @@ -199,11 +240,8 @@ macro_rules! mlkem_decaps_test { let test_c: EncodedCiphertext = match decode_optional_hex(&test.c, "c") { Some(dk) => dk, None => { - if test.result == ExpectedResult::Invalid { - continue; - } else { - panic!("failed to decode ciphertext!") - } + assert_eq!(test.result, ExpectedResult::Invalid); + continue; } }; @@ -214,6 +252,15 @@ macro_rules! mlkem_decaps_test { }; } +mlkem_test!(mlkem_512_test, "mlkem_512_test.json", MlKem512, ml_kem_512); +mlkem_test!(mlkem_768_test, "mlkem_768_test.json", MlKem768, ml_kem_768); +mlkem_test!( + mlkem_1024_test, + "mlkem_1024_test.json", + MlKem1024, + ml_kem_1024 +); + mlkem_keygen_seed_test!( mlkem_512_keygen_seed_test, "mlkem_512_keygen_seed_test.json",