From bf371853239d749a6fb999a9a6ef0905bbbfd43a Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Thu, 22 Jan 2026 17:52:11 -0700 Subject: [PATCH] Add `kem::{TryKeyInit + KeyExport}` impls for encapsulation keys Also adds a `KeySizeUser` impl. This is a companion PR to RustCrypto/traits#2215 which added a bound to the `Encapsulate` trait for `TryKeyInit + KeyExport` from `crypto-common`. These traits both have a supertrait bound on `KeySizeUser`, which defines an `ArraySize` for a fixed-size key. The other two provide common traits for fallible decoding and encoding respectively, where the former uses the common `InvalidKey` type also defined in the `crypto-common` crate. This was one big missing gap for generic KEM use. Some traits we need aren't being re-exported from `kem` and it doesn't do a re-export of `crypto-common` so this has a few TODOs to follow up on that. We need to get this landed first though. --- Cargo.lock | 16 ++++---- Cargo.toml | 3 ++ dhkem/Cargo.toml | 3 ++ dhkem/src/ecdh_kem.rs | 75 ++++++++++++++++++++++++++++++++++++- dhkem/src/x25519_kem.rs | 46 +++++++++++++++++++++-- ml-kem/Cargo.toml | 3 ++ ml-kem/benches/mlkem.rs | 12 +++--- ml-kem/src/kem.rs | 63 +++++++++++++++++++++++-------- ml-kem/src/lib.rs | 6 ++- ml-kem/src/pkcs8.rs | 2 +- ml-kem/src/traits.rs | 4 +- ml-kem/tests/encap-decap.rs | 4 +- ml-kem/tests/key-gen.rs | 14 +++++-- ml-kem/tests/pkcs8.rs | 2 +- x-wing/src/lib.rs | 52 ++++++++++++++----------- 15 files changed, 237 insertions(+), 68 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d5f8fce..5e2d276 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -295,9 +295,9 @@ dependencies = [ [[package]] name = "crypto-common" -version = "0.2.0-rc.11" +version = "0.2.0-rc.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d2bcc93d5cde6659e8649fc412894417ebc14dee54cfc6ee439c683a4a58342" +checksum = "a6dcdb44f2c3ee25689ca12a4c19e664fd09f97aeae0bc5043b2dbab6389e308" dependencies = [ "getrandom", "hybrid-array", @@ -355,6 +355,7 @@ dependencies = [ name = "dhkem" version = "0.1.0-pre.0" dependencies = [ + "crypto-common", "elliptic-curve", "getrandom", "hex-literal", @@ -608,9 +609,9 @@ dependencies = [ [[package]] name = "hybrid-array" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f471e0a81b2f90ffc0cb2f951ae04da57de8baa46fa99112b062a5173a5088d0" +checksum = "b41fb3dc24fe72c2e3a4685eed55917c2fb228851257f4a8f2d985da9443c3e5" dependencies = [ "subtle", "typenum", @@ -683,8 +684,7 @@ dependencies = [ [[package]] name = "kem" version = "0.4.0-rc.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e364ee5b2ff87ea53e759732e09cea035602f5c35449088f2d7d05591680272" +source = "git+https://github.com/RustCrypto/traits#4e195b4fc5062a9ccd0080070d2b05202b2b3108" dependencies = [ "crypto-common", "rand_core", @@ -717,6 +717,7 @@ version = "0.3.0-pre.3" dependencies = [ "const-oid", "criterion", + "crypto-common", "getrandom", "hex", "hex-literal", @@ -1114,8 +1115,7 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sec1" version = "0.8.0-rc.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b54617aeb7e34ace1a4b72ba79bb6297e48285dc0cce064dc063ddcbf538996" +source = "git+https://github.com/RustCrypto/formats#eb63b10272698f9055b6197c09076dd69acfdc59" dependencies = [ "base16ct", "ctutils", diff --git a/Cargo.toml b/Cargo.toml index c01e6f8..a403574 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,6 @@ debug = true [patch.crates-io] ml-kem = { path = "./ml-kem" } + +kem = { git = "https://github.com/RustCrypto/traits" } +sec1 = { git = "https://github.com/RustCrypto/formats" } diff --git a/dhkem/Cargo.toml b/dhkem/Cargo.toml index cf36c8e..cfc4516 100644 --- a/dhkem/Cargo.toml +++ b/dhkem/Cargo.toml @@ -17,6 +17,9 @@ readme = "README.md" kem = "0.4.0-rc.4" rand_core = "0.10.0-rc-5" +# TODO(tarcieri): remove this and get these from `kem` +common = { package = "crypto-common", version = "0.2.0-rc.12" } + # optional dependencies elliptic-curve = { version = "0.14.0-rc.23", optional = true, default-features = false } k256 = { version = "0.14.0-rc.5", optional = true, default-features = false, features = ["arithmetic"] } diff --git a/dhkem/src/ecdh_kem.rs b/dhkem/src/ecdh_kem.rs index 415fd6e..106faf6 100644 --- a/dhkem/src/ecdh_kem.rs +++ b/dhkem/src/ecdh_kem.rs @@ -3,10 +3,18 @@ use crate::{DhDecapsulator, DhEncapsulator, DhKem}; use core::{convert::Infallible, marker::PhantomData}; use elliptic_curve::{ - CurveArithmetic, Generate, PublicKey, + AffinePoint, + CurveArithmetic, + FieldBytesSize, + Generate, + PublicKey, + common::InvalidKey, // TODO(tarcieri): get from `kem` crate ecdh::{EphemeralSecret, SharedSecret}, + sec1::{ + FromEncodedPoint, ModulusSize, ToEncodedPoint, UncompressedPoint, UncompressedPointSize, + }, }; -use kem::{Decapsulate, Encapsulate}; +use kem::{Decapsulate, Encapsulate, KeyExport, KeySizeUser, TryKeyInit}; use rand_core::{CryptoRng, TryCryptoRng}; /// Generic Elliptic Curve Diffie-Hellman KEM adapter compatible with curves implemented using @@ -15,9 +23,68 @@ use rand_core::{CryptoRng, TryCryptoRng}; /// Implements a KEM interface that internally uses ECDH. pub struct EcdhKem(PhantomData); +/// From [RFC9810 §7.1.1]: `SerializePublicKey` and `DeserializePublicKey`: +/// +/// > For P-256, P-384, and P-521, the SerializePublicKey() function of the +/// > KEM performs the uncompressed Elliptic-Curve-Point-to-Octet-String +/// > conversion according to [SECG]. +/// +/// [RFC9810 §7.1.1]: https://datatracker.ietf.org/doc/html/rfc9180#name-serializepublickey-and-dese +/// [SECG]: https://www.secg.org/sec1-v2.pdf +impl KeySizeUser for DhEncapsulator> +where + C: CurveArithmetic, + FieldBytesSize: ModulusSize, +{ + type KeySize = UncompressedPointSize; +} + +/// From [RFC9810 §7.1.1]: `SerializePublicKey` and `DeserializePublicKey`: +/// +/// > DeserializePublicKey() performs the uncompressed +/// > Octet-String-to-Elliptic-Curve-Point conversion. +/// +/// [RFC9810 §7.1.1]: https://datatracker.ietf.org/doc/html/rfc9180#name-serializepublickey-and-dese +impl TryKeyInit for DhEncapsulator> +where + C: CurveArithmetic, + FieldBytesSize: ModulusSize, + AffinePoint: FromEncodedPoint + ToEncodedPoint, +{ + fn new(encapsulation_key: &UncompressedPoint) -> Result { + PublicKey::::from_sec1_bytes(encapsulation_key) + .map(Into::into) + .map_err(|_| InvalidKey) + } +} + +/// From [RFC9810 §7.1.1]: `SerializePublicKey` and `DeserializePublicKey`: +/// +/// > For P-256, P-384, and P-521, the SerializePublicKey() function of the +/// > KEM performs the uncompressed Elliptic-Curve-Point-to-Octet-String +/// > conversion according to [SECG]. +/// +/// [RFC9810 §7.1.1]: https://datatracker.ietf.org/doc/html/rfc9180#name-serializepublickey-and-dese +/// [SECG]: https://www.secg.org/sec1-v2.pdf +impl KeyExport for DhEncapsulator> +where + C: CurveArithmetic, + FieldBytesSize: ModulusSize, + AffinePoint: FromEncodedPoint + ToEncodedPoint, +{ + fn to_bytes(&self) -> UncompressedPoint { + // TODO(tarcieri): use `ToEncodedPoint::to_uncompressed_point` (RustCrypto/traits#2221) + let mut ret = UncompressedPoint::::default(); + ret.copy_from_slice(self.0.to_encoded_point(false).as_bytes()); + ret + } +} + impl Encapsulate, SharedSecret> for DhEncapsulator> where C: CurveArithmetic, + FieldBytesSize: ModulusSize, + AffinePoint: FromEncodedPoint + ToEncodedPoint, { type Error = Infallible; @@ -38,6 +105,8 @@ where impl Decapsulate, SharedSecret> for DhDecapsulator> where C: CurveArithmetic, + FieldBytesSize: ModulusSize, + AffinePoint: FromEncodedPoint + ToEncodedPoint, { type Encapsulator = DhEncapsulator>; type Error = Infallible; @@ -56,6 +125,8 @@ where impl DhKem for EcdhKem where C: CurveArithmetic, + FieldBytesSize: ModulusSize, + AffinePoint: FromEncodedPoint + ToEncodedPoint, { type DecapsulatingKey = DhDecapsulator>; type EncapsulatingKey = DhEncapsulator>; diff --git a/dhkem/src/x25519_kem.rs b/dhkem/src/x25519_kem.rs index 0f53d70..e7982d9 100644 --- a/dhkem/src/x25519_kem.rs +++ b/dhkem/src/x25519_kem.rs @@ -1,14 +1,54 @@ use crate::{DhDecapsulator, DhEncapsulator, DhKem}; use core::convert::Infallible; -use kem::{Decapsulate, Encapsulate}; +use kem::{Decapsulate, Encapsulate, KeyExport, KeySizeUser, TryKeyInit, consts::U32}; use rand_core::{CryptoRng, TryCryptoRng, UnwrapErr}; use x25519::{PublicKey, ReusableSecret, SharedSecret}; +// TODO(tarcieri): get these from `kem` +use common::{InvalidKey, Key}; + /// X22519 Diffie-Hellman KEM adapter. /// /// Implements a KEM interface that internally uses X25519 ECDH. pub struct X25519Kem; +/// From [RFC9810 §7.1.1]: `SerializePublicKey` and `DeserializePublicKey`: +/// +/// > For X25519 and X448, the SerializePublicKey() and +/// > DeserializePublicKey() functions are the identity function, since +/// > these curves already use fixed-length byte strings for public keys. +/// +/// [RFC9810 §7.1.1]: https://datatracker.ietf.org/doc/html/rfc9180#name-serializepublickey-and-dese +impl KeySizeUser for DhEncapsulator { + type KeySize = U32; +} + +/// From [RFC9810 §7.1.1]: `SerializePublicKey` and `DeserializePublicKey`: +/// +/// > For X25519 and X448, the SerializePublicKey() and +/// > DeserializePublicKey() functions are the identity function, since +/// > these curves already use fixed-length byte strings for public keys. +/// +/// [RFC9810 §7.1.1]: https://datatracker.ietf.org/doc/html/rfc9180#name-serializepublickey-and-dese +impl TryKeyInit for DhEncapsulator { + fn new(encapsulation_key: &Key) -> Result { + Ok(Self(PublicKey::from(encapsulation_key.0))) + } +} + +/// From [RFC9810 §7.1.1]: `SerializePublicKey` and `DeserializePublicKey`: +/// +/// > For X25519 and X448, the SerializePublicKey() and +/// > DeserializePublicKey() functions are the identity function, since +/// > these curves already use fixed-length byte strings for public keys. +/// +/// [RFC9810 §7.1.1]: https://datatracker.ietf.org/doc/html/rfc9180#name-serializepublickey-and-dese +impl KeyExport for DhEncapsulator { + fn to_bytes(&self) -> Key { + self.0.to_bytes().into() + } +} + impl Encapsulate for DhEncapsulator { type Error = Infallible; @@ -30,9 +70,7 @@ impl Decapsulate for DhDecapsulator { type Error = Infallible; fn decapsulate(&self, encapsulated_key: &PublicKey) -> Result { - let ss = self.0.diffie_hellman(encapsulated_key); - - Ok(ss) + Ok(self.0.diffie_hellman(encapsulated_key)) } fn encapsulator(&self) -> DhEncapsulator { diff --git a/ml-kem/Cargo.toml b/ml-kem/Cargo.toml index d12d5d7..d8e2c73 100644 --- a/ml-kem/Cargo.toml +++ b/ml-kem/Cargo.toml @@ -31,6 +31,9 @@ rand_core = "0.10.0-rc-5" sha3 = { version = "0.11.0-rc.3", default-features = false } subtle = { version = "2", default-features = false } +# TODO(tarcieri): remove this and get these from `kem` +common = { package = "crypto-common", version = "0.2.0-rc.12" } + # optional dependencies const-oid = { version = "0.10.1", optional = true, default-features = false, features = ["db"] } pkcs8 = { version = "0.11.0-rc.8", optional = true, default-features = false } diff --git a/ml-kem/benches/mlkem.rs b/ml-kem/benches/mlkem.rs index 98d2ee8..177413f 100644 --- a/ml-kem/benches/mlkem.rs +++ b/ml-kem/benches/mlkem.rs @@ -11,15 +11,15 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("keygen", |b| { b.iter(|| { let dk = ml_kem_768::DecapsulationKey::generate_from_rng(&mut rng); - let _dk_bytes = dk.to_bytes(); - let _ek_bytes = dk.encapsulator().to_bytes(); + let _dk_bytes = dk.to_encoded_bytes(); + let _ek_bytes = dk.encapsulator().to_encoded_bytes(); }) }); let dk = ml_kem_768::DecapsulationKey::generate_from_rng(&mut rng); - let dk_bytes = dk.to_bytes(); - let ek_bytes = dk.encapsulator().to_bytes(); - let ek = ml_kem_768::EncapsulationKey::from_bytes(&ek_bytes).unwrap(); + let dk_bytes = dk.to_encoded_bytes(); + let ek_bytes = dk.encapsulator().to_encoded_bytes(); + let ek = ml_kem_768::EncapsulationKey::from_encoded_bytes(&ek_bytes).unwrap(); // Encapsulation c.bench_function("encapsulate", |b| { @@ -28,7 +28,7 @@ fn criterion_benchmark(c: &mut Criterion) { let (ct, _ss) = ek.encapsulate_with_rng(&mut rng).unwrap(); // Decapsulation - let dk = ::DecapsulationKey::from_bytes(&dk_bytes).unwrap(); + let dk = ::DecapsulationKey::from_encoded_bytes(&dk_bytes).unwrap(); c.bench_function("decapsulate", |b| { b.iter(|| { diff --git a/ml-kem/src/kem.rs b/ml-kem/src/kem.rs index 58e1001..c2b520e 100644 --- a/ml-kem/src/kem.rs +++ b/ml-kem/src/kem.rs @@ -1,7 +1,7 @@ -use core::{convert::Infallible, marker::PhantomData}; -use hybrid_array::typenum::{U32, U64}; -use rand_core::{CryptoRng, TryCryptoRng, TryRngCore}; -use subtle::{ConditionallySelectable, ConstantTimeEq}; +//! + +// Re-export traits from the `kem` crate +pub use ::kem::{Decapsulate, Encapsulate, Generate, KeyExport, KeySizeUser, TryKeyInit}; use crate::{ Encoded, EncodedSizeUser, Error, Seed, @@ -13,13 +13,17 @@ use crate::{ pke::{DecryptionKey, EncryptionKey}, util::B32, }; +use core::{convert::Infallible, marker::PhantomData}; +use hybrid_array::typenum::{U32, U64}; +use rand_core::{CryptoRng, TryCryptoRng, TryRngCore}; +use subtle::{ConditionallySelectable, ConstantTimeEq}; + +// TODO(tarcieri): get these from `kem` +use common::{InvalidKey, Key, KeyInit}; #[cfg(feature = "zeroize")] use zeroize::{Zeroize, ZeroizeOnDrop}; -// Re-export traits from the `kem` crate -pub use ::kem::{Decapsulate, Encapsulate, Generate, KeyInit, KeySizeUser}; - /// A shared key resulting from an ML-KEM transaction pub(crate) type SharedKey = B32; @@ -76,14 +80,14 @@ where { type EncodedSize = DecapsulationKeySize

