diff --git a/base64ct/src/decoder.rs b/base64ct/src/decoder.rs index 9a02ae184..3a5cf34b7 100644 --- a/base64ct/src/decoder.rs +++ b/base64ct/src/decoder.rs @@ -8,6 +8,12 @@ use crate::{ }; use core::{cmp, marker::PhantomData}; +#[cfg(feature = "alloc")] +use {alloc::vec::Vec, core::iter}; + +#[cfg(feature = "std")] +use std::io; + #[cfg(docsrs)] use crate::{Base64, Base64Unpadded}; @@ -165,6 +171,27 @@ impl<'i, E: Variant> Decoder<'i, E> { Ok(out) } + /// Decode all remaining Base64 data, placing the result into `buf`. + /// + /// If successful, this function will return the total number of bytes + /// decoded into `buf`. + #[cfg(feature = "alloc")] + #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] + pub fn decode_to_end<'o>(&mut self, buf: &'o mut Vec) -> Result<&'o [u8], Error> { + let start_len = buf.len(); + let decoded_len = self.decoded_len(); + let total_len = start_len + decoded_len; + + if total_len > buf.capacity() { + buf.reserve(total_len - buf.capacity()); + } + + // Append `decoded_len` zeroes to the vector + buf.extend(iter::repeat(0).take(decoded_len)); + self.decode(&mut buf[start_len..])?; + Ok(&buf[start_len..]) + } + /// Get the length of the remaining data after Base64 decoding. /// /// Decreases every time data is decoded. @@ -230,6 +257,29 @@ impl<'i, E: Variant> Decoder<'i, E> { } } +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl<'i, E: Variant> io::Read for Decoder<'i, E> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let slice = match buf.get_mut(..self.decoded_len()) { + Some(bytes) => bytes, + None => buf, + }; + + self.decode(slice)?; + Ok(slice.len()) + } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + Ok(self.decode_to_end(buf)?.len()) + } + + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.decode(buf)?; + Ok(()) + } +} + /// Base64 decode buffer for a 1-block input. /// /// This handles a partially decoded block of data, i.e. data which has been @@ -501,6 +551,9 @@ impl<'i> Iterator for LineReader<'i> { mod tests { use crate::{test_vectors::*, variant::Variant, Base64, Base64Unpadded, Decoder}; + #[cfg(feature = "std")] + use {alloc::vec::Vec, std::io::Read}; + #[test] fn decode_padded() { decode_test(PADDED_BIN, || { @@ -530,6 +583,19 @@ mod tests { }) } + #[cfg(feature = "std")] + #[test] + fn read_multiline_padded() { + let mut decoder = + Decoder::::new_wrapped(MULTILINE_PADDED_BASE64.as_bytes(), 70).unwrap(); + + let mut buf = Vec::new(); + let len = decoder.read_to_end(&mut buf).unwrap(); + + assert_eq!(len, MULTILINE_PADDED_BIN.len()); + assert_eq!(buf.as_slice(), MULTILINE_PADDED_BIN); + } + /// Core functionality of a decoding test fn decode_test<'a, F, V>(expected: &[u8], f: F) where diff --git a/base64ct/src/encoder.rs b/base64ct/src/encoder.rs index 1267d06fa..c8ef38bf3 100644 --- a/base64ct/src/encoder.rs +++ b/base64ct/src/encoder.rs @@ -7,6 +7,9 @@ use crate::{ }; use core::{cmp, marker::PhantomData, str}; +#[cfg(docsrs)] +use crate::{Base64, Base64Unpadded}; + /// Stateful Base64 encoder with support for buffered, incremental encoding. /// /// The `E` type parameter can be any type which impls [`Encoding`] such as diff --git a/base64ct/src/errors.rs b/base64ct/src/errors.rs index 1b43a8d4e..0ea417173 100644 --- a/base64ct/src/errors.rs +++ b/base64ct/src/errors.rs @@ -73,4 +73,14 @@ impl From for Error { } #[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl From for std::io::Error { + fn from(err: Error) -> std::io::Error { + // TODO(tarcieri): better customize `ErrorKind`? + std::io::Error::new(std::io::ErrorKind::InvalidData, err) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl std::error::Error for Error {}