Skip to content
116 changes: 77 additions & 39 deletions src/uu/base32/src/base_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
use clap::{Arg, ArgAction, Command};
use std::ffi::OsString;
use std::fs::File;
use std::io::{self, BufReader, ErrorKind, Read, Write};
use std::io::{self, BufRead, BufReader, ErrorKind, Write};
use std::path::{Path, PathBuf};
use uucore::display::Quotable;
use uucore::encoding::{
Expand Down Expand Up @@ -146,20 +146,26 @@ pub fn base_app(about: String, usage: String) -> Command {
)
}

pub fn get_input(config: &Config) -> UResult<Box<dyn Read>> {
pub fn get_input(config: &Config) -> UResult<Box<dyn BufRead>> {
match &config.to_read {
Some(path_buf) => {
let file =
File::open(path_buf).map_err_context(|| path_buf.maybe_quote().to_string())?;
Ok(Box::new(BufReader::new(file)))
Ok(Box::new(BufReader::with_capacity(
DEFAULT_BUFFER_SIZE,
file,
)))
}
None => {
// Stdin is already buffered by the OS; wrap once more to reduce syscalls per read.
Ok(Box::new(BufReader::new(io::stdin())))
Ok(Box::new(BufReader::with_capacity(
DEFAULT_BUFFER_SIZE,
io::stdin(),
)))
}
}
}
pub fn handle_input<R: Read>(input: &mut R, format: Format, config: Config) -> UResult<()> {
pub fn handle_input<R: BufRead>(input: &mut R, format: Format, config: Config) -> UResult<()> {
// Always allow padding for Base64 to avoid a full pre-scan of the input.
let supports_fast_decode_and_encode =
get_supports_fast_decode_and_encode(format, config.decode, true);
Expand Down Expand Up @@ -292,11 +298,11 @@ pub fn get_supports_fast_decode_and_encode(
}

pub mod fast_encode {
use crate::base_common::{DEFAULT_BUFFER_SIZE, WRAP_DEFAULT};
use crate::base_common::WRAP_DEFAULT;
use std::{
cmp::min,
collections::VecDeque,
io::{self, Read, Write},
io::{self, BufRead, Write},
num::NonZeroUsize,
};
use uucore::{
Expand Down Expand Up @@ -519,7 +525,7 @@ pub mod fast_encode {
/// Remaining bytes are encoded and flushed at the end. I/O or encoding
/// failures are propagated via `UResult`.
pub fn fast_encode_stream(
input: &mut dyn Read,
input: &mut dyn BufRead,
output: &mut dyn Write,
supports_fast_decode_and_encode: &dyn SupportsFastDecodeAndEncode,
wrap: Option<usize>,
Expand All @@ -544,47 +550,79 @@ pub mod fast_encode {
};

// Buffers
let mut leftover_buffer = VecDeque::<u8>::new();
let mut encoded_buffer = VecDeque::<u8>::new();

let mut read_buffer = vec![0u8; encode_in_chunks_of_size.max(DEFAULT_BUFFER_SIZE)];
let mut leftover_buffer = Vec::<u8>::with_capacity(encode_in_chunks_of_size);

loop {
let read = input
.read(&mut read_buffer)
let read_buffer = input
.fill_buf()
.map_err(|err| USimpleError::new(1, super::format_read_error(err.kind())))?;
if read == 0 {
if read_buffer.is_empty() {
break;
}

leftover_buffer.extend(&read_buffer[..read]);
let mut consumed = 0;

while leftover_buffer.len() >= encode_in_chunks_of_size {
{
let contiguous = leftover_buffer.make_contiguous();
if !leftover_buffer.is_empty() {
let needed = encode_in_chunks_of_size - leftover_buffer.len();
let take = needed.min(read_buffer.len());
leftover_buffer.extend_from_slice(&read_buffer[..take]);
consumed += take;

if leftover_buffer.len() == encode_in_chunks_of_size {
encode_in_chunks_to_buffer(
supports_fast_decode_and_encode,
&contiguous[..encode_in_chunks_of_size],
leftover_buffer.as_slice(),
&mut encoded_buffer,
)?;
leftover_buffer.clear();

write_to_output(
&mut line_wrapping,
&mut encoded_buffer,
output,
false,
wrap == Some(0),
)?;
}
}

// Drop the data we just encoded
leftover_buffer.drain(..encode_in_chunks_of_size);
let remaining = &read_buffer[consumed..];
let full_chunk_bytes =
(remaining.len() / encode_in_chunks_of_size) * encode_in_chunks_of_size;

write_to_output(
&mut line_wrapping,
&mut encoded_buffer,
output,
false,
wrap == Some(0),
)?;
if full_chunk_bytes > 0 {
for chunk in remaining[..full_chunk_bytes].chunks_exact(encode_in_chunks_of_size) {
encode_in_chunks_to_buffer(
supports_fast_decode_and_encode,
chunk,
&mut encoded_buffer,
)?;
write_to_output(
&mut line_wrapping,
&mut encoded_buffer,
output,
false,
wrap == Some(0),
)?;
}
consumed += full_chunk_bytes;
}

if consumed < read_buffer.len() {
leftover_buffer.extend_from_slice(&read_buffer[consumed..]);
consumed = read_buffer.len();
}

input.consume(consumed);

// `leftover_buffer` should never exceed one partial chunk.
debug_assert!(leftover_buffer.len() < encode_in_chunks_of_size);
}

// Encode any remaining bytes and flush
supports_fast_decode_and_encode
.encode_to_vec_deque(leftover_buffer.make_contiguous(), &mut encoded_buffer)?;
.encode_to_vec_deque(&leftover_buffer, &mut encoded_buffer)?;

write_to_output(
&mut line_wrapping,
Expand All @@ -599,8 +637,7 @@ pub mod fast_encode {
}

pub mod fast_decode {
use crate::base_common::DEFAULT_BUFFER_SIZE;
use std::io::{self, Read, Write};
use std::io::{self, BufRead, Write};
use uucore::{
encoding::SupportsFastDecodeAndEncode,
error::{UResult, USimpleError},
Expand Down Expand Up @@ -630,7 +667,6 @@ pub mod fast_decode {
fn write_to_output(decoded_buffer: &mut Vec<u8>, output: &mut dyn Write) -> io::Result<()> {
// Write all data in `decoded_buffer` to `output`
output.write_all(decoded_buffer.as_slice())?;
output.flush()?;

decoded_buffer.clear();

Expand Down Expand Up @@ -764,7 +800,7 @@ pub mod fast_decode {
}

pub fn fast_decode_stream(
input: &mut dyn Read,
input: &mut dyn BufRead,
output: &mut dyn Write,
supports_fast_decode_and_encode: &dyn SupportsFastDecodeAndEncode,
ignore_garbage: bool,
Expand All @@ -783,17 +819,17 @@ pub mod fast_decode {

let mut buffer = Vec::with_capacity(decode_in_chunks_of_size);
let mut decoded_buffer = Vec::<u8>::new();
let mut read_buffer = [0u8; DEFAULT_BUFFER_SIZE];

loop {
let read = input
.read(&mut read_buffer)
let read_buffer = input
.fill_buf()
.map_err(|err| USimpleError::new(1, super::format_read_error(err.kind())))?;
if read == 0 {
let read_len = read_buffer.len();
if read_len == 0 {
break;
}

for &byte in &read_buffer[..read] {
for &byte in read_buffer {
if byte == b'\n' || byte == b'\r' {
continue;
}
Expand Down Expand Up @@ -845,6 +881,8 @@ pub mod fast_decode {
buffer.clear();
}
}

input.consume(read_len);
}

if supports_partial_decode {
Expand Down Expand Up @@ -902,7 +940,7 @@ fn format_read_error(kind: ErrorKind) -> String {

/// Determines if the input buffer contains any padding ('=') ignoring trailing whitespace.
#[cfg(test)]
fn read_and_has_padding<R: Read>(input: &mut R) -> UResult<(bool, Vec<u8>)> {
fn read_and_has_padding<R: std::io::Read>(input: &mut R) -> UResult<(bool, Vec<u8>)> {
let mut buf = Vec::new();
input
.read_to_end(&mut buf)
Expand Down
Loading