; - fn from_bytes(expanded: &Encoded) -> Result { + fn from_encoded_bytes(expanded: &Encoded) -> Result { #[allow(deprecated)] Self::from_expanded(expanded) } - fn to_bytes(&self) -> Encoded { + fn to_encoded_bytes(&self) -> Encoded { let dk_pke = self.dk_pke.to_bytes(); - let ek = self.ek.to_bytes(); + let ek = self.ek.to_encoded_bytes(); P::concat_dk(dk_pke, ek, self.ek.h.clone(), self.z.clone()) } } @@ -256,11 +260,38 @@ where { type EncodedSize = EncapsulationKeySize

; - fn from_bytes(enc: &Encoded) -> Result { + fn from_encoded_bytes(enc: &Encoded) -> Result { Ok(Self::new(EncryptionKey::from_bytes(enc)?)) } - fn to_bytes(&self) -> Encoded { + fn to_encoded_bytes(&self) -> Encoded { + self.ek_pke.to_bytes() + } +} + +impl

KeySizeUser for EncapsulationKey

+where + P: KemParams, +{ + type KeySize = EncapsulationKeySize

; +} + +impl

TryKeyInit for EncapsulationKey

+where + P: KemParams, +{ + fn new(encapsulation_key: &Key) -> Result { + EncryptionKey::from_bytes(encapsulation_key) + .map(Self::new) + .map_err(|_| InvalidKey) + } +} + +impl

KeyExport for EncapsulationKey

+where + P: KemParams, +{ + fn to_bytes(&self) -> Key { self.ek_pke.to_bytes() } } @@ -367,12 +398,12 @@ mod test { let dk_original = DecapsulationKey::

::generate_from_rng(&mut rng); let ek_original = dk_original.encapsulation_key().clone(); - let dk_encoded = dk_original.to_bytes(); - let dk_decoded = DecapsulationKey::from_bytes(&dk_encoded).unwrap(); + let dk_encoded = dk_original.to_encoded_bytes(); + let dk_decoded = DecapsulationKey::from_encoded_bytes(&dk_encoded).unwrap(); assert_eq!(dk_original, dk_decoded); - let ek_encoded = ek_original.to_bytes(); - let ek_decoded = EncapsulationKey::from_bytes(&ek_encoded).unwrap(); + let ek_encoded = ek_original.to_encoded_bytes(); + let ek_decoded = EncapsulationKey::from_encoded_bytes(&ek_encoded).unwrap(); assert_eq!(ek_original, ek_decoded); } diff --git a/ml-kem/src/lib.rs b/ml-kem/src/lib.rs index da4f1eb..576f19f 100644 --- a/ml-kem/src/lib.rs +++ b/ml-kem/src/lib.rs @@ -26,8 +26,9 @@ //! //! use ml_kem::{ //! ml_kem_768::DecapsulationKey, -//! kem::{Decapsulate, Encapsulate, Generate, KeyInit} +//! kem::{Decapsulate, Encapsulate, Generate} //! }; +//! use common::KeyInit; // TODO(tarcieri): fix this! //! //! // Generate a decapsulation/encapsulation keypair //! let dk = DecapsulationKey::generate(); @@ -95,6 +96,9 @@ pub use ml_kem_1024::MlKem1024Params; pub use param::{ArraySize, ExpandedDecapsulationKey, ParameterSet}; pub use traits::*; +// TODO(tarcieri): get rid of this! +pub use common; + /// ML-KEM seeds are decapsulation (private) keys, which are consistently 64-bytes across all /// security levels, and are the preferred serialization for representing such keys. pub type Seed = Array; diff --git a/ml-kem/src/pkcs8.rs b/ml-kem/src/pkcs8.rs index 1c6c440..26054da 100644 --- a/ml-kem/src/pkcs8.rs +++ b/ml-kem/src/pkcs8.rs @@ -101,7 +101,7 @@ where /// Serialize the given `EncapsulationKey` into DER format. /// Returns a `Document` which wraps the DER document in case of success. fn to_public_key_der(&self) -> spki::Result { - let public_key = self.to_bytes(); + let public_key = self.to_encoded_bytes(); let subject_public_key = BitStringRef::new(0, &public_key)?; ::pkcs8::SubjectPublicKeyInfo { diff --git a/ml-kem/src/traits.rs b/ml-kem/src/traits.rs index 69e8a95..3f95623 100644 --- a/ml-kem/src/traits.rs +++ b/ml-kem/src/traits.rs @@ -18,10 +18,10 @@ pub trait EncodedSizeUser: Sized { /// /// # Errors /// - If the object failed to decode successfully - fn from_bytes(enc: &Encoded) -> Result; + fn from_encoded_bytes(enc: &Encoded) -> Result; /// Serialize an object to its encoded form - fn to_bytes(&self) -> Encoded; + fn to_encoded_bytes(&self) -> Encoded; } /// A byte array encoding a value the indicated size diff --git a/ml-kem/tests/encap-decap.rs b/ml-kem/tests/encap-decap.rs index ab38a22..6b7101a 100644 --- a/ml-kem/tests/encap-decap.rs +++ b/ml-kem/tests/encap-decap.rs @@ -42,7 +42,7 @@ where { let m = Array::try_from(tc.m.as_slice()).unwrap(); let ek_bytes = Encoded::::try_from(tc.ek.as_slice()).unwrap(); - let ek = K::EncapsulationKey::from_bytes(&ek_bytes).unwrap(); + let ek = K::EncapsulationKey::from_encoded_bytes(&ek_bytes).unwrap(); let (c, k) = ek.encapsulate_deterministic(&m).unwrap(); @@ -62,7 +62,7 @@ fn verify_decap_group(tg: &acvp::DecapTestGroup) { fn verify_decap(tc: &acvp::DecapTestCase, dk_slice: &[u8]) { let dk_bytes = Encoded::::try_from(dk_slice).unwrap(); - let dk = K::DecapsulationKey::from_bytes(&dk_bytes).unwrap(); + let dk = K::DecapsulationKey::from_encoded_bytes(&dk_bytes).unwrap(); let c = Ciphertext::::try_from(tc.c.as_slice()).unwrap(); let k = dk.decapsulate(&c).unwrap(); diff --git a/ml-kem/tests/key-gen.rs b/ml-kem/tests/key-gen.rs index 2e83580..73b5140 100644 --- a/ml-kem/tests/key-gen.rs +++ b/ml-kem/tests/key-gen.rs @@ -37,12 +37,18 @@ fn verify(tc: &acvp::TestCase) { let (dk, ek) = K::from_seed(d.concat(z)); // Verify correctness via serialization - assert_eq!(dk.to_bytes().as_slice(), tc.dk.as_slice()); - assert_eq!(ek.to_bytes().as_slice(), tc.ek.as_slice()); + assert_eq!(dk.to_encoded_bytes().as_slice(), tc.dk.as_slice()); + assert_eq!(ek.to_encoded_bytes().as_slice(), tc.ek.as_slice()); // Verify correctness via deserialization - assert_eq!(dk, K::DecapsulationKey::from_bytes(&dk_bytes).unwrap()); - assert_eq!(ek, K::EncapsulationKey::from_bytes(&ek_bytes).unwrap()); + assert_eq!( + dk, + K::DecapsulationKey::from_encoded_bytes(&dk_bytes).unwrap() + ); + assert_eq!( + ek, + K::EncapsulationKey::from_encoded_bytes(&ek_bytes).unwrap() + ); } mod acvp { diff --git a/ml-kem/tests/pkcs8.rs b/ml-kem/tests/pkcs8.rs index 6d21071..700a653 100644 --- a/ml-kem/tests/pkcs8.rs +++ b/ml-kem/tests/pkcs8.rs @@ -37,7 +37,7 @@ where // verify that original encapsulation key corresponds to deserialized encapsulation key let pub_key = parsed.decode_msg::().unwrap(); assert_eq!( - encaps_key.to_bytes().as_slice(), + encaps_key.to_encoded_bytes().as_slice(), pub_key.subject_public_key.as_bytes().unwrap() ); } diff --git a/x-wing/src/lib.rs b/x-wing/src/lib.rs index f8bf1c2..1afb9b0 100644 --- a/x-wing/src/lib.rs +++ b/x-wing/src/lib.rs @@ -25,12 +25,13 @@ //! assert_eq!(ss_sender, ss_receiver); //! ``` -pub use kem::{self, Decapsulate, Encapsulate, Generate}; +pub use kem::{self, Decapsulate, Encapsulate, Generate, KeyExport, KeySizeUser, TryKeyInit}; use core::convert::Infallible; use ml_kem::{ B32, EncodedSizeUser, Error, KemCore, MlKem768, MlKem768Params, array::{ArrayN, typenum::consts::U32}, + common::{InvalidKey, Key, KeyInit, array::sizes::U1216}, }; use rand_core::{CryptoRng, TryCryptoRng, TryRngCore}; use sha3::{ @@ -96,33 +97,42 @@ impl Encapsulate for EncapsulationKey { } } -impl EncapsulationKey { - /// Convert the key to the following format: - /// ML-KEM-768 public key(1184 bytes) || X25519 public key(32 bytes). - #[must_use] - pub fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_SIZE] { - let mut buffer = [0u8; ENCAPSULATION_KEY_SIZE]; - buffer[0..1184].copy_from_slice(&self.pk_m.to_bytes()); - buffer[1184..1216].copy_from_slice(self.pk_x.as_bytes()); - buffer - } +impl KeySizeUser for EncapsulationKey { + type KeySize = U1216; } -impl TryFrom<&[u8; ENCAPSULATION_KEY_SIZE]> for EncapsulationKey { - type Error = ml_kem::Error; +impl KeyExport for EncapsulationKey { + fn to_bytes(&self) -> Key { + let mut key_bytes = Key::::default(); + let (m, x) = key_bytes.split_at_mut(1184); + m.copy_from_slice(&self.pk_m.to_encoded_bytes()); + x.copy_from_slice(self.pk_x.as_bytes()); + key_bytes + } +} - fn try_from(value: &[u8; ENCAPSULATION_KEY_SIZE]) -> Result { +impl TryKeyInit for EncapsulationKey { + fn new(key_bytes: &Key) -> Result { let mut pk_m = [0; 1184]; - pk_m.copy_from_slice(&value[0..1184]); - let pk_m = MlKem768EncapsulationKey::from_bytes(&pk_m.into())?; + pk_m.copy_from_slice(&key_bytes[0..1184]); + let pk_m = + MlKem768EncapsulationKey::from_encoded_bytes(&pk_m.into()).map_err(|_| InvalidKey)?; let mut pk_x = [0; 32]; - pk_x.copy_from_slice(&value[1184..]); + pk_x.copy_from_slice(&key_bytes[1184..]); let pk_x = PublicKey::from(pk_x); Ok(EncapsulationKey { pk_m, pk_x }) } } +impl TryFrom<&[u8]> for EncapsulationKey { + type Error = InvalidKey; + + fn try_from(key_bytes: &[u8]) -> Result { + Self::new_from_slice(key_bytes) + } +} + /// X-Wing decapsulation key or private key. #[derive(Clone)] #[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))] @@ -153,11 +163,11 @@ impl Decapsulate for DecapsulationKey { } } -impl ::kem::KeySizeUser for DecapsulationKey { +impl KeySizeUser for DecapsulationKey { type KeySize = U32; } -impl ::kem::KeyInit for DecapsulationKey { +impl KeyInit for DecapsulationKey { fn new(key: &ArrayN) -> Self { Self { sk: key.0 } } @@ -368,7 +378,7 @@ mod tests { let (sk, pk) = generate_key_pair_from_rng(&mut seed); assert_eq!(sk.as_bytes(), &test_vector.sk); - assert_eq!(&pk.to_bytes(), test_vector.pk.as_slice()); + assert_eq!(&*pk.to_bytes(), test_vector.pk.as_slice()); let mut eseed = SeedRng::new(test_vector.eseed); let (ct, ss) = pk.encapsulate_with_rng(&mut eseed).unwrap(); @@ -404,7 +414,7 @@ mod tests { let pk_bytes = pk.to_bytes(); let sk_b = DecapsulationKey::from(*sk_bytes); - let pk_b = EncapsulationKey::try_from(&pk_bytes).unwrap(); + let pk_b = EncapsulationKey::new(&pk_bytes).unwrap(); assert!(sk == sk_b); assert!(pk == pk_b);