diff --git a/der/src/reader.rs b/der/src/reader.rs index 845bb8c1f..a57e290d0 100644 --- a/der/src/reader.rs +++ b/der/src/reader.rs @@ -4,6 +4,9 @@ pub(crate) mod pem; pub(crate) mod slice; +#[cfg(feature = "pem")] +mod position; + use crate::{ Decode, DecodeValue, Encode, EncodingRules, Error, ErrorKind, FixedTag, Header, Length, Tag, TagMode, TagNumber, asn1::ContextSpecific, diff --git a/der/src/reader/pem.rs b/der/src/reader/pem.rs index cfbe89e46..83316244d 100644 --- a/der/src/reader/pem.rs +++ b/der/src/reader/pem.rs @@ -1,7 +1,7 @@ //! Streaming PEM reader. -use super::Reader; -use crate::{EncodingRules, Error, ErrorKind, Length}; +use super::{Reader, position::Position}; +use crate::{EncodingRules, Error, ErrorKind, Length, Result}; use pem_rfc7468::Decoder; /// `Reader` type which decodes PEM on-the-fly. @@ -14,11 +14,8 @@ pub struct PemReader<'i> { /// Encoding rules to apply when decoding the input. encoding_rules: EncodingRules, - /// Input length (in bytes after Base64 decoding). - input_len: Length, - - /// Position in the input buffer (in bytes after Base64 decoding). - position: Length, + /// Position tracker. + position: Position, } #[cfg(feature = "pem")] @@ -26,15 +23,14 @@ impl<'i> PemReader<'i> { /// Create a new PEM reader which decodes data on-the-fly. /// /// Uses the default 64-character line wrapping. - pub fn new(pem: &'i [u8]) -> crate::Result { + pub fn new(pem: &'i [u8]) -> Result { let decoder = Decoder::new(pem)?; let input_len = Length::try_from(decoder.remaining_len())?; Ok(Self { decoder, encoding_rules: EncodingRules::default(), - input_len, - position: Length::ZERO, + position: Position::new(input_len), }) } @@ -52,52 +48,37 @@ impl<'i> Reader<'i> for PemReader<'i> { } fn input_len(&self) -> Length { - self.input_len + self.position.input_len() } - fn peek_into(&self, buf: &mut [u8]) -> crate::Result<()> { + fn peek_into(&self, buf: &mut [u8]) -> Result<()> { self.clone().read_into(buf)?; Ok(()) } fn position(&self) -> Length { - self.position + self.position.current() } - fn read_nested(&mut self, len: Length, f: F) -> Result + fn read_nested(&mut self, len: Length, f: F) -> core::result::Result where - F: FnOnce(&mut Self) -> Result, + F: FnOnce(&mut Self) -> core::result::Result, E: From, { - let nested_input_len = (self.position + len)?; - if nested_input_len > self.input_len { - return Err(Error::incomplete(self.input_len).into()); - } - - let orig_input_len = self.input_len; - self.input_len = nested_input_len; + let resumption = self.position.split_nested(len)?; let ret = f(self); - self.input_len = orig_input_len; + self.position.resume_nested(resumption); ret } - fn read_slice(&mut self, _len: Length) -> crate::Result<&'i [u8]> { + fn read_slice(&mut self, _len: Length) -> Result<&'i [u8]> { // Can't borrow from PEM because it requires decoding Err(ErrorKind::Reader.into()) } - fn read_into<'o>(&mut self, buf: &'o mut [u8]) -> crate::Result<&'o [u8]> { - let new_position = (self.position + buf.len())?; - if new_position > self.input_len { - return Err(ErrorKind::Incomplete { - expected_len: new_position, - actual_len: self.input_len, - } - .at(self.position)); - } - + fn read_into<'o>(&mut self, buf: &'o mut [u8]) -> Result<&'o [u8]> { + self.position.advance(Length::try_from(buf.len())?)?; self.decoder.decode(buf)?; - self.position = new_position; Ok(buf) } } diff --git a/der/src/reader/position.rs b/der/src/reader/position.rs new file mode 100644 index 000000000..e35a8a04f --- /dev/null +++ b/der/src/reader/position.rs @@ -0,0 +1,150 @@ +//! Position tracking for processing nested input messages using only the stack. + +use crate::{Error, ErrorKind, Length, Result}; + +/// State tracker for the current position in the input. +#[derive(Clone, Debug)] +pub(super) struct Position { + /// Input length (in bytes after Base64 decoding). + input_len: Length, + + /// Position in the input buffer (in bytes after Base64 decoding). + position: Length, +} + +impl Position { + /// Create a new position tracker with the given overall length. + pub(super) fn new(input_len: Length) -> Self { + Self { + input_len, + position: Length::ZERO, + } + } + + /// Get the input length. + pub(super) fn input_len(&self) -> Length { + self.input_len + } + + /// Get the current position. + pub(super) fn current(&self) -> Length { + self.position + } + + /// Advance the current position by the given amount. + /// + /// # Returns + /// + /// The new current position. + pub(super) fn advance(&mut self, amount: Length) -> Result { + let new_position = (self.position + amount)?; + + if new_position > self.input_len { + return Err(ErrorKind::Incomplete { + expected_len: new_position, + actual_len: self.input_len, + } + .at(self.position)); + } + + self.position = new_position; + Ok(new_position) + } + + /// Split a nested position tracker of the given size. + /// + /// # Returns + /// + /// A [`Resumption`] value which can be used to continue parsing the outer message. + pub(super) fn split_nested(&mut self, len: Length) -> Result { + let nested_input_len = (self.position + len)?; + + if nested_input_len > self.input_len { + return Err(Error::incomplete(self.input_len)); + } + + let resumption = Resumption { + input_len: self.input_len, + }; + self.input_len = nested_input_len; + Ok(resumption) + } + + /// Resume processing the rest of a message after processing a nested inner portion. + pub(super) fn resume_nested(&mut self, resumption: Resumption) { + self.input_len = resumption.input_len; + } +} + +/// Resumption state needed to continue processing a message after handling a nested inner portion. +#[derive(Debug)] +pub(super) struct Resumption { + /// Outer input length. + input_len: Length, +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::Position; + use crate::{ErrorKind, Length}; + + const EXAMPLE_LEN: Length = match Length::new_usize(42) { + Ok(len) => len, + Err(_) => panic!("invalid example len"), + }; + + #[test] + fn initial_state() { + let pos = Position::new(EXAMPLE_LEN); + assert_eq!(pos.input_len(), EXAMPLE_LEN); + assert_eq!(pos.current(), Length::ZERO); + } + + #[test] + fn advance() { + let mut pos = Position::new(EXAMPLE_LEN); + + // advance 1 byte: success + let new_pos = pos.advance(Length::ONE).unwrap(); + assert_eq!(new_pos, Length::ONE); + assert_eq!(pos.current(), Length::ONE); + + // advance to end: success + let end_pos = pos.advance((EXAMPLE_LEN - Length::ONE).unwrap()).unwrap(); + assert_eq!(end_pos, EXAMPLE_LEN); + assert_eq!(pos.current(), EXAMPLE_LEN); + + // advance one byte past end: error + let err = pos.advance(Length::ONE).unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::Incomplete { .. })); + } + + #[test] + fn nested() { + let mut pos = Position::new(EXAMPLE_LEN); + + // split first byte + let resumption = pos.split_nested(Length::ONE).unwrap(); + assert_eq!(pos.current(), Length::ZERO); + assert_eq!(pos.input_len(), Length::ONE); + + // advance one byte + assert_eq!(pos.advance(Length::ONE).unwrap(), Length::ONE); + + // can't advance two bytes + let err = pos.advance(Length::ONE).unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::Incomplete { .. })); + + // resume processing the rest of the message + // TODO(tarcieri): should we fail here if we previously failed reading a nested message? + pos.resume_nested(resumption); + + assert_eq!(pos.current(), Length::ONE); + assert_eq!(pos.input_len(), EXAMPLE_LEN); + + // try to split one byte past end: error + let err = pos.split_nested(EXAMPLE_LEN).unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::Incomplete { .. })); + } +}