Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions ml-kem/src/pke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,50 @@ where
P::concat_ek(t_hat, self.rho.clone())
}

/// Parse an encryption key from a byte array `(t_hat || rho)`
// TODO(tarcieri): validate decoded keys
#[allow(clippy::unnecessary_wraps)]
/// Parse an encryption key from a byte array `(t_hat || rho)`.
///
/// # Errors
/// Returns [`Error`] in the event that the key fails the encapsulation key checks specified in
/// FIPS 203 §7.2.
pub fn from_bytes(enc: &EncodedEncryptionKey<P>) -> Result<Self, Error> {
let (t_hat, rho) = P::split_ek(enc);
let t_hat = P::decode_u12(t_hat);
Ok(Self {
let ret = Self {
t_hat,
rho: rho.clone(),
})
};

// Check the candidate encapsulation key is valid using the method specified in FIPS 203
// §7.2 ML-KEM Encapsulation:
//
// > Encapsulation key check. To check a candidate encapsulation key `ek`, perform the
// > following:
// >
// > 1. (Type check) If `ek` is not an array of bytes of length 384𝑘+32 for the value of 𝑘
// > specified by the relevant parameter set, then input checking failed.
// > 2. (Modulus check) Perform the computation:
// >
// > test ← ByteEncode₁₂(ByteDecode₁₂(ek[0:384𝑘]))
// >
// > (see Section 4.2.1). If `test ≠ ek[0∶384𝑘]`, then input checking failed. This
// > check ensures that the integers encoded in the public key are in the valid range
// > `[0,q-1]`.
// >
// > If both checks pass, then `ML-KEM.Encaps` can be run with input `ek`. It is important
// > to note that this checking process does not guarantee that ek is a properly produced
// > output of `ML-KEM.KeyGen`.
// >
// > `ML-KEM.Encaps` shall not be run with an encapsulation key that has not been checked as
// > above.
//
// #1 is performed by the `EncodedEncryptionKey` type, and the following check vicariously
// performs #2 by encoding the integer-mod-q array using our implementation of ByteEncode₁₂
// and comparing the resulting serialization to see if it round-trips.
if &ret.as_bytes() == enc {
Ok(ret)
} else {
Err(Error)
}
}
}

Expand Down Expand Up @@ -221,4 +255,12 @@ mod test {
codec_test::<MlKem768Params>();
codec_test::<MlKem1024Params>();
}

#[test]
fn reject_invalid_encryption_keys() {
// Create an invalid key: all bytes set to 0xFF
// When decoded as 12-bit coefficients, this produces values of 0xFFF = 4095 > 3329
let invalid_key = [0xFF; 1184];
assert!(EncryptionKey::<MlKem768Params>::from_bytes(&invalid_key.into()).is_err());
}
}