diff --git a/Cargo.toml b/Cargo.toml index 661cd8e9..6ad34e8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,9 @@ byteorder = "1.3.1" thiserror = "1.0.11" subtle = "2.0.0" simple_asn1 = "0.4" -pem = { version = "0.7", optional = true } +pem = { version = "0.8", optional = true } +digest = { version = "0.9.0", features = ["std"] } +sha2 = "0.9.0" [dependencies.zeroize] version = "1.1.0" @@ -36,13 +38,13 @@ default-features = false features = ["std", "derive"] [dev-dependencies] -base64 = "0.11.0" -sha-1 = "0.8.1" -sha2 = "0.8.0" +base64 = "0.12.0" hex = "0.4.0" serde_test = "1.0.89" rand_xorshift = "0.2.0" -pem = "0.7" +pem = "0.8" +sha-1 = "0.9.0" +sha3 = "0.9.0" [[bench]] name = "key" diff --git a/README.md b/README.md index c88eb99d..e3ab26bc 100644 --- a/README.md +++ b/README.md @@ -9,23 +9,21 @@ A portable RSA implementation in pure Rust. ## Example ```rust -extern crate rsa; -extern crate rand; - use rsa::{PublicKey, RSAPrivateKey, PaddingScheme}; use rand::rngs::OsRng; let mut rng = OsRng; let bits = 2048; -let key = RSAPrivateKey::new(&mut rng, bits).expect("failed to generate a key"); +let priv_key = RSAPrivateKey::new(&mut rng, bits).expect("failed to generate a key"); +let pub_key = RSAPublicKey::from(&private_key); // Encrypt let data = b"hello world"; -let enc_data = key.encrypt(&mut rng, PaddingScheme::PKCS1v15, &data[..]).expect("failed to encrypt"); +let enc_data = pub_key.encrypt(&mut rng, PaddingScheme::new_pkcs1v15(), &data[..]).expect("failed to encrypt"); assert_ne!(&data[..], &enc_data[..]); // Decrypt -let dec_data = key.decrypt(PaddingScheme::PKCS1v15, &enc_data).expect("failed to decrypt"); +let dec_data = priv_key.decrypt(PaddingScheme::new_pkcs1v15(), &enc_data).expect("failed to decrypt"); assert_eq!(&data[..], &dec_data[..]); ``` @@ -41,8 +39,8 @@ There will be three phases before `1.0` :ship: can be released. - [x] PKCS1v1.5: Encryption & Decryption :white_check_mark: - [x] PKCS1v1.5: Sign & Verify :white_check_mark: - [ ] PKCS1v1.5 (session key): Encryption & Decryption - - [ ] OAEP: Encryption & Decryption - - [ ] PSS: Sign & Verify + - [x] OAEP: Encryption & Decryption + - [x] PSS: Sign & Verify - [x] Key import & export 2. :rocket: Make it fast - [x] Benchmarks :white_check_mark: diff --git a/benches/key.rs b/benches/key.rs index 5afde12d..68271aa8 100644 --- a/benches/key.rs +++ b/benches/key.rs @@ -6,9 +6,7 @@ use base64; use num_bigint::BigUint; use num_traits::{FromPrimitive, Num}; use rand::{rngs::StdRng, SeedableRng}; -use rsa::hash::Hashes; -use rsa::padding::PaddingScheme; -use rsa::RSAPrivateKey; +use rsa::{Hash, PaddingScheme, RSAPrivateKey}; use sha2::{Digest, Sha256}; use test::Bencher; @@ -33,7 +31,9 @@ fn bench_rsa_2048_pkcsv1_decrypt(b: &mut Bencher) { let x = base64::decode(DECRYPT_VAL).unwrap(); b.iter(|| { - let res = priv_key.decrypt(PaddingScheme::PKCS1v15, &x).unwrap(); + let res = priv_key + .decrypt(PaddingScheme::new_pkcs1v15_encrypt(), &x) + .unwrap(); test::black_box(res); }); } @@ -48,8 +48,7 @@ fn bench_rsa_2048_pkcsv1_sign_blinded(b: &mut Bencher) { let res = priv_key .sign_blinded( &mut rng, - PaddingScheme::PKCS1v15, - Some(&Hashes::SHA2_256), + PaddingScheme::new_pkcs1v15_sign(Some(Hash::SHA2_256)), &digest, ) .unwrap(); diff --git a/src/algorithms.rs b/src/algorithms.rs index a02e884b..02b4d846 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -1,3 +1,4 @@ +use digest::DynDigest; use num_bigint::traits::ModInverse; use num_bigint::{BigUint, RandPrime}; use num_traits::{FromPrimitive, One, Zero}; @@ -110,3 +111,44 @@ pub fn generate_multi_prime_key( primes, )) } + +/// Mask generation function. +/// +/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 +pub fn mgf1_xor(out: &mut [u8], digest: &mut dyn DynDigest, seed: &[u8]) { + let mut counter = [0u8; 4]; + let mut i = 0; + + const MAX_LEN: u64 = std::u32::MAX as u64 + 1; + assert!(out.len() as u64 <= MAX_LEN); + + while i < out.len() { + let mut digest_input = vec![0u8; seed.len() + 4]; + digest_input[0..seed.len()].copy_from_slice(seed); + digest_input[seed.len()..].copy_from_slice(&counter); + + digest.update(digest_input.as_slice()); + let digest_output = &*digest.finalize_reset(); + let mut j = 0; + loop { + if j >= digest_output.len() || i >= out.len() { + break; + } + + out[i] ^= digest_output[j]; + j += 1; + i += 1; + } + inc_counter(&mut counter); + } +} + +fn inc_counter(counter: &mut [u8; 4]) { + for i in (0..4).rev() { + counter[i] = counter[i].wrapping_add(1); + if counter[i] != 0 { + // No overflow + return; + } + } +} diff --git a/src/errors.rs b/src/errors.rs index 949e39bb..18d66a2c 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -35,4 +35,6 @@ pub enum Error { ParseError { reason: String }, #[error("internal error")] Internal, + #[error("label too long")] + LabelTooLong, } diff --git a/src/hash.rs b/src/hash.rs index 6af2ae5b..aa753187 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -1,16 +1,6 @@ -/// A generic trait that exposes the information that is needed for a hash function to be -/// used in `sign` and `verify.`. -pub trait Hash { - /// Returns the length in bytes of a digest. - fn size(&self) -> usize; - - /// Returns the ASN1 DER prefix for the the hash function. - fn asn1_prefix(&self) -> Vec; -} - /// A list of provided hashes, implementing `Hash`. #[derive(Debug, Clone, Copy)] -pub enum Hashes { +pub enum Hash { MD5, SHA1, SHA2_224, @@ -24,67 +14,69 @@ pub enum Hashes { RIPEMD160, } -impl Hash for Hashes { - fn size(&self) -> usize { +impl Hash { + /// Returns the length in bytes of a digest. + pub fn size(&self) -> usize { match *self { - Hashes::MD5 => 16, - Hashes::SHA1 => 20, - Hashes::SHA2_224 => 28, - Hashes::SHA2_256 => 32, - Hashes::SHA2_384 => 48, - Hashes::SHA2_512 => 64, - Hashes::SHA3_256 => 32, - Hashes::SHA3_384 => 48, - Hashes::SHA3_512 => 64, - Hashes::MD5SHA1 => 36, - Hashes::RIPEMD160 => 20, + Hash::MD5 => 16, + Hash::SHA1 => 20, + Hash::SHA2_224 => 28, + Hash::SHA2_256 => 32, + Hash::SHA2_384 => 48, + Hash::SHA2_512 => 64, + Hash::SHA3_256 => 32, + Hash::SHA3_384 => 48, + Hash::SHA3_512 => 64, + Hash::MD5SHA1 => 36, + Hash::RIPEMD160 => 20, } } - fn asn1_prefix(&self) -> Vec { + /// Returns the ASN1 DER prefix for the the hash function. + pub fn asn1_prefix(&self) -> &'static [u8] { match *self { - Hashes::MD5 => vec![ + Hash::MD5 => &[ 0x30, 0x20, 0x30, 0x0c, 0x06, 0x08, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x02, 0x05, 0x05, 0x00, 0x04, 0x10, ], - Hashes::SHA1 => vec![ + Hash::SHA1 => &[ 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, ], - Hashes::SHA2_224 => vec![ + Hash::SHA2_224 => &[ 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, 0x00, 0x04, 0x1c, ], - Hashes::SHA2_256 => vec![ + Hash::SHA2_256 => &[ 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20, ], - Hashes::SHA2_384 => vec![ + Hash::SHA2_384 => &[ 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30, ], - Hashes::SHA2_512 => vec![ + Hash::SHA2_512 => &[ 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40, ], // A special TLS case which doesn't use an ASN1 prefix - Hashes::MD5SHA1 => Vec::new(), - Hashes::RIPEMD160 => vec![ + Hash::MD5SHA1 => &[], + Hash::RIPEMD160 => &[ 0x30, 0x20, 0x30, 0x08, 0x06, 0x06, 0x28, 0xcf, 0x06, 0x03, 0x00, 0x31, 0x04, 0x14, ], - Hashes::SHA3_256 => vec![ + Hash::SHA3_256 => &[ 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x08, 0x05, 0x00, 0x04, 0x20, ], - Hashes::SHA3_384 => vec![ + Hash::SHA3_384 => &[ 30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x08, 0x05, 0x00, 0x04, 0x20, ], - Hashes::SHA3_512 => vec![ + Hash::SHA3_512 => &[ 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x0a, 0x05, 0x00, 0x04, 0x40, ], diff --git a/src/key.rs b/src/key.rs index b6f610a1..91a9da4e 100644 --- a/src/key.rs +++ b/src/key.rs @@ -5,14 +5,15 @@ use num_traits::{FromPrimitive, One}; use rand::{rngs::ThreadRng, Rng}; #[cfg(feature = "serde")] use serde_crate::{Deserialize, Serialize}; +use std::ops::Deref; use zeroize::Zeroize; use crate::algorithms::generate_multi_prime_key; use crate::errors::{Error, Result}; -use crate::hash::Hash; + use crate::padding::PaddingScheme; -use crate::pkcs1v15; use crate::raw::{DecryptionPrimitive, EncryptionPrimitive}; +use crate::{oaep, pkcs1v15, pss}; lazy_static! { static ref MIN_PUB_EXPONENT: BigUint = BigUint::from_u64(2).unwrap(); @@ -22,8 +23,10 @@ lazy_static! { pub trait PublicKeyParts { /// Returns the modulus of the key. fn n(&self) -> &BigUint; + /// Returns the public exponent of the key. fn e(&self) -> &BigUint; + /// Returns the modulus size in bytes. Raw signatures and ciphertexts for /// or by this public key will have the same size. fn size(&self) -> usize { @@ -53,10 +56,8 @@ pub struct RSAPublicKey { serde(crate = "serde_crate") )] pub struct RSAPrivateKey { - /// Modulus - n: BigUint, - /// Public exponent - e: BigUint, + /// Public components of the private key. + pubkey_components: RSAPublicKey, /// Private exponent d: BigUint, /// Prime factors of N, contains >= 2 elements. @@ -69,7 +70,9 @@ pub struct RSAPrivateKey { impl PartialEq for RSAPrivateKey { #[inline] fn eq(&self, other: &RSAPrivateKey) -> bool { - self.n == other.n && self.e == other.e && self.d == other.d && self.primes == other.primes + self.pubkey_components == other.pubkey_components + && self.d == other.d + && self.primes == other.primes } } @@ -94,6 +97,13 @@ impl Drop for RSAPrivateKey { } } +impl Deref for RSAPrivateKey { + type Target = RSAPublicKey; + fn deref(&self) -> &RSAPublicKey { + &self.pubkey_components + } +} + #[derive(Debug, Clone)] pub(crate) struct PrecomputedValues { /// D mod (P-1) @@ -163,13 +173,7 @@ pub trait PublicKey: EncryptionPrimitive + PublicKeyParts { /// `hashed`must be the result of hashing the input using the hashing function /// passed in through `hash`. /// If the message is valid `Ok(())` is returned, otherwiese an `Err` indicating failure. - fn verify( - &self, - padding: PaddingScheme, - hash: Option<&H>, - hashed: &[u8], - sig: &[u8], - ) -> Result<()>; + fn verify(&self, padding: PaddingScheme, hashed: &[u8], sig: &[u8]) -> Result<()>; } impl PublicKeyParts for RSAPublicKey { @@ -185,22 +189,20 @@ impl PublicKeyParts for RSAPublicKey { impl PublicKey for RSAPublicKey { fn encrypt(&self, rng: &mut R, padding: PaddingScheme, msg: &[u8]) -> Result> { match padding { - PaddingScheme::PKCS1v15 => pkcs1v15::encrypt(rng, self, msg), - PaddingScheme::OAEP => unimplemented!("not yet implemented"), + PaddingScheme::PKCS1v15Encrypt => pkcs1v15::encrypt(rng, self, msg), + PaddingScheme::OAEP { mut digest, label } => { + oaep::encrypt(rng, self, msg, &mut *digest, label) + } _ => Err(Error::InvalidPaddingScheme), } } - fn verify( - &self, - padding: PaddingScheme, - hash: Option<&H>, - hashed: &[u8], - sig: &[u8], - ) -> Result<()> { + fn verify(&self, padding: PaddingScheme, hashed: &[u8], sig: &[u8]) -> Result<()> { match padding { - PaddingScheme::PKCS1v15 => pkcs1v15::verify(self, hash, hashed, sig), - PaddingScheme::PSS => unimplemented!("not yet implemented"), + PaddingScheme::PKCS1v15Sign { ref hash } => { + pkcs1v15::verify(self, hash.as_ref(), hashed, sig) + } + PaddingScheme::PSS { mut digest, .. } => pss::verify(self, hashed, sig, &mut *digest), _ => Err(Error::InvalidPaddingScheme), } } @@ -285,10 +287,12 @@ impl RSAPublicKey { } impl<'a> PublicKeyParts for &'a RSAPublicKey { + /// Returns the modulus of the key. fn n(&self) -> &BigUint { &self.n } + /// Returns the public exponent of the key. fn e(&self) -> &BigUint { &self.e } @@ -299,14 +303,8 @@ impl<'a> PublicKey for &'a RSAPublicKey { (*self).encrypt(rng, padding, msg) } - fn verify( - &self, - padding: PaddingScheme, - hash: Option<&H>, - hashed: &[u8], - sig: &[u8], - ) -> Result<()> { - (*self).verify(padding, hash, hashed, sig) + fn verify(&self, padding: PaddingScheme, hashed: &[u8], sig: &[u8]) -> Result<()> { + (*self).verify(padding, hashed, sig) } } @@ -348,8 +346,7 @@ impl RSAPrivateKey { primes: Vec, ) -> RSAPrivateKey { let mut k = RSAPrivateKey { - n, - e, + pubkey_components: RSAPublicKey { n, e }, d, primes, precomputed: None, @@ -544,13 +541,18 @@ impl RSAPrivateKey { pub fn decrypt(&self, padding: PaddingScheme, ciphertext: &[u8]) -> Result> { match padding { // need to pass any Rng as the type arg, so the type checker is happy, it is not actually used for anything - PaddingScheme::PKCS1v15 => pkcs1v15::decrypt::(None, self, ciphertext), - PaddingScheme::OAEP => unimplemented!("not yet implemented"), + PaddingScheme::PKCS1v15Encrypt => { + pkcs1v15::decrypt::(None, self, ciphertext) + } + PaddingScheme::OAEP { mut digest, label } => { + oaep::decrypt::(None, self, ciphertext, &mut *digest, label) + } _ => Err(Error::InvalidPaddingScheme), } } /// Decrypt the given message. + /// /// Uses `rng` to blind the decryption process. pub fn decrypt_blinded( &self, @@ -559,38 +561,61 @@ impl RSAPrivateKey { ciphertext: &[u8], ) -> Result> { match padding { - PaddingScheme::PKCS1v15 => pkcs1v15::decrypt(Some(rng), self, ciphertext), - PaddingScheme::OAEP => unimplemented!("not yet implemented"), + PaddingScheme::PKCS1v15Encrypt => pkcs1v15::decrypt(Some(rng), self, ciphertext), + PaddingScheme::OAEP { mut digest, label } => { + oaep::decrypt(Some(rng), self, ciphertext, &mut *digest, label) + } _ => Err(Error::InvalidPaddingScheme), } } /// Sign the given digest. - pub fn sign( - &self, - padding: PaddingScheme, - hash: Option<&H>, - digest: &[u8], - ) -> Result> { + pub fn sign(&self, padding: PaddingScheme, digest_in: &[u8]) -> Result> { match padding { - PaddingScheme::PKCS1v15 => pkcs1v15::sign::(None, self, hash, digest), - PaddingScheme::PSS => unimplemented!("not yet implemented"), + PaddingScheme::PKCS1v15Sign { ref hash } => { + pkcs1v15::sign::(None, self, hash.as_ref(), digest_in) + } + PaddingScheme::PSS { + mut salt_rng, + mut digest, + salt_len, + } => pss::sign::<_, ThreadRng, _>( + &mut *salt_rng, + None, + self, + digest_in, + salt_len, + &mut *digest, + ), _ => Err(Error::InvalidPaddingScheme), } } /// Sign the given digest. + /// /// Use `rng` for blinding. - pub fn sign_blinded( + pub fn sign_blinded( &self, rng: &mut R, padding: PaddingScheme, - hash: Option<&H>, - digest: &[u8], + digest_in: &[u8], ) -> Result> { match padding { - PaddingScheme::PKCS1v15 => pkcs1v15::sign(Some(rng), self, hash, digest), - PaddingScheme::PSS => unimplemented!("not yet implemented"), + PaddingScheme::PKCS1v15Sign { ref hash } => { + pkcs1v15::sign(Some(rng), self, hash.as_ref(), digest_in) + } + PaddingScheme::PSS { + mut salt_rng, + mut digest, + salt_len, + } => pss::sign::<_, R, _>( + &mut *salt_rng, + Some(rng), + self, + digest_in, + salt_len, + &mut *digest, + ), _ => Err(Error::InvalidPaddingScheme), } } @@ -614,14 +639,21 @@ pub fn check_public(public_key: &impl PublicKeyParts) -> Result<()> { mod tests { use super::*; use crate::internals; + + use digest::{Digest, DynDigest}; use num_traits::{FromPrimitive, ToPrimitive}; - use rand::{rngs::ThreadRng, thread_rng}; + use rand::{distributions::Alphanumeric, rngs::ThreadRng, thread_rng}; + use sha1::Sha1; + use sha2::{Sha224, Sha256, Sha384, Sha512}; + use sha3::{Sha3_256, Sha3_384, Sha3_512}; #[test] fn test_from_into() { let private_key = RSAPrivateKey { - n: BigUint::from_u64(100).unwrap(), - e: BigUint::from_u64(200).unwrap(), + pubkey_components: RSAPublicKey { + n: BigUint::from_u64(100).unwrap(), + e: BigUint::from_u64(200).unwrap(), + }, d: BigUint::from_u64(123).unwrap(), primes: vec![], precomputed: None, @@ -729,7 +761,12 @@ mod tests { let priv_tokens = [ Token::Struct { name: "RSAPrivateKey", - len: 4, + len: 3, + }, + Token::Str("pubkey_components"), + Token::Struct { + name: "RSAPublicKey", + len: 2, }, Token::Str("n"), Token::Seq { len: Some(2) }, @@ -740,6 +777,7 @@ mod tests { Token::Seq { len: Some(1) }, Token::U32(65537), Token::SeqEnd, + Token::StructEnd, Token::Str("d"), Token::Seq { len: Some(2) }, Token::U32(298985985), @@ -797,4 +835,127 @@ mod tests { primes.iter().map(|p| BigUint::from_bytes_be(p)).collect(), ); } + + fn get_private_key() -> RSAPrivateKey { + // -----BEGIN RSA PRIVATE KEY----- + // MIIEpAIBAAKCAQEA05e4TZikwmE47RtpWoEG6tkdVTvwYEG2LT/cUKBB4iK49FKW + // icG4LF5xVU9d1p+i9LYVjPDb61eBGg/DJ+HyjnT+dNO8Fmweq9wbi1e5NMqL5bAL + // TymXW8yZrK9BW1m7KKZ4K7QaLDwpdrPBjbre9i8AxrsiZkAJUJbAzGDSL+fvmH11 + // xqgbENlr8pICivEQ3HzBu8Q9Iq2rN5oM1dgHjMeA/1zWIJ3qNMkiz3hPdxfkKNdb + // WuyP8w5fAUFRB2bi4KuNRzyE6HELK5gifD2wlTN600UvGeK5v7zN2BSKv2d2+lUn + // debnWVbkUimuWpxGlJurHmIvDkj1ZSSoTtNIOwIDAQABAoIBAQDE5wxokWLJTGYI + // KBkbUrTYOSEV30hqmtvoMeRY1zlYMg3Bt1VFbpNwHpcC12+wuS+Q4B0f4kgVMoH+ + // eaqXY6kvrmnY1+zRRN4p+hNb0U+Vc+NJ5FAx47dpgvWDADgmxVLomjl8Gga9IWNI + // hjDZLowrtkPXq+9wDaldaFyUFImkb1S1MW9itdLDp/G70TTLNzU6RGg/3J2V02RY + // 3iL2xEBX/nSgpDbEMI9z9NpC81xHrBanE41IOvyR5B3DoRJzguDA9RGbAiG0/GOd + // a5w4F3pt6bUm69iMONeYLAf5ig79h31Qiq4nW5RpFcAuLhEG0XXXTsZ3f16A0SwF + // PZx74eNBAoGBAPgnu/OkGHfHzFmuv0LtSynDLe/LjtloY9WwkKBaiTDdYkohydz5 + // g4Vo/foN9luEYqXyrJE9bFb5dVMr2OePsHvUBcqZpIS89Z8Bm73cs5M/K85wYwC0 + // 97EQEgxd+QGBWQZ8NdowYaVshjWlK1QnOzEnG0MR8Hld9gIeY1XhpC5hAoGBANpI + // F84Aid028q3mo/9BDHPsNL8bT2vaOEMb/t4RzvH39u+nDl+AY6Ox9uFylv+xX+76 + // CRKgMluNH9ZaVZ5xe1uWHsNFBy4OxSA9A0QdKa9NZAVKBFB0EM8dp457YRnZCexm + // 5q1iW/mVsnmks8W+fYlc18W5xMSX/ecwkW/NtOQbAoGAHabpz4AhKFbodSLrWbzv + // CUt4NroVFKdjnoodjfujfwJFF2SYMV5jN9LG3lVCxca43ulzc1tqka33Nfv8TBcg + // WHuKQZ5ASVgm5VwU1wgDMSoQOve07MWy/yZTccTc1zA0ihDXgn3bfR/NnaVh2wlh + // CkuI92eyW1494hztc7qlmqECgYEA1zenyOQ9ChDIW/ABGIahaZamNxsNRrDFMl3j + // AD+cxHSRU59qC32CQH8ShRy/huHzTaPX2DZ9EEln76fnrS4Ey7uLH0rrFl1XvT6K + // /timJgLvMEvXTx/xBtUdRN2fUqXtI9odbSyCtOYFL+zVl44HJq2UzY4pVRDrNcxs + // SUkQJqsCgYBSaNfPBzR5rrstLtTdZrjImRW1LRQeDEky9WsMDtCTYUGJTsTSfVO8 + // hkU82MpbRVBFIYx+GWIJwcZRcC7OCQoV48vMJllxMAAjqG/p00rVJ+nvA7et/nNu + // BoB0er/UmDm4Ly/97EO9A0PKMOE5YbMq9s3t3RlWcsdrU7dvw+p2+A== + // -----END RSA PRIVATE KEY----- + + RSAPrivateKey::from_components( + BigUint::parse_bytes(b"00d397b84d98a4c26138ed1b695a8106ead91d553bf06041b62d3fdc50a041e222b8f4529689c1b82c5e71554f5dd69fa2f4b6158cf0dbeb57811a0fc327e1f28e74fe74d3bc166c1eabdc1b8b57b934ca8be5b00b4f29975bcc99acaf415b59bb28a6782bb41a2c3c2976b3c18dbadef62f00c6bb226640095096c0cc60d22fe7ef987d75c6a81b10d96bf292028af110dc7cc1bbc43d22adab379a0cd5d8078cc780ff5cd6209dea34c922cf784f7717e428d75b5aec8ff30e5f0141510766e2e0ab8d473c84e8710b2b98227c3db095337ad3452f19e2b9bfbccdd8148abf6776fa552775e6e75956e45229ae5a9c46949bab1e622f0e48f56524a84ed3483b", 16).unwrap(), + BigUint::from_u64(65537).unwrap(), + BigUint::parse_bytes(b"00c4e70c689162c94c660828191b52b4d8392115df486a9adbe831e458d73958320dc1b755456e93701e9702d76fb0b92f90e01d1fe248153281fe79aa9763a92fae69d8d7ecd144de29fa135bd14f9573e349e45031e3b76982f583003826c552e89a397c1a06bd2163488630d92e8c2bb643d7abef700da95d685c941489a46f54b5316f62b5d2c3a7f1bbd134cb37353a44683fdc9d95d36458de22f6c44057fe74a0a436c4308f73f4da42f35c47ac16a7138d483afc91e41dc3a1127382e0c0f5119b0221b4fc639d6b9c38177a6de9b526ebd88c38d7982c07f98a0efd877d508aae275b946915c02e2e1106d175d74ec6777f5e80d12c053d9c7be1e341", 16).unwrap(), + vec![ + BigUint::parse_bytes(b"00f827bbf3a41877c7cc59aebf42ed4b29c32defcb8ed96863d5b090a05a8930dd624a21c9dcf9838568fdfa0df65b8462a5f2ac913d6c56f975532bd8e78fb07bd405ca99a484bcf59f019bbddcb3933f2bce706300b4f7b110120c5df9018159067c35da3061a56c8635a52b54273b31271b4311f0795df6021e6355e1a42e61",16).unwrap(), + BigUint::parse_bytes(b"00da4817ce0089dd36f2ade6a3ff410c73ec34bf1b4f6bda38431bfede11cef1f7f6efa70e5f8063a3b1f6e17296ffb15feefa0912a0325b8d1fd65a559e717b5b961ec345072e0ec5203d03441d29af4d64054a04507410cf1da78e7b6119d909ec66e6ad625bf995b279a4b3c5be7d895cd7c5b9c4c497fde730916fcdb4e41b", 16).unwrap() + ], + ) + } + + #[test] + fn test_encrypt_decrypt_oaep() { + let priv_key = get_private_key(); + do_test_encrypt_decrypt_oaep::(&priv_key); + do_test_encrypt_decrypt_oaep::(&priv_key); + do_test_encrypt_decrypt_oaep::(&priv_key); + do_test_encrypt_decrypt_oaep::(&priv_key); + do_test_encrypt_decrypt_oaep::(&priv_key); + do_test_encrypt_decrypt_oaep::(&priv_key); + do_test_encrypt_decrypt_oaep::(&priv_key); + do_test_encrypt_decrypt_oaep::(&priv_key); + } + + fn do_test_encrypt_decrypt_oaep(prk: &RSAPrivateKey) { + let mut rng = thread_rng(); + + let k = prk.size(); + + for i in 1..8 { + let mut input: Vec = (0..i * 8).map(|_| rng.gen()).collect(); + if input.len() > k - 11 { + input = input[0..k - 11].to_vec(); + } + let has_label: bool = rng.gen(); + let label: Option = if has_label { + Some(rng.sample_iter(&Alphanumeric).take(30).collect()) + } else { + None + }; + + let pub_key: RSAPublicKey = prk.into(); + + let ciphertext = if let Some(ref label) = label { + let padding = PaddingScheme::new_oaep_with_label::(label); + pub_key.encrypt(&mut rng, padding, &input).unwrap() + } else { + let padding = PaddingScheme::new_oaep::(); + pub_key.encrypt(&mut rng, padding, &input).unwrap() + }; + + assert_ne!(input, ciphertext); + let blind: bool = rng.gen(); + + let padding = if let Some(ref label) = label { + PaddingScheme::new_oaep_with_label::(label) + } else { + PaddingScheme::new_oaep::() + }; + + let plaintext = if blind { + prk.decrypt(padding, &ciphertext).unwrap() + } else { + prk.decrypt_blinded(&mut rng, padding, &ciphertext).unwrap() + }; + + assert_eq!(input, plaintext); + } + } + + #[test] + fn test_decrypt_oaep_invalid_hash() { + let mut rng = thread_rng(); + let priv_key = get_private_key(); + let pub_key: RSAPublicKey = (&priv_key).into(); + let ciphertext = pub_key + .encrypt( + &mut rng, + PaddingScheme::new_oaep::(), + "a_plain_text".as_bytes(), + ) + .unwrap(); + assert!( + priv_key + .decrypt_blinded( + &mut rng, + PaddingScheme::new_oaep_with_label::("label"), + &ciphertext, + ) + .is_err(), + "decrypt should have failed on hash verification" + ); + } } diff --git a/src/lib.rs b/src/lib.rs index 395fb0a5..dc3c8484 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,10 +4,8 @@ //! //! # Usage //! +//! Using PKCS1v15. //! ``` -//! extern crate rsa; -//! extern crate rand; -//! //! use rsa::{PublicKey, RSAPrivateKey, RSAPublicKey, PaddingScheme}; //! use rand::rngs::OsRng; //! @@ -18,23 +16,41 @@ //! //! // Encrypt //! let data = b"hello world"; -//! let enc_data = public_key.encrypt(&mut rng, PaddingScheme::PKCS1v15, &data[..]).expect("failed to encrypt"); +//! let padding = PaddingScheme::new_pkcs1v15_encrypt(); +//! let enc_data = public_key.encrypt(&mut rng, padding, &data[..]).expect("failed to encrypt"); //! assert_ne!(&data[..], &enc_data[..]); //! //! // Decrypt -//! let dec_data = private_key.decrypt(PaddingScheme::PKCS1v15, &enc_data).expect("failed to decrypt"); +//! let padding = PaddingScheme::new_pkcs1v15_encrypt(); +//! let dec_data = private_key.decrypt(padding, &enc_data).expect("failed to decrypt"); //! assert_eq!(&data[..], &dec_data[..]); //! ``` //! +//! Using OAEP. +//! ``` +//! use rsa::{PublicKey, RSAPrivateKey, RSAPublicKey, PaddingScheme}; +//! use rand::rngs::OsRng; +//! +//! let mut rng = OsRng; +//! let bits = 2048; +//! let private_key = RSAPrivateKey::new(&mut rng, bits).expect("failed to generate a key"); +//! let public_key = RSAPublicKey::from(&private_key); +//! +//! // Encrypt +//! let data = b"hello world"; +//! let padding = PaddingScheme::new_oaep::(); +//! let enc_data = public_key.encrypt(&mut rng, padding, &data[..]).expect("failed to encrypt"); +//! assert_ne!(&data[..], &enc_data[..]); +//! +//! // Decrypt +//! let padding = PaddingScheme::new_oaep::(); +//! let dec_data = private_key.decrypt(padding, &enc_data).expect("failed to decrypt"); +//! assert_eq!(&data[..], &dec_data[..]); +//! ``` #[macro_use] extern crate lazy_static; -extern crate num_iter; -extern crate rand; -extern crate subtle; -extern crate zeroize; - #[cfg(feature = "serde")] extern crate serde_crate; @@ -63,11 +79,14 @@ pub mod padding; pub use pem; mod key; +mod oaep; mod parse; mod pkcs1v15; +mod pss; mod raw; -pub use self::key::{PublicKey, RSAPrivateKey, RSAPublicKey}; +pub use self::hash::Hash; +pub use self::key::{PublicKey, PublicKeyParts, RSAPrivateKey, RSAPublicKey}; pub use self::padding::PaddingScheme; // Optionally expose internals if requested via feature-flag. diff --git a/src/oaep.rs b/src/oaep.rs new file mode 100644 index 00000000..00fe6392 --- /dev/null +++ b/src/oaep.rs @@ -0,0 +1,152 @@ +use rand::Rng; + +use digest::DynDigest; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; + +use crate::algorithms::mgf1_xor; +use crate::errors::{Error, Result}; +use crate::key::{self, PrivateKey, PublicKey}; + +// 2**61 -1 (pow is not const yet) +// TODO: This is the maximum for SHA-1, unclear from the RFC what the values are for other hashing functions. +const MAX_LABEL_LEN: u64 = 2_305_843_009_213_693_951; + +/// Encrypts the given message with RSA and the padding +/// scheme from PKCS#1 OAEP. The message must be no longer than the +/// length of the public modulus minus (2+ 2*hash.size()). +#[inline] +pub fn encrypt( + rng: &mut R, + pub_key: &K, + msg: &[u8], + digest: &mut dyn DynDigest, + label: Option, +) -> Result> { + key::check_public(pub_key)?; + + let k = pub_key.size(); + + let h_size = digest.output_size(); + + if msg.len() + 2 * h_size + 2 > k { + return Err(Error::MessageTooLong); + } + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::LabelTooLong); + } + + let mut em = vec![0u8; k]; + + let (_, payload) = em.split_at_mut(1); + let (seed, db) = payload.split_at_mut(h_size); + rng.fill(seed); + + // Data block DB = pHash || PS || 01 || M + let db_len = k - h_size - 1; + + digest.update(label.as_bytes()); + let p_hash = digest.finalize_reset(); + db[0..h_size].copy_from_slice(&*p_hash); + db[db_len - msg.len() - 1] = 1; + db[db_len - msg.len()..].copy_from_slice(msg); + + mgf1_xor(db, digest, seed); + mgf1_xor(seed, digest, db); + + pub_key.raw_encryption_primitive(&em, pub_key.size()) +} + +/// Decrypts a plaintext using RSA and the padding scheme from pkcs1# OAEP +/// If an `rng` is passed, it uses RSA blinding to avoid timing side-channel attacks. +/// +/// Note that whether this function returns an error or not discloses secret +/// information. If an attacker can cause this function to run repeatedly and +/// learn whether each instance returned an error then they can decrypt and +/// forge signatures as if they had the private key. See +/// `decrypt_session_key` for a way of solving this problem. +#[inline] +pub fn decrypt( + rng: Option<&mut R>, + priv_key: &SK, + ciphertext: &[u8], + digest: &mut dyn DynDigest, + label: Option, +) -> Result> { + key::check_public(priv_key)?; + + let res = decrypt_inner(rng, priv_key, ciphertext, digest, label)?; + if res.is_none().into() { + return Err(Error::Decryption); + } + + let (out, index) = res.unwrap(); + + Ok(out[index as usize..].to_vec()) +} + +/// Decrypts ciphertext using `priv_key` and blinds the operation if +/// `rng` is given. It returns one or zero in valid that indicates whether the +/// plaintext was correctly structured. +#[inline] +fn decrypt_inner( + rng: Option<&mut R>, + priv_key: &SK, + ciphertext: &[u8], + digest: &mut dyn DynDigest, + label: Option, +) -> Result, u32)>> { + let k = priv_key.size(); + if k < 11 { + return Err(Error::Decryption); + } + + let h_size = digest.output_size(); + + if ciphertext.len() != k || k < h_size * 2 + 2 { + return Err(Error::Decryption); + } + + let mut em = priv_key.raw_decryption_primitive(rng, ciphertext, priv_key.size())?; + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::LabelTooLong); + } + + digest.update(label.as_bytes()); + + let expected_p_hash = &*digest.finalize_reset(); + + let first_byte_is_zero = em[0].ct_eq(&0u8); + + let (_, payload) = em.split_at_mut(1); + let (seed, db) = payload.split_at_mut(h_size); + + mgf1_xor(seed, digest, db); + mgf1_xor(db, digest, seed); + + let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash); + + // The remainder of the plaintext must be zero or more 0x00, followed + // by 0x01, followed by the message. + // looking_for_index: 1 if we are still looking for the 0x01 + // index: the offset of the first 0x01 byte + // zero_before_one: 1 if we saw a non-zero byte before the 1 + let mut looking_for_index = Choice::from(1u8); + let mut index = 0u32; + let mut nonzero_before_one = Choice::from(0u8); + + for (i, el) in db.iter().skip(h_size).enumerate() { + let equals0 = el.ct_eq(&0u8); + let equals1 = el.ct_eq(&1u8); + index.conditional_assign(&(i as u32), looking_for_index & equals1); + looking_for_index &= !equals1; + nonzero_before_one |= looking_for_index & !equals0; + } + + let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index; + + Ok(CtOption::new((em, index + 2 + (h_size * 2) as u32), valid)) +} diff --git a/src/padding.rs b/src/padding.rs index 4525ecb2..6b5e2ccf 100644 --- a/src/padding.rs +++ b/src/padding.rs @@ -1,7 +1,87 @@ +use std::fmt; + +use digest::{Digest, DynDigest}; +use rand::RngCore; + +use crate::hash::Hash; + /// Available padding schemes. -#[derive(Debug, Clone, Copy)] pub enum PaddingScheme { - PKCS1v15, - OAEP, - PSS, + /// Encryption and Decryption using PKCS1v15 padding. + PKCS1v15Encrypt, + /// Sign and Verify using PKCS1v15 padding. + PKCS1v15Sign { hash: Option }, + /// Encryption and Decryption using OAEP padding. + OAEP { + digest: Box, + label: Option, + }, + /// Sign and Verify using PSS padding. + PSS { + salt_rng: Box, + digest: Box, + salt_len: Option, + }, +} + +impl fmt::Debug for PaddingScheme { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PaddingScheme::PKCS1v15Encrypt => write!(f, "PaddingScheme::PKCS1v15Encrypt"), + PaddingScheme::PKCS1v15Sign { ref hash } => { + write!(f, "PaddingScheme::PKCS1v15Sign({:?})", hash) + } + PaddingScheme::OAEP { ref label, .. } => { + // TODO: How to print the digest name? + write!(f, "PaddingScheme::OAEP({:?})", label) + } + PaddingScheme::PSS { ref salt_len, .. } => { + // TODO: How to print the digest name? + write!(f, "PaddingScheme::PSS(salt_len: {:?})", salt_len) + } + } + } +} + +impl PaddingScheme { + pub fn new_pkcs1v15_encrypt() -> Self { + PaddingScheme::PKCS1v15Encrypt + } + + pub fn new_pkcs1v15_sign(hash: Option) -> Self { + PaddingScheme::PKCS1v15Sign { hash } + } + + pub fn new_oaep() -> Self { + PaddingScheme::OAEP { + digest: Box::new(T::new()), + label: None, + } + } + + pub fn new_oaep_with_label>(label: S) -> Self { + PaddingScheme::OAEP { + digest: Box::new(T::new()), + label: Some(label.as_ref().to_string()), + } + } + + pub fn new_pss(rng: S) -> Self { + PaddingScheme::PSS { + salt_rng: Box::new(rng), + digest: Box::new(T::new()), + salt_len: None, + } + } + + pub fn new_pss_with_salt( + rng: S, + len: usize, + ) -> Self { + PaddingScheme::PSS { + salt_rng: Box::new(rng), + digest: Box::new(T::new()), + salt_len: Some(len), + } + } } diff --git a/src/parse.rs b/src/parse.rs index 17f9ffe5..e7db888e 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -23,13 +23,13 @@ impl TryFrom for RSAPrivateKey { /// Expects one of the following `pem` headers: /// - `-----BEGIN PRIVATE KEY-----` /// - `-----BEGIN RSA PRIVATE KEY-----` - /// + /// /// # Example - /// + /// /// ``` /// use std::convert::TryFrom; /// use rsa::RSAPrivateKey; - /// + /// /// # // openssl genrsa -out tiny_key.pem 512 /// let file_content = r#" /// -----BEGIN RSA PRIVATE KEY----- @@ -42,7 +42,7 @@ impl TryFrom for RSAPrivateKey { /// JdwDGF7Kanex70KAacmOlw3vfx6XWT+2PH6Qh8tLug== /// -----END RSA PRIVATE KEY----- /// "#; - /// + /// /// let pem = rsa::pem::parse(file_content).expect("failed to parse pem file"); /// let private_key = RSAPrivateKey::try_from(pem).expect("failed to parse key"); /// ``` @@ -66,13 +66,13 @@ impl TryFrom for RSAPublicKey { /// Expects one of the following `pem` headers: /// - `-----BEGIN PUBLIC KEY-----` /// - `-----BEGIN RSA PUBLIC KEY-----` - /// + /// /// # Example - /// + /// /// ``` /// use std::convert::TryFrom; /// use rsa::RSAPublicKey; - /// + /// /// # // openssl rsa -in tiny_key.pem -outform PEM -pubout -out tiny_key.pub.pem /// let file_content = r#" /// -----BEGIN PUBLIC KEY----- @@ -80,7 +80,7 @@ impl TryFrom for RSAPublicKey { /// si2oPAUmNw2Z/qb2Sr/BEBoWpagFf8Gl1K4PRipJSudDl6N/Vdb2CYkCAwEAAQ== /// -----END PUBLIC KEY----- /// "#; - /// + /// /// let pem = rsa::pem::parse(file_content).expect("failed to parse pem file"); /// let public_key = RSAPublicKey::try_from(pem).expect("failed to parse key"); /// ``` @@ -404,11 +404,15 @@ VY8J0wvbOtL9NjCMy6zz1zQ+N7oJ9mdhNwIDAQAB let clear_text = "Hello, World!"; let encrypted = public_key - .encrypt(rng, PaddingScheme::PKCS1v15, clear_text.as_bytes()) + .encrypt( + rng, + PaddingScheme::new_pkcs1v15_encrypt(), + clear_text.as_bytes(), + ) .expect("encrypt failed"); let decrypted = private_key - .decrypt(PaddingScheme::PKCS1v15, &encrypted) + .decrypt(PaddingScheme::new_pkcs1v15_encrypt(), &encrypted) .expect("decrypt failed"); assert_eq!( diff --git a/src/pkcs1v15.rs b/src/pkcs1v15.rs index d0a2dc1d..5e4ad496 100644 --- a/src/pkcs1v15.rs +++ b/src/pkcs1v15.rs @@ -9,7 +9,7 @@ use crate::key::{self, PrivateKey, PublicKey}; // scheme from PKCS#1 v1.5. The message must be no longer than the // length of the public modulus minus 11 bytes. #[inline] -pub fn encrypt(rng: &mut R, pub_key: &K, msg: &[u8]) -> Result> { +pub fn encrypt(rng: &mut R, pub_key: &PK, msg: &[u8]) -> Result> { key::check_public(pub_key)?; let k = pub_key.size(); @@ -24,7 +24,7 @@ pub fn encrypt(rng: &mut R, pub_key: &K, msg: &[u8]) -> Re em[k - msg.len() - 1] = 0; em[k - msg.len()..].copy_from_slice(msg); - pub_key.raw_encryption_primitive(&em) + pub_key.raw_encryption_primitive(&em, pub_key.size()) } /// Decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5. @@ -65,10 +65,10 @@ pub fn decrypt( // messages to signatures and identify the signed messages. As ever, // signatures provide authenticity, not confidentiality. #[inline] -pub fn sign( +pub fn sign( rng: Option<&mut R>, priv_key: &SK, - hash: Option<&H>, + hash: Option<&Hash>, hashed: &[u8], ) -> Result> { let (hash_len, prefix) = hash_info(hash, hashed.len())?; @@ -87,14 +87,14 @@ pub fn sign( em[k - t_len..k - hash_len].copy_from_slice(&prefix); em[k - hash_len..k].copy_from_slice(hashed); - priv_key.raw_decryption_primitive(rng, &em) + priv_key.raw_decryption_primitive(rng, &em, priv_key.size()) } /// Verifies an RSA PKCS#1 v1.5 signature. #[inline] -pub fn verify( - pub_key: &K, - hash: Option<&H>, +pub fn verify( + pub_key: &PK, + hash: Option<&Hash>, hashed: &[u8], sig: &[u8], ) -> Result<()> { @@ -106,7 +106,7 @@ pub fn verify( return Err(Error::Verification); } - let em = pub_key.raw_encryption_primitive(sig)?; + let em = pub_key.raw_encryption_primitive(sig, pub_key.size())?; // EM = 0x00 || 0x01 || PS || 0x00 || T let mut ok = em[0].ct_eq(&0u8); @@ -127,7 +127,7 @@ pub fn verify( } #[inline] -fn hash_info(hash: Option<&H>, digest_len: usize) -> Result<(usize, Vec)> { +fn hash_info(hash: Option<&Hash>, digest_len: usize) -> Result<(usize, &'static [u8])> { match hash { Some(hash) => { let hash_len = hash.size(); @@ -138,7 +138,7 @@ fn hash_info(hash: Option<&H>, digest_len: usize) -> Result<(usize, Vec Ok((hash_len, hash.asn1_prefix())) } // this means the data is signed directly - None => Ok((digest_len, Vec::new())), + None => Ok((digest_len, &[])), } } @@ -159,7 +159,7 @@ fn decrypt_inner( return Err(Error::Decryption); } - let em = priv_key.raw_decryption_primitive(rng, ciphertext)?; + let em = priv_key.raw_decryption_primitive(rng, ciphertext, priv_key.size())?; let first_byte_is_zero = em[0].ct_eq(&0u8); let second_byte_is_two = em[1].ct_eq(&2u8); @@ -218,9 +218,7 @@ mod tests { use rand::thread_rng; use sha1::{Digest, Sha1}; - use crate::hash::Hashes; - use crate::key::{PublicKeyParts, RSAPrivateKey, RSAPublicKey}; - use crate::padding::PaddingScheme; + use crate::{Hash, PaddingScheme, PublicKey, PublicKeyParts, RSAPrivateKey, RSAPublicKey}; #[test] fn test_non_zero_bytes() { @@ -277,7 +275,10 @@ mod tests { for test in &tests { let out = priv_key - .decrypt(PaddingScheme::PKCS1v15, &base64::decode(test[0]).unwrap()) + .decrypt( + PaddingScheme::new_pkcs1v15_encrypt(), + &base64::decode(test[0]).unwrap(), + ) .unwrap(); assert_eq!(out, test[1].as_bytes()); } @@ -318,7 +319,7 @@ mod tests { let expected = hex::decode(test[1]).unwrap(); let out = priv_key - .sign(PaddingScheme::PKCS1v15, Some(&Hashes::SHA1), &digest) + .sign(PaddingScheme::new_pkcs1v15_sign(Some(Hash::SHA1)), &digest) .unwrap(); assert_ne!(out, digest); assert_eq!(out, expected); @@ -327,8 +328,7 @@ mod tests { let out2 = priv_key .sign_blinded( &mut rng, - PaddingScheme::PKCS1v15, - Some(&Hashes::SHA1), + PaddingScheme::new_pkcs1v15_sign(Some(Hash::SHA1)), &digest, ) .unwrap(); @@ -350,7 +350,11 @@ mod tests { let sig = hex::decode(test[1]).unwrap(); pub_key - .verify(PaddingScheme::PKCS1v15, Some(&Hashes::SHA1), &digest, &sig) + .verify( + PaddingScheme::new_pkcs1v15_sign(Some(Hash::SHA1)), + &digest, + &sig, + ) .expect("failed to verify"); } } @@ -362,13 +366,13 @@ mod tests { let priv_key = get_private_key(); let sig = priv_key - .sign::(PaddingScheme::PKCS1v15, None, msg) + .sign(PaddingScheme::new_pkcs1v15_sign(None), msg) .unwrap(); assert_eq!(expected_sig, sig); let pub_key: RSAPublicKey = priv_key.into(); pub_key - .verify::(PaddingScheme::PKCS1v15, None, msg, &sig) + .verify(PaddingScheme::new_pkcs1v15_sign(None), msg, &sig) .expect("failed to verify"); } } diff --git a/src/pss.rs b/src/pss.rs new file mode 100644 index 00000000..76bbcc9d --- /dev/null +++ b/src/pss.rs @@ -0,0 +1,316 @@ +use std::vec::Vec; + +use digest::DynDigest; +use rand::{Rng, RngCore}; +use subtle::ConstantTimeEq; + +use crate::algorithms::mgf1_xor; +use crate::errors::{Error, Result}; +use crate::key::{PrivateKey, PublicKey}; + +pub fn verify( + pub_key: &PK, + hashed: &[u8], + sig: &[u8], + digest: &mut dyn DynDigest, +) -> Result<()> { + if sig.len() != pub_key.size() { + return Err(Error::Verification); + } + + let em_bits = pub_key.n().bits() - 1; + let em_len = (em_bits + 7) / 8; + let mut em = pub_key.raw_encryption_primitive(sig, em_len)?; + + emsa_pss_verify(hashed, &mut em, em_bits, None, digest) +} + +/// SignPSS calculates the signature of hashed using RSASSA-PSS [1]. +/// Note that hashed must be the result of hashing the input message using the +/// given hash function. The opts argument may be nil, in which case sensible +/// defaults are used. +pub fn sign( + rng: &mut T, + blind_rng: Option<&mut S>, + priv_key: &SK, + hashed: &[u8], + salt_len: Option, + digest: &mut dyn DynDigest, +) -> Result> { + let salt_len = salt_len.unwrap_or_else(|| priv_key.size() - 2 - digest.output_size()); + + let mut salt = vec![0; salt_len]; + rng.fill(&mut salt[..]); + + sign_pss_with_salt(blind_rng, priv_key, hashed, &salt, digest) +} + +/// signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt. +/// Note that hashed must be the result of hashing the input message using the +/// given hash function. salt is a random sequence of bytes whose length will be +/// later used to verify the signature. +fn sign_pss_with_salt( + blind_rng: Option<&mut T>, + priv_key: &SK, + hashed: &[u8], + salt: &[u8], + digest: &mut dyn DynDigest, +) -> Result> { + let em_bits = priv_key.n().bits() - 1; + let em = emsa_pss_encode(hashed, em_bits, salt, digest)?; + + priv_key.raw_decryption_primitive(blind_rng, &em, priv_key.size()) +} + +fn emsa_pss_encode( + m_hash: &[u8], + em_bits: usize, + salt: &[u8], + hash: &mut dyn DynDigest, +) -> Result> { + // See [1], section 9.1.1 + let h_len = hash.output_size(); + let s_len = salt.len(); + let em_len = (em_bits + 7) / 8; + + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "message too + // long" and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + if m_hash.len() != h_len { + return Err(Error::InputNotHashed); + } + + // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. + if em_len < h_len + s_len + 2 { + // TODO: Key size too small + return Err(Error::Internal); + } + + let mut em = vec![0; em_len]; + + let (db, h) = em.split_at_mut(em_len - h_len - 1); + let h = &mut h[..(em_len - 1) - db.len()]; + + // 4. Generate a random octet string salt of length s_len; if s_len = 0, + // then salt is the empty string. + // + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; + // + // M' is an octet string of length 8 + h_len + s_len with eight + // initial zero octets. + // + // 6. Let H = Hash(M'), an octet string of length h_len. + let prefix = [0u8; 8]; + + hash.update(&prefix); + hash.update(m_hash); + hash.update(salt); + + let hashed = hash.finalize_reset(); + h.copy_from_slice(&hashed); + + // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 + // zero octets. The length of PS may be 0. + // + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + db[em_len - s_len - h_len - 2] = 0x01; + db[em_len - s_len - h_len - 1..].copy_from_slice(salt); + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + // + // 10. Let maskedDB = DB \xor dbMask. + mgf1_xor(db, hash, &h); + + // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in + // maskedDB to zero. + db[0] &= 0xFF >> (8 * em_len - em_bits); + + // 12. Let EM = maskedDB || H || 0xbc. + em[em_len - 1] = 0xBC; + + Ok(em) +} + +fn emsa_pss_verify( + m_hash: &[u8], + em: &mut [u8], + em_bits: usize, + s_len: Option, + hash: &mut dyn DynDigest, +) -> Result<()> { + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" + // and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen + let h_len = hash.output_size(); + if m_hash.len() != h_len { + return Err(Error::Verification); + } + + // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. + let em_len = em.len(); //(em_bits + 7) / 8; + if em_len < h_len + s_len.unwrap_or_default() + 2 { + return Err(Error::Verification); + } + + // 4. If the rightmost octet of EM does not have hexadecimal value + // 0xbc, output "inconsistent" and stop. + if em[em.len() - 1] != 0xBC { + return Err(Error::Verification); + } + + // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and + // let H be the next hLen octets. + let (db, h) = em.split_at_mut(em_len - h_len - 1); + let h = &mut h[..h_len]; + + // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in + // maskedDB are not all equal to zero, output "inconsistent" and + // stop. + if db[0] & (0xFF << /*uint*/(8 - (8 * em_len - em_bits))) != 0 { + return Err(Error::Verification); + } + + // 7. Let dbMask = MGF(H, em_len - h_len - 1) + // + // 8. Let DB = maskedDB \xor dbMask + mgf1_xor(db, hash, &*h); + + // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB + // to zero. + db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); + + let s_len = match s_len { + None => (0..=em_len - (h_len + 2)) + .rev() + .try_fold(None, |state, i| match (state, db[em_len - h_len - i - 2]) { + (Some(i), _) => Ok(Some(i)), + (_, 1) => Ok(Some(i)), + (_, 0) => Ok(None), + _ => Err(Error::Verification), + })? + .ok_or(Error::Verification)?, + Some(s_len) => { + // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero + // or if the octet at position emLen - hLen - sLen - 1 (the leftmost + // position is "position 1") does not have hexadecimal value 0x01, + // output "inconsistent" and stop. + let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); + if zeroes.iter().any(|e| *e != 0x00) || rest[0] != 0x01 { + return Err(Error::Verification); + } + + s_len + } + }; + + // 11. Let salt be the last s_len octets of DB. + let salt = &db[db.len() - s_len..]; + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 13. Let H' = Hash(M'), an octet string of length hLen. + let prefix = [0u8; 8]; + + hash.update(&prefix[..]); + hash.update(m_hash); + hash.update(salt); + let h0 = hash.finalize_reset(); + + // 14. If H = H', output "consistent." Otherwise, output "inconsistent." + if h0.ct_eq(h).into() { + Ok(()) + } else { + Err(Error::Verification) + } +} + +#[cfg(test)] +mod test { + use crate::{PaddingScheme, PublicKey, RSAPrivateKey, RSAPublicKey}; + + use num_bigint::BigUint; + use num_traits::{FromPrimitive, Num}; + use rand::thread_rng; + use sha1::{Digest, Sha1}; + + fn get_private_key() -> RSAPrivateKey { + // In order to generate new test vectors you'll need the PEM form of this key: + // -----BEGIN RSA PRIVATE KEY----- + // MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0 + // fd7Ai2KW5ToIwzFofvJcS/STa6HA5gQenRUCAwEAAQJBAIq9amn00aS0h/CrjXqu + // /ThglAXJmZhOMPVn4eiu7/ROixi9sex436MaVeMqSNf7Ex9a8fRNfWss7Sqd9eWu + // RTUCIQDasvGASLqmjeffBNLTXV2A5g4t+kLVCpsEIZAycV5GswIhANEPLmax0ME/ + // EO+ZJ79TJKN5yiGBRsv5yvx5UiHxajEXAiAhAol5N4EUyq6I9w1rYdhPMGpLfk7A + // IU2snfRJ6Nq2CQIgFrPsWRCkV+gOYcajD17rEqmuLrdIRexpg8N1DOSXoJ8CIGlS + // tAboUGBxTDq3ZroNism3DaMIbKPyYrAqhKov1h5V + // -----END RSA PRIVATE KEY----- + + RSAPrivateKey::from_components( + BigUint::from_str_radix("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077", 10).unwrap(), + BigUint::from_u64(65537).unwrap(), + BigUint::from_str_radix("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861", 10).unwrap(), + vec![ + BigUint::from_str_radix("98920366548084643601728869055592650835572950932266967461790948584315647051443",10).unwrap(), + BigUint::from_str_radix("94560208308847015747498523884063394671606671904944666360068158221458669711639", 10).unwrap() + ], + ) + } + + #[test] + fn test_verify_pss() { + let priv_key = get_private_key(); + + let tests = [[ + "test\n", "6f86f26b14372b2279f79fb6807c49889835c204f71e38249b4c5601462da8ae30f26ffdd9c13f1c75eee172bebe7b7c89f2f1526c722833b9737d6c172a962f" + ]]; + let pub_key: RSAPublicKey = priv_key.into(); + + for test in &tests { + let digest = Sha1::digest(test[0].as_bytes()).to_vec(); + let sig = hex::decode(test[1]).unwrap(); + + pub_key + .verify( + PaddingScheme::new_pss::(thread_rng()), + &digest, + &sig, + ) + .expect("failed to verify"); + } + } + + #[test] + fn test_sign_and_verify_roundtrip() { + let priv_key = get_private_key(); + + let tests = ["test\n"]; + + for test in &tests { + let digest = Sha1::digest(test.as_bytes()).to_vec(); + let sig = priv_key + .sign_blinded( + &mut thread_rng(), + PaddingScheme::new_pss::(thread_rng()), + &digest, + ) + .expect("failed to sign"); + + priv_key + .verify( + PaddingScheme::new_pss::(thread_rng()), + &digest, + &sig, + ) + .expect("failed to verify"); + } + } +} diff --git a/src/raw.rs b/src/raw.rs index 496b7a7c..9bf55e96 100644 --- a/src/raw.rs +++ b/src/raw.rs @@ -2,13 +2,13 @@ use num_bigint::BigUint; use rand::Rng; use zeroize::Zeroize; -use crate::errors::Result; +use crate::errors::{Error, Result}; use crate::internals; -use crate::key::{PublicKeyParts, RSAPrivateKey, RSAPublicKey}; +use crate::key::{RSAPrivateKey, RSAPublicKey}; pub trait EncryptionPrimitive { /// Do NOT use directly! Only for implementors. - fn raw_encryption_primitive(&self, plaintext: &[u8]) -> Result>; + fn raw_encryption_primitive(&self, plaintext: &[u8], pad_size: usize) -> Result>; } pub trait DecryptionPrimitive { @@ -17,15 +17,20 @@ pub trait DecryptionPrimitive { &self, rng: Option<&mut R>, ciphertext: &[u8], + pad_size: usize, ) -> Result>; } impl EncryptionPrimitive for RSAPublicKey { - fn raw_encryption_primitive(&self, plaintext: &[u8]) -> Result> { + fn raw_encryption_primitive(&self, plaintext: &[u8], pad_size: usize) -> Result> { let mut m = BigUint::from_bytes_be(plaintext); let mut c = internals::encrypt(self, &m); let mut c_bytes = c.to_bytes_be(); - let ciphertext = internals::left_pad(&c_bytes, self.size()); + let ciphertext = internals::left_pad(&c_bytes, pad_size); + + if pad_size < ciphertext.len() { + return Err(Error::Verification); + } // clear out tmp values m.zeroize(); @@ -37,8 +42,8 @@ impl EncryptionPrimitive for RSAPublicKey { } impl<'a> EncryptionPrimitive for &'a RSAPublicKey { - fn raw_encryption_primitive(&self, plaintext: &[u8]) -> Result> { - (*self).raw_encryption_primitive(plaintext) + fn raw_encryption_primitive(&self, plaintext: &[u8], pad_size: usize) -> Result> { + (*self).raw_encryption_primitive(plaintext, pad_size) } } @@ -47,11 +52,12 @@ impl DecryptionPrimitive for RSAPrivateKey { &self, rng: Option<&mut R>, ciphertext: &[u8], + pad_size: usize, ) -> Result> { let mut c = BigUint::from_bytes_be(ciphertext); let mut m = internals::decrypt_and_check(rng, self, &c)?; let mut m_bytes = m.to_bytes_be(); - let plaintext = internals::left_pad(&m_bytes, self.size()); + let plaintext = internals::left_pad(&m_bytes, pad_size); // clear tmp values c.zeroize(); @@ -67,7 +73,8 @@ impl<'a> DecryptionPrimitive for &'a RSAPrivateKey { &self, rng: Option<&mut R>, ciphertext: &[u8], + pad_size: usize, ) -> Result> { - (*self).raw_decryption_primitive(rng, ciphertext) + (*self).raw_decryption_primitive(rng, ciphertext, pad_size) } }