diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs
index e0b8da4..c9421ad 100644
--- a/ml-kem/src/pke.rs
+++ b/ml-kem/src/pke.rs
@@ -155,16 +155,50 @@ where
P::concat_ek(t_hat, self.rho.clone())
}
- /// Parse an encryption key from a byte array `(t_hat || rho)`
- // TODO(tarcieri): validate decoded keys
- #[allow(clippy::unnecessary_wraps)]
+ /// Parse an encryption key from a byte array `(t_hat || rho)`.
+ ///
+ /// # Errors
+ /// Returns [`Error`] in the event that the key fails the encapsulation key checks specified in
+ /// FIPS 203 §7.2.
pub fn from_bytes(enc: &EncodedEncryptionKey
) -> Result {
let (t_hat, rho) = P::split_ek(enc);
let t_hat = P::decode_u12(t_hat);
- Ok(Self {
+ let ret = Self {
t_hat,
rho: rho.clone(),
- })
+ };
+
+ // Check the candidate encapsulation key is valid using the method specified in FIPS 203
+ // §7.2 ML-KEM Encapsulation:
+ //
+ // > Encapsulation key check. To check a candidate encapsulation key `ek`, perform the
+ // > following:
+ // >
+ // > 1. (Type check) If `ek` is not an array of bytes of length 384𝑘+32 for the value of 𝑘
+ // > specified by the relevant parameter set, then input checking failed.
+ // > 2. (Modulus check) Perform the computation:
+ // >
+ // > test ← ByteEncode₁₂(ByteDecode₁₂(ek[0:384𝑘]))
+ // >
+ // > (see Section 4.2.1). If `test ≠ ek[0∶384𝑘]`, then input checking failed. This
+ // > check ensures that the integers encoded in the public key are in the valid range
+ // > `[0,q-1]`.
+ // >
+ // > If both checks pass, then `ML-KEM.Encaps` can be run with input `ek`. It is important
+ // > to note that this checking process does not guarantee that ek is a properly produced
+ // > output of `ML-KEM.KeyGen`.
+ // >
+ // > `ML-KEM.Encaps` shall not be run with an encapsulation key that has not been checked as
+ // > above.
+ //
+ // #1 is performed by the `EncodedEncryptionKey` type, and the following check vicariously
+ // performs #2 by encoding the integer-mod-q array using our implementation of ByteEncode₁₂
+ // and comparing the resulting serialization to see if it round-trips.
+ if &ret.as_bytes() == enc {
+ Ok(ret)
+ } else {
+ Err(Error)
+ }
}
}
@@ -221,4 +255,12 @@ mod test {
codec_test::();
codec_test::();
}
+
+ #[test]
+ fn reject_invalid_encryption_keys() {
+ // Create an invalid key: all bytes set to 0xFF
+ // When decoded as 12-bit coefficients, this produces values of 0xFFF = 4095 > 3329
+ let invalid_key = [0xFF; 1184];
+ assert!(EncryptionKey::::from_bytes(&invalid_key.into()).is_err());
+ }
}