diff --git a/ssh-key/src/certificate.rs b/ssh-key/src/certificate.rs index b28afea63..ca24ca0a5 100644 --- a/ssh-key/src/certificate.rs +++ b/ssh-key/src/certificate.rs @@ -5,12 +5,14 @@ mod cert_type; mod field; mod options_map; mod signing_key; +mod unix_time; pub use self::{ builder::Builder, cert_type::CertType, field::Field, options_map::OptionsMap, signing_key::SigningKey, }; +use self::unix_time::UnixTime; use crate::{ checked::CheckedSum, decode::Decode, @@ -37,11 +39,7 @@ use { use serde::{de, ser, Deserialize, Serialize}; #[cfg(feature = "std")] -use std::{ - fs, - path::Path, - time::{Duration, SystemTime, UNIX_EPOCH}, -}; +use std::{fs, path::Path, time::SystemTime}; /// OpenSSH certificate as specified in [PROTOCOL.certkeys]. /// @@ -151,11 +149,11 @@ pub struct Certificate { /// Valid principals. valid_principals: Vec, - /// Valid after (Unix time). - valid_after: u64, + /// Valid after. + valid_after: UnixTime, - /// Valid before (Unix time). - valid_before: u64, + /// Valid before. + valid_before: UnixTime, /// Critical options. critical_options: OptionsMap, @@ -313,30 +311,26 @@ impl Certificate { /// Valid after (Unix time). pub fn valid_after(&self) -> u64 { - self.valid_after + self.valid_after.into() } /// Valid before (Unix time). pub fn valid_before(&self) -> u64 { - self.valid_before + self.valid_before.into() } /// Valid after (system time). #[cfg(feature = "std")] #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn valid_after_time(&self) -> SystemTime { - UNIX_EPOCH - .checked_add(Duration::from_secs(self.valid_after)) - .expect("time overflow") + self.valid_after.into() } /// Valid before (system time). #[cfg(feature = "std")] #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn valid_before_time(&self) -> SystemTime { - UNIX_EPOCH - .checked_add(Duration::from_secs(self.valid_before)) - .expect("time overflow") + self.valid_before.into() } /// The critical options section of the certificate specifies zero or more @@ -384,12 +378,7 @@ impl Certificate { where I: IntoIterator, { - let unix_time = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_err(|_| Error::CertificateValidation)? - .as_secs(); - - self.validate_at(unix_time, ca_fingerprints) + self.validate_at(UnixTime::now()?.into(), ca_fingerprints) } /// Perform certificate validation. @@ -435,6 +424,8 @@ impl Certificate { return Err(Error::CertificateValidation); } + let unix_timestamp = UnixTime::new(unix_timestamp)?; + // From PROTOCOL.certkeys: // // "valid after" and "valid before" specify a validity period for the @@ -503,8 +494,8 @@ impl Decode for Certificate { cert_type: CertType::decode(reader)?, key_id: String::decode(reader)?, valid_principals: Vec::decode(reader)?, - valid_after: u64::decode(reader)?, - valid_before: u64::decode(reader)?, + valid_after: UnixTime::decode(reader)?, + valid_before: UnixTime::decode(reader)?, critical_options: OptionsMap::decode(reader)?, extensions: OptionsMap::decode(reader)?, reserved: Vec::decode(reader)?, diff --git a/ssh-key/src/certificate/builder.rs b/ssh-key/src/certificate/builder.rs index f0b66a92e..5774abb33 100644 --- a/ssh-key/src/certificate/builder.rs +++ b/ssh-key/src/certificate/builder.rs @@ -1,6 +1,6 @@ //! OpenSSH certificate builder. -use super::{CertType, Certificate, Field, OptionsMap, SigningKey}; +use super::{unix_time::UnixTime, CertType, Certificate, Field, OptionsMap, SigningKey}; use crate::{public, Result, Signature}; use alloc::{string::String, vec::Vec}; @@ -8,7 +8,7 @@ use alloc::{string::String, vec::Vec}; use rand_core::{CryptoRng, RngCore}; #[cfg(feature = "std")] -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::SystemTime; #[cfg(doc)] use crate::PrivateKey; @@ -84,8 +84,8 @@ pub struct Builder { cert_type: Option, key_id: Option, valid_principals: Option>, - valid_after: u64, - valid_before: u64, + valid_after: UnixTime, + valid_before: UnixTime, critical_options: OptionsMap, extensions: OptionsMap, comment: Option, @@ -105,6 +105,11 @@ impl Builder { valid_after: u64, valid_before: u64, ) -> Self { + // TODO(tarcieri): return a `Result` instead of using `expect` + // Breaking change; needs to be done in the next release + let valid_after = UnixTime::new(valid_after).expect("valid_after time overflow"); + let valid_before = UnixTime::new(valid_before).expect("valid_before time overflow"); + Self { nonce: nonce.into(), public_key: public_key.into(), @@ -130,21 +135,23 @@ impl Builder { valid_after: SystemTime, valid_before: SystemTime, ) -> Result { - let valid_after = valid_after - .duration_since(UNIX_EPOCH) - .map_err(|_| Field::ValidAfter.invalid_error())? - .as_secs(); + let valid_after = + UnixTime::try_from(valid_after).map_err(|_| Field::ValidAfter.invalid_error())?; - let valid_before = valid_before - .duration_since(UNIX_EPOCH) - .map_err(|_| Field::ValidBefore.invalid_error())? - .as_secs(); + let valid_before = + UnixTime::try_from(valid_before).map_err(|_| Field::ValidBefore.invalid_error())?; + // TODO(tarcieri): move this check into `Builder::new` if valid_before < valid_after { return Err(Field::ValidBefore.invalid_error()); } - Ok(Self::new(nonce, public_key, valid_before, valid_after)) + Ok(Self::new( + nonce, + public_key, + valid_before.into(), + valid_after.into(), + )) } /// Create a new certificate builder, generating a random nonce using the @@ -304,7 +311,7 @@ impl Builder { #[cfg(all(debug_assertions, feature = "fingerprint"))] cert.validate_at( - cert.valid_after, + cert.valid_after.into(), &[cert.signature_key.fingerprint(Default::default())], )?; diff --git a/ssh-key/src/certificate/unix_time.rs b/ssh-key/src/certificate/unix_time.rs new file mode 100644 index 000000000..dfa7d00f7 --- /dev/null +++ b/ssh-key/src/certificate/unix_time.rs @@ -0,0 +1,128 @@ +//! Unix timestamps. + +use crate::{decode::Decode, encode::Encode, reader::Reader, writer::Writer, Error, Result}; +use core::fmt; +use core::fmt::Formatter; + +#[cfg(feature = "std")] +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +/// Maximum number of seconds since the Unix epoch allowed. +pub const MAX_SECS: u64 = i64::MAX as u64; + +/// Unix timestamps as used in OpenSSH certificates. +#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)] +pub(super) struct UnixTime { + /// Number of seconds since the Unix epoch + secs: u64, + + /// System time corresponding to this Unix timestamp + #[cfg(feature = "std")] + time: SystemTime, +} + +impl UnixTime { + /// Create a new Unix timestamp. + /// + /// `secs` is the number of seconds since the Unix epoch and must be less + /// than or equal to `i64::MAX`. + #[cfg(not(feature = "std"))] + pub fn new(secs: u64) -> Result { + if secs <= MAX_SECS { + Ok(Self { secs }) + } else { + Err(Error::Time) + } + } + + /// Create a new Unix timestamp. + /// + /// This version requires `std` and ensures there's a valid `SystemTime` + /// representation with an infallible conversion (which also improves the + /// `Debug` output) + #[cfg(feature = "std")] + pub fn new(secs: u64) -> Result { + if secs > MAX_SECS { + return Err(Error::Time); + } + + match UNIX_EPOCH.checked_add(Duration::from_secs(secs)) { + Some(time) => Ok(Self { secs, time }), + None => Err(Error::Time), + } + } + + /// Get the current time as a Unix timestamp. + #[cfg(all(feature = "std", feature = "fingerprint"))] + pub fn now() -> Result { + SystemTime::now().try_into() + } +} + +impl Decode for UnixTime { + fn decode(reader: &mut impl Reader) -> Result { + u64::decode(reader)?.try_into() + } +} + +impl Encode for UnixTime { + fn encoded_len(&self) -> Result { + self.secs.encoded_len() + } + + fn encode(&self, writer: &mut impl Writer) -> Result<()> { + self.secs.encode(writer) + } +} + +impl From for u64 { + fn from(unix_time: UnixTime) -> u64 { + unix_time.secs + } +} + +#[cfg(feature = "std")] +impl From for SystemTime { + fn from(unix_time: UnixTime) -> SystemTime { + unix_time.time + } +} + +impl TryFrom for UnixTime { + type Error = Error; + + fn try_from(unix_secs: u64) -> Result { + Self::new(unix_secs) + } +} + +#[cfg(feature = "std")] +impl TryFrom for UnixTime { + type Error = Error; + + fn try_from(time: SystemTime) -> Result { + Self::new(time.duration_since(UNIX_EPOCH)?.as_secs()) + } +} + +impl fmt::Debug for UnixTime { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.secs) + } +} + +#[cfg(test)] +mod tests { + use super::{UnixTime, MAX_SECS}; + use crate::Error; + + #[test] + fn new_with_max_secs() { + assert!(UnixTime::new(MAX_SECS).is_ok()); + } + + #[test] + fn new_over_max_secs_returns_error() { + assert_eq!(UnixTime::new(MAX_SECS + 1), Err(Error::Time)); + } +} diff --git a/ssh-key/src/error.rs b/ssh-key/src/error.rs index accf2712c..5bab7ca3f 100644 --- a/ssh-key/src/error.rs +++ b/ssh-key/src/error.rs @@ -64,6 +64,9 @@ pub enum Error { /// Public key does not match private key. PublicKey, + /// Invalid timestamp (e.g. in a certificate) + Time, + /// Unexpected trailing data at end of message. TrailingData { /// Number of bytes of remaining data at end of message. @@ -74,26 +77,27 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Error::Algorithm => f.write_str("unknown or unsupported algorithm"), + Error::Algorithm => write!(f, "unknown or unsupported algorithm"), Error::Base64(err) => write!(f, "Base64 encoding error: {}", err), #[cfg(feature = "alloc")] Error::CertificateFieldInvalid(field) => { write!(f, "certificate field invalid: {}", field) } - Error::CertificateValidation => f.write_str("certificate validation failed"), - Error::CharacterEncoding => f.write_str("character encoding invalid"), - Error::Crypto => f.write_str("cryptographic error"), - Error::Decrypted => f.write_str("private key is already decrypted"), + Error::CertificateValidation => write!(f, "certificate validation failed"), + Error::CharacterEncoding => write!(f, "character encoding invalid"), + Error::Crypto => write!(f, "cryptographic error"), + Error::Decrypted => write!(f, "private key is already decrypted"), #[cfg(feature = "ecdsa")] Error::Ecdsa(err) => write!(f, "ECDSA encoding error: {}", err), - Error::Encrypted => f.write_str("private key is encrypted"), - Error::FormatEncoding => f.write_str("format encoding error"), + Error::Encrypted => write!(f, "private key is encrypted"), + Error::FormatEncoding => write!(f, "format encoding error"), #[cfg(feature = "std")] Error::Io(err) => write!(f, "I/O error: {}", std::io::Error::from(*err)), - Error::Length => f.write_str("length invalid"), - Error::Overflow => f.write_str("internal overflow error"), + Error::Length => write!(f, "length invalid"), + Error::Overflow => write!(f, "internal overflow error"), Error::Pem(err) => write!(f, "{}", err), - Error::PublicKey => f.write_str("public key is incorrect"), + Error::PublicKey => write!(f, "public key is incorrect"), + Error::Time => write!(f, "invalid time"), Error::TrailingData { remaining } => write!( f, "unexpected trailing data at end of message ({} bytes)", @@ -103,9 +107,6 @@ impl fmt::Display for Error { } } -#[cfg(feature = "std")] -impl std::error::Error for Error {} - impl From for Error { fn from(err: base64ct::Error) -> Error { Error::Base64(err) @@ -189,3 +190,13 @@ impl From for Error { Error::Io(err.kind()) } } + +#[cfg(feature = "std")] +impl From for Error { + fn from(_: std::time::SystemTimeError) -> Error { + Error::Time + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {}